Files
kubesphere/vendor/github.com/bifurcation/mint/record-layer.go
jeff 4ac20ffc2b add service mesh controller
add service mesh metrics

remove unused circle yaml

fix travis misconfiguration

fix travis misconfiguration

fix travis misconfiguration
2019-03-17 17:28:52 +08:00

508 lines
13 KiB
Go

package mint
import (
"crypto/cipher"
"fmt"
"io"
"sync"
)
const (
sequenceNumberLen = 8 // sequence number length
recordHeaderLenTLS = 5 // record header length (TLS)
recordHeaderLenDTLS = 13 // record header length (DTLS)
maxFragmentLen = 1 << 14 // max number of bytes in a record
)
type DecryptError string
func (err DecryptError) Error() string {
return string(err)
}
type Direction uint8
const (
DirectionWrite = Direction(1)
DirectionRead = Direction(2)
)
// struct {
// ContentType type;
// ProtocolVersion record_version [0301 for CH, 0303 for others]
// uint16 length;
// opaque fragment[TLSPlaintext.length];
// } TLSPlaintext;
type TLSPlaintext struct {
// Omitted: record_version (static)
// Omitted: length (computed from fragment)
contentType RecordType
epoch Epoch
seq uint64
fragment []byte
}
func NewTLSPlaintext(ct RecordType, epoch Epoch, fragment []byte) *TLSPlaintext {
return &TLSPlaintext{
contentType: ct,
epoch: epoch,
fragment: fragment,
}
}
func (t TLSPlaintext) Fragment() []byte {
return t.fragment
}
type cipherState struct {
epoch Epoch // DTLS epoch
ivLength int // Length of the seq and nonce fields
seq uint64 // Zero-padded sequence number
iv []byte // Buffer for the IV
cipher cipher.AEAD // AEAD cipher
}
type RecordLayerFactory interface {
NewLayer(conn io.ReadWriter, dir Direction) RecordLayer
}
type RecordLayer interface {
Lock()
Unlock()
SetVersion(v uint16)
SetLabel(s string)
Rekey(epoch Epoch, factory AeadFactory, keys *KeySet) error
ResetClear(seq uint64)
DiscardReadKey(epoch Epoch)
PeekRecordType(block bool) (RecordType, error)
ReadRecord() (*TLSPlaintext, error)
WriteRecord(pt *TLSPlaintext) error
Epoch() Epoch
}
type RecordLayerImpl struct {
sync.Mutex
label string
direction Direction
version uint16 // The current version number
conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader
nextData []byte // The next record to send
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read
cipher *cipherState
readCiphers map[Epoch]*cipherState
datagram bool
}
func (r *RecordLayerImpl) Impl() *RecordLayerImpl {
return r
}
type recordLayerFrameDetails struct {
datagram bool
}
func (d recordLayerFrameDetails) headerLen() int {
if d.datagram {
return recordHeaderLenDTLS
}
return recordHeaderLenTLS
}
func (d recordLayerFrameDetails) defaultReadLen() int {
return d.headerLen() + maxFragmentLen
}
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
}
func newCipherStateNull() *cipherState {
return &cipherState{EpochClear, 0, 0, nil, nil}
}
func newCipherStateAead(epoch Epoch, factory AeadFactory, key []byte, iv []byte) (*cipherState, error) {
cipher, err := factory(key)
if err != nil {
return nil, err
}
return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
}
func NewRecordLayerTLS(conn io.ReadWriter, dir Direction) *RecordLayerImpl {
r := RecordLayerImpl{}
r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{false})
r.cipher = newCipherStateNull()
r.version = tls10Version
return &r
}
func NewRecordLayerDTLS(conn io.ReadWriter, dir Direction) *RecordLayerImpl {
r := RecordLayerImpl{}
r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{true})
r.cipher = newCipherStateNull()
r.readCiphers = make(map[Epoch]*cipherState, 0)
r.readCiphers[0] = r.cipher
r.datagram = true
return &r
}
func (r *RecordLayerImpl) SetVersion(v uint16) {
r.version = v
}
func (r *RecordLayerImpl) ResetClear(seq uint64) {
r.cipher = newCipherStateNull()
r.cipher.seq = seq
}
func (r *RecordLayerImpl) Epoch() Epoch {
return r.cipher.epoch
}
func (r *RecordLayerImpl) SetLabel(s string) {
r.label = s
}
func (r *RecordLayerImpl) Rekey(epoch Epoch, factory AeadFactory, keys *KeySet) error {
cipher, err := newCipherStateAead(epoch, factory, keys.Key, keys.Iv)
if err != nil {
return err
}
r.cipher = cipher
if r.datagram && r.direction == DirectionRead {
r.readCiphers[epoch] = cipher
}
return nil
}
// TODO(ekr@rtfm.com): This is never used, which is a bug.
func (r *RecordLayerImpl) DiscardReadKey(epoch Epoch) {
if !r.datagram {
return
}
_, ok := r.readCiphers[epoch]
assert(ok)
delete(r.readCiphers, epoch)
}
func (c *cipherState) combineSeq(datagram bool) uint64 {
seq := c.seq
if datagram {
seq |= uint64(c.epoch) << 48
}
return seq
}
func (c *cipherState) computeNonce(seq uint64) []byte {
nonce := make([]byte, len(c.iv))
copy(nonce, c.iv)
s := seq
offset := len(c.iv)
for i := 0; i < 8; i++ {
nonce[(offset-i)-1] ^= byte(s & 0xff)
s >>= 8
}
logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
return nonce
}
func (c *cipherState) incrementSequenceNumber() {
if c.seq >= (1<<48 - 1) {
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother. This is the
// DTLS limit.
panic("TLS: sequence number wraparound")
}
c.seq++
}
func (c *cipherState) overhead() int {
if c.cipher == nil {
return 0
}
return c.cipher.Overhead()
}
func (r *RecordLayerImpl) encrypt(cipher *cipherState, seq uint64, header []byte, pt *TLSPlaintext, padLen int) []byte {
assert(r.direction == DirectionWrite)
logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
ciphertextLen := plaintextLen + cipher.overhead()
ciphertext := make([]byte, ciphertextLen)
copy(ciphertext, pt.fragment)
ciphertext[originalLen] = byte(pt.contentType)
for i := 1; i <= padLen; i++ {
ciphertext[originalLen+i] = 0
}
// Encrypt the fragment
payload := ciphertext[:plaintextLen]
cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, header)
return ciphertext
}
func (r *RecordLayerImpl) decrypt(seq uint64, header []byte, pt *TLSPlaintext) (*TLSPlaintext, int, error) {
assert(r.direction == DirectionRead)
logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
if len(pt.fragment) < r.cipher.overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
return nil, 0, DecryptError(msg)
}
decryptLen := len(pt.fragment) - r.cipher.overhead()
out := &TLSPlaintext{
contentType: pt.contentType,
fragment: make([]byte, decryptLen),
}
// Decrypt
_, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, header)
if err != nil {
logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
}
// Find the padding boundary
padLen := 0
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
}
// Transfer the content type
newLen := decryptLen - padLen - 1
out.contentType = RecordType(out.fragment[newLen])
// Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen]
out.seq = seq
return out, padLen, nil
}
func (r *RecordLayerImpl) PeekRecordType(block bool) (RecordType, error) {
var pt *TLSPlaintext
var err error
for {
pt, err = r.nextRecord(false)
if err == nil {
break
}
if !block || err != AlertWouldBlock {
return 0, err
}
}
return pt.contentType, nil
}
func (r *RecordLayerImpl) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord(false)
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayerImpl) ReadRecordAnyEpoch() (*TLSPlaintext, error) {
pt, err := r.nextRecord(true)
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayerImpl) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
cipher := r.cipher
if r.cachedRecord != nil {
logf(logTypeIO, "%s Returning cached record", r.label)
return r.cachedRecord, r.cachedError
}
// Loop until one of three things happens:
//
// 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case
// returnAlertWouldBlock
// 3. We get an error.
var err error
err = AlertWouldBlock
var header, body []byte
for err != nil {
if r.frame.needed() > 0 {
buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
n, err := r.conn.Read(buf)
if err != nil {
logf(logTypeIO, "%s Error reading, %v", r.label, err)
return nil, err
}
if n == 0 {
return nil, AlertWouldBlock
}
logf(logTypeIO, "%s Read %v bytes", r.label, n)
buf = buf[:n]
r.frame.addChunk(buf)
}
header, body, err = r.frame.process()
// Loop around onAlertWouldBlock to see if some
// data is now available.
if err != nil && err != AlertWouldBlock {
return nil, err
}
}
pt := &TLSPlaintext{}
// Validate content type
switch RecordType(header[0]) {
default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
pt.contentType = RecordType(header[0])
}
// Validate version
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
}
// Validate size < max
size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
if size > maxFragmentLen+256 {
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
}
pt.fragment = make([]byte, size)
copy(pt.fragment, body)
// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
// Attempt to decrypt fragment
seq := cipher.seq
if r.datagram {
// TODO(ekr@rtfm.com): Handle duplicates.
seq, _ = decodeUint(header[3:11], 8)
epoch := Epoch(seq >> 48)
// Look up the cipher suite from the epoch
c, ok := r.readCiphers[epoch]
if !ok {
logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
return nil, AlertWouldBlock
}
if epoch != cipher.epoch {
logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
cipher.epoch, allowOldEpoch)
if !allowOldEpoch {
return nil, AlertWouldBlock
}
cipher = c
}
}
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
pt, _, err = r.decrypt(seq, header, pt)
if err != nil {
logf(logTypeIO, "%s Decryption failed", r.label)
return nil, err
}
}
pt.epoch = cipher.epoch
// Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big")
}
logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
r.cachedRecord = pt
cipher.incrementSequenceNumber()
return pt, nil
}
func (r *RecordLayerImpl) WriteRecord(pt *TLSPlaintext) error {
return r.writeRecordWithPadding(pt, r.cipher, 0)
}
func (r *RecordLayerImpl) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
return r.writeRecordWithPadding(pt, r.cipher, padLen)
}
func (r *RecordLayerImpl) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
seq := cipher.combineSeq(r.datagram)
length := len(pt.fragment)
var contentType RecordType
if cipher.cipher != nil {
length += 1 + padLen + cipher.cipher.Overhead()
contentType = RecordTypeApplicationData
} else {
contentType = pt.contentType
}
var header []byte
if !r.datagram {
header = []byte{byte(contentType),
byte(r.version >> 8), byte(r.version & 0xff),
byte(length >> 8), byte(length)}
} else {
header = make([]byte, 13)
version := dtlsConvertVersion(r.version)
copy(header, []byte{byte(contentType),
byte(version >> 8), byte(version & 0xff),
})
encodeUint(seq, 8, header[3:])
encodeUint(uint64(length), 2, header[11:])
}
var ciphertext []byte
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
ciphertext = r.encrypt(cipher, seq, header, pt, padLen)
} else {
if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
}
ciphertext = pt.fragment
}
if len(ciphertext) > maxFragmentLen {
return fmt.Errorf("tls.record: Record size too big")
}
record := append(header, ciphertext...)
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, contentType, ciphertext)
cipher.incrementSequenceNumber()
_, err := r.conn.Write(record)
return err
}