Merge branch 'master' into monitoring

This commit is contained in:
Benjamin Huo
2019-05-14 11:22:56 +08:00
committed by GitHub
721 changed files with 87960 additions and 53512 deletions

274
Gopkg.lock generated
View File

@@ -4,8 +4,8 @@
[[projects]]
name = "cloud.google.com/go"
packages = ["compute/metadata"]
revision = "f52f9bc132541d2aa914f42100c36d10b1ef7e0c"
version = "v0.37.0"
revision = "8c41231e01b2085512d98153bcffb847ff9b4b9f"
version = "v0.38.0"
[[projects]]
name = "github.com/Microsoft/go-winio"
@@ -34,17 +34,8 @@
[[projects]]
name = "github.com/Sirupsen/logrus"
packages = ["."]
revision = "dae0fa8d5b0c810a8ab733fbd5510c7cae84eca4"
version = "v1.4.0"
[[projects]]
branch = "master"
name = "github.com/aead/chacha20"
packages = [
".",
"chacha"
]
revision = "8b13a72661dae6e9e5dea04f344f0dc95ea29547"
revision = "8bdbc7bcc01dcbb8ec23dc8a28e332258d25251f"
version = "v1.4.1"
[[projects]]
name = "github.com/andybalholm/cascadia"
@@ -71,24 +62,16 @@
version = "v1.1.0"
[[projects]]
branch = "master"
name = "github.com/beorn7/perks"
packages = ["quantile"]
revision = "3a771d992973f24aa725d07868b467d1ddfceafb"
[[projects]]
name = "github.com/bifurcation/mint"
packages = [
".",
"syntax"
]
revision = "824af65410658916142a7600349144e1289f2110"
revision = "4b2b341e8d7715fae06375aa633dbb6e91b3fb46"
version = "v1.0.0"
[[projects]]
name = "github.com/cenkalti/backoff"
packages = ["."]
revision = "1e4cf3da559842a91afcb6ea6141451e6c30c618"
version = "v2.1.1"
revision = "4b4cebaf850ec58f1bb1fec5bdebdf8501c2bc3f"
version = "v3.0.0"
[[projects]]
name = "github.com/cheekybits/genny"
@@ -115,7 +98,7 @@
"digestset",
"reference"
]
revision = "6d62eb1d4a3515399431b713fde3ce5a9b40e8d5"
revision = "3226863cbcba6dbc2f6c83a37b28126c934af3f8"
[[projects]]
name = "github.com/docker/docker"
@@ -145,20 +128,20 @@
version = "v17.05.0-ce-rc3"
[[projects]]
branch = "master"
name = "github.com/docker/go-connections"
packages = [
"nat",
"sockets",
"tlsconfig"
]
revision = "97c2040d34dfae1d1b1275fa3a78dbdd2f41cf7e"
revision = "7395e3f8aa162843a74ed6d48e79627d9792ac55"
version = "v0.4.0"
[[projects]]
name = "github.com/docker/go-units"
packages = ["."]
revision = "47565b4f722fb6ceae66b95f853feed578a4a51c"
version = "v0.3.3"
revision = "519db1ee28dcc9fd2474ae59fca29a810482bfb1"
version = "v0.4.0"
[[projects]]
branch = "master"
@@ -187,8 +170,8 @@
".",
"log"
]
revision = "85d198d05a92d31823b852b4a5928114912e8949"
version = "v2.9.0"
revision = "103c9496ad8f7e687b8291b56750190012091a96"
version = "v2.9.4"
[[projects]]
name = "github.com/emicklei/go-restful-openapi"
@@ -212,8 +195,8 @@
[[projects]]
name = "github.com/evanphx/json-patch"
packages = ["."]
revision = "72bf35d0ff611848c1dc9df0f976c81192392fa5"
version = "v4.1.0"
revision = "5858425f75500d40c52783dce87d085a483ce135"
version = "v4.2.0"
[[projects]]
name = "github.com/fatih/structs"
@@ -233,11 +216,34 @@
revision = "0ca9ea5df5451ffdf184b4428c902747c2c11cd7"
version = "v1.0.0"
[[projects]]
name = "github.com/go-acme/lego"
packages = [
"acme",
"acme/api",
"acme/api/internal/nonces",
"acme/api/internal/secure",
"acme/api/internal/sender",
"certcrypto",
"certificate",
"challenge",
"challenge/dns01",
"challenge/http01",
"challenge/resolver",
"challenge/tlsalpn01",
"lego",
"log",
"platform/wait",
"registration"
]
revision = "3d13faf68920543a393ad6cdfdea429627af2d34"
version = "v2.5.0"
[[projects]]
name = "github.com/go-ldap/ldap"
packages = ["."]
revision = "729c20c2694d870bcd631f0dadaecd088bd7ccbc"
version = "v3.0.2"
revision = "9f0d712775a0973b7824a1585a86a4ea1d5263d9"
version = "v3.0.3"
[[projects]]
name = "github.com/go-logr/logr"
@@ -255,25 +261,25 @@
name = "github.com/go-openapi/jsonpointer"
packages = ["."]
revision = "ef5f0afec364d3b9396b7b77b43dbe26bf1f8004"
version = "v0.18.0"
version = "v0.19.0"
[[projects]]
name = "github.com/go-openapi/jsonreference"
packages = ["."]
revision = "8483a886a90412cd6858df4ea3483dce9c8e35a3"
version = "v0.18.0"
version = "v0.19.0"
[[projects]]
name = "github.com/go-openapi/spec"
packages = ["."]
revision = "5b6cdde3200976e3ecceb2868706ee39b6aff3e4"
version = "v0.18.0"
revision = "53d776530bf78a11b03a7b52dd8a083086b045e5"
version = "v0.19.0"
[[projects]]
name = "github.com/go-openapi/swag"
packages = ["."]
revision = "1d29f06aebd59ccdf11ae04aa0334ded96e2d909"
version = "v0.18.0"
revision = "b3e2804c8535ee0d1b89320afd98474d5b8e9e3b"
version = "v0.19.0"
[[projects]]
name = "github.com/go-redis/redis"
@@ -298,8 +304,8 @@
[[projects]]
name = "github.com/gobuffalo/envy"
packages = ["."]
revision = "fa0dfdc10b5366ce365b7d9d1755a03e4e797bc5"
version = "v1.6.15"
revision = "043cb4b8af871b49563291e32c66bb84378a60ac"
version = "v1.7.0"
[[projects]]
name = "github.com/gocraft/dbr"
@@ -349,10 +355,10 @@
version = "v1.3.1"
[[projects]]
branch = "master"
name = "github.com/google/btree"
packages = ["."]
revision = "4030bb1f1f0c35b30ca7009e9ebd06849dd45306"
version = "v1.0.0"
[[projects]]
name = "github.com/google/go-querystring"
@@ -361,10 +367,10 @@
version = "v1.0.0"
[[projects]]
branch = "master"
name = "github.com/google/gofuzz"
packages = ["."]
revision = "24818f796faf91cd76ec7bddd72458fbced7a6c1"
revision = "f140a6486e521aad38f5917de355cbf147cc0496"
version = "v1.0.0"
[[projects]]
name = "github.com/google/uuid"
@@ -385,8 +391,8 @@
[[projects]]
name = "github.com/gorilla/mux"
packages = ["."]
revision = "a7962380ca08b5a188038c69871b8d3fbdf31e89"
version = "v1.7.0"
revision = "c5c6c98bc25355028a63748a498942a6398ccd22"
version = "v1.7.1"
[[projects]]
name = "github.com/gorilla/websocket"
@@ -412,8 +418,8 @@
[[projects]]
name = "github.com/hashicorp/go-version"
packages = ["."]
revision = "d40cf49b3a77bba84a7afdbd7f1dc295d114efb1"
version = "v1.1.0"
revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd"
version = "v1.2.0"
[[projects]]
name = "github.com/hashicorp/golang-lru"
@@ -510,8 +516,8 @@
[[projects]]
name = "github.com/klauspost/cpuid"
packages = ["."]
revision = "e7e905edc00ea8827e58662220139109efea09db"
version = "v1.2.0"
revision = "05a8198c0f5a27739aec358908d7e12c64ce6eb7"
version = "v1.2.1"
[[projects]]
name = "github.com/knative/pkg"
@@ -547,6 +553,14 @@
name = "github.com/kubernetes-sigs/application"
packages = [
"pkg/apis/app/v1beta1",
"pkg/client/clientset/versioned",
"pkg/client/clientset/versioned/scheme",
"pkg/client/clientset/versioned/typed/app/v1beta1",
"pkg/client/informers/externalversions",
"pkg/client/informers/externalversions/app",
"pkg/client/informers/externalversions/app/v1beta1",
"pkg/client/informers/externalversions/internalinterfaces",
"pkg/client/listers/app/v1beta1",
"pkg/component",
"pkg/customresource",
"pkg/finalizer",
@@ -554,7 +568,7 @@
"pkg/kbcontroller",
"pkg/resource"
]
revision = "c06442ed33338d8a2eece32e25b869f728a28cbc"
revision = "1be8f5eada07fe5b17804e4b91fc2f4c4fc4ecb9"
source = "https://github.com/kubesphere/application"
[[projects]]
@@ -579,12 +593,6 @@
revision = "313ae64d680ed14ef12e806ab0dd3777240e908d"
version = "v0.0.2"
[[projects]]
branch = "master"
name = "github.com/lucas-clemente/aes12"
packages = ["."]
revision = "cd47fb39b79f867c6e4e5cd39cf7abd799f71670"
[[projects]]
name = "github.com/lucas-clemente/quic-go"
packages = [
@@ -592,22 +600,15 @@
"h2quic",
"internal/ackhandler",
"internal/congestion",
"internal/crypto",
"internal/flowcontrol",
"internal/handshake",
"internal/protocol",
"internal/qerr",
"internal/utils",
"internal/wire",
"qerr"
"internal/wire"
]
revision = "714f38d5d0aff85894fd890718b991e361f03e7d"
version = "v0.10.1"
[[projects]]
branch = "master"
name = "github.com/lucas-clemente/quic-go-certificates"
packages = ["."]
revision = "d2f86524cced5186554df90d92529757d22c1cb6"
revision = "8dcdf12ff78def6cfd36107d8bf42aa38110fa94"
version = "v0.11.1"
[[projects]]
branch = "master"
@@ -617,7 +618,7 @@
"jlexer",
"jwriter"
]
revision = "1de009706dbeb9d05f18586f0735fcdb7c524481"
revision = "1ea4449da9834f4d333f1cc461c374aea217d249"
[[projects]]
name = "github.com/markbates/inflect"
@@ -625,6 +626,12 @@
revision = "24b83195037b3bc61fcda2d28b7b0518bce293b6"
version = "v1.0.4"
[[projects]]
name = "github.com/marten-seemann/qtls"
packages = ["."]
revision = "65ca381cd298d7e0aef0de8ba523a870ec5a96fe"
version = "v0.2.3"
[[projects]]
name = "github.com/matttproud/golang_protobuf_extensions"
packages = ["pbutil"]
@@ -673,20 +680,20 @@
"onevent/hook",
"telemetry"
]
revision = "80dfb8b2a7f89b120a627bc4d866a1dc5ed3d92f"
version = "v0.11.5"
revision = "15fecbc16151308959674a5ce3843efb9e66bd5f"
version = "v1.0.0"
[[projects]]
branch = "master"
name = "github.com/mholt/certmagic"
packages = ["."]
revision = "e3e89d1096d76d61680f8eeb8f67649baa6c54b8"
revision = "0030c3ed9a43567fe81571f0750141725f468c55"
version = "v0.5.1"
[[projects]]
name = "github.com/miekg/dns"
packages = ["."]
revision = "cc8cd02140663157ce797c6650488d6c8563f31f"
version = "v1.1.6"
revision = "8aa92d4e02c501ba21e26fb92cf2fb9f23f56917"
version = "v1.1.9"
[[projects]]
name = "github.com/mitchellh/go-homedir"
@@ -803,8 +810,8 @@
[[projects]]
name = "github.com/peterbourgon/diskv"
packages = ["."]
revision = "5f041e8faa004a95c88a202771f4cc3e991971e6"
version = "v2.0.1"
revision = "0be1b92a6df0e4f5cb0a5d15fb7f643d0ad93ce6"
version = "v3.0.0"
[[projects]]
name = "github.com/pkg/errors"
@@ -843,20 +850,17 @@
"internal/bitbucket.org/ww/goautoneg",
"model"
]
revision = "cfeb6f9992ffa54aaa4f2170ade4067ee478b250"
version = "v0.2.0"
revision = "1ba88736f028e37bc17328369e94a537ae9e0234"
version = "v0.4.0"
[[projects]]
branch = "master"
name = "github.com/prometheus/procfs"
packages = [
".",
"internal/util",
"iostats",
"nfs",
"xfs"
"internal/fs"
]
revision = "e56f2e22fc761e82a34aca553f6725e2aff4fe6c"
revision = "5867b95ac084bbfee6ea16595c4e05ab009021da"
[[projects]]
name = "github.com/rogpeppe/go-internal"
@@ -865,8 +869,8 @@
"module",
"semver"
]
revision = "1cf9852c553c5b7da2d5a4a091129a7822fed0c9"
version = "v1.2.2"
revision = "438578804ca6f31be148c27683afc419ce47c06e"
version = "v1.3.0"
[[projects]]
name = "github.com/russross/blackfriday"
@@ -898,8 +902,8 @@
".",
"mem"
]
revision = "f4711e4db9e9a1d3887343acb72b2bbfc2f686f5"
version = "v1.2.1"
revision = "588a75ec4f32903aa5e39a2619ba6a4631e28424"
version = "v1.2.2"
[[projects]]
name = "github.com/spf13/cobra"
@@ -945,34 +949,11 @@
revision = "6a3e2ff9e7c564f36873c2e36413f634534f1c44"
version = "v0.2.1"
[[projects]]
name = "github.com/xenolf/lego"
packages = [
"acme",
"acme/api",
"acme/api/internal/nonces",
"acme/api/internal/secure",
"acme/api/internal/sender",
"certcrypto",
"certificate",
"challenge",
"challenge/dns01",
"challenge/http01",
"challenge/resolver",
"challenge/tlsalpn01",
"lego",
"log",
"platform/wait",
"registration"
]
revision = "2952cdaebd4da7cd560e195343bdd3cb78a67643"
version = "v2.3.0"
[[projects]]
name = "go.uber.org/atomic"
packages = ["."]
revision = "1ea20fb1cbb1cc08cbd0d913a96dead89aa18289"
version = "v1.3.2"
revision = "df976f2515e274675050de7b3f42545de80594fd"
version = "v1.4.0"
[[projects]]
name = "go.uber.org/multierr"
@@ -990,14 +971,17 @@
"internal/exit",
"zapcore"
]
revision = "ff33455a0e382e8a81d14dd7c922020b6b5e7982"
version = "v1.9.1"
revision = "27376062155ad36be76b0f12cf1572a221d3a48c"
version = "v1.10.0"
[[projects]]
branch = "master"
name = "golang.org/x/crypto"
packages = [
"cast5",
"chacha20poly1305",
"cryptobyte",
"cryptobyte/asn1",
"curve25519",
"ed25519",
"ed25519/internal/edwards25519",
@@ -1018,7 +1002,7 @@
"ssh/knownhosts",
"ssh/terminal"
]
revision = "a1f597ede03a7bef967a422b5b3a5bd08805a01e"
revision = "cbcb750295291b33242907a04be40e80801d0cfc"
[[projects]]
branch = "master"
@@ -1041,7 +1025,7 @@
"ipv6",
"proxy"
]
revision = "9f648a60d9775ef5c977e7669d1673a7a67bef33"
revision = "a4d6f7feada510cc50e69a37b484cb0fdc6b7876"
[[projects]]
branch = "master"
@@ -1053,7 +1037,7 @@
"jws",
"jwt"
]
revision = "e64efc72b421e893cbf63f17ba2221e7d6d0b0f3"
revision = "9f3314589c9a9136388751d9adae6b0ed400978a"
[[projects]]
branch = "master"
@@ -1063,7 +1047,7 @@
"unix",
"windows"
]
revision = "fead79001313d15903fb4605b4a1b781532cd93e"
revision = "a5b02f93d862f065920dd6a40dddc66b60d0dec4"
[[projects]]
name = "golang.org/x/text"
@@ -1082,6 +1066,8 @@
"encoding/unicode",
"internal/colltab",
"internal/gen",
"internal/language",
"internal/language/compact",
"internal/tag",
"internal/triegen",
"internal/ucd",
@@ -1096,8 +1082,8 @@
"unicode/rangetable",
"width"
]
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
version = "v0.3.0"
revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475"
version = "v0.3.2"
[[projects]]
branch = "master"
@@ -1111,7 +1097,6 @@
packages = [
"go/ast/astutil",
"go/gcexportdata",
"go/internal/cgo",
"go/internal/gcimporter",
"go/internal/packagesdriver",
"go/packages",
@@ -1122,7 +1107,7 @@
"internal/module",
"internal/semver"
]
revision = "8b67d361bba210f5fbb3c1a0fc121e0847b10b57"
revision = "99f201b6807eb28f750a1966316bb0d4417b6020"
[[projects]]
name = "google.golang.org/appengine"
@@ -1139,8 +1124,8 @@
"internal/urlfetch",
"urlfetch"
]
revision = "e9657d882bb81064595ca3b56cbe2546bbabf7b1"
version = "v1.4.0"
revision = "54a98f90d1c46b7731eb8fb305d2a321c30ef610"
version = "v1.5.0"
[[projects]]
name = "gopkg.in/asn1-ber.v1"
@@ -1181,8 +1166,8 @@
"json",
"jwt"
]
revision = "628223f44a71f715d2881ea69afc795a1e9c01be"
version = "v2.3.0"
revision = "730df5f748271903322feb182be83b43ebbbe27d"
version = "v2.3.1"
[[projects]]
name = "gopkg.in/src-d/go-billy.v4"
@@ -1241,8 +1226,8 @@
"utils/merkletrie/internal/frame",
"utils/merkletrie/noder"
]
revision = "db6c41c156481962abf9a55a324858674c25ab08"
version = "v4.10.0"
revision = "aa6f288c256ff8baf8a7745546a9752323dc0d89"
version = "v4.11.0"
[[projects]]
branch = "v1"
@@ -1372,7 +1357,7 @@
"third_party/forked/golang/reflect"
]
revision = "2b1284ed4c93a43499e781493253e2ac5959c4fd"
version = "kubernetes-1.13.1"
version = "kubernetes-1.13.0"
[[projects]]
name = "k8s.io/apiserver"
@@ -1545,6 +1530,7 @@
version = "kubernetes-1.13.1"
[[projects]]
branch = "release-1.13"
name = "k8s.io/code-generator"
packages = [
"cmd/client-gen",
@@ -1558,7 +1544,6 @@
"pkg/util"
]
revision = "c2090bec4d9b1fb25de3812f868accc2bc9ecbae"
version = "kubernetes-1.13.1"
[[projects]]
branch = "master"
@@ -1573,19 +1558,19 @@
"parser",
"types"
]
revision = "b90029ef6cd877cb3f422d75b3a07707e3aac6b7"
revision = "e17681d19d3ac4837a019ece36c2a0ec31ffe985"
[[projects]]
name = "k8s.io/klog"
packages = ["."]
revision = "a5bc97fbc634d635061f3146511332c7e313a55a"
version = "v0.1.0"
revision = "e531227889390a39d9533dde61f590fe9f4b0035"
version = "v0.3.0"
[[projects]]
branch = "master"
name = "k8s.io/kube-openapi"
packages = ["pkg/util/proto"]
revision = "15615b16d372105f0c69ff47dfe7402926a65aaa"
revision = "a01b7d5d6c2258c80a4a10070f3dee9cd575d9c7"
[[projects]]
name = "k8s.io/kubernetes"
@@ -1619,20 +1604,21 @@
"pkg/util/parsers",
"pkg/util/taints"
]
revision = "c27b913fddd1a6c480c229191a087698aa92f0b1"
version = "v1.13.4"
revision = "abdda3f9fefa29172298a2e42f5102e777a8ec25"
version = "v1.13.6"
[[projects]]
branch = "master"
name = "k8s.io/utils"
packages = ["pointer"]
revision = "21c4ce38f2a793ec01e925ddc31216500183b773"
revision = "8fab8cb257d50c8cf94ec9771e74826edbb68fb5"
[[projects]]
branch = "master"
branch = "kubesphere"
name = "sigs.k8s.io/application"
packages = ["pkg/controller/application"]
revision = "4ead7f1b87048b7717b3e474a21fdc07e6bce636"
revision = "1be8f5eada07fe5b17804e4b91fc2f4c4fc4ecb9"
source = "https://github.com/kubesphere/application"
[[projects]]
name = "sigs.k8s.io/controller-runtime"
@@ -1713,6 +1699,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "5af8bb8e719bbe2eca43012a2fa6c77adfd90b211a352e09ab95de7a2865b46f"
inputs-digest = "84050360b125f37a31888a849dfe1298bc825cd2035af10c31a4013d03244d9a"
solver-name = "gps-cdcl"
solver-version = 1

View File

@@ -89,7 +89,7 @@ required = [
# vendor/github.com/docker/docker/registry/service_v2.go:11: cannot call non-function tlsconfig.ServerDefault (type tls.Config)
[[override]]
name = "github.com/docker/go-connections"
branch = "master"
version = "0.4.0"
# For dependency below: Refer to issue https://github.com/golang/dep/issues/1799
[[override]]
@@ -101,9 +101,12 @@ required = [
name = "github.com/russross/blackfriday"
version = "v1.5.2"
# offical application controller doesn't limit observe scope to namespace
# use our own version instead
[[constraint]]
branch = "master"
name = "sigs.k8s.io/application"
source = "https://github.com/kubesphere/application"
branch = "kubesphere"
[[constraint]]
name = "github.com/kiali/kiali"
@@ -115,10 +118,6 @@ required = [
source = "https://github.com/kubesphere/application"
branch = "kubesphere"
[[constraint]]
name = "github.com/gorilla/mux"
version = "1.7.0"
[[constraint]]
name = "github.com/knative/pkg"
revision = "cd278f2d3394c865fda66bca12459e879e0279b8"

View File

@@ -21,8 +21,11 @@ import (
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"kubesphere.io/kubesphere/pkg/controller/application"
"kubesphere.io/kubesphere/pkg/controller/destinationrule"
"kubesphere.io/kubesphere/pkg/controller/job"
//"kubesphere.io/kubesphere/pkg/controller/job"
"kubesphere.io/kubesphere/pkg/controller/virtualservice"
"sigs.k8s.io/controller-runtime/pkg/manager"
"time"
@@ -31,6 +34,8 @@ import (
istioclientset "github.com/knative/pkg/client/clientset/versioned"
istioinformers "github.com/knative/pkg/client/informers/externalversions"
applicationclientset "github.com/kubernetes-sigs/application/pkg/client/clientset/versioned"
applicationinformers "github.com/kubernetes-sigs/application/pkg/client/informers/externalversions"
servicemeshclientset "kubesphere.io/kubesphere/pkg/client/clientset/versioned"
servicemeshinformers "kubesphere.io/kubesphere/pkg/client/informers/externalversions"
)
@@ -52,8 +57,15 @@ func AddControllers(mgr manager.Manager, cfg *rest.Config, stopCh <-chan struct{
return err
}
applicationClient, err := applicationclientset.NewForConfig(cfg)
if err != nil {
log.Error(err, "create application client failed")
return err
}
informerFactory := informers.NewSharedInformerFactory(kubeClient, defaultResync)
istioInformer := istioinformers.NewSharedInformerFactory(istioclient, defaultResync)
applicationInformer := applicationinformers.NewSharedInformerFactory(applicationClient, defaultResync)
servicemeshclient, err := servicemeshclientset.NewForConfig(cfg)
if err != nil {
@@ -61,12 +73,12 @@ func AddControllers(mgr manager.Manager, cfg *rest.Config, stopCh <-chan struct{
return err
}
servicemeshinformer := servicemeshinformers.NewSharedInformerFactory(servicemeshclient, defaultResync)
servicemeshInformer := servicemeshinformers.NewSharedInformerFactory(servicemeshclient, defaultResync)
vsController := virtualservice.NewVirtualServiceController(informerFactory.Core().V1().Services(),
istioInformer.Networking().V1alpha3().VirtualServices(),
istioInformer.Networking().V1alpha3().DestinationRules(),
servicemeshinformer.Servicemesh().V1alpha2().Strategies(),
servicemeshInformer.Servicemesh().V1alpha2().Strategies(),
kubeClient,
istioclient,
servicemeshclient)
@@ -74,19 +86,30 @@ func AddControllers(mgr manager.Manager, cfg *rest.Config, stopCh <-chan struct{
drController := destinationrule.NewDestinationRuleController(informerFactory.Apps().V1().Deployments(),
istioInformer.Networking().V1alpha3().DestinationRules(),
informerFactory.Core().V1().Services(),
servicemeshinformer.Servicemesh().V1alpha2().ServicePolicies(),
servicemeshInformer.Servicemesh().V1alpha2().ServicePolicies(),
kubeClient,
istioclient)
apController := application.NewApplicationController(informerFactory.Core().V1().Services(),
informerFactory.Apps().V1().Deployments(),
informerFactory.Apps().V1().StatefulSets(),
servicemeshInformer.Servicemesh().V1alpha2().Strategies(),
servicemeshInformer.Servicemesh().V1alpha2().ServicePolicies(),
applicationInformer.App().V1beta1().Applications(),
kubeClient,
applicationClient)
jobController := job.NewJobController(informerFactory.Batch().V1().Jobs(), kubeClient)
servicemeshinformer.Start(stopCh)
servicemeshInformer.Start(stopCh)
istioInformer.Start(stopCh)
informerFactory.Start(stopCh)
applicationInformer.Start(stopCh)
controllers := map[string]manager.Runnable{
"virtualservice-controller": vsController,
"destinationrule-controller": drController,
"application-controller": apController,
"job-controller": jobController,
}

View File

@@ -32,6 +32,7 @@ import (
"kubesphere.io/kubesphere/pkg/informers"
"kubesphere.io/kubesphere/pkg/models/devops"
logging "kubesphere.io/kubesphere/pkg/models/log"
"kubesphere.io/kubesphere/pkg/server"
"kubesphere.io/kubesphere/pkg/signals"
"kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
@@ -74,7 +75,7 @@ func Run(s *options.ServerRunOptions) error {
container := runtime.Container
container.DoNotRecover(false)
container.Filter(filter.Logging)
container.RecoverHandler(server.LogStackOnRecover)
for _, webservice := range container.RegisteredWebServices() {
for _, route := range webservice.Routes() {
log.Println(route.Method, route.Path)

View File

@@ -28,6 +28,7 @@ import (
"kubesphere.io/kubesphere/pkg/filter"
"kubesphere.io/kubesphere/pkg/informers"
"kubesphere.io/kubesphere/pkg/models/iam"
"kubesphere.io/kubesphere/pkg/server"
"kubesphere.io/kubesphere/pkg/signals"
"kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
@@ -84,6 +85,7 @@ func Run(s *options.ServerRunOptions) error {
container := runtime.Container
container.Filter(filter.Logging)
container.DoNotRecover(false)
container.RecoverHandler(server.LogStackOnRecover)
for _, webservice := range container.RegisteredWebServices() {
for _, route := range webservice.Routes() {

View File

@@ -18,11 +18,7 @@
package controller
import (
"kubesphere.io/kubesphere/pkg/controller/clusterrolebinding"
)
func init() {
// AddToManagerFuncs is a list of functions to create controllers and add them to a manager.
AddToManagerFuncs = append(AddToManagerFuncs, clusterrolebinding.Add)
//AddToManagerFuncs = append(AddToManagerFuncs, clusterrolebinding.Add)
}

View File

@@ -18,9 +18,7 @@
package controller
import (
"kubesphere.io/kubesphere/pkg/controller/namespace"
)
import "kubesphere.io/kubesphere/pkg/controller/namespace"
func init() {
// AddToManagerFuncs is a list of functions to create controllers and add them to a manager.

View File

@@ -21,9 +21,6 @@ import (
)
func init() {
// AddToManagerFuncs is a list of functions to create controllers and add them to a manager.
//AddToManagerFuncs = append(AddToManagerFuncs, strategy.Add)
// Add application to manager functions
AddToManagerFuncs = append(AddToManagerFuncs, application.Add)

View File

@@ -0,0 +1,262 @@
package application
import (
"fmt"
applicationclient "github.com/kubernetes-sigs/application/pkg/client/clientset/versioned"
applicationinformers "github.com/kubernetes-sigs/application/pkg/client/informers/externalversions/app/v1beta1"
applicationlister "github.com/kubernetes-sigs/application/pkg/client/listers/app/v1beta1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
informersv1 "k8s.io/client-go/informers/apps/v1"
coreinformers "k8s.io/client-go/informers/core/v1"
clientset "k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
listersv1 "k8s.io/client-go/listers/apps/v1"
corelisters "k8s.io/client-go/listers/core/v1"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/tools/record"
"k8s.io/client-go/util/workqueue"
"k8s.io/kubernetes/pkg/controller"
"k8s.io/kubernetes/pkg/util/metrics"
servicemeshinformers "kubesphere.io/kubesphere/pkg/client/informers/externalversions/servicemesh/v1alpha2"
servicemeshlisters "kubesphere.io/kubesphere/pkg/client/listers/servicemesh/v1alpha2"
"kubesphere.io/kubesphere/pkg/controller/virtualservice/util"
logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"
"time"
)
const (
// maxRetries is the number of times a service will be retried before it is dropped out of the queue.
// With the current rate-limiter in use (5ms*2^(maxRetries-1)) the following numbers represent the
// sequence of delays between successive queuings of a service.
//
// 5ms, 10ms, 20ms, 40ms, 80ms, 160ms, 320ms, 640ms, 1.3s, 2.6s, 5.1s, 10.2s, 20.4s, 41s, 82s
maxRetries = 15
)
var log = logf.Log.WithName("application-controller")
type ApplicationController struct {
client clientset.Interface
applicationClient applicationclient.Interface
eventBroadcaster record.EventBroadcaster
eventRecorder record.EventRecorder
applicationLister applicationlister.ApplicationLister
applicationSynced cache.InformerSynced
serviceLister corelisters.ServiceLister
serviceSynced cache.InformerSynced
deploymentLister listersv1.DeploymentLister
deploymentSynced cache.InformerSynced
statefulSetLister listersv1.StatefulSetLister
statefulSetSynced cache.InformerSynced
strategyLister servicemeshlisters.StrategyLister
strategySynced cache.InformerSynced
servicePolicyLister servicemeshlisters.ServicePolicyLister
servicePolicySynced cache.InformerSynced
queue workqueue.RateLimitingInterface
workerLoopPeriod time.Duration
}
func NewApplicationController(serviceInformer coreinformers.ServiceInformer,
deploymentInformer informersv1.DeploymentInformer,
statefulSetInformer informersv1.StatefulSetInformer,
strategyInformer servicemeshinformers.StrategyInformer,
servicePolicyInformer servicemeshinformers.ServicePolicyInformer,
applicationInformer applicationinformers.ApplicationInformer,
client clientset.Interface,
applicationClient applicationclient.Interface) *ApplicationController {
broadcaster := record.NewBroadcaster()
broadcaster.StartLogging(func(format string, args ...interface{}) {
log.Info(fmt.Sprintf(format, args))
})
broadcaster.StartRecordingToSink(&corev1.EventSinkImpl{Interface: client.CoreV1().Events("")})
recorder := broadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "application-controller"})
if client != nil && client.CoreV1().RESTClient().GetRateLimiter() != nil {
metrics.RegisterMetricAndTrackRateLimiterUsage("virtualservice_controller", client.CoreV1().RESTClient().GetRateLimiter())
}
v := &ApplicationController{
client: client,
applicationClient: applicationClient,
queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "application"),
workerLoopPeriod: time.Second,
}
v.deploymentLister = deploymentInformer.Lister()
v.deploymentSynced = deploymentInformer.Informer().HasSynced
deploymentInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: v.enqueueObject,
DeleteFunc: v.enqueueObject,
})
v.statefulSetLister = statefulSetInformer.Lister()
v.statefulSetSynced = statefulSetInformer.Informer().HasSynced
statefulSetInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: v.enqueueObject,
DeleteFunc: v.enqueueObject,
})
v.serviceLister = serviceInformer.Lister()
v.serviceSynced = serviceInformer.Informer().HasSynced
serviceInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: v.enqueueObject,
DeleteFunc: v.enqueueObject,
})
v.strategyLister = strategyInformer.Lister()
v.strategySynced = strategyInformer.Informer().HasSynced
strategyInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: v.enqueueObject,
DeleteFunc: v.enqueueObject,
})
v.servicePolicyLister = servicePolicyInformer.Lister()
v.servicePolicySynced = servicePolicyInformer.Informer().HasSynced
servicePolicyInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
AddFunc: v.enqueueObject,
DeleteFunc: v.enqueueObject,
})
v.applicationLister = applicationInformer.Lister()
v.applicationSynced = applicationInformer.Informer().HasSynced
v.eventBroadcaster = broadcaster
v.eventRecorder = recorder
return v
}
func (v *ApplicationController) Start(stopCh <-chan struct{}) error {
v.Run(5, stopCh)
return nil
}
func (v *ApplicationController) Run(workers int, stopCh <-chan struct{}) {
defer utilruntime.HandleCrash()
defer v.queue.ShutDown()
log.Info("starting application controller")
defer log.Info("shutting down application controller")
if !controller.WaitForCacheSync("application-controller", stopCh, v.deploymentSynced, v.statefulSetSynced, v.serviceSynced, v.strategySynced, v.servicePolicySynced, v.applicationSynced) {
return
}
for i := 0; i < workers; i++ {
go wait.Until(v.worker, v.workerLoopPeriod, stopCh)
}
<-stopCh
}
func (v *ApplicationController) worker() {
for v.processNextWorkItem() {
}
}
func (v *ApplicationController) processNextWorkItem() bool {
eKey, quit := v.queue.Get()
if quit {
return false
}
defer v.queue.Done(eKey)
err := v.syncApplication(eKey.(string))
v.handleErr(err, eKey)
return true
}
func (v *ApplicationController) syncApplication(key string) error {
startTime := time.Now()
namespace, name, err := cache.SplitMetaNamespaceKey(key)
if err != nil {
log.Error(err, "not a valid controller key", "key", key)
return err
}
defer func() {
log.V(4).Info("Finished updating application.", "namespace", namespace, "name", name, "duration", time.Since(startTime))
}()
application, err := v.applicationLister.Applications(namespace).Get(name)
if err != nil {
if errors.IsNotFound(err) {
// application has been deleted
return nil
}
log.Error(err, "get application failed")
}
annotations := application.GetAnnotations()
annotations["kubesphere.io/last-updated"] = time.Now().String()
application.SetAnnotations(annotations)
_, err = v.applicationClient.AppV1beta1().Applications(namespace).Update(application)
if err != nil {
if errors.IsNotFound(err) {
log.V(4).Info("application has been deleted during update")
return nil
}
log.Error(err, "failed to update application", "namespace", namespace, "name", name)
return err
}
return nil
}
func (v *ApplicationController) enqueueObject(obj interface{}) {
var resource = obj.(metav1.Object)
if resource.GetLabels() == nil || !util.IsApplicationComponent(resource.GetLabels()) {
return
}
applicationName := util.GetApplictionName(resource.GetLabels())
if len(applicationName) > 0 {
key := resource.GetNamespace() + "/" + applicationName
v.queue.Add(key)
}
}
func (v *ApplicationController) handleErr(err error, key interface{}) {
if err != nil {
v.queue.Forget(key)
return
}
if v.queue.NumRequeues(key) < maxRetries {
log.V(2).Info("Error syncing virtualservice for service retrying.", "key", key, "error", err)
v.queue.AddRateLimited(key)
return
}
log.V(4).Info("Dropping service out of the queue.", "key", key, "error", err)
v.queue.Forget(key)
utilruntime.HandleError(err)
}

View File

@@ -0,0 +1 @@
package application

View File

@@ -332,9 +332,9 @@ func (v *DestinationRuleController) syncService(key string) error {
}
if createDestinationRule {
_, err = v.destinationRuleClient.NetworkingV1alpha3().DestinationRules(namespace).Create(newDestinationRule)
newDestinationRule, err = v.destinationRuleClient.NetworkingV1alpha3().DestinationRules(namespace).Create(newDestinationRule)
} else {
_, err = v.destinationRuleClient.NetworkingV1alpha3().DestinationRules(namespace).Update(newDestinationRule)
newDestinationRule, err = v.destinationRuleClient.NetworkingV1alpha3().DestinationRules(namespace).Update(newDestinationRule)
}
if err != nil {

View File

@@ -18,6 +18,7 @@ const (
// resource with these following labels considered as part of servicemesh
var ApplicationLabels = [...]string{
ApplicationNameLabel,
ApplicationVersionLabel,
AppLabel,
}
@@ -32,6 +33,14 @@ func NormalizeVersionName(version string) string {
return version
}
func GetApplictionName(lbs map[string]string) string {
if name, ok := lbs[ApplicationNameLabel]; ok {
return name
}
return ""
}
func GetComponentName(meta *metav1.ObjectMeta) string {
if len(meta.Labels[AppLabel]) > 0 {
return meta.Labels[AppLabel]

View File

@@ -387,9 +387,9 @@ func (v *VirtualServiceController) syncService(key string) error {
}
if createVirtualService {
_, err = v.virtualServiceClient.NetworkingV1alpha3().VirtualServices(namespace).Create(newVirtualService)
newVirtualService, err = v.virtualServiceClient.NetworkingV1alpha3().VirtualServices(namespace).Create(newVirtualService)
} else {
_, err = v.virtualServiceClient.NetworkingV1alpha3().VirtualServices(namespace).Update(newVirtualService)
newVirtualService, err = v.virtualServiceClient.NetworkingV1alpha3().VirtualServices(namespace).Update(newVirtualService)
}
if err != nil {

View File

@@ -19,9 +19,11 @@
package routers
import (
"fmt"
"github.com/golang/glog"
"io/ioutil"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"kubesphere.io/kubesphere/pkg/simple/client/k8s"
"sort"
@@ -48,6 +50,37 @@ const (
SIDECAR_INJECT = "sidecar.istio.io/inject"
)
var routerTemplates map[string]runtime.Object
// Load yamls
func init() {
yamls, err := LoadYamls()
routerTemplates = make(map[string]runtime.Object, 2)
if err != nil {
glog.Error(err)
return
}
for _, f := range yamls {
decode := scheme.Codecs.UniversalDeserializer().Decode
obj, _, err := decode([]byte(f), nil, nil)
if err != nil {
glog.Error(err)
continue
}
switch obj.(type) {
case *corev1.Service:
routerTemplates["SERVICE"] = obj
case *extensionsv1beta1.Deployment:
routerTemplates["DEPLOYMENT"] = obj
}
}
}
// get master node ip, if there are multiple master nodes,
// choose first one according by their names alphabetically
func getMasterNodeIp() string {
@@ -115,6 +148,12 @@ func GetAllRouters() ([]*corev1.Service, error) {
// Get router from a namespace
func GetRouter(namespace string) (*corev1.Service, error) {
service, err := getRouterService(namespace)
addLoadBalancerIp(service)
return service, err
}
func getRouterService(namespace string) (*corev1.Service, error) {
serviceName := constants.IngressControllerPrefix + namespace
serviceLister := informers.SharedInformerFactory().Core().V1().Services().Lister()
@@ -127,8 +166,6 @@ func GetRouter(namespace string) (*corev1.Service, error) {
glog.Error(err)
return nil, err
}
addLoadBalancerIp(service)
return service, nil
}
@@ -163,12 +200,6 @@ func LoadYamls() ([]string, error) {
// Create a ingress controller in a namespace
func CreateRouter(namespace string, routerType corev1.ServiceType, annotations map[string]string) (*corev1.Service, error) {
k8sClient := k8s.Client()
var router *corev1.Service
yamls, err := LoadYamls()
injectSidecar := false
if enabled, ok := annotations[SERVICEMESH_ENABLED]; ok {
if enabled == "true" {
@@ -176,74 +207,17 @@ func CreateRouter(namespace string, routerType corev1.ServiceType, annotations m
}
}
err := createOrUpdateRouterWorkload(namespace, routerType == corev1.ServiceTypeLoadBalancer, injectSidecar)
if err != nil {
glog.Error(err)
return nil, err
}
for _, f := range yamls {
decode := scheme.Codecs.UniversalDeserializer().Decode
obj, _, err := decode([]byte(f), nil, nil)
if err != nil {
glog.Error(err)
return router, err
}
switch obj.(type) {
case *corev1.Service:
service := obj.(*corev1.Service)
service.SetAnnotations(annotations)
service.Spec.Type = routerType
service.Name = constants.IngressControllerPrefix + namespace
// Add project selector
service.Labels["project"] = namespace
service.Spec.Selector["project"] = namespace
service, err := k8sClient.CoreV1().Services(constants.IngressControllerNamespace).Create(service)
if err != nil {
glog.Error(err)
return nil, err
}
router = service
case *extensionsv1beta1.Deployment:
deployment := obj.(*extensionsv1beta1.Deployment)
deployment.Name = constants.IngressControllerPrefix + namespace
// Add project label
deployment.Spec.Selector.MatchLabels["project"] = namespace
deployment.Spec.Template.Labels["project"] = namespace
if injectSidecar {
if deployment.Spec.Template.Annotations == nil {
deployment.Spec.Template.Annotations = make(map[string]string, 0)
}
deployment.Spec.Template.Annotations[SIDECAR_INJECT] = "true"
}
// Isolate namespace
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--watch-namespace="+namespace)
// Choose self as master
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--election-id="+deployment.Name)
if routerType == corev1.ServiceTypeLoadBalancer {
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--publish-service="+constants.IngressControllerNamespace+"/"+constants.IngressControllerPrefix+namespace)
} else {
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--report-node-internal-ip-address")
}
deployment, err := k8sClient.ExtensionsV1beta1().Deployments(constants.IngressControllerNamespace).Create(deployment)
if err != nil {
glog.Error(err)
}
default:
//glog.Info("Default resource")
}
router, err := createRouterService(namespace, routerType, annotations)
if err != nil {
glog.Error(err)
_ = deleteRouterWorkload(namespace)
return nil, err
}
addLoadBalancerIp(router)
@@ -253,10 +227,85 @@ func CreateRouter(namespace string, routerType corev1.ServiceType, annotations m
// DeleteRouter is used to delete ingress controller related resources in namespace
// It will not delete ClusterRole resource cause it maybe used by other controllers
func DeleteRouter(namespace string) (*corev1.Service, error) {
err := deleteRouterWorkload(namespace)
if err != nil {
glog.Error(err)
}
router, err := deleteRouterService(namespace)
if err != nil {
glog.Error(err)
return router, err
}
return router, nil
}
func createRouterService(namespace string, routerType corev1.ServiceType, annotations map[string]string) (*corev1.Service, error) {
obj, ok := routerTemplates["SERVICE"]
if !ok {
glog.Error("service template not loaded")
return nil, fmt.Errorf("service template not loaded")
}
k8sClient := k8s.Client()
var err error
var router *corev1.Service
service := obj.(*corev1.Service)
service.SetAnnotations(annotations)
service.Spec.Type = routerType
service.Name = constants.IngressControllerPrefix + namespace
// Add project selector
service.Labels["project"] = namespace
service.Spec.Selector["project"] = namespace
service, err := k8sClient.CoreV1().Services(constants.IngressControllerNamespace).Create(service)
if err != nil {
glog.Error(err)
return nil, err
}
return service, nil
}
func updateRouterService(namespace string, routerType corev1.ServiceType, annotations map[string]string) (*corev1.Service, error) {
k8sClient := k8s.Client()
service, err := getRouterService(namespace)
if err != nil {
glog.Error(err, "get router failed")
return service, err
}
service.Spec.Type = routerType
originalAnnotations := service.GetAnnotations()
for key, val := range annotations {
originalAnnotations[key] = val
}
service.SetAnnotations(originalAnnotations)
service, err = k8sClient.CoreV1().Services(constants.IngressControllerNamespace).Update(service)
return service, err
}
func deleteRouterService(namespace string) (*corev1.Service, error) {
service, err := getRouterService(namespace)
if err != nil {
glog.Error(err)
return service, err
}
k8sClient := k8s.Client()
// delete controller service
serviceName := constants.IngressControllerPrefix + namespace
@@ -265,11 +314,98 @@ func DeleteRouter(namespace string) (*corev1.Service, error) {
err = k8sClient.CoreV1().Services(constants.IngressControllerNamespace).Delete(serviceName, &deleteOptions)
if err != nil {
glog.Error(err)
return service, err
}
return service, nil
}
func createOrUpdateRouterWorkload(namespace string, publishService bool, servicemeshEnabled bool) error {
obj, ok := routerTemplates["DEPLOYMENT"]
if !ok {
glog.Error("Deployment template file not loaded")
return fmt.Errorf("deployment template file not loaded")
}
deployName := constants.IngressControllerPrefix + namespace
k8sClient := k8s.Client()
deployment, err := k8sClient.ExtensionsV1beta1().Deployments(constants.IngressControllerNamespace).Get(deployName, meta_v1.GetOptions{})
createDeployment := true
if err != nil {
if errors.IsNotFound(err) {
deployment = obj.(*extensionsv1beta1.Deployment)
deployment.Name = constants.IngressControllerPrefix + namespace
// Add project label
deployment.Spec.Selector.MatchLabels["project"] = namespace
deployment.Spec.Template.Labels["project"] = namespace
// Isolate namespace
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--watch-namespace="+namespace)
// Choose self as master
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--election-id="+deployment.Name)
}
} else {
createDeployment = false
for i := range deployment.Spec.Template.Spec.Containers {
if deployment.Spec.Template.Spec.Containers[i].Name == "nginx-ingress-controller" {
var args []string
for j := range deployment.Spec.Template.Spec.Containers[i].Args {
if strings.HasPrefix("--publish-service", deployment.Spec.Template.Spec.Containers[i].Args[j]) ||
strings.HasPrefix("--report-node-internal-ip-address", deployment.Spec.Template.Spec.Containers[i].Args[j]) {
continue
}
args = append(args, deployment.Spec.Template.Spec.Containers[i].Args[j])
}
deployment.Spec.Template.Spec.Containers[i].Args = args
}
}
}
if deployment.Spec.Template.Annotations == nil {
deployment.Spec.Template.Annotations = make(map[string]string, 0)
}
if servicemeshEnabled {
deployment.Spec.Template.Annotations[SIDECAR_INJECT] = "true"
} else {
deployment.Spec.Template.Annotations[SIDECAR_INJECT] = "false"
}
if publishService {
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--publish-service="+constants.IngressControllerNamespace+"/"+constants.IngressControllerPrefix+namespace)
} else {
deployment.Spec.Template.Spec.Containers[0].Args = append(deployment.Spec.Template.Spec.Containers[0].Args, "--report-node-internal-ip-address")
}
if createDeployment {
deployment, err = k8sClient.ExtensionsV1beta1().Deployments(constants.IngressControllerNamespace).Create(deployment)
} else {
deployment, err = k8sClient.ExtensionsV1beta1().Deployments(constants.IngressControllerNamespace).Update(deployment)
}
if err != nil {
glog.Error(err)
return err
}
return nil
}
func deleteRouterWorkload(namespace string) error {
k8sClient := k8s.Client()
deleteOptions := meta_v1.DeleteOptions{}
// delete controller deployment
deploymentName := constants.IngressControllerPrefix + namespace
err = k8sClient.ExtensionsV1beta1().Deployments(constants.IngressControllerNamespace).Delete(deploymentName, &deleteOptions)
err := k8sClient.ExtensionsV1beta1().Deployments(constants.IngressControllerNamespace).Delete(deploymentName, &deleteOptions)
if err != nil {
glog.Error(err)
}
@@ -280,45 +416,50 @@ func DeleteRouter(namespace string) (*corev1.Service, error) {
"app": "kubesphere",
"component": "ks-router",
"tier": "backend",
"project": deploymentName,
"project": namespace,
})
replicaSetLister := informers.SharedInformerFactory().Apps().V1().ReplicaSets().Lister()
replicaSets, err := replicaSetLister.ReplicaSets(constants.IngressControllerNamespace).List(selector)
if err == nil {
if err != nil {
glog.Error(err)
}
for i := range replicaSets {
err = k8sClient.AppsV1().ReplicaSets(constants.IngressControllerNamespace).Delete(replicaSets[i].Name, &deleteOptions)
glog.Error(err)
if err != nil {
glog.Error(err)
}
}
return router, nil
return nil
}
// Update Ingress Controller Service, change type from NodePort to Loadbalancer or vice versa.
// Update Ingress Controller Service, change type from NodePort to loadbalancer or vice versa.
func UpdateRouter(namespace string, routerType corev1.ServiceType, annotations map[string]string) (*corev1.Service, error) {
var router *corev1.Service
router, err := GetRouter(namespace)
router, err := getRouterService(namespace)
if err != nil {
glog.Error(err)
return router, nil
return router, err
}
router, err = DeleteRouter(namespace)
enableServicemesh := annotations[SERVICEMESH_ENABLED] == "true"
err = createOrUpdateRouterWorkload(namespace, routerType == corev1.ServiceTypeLoadBalancer, enableServicemesh)
if err != nil {
glog.Error(err)
return router, err
}
newRouter, err := updateRouterService(namespace, routerType, annotations)
if err != nil {
glog.Error(err)
return newRouter, err
}
router, err = CreateRouter(namespace, routerType, annotations)
if err != nil {
glog.Error(err)
}
return router, nil
return newRouter, nil
}

24
pkg/server/server.go Normal file
View File

@@ -0,0 +1,24 @@
package server
import (
"bytes"
"fmt"
"github.com/golang/glog"
"net/http"
"runtime"
)
func LogStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
var buffer bytes.Buffer
buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
for i := 2; ; i += 1 {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
buffer.WriteString(fmt.Sprintf(" %s:%d\r\n", file, line))
}
glog.Error(buffer.String())
httpWriter.WriteHeader(http.StatusInternalServerError)
httpWriter.Write([]byte("recover from panic situation"))
}

15
vendor/cloud.google.com/go/AUTHORS generated vendored
View File

@@ -1,15 +0,0 @@
# This is the official list of cloud authors for copyright purposes.
# This file is distinct from the CONTRIBUTORS files.
# See the latter for an explanation.
# Names should be added to this file as:
# Name or Organization <email address>
# The email address is not required for organizations.
Filippo Valsorda <hi@filippo.io>
Google Inc.
Ingo Oeser <nightlyone@googlemail.com>
Palm Stone Games, Inc.
Paweł Knap <pawelknap88@gmail.com>
Péter Szilágyi <peterke@gmail.com>
Tyler Treat <ttreat31@gmail.com>

View File

@@ -1,40 +0,0 @@
# People who have agreed to one of the CLAs and can contribute patches.
# The AUTHORS file lists the copyright holders; this file
# lists people. For example, Google employees are listed here
# but not in AUTHORS, because Google holds the copyright.
#
# https://developers.google.com/open-source/cla/individual
# https://developers.google.com/open-source/cla/corporate
#
# Names should be added to this file as:
# Name <email address>
# Keep the list alphabetically sorted.
Alexis Hunt <lexer@google.com>
Andreas Litt <andreas.litt@gmail.com>
Andrew Gerrand <adg@golang.org>
Brad Fitzpatrick <bradfitz@golang.org>
Burcu Dogan <jbd@google.com>
Dave Day <djd@golang.org>
David Sansome <me@davidsansome.com>
David Symonds <dsymonds@golang.org>
Filippo Valsorda <hi@filippo.io>
Glenn Lewis <gmlewis@google.com>
Ingo Oeser <nightlyone@googlemail.com>
James Hall <james.hall@shopify.com>
Johan Euphrosine <proppy@google.com>
Jonathan Amsterdam <jba@google.com>
Kunpei Sakai <namusyaka@gmail.com>
Luna Duclos <luna.duclos@palmstonegames.com>
Magnus Hiie <magnus.hiie@gmail.com>
Mario Castro <mariocaster@gmail.com>
Michael McGreevy <mcgreevy@golang.org>
Omar Jarjur <ojarjur@google.com>
Paweł Knap <pawelknap88@gmail.com>
Péter Szilágyi <peterke@gmail.com>
Sarah Adams <shadams@google.com>
Thanatat Tamtan <acoshift@gmail.com>
Toby Burress <kurin@google.com>
Tuo Shan <shantuo@google.com>
Tyler Treat <ttreat31@gmail.com>

View File

@@ -103,8 +103,7 @@ func (entry *Entry) WithError(err error) *Entry {
// Add a context to the Entry.
func (entry *Entry) WithContext(ctx context.Context) *Entry {
entry.Context = ctx
return entry
return &Entry{Logger: entry.Logger, Data: entry.Data, Time: entry.Time, err: entry.err, Context: ctx}
}
// Add a single field to the Entry.

View File

@@ -1,9 +0,0 @@
// +build !appengine,!js,!windows,aix
package logrus
import "io"
func checkIfTerminal(w io.Writer) bool {
return false
}

View File

@@ -0,0 +1,13 @@
// +build darwin dragonfly freebsd netbsd openbsd
package logrus
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TIOCGETA
func isTerminal(fd int) bool {
_, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
return err == nil
}

View File

@@ -1,18 +1,16 @@
// +build !appengine,!js,!windows,!aix
// +build !appengine,!js,!windows
package logrus
import (
"io"
"os"
"golang.org/x/crypto/ssh/terminal"
)
func checkIfTerminal(w io.Writer) bool {
switch v := w.(type) {
case *os.File:
return terminal.IsTerminal(int(v.Fd()))
return isTerminal(int(v.Fd()))
default:
return false
}

View File

@@ -0,0 +1,13 @@
// +build linux aix
package logrus
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TCGETS
func isTerminal(fd int) bool {
_, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
return err == nil
}

View File

@@ -73,9 +73,9 @@ type TextFormatter struct {
FieldMap FieldMap
// CallerPrettyfier can be set by the user to modify the content
// of the function and file keys in the json data when ReportCaller is
// of the function and file keys in the data when ReportCaller is
// activated. If any of the returned value is the empty string the
// corresponding key will be removed from json fields.
// corresponding key will be removed from fields.
CallerPrettyfier func(*runtime.Frame) (function string, file string)
terminalInitOnce sync.Once
@@ -133,14 +133,19 @@ func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyLogrusError))
}
if entry.HasCaller() {
fixedKeys = append(fixedKeys,
f.FieldMap.resolve(FieldKeyFunc), f.FieldMap.resolve(FieldKeyFile))
if f.CallerPrettyfier != nil {
funcVal, fileVal = f.CallerPrettyfier(entry.Caller)
} else {
funcVal = entry.Caller.Function
fileVal = fmt.Sprintf("%s:%d", entry.Caller.File, entry.Caller.Line)
}
if funcVal != "" {
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyFunc))
}
if fileVal != "" {
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyFile))
}
}
if !f.DisableSorting {
@@ -225,7 +230,6 @@ func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []strin
entry.Message = strings.TrimSuffix(entry.Message, "\n")
caller := ""
if entry.HasCaller() {
funcVal := fmt.Sprintf("%s()", entry.Caller.Function)
fileVal := fmt.Sprintf("%s:%d", entry.Caller.File, entry.Caller.Line)
@@ -233,7 +237,14 @@ func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []strin
if f.CallerPrettyfier != nil {
funcVal, fileVal = f.CallerPrettyfier(entry.Caller)
}
caller = fileVal + " " + funcVal
if fileVal == "" {
caller = funcVal
} else if funcVal == "" {
caller = fileVal
} else {
caller = fileVal + " " + funcVal
}
}
if f.DisableTimestamp {

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2016 Andreas Auernhammer
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,197 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// Package chacha implements some low-level functions of the
// ChaCha cipher family.
package chacha // import "github.com/aead/chacha20/chacha"
import (
"encoding/binary"
"errors"
"math"
)
const (
// NonceSize is the size of the ChaCha20 nonce in bytes.
NonceSize = 8
// INonceSize is the size of the IETF-ChaCha20 nonce in bytes.
INonceSize = 12
// XNonceSize is the size of the XChaCha20 nonce in bytes.
XNonceSize = 24
// KeySize is the size of the key in bytes.
KeySize = 32
)
var (
useSSE2 bool
useSSSE3 bool
useAVX bool
useAVX2 bool
)
var (
errKeySize = errors.New("chacha20/chacha: bad key length")
errInvalidNonce = errors.New("chacha20/chacha: bad nonce length")
)
func setup(state *[64]byte, nonce, key []byte) (err error) {
if len(key) != KeySize {
err = errKeySize
return
}
var Nonce [16]byte
switch len(nonce) {
case NonceSize:
copy(Nonce[8:], nonce)
initialize(state, key, &Nonce)
case INonceSize:
copy(Nonce[4:], nonce)
initialize(state, key, &Nonce)
case XNonceSize:
var tmpKey [32]byte
var hNonce [16]byte
copy(hNonce[:], nonce[:16])
copy(tmpKey[:], key)
HChaCha20(&tmpKey, &hNonce, &tmpKey)
copy(Nonce[8:], nonce[16:])
initialize(state, tmpKey[:], &Nonce)
// BUG(aead): A "good" compiler will remove this (optimizations)
// But using the provided key instead of tmpKey,
// will change the key (-> probably confuses users)
for i := range tmpKey {
tmpKey[i] = 0
}
default:
err = errInvalidNonce
}
return
}
// XORKeyStream crypts bytes from src to dst using the given nonce and key.
// The length of the nonce determinds the version of ChaCha20:
// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period.
// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period.
// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period.
// The rounds argument specifies the number of rounds performed for keystream
// generation - valid values are 8, 12 or 20. The src and dst may be the same slice
// but otherwise should not overlap. If len(dst) < len(src) this function panics.
// If the nonce is neither 64, 96 nor 192 bits long, this function panics.
func XORKeyStream(dst, src, nonce, key []byte, rounds int) {
if rounds != 20 && rounds != 12 && rounds != 8 {
panic("chacha20/chacha: bad number of rounds")
}
if len(dst) < len(src) {
panic("chacha20/chacha: dst buffer is to small")
}
if len(nonce) == INonceSize && uint64(len(src)) > (1<<38) {
panic("chacha20/chacha: src is too large")
}
var block, state [64]byte
if err := setup(&state, nonce, key); err != nil {
panic(err)
}
xorKeyStream(dst, src, &block, &state, rounds)
}
// Cipher implements ChaCha20/r (XChaCha20/r) for a given number of rounds r.
type Cipher struct {
state, block [64]byte
off int
rounds int // 20 for ChaCha20
noncesize int
}
// NewCipher returns a new *chacha.Cipher implementing the ChaCha20/r or XChaCha20/r
// (r = 8, 12 or 20) stream cipher. The nonce must be unique for one key for all time.
// The length of the nonce determinds the version of ChaCha20:
// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period.
// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period.
// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period.
// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned.
func NewCipher(nonce, key []byte, rounds int) (*Cipher, error) {
if rounds != 20 && rounds != 12 && rounds != 8 {
panic("chacha20/chacha: bad number of rounds")
}
c := new(Cipher)
if err := setup(&(c.state), nonce, key); err != nil {
return nil, err
}
c.rounds = rounds
if len(nonce) == INonceSize {
c.noncesize = INonceSize
} else {
c.noncesize = NonceSize
}
return c, nil
}
// XORKeyStream crypts bytes from src to dst. Src and dst may be the same slice
// but otherwise should not overlap. If len(dst) < len(src) the function panics.
func (c *Cipher) XORKeyStream(dst, src []byte) {
if len(dst) < len(src) {
panic("chacha20/chacha: dst buffer is to small")
}
if c.off > 0 {
n := len(c.block[c.off:])
if len(src) <= n {
for i, v := range src {
dst[i] = v ^ c.block[c.off]
c.off++
}
if c.off == 64 {
c.off = 0
}
return
}
for i, v := range c.block[c.off:] {
dst[i] = src[i] ^ v
}
src = src[n:]
dst = dst[n:]
c.off = 0
}
// check for counter overflow
blocksToXOR := len(src) / 64
if len(src)%64 != 0 {
blocksToXOR++
}
var overflow bool
if c.noncesize == INonceSize {
overflow = binary.LittleEndian.Uint32(c.state[48:]) > math.MaxUint32-uint32(blocksToXOR)
} else {
overflow = binary.LittleEndian.Uint64(c.state[48:]) > math.MaxUint64-uint64(blocksToXOR)
}
if overflow {
panic("chacha20/chacha: counter overflow")
}
c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds)
}
// SetCounter skips ctr * 64 byte blocks. SetCounter(0) resets the cipher.
// This function always skips the unused keystream of the current 64 byte block.
func (c *Cipher) SetCounter(ctr uint64) {
if c.noncesize == INonceSize {
binary.LittleEndian.PutUint32(c.state[48:], uint32(ctr))
} else {
binary.LittleEndian.PutUint64(c.state[48:], ctr)
}
c.off = 0
}
// HChaCha20 generates 32 pseudo-random bytes from a 128 bit nonce and a 256 bit secret key.
// It can be used as a key-derivation-function (KDF).
func HChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { hChaCha20(out, nonce, key) }

View File

@@ -1,406 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build amd64,!gccgo,!appengine,!nacl
#include "const.s"
#include "macro.s"
#define TWO 0(SP)
#define C16 32(SP)
#define C8 64(SP)
#define STATE_0 96(SP)
#define STATE_1 128(SP)
#define STATE_2 160(SP)
#define STATE_3 192(SP)
#define TMP_0 224(SP)
#define TMP_1 256(SP)
// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamAVX2(SB), 4, $320-80
MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI
MOVQ block+48(FP), BX
MOVQ state+56(FP), AX
MOVQ rounds+64(FP), DX
MOVQ src_len+32(FP), CX
MOVQ SP, R8
ADDQ $32, SP
ANDQ $-32, SP
VMOVDQU 0(AX), Y2
VMOVDQU 32(AX), Y3
VPERM2I128 $0x22, Y2, Y0, Y0
VPERM2I128 $0x33, Y2, Y1, Y1
VPERM2I128 $0x22, Y3, Y2, Y2
VPERM2I128 $0x33, Y3, Y3, Y3
TESTQ CX, CX
JZ done
VMOVDQU ·one_AVX2<>(SB), Y4
VPADDD Y4, Y3, Y3
VMOVDQA Y0, STATE_0
VMOVDQA Y1, STATE_1
VMOVDQA Y2, STATE_2
VMOVDQA Y3, STATE_3
VMOVDQU ·rol16_AVX2<>(SB), Y4
VMOVDQU ·rol8_AVX2<>(SB), Y5
VMOVDQU ·two_AVX2<>(SB), Y6
VMOVDQA Y4, Y14
VMOVDQA Y5, Y15
VMOVDQA Y4, C16
VMOVDQA Y5, C8
VMOVDQA Y6, TWO
CMPQ CX, $64
JBE between_0_and_64
CMPQ CX, $192
JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
CMPQ CX, $448
JBE between_320_and_448
at_least_512:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VPADDQ TWO, Y3, Y7
VMOVDQA Y0, Y8
VMOVDQA Y1, Y9
VMOVDQA Y2, Y10
VPADDQ TWO, Y7, Y11
VMOVDQA Y0, Y12
VMOVDQA Y1, Y13
VMOVDQA Y2, Y14
VPADDQ TWO, Y11, Y15
MOVQ DX, R9
chacha_loop_512:
VMOVDQA Y8, TMP_0
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
VMOVDQA TMP_0, Y8
VMOVDQA Y0, TMP_0
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_SHUFFLE_AVX(Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
VMOVDQA TMP_0, Y0
VMOVDQA Y8, TMP_0
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
VMOVDQA TMP_0, Y8
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
CHACHA_SHUFFLE_AVX(Y15, Y14, Y13)
SUBQ $2, R9
JA chacha_loop_512
VMOVDQA Y12, TMP_0
VMOVDQA Y13, TMP_1
VPADDD STATE_0, Y0, Y0
VPADDD STATE_1, Y1, Y1
VPADDD STATE_2, Y2, Y2
VPADDD STATE_3, Y3, Y3
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13)
VMOVDQA STATE_0, Y0
VMOVDQA STATE_1, Y1
VMOVDQA STATE_2, Y2
VMOVDQA STATE_3, Y3
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
VPADDQ TWO, Y3, Y3
VPADDD TMP_0, Y0, Y12
VPADDD TMP_1, Y1, Y13
VPADDD Y2, Y14, Y14
VPADDD Y3, Y15, Y15
VPADDQ TWO, Y3, Y3
CMPQ CX, $512
JB less_than_512
XOR_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5)
VMOVDQA Y3, STATE_3
ADDQ $512, SI
ADDQ $512, DI
SUBQ $512, CX
CMPQ CX, $448
JA at_least_512
TESTQ CX, CX
JZ done
VMOVDQA C16, Y14
VMOVDQA C8, Y15
CMPQ CX, $64
JBE between_0_and_64
CMPQ CX, $192
JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
JMP between_320_and_448
less_than_512:
XOR_UPPER_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5)
EXTRACT_LOWER(BX, Y12, Y13, Y14, Y15, Y4)
ADDQ $448, SI
ADDQ $448, DI
SUBQ $448, CX
JMP finalize
between_320_and_448:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VPADDQ TWO, Y3, Y7
VMOVDQA Y0, Y8
VMOVDQA Y1, Y9
VMOVDQA Y2, Y10
VPADDQ TWO, Y7, Y11
MOVQ DX, R9
chacha_loop_384:
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
SUBQ $2, R9
JA chacha_loop_384
VPADDD STATE_0, Y0, Y0
VPADDD STATE_1, Y1, Y1
VPADDD STATE_2, Y2, Y2
VPADDD STATE_3, Y3, Y3
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13)
VMOVDQA STATE_0, Y0
VMOVDQA STATE_1, Y1
VMOVDQA STATE_2, Y2
VMOVDQA STATE_3, Y3
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3
CMPQ CX, $384
JB less_than_384
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
SUBQ $384, CX
TESTQ CX, CX
JE done
ADDQ $384, SI
ADDQ $384, DI
JMP between_0_and_64
less_than_384:
XOR_UPPER_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12)
ADDQ $320, SI
ADDQ $320, DI
SUBQ $320, CX
JMP finalize
between_192_and_320:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VMOVDQA Y3, Y7
VMOVDQA Y0, Y8
VMOVDQA Y1, Y9
VMOVDQA Y2, Y10
VPADDQ TWO, Y3, Y11
MOVQ DX, R9
chacha_loop_256:
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
SUBQ $2, R9
JA chacha_loop_256
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
VPADDQ TWO, Y3, Y3
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3
CMPQ CX, $256
JB less_than_256
XOR_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13)
SUBQ $256, CX
TESTQ CX, CX
JE done
ADDQ $256, SI
ADDQ $256, DI
JMP between_0_and_64
less_than_256:
XOR_UPPER_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13)
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12)
ADDQ $192, SI
ADDQ $192, DI
SUBQ $192, CX
JMP finalize
between_64_and_192:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VMOVDQA Y3, Y7
MOVQ DX, R9
chacha_loop_128:
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
SUBQ $2, R9
JA chacha_loop_128
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
VPADDQ TWO, Y3, Y3
CMPQ CX, $128
JB less_than_128
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
SUBQ $128, CX
TESTQ CX, CX
JE done
ADDQ $128, SI
ADDQ $128, DI
JMP between_0_and_64
less_than_128:
XOR_UPPER_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
EXTRACT_LOWER(BX, Y4, Y5, Y6, Y7, Y13)
ADDQ $64, SI
ADDQ $64, DI
SUBQ $64, CX
JMP finalize
between_0_and_64:
VMOVDQA X0, X4
VMOVDQA X1, X5
VMOVDQA X2, X6
VMOVDQA X3, X7
MOVQ DX, R9
chacha_loop_64:
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE_AVX(X5, X6, X7)
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE_AVX(X7, X6, X5)
SUBQ $2, R9
JA chacha_loop_64
VPADDD X0, X4, X4
VPADDD X1, X5, X5
VPADDD X2, X6, X6
VPADDD X3, X7, X7
VMOVDQU ·one<>(SB), X0
VPADDQ X0, X3, X3
CMPQ CX, $64
JB less_than_64
XOR_AVX(DI, SI, 0, X4, X5, X6, X7, X13)
SUBQ $64, CX
JMP done
less_than_64:
VMOVDQU X4, 0(BX)
VMOVDQU X5, 16(BX)
VMOVDQU X6, 32(BX)
VMOVDQU X7, 48(BX)
finalize:
XORQ R11, R11
XORQ R12, R12
MOVQ CX, BP
xor_loop:
MOVB 0(SI), R11
MOVB 0(BX), R12
XORQ R11, R12
MOVB R12, 0(DI)
INCQ SI
INCQ BX
INCQ DI
DECQ BP
JA xor_loop
done:
VMOVDQU X3, 48(AX)
VZEROUPPER
MOVQ R8, SP
MOVQ CX, ret+72(FP)
RET

View File

@@ -1,60 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl
package chacha
import (
"encoding/binary"
"golang.org/x/sys/cpu"
)
func init() {
useSSE2 = cpu.X86.HasSSE2
useSSSE3 = cpu.X86.HasSSSE3
useAVX = false
useAVX2 = false
}
func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
binary.LittleEndian.PutUint32(state[0:], sigma[0])
binary.LittleEndian.PutUint32(state[4:], sigma[1])
binary.LittleEndian.PutUint32(state[8:], sigma[2])
binary.LittleEndian.PutUint32(state[12:], sigma[3])
copy(state[16:], key[:])
copy(state[48:], nonce[:])
}
// This function is implemented in chacha_386.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_386.s
//go:noescape
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_386.s
//go:noescape
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
switch {
case useSSSE3:
hChaCha20SSSE3(out, nonce, key)
case useSSE2:
hChaCha20SSE2(out, nonce, key)
default:
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
if useSSE2 {
return xorKeyStreamSSE2(dst, src, block, state, rounds)
} else {
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
}

View File

@@ -1,163 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl
#include "const.s"
#include "macro.s"
// FINALIZE xors len bytes from src and block using
// the temp. registers t0 and t1 and writes the result
// to dst.
#define FINALIZE(dst, src, block, len, t0, t1) \
XORL t0, t0; \
XORL t1, t1; \
FINALIZE_LOOP:; \
MOVB 0(src), t0; \
MOVB 0(block), t1; \
XORL t0, t1; \
MOVB t1, 0(dst); \
INCL src; \
INCL block; \
INCL dst; \
DECL len; \
JG FINALIZE_LOOP \
#define Dst DI
#define Nonce AX
#define Key BX
#define Rounds DX
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSE2(SB), 4, $0-12
MOVL out+0(FP), Dst
MOVL nonce+4(FP), Nonce
MOVL key+8(FP), Key
MOVOU ·sigma<>(SB), X0
MOVOU 0*16(Key), X1
MOVOU 1*16(Key), X2
MOVOU 0*16(Nonce), X3
MOVL $20, Rounds
chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE_SSE(X1, X2, X3)
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE_SSE(X3, X2, X1)
SUBL $2, Rounds
JNZ chacha_loop
MOVOU X0, 0*16(Dst)
MOVOU X3, 1*16(Dst)
RET
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSSE3(SB), 4, $0-12
MOVL out+0(FP), Dst
MOVL nonce+4(FP), Nonce
MOVL key+8(FP), Key
MOVOU ·sigma<>(SB), X0
MOVOU 0*16(Key), X1
MOVOU 1*16(Key), X2
MOVOU 0*16(Nonce), X3
MOVL $20, Rounds
MOVOU ·rol16<>(SB), X5
MOVOU ·rol8<>(SB), X6
chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE_SSE(X1, X2, X3)
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE_SSE(X3, X2, X1)
SUBL $2, Rounds
JNZ chacha_loop
MOVOU X0, 0*16(Dst)
MOVOU X3, 1*16(Dst)
RET
#undef Dst
#undef Nonce
#undef Key
#undef Rounds
#define State AX
#define Dst DI
#define Src SI
#define Len DX
#define Tmp0 BX
#define Tmp1 BP
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSE2(SB), 4, $0-40
MOVL dst_base+0(FP), Dst
MOVL src_base+12(FP), Src
MOVL state+28(FP), State
MOVL src_len+16(FP), Len
MOVL $0, ret+36(FP) // Number of bytes written to the keystream buffer - 0 iff len mod 64 == 0
MOVOU 0*16(State), X0
MOVOU 1*16(State), X1
MOVOU 2*16(State), X2
MOVOU 3*16(State), X3
TESTL Len, Len
JZ DONE
GENERATE_KEYSTREAM:
MOVO X0, X4
MOVO X1, X5
MOVO X2, X6
MOVO X3, X7
MOVL rounds+32(FP), Tmp0
CHACHA_LOOP:
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE_SSE(X5, X6, X7)
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE_SSE(X7, X6, X5)
SUBL $2, Tmp0
JA CHACHA_LOOP
MOVOU 0*16(State), X0 // Restore X0 from state
PADDL X0, X4
PADDL X1, X5
PADDL X2, X6
PADDL X3, X7
MOVOU ·one<>(SB), X0
PADDQ X0, X3
CMPL Len, $64
JL BUFFER_KEYSTREAM
XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X0)
MOVOU 0*16(State), X0 // Restore X0 from state
ADDL $64, Src
ADDL $64, Dst
SUBL $64, Len
JZ DONE
JMP GENERATE_KEYSTREAM // There is at least one more plaintext byte
BUFFER_KEYSTREAM:
MOVL block+24(FP), State
MOVOU X4, 0(State)
MOVOU X5, 16(State)
MOVOU X6, 32(State)
MOVOU X7, 48(State)
MOVL Len, ret+36(FP) // Number of bytes written to the keystream buffer - 0 < Len < 64
FINALIZE(Dst, Src, State, Len, Tmp0, Tmp1)
DONE:
MOVL state+28(FP), State
MOVOU X3, 3*16(State)
RET
#undef State
#undef Dst
#undef Src
#undef Len
#undef Tmp0
#undef Tmp1

View File

@@ -1,76 +0,0 @@
// Copyright (c) 2017 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build go1.7,amd64,!gccgo,!appengine,!nacl
package chacha
import "golang.org/x/sys/cpu"
func init() {
useSSE2 = cpu.X86.HasSSE2
useSSSE3 = cpu.X86.HasSSSE3
useAVX = cpu.X86.HasAVX
useAVX2 = cpu.X86.HasAVX2
}
// This function is implemented in chacha_amd64.s
//go:noescape
func initialize(state *[64]byte, key []byte, nonce *[16]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chachaAVX2_amd64.s
//go:noescape
func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chachaAVX2_amd64.s
//go:noescape
func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
switch {
case useAVX:
hChaCha20AVX(out, nonce, key)
case useSSSE3:
hChaCha20SSSE3(out, nonce, key)
case useSSE2:
hChaCha20SSE2(out, nonce, key)
default:
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
switch {
case useAVX2:
return xorKeyStreamAVX2(dst, src, block, state, rounds)
case useAVX:
return xorKeyStreamAVX(dst, src, block, state, rounds)
case useSSSE3:
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
case useSSE2:
return xorKeyStreamSSE2(dst, src, block, state, rounds)
default:
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,319 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
package chacha
import "encoding/binary"
var sigma = [4]uint32{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574}
func xorKeyStreamGeneric(dst, src []byte, block, state *[64]byte, rounds int) int {
for len(src) >= 64 {
chachaGeneric(block, state, rounds)
for i, v := range block {
dst[i] = src[i] ^ v
}
src = src[64:]
dst = dst[64:]
}
n := len(src)
if n > 0 {
chachaGeneric(block, state, rounds)
for i, v := range src {
dst[i] = v ^ block[i]
}
}
return n
}
func chachaGeneric(dst *[64]byte, state *[64]byte, rounds int) {
v00 := binary.LittleEndian.Uint32(state[0:])
v01 := binary.LittleEndian.Uint32(state[4:])
v02 := binary.LittleEndian.Uint32(state[8:])
v03 := binary.LittleEndian.Uint32(state[12:])
v04 := binary.LittleEndian.Uint32(state[16:])
v05 := binary.LittleEndian.Uint32(state[20:])
v06 := binary.LittleEndian.Uint32(state[24:])
v07 := binary.LittleEndian.Uint32(state[28:])
v08 := binary.LittleEndian.Uint32(state[32:])
v09 := binary.LittleEndian.Uint32(state[36:])
v10 := binary.LittleEndian.Uint32(state[40:])
v11 := binary.LittleEndian.Uint32(state[44:])
v12 := binary.LittleEndian.Uint32(state[48:])
v13 := binary.LittleEndian.Uint32(state[52:])
v14 := binary.LittleEndian.Uint32(state[56:])
v15 := binary.LittleEndian.Uint32(state[60:])
s00, s01, s02, s03, s04, s05, s06, s07 := v00, v01, v02, v03, v04, v05, v06, v07
s08, s09, s10, s11, s12, s13, s14, s15 := v08, v09, v10, v11, v12, v13, v14, v15
for i := 0; i < rounds; i += 2 {
v00 += v04
v12 ^= v00
v12 = (v12 << 16) | (v12 >> 16)
v08 += v12
v04 ^= v08
v04 = (v04 << 12) | (v04 >> 20)
v00 += v04
v12 ^= v00
v12 = (v12 << 8) | (v12 >> 24)
v08 += v12
v04 ^= v08
v04 = (v04 << 7) | (v04 >> 25)
v01 += v05
v13 ^= v01
v13 = (v13 << 16) | (v13 >> 16)
v09 += v13
v05 ^= v09
v05 = (v05 << 12) | (v05 >> 20)
v01 += v05
v13 ^= v01
v13 = (v13 << 8) | (v13 >> 24)
v09 += v13
v05 ^= v09
v05 = (v05 << 7) | (v05 >> 25)
v02 += v06
v14 ^= v02
v14 = (v14 << 16) | (v14 >> 16)
v10 += v14
v06 ^= v10
v06 = (v06 << 12) | (v06 >> 20)
v02 += v06
v14 ^= v02
v14 = (v14 << 8) | (v14 >> 24)
v10 += v14
v06 ^= v10
v06 = (v06 << 7) | (v06 >> 25)
v03 += v07
v15 ^= v03
v15 = (v15 << 16) | (v15 >> 16)
v11 += v15
v07 ^= v11
v07 = (v07 << 12) | (v07 >> 20)
v03 += v07
v15 ^= v03
v15 = (v15 << 8) | (v15 >> 24)
v11 += v15
v07 ^= v11
v07 = (v07 << 7) | (v07 >> 25)
v00 += v05
v15 ^= v00
v15 = (v15 << 16) | (v15 >> 16)
v10 += v15
v05 ^= v10
v05 = (v05 << 12) | (v05 >> 20)
v00 += v05
v15 ^= v00
v15 = (v15 << 8) | (v15 >> 24)
v10 += v15
v05 ^= v10
v05 = (v05 << 7) | (v05 >> 25)
v01 += v06
v12 ^= v01
v12 = (v12 << 16) | (v12 >> 16)
v11 += v12
v06 ^= v11
v06 = (v06 << 12) | (v06 >> 20)
v01 += v06
v12 ^= v01
v12 = (v12 << 8) | (v12 >> 24)
v11 += v12
v06 ^= v11
v06 = (v06 << 7) | (v06 >> 25)
v02 += v07
v13 ^= v02
v13 = (v13 << 16) | (v13 >> 16)
v08 += v13
v07 ^= v08
v07 = (v07 << 12) | (v07 >> 20)
v02 += v07
v13 ^= v02
v13 = (v13 << 8) | (v13 >> 24)
v08 += v13
v07 ^= v08
v07 = (v07 << 7) | (v07 >> 25)
v03 += v04
v14 ^= v03
v14 = (v14 << 16) | (v14 >> 16)
v09 += v14
v04 ^= v09
v04 = (v04 << 12) | (v04 >> 20)
v03 += v04
v14 ^= v03
v14 = (v14 << 8) | (v14 >> 24)
v09 += v14
v04 ^= v09
v04 = (v04 << 7) | (v04 >> 25)
}
v00 += s00
v01 += s01
v02 += s02
v03 += s03
v04 += s04
v05 += s05
v06 += s06
v07 += s07
v08 += s08
v09 += s09
v10 += s10
v11 += s11
v12 += s12
v13 += s13
v14 += s14
v15 += s15
s12++
binary.LittleEndian.PutUint32(state[48:], s12)
if s12 == 0 { // indicates overflow
s13++
binary.LittleEndian.PutUint32(state[52:], s13)
}
binary.LittleEndian.PutUint32(dst[0:], v00)
binary.LittleEndian.PutUint32(dst[4:], v01)
binary.LittleEndian.PutUint32(dst[8:], v02)
binary.LittleEndian.PutUint32(dst[12:], v03)
binary.LittleEndian.PutUint32(dst[16:], v04)
binary.LittleEndian.PutUint32(dst[20:], v05)
binary.LittleEndian.PutUint32(dst[24:], v06)
binary.LittleEndian.PutUint32(dst[28:], v07)
binary.LittleEndian.PutUint32(dst[32:], v08)
binary.LittleEndian.PutUint32(dst[36:], v09)
binary.LittleEndian.PutUint32(dst[40:], v10)
binary.LittleEndian.PutUint32(dst[44:], v11)
binary.LittleEndian.PutUint32(dst[48:], v12)
binary.LittleEndian.PutUint32(dst[52:], v13)
binary.LittleEndian.PutUint32(dst[56:], v14)
binary.LittleEndian.PutUint32(dst[60:], v15)
}
func hChaCha20Generic(out *[32]byte, nonce *[16]byte, key *[32]byte) {
v00 := sigma[0]
v01 := sigma[1]
v02 := sigma[2]
v03 := sigma[3]
v04 := binary.LittleEndian.Uint32(key[0:])
v05 := binary.LittleEndian.Uint32(key[4:])
v06 := binary.LittleEndian.Uint32(key[8:])
v07 := binary.LittleEndian.Uint32(key[12:])
v08 := binary.LittleEndian.Uint32(key[16:])
v09 := binary.LittleEndian.Uint32(key[20:])
v10 := binary.LittleEndian.Uint32(key[24:])
v11 := binary.LittleEndian.Uint32(key[28:])
v12 := binary.LittleEndian.Uint32(nonce[0:])
v13 := binary.LittleEndian.Uint32(nonce[4:])
v14 := binary.LittleEndian.Uint32(nonce[8:])
v15 := binary.LittleEndian.Uint32(nonce[12:])
for i := 0; i < 20; i += 2 {
v00 += v04
v12 ^= v00
v12 = (v12 << 16) | (v12 >> 16)
v08 += v12
v04 ^= v08
v04 = (v04 << 12) | (v04 >> 20)
v00 += v04
v12 ^= v00
v12 = (v12 << 8) | (v12 >> 24)
v08 += v12
v04 ^= v08
v04 = (v04 << 7) | (v04 >> 25)
v01 += v05
v13 ^= v01
v13 = (v13 << 16) | (v13 >> 16)
v09 += v13
v05 ^= v09
v05 = (v05 << 12) | (v05 >> 20)
v01 += v05
v13 ^= v01
v13 = (v13 << 8) | (v13 >> 24)
v09 += v13
v05 ^= v09
v05 = (v05 << 7) | (v05 >> 25)
v02 += v06
v14 ^= v02
v14 = (v14 << 16) | (v14 >> 16)
v10 += v14
v06 ^= v10
v06 = (v06 << 12) | (v06 >> 20)
v02 += v06
v14 ^= v02
v14 = (v14 << 8) | (v14 >> 24)
v10 += v14
v06 ^= v10
v06 = (v06 << 7) | (v06 >> 25)
v03 += v07
v15 ^= v03
v15 = (v15 << 16) | (v15 >> 16)
v11 += v15
v07 ^= v11
v07 = (v07 << 12) | (v07 >> 20)
v03 += v07
v15 ^= v03
v15 = (v15 << 8) | (v15 >> 24)
v11 += v15
v07 ^= v11
v07 = (v07 << 7) | (v07 >> 25)
v00 += v05
v15 ^= v00
v15 = (v15 << 16) | (v15 >> 16)
v10 += v15
v05 ^= v10
v05 = (v05 << 12) | (v05 >> 20)
v00 += v05
v15 ^= v00
v15 = (v15 << 8) | (v15 >> 24)
v10 += v15
v05 ^= v10
v05 = (v05 << 7) | (v05 >> 25)
v01 += v06
v12 ^= v01
v12 = (v12 << 16) | (v12 >> 16)
v11 += v12
v06 ^= v11
v06 = (v06 << 12) | (v06 >> 20)
v01 += v06
v12 ^= v01
v12 = (v12 << 8) | (v12 >> 24)
v11 += v12
v06 ^= v11
v06 = (v06 << 7) | (v06 >> 25)
v02 += v07
v13 ^= v02
v13 = (v13 << 16) | (v13 >> 16)
v08 += v13
v07 ^= v08
v07 = (v07 << 12) | (v07 >> 20)
v02 += v07
v13 ^= v02
v13 = (v13 << 8) | (v13 >> 24)
v08 += v13
v07 ^= v08
v07 = (v07 << 7) | (v07 >> 25)
v03 += v04
v14 ^= v03
v14 = (v14 << 16) | (v14 >> 16)
v09 += v14
v04 ^= v09
v04 = (v04 << 12) | (v04 >> 20)
v03 += v04
v14 ^= v03
v14 = (v14 << 8) | (v14 >> 24)
v09 += v14
v04 ^= v09
v04 = (v04 << 7) | (v04 >> 25)
}
binary.LittleEndian.PutUint32(out[0:], v00)
binary.LittleEndian.PutUint32(out[4:], v01)
binary.LittleEndian.PutUint32(out[8:], v02)
binary.LittleEndian.PutUint32(out[12:], v03)
binary.LittleEndian.PutUint32(out[16:], v12)
binary.LittleEndian.PutUint32(out[20:], v13)
binary.LittleEndian.PutUint32(out[24:], v14)
binary.LittleEndian.PutUint32(out[28:], v15)
}

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build !amd64,!386 gccgo appengine nacl
package chacha
import "encoding/binary"
func init() {
useSSE2 = false
useSSSE3 = false
useAVX = false
useAVX2 = false
}
func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
binary.LittleEndian.PutUint32(state[0:], sigma[0])
binary.LittleEndian.PutUint32(state[4:], sigma[1])
binary.LittleEndian.PutUint32(state[8:], sigma[2])
binary.LittleEndian.PutUint32(state[12:], sigma[3])
copy(state[16:], key[:])
copy(state[48:], nonce[:])
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
hChaCha20Generic(out, nonce, key)
}

View File

@@ -1,53 +0,0 @@
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
#include "textflag.h"
DATA ·sigma<>+0x00(SB)/4, $0x61707865
DATA ·sigma<>+0x04(SB)/4, $0x3320646e
DATA ·sigma<>+0x08(SB)/4, $0x79622d32
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 // The 4 ChaCha initialization constants
// SSE2/SSE3/AVX constants
DATA ·one<>+0x00(SB)/8, $1
DATA ·one<>+0x08(SB)/8, $0
GLOBL ·one<>(SB), (NOPTR+RODATA), $16 // The constant 1 as 128 bit value
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 16 bit left rotate constant
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 8 bit left rotate constant
// AVX2 constants
DATA ·one_AVX2<>+0x00(SB)/8, $0
DATA ·one_AVX2<>+0x08(SB)/8, $0
DATA ·one_AVX2<>+0x10(SB)/8, $1
DATA ·one_AVX2<>+0x18(SB)/8, $0
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 // The constant 1 as 256 bit value
DATA ·two_AVX2<>+0x00(SB)/8, $2
DATA ·two_AVX2<>+0x08(SB)/8, $0
DATA ·two_AVX2<>+0x10(SB)/8, $2
DATA ·two_AVX2<>+0x18(SB)/8, $0
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 16 bit left rotate constant
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 8 bit left rotate constant

View File

@@ -1,163 +0,0 @@
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
// ROTL_SSE rotates all 4 32 bit values of the XMM register v
// left by n bits using SSE2 instructions (0 <= n <= 32).
// The XMM register t is used as a temp. register.
#define ROTL_SSE(n, t, v) \
MOVO v, t; \
PSLLL $n, t; \
PSRLL $(32-n), v; \
PXOR t, v
// ROTL_AVX rotates all 4/8 32 bit values of the AVX/AVX2 register v
// left by n bits using AVX/AVX2 instructions (0 <= n <= 32).
// The AVX/AVX2 register t is used as a temp. register.
#define ROTL_AVX(n, t, v) \
VPSLLD $n, v, t; \
VPSRLD $(32-n), v, v; \
VPXOR v, t, v
// CHACHA_QROUND_SSE2 performs a ChaCha quarter-round using the
// 4 XMM registers v0, v1, v2 and v3. It uses only ROTL_SSE2 for
// rotations. The XMM register t is used as a temp. register.
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t) \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE(16, t, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(12, t, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE(8, t, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(7, t, v1)
// CHACHA_QROUND_SSSE3 performs a ChaCha quarter-round using the
// 4 XMM registers v0, v1, v2 and v3. It uses PSHUFB for 8/16 bit
// rotations. The XMM register t is used as a temp. register.
//
// r16 holds the PSHUFB constant for a 16 bit left rotate.
// r8 holds the PSHUFB constant for a 8 bit left rotate.
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t, r16, r8) \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r16, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(12, t, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r8, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(7, t, v1)
// CHACHA_QROUND_AVX performs a ChaCha quarter-round using the
// 4 AVX/AVX2 registers v0, v1, v2 and v3. It uses VPSHUFB for 8/16 bit
// rotations. The AVX/AVX2 register t is used as a temp. register.
//
// r16 holds the VPSHUFB constant for a 16 bit left rotate.
// r8 holds the VPSHUFB constant for a 8 bit left rotate.
#define CHACHA_QROUND_AVX(v0, v1, v2, v3, t, r16, r8) \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB r16, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL_AVX(12, t, v1); \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB r8, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL_AVX(7, t, v1)
// CHACHA_SHUFFLE_SSE performs a ChaCha shuffle using the
// 3 XMM registers v1, v2 and v3. The inverse shuffle is
// performed by switching v1 and v3: CHACHA_SHUFFLE_SSE(v3, v2, v1).
#define CHACHA_SHUFFLE_SSE(v1, v2, v3) \
PSHUFL $0x39, v1, v1; \
PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3
// CHACHA_SHUFFLE_AVX performs a ChaCha shuffle using the
// 3 AVX/AVX2 registers v1, v2 and v3. The inverse shuffle is
// performed by switching v1 and v3: CHACHA_SHUFFLE_AVX(v3, v2, v1).
#define CHACHA_SHUFFLE_AVX(v1, v2, v3) \
VPSHUFD $0x39, v1, v1; \
VPSHUFD $0x4E, v2, v2; \
VPSHUFD $0x93, v3, v3
// XOR_SSE extracts 4x16 byte vectors from src at
// off, xors all vectors with the corresponding XMM
// register (v0 - v3) and writes the result to dst
// at off.
// The XMM register t is used as a temp. register.
#define XOR_SSE(dst, src, off, v0, v1, v2, v3, t) \
MOVOU 0+off(src), t; \
PXOR v0, t; \
MOVOU t, 0+off(dst); \
MOVOU 16+off(src), t; \
PXOR v1, t; \
MOVOU t, 16+off(dst); \
MOVOU 32+off(src), t; \
PXOR v2, t; \
MOVOU t, 32+off(dst); \
MOVOU 48+off(src), t; \
PXOR v3, t; \
MOVOU t, 48+off(dst)
// XOR_AVX extracts 4x16 byte vectors from src at
// off, xors all vectors with the corresponding AVX
// register (v0 - v3) and writes the result to dst
// at off.
// The XMM register t is used as a temp. register.
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t) \
VPXOR 0+off(src), v0, t; \
VMOVDQU t, 0+off(dst); \
VPXOR 16+off(src), v1, t; \
VMOVDQU t, 16+off(dst); \
VPXOR 32+off(src), v2, t; \
VMOVDQU t, 32+off(dst); \
VPXOR 48+off(src), v3, t; \
VMOVDQU t, 48+off(dst)
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
VMOVDQU (64+off)(src), t0; \
VPERM2I128 $49, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (64+off)(dst); \
VMOVDQU (96+off)(src), t0; \
VPERM2I128 $49, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (96+off)(dst)
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
VPERM2I128 $49, v1, v0, t0; \
VMOVDQU t0, 0(dst); \
VPERM2I128 $49, v3, v2, t0; \
VMOVDQU t0, 32(dst)

View File

@@ -1,41 +0,0 @@
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// Package chacha20 implements the ChaCha20 / XChaCha20 stream chipher.
// Notice that one specific key-nonce combination must be unique for all time.
//
// There are three versions of ChaCha20:
// - ChaCha20 with a 64 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
// - ChaCha20 with a 96 bit nonce (en/decrypt up to 2^32 * 64 bytes (~256 GB) for one key-nonce combination)
// - XChaCha20 with a 192 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
package chacha20 // import "github.com/aead/chacha20"
import (
"crypto/cipher"
"github.com/aead/chacha20/chacha"
)
// XORKeyStream crypts bytes from src to dst using the given nonce and key.
// The length of the nonce determinds the version of ChaCha20:
// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period.
// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period.
// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period.
// Src and dst may be the same slice but otherwise should not overlap.
// If len(dst) < len(src) this function panics.
// If the nonce is neither 64, 96 nor 192 bits long, this function panics.
func XORKeyStream(dst, src, nonce, key []byte) {
chacha.XORKeyStream(dst, src, nonce, key, 20)
}
// NewCipher returns a new cipher.Stream implementing a ChaCha20 version.
// The nonce must be unique for one key for all time.
// The length of the nonce determinds the version of ChaCha20:
// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period.
// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period.
// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period.
// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned.
func NewCipher(nonce, key []byte) (cipher.Stream, error) {
return chacha.NewCipher(nonce, key, 20)
}

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2016 Richard Barnes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -1,101 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mint
import "strconv"
type Alert uint8
const (
// alert level
AlertLevelWarning = 1
AlertLevelError = 2
)
const (
AlertCloseNotify Alert = 0
AlertUnexpectedMessage Alert = 10
AlertBadRecordMAC Alert = 20
AlertDecryptionFailed Alert = 21
AlertRecordOverflow Alert = 22
AlertDecompressionFailure Alert = 30
AlertHandshakeFailure Alert = 40
AlertBadCertificate Alert = 42
AlertUnsupportedCertificate Alert = 43
AlertCertificateRevoked Alert = 44
AlertCertificateExpired Alert = 45
AlertCertificateUnknown Alert = 46
AlertIllegalParameter Alert = 47
AlertUnknownCA Alert = 48
AlertAccessDenied Alert = 49
AlertDecodeError Alert = 50
AlertDecryptError Alert = 51
AlertProtocolVersion Alert = 70
AlertInsufficientSecurity Alert = 71
AlertInternalError Alert = 80
AlertInappropriateFallback Alert = 86
AlertUserCanceled Alert = 90
AlertNoRenegotiation Alert = 100
AlertMissingExtension Alert = 109
AlertUnsupportedExtension Alert = 110
AlertCertificateUnobtainable Alert = 111
AlertUnrecognizedName Alert = 112
AlertBadCertificateStatsResponse Alert = 113
AlertBadCertificateHashValue Alert = 114
AlertUnknownPSKIdentity Alert = 115
AlertNoApplicationProtocol Alert = 120
AlertStatelessRetry Alert = 253
AlertWouldBlock Alert = 254
AlertNoAlert Alert = 255
)
var alertText = map[Alert]string{
AlertCloseNotify: "close notify",
AlertUnexpectedMessage: "unexpected message",
AlertBadRecordMAC: "bad record MAC",
AlertDecryptionFailed: "decryption failed",
AlertRecordOverflow: "record overflow",
AlertDecompressionFailure: "decompression failure",
AlertHandshakeFailure: "handshake failure",
AlertBadCertificate: "bad certificate",
AlertUnsupportedCertificate: "unsupported certificate",
AlertCertificateRevoked: "revoked certificate",
AlertCertificateExpired: "expired certificate",
AlertCertificateUnknown: "unknown certificate",
AlertIllegalParameter: "illegal parameter",
AlertUnknownCA: "unknown certificate authority",
AlertAccessDenied: "access denied",
AlertDecodeError: "error decoding message",
AlertDecryptError: "error decrypting message",
AlertProtocolVersion: "protocol version not supported",
AlertInsufficientSecurity: "insufficient security level",
AlertInternalError: "internal error",
AlertInappropriateFallback: "inappropriate fallback",
AlertUserCanceled: "user canceled",
AlertMissingExtension: "missing extension",
AlertUnsupportedExtension: "unsupported extension",
AlertCertificateUnobtainable: "certificate unobtainable",
AlertUnrecognizedName: "unrecognized name",
AlertBadCertificateStatsResponse: "bad certificate status response",
AlertBadCertificateHashValue: "bad certificate hash value",
AlertUnknownPSKIdentity: "unknown PSK identity",
AlertNoApplicationProtocol: "no application protocol",
AlertNoRenegotiation: "no renegotiation",
AlertStatelessRetry: "stateless retry",
AlertWouldBlock: "would have blocked",
AlertNoAlert: "no alert",
}
func (e Alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
func (e Alert) Error() string {
return e.String()
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,266 +0,0 @@
package mint
import (
"fmt"
"strconv"
)
const (
tls13Version uint16 = 0x0304
tls12Version uint16 = 0x0303
tls10Version uint16 = 0x0301
dtls12WireVersion uint16 = 0xfefd
)
var (
// Flags for some minor compat issues
allowWrongVersionNumber = true
allowPKCS1 = true
)
// enum {...} ContentType;
type RecordType byte
const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
RecordTypeAck RecordType = 25
)
// enum {...} HandshakeType;
type HandshakeType byte
const (
// Omitted: *_RESERVED
HandshakeTypeClientHello HandshakeType = 1
HandshakeTypeServerHello HandshakeType = 2
HandshakeTypeNewSessionTicket HandshakeType = 4
HandshakeTypeEndOfEarlyData HandshakeType = 5
HandshakeTypeHelloRetryRequest HandshakeType = 6
HandshakeTypeEncryptedExtensions HandshakeType = 8
HandshakeTypeCertificate HandshakeType = 11
HandshakeTypeCertificateRequest HandshakeType = 13
HandshakeTypeCertificateVerify HandshakeType = 15
HandshakeTypeServerConfiguration HandshakeType = 17
HandshakeTypeFinished HandshakeType = 20
HandshakeTypeKeyUpdate HandshakeType = 24
HandshakeTypeMessageHash HandshakeType = 254
)
var hrrRandomSentinel = [32]byte{
0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11,
0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e,
0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
}
// uint8 CipherSuite[2];
type CipherSuite uint16
const (
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
// value for this type so that we can detect when a field is set.
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
)
func (c CipherSuite) String() string {
switch c {
case CIPHER_SUITE_UNKNOWN:
return "unknown"
case TLS_AES_128_GCM_SHA256:
return "TLS_AES_128_GCM_SHA256"
case TLS_AES_256_GCM_SHA384:
return "TLS_AES_256_GCM_SHA384"
case TLS_CHACHA20_POLY1305_SHA256:
return "TLS_CHACHA20_POLY1305_SHA256"
case TLS_AES_128_CCM_SHA256:
return "TLS_AES_128_CCM_SHA256"
case TLS_AES_256_CCM_8_SHA256:
return "TLS_AES_256_CCM_8_SHA256"
}
// cannot use %x here, since it calls String(), leading to infinite recursion
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
}
// enum {...} SignatureScheme
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
// ECDSA algorithms
ECDSA_P256_SHA256 SignatureScheme = 0x0403
ECDSA_P384_SHA384 SignatureScheme = 0x0503
ECDSA_P521_SHA512 SignatureScheme = 0x0603
// RSASSA-PSS algorithms
RSA_PSS_SHA256 SignatureScheme = 0x0804
RSA_PSS_SHA384 SignatureScheme = 0x0805
RSA_PSS_SHA512 SignatureScheme = 0x0806
// EdDSA algorithms
Ed25519 SignatureScheme = 0x0807
Ed448 SignatureScheme = 0x0808
)
// enum {...} ExtensionType
type ExtensionType uint16
const (
ExtensionTypeServerName ExtensionType = 0
ExtensionTypeSupportedGroups ExtensionType = 10
ExtensionTypeSignatureAlgorithms ExtensionType = 13
ExtensionTypeALPN ExtensionType = 16
ExtensionTypeKeyShare ExtensionType = 51
ExtensionTypePreSharedKey ExtensionType = 41
ExtensionTypeEarlyData ExtensionType = 42
ExtensionTypeSupportedVersions ExtensionType = 43
ExtensionTypeCookie ExtensionType = 44
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
)
// enum {...} NamedGroup
type NamedGroup uint16
const (
// Elliptic Curve Groups.
P256 NamedGroup = 23
P384 NamedGroup = 24
P521 NamedGroup = 25
// ECDH functions.
X25519 NamedGroup = 29
X448 NamedGroup = 30
// Finite field groups.
FFDHE2048 NamedGroup = 256
FFDHE3072 NamedGroup = 257
FFDHE4096 NamedGroup = 258
FFDHE6144 NamedGroup = 259
FFDHE8192 NamedGroup = 260
)
// enum {...} PskKeyExchangeMode;
type PSKKeyExchangeMode uint8
const (
PSKModeKE PSKKeyExchangeMode = 0
PSKModeDHEKE PSKKeyExchangeMode = 1
)
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
type KeyUpdateRequest uint8
const (
KeyUpdateNotRequested KeyUpdateRequest = 0
KeyUpdateRequested KeyUpdateRequest = 1
)
type State uint8
const (
StateInit = 0
// states valid for the client
StateClientStart State = iota
StateClientWaitSH
StateClientWaitEE
StateClientWaitCert
StateClientWaitCV
StateClientWaitFinished
StateClientWaitCertCR
StateClientConnected
// states valid for the server
StateServerStart State = iota
StateServerRecvdCH
StateServerNegotiated
StateServerReadPastEarlyData
StateServerWaitEOED
StateServerWaitFlight2
StateServerWaitCert
StateServerWaitCV
StateServerWaitFinished
StateServerConnected
)
func (s State) String() string {
switch s {
case StateClientStart:
return "Client START"
case StateClientWaitSH:
return "Client WAIT_SH"
case StateClientWaitEE:
return "Client WAIT_EE"
case StateClientWaitCert:
return "Client WAIT_CERT"
case StateClientWaitCV:
return "Client WAIT_CV"
case StateClientWaitFinished:
return "Client WAIT_FINISHED"
case StateClientWaitCertCR:
return "Client WAIT_CERT_CR"
case StateClientConnected:
return "Client CONNECTED"
case StateServerStart:
return "Server START"
case StateServerRecvdCH:
return "Server RECVD_CH"
case StateServerNegotiated:
return "Server NEGOTIATED"
case StateServerReadPastEarlyData:
return "Server READ_PAST_EARLY_DATA"
case StateServerWaitEOED:
return "Server WAIT_EOED"
case StateServerWaitFlight2:
return "Server WAIT_FLIGHT2"
case StateServerWaitCert:
return "Server WAIT_CERT"
case StateServerWaitCV:
return "Server WAIT_CV"
case StateServerWaitFinished:
return "Server WAIT_FINISHED"
case StateServerConnected:
return "Server CONNECTED"
default:
return fmt.Sprintf("unknown state: %d", s)
}
}
// Epochs for DTLS (also used for key phase labelling)
type Epoch uint16
const (
EpochClear Epoch = 0
EpochEarlyData Epoch = 1
EpochHandshakeData Epoch = 2
EpochApplicationData Epoch = 3
EpochUpdate Epoch = 4
)
func (e Epoch) label() string {
switch e {
case EpochClear:
return "clear"
case EpochEarlyData:
return "early data"
case EpochHandshakeData:
return "handshake"
case EpochApplicationData:
return "application data"
}
return "Application data (updated)"
}
func assert(b bool) {
if !b {
panic("Assertion failed")
}
}

View File

@@ -1,928 +0,0 @@
package mint
import (
"crypto"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"reflect"
"sync"
"time"
)
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
}
type PreSharedKey struct {
CipherSuite CipherSuite
IsResumption bool
Identity []byte
Key []byte
NextProto string
ReceivedAt time.Time
ExpiresAt time.Time
TicketAgeAdd uint32
}
type PreSharedKeyCache interface {
Get(string) (PreSharedKey, bool)
Put(string, PreSharedKey)
Size() int
}
// A CookieHandler can be used to give the application more fine-grained control over Cookies.
// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie.
// When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie.
type CookieHandler interface {
// Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// If Generate returns nil, mint will not send a HelloRetryRequest.
Generate(*Conn) ([]byte, error)
// Validate is called when receiving a ClientHello containing a Cookie.
// If validation failed, the handshake is aborted.
Validate(*Conn, []byte) bool
}
type PSKMapCache map[string]PreSharedKey
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
psk, ok = cache[key]
return
}
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
(*cache)[key] = psk
}
func (cache PSKMapCache) Size() int {
return len(cache)
}
// Config is the struct used to pass configuration settings to a TLS client or
// server instance. The settings for client and server are pretty different,
// but we just throw them all in here.
type Config struct {
// Client fields
ServerName string
// Server fields
SendSessionTickets bool
TicketLifetime uint32
TicketLen int
EarlyDataLifetime uint32
AllowEarlyData bool
// Require the client to echo a cookie.
RequireCookie bool
// A CookieHandler can be used to set and validate a cookie.
// The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector.
// If no CookieHandler is set, mint will always send a cookie.
// The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent.
CookieHandler CookieHandler
// The CookieProtector is used to encrypt / decrypt cookies.
// It should make sure that the Cookie cannot be read and tampered with by the client.
// If non-blocking mode is used, and cookies are required, this field has to be set.
// In blocking mode, a default cookie protector is used, if this is unused.
CookieProtector CookieProtector
// The ExtensionHandler is used to add custom extensions.
ExtensionHandler AppExtensionHandler
RequireClientAuth bool
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses time.Now.
Time func() time.Time
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
// If InsecureSkipVerify is true, TLS accepts any certificate
// presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureSkipVerify bool
// Shared fields
Certificates []*Certificate
// VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a TLS client or server. It
// receives the raw ASN.1 certificates provided by the peer and also
// any verified chains that normal processing found. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled by
// setting InsecureSkipVerify then this callback will be considered but
// the verifiedChains argument will always be nil.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
NextProtos []string
PSKs PreSharedKeyCache
PSKModes []PSKKeyExchangeMode
NonBlocking bool
UseDTLS bool
RecordLayer RecordLayerFactory
// The same config object can be shared among different connections, so it
// needs its own mutex
mutex sync.RWMutex
}
// Clone returns a shallow clone of c. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
c.mutex.Lock()
defer c.mutex.Unlock()
return &Config{
ServerName: c.ServerName,
SendSessionTickets: c.SendSessionTickets,
TicketLifetime: c.TicketLifetime,
TicketLen: c.TicketLen,
EarlyDataLifetime: c.EarlyDataLifetime,
AllowEarlyData: c.AllowEarlyData,
RequireCookie: c.RequireCookie,
CookieHandler: c.CookieHandler,
CookieProtector: c.CookieProtector,
ExtensionHandler: c.ExtensionHandler,
RequireClientAuth: c.RequireClientAuth,
Time: c.Time,
RootCAs: c.RootCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
Certificates: c.Certificates,
VerifyPeerCertificate: c.VerifyPeerCertificate,
CipherSuites: c.CipherSuites,
Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes,
NextProtos: c.NextProtos,
PSKs: c.PSKs,
PSKModes: c.PSKModes,
NonBlocking: c.NonBlocking,
UseDTLS: c.UseDTLS,
}
}
func (c *Config) Init(isClient bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Set defaults
if len(c.CipherSuites) == 0 {
c.CipherSuites = defaultSupportedCipherSuites
}
if len(c.Groups) == 0 {
c.Groups = defaultSupportedGroups
}
if len(c.SignatureSchemes) == 0 {
c.SignatureSchemes = defaultSignatureSchemes
}
if c.TicketLen == 0 {
c.TicketLen = defaultTicketLen
}
if !reflect.ValueOf(c.PSKs).IsValid() {
c.PSKs = &PSKMapCache{}
}
if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes
}
return nil
}
func (c *Config) ValidForServer() bool {
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
(len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil)
}
func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
}
func (c *Config) time() time.Time {
t := c.Time
if t == nil {
t = time.Now
}
return t()
}
var (
defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
defaultSupportedGroups = []NamedGroup{
P256,
P384,
FFDHE2048,
X25519,
}
defaultSignatureSchemes = []SignatureScheme{
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
ECDSA_P256_SHA256,
ECDSA_P384_SHA384,
ECDSA_P521_SHA512,
}
defaultTicketLen = 16
defaultPSKModes = []PSKKeyExchangeMode{
PSKModeKE,
PSKModeDHEKE,
}
)
type ConnectionState struct {
HandshakeState State
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
NextProto string // Selected ALPN proto
UsingPSK bool // Are we using PSK.
UsingEarlyData bool // Did we negotiate 0-RTT.
}
// Conn implements the net.Conn interface, as with "crypto/tls"
// * Read, Write, and Close are provided locally
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
type Conn struct {
config *Config
conn net.Conn
isClient bool
state stateConnected
hState HandshakeState
handshakeMutex sync.Mutex
handshakeAlert Alert
handshakeComplete bool
readBuffer []byte
in, out RecordLayer
hsCtx *HandshakeContext
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
if !config.UseDTLS {
if config.RecordLayer == nil {
c.in = NewRecordLayerTLS(c.conn, DirectionRead)
c.out = NewRecordLayerTLS(c.conn, DirectionWrite)
} else {
c.in = config.RecordLayer.NewLayer(c.conn, DirectionRead)
c.out = config.RecordLayer.NewLayer(c.conn, DirectionWrite)
}
c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
} else {
c.in = NewRecordLayerDTLS(c.conn, DirectionRead)
c.out = NewRecordLayerDTLS(c.conn, DirectionWrite)
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
c.hsCtx.timeoutMS = initialTimeout
c.hsCtx.timers = newTimerSet()
c.hsCtx.waitingNextFlight = true
}
c.in.SetLabel(c.label())
c.out.SetLabel(c.label())
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
return c
}
// Read up
func (c *Conn) consumeRecord() error {
pt, err := c.in.ReadRecord()
if pt == nil {
logf(logTypeIO, "extendBuffer returns error %v", err)
return err
}
switch pt.contentType {
case RecordTypeHandshake:
logf(logTypeHandshake, "Received post-handshake message")
// We do not support fragmentation of post-handshake handshake messages.
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
start := 0
headerLen := handshakeHeaderLenTLS
if c.config.UseDTLS {
headerLen = handshakeHeaderLenDTLS
}
for start < len(pt.fragment) {
if len(pt.fragment[start:]) < headerLen {
return fmt.Errorf("Post-handshake handshake message too short for header")
}
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(pt.fragment[start])
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
if len(pt.fragment[start+headerLen:]) < hmLen {
return fmt.Errorf("Post-handshake handshake message too short for body")
}
hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen]
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
state, actions, alert := c.state.ProcessMessage(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return io.EOF
}
}
var connected bool
c.state, connected = state.(stateConnected)
if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
start += headerLen + hmLen
}
case RecordTypeAlert:
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
if len(pt.fragment) != 2 {
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
if Alert(pt.fragment[1]) == AlertCloseNotify {
return io.EOF
}
switch pt.fragment[0] {
case AlertLevelWarning:
// drop on the floor
case AlertLevelError:
return Alert(pt.fragment[1])
default:
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
case RecordTypeAck:
if !c.hsCtx.hIn.datagram {
logf(logTypeHandshake, "Received ACK in TLS mode")
return AlertUnexpectedMessage
}
return c.hsCtx.processAck(pt.fragment)
case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
}
return err
}
func readPartial(in *[]byte, buffer []byte) int {
logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in)))
read := copy(buffer, *in)
*in = (*in)[read:]
logf(logTypeVerbose, "Returning %v", string(buffer))
return read
}
// Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) {
if _, connected := c.hState.(stateConnected); !connected {
// Clients can't call Read prior to handshake completion.
if c.isClient {
return 0, errors.New("Read called before the handshake completed")
}
// Neither can servers that don't allow early data.
if !c.config.AllowEarlyData {
return 0, errors.New("Read called before the handshake completed")
}
// If there's no early data, then return WouldBlock
if len(c.hsCtx.earlyData) == 0 {
return 0, AlertWouldBlock
}
return readPartial(&c.hsCtx.earlyData, buffer), nil
}
// The handshake is now connected.
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert
}
if len(buffer) == 0 {
return 0, nil
}
// Run our timers.
if c.config.UseDTLS {
if err := c.hsCtx.timers.check(time.Now()); err != nil {
return 0, AlertInternalError
}
}
// Lock the input channel
c.in.Lock()
defer c.in.Unlock()
for len(c.readBuffer) == 0 {
err := c.consumeRecord()
// err can be nil if consumeRecord processed a non app-data
// record.
if err != nil {
if c.config.NonBlocking || err != AlertWouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err
}
}
}
return readPartial(&c.readBuffer, buffer), nil
}
// Write application data
func (c *Conn) Write(buffer []byte) (int, error) {
// Lock the output channel
c.out.Lock()
defer c.out.Unlock()
if !c.Writable() {
return 0, errors.New("Write called before the handshake completed (and early data not in use)")
}
// Send full-size fragments
var start int
sent := 0
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return sent, err
}
sent += maxFragmentLen
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start:],
})
if err != nil {
return sent, err
}
sent += len(buffer[start:])
}
return sent, nil
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlert(err Alert) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
var level int
switch err {
case AlertNoRenegotiation, AlertCloseNotify:
level = AlertLevelWarning
default:
level = AlertLevelError
}
buf := []byte{byte(err), byte(level)}
c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: buf,
})
// close_notify and end_of_early_data are not actually errors
if level == AlertLevelWarning {
return &net.OpError{Op: "local error", Err: err}
}
return c.Close()
}
// Close closes the connection.
func (c *Conn) Close() error {
// XXX crypto/tls has an interlock with Write here. Do we need that?
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
switch action := actionGeneric.(type) {
case QueueHandshakeMessage:
logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType)
err := c.hsCtx.hOut.QueueMessage(action.Message)
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
case SendQueuedHandshake:
_, err := c.hsCtx.hOut.SendQueuedMessages()
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
if c.config.UseDTLS {
c.hsCtx.timers.start(retransmitTimerLabel,
c.hsCtx.handshakeRetransmit,
c.hsCtx.timeoutMS)
}
case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet)
// Check that we don't have an input data in the handshake frame parser.
if len(c.hsCtx.hIn.frame.remainder) > 0 {
logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label)
return AlertDecodeError
}
err := c.in.Rekey(action.epoch, action.KeySet.Cipher, &action.KeySet)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError
}
case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet)
err := c.out.Rekey(action.epoch, action.KeySet.Cipher, &action.KeySet)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError
}
case ResetOut:
logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq)
c.out.ResetClear(action.seq)
case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
if c.isClient {
// Clients look up PSKs based on server name
c.config.PSKs.Put(c.config.ServerName, action.PSK)
} else {
// Servers look them up based on the identity in the extension
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
}
default:
logf(logTypeHandshake, "%s Unknown action type", label)
assert(false)
return AlertInternalError
}
return AlertNoAlert
}
func (c *Conn) HandshakeSetup() Alert {
var state HandshakeState
var actions []HandshakeAction
var alert Alert
if err := c.config.Init(c.isClient); err != nil {
logf(logTypeHandshake, "Error initializing config: %v", err)
return AlertInternalError
}
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
}
if c.isClient {
state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
return alert
}
}
} else {
if c.config.RequireCookie && c.config.CookieProtector == nil {
logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.")
if c.config.NonBlocking {
logf(logTypeHandshake, "Not possible in non-blocking mode.")
return AlertInternalError
}
var err error
c.config.CookieProtector, err = NewDefaultCookieProtector()
if err != nil {
logf(logTypeHandshake, "Error initializing cookie source: %v", alert)
return AlertInternalError
}
}
state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx}
}
c.hState = state
return AlertNoAlert
}
type handshakeMessageReader interface {
ReadMessage() (*HandshakeMessage, Alert)
}
type handshakeMessageReaderImpl struct {
hsCtx *HandshakeContext
}
var _ handshakeMessageReader = &handshakeMessageReaderImpl{}
func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) {
var hm *HandshakeMessage
var err error
for {
hm, err = r.hsCtx.hIn.ReadMessage()
if err == AlertWouldBlock {
return nil, AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "Error reading message: %v", err)
return nil, AlertCloseNotify
}
if hm != nil {
break
}
}
return hm, AlertNoAlert
}
// Handshake causes a TLS handshake on the connection. The `isClient` member
// determines whether a client or server handshake is performed. If a
// handshake has already been performed, then its result will be returned.
func (c *Conn) Handshake() Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
// TODO Lock handshakeMutex
// TODO Remove CloseNotify hack
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
return c.handshakeAlert
}
if c.handshakeComplete {
return AlertNoAlert
}
if c.hState == nil {
logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label)
alert := c.HandshakeSetup()
if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) {
return alert
}
}
logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState)
state := c.hState
_, connected := state.(stateConnected)
hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx}
for !connected {
var alert Alert
var actions []HandshakeAction
// Advance the state machine
state, actions, alert = state.Next(hmr)
if alert == AlertWouldBlock {
logf(logTypeHandshake, "%s Would block reading message: %s", label, alert)
// If we blocked, then run our timers to see if any have expired.
if c.hsCtx.hIn.datagram {
if err := c.hsCtx.timers.check(time.Now()); err != nil {
return AlertInternalError
}
}
return AlertWouldBlock
}
if alert == AlertCloseNotify {
logf(logTypeHandshake, "%s Error reading message: %s", label, alert)
c.sendAlert(AlertCloseNotify)
return AlertCloseNotify
}
if alert != AlertNoAlert && alert != AlertStatelessRetry {
logf(logTypeHandshake, "Error in state transition: %v", alert)
return alert
}
for index, action := range actions {
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
if alert := c.takeAction(action); alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(stateConnected)
if connected {
c.state = state.(stateConnected)
c.handshakeComplete = true
if !c.isClient {
// Send NewSessionTicket if configured to
if c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
// If there is early data, move it into the main buffer
if c.hsCtx.earlyData != nil {
c.readBuffer = c.hsCtx.earlyData
c.hsCtx.earlyData = nil
}
} else {
assert(c.hsCtx.earlyData == nil)
}
}
if c.config.NonBlocking {
if alert == AlertStatelessRetry {
return AlertStatelessRetry
}
return AlertNoAlert
}
}
return AlertNoAlert
}
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
if !c.handshakeComplete {
return fmt.Errorf("Cannot update keys until after handshake")
}
request := KeyUpdateNotRequested
if requestUpdate {
request = KeyUpdateRequested
}
// Create the key update and update state
actions, alert := c.state.KeyUpdate(request)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert while generating key update: %v", alert)
}
// Take actions (send key update and rekey)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert during key update actions: %v", alert)
}
}
return nil
}
func (c *Conn) GetHsState() State {
if c.hState == nil {
return StateInit
}
return c.hState.State()
}
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(stateConnected)
if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
}
if c.state.exporterSecret == nil {
return nil, fmt.Errorf("Internal error: no exporter secret")
}
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
hc := c.state.cryptoParams.Hash.New().Sum(context)
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
}
func (c *Conn) ConnectionState() ConnectionState {
state := ConnectionState{
HandshakeState: c.GetHsState(),
}
if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto
state.VerifiedChains = c.state.verifiedChains
state.PeerCertificates = c.state.peerCertificates
state.UsingPSK = c.state.Params.UsingPSK
state.UsingEarlyData = c.state.Params.UsingEarlyData
}
return state
}
func (c *Conn) Writable() bool {
// If we're connected, we're writable.
if _, connected := c.hState.(stateConnected); connected {
return true
}
// If we're a client in 0-RTT, then we're writable.
if c.isClient && c.out.Epoch() == EpochEarlyData {
return true
}
return false
}
func (c *Conn) label() string {
if c.isClient {
return "client"
}
return "server"
}

View File

@@ -1,86 +0,0 @@
package mint
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// CookieProtector is used to create and verify a cookie
type CookieProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const cookieSecretSize = 32
const cookieNonceSize = 32
// The DefaultCookieProtector is a simple implementation for the CookieProtector.
type DefaultCookieProtector struct {
secret []byte
}
var _ CookieProtector = &DefaultCookieProtector{}
// NewDefaultCookieProtector creates a source for source address tokens
func NewDefaultCookieProtector() (CookieProtector, error) {
secret := make([]byte, cookieSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &DefaultCookieProtector{secret: secret}, nil
}
// NewToken encodes data into a new token.
func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, cookieNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) {
if len(p) < cookieNonceSize {
return nil, fmt.Errorf("Token too short: %d", len(p))
}
nonce := p[:cookieNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
}
func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}

View File

@@ -1,671 +0,0 @@
package mint
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"time"
"golang.org/x/crypto/curve25519"
// Blank includes to ensure hash support
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
var prng = rand.Reader
type AeadFactory func(key []byte) (cipher.AEAD, error)
type CipherSuiteParams struct {
Suite CipherSuite
Cipher AeadFactory // Cipher factory
Hash crypto.Hash // Hash function
KeyLen int // Key length in octets
IvLen int // IV length in octets
}
type signatureAlgorithm uint8
const (
signatureAlgorithmUnknown = iota
signatureAlgorithmRSA_PKCS1
signatureAlgorithmRSA_PSS
signatureAlgorithmECDSA
)
var (
hashMap = map[SignatureScheme]crypto.Hash{
RSA_PKCS1_SHA1: crypto.SHA1,
RSA_PKCS1_SHA256: crypto.SHA256,
RSA_PKCS1_SHA384: crypto.SHA384,
RSA_PKCS1_SHA512: crypto.SHA512,
ECDSA_P256_SHA256: crypto.SHA256,
ECDSA_P384_SHA384: crypto.SHA384,
ECDSA_P521_SHA512: crypto.SHA512,
RSA_PSS_SHA256: crypto.SHA256,
RSA_PSS_SHA384: crypto.SHA384,
RSA_PSS_SHA512: crypto.SHA512,
}
sigMap = map[SignatureScheme]signatureAlgorithm{
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
}
curveMap = map[SignatureScheme]NamedGroup{
ECDSA_P256_SHA256: P256,
ECDSA_P384_SHA384: P384,
ECDSA_P521_SHA512: P521,
}
newAESGCM = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// TLS always uses 12-byte nonces
return cipher.NewGCMWithNonceSize(block, 12)
}
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
TLS_AES_128_GCM_SHA256: {
Suite: TLS_AES_128_GCM_SHA256,
Cipher: newAESGCM,
Hash: crypto.SHA256,
KeyLen: 16,
IvLen: 12,
},
TLS_AES_256_GCM_SHA384: {
Suite: TLS_AES_256_GCM_SHA384,
Cipher: newAESGCM,
Hash: crypto.SHA384,
KeyLen: 32,
IvLen: 12,
},
}
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
}
defaultRSAKeySize = 2048
)
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
switch group {
case P256:
crv = elliptic.P256()
case P384:
crv = elliptic.P384()
case P521:
crv = elliptic.P521()
}
return
}
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
switch key.Curve.Params().Name {
case elliptic.P256().Params().Name:
g = P256
case elliptic.P384().Params().Name:
g = P384
case elliptic.P521().Params().Name:
g = P521
}
return
}
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
size = 0
switch group {
case X25519:
size = 32
case P256:
size = 65
case P384:
size = 97
case P521:
size = 133
case FFDHE2048:
size = 256
case FFDHE3072:
size = 384
case FFDHE4096:
size = 512
case FFDHE6144:
size = 768
case FFDHE8192:
size = 1024
}
return
}
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
switch group {
case FFDHE2048:
p = finiteFieldPrime2048
case FFDHE3072:
p = finiteFieldPrime3072
case FFDHE4096:
p = finiteFieldPrime4096
case FFDHE6144:
p = finiteFieldPrime6144
case FFDHE8192:
p = finiteFieldPrime8192
}
return
}
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
sigType := sigMap[alg]
switch key.(type) {
case *rsa.PrivateKey:
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
case *ecdsa.PrivateKey:
return sigType == signatureAlgorithmECDSA
default:
return false
}
}
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
primeLen := len(p.Bytes())
for {
// g = 2 for all ffdhe groups
priv, err = rand.Int(prng, p)
if err != nil {
return
}
pub = big.NewInt(0)
pub.Exp(big.NewInt(2), priv, p)
if len(pub.Bytes()) == primeLen {
return
}
}
}
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
switch group {
case P256, P384, P521:
var x, y *big.Int
crv := curveFromNamedGroup(group)
priv, x, y, err = elliptic.GenerateKey(crv, prng)
if err != nil {
return
}
pub = elliptic.Marshal(crv, x, y)
return
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
p := primeFromNamedGroup(group)
x, X, err2 := ffdheKeyShareFromPrime(p)
if err2 != nil {
err = err2
return
}
priv = x.Bytes()
pubBytes := X.Bytes()
numBytes := keyExchangeSizeFromNamedGroup(group)
pub = make([]byte, numBytes)
copy(pub[numBytes-len(pubBytes):], pubBytes)
return
case X25519:
var private, public [32]byte
_, err = prng.Read(private[:])
if err != nil {
return
}
curve25519.ScalarBaseMult(&public, &private)
priv = private[:]
pub = public[:]
return
default:
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
}
}
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
switch group {
case P256, P384, P521:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
crv := curveFromNamedGroup(group)
pubX, pubY := elliptic.Unmarshal(crv, pub)
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
xBytes := x.Bytes()
numBytes := len(crv.Params().P.Bytes())
ret := make([]byte, numBytes)
copy(ret[numBytes-len(xBytes):], xBytes)
return ret, nil
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
numBytes := keyExchangeSizeFromNamedGroup(group)
if len(pub) != numBytes {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
p := primeFromNamedGroup(group)
x := big.NewInt(0).SetBytes(priv)
Y := big.NewInt(0).SetBytes(pub)
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
ret := make([]byte, numBytes)
copy(ret[numBytes-len(ZBytes):], ZBytes)
return ret, nil
case X25519:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
var private, public, ret [32]byte
copy(private[:], priv)
copy(public[:], pub)
curve25519.ScalarMult(&ret, &private, &public)
return ret[:], nil
default:
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
}
}
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
switch sig {
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
RSA_PSS_SHA256, RSA_PSS_SHA384,
RSA_PSS_SHA512:
return rsa.GenerateKey(prng, defaultRSAKeySize)
case ECDSA_P256_SHA256:
return ecdsa.GenerateKey(elliptic.P256(), prng)
case ECDSA_P384_SHA384:
return ecdsa.GenerateKey(elliptic.P384(), prng)
case ECDSA_P521_SHA512:
return ecdsa.GenerateKey(elliptic.P521(), prng)
default:
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
}
}
// XXX(rlb): Copied from crypto/x509
type ecdsaSignature struct {
R, S *big.Int
}
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
var opts crypto.SignerOpts
hash := hashMap[alg]
if hash == crypto.SHA1 {
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
var realInput []byte
switch key := privateKey.(type) {
case *rsa.PrivateKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
opts = hash
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
case *ecdsa.PrivateKey:
if sigType != signatureAlgorithmECDSA {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
}
algGroup := curveMap[alg]
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
if algGroup != keyGroup {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
}
sig, err := privateKey.Sign(prng, realInput, opts)
logf(logTypeCrypto, "signature: %x", sig)
return sig, err
}
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
hash := hashMap[alg]
if hash == crypto.SHA1 {
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
switch pub := publicKey.(type) {
case *rsa.PublicKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
default:
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
}
case *ecdsa.PublicKey:
if sigType != signatureAlgorithmECDSA {
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
}
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
}
ecdsaSig := new(ecdsaSignature)
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
return err
} else if len(rest) != 0 {
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
return fmt.Errorf("tls.verify: ECDSA verification failure")
}
return nil
default:
return fmt.Errorf("tls.verify: Unsupported key type")
}
}
// 0
// |
// v
// PSK -> HKDF-Extract = Early Secret
// |
// +-----> Derive-Secret(.,
// | "ext binder" |
// | "res binder",
// | "")
// | = binder_key
// |
// +-----> Derive-Secret(., "c e traffic",
// | ClientHello)
// | = client_early_traffic_secret
// |
// +-----> Derive-Secret(., "e exp master",
// | ClientHello)
// | = early_exporter_master_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// (EC)DHE -> HKDF-Extract = Handshake Secret
// |
// +-----> Derive-Secret(., "c hs traffic",
// | ClientHello...ServerHello)
// | = client_handshake_traffic_secret
// |
// +-----> Derive-Secret(., "s hs traffic",
// | ClientHello...ServerHello)
// | = server_handshake_traffic_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// 0 -> HKDF-Extract = Master Secret
// |
// +-----> Derive-Secret(., "c ap traffic",
// | ClientHello...server Finished)
// | = client_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "s ap traffic",
// | ClientHello...server Finished)
// | = server_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "exp master",
// | ClientHello...server Finished)
// | = exporter_master_secret
// |
// +-----> Derive-Secret(., "res master",
// ClientHello...client Finished)
// = resumption_master_secret
// From RFC 5869
// PRK = HMAC-Hash(salt, IKM)
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
salt := saltIn
// if [salt is] not provided, it is set to a string of HashLen zeros
if salt == nil {
salt = bytes.Repeat([]byte{0}, hash.Size())
}
h := hmac.New(hash.New, salt)
h.Write(input)
out := h.Sum(nil)
logf(logTypeCrypto, "HKDF Extract:\n")
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
return out
}
const (
labelExternalBinder = "ext binder"
labelResumptionBinder = "res binder"
labelEarlyTrafficSecret = "c e traffic"
labelEarlyExporterSecret = "e exp master"
labelClientHandshakeTrafficSecret = "c hs traffic"
labelServerHandshakeTrafficSecret = "s hs traffic"
labelClientApplicationTrafficSecret = "c ap traffic"
labelServerApplicationTrafficSecret = "s ap traffic"
labelExporterSecret = "exp master"
labelResumptionSecret = "res master"
labelDerived = "derived"
labelFinished = "finished"
labelResumption = "resumption"
)
// struct HkdfLabel {
// uint16 length;
// opaque label<9..255>;
// opaque hash_value<0..255>;
// };
var HkdfLabelPrefix = "tls13 "
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
label := HkdfLabelPrefix + labelIn
labelLen := len(label)
hashLen := len(hashValue)
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
hkdfLabel[0] = byte(outLen >> 8)
hkdfLabel[1] = byte(outLen)
hkdfLabel[2] = byte(labelLen)
copy(hkdfLabel[3:3+labelLen], []byte(label))
hkdfLabel[3+labelLen] = byte(hashLen)
copy(hkdfLabel[3+labelLen+1:], hashValue)
return hkdfLabel
}
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
out := []byte{}
T := []byte{}
i := byte(1)
for len(out) < outLen {
block := append(T, info...)
block = append(block, i)
h := hmac.New(hash.New, prk)
h.Write(block)
T = h.Sum(nil)
out = append(out, T...)
i++
}
return out[:outLen]
}
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
info := hkdfEncodeLabel(label, hashValue, outLen)
derived := HkdfExpand(hash, secret, info, outLen)
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
return derived
}
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
}
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
mac := hmac.New(params.Hash.New, macKey)
mac.Write(input)
return mac.Sum(nil)
}
type KeySet struct {
Cipher AeadFactory
Key []byte
Iv []byte
Pn []byte
}
func makeTrafficKeys(params CipherSuiteParams, secret []byte) KeySet {
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
return KeySet{
Cipher: params.Cipher,
Key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
Iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
Pn: HkdfExpandLabel(params.Hash, secret, "pn", []byte{}, params.KeyLen),
}
}
func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) {
priv, err := newSigningKey(alg)
if err != nil {
return nil, nil, err
}
cert, err := newSelfSigned(name, alg, priv)
if err != nil {
return nil, nil, err
}
return priv, cert, nil
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}

View File

@@ -1,222 +0,0 @@
package mint
import (
"fmt"
"github.com/bifurcation/mint/syntax"
"time"
)
const (
initialMtu = 1200
initialTimeout = 100
)
// labels for timers
const (
retransmitTimerLabel = "handshake retransmit"
ackTimerLabel = "ack timer"
)
type SentHandshakeFragment struct {
seq uint32
offset int
fragLength int
record uint64
acked bool
}
type DtlsAck struct {
RecordNumbers []uint64 `tls:"head=2"`
}
func wireVersion(h *HandshakeLayer) uint16 {
if h.datagram {
return dtls12WireVersion
}
return tls12Version
}
func dtlsConvertVersion(version uint16) uint16 {
if version == tls12Version {
return dtls12WireVersion
}
if version == tls10Version {
return 0xfeff
}
panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
}
// TODO(ekr@rtfm.com): Move these to state-machine.go
func (h *HandshakeContext) handshakeRetransmit() error {
if _, err := h.hOut.SendQueuedMessages(); err != nil {
return err
}
h.timers.start(retransmitTimerLabel,
h.handshakeRetransmit,
h.timeoutMS)
// TODO(ekr@rtfm.com): Back off timer
return nil
}
func (h *HandshakeContext) sendAck() error {
toack := h.hIn.recvdRecords
count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU
if len(toack) > count {
toack = toack[:count]
}
logf(logTypeHandshake, "Sending ACK: [%x]", toack)
ack := &DtlsAck{toack}
body, err := syntax.Marshal(&ack)
if err != nil {
return err
}
err = h.hOut.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAck,
fragment: body,
})
if err != nil {
return err
}
return nil
}
func (h *HandshakeContext) processAck(data []byte) error {
// Cancel the retransmit timer because we will be resending
// and possibly re-arming later.
h.timers.cancel(retransmitTimerLabel)
ack := &DtlsAck{}
read, err := syntax.Unmarshal(data, &ack)
if err != nil {
return err
}
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
}
logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers)
for _, r := range ack.RecordNumbers {
for _, m := range h.sentFragments {
if r == m.record {
logf(logTypeHandshake, "Marking %v %v(%v) as acked",
m.seq, m.offset, m.fragLength)
m.acked = true
}
}
}
count, err := h.hOut.SendQueuedMessages()
if err != nil {
return err
}
if count == 0 {
logf(logTypeHandshake, "All messages ACKed")
h.hOut.ClearQueuedMessages()
return nil
}
// Reset the timer
h.timers.start(retransmitTimerLabel,
h.handshakeRetransmit,
h.timeoutMS)
return nil
}
func (c *Conn) GetDTLSTimeout() (bool, time.Duration) {
return c.hsCtx.timers.remaining()
}
func (h *HandshakeContext) receivedHandshakeMessage() {
logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight)
// This just enables tests.
if h.hIn == nil {
return
}
if !h.hIn.datagram {
return
}
if h.waitingNextFlight {
logf(logTypeHandshake, "Received the start of the flight")
// Clear the outgoing DTLS queue and terminate the retransmit timer
h.hOut.ClearQueuedMessages()
h.timers.cancel(retransmitTimerLabel)
// OK, we're not waiting any more.
h.waitingNextFlight = false
}
// Now pre-emptively arm the ACK timer if it's not armed already.
// We'll automatically dis-arm it at the end of the handshake.
if h.timers.getTimer(ackTimerLabel) == nil {
h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4)
}
}
func (h *HandshakeContext) receivedEndOfFlight() {
logf(logTypeHandshake, "%p Received the end of the flight", h)
if !h.hIn.datagram {
return
}
// Empty incoming queue
h.hIn.queued = nil
// Note that we are waiting for the next flight.
h.waitingNextFlight = true
// Clear the ACK queue.
h.hIn.recvdRecords = nil
// Disarm the ACK timer
h.timers.cancel(ackTimerLabel)
}
func (h *HandshakeContext) receivedFinalFlight() {
logf(logTypeHandshake, "%p Received final flight", h)
if !h.hIn.datagram {
return
}
// Disarm the ACK timer
h.timers.cancel(ackTimerLabel)
// But send an ACK immediately.
h.sendAck()
}
func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool {
logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen)
for _, f := range h.sentFragments {
if !f.acked {
continue
}
if f.seq != seq {
continue
}
if f.offset > offset {
continue
}
// At this point, we know that the stored fragment starts
// at or before what we want to send, so check where the end
// is.
if f.offset+f.fragLength < offset+fraglen {
continue
}
return true
}
return false
}

View File

@@ -1,626 +0,0 @@
package mint
import (
"bytes"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type ExtensionBody interface {
Type() ExtensionType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ExtensionType extension_type;
// opaque extension_data<0..2^16-1>;
// } Extension;
type Extension struct {
ExtensionType ExtensionType
ExtensionData []byte `tls:"head=2"`
}
func (ext Extension) Marshal() ([]byte, error) {
return syntax.Marshal(ext)
}
func (ext *Extension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ext)
}
type ExtensionList []Extension
type extensionListInner struct {
List []Extension `tls:"head=2"`
}
func (el ExtensionList) Marshal() ([]byte, error) {
return syntax.Marshal(extensionListInner{el})
}
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
var list extensionListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
*el = list.List
return read, nil
}
func (el *ExtensionList) Add(src ExtensionBody) error {
data, err := src.Marshal()
if err != nil {
return err
}
if el == nil {
el = new(ExtensionList)
}
// If one already exists with this type, replace it
for i := range *el {
if (*el)[i].ExtensionType == src.Type() {
(*el)[i].ExtensionData = data
return nil
}
}
// Otherwise append
*el = append(*el, Extension{
ExtensionType: src.Type(),
ExtensionData: data,
})
return nil
}
func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) {
found := make(map[ExtensionType]bool)
for _, dst := range dsts {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
if found[dst.Type()] {
return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type())
}
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return nil, err
}
found[dst.Type()] = true
}
}
}
return found, nil
}
func (el ExtensionList) Find(dst ExtensionBody) (bool, error) {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
err := safeUnmarshal(dst, ext.ExtensionData)
if err != nil {
return true, err
}
return true, nil
}
}
return false, nil
}
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
//
// But we only care about the case where there's a single DNS hostname. We
// will never create anything else, and throw if we receive something else
//
// 2 1 2
// | listLen | NameType | nameLen | name |
type ServerNameExtension string
type serverNameInner struct {
NameType uint8
HostName []byte `tls:"head=2,min=1"`
}
type serverNameListInner struct {
ServerNameList []serverNameInner `tls:"head=2,min=1"`
}
func (sni ServerNameExtension) Type() ExtensionType {
return ExtensionTypeServerName
}
func (sni ServerNameExtension) Marshal() ([]byte, error) {
list := serverNameListInner{
ServerNameList: []serverNameInner{{
NameType: 0x00, // host_name
HostName: []byte(sni),
}},
}
return syntax.Marshal(list)
}
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
var list serverNameListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
// Syntax requires at least one entry
// Entries beyond the first are ignored
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
}
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
return read, nil
}
// struct {
// NamedGroup group;
// opaque key_exchange<1..2^16-1>;
// } KeyShareEntry;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// KeyShareEntry client_shares<0..2^16-1>;
//
// case hello_retry_request:
// NamedGroup selected_group;
//
// case server_hello:
// KeyShareEntry server_share;
// };
// } KeyShare;
type KeyShareEntry struct {
Group NamedGroup
KeyExchange []byte `tls:"head=2,min=1"`
}
func (kse KeyShareEntry) SizeValid() bool {
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
}
type KeyShareExtension struct {
HandshakeType HandshakeType
SelectedGroup NamedGroup
Shares []KeyShareEntry
}
type KeyShareClientHelloInner struct {
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
}
type KeyShareHelloRetryInner struct {
SelectedGroup NamedGroup
}
type KeyShareServerHelloInner struct {
ServerShare KeyShareEntry
}
func (ks KeyShareExtension) Type() ExtensionType {
return ExtensionTypeKeyShare
}
func (ks KeyShareExtension) Marshal() ([]byte, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
for _, share := range ks.Shares {
if !share.SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
case HandshakeTypeHelloRetryRequest:
if len(ks.Shares) > 0 {
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
}
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
case HandshakeTypeServerHello:
if len(ks.Shares) != 1 {
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
}
if !ks.Shares[0].SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
default:
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
var inner KeyShareClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
for _, share := range inner.ClientShares {
if !share.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
ks.Shares = inner.ClientShares
return read, nil
case HandshakeTypeHelloRetryRequest:
var inner KeyShareHelloRetryInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
ks.SelectedGroup = inner.SelectedGroup
return read, nil
case HandshakeTypeServerHello:
var inner KeyShareServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if !inner.ServerShare.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
ks.Shares = []KeyShareEntry{inner.ServerShare}
return read, nil
default:
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
// struct {
// NamedGroup named_group_list<2..2^16-1>;
// } NamedGroupList;
type SupportedGroupsExtension struct {
Groups []NamedGroup `tls:"head=2,min=2"`
}
func (sg SupportedGroupsExtension) Type() ExtensionType {
return ExtensionTypeSupportedGroups
}
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sg)
}
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sg)
}
// struct {
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
// } SignatureSchemeList
type SignatureAlgorithmsExtension struct {
Algorithms []SignatureScheme `tls:"head=2,min=2"`
}
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
return ExtensionTypeSignatureAlgorithms
}
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sa)
}
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sa)
}
// struct {
// opaque identity<1..2^16-1>;
// uint32 obfuscated_ticket_age;
// } PskIdentity;
//
// opaque PskBinderEntry<32..255>;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// PskIdentity identities<7..2^16-1>;
// PskBinderEntry binders<33..2^16-1>;
//
// case server_hello:
// uint16 selected_identity;
// };
//
// } PreSharedKeyExtension;
type PSKIdentity struct {
Identity []byte `tls:"head=2,min=1"`
ObfuscatedTicketAge uint32
}
type PSKBinderEntry struct {
Binder []byte `tls:"head=1,min=32"`
}
type PreSharedKeyExtension struct {
HandshakeType HandshakeType
Identities []PSKIdentity
Binders []PSKBinderEntry
SelectedIdentity uint16
}
type preSharedKeyClientInner struct {
Identities []PSKIdentity `tls:"head=2,min=7"`
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}
type preSharedKeyServerInner struct {
SelectedIdentity uint16
}
func (psk PreSharedKeyExtension) Type() ExtensionType {
return ExtensionTypePreSharedKey
}
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(preSharedKeyClientInner{
Identities: psk.Identities,
Binders: psk.Binders,
})
case HandshakeTypeServerHello:
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
}
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
default:
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
var inner preSharedKeyClientInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if len(inner.Identities) != len(inner.Binders) {
return 0, fmt.Errorf("Lengths of identities and binders not equal")
}
psk.Identities = inner.Identities
psk.Binders = inner.Binders
return read, nil
case HandshakeTypeServerHello:
var inner preSharedKeyServerInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
psk.SelectedIdentity = inner.SelectedIdentity
return read, nil
default:
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
for i, localID := range psk.Identities {
if bytes.Equal(localID.Identity, id) {
return psk.Binders[i].Binder, true
}
}
return nil, false
}
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
//
// struct {
// PskKeyExchangeMode ke_modes<1..255>;
// } PskKeyExchangeModes;
type PSKKeyExchangeModesExtension struct {
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
}
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
return ExtensionTypePSKKeyExchangeModes
}
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
return syntax.Marshal(pkem)
}
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, pkem)
}
// struct {
// } EarlyDataIndication;
type EarlyDataExtension struct{}
func (ed EarlyDataExtension) Type() ExtensionType {
return ExtensionTypeEarlyData
}
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
return 0, nil
}
// struct {
// uint32 max_early_data_size;
// } TicketEarlyDataInfo;
type TicketEarlyDataInfoExtension struct {
MaxEarlyDataSize uint32
}
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
return ExtensionTypeTicketEarlyDataInfo
}
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
return syntax.Marshal(tedi)
}
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tedi)
}
// opaque ProtocolName<1..2^8-1>;
//
// struct {
// ProtocolName protocol_name_list<2..2^16-1>
// } ProtocolNameList;
type ALPNExtension struct {
Protocols []string
}
type protocolNameInner struct {
Name []byte `tls:"head=1,min=1"`
}
type alpnExtensionInner struct {
Protocols []protocolNameInner `tls:"head=2,min=2"`
}
func (alpn ALPNExtension) Type() ExtensionType {
return ExtensionTypeALPN
}
func (alpn ALPNExtension) Marshal() ([]byte, error) {
protocols := make([]protocolNameInner, len(alpn.Protocols))
for i, protocol := range alpn.Protocols {
protocols[i] = protocolNameInner{[]byte(protocol)}
}
return syntax.Marshal(alpnExtensionInner{protocols})
}
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
var inner alpnExtensionInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
alpn.Protocols = make([]string, len(inner.Protocols))
for i, protocol := range inner.Protocols {
alpn.Protocols[i] = string(protocol.Name)
}
return read, nil
}
// struct {
// ProtocolVersion versions<2..254>;
// } SupportedVersions;
type SupportedVersionsExtension struct {
HandshakeType HandshakeType
Versions []uint16
}
type SupportedVersionsClientHelloInner struct {
Versions []uint16 `tls:"head=1,min=2,max=254"`
}
type SupportedVersionsServerHelloInner struct {
Version uint16
}
func (sv SupportedVersionsExtension) Type() ExtensionType {
return ExtensionTypeSupportedVersions
}
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
switch sv.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions})
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]})
default:
return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
}
}
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
switch sv.HandshakeType {
case HandshakeTypeClientHello:
var inner SupportedVersionsClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
sv.Versions = inner.Versions
return read, nil
case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest:
var inner SupportedVersionsServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
sv.Versions = []uint16{inner.Version}
return read, nil
default:
return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed")
}
}
// struct {
// opaque cookie<1..2^16-1>;
// } Cookie;
type CookieExtension struct {
Cookie []byte `tls:"head=2,min=1"`
}
func (c CookieExtension) Type() ExtensionType {
return ExtensionTypeCookie
}
func (c CookieExtension) Marshal() ([]byte, error) {
return syntax.Marshal(c)
}
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, c)
}

View File

@@ -1,147 +0,0 @@
package mint
import (
"encoding/hex"
"math/big"
)
var (
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B423861285C97FFFFFFFFFFFFFFFF"
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
"FFFFFFFFFFFFFFFF"
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
)

View File

@@ -1,98 +0,0 @@
// Read a generic "framed" packet consisting of a header and a
// This is used for both TLS Records and TLS Handshake Messages
package mint
type framing interface {
headerLen() int
defaultReadLen() int
frameLen(hdr []byte) (int, error)
}
const (
kFrameReaderHdr = 0
kFrameReaderBody = 1
)
type frameNextAction func(f *frameReader) error
type frameReader struct {
details framing
state uint8
header []byte
body []byte
working []byte
writeOffset int
remainder []byte
}
func newFrameReader(d framing) *frameReader {
hdr := make([]byte, d.headerLen())
return &frameReader{
d,
kFrameReaderHdr,
hdr,
nil,
hdr,
0,
nil,
}
}
func dup(a []byte) []byte {
r := make([]byte, len(a))
copy(r, a)
return r
}
func (f *frameReader) needed() int {
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
if tmp < 0 {
return 0
}
return tmp
}
func (f *frameReader) addChunk(in []byte) {
// Append to the buffer.
logf(logTypeFrameReader, "Appending %v", len(in))
f.remainder = append(f.remainder, in...)
}
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
for f.needed() == 0 {
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
// Fill out our working block
copied := copy(f.working[f.writeOffset:], f.remainder)
f.remainder = f.remainder[copied:]
f.writeOffset += copied
if f.writeOffset < len(f.working) {
logf(logTypeVerbose, "Read would have blocked 1")
return nil, nil, AlertWouldBlock
}
// Reset the write offset, because we are now full.
f.writeOffset = 0
// We have read a full frame
if f.state == kFrameReaderBody {
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
f.state = kFrameReaderHdr
f.working = f.header
return dup(f.header), dup(f.body), nil
}
// We have read the header
bodyLen, err := f.details.frameLen(f.header)
if err != nil {
return nil, nil, err
}
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
f.body = make([]byte, bodyLen)
f.working = f.body
f.writeOffset = 0
f.state = kFrameReaderBody
}
logf(logTypeVerbose, "Read would have blocked 2")
return nil, nil, AlertWouldBlock
}

View File

@@ -1,570 +0,0 @@
package mint
import (
"fmt"
"io"
"net"
)
const (
handshakeHeaderLenTLS = 4 // handshake message header length
handshakeHeaderLenDTLS = 12 // handshake message header length
maxHandshakeMessageLen = 1 << 24 // max handshake message length
)
// struct {
// HandshakeType msg_type; /* handshake type */
// uint24 length; /* bytes in message */
// select (HandshakeType) {
// ...
// } body;
// } Handshake;
//
// We do the select{...} part in a different layer, so we treat the
// actual message body as opaque:
//
// struct {
// HandshakeType msg_type;
// opaque msg<0..2^24-1>
// } Handshake;
//
type HandshakeMessage struct {
msgType HandshakeType
seq uint32
body []byte
datagram bool
offset uint32 // Used for DTLS
length uint32
cipher *cipherState
}
// Note: This could be done with the `syntax` module, using the simplified
// syntax as discussed above. However, since this is so simple, there's not
// much benefit to doing so.
// When datagram is set, we marshal this as a whole DTLS record.
func (hm *HandshakeMessage) Marshal() []byte {
if hm == nil {
return []byte{}
}
fragLen := len(hm.body)
var data []byte
if hm.datagram {
data = make([]byte, handshakeHeaderLenDTLS+fragLen)
} else {
data = make([]byte, handshakeHeaderLenTLS+fragLen)
}
tmp := data
tmp = encodeUint(uint64(hm.msgType), 1, tmp)
tmp = encodeUint(uint64(hm.length), 3, tmp)
if hm.datagram {
tmp = encodeUint(uint64(hm.seq), 2, tmp)
tmp = encodeUint(uint64(hm.offset), 3, tmp)
tmp = encodeUint(uint64(fragLen), 3, tmp)
}
copy(tmp, hm.body)
return data
}
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
var body HandshakeMessageBody
switch hm.msgType {
case HandshakeTypeClientHello:
body = new(ClientHelloBody)
case HandshakeTypeServerHello:
body = new(ServerHelloBody)
case HandshakeTypeEncryptedExtensions:
body = new(EncryptedExtensionsBody)
case HandshakeTypeCertificate:
body = new(CertificateBody)
case HandshakeTypeCertificateRequest:
body = new(CertificateRequestBody)
case HandshakeTypeCertificateVerify:
body = new(CertificateVerifyBody)
case HandshakeTypeFinished:
body = &FinishedBody{VerifyDataLen: len(hm.body)}
case HandshakeTypeNewSessionTicket:
body = new(NewSessionTicketBody)
case HandshakeTypeKeyUpdate:
body = new(KeyUpdateBody)
case HandshakeTypeEndOfEarlyData:
body = new(EndOfEarlyDataBody)
default:
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
}
err := safeUnmarshal(body, hm.body)
return body, err
}
func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
data, err := body.Marshal()
if err != nil {
return nil, err
}
m := &HandshakeMessage{
msgType: body.Type(),
body: data,
seq: h.msgSeq,
datagram: h.datagram,
length: uint32(len(data)),
}
h.msgSeq++
return m, nil
}
type HandshakeLayer struct {
ctx *HandshakeContext // The handshake we are attached to
nonblocking bool // Should we operate in nonblocking mode
conn RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader
datagram bool // Is this DTLS?
msgSeq uint32 // The DTLS message sequence number
queued []*HandshakeMessage // In/out queue
sent []*HandshakeMessage // Sent messages for DTLS
recvdRecords []uint64 // Records we have received.
maxFragmentLen int
}
type handshakeLayerFrameDetails struct {
datagram bool
}
func (d handshakeLayerFrameDetails) headerLen() int {
if d.datagram {
return handshakeHeaderLenDTLS
}
return handshakeHeaderLenTLS
}
func (d handshakeLayerFrameDetails) defaultReadLen() int {
return d.headerLen() + maxFragmentLen
}
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
logf(logTypeIO, "Header=%x", hdr)
// The length of this fragment (as opposed to the message)
// is always the last three bytes for both TLS and DTLS
val, _ := decodeUint(hdr[len(hdr)-3:], 3)
return int(val), nil
}
func NewHandshakeLayerTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
h.datagram = false
h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
h.maxFragmentLen = maxFragmentLen
return &h
}
func NewHandshakeLayerDTLS(c *HandshakeContext, r RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
h.datagram = true
h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
h.maxFragmentLen = initialMtu // Not quite right
return &h
}
func (h *HandshakeLayer) readRecord() error {
var pt *TLSPlaintext
var err error
if h.datagram {
logf(logTypeVerbose, "Trying to read record")
pt, err = h.conn.(*RecordLayerImpl).ReadRecordAnyEpoch()
} else {
pt, err = h.conn.ReadRecord()
}
if err != nil {
return err
}
switch pt.contentType {
case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
default:
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
}
if pt.contentType == RecordTypeAck {
if !h.datagram {
return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
}
logf(logTypeIO, "read ACK")
return h.ctx.processAck(pt.fragment)
}
if pt.contentType == RecordTypeAlert {
logf(logTypeIO, "read alert %v", pt.fragment[1])
if len(pt.fragment) < 2 {
h.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
return Alert(pt.fragment[1])
}
assert(h.ctx.hIn.conn != nil)
if pt.epoch != h.ctx.hIn.conn.Epoch() {
// This is out of order but we're dropping it.
// TODO(ekr@rtfm.com): If server, need to retransmit Finished.
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
return nil
}
// Anything else shouldn't happen.
return AlertIllegalParameter
}
h.recvdRecords = append(h.recvdRecords, pt.seq)
h.frame.addChunk(pt.fragment)
return nil
}
// sendAlert sends a TLS alert message.
func (h *HandshakeLayer) sendAlert(err Alert) error {
tmp := make([]byte, 2)
tmp[0] = AlertLevelError
tmp[1] = byte(err)
h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: tmp},
)
// closeNotify is a special case in that it isn't an error:
if err != AlertCloseNotify {
return &net.OpError{Op: "local error", Err: err}
}
return nil
}
func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
h.msgSeq = seq + 1
var i int
var m *HandshakeMessage
for i, m = range h.queued {
if m.seq > seq {
break
}
}
h.queued = h.queued[i:]
}
func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
if hm.seq < h.msgSeq {
return nil, nil
}
// TODO(ekr@rtfm.com): Send an ACK immediately if we got something
// out of order.
h.ctx.receivedHandshakeMessage()
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
// TODO(ekr@rtfm.com): Check the length?
// This is complete.
h.noteMessageDelivered(hm.seq)
return hm, nil
}
// Now insert sorted.
var i int
for i = 0; i < len(h.queued); i++ {
f := h.queued[i]
if hm.seq < f.seq {
break
}
if hm.offset < f.offset {
break
}
}
tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
tmp = append(tmp, h.queued[:i]...)
tmp = append(tmp, hm)
tmp = append(tmp, h.queued[i:]...)
h.queued = tmp
return h.checkMessageAvailable()
}
func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
if len(h.queued) == 0 {
return nil, nil
}
hm := h.queued[0]
if hm.seq != h.msgSeq {
return nil, nil
}
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
// TODO(ekr@rtfm.com): Check the length?
// This is complete.
h.noteMessageDelivered(hm.seq)
return hm, nil
}
// OK, this at least might complete the message.
end := uint32(0)
buf := make([]byte, hm.length)
for _, f := range h.queued {
// Out of fragments
if f.seq > hm.seq {
break
}
if f.length != uint32(len(buf)) {
return nil, fmt.Errorf("Mismatched DTLS length")
}
if f.offset > end {
break
}
if f.offset+uint32(len(f.body)) > end {
// OK, this is adding something we don't know about
copy(buf[f.offset:], f.body)
end = f.offset + uint32(len(f.body))
if end == hm.length {
h2 := *hm
h2.offset = 0
h2.body = buf
h.noteMessageDelivered(hm.seq)
return &h2, nil
}
}
}
return nil, nil
}
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
var hdr, body []byte
var err error
hm, err := h.checkMessageAvailable()
if err != nil {
return nil, err
}
if hm != nil {
return hm, nil
}
for {
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 {
logf(logTypeVerbose, "Trying to read a new record")
err = h.readRecord()
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
return nil, err
}
}
hdr, body, err = h.frame.process()
if err == nil {
break
}
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
return nil, err
}
}
logf(logTypeHandshake, "read handshake message")
hm = &HandshakeMessage{}
hm.msgType = HandshakeType(hdr[0])
hm.datagram = h.datagram
hm.body = make([]byte, len(body))
copy(hm.body, body)
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
if h.datagram {
tmp, hdr := decodeUint(hdr[1:], 3)
hm.length = uint32(tmp)
tmp, hdr = decodeUint(hdr, 2)
hm.seq = uint32(tmp)
tmp, hdr = decodeUint(hdr, 3)
hm.offset = uint32(tmp)
return h.newFragmentReceived(hm)
}
hm.length = uint32(len(body))
return hm, nil
}
func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
if h.datagram {
hm.cipher = h.conn.(*RecordLayerImpl).cipher
h.queued = append(h.queued, hm)
return nil
}
_, err := h.WriteMessages([]*HandshakeMessage{hm})
return err
}
func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
logf(logTypeHandshake, "Sending outgoing messages")
count, err := h.WriteMessages(h.queued)
if !h.datagram {
h.ClearQueuedMessages()
}
return count, err
}
func (h *HandshakeLayer) ClearQueuedMessages() {
logf(logTypeHandshake, "Clearing outgoing hs message queue")
h.queued = nil
}
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
var buf []byte
// Figure out if we're going to want the full header or just
// the body
hdrlen := 0
if hm.datagram {
hdrlen = handshakeHeaderLenDTLS
} else if start == 0 {
hdrlen = handshakeHeaderLenTLS
}
// Compute the amount of body we can fit in
room -= hdrlen
if room == 0 {
// This works because we are doing one record per
// message
panic("Too short max fragment len")
}
bodylen := len(hm.body) - start
if bodylen > room {
bodylen = room
}
body := hm.body[start : start+bodylen]
// Now see if this chunk has been ACKed. This doesn't produce ideal
// retransmission but is simple.
if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
return false, start + bodylen, nil
}
// Encode the data.
if hdrlen > 0 {
hm2 := *hm
hm2.offset = uint32(start)
hm2.body = body
buf = hm2.Marshal()
hm = &hm2
} else {
buf = body
}
var err error
if h.datagram {
// Remember that we sent this.
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
hm.seq,
start,
len(body),
h.conn.(*RecordLayerImpl).cipher.combineSeq(true),
false,
})
err = h.conn.(*RecordLayerImpl).writeRecordWithPadding(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
},
hm.cipher, 0)
} else {
err = h.conn.WriteRecord(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
})
}
return true, start + bodylen, err
}
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
start := int(0)
if len(hm.body) > maxHandshakeMessageLen {
return 0, fmt.Errorf("Tried to write a handshake message that's too long")
}
written := 0
wrote := false
// Always make one pass through to allow EOED (which is empty).
for {
var err error
wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
if err != nil {
return 0, err
}
if wrote {
written++
}
if start >= len(hm.body) {
break
}
}
return written, nil
}
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
written := 0
for _, hm := range hms {
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
wrote, err := h.WriteMessage(hm)
if err != nil {
return 0, err
}
written += wrote
}
return written, nil
}
func encodeUint(v uint64, size int, out []byte) []byte {
for i := size - 1; i >= 0; i-- {
out[i] = byte(v & 0xff)
v >>= 8
}
return out[size:]
}
func decodeUint(in []byte, size int) (uint64, []byte) {
val := uint64(0)
for i := 0; i < size; i++ {
val <<= 8
val += uint64(in[i])
}
return val, in[size:]
}
type marshalledPDU interface {
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
read, err := pdu.Unmarshal(data)
if err != nil {
return err
}
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
}
return nil
}

View File

@@ -1,481 +0,0 @@
package mint
import (
"bytes"
"crypto"
"crypto/x509"
"encoding/binary"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type HandshakeMessageBody interface {
Type() HandshakeType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random;
// opaque legacy_session_id<0..32>;
// CipherSuite cipher_suites<2..2^16-2>;
// opaque legacy_compression_methods<1..2^8-1>;
// Extension extensions<0..2^16-1>;
// } ClientHello;
type ClientHelloBody struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte
CipherSuites []CipherSuite
Extensions ExtensionList
}
type clientHelloBodyInnerTLS struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
type clientHelloBodyInnerDTLS struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
EmptyCookie uint8
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
func (ch ClientHelloBody) Type() HandshakeType {
return HandshakeTypeClientHello
}
func (ch ClientHelloBody) Marshal() ([]byte, error) {
if ch.LegacyVersion == tls12Version {
return syntax.Marshal(clientHelloBodyInnerTLS{
LegacyVersion: ch.LegacyVersion,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
} else {
return syntax.Marshal(clientHelloBodyInnerDTLS{
LegacyVersion: ch.LegacyVersion,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
}
}
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
var read int
var err error
// Note that this might be 0, in which case we do TLS. That
// makes the tests easier.
if ch.LegacyVersion != dtls12WireVersion {
var inner clientHelloBodyInnerTLS
read, err = syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.LegacyVersion = inner.LegacyVersion
ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
} else {
var inner clientHelloBodyInnerDTLS
read, err = syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if inner.EmptyCookie != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid cookie")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.LegacyVersion = inner.LegacyVersion
ch.Random = inner.Random
ch.LegacySessionID = inner.LegacySessionID
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
}
return read, nil
}
// TODO: File a spec bug to clarify this
func (ch ClientHelloBody) Truncated() ([]byte, error) {
if len(ch.Extensions) == 0 {
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
}
pskExt := ch.Extensions[len(ch.Extensions)-1]
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
}
body, err := ch.Marshal()
if err != nil {
return nil, err
}
chm := &HandshakeMessage{
msgType: ch.Type(),
body: body,
length: uint32(len(body)),
}
chData := chm.Marshal()
psk := PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
}
_, err = psk.Unmarshal(pskExt.ExtensionData)
if err != nil {
return nil, err
}
// Marshal just the binders so that we know how much to truncate
binders := struct {
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}{Binders: psk.Binders}
binderData, _ := syntax.Marshal(binders)
binderLen := len(binderData)
chLen := len(chData)
return chData[:chLen-binderLen], nil
}
// struct {
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random;
// opaque legacy_session_id_echo<0..32>;
// CipherSuite cipher_suite;
// uint8 legacy_compression_method = 0;
// Extension extensions<6..2^16-1>;
// } ServerHello;
type ServerHelloBody struct {
Version uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
CipherSuite CipherSuite
LegacyCompressionMethod uint8
Extensions ExtensionList `tls:"head=2"`
}
func (sh ServerHelloBody) Type() HandshakeType {
return HandshakeTypeServerHello
}
func (sh ServerHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(sh)
}
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sh)
}
// struct {
// opaque verify_data[verify_data_length];
// } Finished;
//
// verifyDataLen is not a field in the TLS struct, but we add it here so
// that calling code can tell us how much data to expect when we marshal /
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
// try to keep the signature consistent for now.)
//
// For similar reasons, we don't use the `syntax` module here, because this
// struct doesn't map well to standard TLS presentation language concepts.
//
// TODO: File a spec bug
type FinishedBody struct {
VerifyDataLen int
VerifyData []byte
}
func (fin FinishedBody) Type() HandshakeType {
return HandshakeTypeFinished
}
func (fin FinishedBody) Marshal() ([]byte, error) {
if len(fin.VerifyData) != fin.VerifyDataLen {
return nil, fmt.Errorf("tls.finished: data length mismatch")
}
body := make([]byte, len(fin.VerifyData))
copy(body, fin.VerifyData)
return body, nil
}
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
if len(data) < fin.VerifyDataLen {
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
}
fin.VerifyData = make([]byte, fin.VerifyDataLen)
copy(fin.VerifyData, data[:fin.VerifyDataLen])
return fin.VerifyDataLen, nil
}
// struct {
// Extension extensions<0..2^16-1>;
// } EncryptedExtensions;
//
// Marshal() and Unmarshal() are handled by ExtensionList
type EncryptedExtensionsBody struct {
Extensions ExtensionList `tls:"head=2"`
}
func (ee EncryptedExtensionsBody) Type() HandshakeType {
return HandshakeTypeEncryptedExtensions
}
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
return syntax.Marshal(ee)
}
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ee)
}
// opaque ASN1Cert<1..2^24-1>;
//
// struct {
// ASN1Cert cert_data;
// Extension extensions<0..2^16-1>
// } CertificateEntry;
//
// struct {
// opaque certificate_request_context<0..2^8-1>;
// CertificateEntry certificate_list<0..2^24-1>;
// } Certificate;
type CertificateEntry struct {
CertData *x509.Certificate
Extensions ExtensionList
}
type CertificateBody struct {
CertificateRequestContext []byte
CertificateList []CertificateEntry
}
type certificateEntryInner struct {
CertData []byte `tls:"head=3,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
type certificateBodyInner struct {
CertificateRequestContext []byte `tls:"head=1"`
CertificateList []certificateEntryInner `tls:"head=3"`
}
func (c CertificateBody) Type() HandshakeType {
return HandshakeTypeCertificate
}
func (c CertificateBody) Marshal() ([]byte, error) {
inner := certificateBodyInner{
CertificateRequestContext: c.CertificateRequestContext,
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
}
for i, entry := range c.CertificateList {
inner.CertificateList[i] = certificateEntryInner{
CertData: entry.CertData.Raw,
Extensions: entry.Extensions,
}
}
return syntax.Marshal(inner)
}
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
inner := certificateBodyInner{}
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return read, err
}
c.CertificateRequestContext = inner.CertificateRequestContext
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
for i, entry := range inner.CertificateList {
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
if err != nil {
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
}
c.CertificateList[i].Extensions = entry.Extensions
}
return read, nil
}
// struct {
// SignatureScheme algorithm;
// opaque signature<0..2^16-1>;
// } CertificateVerify;
type CertificateVerifyBody struct {
Algorithm SignatureScheme
Signature []byte `tls:"head=2"`
}
func (cv CertificateVerifyBody) Type() HandshakeType {
return HandshakeTypeCertificateVerify
}
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
return syntax.Marshal(cv)
}
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cv)
}
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
// TODO: Change context for client auth
// TODO: Put this in a const
const context = "TLS 1.3, server CertificateVerify"
sigInput := bytes.Repeat([]byte{0x20}, 64)
sigInput = append(sigInput, []byte(context)...)
sigInput = append(sigInput, []byte{0}...)
sigInput = append(sigInput, data...)
return sigInput
}
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
sigInput := cv.EncodeSignatureInput(handshakeHash)
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return
}
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
sigInput := cv.EncodeSignatureInput(handshakeHash)
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
}
// struct {
// opaque certificate_request_context<0..2^8-1>;
// Extension extensions<2..2^16-1>;
// } CertificateRequest;
type CertificateRequestBody struct {
CertificateRequestContext []byte `tls:"head=1"`
Extensions ExtensionList `tls:"head=2"`
}
func (cr CertificateRequestBody) Type() HandshakeType {
return HandshakeTypeCertificateRequest
}
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(cr)
}
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cr)
}
// struct {
// uint32 ticket_lifetime;
// uint32 ticket_age_add;
// opaque ticket_nonce<1..255>;
// opaque ticket<1..2^16-1>;
// Extension extensions<0..2^16-2>;
// } NewSessionTicket;
type NewSessionTicketBody struct {
TicketLifetime uint32
TicketAgeAdd uint32
TicketNonce []byte `tls:"head=1"`
Ticket []byte `tls:"head=2,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
const ticketNonceLen = 16
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
buf := make([]byte, 4+ticketNonceLen+ticketLen)
_, err := prng.Read(buf)
if err != nil {
return nil, err
}
tkt := &NewSessionTicketBody{
TicketLifetime: ticketLifetime,
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
TicketNonce: buf[4 : 4+ticketNonceLen],
Ticket: buf[4+ticketNonceLen:],
}
return tkt, err
}
func (tkt NewSessionTicketBody) Type() HandshakeType {
return HandshakeTypeNewSessionTicket
}
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
return syntax.Marshal(tkt)
}
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tkt)
}
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
//
// struct {
// KeyUpdateRequest request_update;
// } KeyUpdate;
type KeyUpdateBody struct {
KeyUpdateRequest KeyUpdateRequest
}
func (ku KeyUpdateBody) Type() HandshakeType {
return HandshakeTypeKeyUpdate
}
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
return syntax.Marshal(ku)
}
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ku)
}
// struct {} EndOfEarlyData;
type EndOfEarlyDataBody struct{}
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
return HandshakeTypeEndOfEarlyData
}
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
return 0, nil
}

View File

@@ -1,55 +0,0 @@
package mint
import (
"fmt"
"log"
"os"
"strings"
)
// We use this environment variable to control logging. It should be a
// comma-separated list of log tags (see below) or "*" to enable all logging.
const logConfigVar = "MINT_LOG"
// Pre-defined log types
const (
logTypeCrypto = "crypto"
logTypeHandshake = "handshake"
logTypeNegotiation = "negotiation"
logTypeIO = "io"
logTypeFrameReader = "frame"
logTypeVerbose = "verbose"
)
var (
logFunction = log.Printf
logAll = false
logSettings = map[string]bool{}
)
func init() {
parseLogEnv(os.Environ())
}
func parseLogEnv(env []string) {
for _, stmt := range env {
if strings.HasPrefix(stmt, logConfigVar+"=") {
val := stmt[len(logConfigVar)+1:]
if val == "*" {
logAll = true
} else {
for _, t := range strings.Split(val, ",") {
logSettings[t] = true
}
}
}
}
}
func logf(tag string, format string, args ...interface{}) {
if logAll || logSettings[tag] {
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
logFunction(fullFormat, args...)
}
}

View File

@@ -1,218 +0,0 @@
package mint
import (
"bytes"
"encoding/hex"
"fmt"
"time"
)
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
for _, offeredVersion := range offered {
for _, tls13Version := range supported {
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, tls13Version)
if offeredVersion == tls13Version {
// XXX: Should probably be highest supported version, but for now, we
// only support one version, so it doesn't really matter.
return true, offeredVersion
}
}
}
return false, 0
}
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
for _, share := range keyShares {
for _, group := range groups {
if group != share.Group {
continue
}
pub, priv, err := newKeyShare(share.Group)
if err != nil {
// If we encounter an error, just keep looking
continue
}
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
if err != nil {
// If we encounter an error, just keep looking
continue
}
return true, group, pub, dhSecret
}
}
return false, 0, nil, nil
}
const (
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
)
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
for i, id := range identities {
identityHex := hex.EncodeToString(id.Identity)
psk, ok := psks.Get(identityHex)
if !ok {
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
continue
}
// For resumption, make sure the ticket age is correct
if psk.IsResumption {
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
ticketAgeDelta := knownTicketAge - extTicketAge
if knownTicketAge < extTicketAge {
ticketAgeDelta = extTicketAge - knownTicketAge
}
if ticketAgeDelta > ticketAgeTolerance {
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
}
}
params, ok := cipherSuiteMap[psk.CipherSuite]
if !ok {
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
return false, 0, nil, CipherSuiteParams{}, err
}
// Compute binder
binderLabel := labelExternalBinder
if psk.IsResumption {
binderLabel = labelResumptionBinder
}
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
// context = ClientHello[truncated]
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
ctxHash := params.Hash.New()
ctxHash.Write(context)
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
if !bytes.Equal(binder, binders[i].Binder) {
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
}
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
return true, i, &psk, params, nil
}
logf(logTypeNegotiation, "Failed to find a usable PSK")
return false, 0, nil, CipherSuiteParams{}, nil
}
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
dhAllowed := false
dhRequired := true
for _, mode := range modes {
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
dhRequired = dhRequired && (mode == PSKModeDHEKE)
}
// Use PSK if we can meet DH requirement and modes were provided
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
// Use DH if allowed
usingDH := canDoDH && (dhAllowed || !usingPSK)
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
return usingDH, usingPSK
}
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
// Select for server name if provided
candidates := certs
if serverName != nil {
candidatesByName := []*Certificate{}
for _, cert := range certs {
for _, name := range cert.Chain[0].DNSNames {
if len(*serverName) > 0 && name == *serverName {
candidatesByName = append(candidatesByName, cert)
}
}
}
if len(candidatesByName) == 0 {
return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName)
}
candidates = candidatesByName
}
// Select for signature scheme
for _, cert := range candidates {
for _, scheme := range signatureSchemes {
if !schemeValidForKey(scheme, cert.PrivateKey) {
continue
}
return cert, scheme, nil
}
}
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
}
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) {
using = gotEarlyData && usingPSK && allowEarlyData
rejected = gotEarlyData && !using
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected)
return
}
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
for _, s1 := range offered {
if psk != nil {
if s1 == psk.CipherSuite {
return s1, nil
}
continue
}
for _, s2 := range supported {
if s1 == s2 {
return s1, nil
}
}
}
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
}
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
for _, p1 := range offered {
if psk != nil {
if p1 != psk.NextProto {
continue
}
}
for _, p2 := range supported {
if p1 == p2 {
return p1, nil
}
}
}
// If the client offers ALPN on resumption, it must match the earlier one
var err error
if psk != nil && psk.IsResumption && (len(offered) > 0) {
err = fmt.Errorf("ALPN for PSK not provided")
}
return "", err
}

View File

@@ -1,507 +0,0 @@
package mint
import (
"crypto/cipher"
"fmt"
"io"
"sync"
)
const (
sequenceNumberLen = 8 // sequence number length
recordHeaderLenTLS = 5 // record header length (TLS)
recordHeaderLenDTLS = 13 // record header length (DTLS)
maxFragmentLen = 1 << 14 // max number of bytes in a record
)
type DecryptError string
func (err DecryptError) Error() string {
return string(err)
}
type Direction uint8
const (
DirectionWrite = Direction(1)
DirectionRead = Direction(2)
)
// struct {
// ContentType type;
// ProtocolVersion record_version [0301 for CH, 0303 for others]
// uint16 length;
// opaque fragment[TLSPlaintext.length];
// } TLSPlaintext;
type TLSPlaintext struct {
// Omitted: record_version (static)
// Omitted: length (computed from fragment)
contentType RecordType
epoch Epoch
seq uint64
fragment []byte
}
func NewTLSPlaintext(ct RecordType, epoch Epoch, fragment []byte) *TLSPlaintext {
return &TLSPlaintext{
contentType: ct,
epoch: epoch,
fragment: fragment,
}
}
func (t TLSPlaintext) Fragment() []byte {
return t.fragment
}
type cipherState struct {
epoch Epoch // DTLS epoch
ivLength int // Length of the seq and nonce fields
seq uint64 // Zero-padded sequence number
iv []byte // Buffer for the IV
cipher cipher.AEAD // AEAD cipher
}
type RecordLayerFactory interface {
NewLayer(conn io.ReadWriter, dir Direction) RecordLayer
}
type RecordLayer interface {
Lock()
Unlock()
SetVersion(v uint16)
SetLabel(s string)
Rekey(epoch Epoch, factory AeadFactory, keys *KeySet) error
ResetClear(seq uint64)
DiscardReadKey(epoch Epoch)
PeekRecordType(block bool) (RecordType, error)
ReadRecord() (*TLSPlaintext, error)
WriteRecord(pt *TLSPlaintext) error
Epoch() Epoch
}
type RecordLayerImpl struct {
sync.Mutex
label string
direction Direction
version uint16 // The current version number
conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader
nextData []byte // The next record to send
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read
cipher *cipherState
readCiphers map[Epoch]*cipherState
datagram bool
}
func (r *RecordLayerImpl) Impl() *RecordLayerImpl {
return r
}
type recordLayerFrameDetails struct {
datagram bool
}
func (d recordLayerFrameDetails) headerLen() int {
if d.datagram {
return recordHeaderLenDTLS
}
return recordHeaderLenTLS
}
func (d recordLayerFrameDetails) defaultReadLen() int {
return d.headerLen() + maxFragmentLen
}
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
}
func newCipherStateNull() *cipherState {
return &cipherState{EpochClear, 0, 0, nil, nil}
}
func newCipherStateAead(epoch Epoch, factory AeadFactory, key []byte, iv []byte) (*cipherState, error) {
cipher, err := factory(key)
if err != nil {
return nil, err
}
return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
}
func NewRecordLayerTLS(conn io.ReadWriter, dir Direction) *RecordLayerImpl {
r := RecordLayerImpl{}
r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{false})
r.cipher = newCipherStateNull()
r.version = tls10Version
return &r
}
func NewRecordLayerDTLS(conn io.ReadWriter, dir Direction) *RecordLayerImpl {
r := RecordLayerImpl{}
r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{true})
r.cipher = newCipherStateNull()
r.readCiphers = make(map[Epoch]*cipherState, 0)
r.readCiphers[0] = r.cipher
r.datagram = true
return &r
}
func (r *RecordLayerImpl) SetVersion(v uint16) {
r.version = v
}
func (r *RecordLayerImpl) ResetClear(seq uint64) {
r.cipher = newCipherStateNull()
r.cipher.seq = seq
}
func (r *RecordLayerImpl) Epoch() Epoch {
return r.cipher.epoch
}
func (r *RecordLayerImpl) SetLabel(s string) {
r.label = s
}
func (r *RecordLayerImpl) Rekey(epoch Epoch, factory AeadFactory, keys *KeySet) error {
cipher, err := newCipherStateAead(epoch, factory, keys.Key, keys.Iv)
if err != nil {
return err
}
r.cipher = cipher
if r.datagram && r.direction == DirectionRead {
r.readCiphers[epoch] = cipher
}
return nil
}
// TODO(ekr@rtfm.com): This is never used, which is a bug.
func (r *RecordLayerImpl) DiscardReadKey(epoch Epoch) {
if !r.datagram {
return
}
_, ok := r.readCiphers[epoch]
assert(ok)
delete(r.readCiphers, epoch)
}
func (c *cipherState) combineSeq(datagram bool) uint64 {
seq := c.seq
if datagram {
seq |= uint64(c.epoch) << 48
}
return seq
}
func (c *cipherState) computeNonce(seq uint64) []byte {
nonce := make([]byte, len(c.iv))
copy(nonce, c.iv)
s := seq
offset := len(c.iv)
for i := 0; i < 8; i++ {
nonce[(offset-i)-1] ^= byte(s & 0xff)
s >>= 8
}
logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
return nonce
}
func (c *cipherState) incrementSequenceNumber() {
if c.seq >= (1<<48 - 1) {
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother. This is the
// DTLS limit.
panic("TLS: sequence number wraparound")
}
c.seq++
}
func (c *cipherState) overhead() int {
if c.cipher == nil {
return 0
}
return c.cipher.Overhead()
}
func (r *RecordLayerImpl) encrypt(cipher *cipherState, seq uint64, header []byte, pt *TLSPlaintext, padLen int) []byte {
assert(r.direction == DirectionWrite)
logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
ciphertextLen := plaintextLen + cipher.overhead()
ciphertext := make([]byte, ciphertextLen)
copy(ciphertext, pt.fragment)
ciphertext[originalLen] = byte(pt.contentType)
for i := 1; i <= padLen; i++ {
ciphertext[originalLen+i] = 0
}
// Encrypt the fragment
payload := ciphertext[:plaintextLen]
cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, header)
return ciphertext
}
func (r *RecordLayerImpl) decrypt(seq uint64, header []byte, pt *TLSPlaintext) (*TLSPlaintext, int, error) {
assert(r.direction == DirectionRead)
logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
if len(pt.fragment) < r.cipher.overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
return nil, 0, DecryptError(msg)
}
decryptLen := len(pt.fragment) - r.cipher.overhead()
out := &TLSPlaintext{
contentType: pt.contentType,
fragment: make([]byte, decryptLen),
}
// Decrypt
_, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, header)
if err != nil {
logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
}
// Find the padding boundary
padLen := 0
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
}
// Transfer the content type
newLen := decryptLen - padLen - 1
out.contentType = RecordType(out.fragment[newLen])
// Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen]
out.seq = seq
return out, padLen, nil
}
func (r *RecordLayerImpl) PeekRecordType(block bool) (RecordType, error) {
var pt *TLSPlaintext
var err error
for {
pt, err = r.nextRecord(false)
if err == nil {
break
}
if !block || err != AlertWouldBlock {
return 0, err
}
}
return pt.contentType, nil
}
func (r *RecordLayerImpl) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord(false)
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayerImpl) ReadRecordAnyEpoch() (*TLSPlaintext, error) {
pt, err := r.nextRecord(true)
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayerImpl) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
cipher := r.cipher
if r.cachedRecord != nil {
logf(logTypeIO, "%s Returning cached record", r.label)
return r.cachedRecord, r.cachedError
}
// Loop until one of three things happens:
//
// 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case
// returnAlertWouldBlock
// 3. We get an error.
var err error
err = AlertWouldBlock
var header, body []byte
for err != nil {
if r.frame.needed() > 0 {
buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
n, err := r.conn.Read(buf)
if err != nil {
logf(logTypeIO, "%s Error reading, %v", r.label, err)
return nil, err
}
if n == 0 {
return nil, AlertWouldBlock
}
logf(logTypeIO, "%s Read %v bytes", r.label, n)
buf = buf[:n]
r.frame.addChunk(buf)
}
header, body, err = r.frame.process()
// Loop around onAlertWouldBlock to see if some
// data is now available.
if err != nil && err != AlertWouldBlock {
return nil, err
}
}
pt := &TLSPlaintext{}
// Validate content type
switch RecordType(header[0]) {
default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
pt.contentType = RecordType(header[0])
}
// Validate version
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
}
// Validate size < max
size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
if size > maxFragmentLen+256 {
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
}
pt.fragment = make([]byte, size)
copy(pt.fragment, body)
// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
// Attempt to decrypt fragment
seq := cipher.seq
if r.datagram {
// TODO(ekr@rtfm.com): Handle duplicates.
seq, _ = decodeUint(header[3:11], 8)
epoch := Epoch(seq >> 48)
// Look up the cipher suite from the epoch
c, ok := r.readCiphers[epoch]
if !ok {
logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
return nil, AlertWouldBlock
}
if epoch != cipher.epoch {
logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
cipher.epoch, allowOldEpoch)
if !allowOldEpoch {
return nil, AlertWouldBlock
}
cipher = c
}
}
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
pt, _, err = r.decrypt(seq, header, pt)
if err != nil {
logf(logTypeIO, "%s Decryption failed", r.label)
return nil, err
}
}
pt.epoch = cipher.epoch
// Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big")
}
logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
r.cachedRecord = pt
cipher.incrementSequenceNumber()
return pt, nil
}
func (r *RecordLayerImpl) WriteRecord(pt *TLSPlaintext) error {
return r.writeRecordWithPadding(pt, r.cipher, 0)
}
func (r *RecordLayerImpl) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
return r.writeRecordWithPadding(pt, r.cipher, padLen)
}
func (r *RecordLayerImpl) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
seq := cipher.combineSeq(r.datagram)
length := len(pt.fragment)
var contentType RecordType
if cipher.cipher != nil {
length += 1 + padLen + cipher.cipher.Overhead()
contentType = RecordTypeApplicationData
} else {
contentType = pt.contentType
}
var header []byte
if !r.datagram {
header = []byte{byte(contentType),
byte(r.version >> 8), byte(r.version & 0xff),
byte(length >> 8), byte(length)}
} else {
header = make([]byte, 13)
version := dtlsConvertVersion(r.version)
copy(header, []byte{byte(contentType),
byte(version >> 8), byte(version & 0xff),
})
encodeUint(seq, 8, header[3:])
encodeUint(uint64(length), 2, header[11:])
}
var ciphertext []byte
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
ciphertext = r.encrypt(cipher, seq, header, pt, padLen)
} else {
if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
}
ciphertext = pt.fragment
}
if len(ciphertext) > maxFragmentLen {
return fmt.Errorf("tls.record: Record size too big")
}
record := append(header, ciphertext...)
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, contentType, ciphertext)
cipher.incrementSequenceNumber()
_, err := r.conn.Write(record)
return err
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,247 +0,0 @@
package mint
import (
"crypto/x509"
"time"
)
// Marker interface for actions that an implementation should take based on
// state transitions.
type HandshakeAction interface{}
type QueueHandshakeMessage struct {
Message *HandshakeMessage
}
type SendQueuedHandshake struct{}
type SendEarlyData struct{}
type RekeyIn struct {
epoch Epoch
KeySet KeySet
}
type RekeyOut struct {
epoch Epoch
KeySet KeySet
}
type ResetOut struct {
seq uint64
}
type StorePSK struct {
PSK PreSharedKey
}
type HandshakeState interface {
Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert)
State() State
}
type AppExtensionHandler interface {
Send(hs HandshakeType, el *ExtensionList) error
Receive(hs HandshakeType, el *ExtensionList) error
}
// ConnectionOptions objects represent per-connection settings for a client
// initiating a connection
type ConnectionOptions struct {
ServerName string
NextProtos []string
}
// ConnectionParameters objects represent the parameters negotiated for a
// connection.
type ConnectionParameters struct {
UsingPSK bool
UsingDH bool
ClientSendingEarlyData bool
UsingEarlyData bool
RejectedEarlyData bool
UsingClientAuth bool
CipherSuite CipherSuite
ServerName string
NextProto string
}
// Working state for the handshake.
type HandshakeContext struct {
timeoutMS uint32
timers *timerSet
recvdRecords []uint64
sentFragments []*SentHandshakeFragment
hIn, hOut *HandshakeLayer
waitingNextFlight bool
earlyData []byte
}
func (hc *HandshakeContext) SetVersion(version uint16) {
if hc.hIn.conn != nil {
hc.hIn.conn.SetVersion(version)
}
if hc.hOut.conn != nil {
hc.hOut.conn.SetVersion(version)
}
}
// stateConnected is symmetric between client and server
type stateConnected struct {
Params ConnectionParameters
hsCtx *HandshakeContext
isClient bool
cryptoParams CipherSuiteParams
resumptionSecret []byte
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
peerCertificates []*x509.Certificate
verifiedChains [][]*x509.Certificate
}
var _ HandshakeState = &stateConnected{}
func (state stateConnected) State() State {
if state.isClient {
return StateClientConnected
}
return StateServerConnected
}
func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
var trafficKeys KeySet
if state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
QueueHandshakeMessage{kum},
SendQueuedHandshake{},
RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys},
}
return toSend, AlertNoAlert
}
func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
tkt, err := NewSessionTicket(length, lifetime)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
return nil, AlertInternalError
}
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
return nil, AlertInternalError
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
newPSK := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: tkt.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
TicketAgeAdd: tkt.TicketAgeAdd,
}
tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
StorePSK{newPSK},
QueueHandshakeMessage{tktm},
SendQueuedHandshake{},
}
return toSend, AlertNoAlert
}
// Next does nothing for this state.
func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
return state, nil, AlertNoAlert
}
func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[StateConnected] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *KeyUpdateBody:
var trafficKeys KeySet
if !state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}}
// If requested, roll outbound keys and send a KeyUpdate
if body.KeyUpdateRequest == KeyUpdateRequested {
logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest)
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
if alert != AlertNoAlert {
return nil, nil, alert
}
toSend = append(toSend, moreToSend...)
}
return state, toSend, AlertNoAlert
case *NewSessionTicketBody:
// XXX: Allow NewSessionTicket in both directions?
if !state.isClient {
return nil, nil, AlertUnexpectedMessage
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
psk := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: body.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
TicketAgeAdd: body.TicketAgeAdd,
}
toSend := []HandshakeAction{StorePSK{psk}}
return state, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}

View File

@@ -1,344 +0,0 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Unmarshal(data []byte, v interface{}) (int, error) {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
d := decodeState{}
d.Write(data)
return d.unmarshal(v)
}
// Unmarshaler is the interface implemented by types that can
// unmarshal a TLS description of themselves. Note that unlike the
// JSON unmarshaler interface, it is not known a priori how much of
// the input data will be consumed. So the Unmarshaler must state
// how much of the input data it consumed.
type Unmarshaler interface {
UnmarshalTLS([]byte) (int, error)
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type decOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
varint bool // whether to decode as a varint
}
type decodeState struct {
bytes.Buffer
}
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
}
read = d.value(rv)
return read, nil
}
func (e *decodeState) value(v reflect.Value) int {
return valueDecoder(v)(e, v, decOpts{})
}
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
func valueDecoder(v reflect.Value) decoderFunc {
return typeDecoder(v.Type().Elem())
}
func typeDecoder(t reflect.Type) decoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeDecoder(t)
}
var (
unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
)
func newTypeDecoder(t reflect.Type) decoderFunc {
if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) {
return unmarshalerDecoder
}
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintDecoder
case reflect.Array:
return newArrayDecoder(t)
case reflect.Slice:
return newSliceDecoder(t)
case reflect.Struct:
return newStructDecoder(t)
case reflect.Ptr:
return newPointerDecoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific decoders below
func unmarshalerDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
um, ok := v.Interface().(Unmarshaler)
if !ok {
panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder"))
}
read, err := um.UnmarshalTLS(d.Bytes())
if err != nil {
panic(err)
}
if read > d.Len() {
panic(fmt.Errorf("Invalid return value from UnmarshalTLS"))
}
d.Next(read)
return read
}
//////////
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
if opts.varint {
return varintDecoder(d, v, opts)
}
uintLen := int(v.Elem().Type().Size())
buf := d.Next(uintLen)
if len(buf) != uintLen {
panic(fmt.Errorf("Insufficient data to read uint"))
}
return setUintFromBuffer(v, buf)
}
func varintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
l, val := readVarint(d)
uintLen := int(v.Elem().Type().Size())
if uintLen < l {
panic(fmt.Errorf("Uint too small to fit varint: %d < %d", uintLen, l))
}
v.Elem().SetUint(val)
return l
}
func readVarint(d *decodeState) (int, uint64) {
// Read the first octet and decide the size of the presented varint
first := d.Next(1)
if len(first) != 1 {
panic(fmt.Errorf("Insufficient data to read varint length"))
}
twoBits := uint(first[0] >> 6)
varintLen := 1 << twoBits
rest := d.Next(varintLen - 1)
if len(rest) != varintLen-1 {
panic(fmt.Errorf("Insufficient data to read varint"))
}
buf := append(first, rest...)
buf[0] &= 0x3f
return len(buf), decodeUintFromBuffer(buf)
}
func decodeUintFromBuffer(buf []byte) uint64 {
val := uint64(0)
for _, b := range buf {
val = (val << 8) + uint64(b)
}
return val
}
func setUintFromBuffer(v reflect.Value, buf []byte) int {
v.Elem().SetUint(decodeUintFromBuffer(buf))
return len(buf)
}
//////////
type arrayDecoder struct {
elemDec decoderFunc
}
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
n := v.Elem().Type().Len()
read := 0
for i := 0; i < n; i += 1 {
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
}
return read
}
func newArrayDecoder(t reflect.Type) decoderFunc {
dec := &arrayDecoder{typeDecoder(t.Elem())}
return dec.decode
}
//////////
type sliceDecoder struct {
elementType reflect.Type
elementDec decoderFunc
}
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
var length uint64
var read int
var data []byte
if opts.head == 0 {
panic(fmt.Errorf("Cannot decode a slice without a header length"))
}
// If the caller indicated there is no header, then read everything from the buffer
if opts.head == headValueNoHead {
for {
chunk := d.Next(1024)
data = append(data, chunk...)
if len(chunk) != 1024 {
break
}
}
length = uint64(len(data))
if opts.max > 0 && length > uint64(opts.max) {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < uint64(opts.min) {
panic(fmt.Errorf("Length of vector below declared min"))
}
} else {
if opts.head != headValueVarint {
lengthBytes := d.Next(int(opts.head))
if len(lengthBytes) != int(opts.head) {
panic(fmt.Errorf("Not enough data to read header"))
}
read = len(lengthBytes)
length = decodeUintFromBuffer(lengthBytes)
} else {
read, length = readVarint(d)
}
if opts.max > 0 && length > uint64(opts.max) {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < uint64(opts.min) {
panic(fmt.Errorf("Length of vector below declared min"))
}
data = d.Next(int(length))
if len(data) != int(length) {
panic(fmt.Errorf("Available data less than declared length [%d < %d]", len(data), length))
}
}
elemBuf := &decodeState{}
elemBuf.Write(data)
elems := []reflect.Value{}
for elemBuf.Len() > 0 {
elem := reflect.New(sd.elementType)
read += sd.elementDec(elemBuf, elem, opts)
elems = append(elems, elem)
}
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
for i := 0; i < len(elems); i += 1 {
v.Elem().Index(i).Set(elems[i].Elem())
}
return read
}
func newSliceDecoder(t reflect.Type) decoderFunc {
dec := &sliceDecoder{
elementType: t.Elem(),
elementDec: typeDecoder(t.Elem()),
}
return dec.decode
}
//////////
type structDecoder struct {
fieldOpts []decOpts
fieldDecs []decoderFunc
}
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
read := 0
for i := range sd.fieldDecs {
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
}
return read
}
func newStructDecoder(t reflect.Type) decoderFunc {
n := t.NumField()
sd := structDecoder{
fieldOpts: make([]decOpts, n),
fieldDecs: make([]decoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
sd.fieldOpts[i] = decOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
varint: tagOpts[varintOption] > 0,
}
sd.fieldDecs[i] = typeDecoder(f.Type)
}
return sd.decode
}
//////////
type pointerDecoder struct {
base decoderFunc
}
func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
v.Elem().Set(reflect.New(v.Elem().Type().Elem()))
return pd.base(d, v.Elem(), opts)
}
func newPointerDecoder(t reflect.Type) decoderFunc {
baseDecoder := typeDecoder(t.Elem())
pd := pointerDecoder{base: baseDecoder}
return pd.decode
}

View File

@@ -1,276 +0,0 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Marshal(v interface{}) ([]byte, error) {
e := &encodeState{}
err := e.marshal(v, encOpts{})
if err != nil {
return nil, err
}
return e.Bytes(), nil
}
// Marshaler is the interface implemented by types that
// have a defined TLS encoding.
type Marshaler interface {
MarshalTLS() ([]byte, error)
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type encOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
varint bool // whether to encode as a varint
}
type encodeState struct {
bytes.Buffer
}
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
e.reflectValue(reflect.ValueOf(v), opts)
return nil
}
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
valueEncoder(v)(e, v, opts)
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
panic(fmt.Errorf("Cannot encode an invalid value"))
}
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type) encoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeEncoder(t)
}
var (
marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
)
func newTypeEncoder(t reflect.Type) encoderFunc {
if t.Implements(marshalerType) {
return marshalerEncoder
}
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintEncoder
case reflect.Array:
return newArrayEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Struct:
return newStructEncoder(t)
case reflect.Ptr:
return newPointerEncoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific encoders below
func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Ptr && v.IsNil() {
panic(fmt.Errorf("Cannot encode nil pointer"))
}
m, ok := v.Interface().(Marshaler)
if !ok {
panic(fmt.Errorf("Non-Marshaler passed to marshalerEncoder"))
}
b, err := m.MarshalTLS()
if err == nil {
_, err = e.Write(b)
}
if err != nil {
panic(err)
}
}
//////////
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if opts.varint {
varintEncoder(e, v, opts)
return
}
writeUint(e, v.Uint(), int(v.Type().Size()))
}
func varintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
writeVarint(e, v.Uint())
}
func writeVarint(e *encodeState, u uint64) {
if (u >> 62) > 0 {
panic(fmt.Errorf("uint value is too big for varint"))
}
var varintLen int
for _, len := range []uint{1, 2, 4, 8} {
if u < (uint64(1) << (8*len - 2)) {
varintLen = int(len)
break
}
}
twoBits := map[int]uint64{1: 0x00, 2: 0x01, 4: 0x02, 8: 0x03}[varintLen]
shift := uint(8*varintLen - 2)
writeUint(e, u|(twoBits<<shift), varintLen)
}
func writeUint(e *encodeState, u uint64, len int) {
data := make([]byte, len)
for i := 0; i < len; i += 1 {
data[i] = byte(u >> uint(8*(len-i-1)))
}
e.Write(data)
}
//////////
type arrayEncoder struct {
elemEnc encoderFunc
}
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
n := v.Len()
for i := 0; i < n; i += 1 {
ae.elemEnc(e, v.Index(i), opts)
}
}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := &arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
//////////
type sliceEncoder struct {
ae *arrayEncoder
}
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
arrayState := &encodeState{}
se.ae.encode(arrayState, v, opts)
n := uint(arrayState.Len())
if opts.head == 0 {
panic(fmt.Errorf("Cannot encode a slice without a header length"))
}
if opts.max > 0 && n > opts.max {
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
}
if n < opts.min {
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
}
switch opts.head {
case headValueNoHead:
// None.
case headValueVarint:
writeVarint(e, uint64(n))
default:
if n>>(8*opts.head) > 0 {
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
}
writeUint(e, uint64(n), int(opts.head))
}
e.Write(arrayState.Bytes())
}
func newSliceEncoder(t reflect.Type) encoderFunc {
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
return enc.encode
}
//////////
type structEncoder struct {
fieldOpts []encOpts
fieldEncs []encoderFunc
}
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
for i := range se.fieldEncs {
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
}
}
func newStructEncoder(t reflect.Type) encoderFunc {
n := t.NumField()
se := structEncoder{
fieldOpts: make([]encOpts, n),
fieldEncs: make([]encoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
se.fieldOpts[i] = encOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
varint: tagOpts[varintOption] > 0,
}
se.fieldEncs[i] = typeEncoder(f.Type)
}
return se.encode
}
//////////
type pointerEncoder struct {
base encoderFunc
}
func (pe pointerEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
panic(fmt.Errorf("Cannot marshal a struct containing a nil pointer"))
}
pe.base(e, v.Elem(), opts)
}
func newPointerEncoder(t reflect.Type) encoderFunc {
baseEncoder := typeEncoder(t.Elem())
pe := pointerEncoder{base: baseEncoder}
return pe.encode
}

View File

@@ -1,50 +0,0 @@
package syntax
import (
"strconv"
"strings"
)
// `tls:"head=2,min=2,max=255,varint"`
type tagOptions map[string]uint
var (
varintOption = "varint"
headOptionNone = "none"
headOptionVarint = "varint"
headValueNoHead = uint(255)
headValueVarint = uint(254)
)
// parseTag parses a struct field's "tls" tag as a comma-separated list of
// name=value pairs, where the values MUST be unsigned integers, or in
// the special case of head, "none" or "varint"
func parseTag(tag string) tagOptions {
opts := tagOptions{}
for _, token := range strings.Split(tag, ",") {
if token == varintOption {
opts[varintOption] = 1
continue
}
parts := strings.Split(token, "=")
if len(parts[0]) == 0 {
continue
}
if len(parts) == 1 {
continue
}
if parts[0] == "head" && parts[1] == headOptionNone {
opts[parts[0]] = headValueNoHead
} else if parts[0] == "head" && parts[1] == headOptionVarint {
opts[parts[0]] = headValueVarint
} else if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
opts[parts[0]] = uint(val)
}
}
return opts
}

View File

@@ -1,122 +0,0 @@
package mint
import (
"time"
)
// This is a simple timer implementation. Timers are stored in a sorted
// list.
// TODO(ekr@rtfm.com): Add a way to uncouple these from the system
// clock.
type timerCb func() error
type timer struct {
label string
cb timerCb
deadline time.Time
duration uint32
}
type timerSet struct {
ts []*timer
}
func newTimerSet() *timerSet {
return &timerSet{}
}
func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer {
now := time.Now()
t := timer{
label,
cb,
now.Add(time.Millisecond * time.Duration(delayMs)),
delayMs,
}
logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline)
var i int
ntimers := len(ts.ts)
for i = 0; i < ntimers; i++ {
if t.deadline.Before(ts.ts[i].deadline) {
break
}
}
tmp := make([]*timer, 0, ntimers+1)
tmp = append(tmp, ts.ts[:i]...)
tmp = append(tmp, &t)
tmp = append(tmp, ts.ts[i:]...)
ts.ts = tmp
return &t
}
// TODO(ekr@rtfm.com): optimize this now that the list is sorted.
// We should be able to do just one list manipulation, as long
// as we're careful about how we handle inserts during callbacks.
func (ts *timerSet) check(now time.Time) error {
for i, t := range ts.ts {
if now.After(t.deadline) {
ts.ts = append(ts.ts[:i], ts.ts[:i+1]...)
if t.cb != nil {
logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline)
cb := t.cb
t.cb = nil
err := cb()
if err != nil {
return err
}
}
} else {
break
}
}
return nil
}
// Returns the next time any of the timers would fire.
func (ts *timerSet) remaining() (bool, time.Duration) {
for _, t := range ts.ts {
if t.cb != nil {
return true, time.Until(t.deadline)
}
}
return false, time.Duration(0)
}
func (ts *timerSet) cancel(label string) {
for _, t := range ts.ts {
if t.label == label {
t.cancel()
}
}
}
func (ts *timerSet) getTimer(label string) *timer {
for _, t := range ts.ts {
if t.label == label && t.cb != nil {
return t
}
}
return nil
}
func (ts *timerSet) getAllTimers() []string {
var ret []string
for _, t := range ts.ts {
if t.cb != nil {
ret = append(ret, t.label)
}
}
return ret
}
func (t *timer) cancel() {
logf(logTypeHandshake, "Timer %s cancelled", t.label)
t.cb = nil
t.label = ""
}

View File

@@ -1,179 +0,0 @@
package mint
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
import (
"errors"
"net"
"strings"
"time"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, false)
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, true)
}
// A listener implements a network listener (net.Listener) for TLS connections.
type Listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection c is a *tls.Conn.
func (l *Listener) Accept() (c net.Conn, err error) {
c, err = l.Listener.Accept()
if err != nil {
return
}
server := Server(c, l.config)
err = server.Handshake()
if err == AlertNoAlert {
err = nil
}
c = server
return
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) (net.Listener, error) {
if config != nil && config.NonBlocking {
return nil, errors.New("listening not possible in non-blocking mode")
}
l := new(Listener)
l.Listener = inner
l.config = config
return l, nil
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || !config.ValidForServer() {
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config)
}
type TimeoutError struct{}
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (TimeoutError) Timeout() bool { return true }
func (TimeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if config != nil && config.NonBlocking {
return nil, errors.New("dialing not possible in non-blocking mode")
}
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- TimeoutError{}
})
}
rawConn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = &Config{}
} else {
config = config.Clone()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
config.ServerName = hostname
}
// Set up DTLS as needed.
config.UseDTLS = (network == "udp")
conn := Client(rawConn, config)
if timeout == 0 {
err = conn.Handshake()
if err == AlertNoAlert {
err = nil
}
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
if err == AlertNoAlert {
err = nil
}
}
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}

View File

@@ -1,51 +1,5 @@
// +build !windows
/*
Package sockets is a simple unix domain socket wrapper.
Usage
For example:
import(
"fmt"
"net"
"os"
"github.com/docker/go-connections/sockets"
)
func main() {
l, err := sockets.NewUnixSocketWithOpts("/path/to/sockets",
sockets.WithChown(0,0),sockets.WithChmod(0660))
if err != nil {
panic(err)
}
echoStr := "hello"
go func() {
for {
conn, err := l.Accept()
if err != nil {
return
}
conn.Write([]byte(echoStr))
conn.Close()
}
}()
conn, err := net.Dial("unix", path)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 5)
if _, err := conn.Read(buf); err != nil {
panic(err)
} else if string(buf) != echoStr {
panic(fmt.Errorf("Msg may lost"))
}
}
*/
package sockets
import (
@@ -54,31 +8,8 @@ import (
"syscall"
)
// SockOption sets up socket file's creating option
type SockOption func(string) error
// WithChown modifies the socket file's uid and gid
func WithChown(uid, gid int) SockOption {
return func(path string) error {
if err := os.Chown(path, uid, gid); err != nil {
return err
}
return nil
}
}
// WithChmod modifies socket file's access mode
func WithChmod(mask os.FileMode) SockOption {
return func(path string) error {
if err := os.Chmod(path, mask); err != nil {
return err
}
return nil
}
}
// NewUnixSocketWithOpts creates a unix socket with the specified options
func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error) {
// NewUnixSocket creates a unix socket with the specified path and group.
func NewUnixSocket(path string, gid int) (net.Listener, error) {
if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) {
return nil, err
}
@@ -89,18 +20,13 @@ func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error
if err != nil {
return nil, err
}
for _, op := range opts {
if err := op(path); err != nil {
l.Close()
return nil, err
}
if err := os.Chown(path, 0, gid); err != nil {
l.Close()
return nil, err
}
if err := os.Chmod(path, 0660); err != nil {
l.Close()
return nil, err
}
return l, nil
}
// NewUnixSocket creates a unix socket with the specified path and group.
func NewUnixSocket(path string, gid int) (net.Listener, error) {
return NewUnixSocketWithOpts(path, WithChown(0, gid), WithChmod(0660))
}

View File

@@ -18,7 +18,7 @@ func HumanDuration(d time.Duration) string {
return fmt.Sprintf("%d seconds", seconds)
} else if minutes := int(d.Minutes()); minutes == 1 {
return "About a minute"
} else if minutes < 46 {
} else if minutes < 60 {
return fmt.Sprintf("%d minutes", minutes)
} else if hours := int(d.Hours() + 0.5); hours == 1 {
return "About an hour"

View File

@@ -96,8 +96,13 @@ func ParseUlimit(val string) (*Ulimit, error) {
return nil, fmt.Errorf("too many limit value arguments - %s, can only have up to two, `soft[:hard]`", parts[1])
}
if soft > *hard {
return nil, fmt.Errorf("ulimit soft limit must be less than or equal to hard limit: %d > %d", soft, *hard)
if *hard != -1 {
if soft == -1 {
return nil, fmt.Errorf("ulimit soft limit must be less than or equal to hard limit: soft: -1 (unlimited), hard: %d", *hard)
}
if soft > *hard {
return nil, fmt.Errorf("ulimit soft limit must be less than or equal to hard limit: %d > %d", soft, *hard)
}
}
return &Ulimit{Name: parts[0], Soft: soft, Hard: *hard}, nil

View File

@@ -45,14 +45,14 @@ func (c CurlyRouter) SelectRoute(
// selectRoutes return a collection of Route from a WebService that matches the path tokens from the request.
func (c CurlyRouter) selectRoutes(ws *WebService, requestTokens []string) sortableCurlyRoutes {
candidates := sortableCurlyRoutes{}
candidates := make(sortableCurlyRoutes, 0, 8)
for _, each := range ws.routes {
matches, paramCount, staticCount := c.matchesRouteByPathTokens(each.pathParts, requestTokens)
if matches {
candidates.add(curlyRoute{each, paramCount, staticCount}) // TODO make sure Routes() return pointers?
}
}
sort.Sort(sort.Reverse(candidates))
sort.Sort(candidates)
return candidates
}

View File

@@ -11,6 +11,7 @@ type curlyRoute struct {
staticCount int
}
// sortableCurlyRoutes orders by most parameters and path elements first.
type sortableCurlyRoutes []curlyRoute
func (s *sortableCurlyRoutes) add(route curlyRoute) {
@@ -18,6 +19,7 @@ func (s *sortableCurlyRoutes) add(route curlyRoute) {
}
func (s sortableCurlyRoutes) routes() (routes []Route) {
routes = make([]Route, 0, len(s))
for _, each := range s {
routes = append(routes, each.route) // TODO change return type
}
@@ -31,22 +33,22 @@ func (s sortableCurlyRoutes) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (s sortableCurlyRoutes) Less(i, j int) bool {
ci := s[i]
cj := s[j]
a := s[j]
b := s[i]
// primary key
if ci.staticCount < cj.staticCount {
if a.staticCount < b.staticCount {
return true
}
if ci.staticCount > cj.staticCount {
if a.staticCount > b.staticCount {
return false
}
// secundary key
if ci.paramCount < cj.paramCount {
if a.paramCount < b.paramCount {
return true
}
if ci.paramCount > cj.paramCount {
if a.paramCount > b.paramCount {
return false
}
return ci.route.Path < cj.route.Path
return a.route.Path < b.route.Path
}

View File

@@ -66,8 +66,8 @@ func (RouterJSR311) extractParams(pathExpr *pathExpression, matches []string) ma
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2
func (r RouterJSR311) detectRoute(routes []Route, httpRequest *http.Request) (*Route, error) {
ifOk := []Route{}
for _, each := range routes {
candidates := make([]*Route, 0, 8)
for i, each := range routes {
ok := true
for _, fn := range each.If {
if !fn(httpRequest) {
@@ -76,10 +76,10 @@ func (r RouterJSR311) detectRoute(routes []Route, httpRequest *http.Request) (*R
}
}
if ok {
ifOk = append(ifOk, each)
candidates = append(candidates, &routes[i])
}
}
if len(ifOk) == 0 {
if len(candidates) == 0 {
if trace {
traceLogger.Printf("no Route found (from %d) that passes conditional checks", len(routes))
}
@@ -87,53 +87,58 @@ func (r RouterJSR311) detectRoute(routes []Route, httpRequest *http.Request) (*R
}
// http method
methodOk := []Route{}
for _, each := range ifOk {
previous := candidates
candidates = candidates[:0]
for _, each := range previous {
if httpRequest.Method == each.Method {
methodOk = append(methodOk, each)
candidates = append(candidates, each)
}
}
if len(methodOk) == 0 {
if len(candidates) == 0 {
if trace {
traceLogger.Printf("no Route found (in %d routes) that matches HTTP method %s\n", len(routes), httpRequest.Method)
traceLogger.Printf("no Route found (in %d routes) that matches HTTP method %s\n", len(previous), httpRequest.Method)
}
return nil, NewError(http.StatusMethodNotAllowed, "405: Method Not Allowed")
}
// content-type
contentType := httpRequest.Header.Get(HEADER_ContentType)
inputMediaOk := []Route{}
for _, each := range methodOk {
previous = candidates
candidates = candidates[:0]
for _, each := range previous {
if each.matchesContentType(contentType) {
inputMediaOk = append(inputMediaOk, each)
candidates = append(candidates, each)
}
}
if len(inputMediaOk) == 0 {
if len(candidates) == 0 {
if trace {
traceLogger.Printf("no Route found (from %d) that matches HTTP Content-Type: %s\n", len(methodOk), contentType)
traceLogger.Printf("no Route found (from %d) that matches HTTP Content-Type: %s\n", len(previous), contentType)
}
if httpRequest.ContentLength > 0 {
return nil, NewError(http.StatusUnsupportedMediaType, "415: Unsupported Media Type")
}
return nil, NewError(http.StatusUnsupportedMediaType, "415: Unsupported Media Type")
}
// accept
outputMediaOk := []Route{}
previous = candidates
candidates = candidates[:0]
accept := httpRequest.Header.Get(HEADER_Accept)
if len(accept) == 0 {
accept = "*/*"
}
for _, each := range inputMediaOk {
for _, each := range previous {
if each.matchesAccept(accept) {
outputMediaOk = append(outputMediaOk, each)
candidates = append(candidates, each)
}
}
if len(outputMediaOk) == 0 {
if len(candidates) == 0 {
if trace {
traceLogger.Printf("no Route found (from %d) that matches HTTP Accept: %s\n", len(inputMediaOk), accept)
traceLogger.Printf("no Route found (from %d) that matches HTTP Accept: %s\n", len(previous), accept)
}
return nil, NewError(http.StatusNotAcceptable, "406: Not Acceptable")
}
// return r.bestMatchByMedia(outputMediaOk, contentType, accept), nil
return &outputMediaOk[0], nil
return candidates[0], nil
}
// http://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-360003.7.2

View File

@@ -22,7 +22,10 @@ func insertMime(l []mime, e mime) []mime {
return append(l, e)
}
const qFactorWeightingKey = "q"
// sortedMimes returns a list of mime sorted (desc) by its specified quality.
// e.g. text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3
func sortedMimes(accept string) (sorted []mime) {
for _, each := range strings.Split(accept, ",") {
typeAndQuality := strings.Split(strings.Trim(each, " "), ";")
@@ -30,14 +33,16 @@ func sortedMimes(accept string) (sorted []mime) {
sorted = insertMime(sorted, mime{typeAndQuality[0], 1.0})
} else {
// take factor
parts := strings.Split(typeAndQuality[1], "=")
if len(parts) == 2 {
f, err := strconv.ParseFloat(parts[1], 64)
qAndWeight := strings.Split(typeAndQuality[1], "=")
if len(qAndWeight) == 2 && strings.Trim(qAndWeight[0], " ") == qFactorWeightingKey {
f, err := strconv.ParseFloat(qAndWeight[1], 64)
if err != nil {
traceLogger.Printf("unable to parse quality in %s, %v", each, err)
} else {
sorted = insertMime(sorted, mime{typeAndQuality[0], f})
}
} else {
sorted = insertMime(sorted, mime{typeAndQuality[0], 1.0})
}
}
}

View File

@@ -38,6 +38,7 @@ type Route struct {
Operation string
ParameterDocs []*Parameter
ResponseErrors map[int]ResponseError
DefaultResponse *ResponseError
ReadSample, WriteSample interface{} // structs that model an example request or response payload
// Extra information used to store custom information about the route.
@@ -77,28 +78,36 @@ func (r *Route) dispatchWithFilters(wrappedRequest *Request, wrappedResponse *Re
}
}
func stringTrimSpaceCutset(r rune) bool {
return r == ' '
}
// Return whether the mimeType matches to what this Route can produce.
func (r Route) matchesAccept(mimeTypesWithQuality string) bool {
parts := strings.Split(mimeTypesWithQuality, ",")
for _, each := range parts {
var withoutQuality string
if strings.Contains(each, ";") {
withoutQuality = strings.Split(each, ";")[0]
remaining := mimeTypesWithQuality
for {
var mimeType string
if end := strings.Index(remaining, ","); end == -1 {
mimeType, remaining = remaining, ""
} else {
withoutQuality = each
mimeType, remaining = remaining[:end], remaining[end+1:]
}
// trim before compare
withoutQuality = strings.Trim(withoutQuality, " ")
if withoutQuality == "*/*" {
if quality := strings.Index(mimeType, ";"); quality != -1 {
mimeType = mimeType[:quality]
}
mimeType = strings.TrimFunc(mimeType, stringTrimSpaceCutset)
if mimeType == "*/*" {
return true
}
for _, producibleType := range r.Produces {
if producibleType == "*/*" || producibleType == withoutQuality {
if producibleType == "*/*" || producibleType == mimeType {
return true
}
}
if len(remaining) == 0 {
return false
}
}
return false
}
// Return whether this Route can consume content with a type specified by mimeTypes (can be empty).
@@ -119,29 +128,33 @@ func (r Route) matchesContentType(mimeTypes string) bool {
mimeTypes = MIME_OCTET
}
parts := strings.Split(mimeTypes, ",")
for _, each := range parts {
var contentType string
if strings.Contains(each, ";") {
contentType = strings.Split(each, ";")[0]
remaining := mimeTypes
for {
var mimeType string
if end := strings.Index(remaining, ","); end == -1 {
mimeType, remaining = remaining, ""
} else {
contentType = each
mimeType, remaining = remaining[:end], remaining[end+1:]
}
// trim before compare
contentType = strings.Trim(contentType, " ")
if quality := strings.Index(mimeType, ";"); quality != -1 {
mimeType = mimeType[:quality]
}
mimeType = strings.TrimFunc(mimeType, stringTrimSpaceCutset)
for _, consumeableType := range r.Consumes {
if consumeableType == "*/*" || consumeableType == contentType {
if consumeableType == "*/*" || consumeableType == mimeType {
return true
}
}
if len(remaining) == 0 {
return false
}
}
return false
}
// Tokenize an URL path using the slash separator ; the result does not have empty tokens
func tokenizePath(path string) []string {
if "/" == path {
return []string{}
return nil
}
return strings.Split(strings.Trim(path, "/"), "/")
}

View File

@@ -35,8 +35,10 @@ type RouteBuilder struct {
readSample, writeSample interface{}
parameters []*Parameter
errorMap map[int]ResponseError
defaultResponse *ResponseError
metadata map[string]interface{}
deprecated bool
contentEncodingEnabled *bool
}
// Do evaluates each argument with the RouteBuilder itself.
@@ -164,7 +166,7 @@ func (b *RouteBuilder) Returns(code int, message string, model interface{}) *Rou
Code: code,
Message: message,
Model: model,
IsDefault: false,
IsDefault: false, // this field is deprecated, use default response instead.
}
// lazy init because there is no NewRouteBuilder (yet)
if b.errorMap == nil {
@@ -174,17 +176,11 @@ func (b *RouteBuilder) Returns(code int, message string, model interface{}) *Rou
return b
}
// DefaultReturns is a special Returns call that sets the default of the response ; the code is zero.
// DefaultReturns is a special Returns call that sets the default of the response.
func (b *RouteBuilder) DefaultReturns(message string, model interface{}) *RouteBuilder {
b.Returns(0, message, model)
// Modify the ResponseError just added/updated
re := b.errorMap[0]
// errorMap is initialized
b.errorMap[0] = ResponseError{
Code: re.Code,
Message: re.Message,
Model: re.Model,
IsDefault: true,
b.defaultResponse = &ResponseError{
Message: message,
Model: model,
}
return b
}
@@ -238,6 +234,12 @@ func (b *RouteBuilder) If(condition RouteSelectionConditionFunction) *RouteBuild
return b
}
// ContentEncodingEnabled allows you to override the Containers value for auto-compressing this route response.
func (b *RouteBuilder) ContentEncodingEnabled(enabled bool) *RouteBuilder {
b.contentEncodingEnabled = &enabled
return b
}
// If no specific Route path then set to rootPath
// If no specific Produces then set to rootProduces
// If no specific Consumes then set to rootConsumes
@@ -274,24 +276,27 @@ func (b *RouteBuilder) Build() Route {
operationName = nameOfFunction(b.function)
}
route := Route{
Method: b.httpMethod,
Path: concatPath(b.rootPath, b.currentPath),
Produces: b.produces,
Consumes: b.consumes,
Function: b.function,
Filters: b.filters,
If: b.conditions,
relativePath: b.currentPath,
pathExpr: pathExpr,
Doc: b.doc,
Notes: b.notes,
Operation: operationName,
ParameterDocs: b.parameters,
ResponseErrors: b.errorMap,
ReadSample: b.readSample,
WriteSample: b.writeSample,
Metadata: b.metadata,
Deprecated: b.deprecated}
Method: b.httpMethod,
Path: concatPath(b.rootPath, b.currentPath),
Produces: b.produces,
Consumes: b.consumes,
Function: b.function,
Filters: b.filters,
If: b.conditions,
relativePath: b.currentPath,
pathExpr: pathExpr,
Doc: b.doc,
Notes: b.notes,
Operation: operationName,
ParameterDocs: b.parameters,
ResponseErrors: b.errorMap,
DefaultResponse: b.defaultResponse,
ReadSample: b.readSample,
WriteSample: b.writeSample,
Metadata: b.metadata,
Deprecated: b.deprecated,
contentEncodingEnabled: b.contentEncodingEnabled,
}
route.postBuild()
return route
}

38
vendor/github.com/evanphx/json-patch/errors.go generated vendored Normal file
View File

@@ -0,0 +1,38 @@
package jsonpatch
import "fmt"
// AccumulatedCopySizeError is an error type returned when the accumulated size
// increase caused by copy operations in a patch operation has exceeded the
// limit.
type AccumulatedCopySizeError struct {
limit int64
accumulated int64
}
// NewAccumulatedCopySizeError returns an AccumulatedCopySizeError.
func NewAccumulatedCopySizeError(l, a int64) *AccumulatedCopySizeError {
return &AccumulatedCopySizeError{limit: l, accumulated: a}
}
// Error implements the error interface.
func (a *AccumulatedCopySizeError) Error() string {
return fmt.Sprintf("Unable to complete the copy, the accumulated size increase of copy is %d, exceeding the limit %d", a.accumulated, a.limit)
}
// ArraySizeError is an error type returned when the array size has exceeded
// the limit.
type ArraySizeError struct {
limit int
size int
}
// NewArraySizeError returns an ArraySizeError.
func NewArraySizeError(l, s int) *ArraySizeError {
return &ArraySizeError{limit: l, size: s}
}
// Error implements the error interface.
func (a *ArraySizeError) Error() string {
return fmt.Sprintf("Unable to create array of size %d, limit is %d", a.size, a.limit)
}

View File

@@ -14,7 +14,15 @@ const (
eAry
)
var SupportNegativeIndices bool = true
var (
// SupportNegativeIndices decides whether to support non-standard practice of
// allowing negative indices to mean indices starting at the end of an array.
// Default to true.
SupportNegativeIndices bool = true
// AccumulatedCopySizeLimit limits the total size increase in bytes caused by
// "copy" operations in a patch.
AccumulatedCopySizeLimit int64 = 0
)
type lazyNode struct {
raw *json.RawMessage
@@ -63,6 +71,20 @@ func (n *lazyNode) UnmarshalJSON(data []byte) error {
return nil
}
func deepCopy(src *lazyNode) (*lazyNode, int, error) {
if src == nil {
return nil, 0, nil
}
a, err := src.MarshalJSON()
if err != nil {
return nil, 0, err
}
sz := len(a)
ra := make(json.RawMessage, sz)
copy(ra, a)
return newLazyNode(&ra), sz, nil
}
func (n *lazyNode) intoDoc() (*partialDoc, error) {
if n.which == eDoc {
return &n.doc, nil
@@ -344,35 +366,14 @@ func (d *partialDoc) remove(key string) error {
return nil
}
// set should only be used to implement the "replace" operation, so "key" must
// be an already existing index in "d".
func (d *partialArray) set(key string, val *lazyNode) error {
if key == "-" {
*d = append(*d, val)
return nil
}
idx, err := strconv.Atoi(key)
if err != nil {
return err
}
sz := len(*d)
if idx+1 > sz {
sz = idx + 1
}
ary := make([]*lazyNode, sz)
cur := *d
copy(ary, cur)
if idx >= len(ary) {
return fmt.Errorf("Unable to access invalid index: %d", idx)
}
ary[idx] = val
*d = ary
(*d)[idx] = val
return nil
}
@@ -387,7 +388,9 @@ func (d *partialArray) add(key string, val *lazyNode) error {
return err
}
ary := make([]*lazyNode, len(*d)+1)
sz := len(*d) + 1
ary := make([]*lazyNode, sz)
cur := *d
@@ -527,7 +530,7 @@ func (p Patch) move(doc *container, op operation) error {
return fmt.Errorf("jsonpatch move operation does not apply: doc is missing destination path: %s", path)
}
return con.set(key, val)
return con.add(key, val)
}
func (p Patch) test(doc *container, op operation) error {
@@ -561,7 +564,7 @@ func (p Patch) test(doc *container, op operation) error {
return fmt.Errorf("Testing value %s failed", path)
}
func (p Patch) copy(doc *container, op operation) error {
func (p Patch) copy(doc *container, op operation, accumulatedCopySize *int64) error {
from := op.from()
con, key := findObject(doc, from)
@@ -583,7 +586,16 @@ func (p Patch) copy(doc *container, op operation) error {
return fmt.Errorf("jsonpatch copy operation does not apply: doc is missing destination path: %s", path)
}
return con.set(key, val)
valCopy, sz, err := deepCopy(val)
if err != nil {
return err
}
(*accumulatedCopySize) += int64(sz)
if AccumulatedCopySizeLimit > 0 && *accumulatedCopySize > AccumulatedCopySizeLimit {
return NewAccumulatedCopySizeError(AccumulatedCopySizeLimit, *accumulatedCopySize)
}
return con.add(key, valCopy)
}
// Equal indicates if 2 JSON documents have the same structural equality.
@@ -636,6 +648,8 @@ func (p Patch) ApplyIndent(doc []byte, indent string) ([]byte, error) {
err = nil
var accumulatedCopySize int64
for _, op := range p {
switch op.kind() {
case "add":
@@ -649,7 +663,7 @@ func (p Patch) ApplyIndent(doc []byte, indent string) ([]byte, error) {
case "test":
err = p.test(&pd, op)
case "copy":
err = p.copy(&pd, op)
err = p.copy(&pd, op, &accumulatedCopySize)
default:
err = fmt.Errorf("Unexpected kind: %s", op.kind())
}

View File

@@ -1,11 +1,11 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"encoding/base64"
"errors"
"fmt"
"github.com/xenolf/lego/acme"
"github.com/go-acme/lego/acme"
)
type AccountService service

View File

@@ -1,4 +1,4 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"bytes"
@@ -11,11 +11,11 @@ import (
"time"
"github.com/cenkalti/backoff"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api/internal/nonces"
"github.com/xenolf/lego/acme/api/internal/secure"
"github.com/xenolf/lego/acme/api/internal/sender"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/acme/api/internal/nonces"
"github.com/go-acme/lego/acme/api/internal/secure"
"github.com/go-acme/lego/acme/api/internal/sender"
"github.com/go-acme/lego/log"
)
// Core ACME/LE core API.

View File

@@ -1,9 +1,9 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"errors"
"github.com/xenolf/lego/acme"
"github.com/go-acme/lego/acme"
)
type AuthorizationService service

View File

@@ -1,4 +1,4 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"crypto/x509"
@@ -7,9 +7,9 @@ import (
"io/ioutil"
"net/http"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/certcrypto"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/certcrypto"
"github.com/go-acme/lego/log"
)
// maxBodySize is the maximum size of body that we will read.

View File

@@ -1,9 +1,9 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"errors"
"github.com/xenolf/lego/acme"
"github.com/go-acme/lego/acme"
)
type ChallengeService service

View File

@@ -1,4 +1,4 @@
package nonces // import "github.com/xenolf/lego/acme/api/internal/nonces"
package nonces
import (
"errors"
@@ -6,7 +6,7 @@ import (
"net/http"
"sync"
"github.com/xenolf/lego/acme/api/internal/sender"
"github.com/go-acme/lego/acme/api/internal/sender"
)
// Manager Manages nonces.

View File

@@ -1,4 +1,4 @@
package secure // import "github.com/xenolf/lego/acme/api/internal/secure"
package secure
import (
"crypto"
@@ -6,10 +6,9 @@ import (
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"errors"
"fmt"
"github.com/xenolf/lego/acme/api/internal/nonces"
"github.com/go-acme/lego/acme/api/internal/nonces"
jose "gopkg.in/square/go-jose.v2"
)
@@ -118,9 +117,6 @@ func (j *JWS) GetKeyAuthorization(token string) (string, error) {
// Generate the Key Authorization for the challenge
jwk := &jose.JSONWebKey{Key: publicKey}
if jwk == nil {
return "", errors.New("could not generate JWK from key")
}
thumbBytes, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {

View File

@@ -1,4 +1,4 @@
package sender // import "github.com/xenolf/lego/acme/api/internal/sender"
package sender
import (
"encoding/json"
@@ -9,7 +9,7 @@ import (
"runtime"
"strings"
"github.com/xenolf/lego/acme"
"github.com/go-acme/lego/acme"
)
type RequestOption func(*http.Request) error

View File

@@ -5,7 +5,7 @@ package sender
const (
// ourUserAgent is the User-Agent of this underlying library package.
ourUserAgent = "xenolf-acme/2.3.0"
ourUserAgent = "xenolf-acme/2.5.0"
// ourUserAgentComment is part of the UA comment linked to the version status of this underlying library package.
// values: detach|release

View File

@@ -1,10 +1,10 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"encoding/base64"
"errors"
"github.com/xenolf/lego/acme"
"github.com/go-acme/lego/acme"
)
type OrderService service

View File

@@ -1,4 +1,4 @@
package api // import "github.com/xenolf/lego/acme/api"
package api
import (
"net/http"

View File

@@ -1,6 +1,6 @@
// Package acme contains all objects related the ACME endpoints.
// https://tools.ietf.org/html/draft-ietf-acme-acme-16
package acme // import "github.com/xenolf/lego/acme"
package acme
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package acme // import "github.com/xenolf/lego/acme"
package acme
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package certcrypto // import "github.com/xenolf/lego/certcrypto"
package certcrypto
import (
"crypto"

View File

@@ -1,10 +1,10 @@
package certificate // import "github.com/xenolf/lego/certificate"
package certificate
import (
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/log"
)
const (

View File

@@ -1,4 +1,4 @@
package certificate // import "github.com/xenolf/lego/certificate"
package certificate
import (
"bytes"
@@ -12,12 +12,12 @@ import (
"strings"
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/certcrypto"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
"github.com/xenolf/lego/platform/wait"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/acme/api"
"github.com/go-acme/lego/certcrypto"
"github.com/go-acme/lego/challenge"
"github.com/go-acme/lego/log"
"github.com/go-acme/lego/platform/wait"
"golang.org/x/crypto/ocsp"
"golang.org/x/net/idna"
)
@@ -114,6 +114,7 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
err = c.resolver.Solve(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order)
return nil, err
}
@@ -170,6 +171,7 @@ func (c *Certifier) ObtainForCSR(csr x509.CertificateRequest, bundle bool) (*Res
err = c.resolver.Solve(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order)
return nil, err
}

View File

@@ -1,4 +1,4 @@
package certificate // import "github.com/xenolf/lego/certificate"
package certificate
import (
"bytes"

View File

@@ -1,9 +1,9 @@
package challenge // import "github.com/xenolf/lego/challenge"
package challenge
import (
"fmt"
"github.com/xenolf/lego/acme"
"github.com/go-acme/lego/acme"
)
// Type is a string that identifies a particular challenge type and version of ACME challenge.

View File

@@ -1,4 +1,4 @@
package dns01 // import "github.com/xenolf/lego/challenge/dns01"
package dns01
import "github.com/miekg/dns"

View File

@@ -1,4 +1,4 @@
package dns01 // import "github.com/xenolf/lego/challenge/dns01"
package dns01
import (
"crypto/sha256"
@@ -8,12 +8,12 @@ import (
"strconv"
"time"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/acme/api"
"github.com/go-acme/lego/challenge"
"github.com/go-acme/lego/log"
"github.com/go-acme/lego/platform/wait"
"github.com/miekg/dns"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
"github.com/xenolf/lego/platform/wait"
)
const (

View File

@@ -1,4 +1,4 @@
package dns01 // import "github.com/xenolf/lego/challenge/dns01"
package dns01
import (
"bufio"

View File

@@ -1,4 +1,4 @@
package dns01 // import "github.com/xenolf/lego/challenge/dns01"
package dns01
// ToFqdn converts the name into a fqdn appending a trailing dot.
func ToFqdn(name string) string {

View File

@@ -1,4 +1,4 @@
package dns01 // import "github.com/xenolf/lego/challenge/dns01"
package dns01
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package dns01 // import "github.com/xenolf/lego/challenge/dns01"
package dns01
import (
"errors"

View File

@@ -1,12 +1,12 @@
package http01 // import "github.com/xenolf/lego/challenge/http01"
package http01
import (
"fmt"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/acme/api"
"github.com/go-acme/lego/challenge"
"github.com/go-acme/lego/log"
)
type ValidateFunc func(core *api.Core, domain string, chlng acme.Challenge) error

View File

@@ -1,4 +1,4 @@
package http01 // import "github.com/xenolf/lego/challenge/http01"
package http01
import (
"fmt"
@@ -6,7 +6,7 @@ import (
"net/http"
"strings"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/log"
)
// ProviderServer implements ChallengeProvider for `http-01` challenge

View File

@@ -1,4 +1,4 @@
package challenge // import "github.com/xenolf/lego/challenge"
package challenge
import "time"

View File

@@ -1,4 +1,4 @@
package resolver // import "github.com/xenolf/lego/challenge/resolver"
package resolver
import (
"bytes"

View File

@@ -1,12 +1,12 @@
package resolver // import "github.com/xenolf/lego/challenge/resolver"
package resolver
import (
"fmt"
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/challenge"
"github.com/go-acme/lego/log"
)
// Interface for all challenge solvers to implement.

View File

@@ -1,4 +1,4 @@
package resolver // import "github.com/xenolf/lego/challenge/resolver"
package resolver
import (
"context"
@@ -9,13 +9,13 @@ import (
"time"
"github.com/cenkalti/backoff"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/dns01"
"github.com/xenolf/lego/challenge/http01"
"github.com/xenolf/lego/challenge/tlsalpn01"
"github.com/xenolf/lego/log"
"github.com/go-acme/lego/acme"
"github.com/go-acme/lego/acme/api"
"github.com/go-acme/lego/challenge"
"github.com/go-acme/lego/challenge/dns01"
"github.com/go-acme/lego/challenge/http01"
"github.com/go-acme/lego/challenge/tlsalpn01"
"github.com/go-acme/lego/log"
)
type byType []acme.Challenge

Some files were not shown because too many files have changed in this diff Show More