551
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
551
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
@@ -0,0 +1,551 @@
|
||||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeHeaderLenTLS = 4 // handshake message header length
|
||||
handshakeHeaderLenDTLS = 12 // handshake message header length
|
||||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||
)
|
||||
|
||||
// struct {
|
||||
// HandshakeType msg_type; /* handshake type */
|
||||
// uint24 length; /* bytes in message */
|
||||
// select (HandshakeType) {
|
||||
// ...
|
||||
// } body;
|
||||
// } Handshake;
|
||||
//
|
||||
// We do the select{...} part in a different layer, so we treat the
|
||||
// actual message body as opaque:
|
||||
//
|
||||
// struct {
|
||||
// HandshakeType msg_type;
|
||||
// opaque msg<0..2^24-1>
|
||||
// } Handshake;
|
||||
//
|
||||
type HandshakeMessage struct {
|
||||
msgType HandshakeType
|
||||
seq uint32
|
||||
body []byte
|
||||
datagram bool
|
||||
offset uint32 // Used for DTLS
|
||||
length uint32
|
||||
cipher *cipherState
|
||||
}
|
||||
|
||||
// Note: This could be done with the `syntax` module, using the simplified
|
||||
// syntax as discussed above. However, since this is so simple, there's not
|
||||
// much benefit to doing so.
|
||||
// When datagram is set, we marshal this as a whole DTLS record.
|
||||
func (hm *HandshakeMessage) Marshal() []byte {
|
||||
if hm == nil {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
fragLen := len(hm.body)
|
||||
var data []byte
|
||||
|
||||
if hm.datagram {
|
||||
data = make([]byte, handshakeHeaderLenDTLS+fragLen)
|
||||
} else {
|
||||
data = make([]byte, handshakeHeaderLenTLS+fragLen)
|
||||
}
|
||||
tmp := data
|
||||
tmp = encodeUint(uint64(hm.msgType), 1, tmp)
|
||||
tmp = encodeUint(uint64(hm.length), 3, tmp)
|
||||
if hm.datagram {
|
||||
tmp = encodeUint(uint64(hm.seq), 2, tmp)
|
||||
tmp = encodeUint(uint64(hm.offset), 3, tmp)
|
||||
tmp = encodeUint(uint64(fragLen), 3, tmp)
|
||||
}
|
||||
copy(tmp, hm.body)
|
||||
return data
|
||||
}
|
||||
|
||||
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||||
|
||||
var body HandshakeMessageBody
|
||||
switch hm.msgType {
|
||||
case HandshakeTypeClientHello:
|
||||
body = new(ClientHelloBody)
|
||||
case HandshakeTypeServerHello:
|
||||
body = new(ServerHelloBody)
|
||||
case HandshakeTypeEncryptedExtensions:
|
||||
body = new(EncryptedExtensionsBody)
|
||||
case HandshakeTypeCertificate:
|
||||
body = new(CertificateBody)
|
||||
case HandshakeTypeCertificateRequest:
|
||||
body = new(CertificateRequestBody)
|
||||
case HandshakeTypeCertificateVerify:
|
||||
body = new(CertificateVerifyBody)
|
||||
case HandshakeTypeFinished:
|
||||
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||||
case HandshakeTypeNewSessionTicket:
|
||||
body = new(NewSessionTicketBody)
|
||||
case HandshakeTypeKeyUpdate:
|
||||
body = new(KeyUpdateBody)
|
||||
case HandshakeTypeEndOfEarlyData:
|
||||
body = new(EndOfEarlyDataBody)
|
||||
default:
|
||||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||
}
|
||||
|
||||
err := safeUnmarshal(body, hm.body)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||
data, err := body.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &HandshakeMessage{
|
||||
msgType: body.Type(),
|
||||
body: data,
|
||||
seq: h.msgSeq,
|
||||
datagram: h.datagram,
|
||||
length: uint32(len(data)),
|
||||
}
|
||||
h.msgSeq++
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type HandshakeLayer struct {
|
||||
ctx *HandshakeContext // The handshake we are attached to
|
||||
nonblocking bool // Should we operate in nonblocking mode
|
||||
conn *RecordLayer // Used for reading/writing records
|
||||
frame *frameReader // The buffered frame reader
|
||||
datagram bool // Is this DTLS?
|
||||
msgSeq uint32 // The DTLS message sequence number
|
||||
queued []*HandshakeMessage // In/out queue
|
||||
sent []*HandshakeMessage // Sent messages for DTLS
|
||||
recvdRecords []uint64 // Records we have received.
|
||||
maxFragmentLen int
|
||||
}
|
||||
|
||||
type handshakeLayerFrameDetails struct {
|
||||
datagram bool
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||
if d.datagram {
|
||||
return handshakeHeaderLenDTLS
|
||||
}
|
||||
return handshakeHeaderLenTLS
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||
return d.headerLen() + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
logf(logTypeIO, "Header=%x", hdr)
|
||||
// The length of this fragment (as opposed to the message)
|
||||
// is always the last three bytes for both TLS and DTLS
|
||||
val, _ := decodeUint(hdr[len(hdr)-3:], 3)
|
||||
return int(val), nil
|
||||
}
|
||||
|
||||
func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.ctx = c
|
||||
h.conn = r
|
||||
h.datagram = false
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
|
||||
h.maxFragmentLen = maxFragmentLen
|
||||
return &h
|
||||
}
|
||||
|
||||
func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.ctx = c
|
||||
h.conn = r
|
||||
h.datagram = true
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
|
||||
h.maxFragmentLen = initialMtu // Not quite right
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) readRecord() error {
|
||||
logf(logTypeVerbose, "Trying to read record")
|
||||
pt, err := h.conn.readRecordAnyEpoch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt.contentType {
|
||||
case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
|
||||
default:
|
||||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAck {
|
||||
if !h.datagram {
|
||||
return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
|
||||
}
|
||||
logf(logTypeIO, "read ACK")
|
||||
return h.ctx.processAck(pt.fragment)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAlert {
|
||||
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||
if len(pt.fragment) < 2 {
|
||||
h.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
return Alert(pt.fragment[1])
|
||||
}
|
||||
|
||||
assert(h.ctx.hIn.conn != nil)
|
||||
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
|
||||
// This is out of order but we're dropping it.
|
||||
// TODO(ekr@rtfm.com): If server, need to retransmit Finished.
|
||||
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Anything else shouldn't happen.
|
||||
return AlertIllegalParameter
|
||||
}
|
||||
|
||||
h.recvdRecords = append(h.recvdRecords, pt.seq)
|
||||
h.frame.addChunk(pt.fragment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||
tmp := make([]byte, 2)
|
||||
tmp[0] = AlertLevelError
|
||||
tmp[1] = byte(err)
|
||||
h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: tmp},
|
||||
)
|
||||
|
||||
// closeNotify is a special case in that it isn't an error:
|
||||
if err != AlertCloseNotify {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
|
||||
h.msgSeq = seq + 1
|
||||
var i int
|
||||
var m *HandshakeMessage
|
||||
for i, m = range h.queued {
|
||||
if m.seq > seq {
|
||||
break
|
||||
}
|
||||
}
|
||||
h.queued = h.queued[i:]
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
|
||||
if hm.seq < h.msgSeq {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO(ekr@rtfm.com): Send an ACK immediately if we got something
|
||||
// out of order.
|
||||
h.ctx.receivedHandshakeMessage()
|
||||
|
||||
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
|
||||
// TODO(ekr@rtfm.com): Check the length?
|
||||
// This is complete.
|
||||
h.noteMessageDelivered(hm.seq)
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// Now insert sorted.
|
||||
var i int
|
||||
for i = 0; i < len(h.queued); i++ {
|
||||
f := h.queued[i]
|
||||
if hm.seq < f.seq {
|
||||
break
|
||||
}
|
||||
if hm.offset < f.offset {
|
||||
break
|
||||
}
|
||||
}
|
||||
tmp := make([]*HandshakeMessage, 0, len(h.queued)+1)
|
||||
tmp = append(tmp, h.queued[:i]...)
|
||||
tmp = append(tmp, hm)
|
||||
tmp = append(tmp, h.queued[i:]...)
|
||||
h.queued = tmp
|
||||
|
||||
return h.checkMessageAvailable()
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
|
||||
if len(h.queued) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hm := h.queued[0]
|
||||
if hm.seq != h.msgSeq {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
|
||||
// TODO(ekr@rtfm.com): Check the length?
|
||||
// This is complete.
|
||||
h.noteMessageDelivered(hm.seq)
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// OK, this at least might complete the message.
|
||||
end := uint32(0)
|
||||
buf := make([]byte, hm.length)
|
||||
|
||||
for _, f := range h.queued {
|
||||
// Out of fragments
|
||||
if f.seq > hm.seq {
|
||||
break
|
||||
}
|
||||
|
||||
if f.length != uint32(len(buf)) {
|
||||
return nil, fmt.Errorf("Mismatched DTLS length")
|
||||
}
|
||||
|
||||
if f.offset > end {
|
||||
break
|
||||
}
|
||||
|
||||
if f.offset+uint32(len(f.body)) > end {
|
||||
// OK, this is adding something we don't know about
|
||||
copy(buf[f.offset:], f.body)
|
||||
end = f.offset + uint32(len(f.body))
|
||||
if end == hm.length {
|
||||
h2 := *hm
|
||||
h2.offset = 0
|
||||
h2.body = buf
|
||||
h.noteMessageDelivered(hm.seq)
|
||||
return &h2, nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||
var hdr, body []byte
|
||||
var err error
|
||||
|
||||
hm, err := h.checkMessageAvailable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hm != nil {
|
||||
return hm, nil
|
||||
}
|
||||
for {
|
||||
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||
if h.frame.needed() > 0 {
|
||||
logf(logTypeVerbose, "Trying to read a new record")
|
||||
err = h.readRecord()
|
||||
|
||||
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
hdr, body, err = h.frame.process()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "read handshake message")
|
||||
|
||||
hm = &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(hdr[0])
|
||||
hm.datagram = h.datagram
|
||||
hm.body = make([]byte, len(body))
|
||||
copy(hm.body, body)
|
||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||
if h.datagram {
|
||||
tmp, hdr := decodeUint(hdr[1:], 3)
|
||||
hm.length = uint32(tmp)
|
||||
tmp, hdr = decodeUint(hdr, 2)
|
||||
hm.seq = uint32(tmp)
|
||||
tmp, hdr = decodeUint(hdr, 3)
|
||||
hm.offset = uint32(tmp)
|
||||
|
||||
return h.newFragmentReceived(hm)
|
||||
}
|
||||
|
||||
hm.length = uint32(len(body))
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
|
||||
hm.cipher = h.conn.cipher
|
||||
h.queued = append(h.queued, hm)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
|
||||
logf(logTypeHandshake, "Sending outgoing messages")
|
||||
count, err := h.WriteMessages(h.queued)
|
||||
if !h.datagram {
|
||||
h.ClearQueuedMessages()
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ClearQueuedMessages() {
|
||||
logf(logTypeHandshake, "Clearing outgoing hs message queue")
|
||||
h.queued = nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
|
||||
var buf []byte
|
||||
|
||||
// Figure out if we're going to want the full header or just
|
||||
// the body
|
||||
hdrlen := 0
|
||||
if hm.datagram {
|
||||
hdrlen = handshakeHeaderLenDTLS
|
||||
} else if start == 0 {
|
||||
hdrlen = handshakeHeaderLenTLS
|
||||
}
|
||||
|
||||
// Compute the amount of body we can fit in
|
||||
room -= hdrlen
|
||||
if room == 0 {
|
||||
// This works because we are doing one record per
|
||||
// message
|
||||
panic("Too short max fragment len")
|
||||
}
|
||||
bodylen := len(hm.body) - start
|
||||
if bodylen > room {
|
||||
bodylen = room
|
||||
}
|
||||
body := hm.body[start : start+bodylen]
|
||||
|
||||
// Now see if this chunk has been ACKed. This doesn't produce ideal
|
||||
// retransmission but is simple.
|
||||
if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
|
||||
logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
|
||||
return false, start + bodylen, nil
|
||||
}
|
||||
|
||||
// Encode the data.
|
||||
if hdrlen > 0 {
|
||||
hm2 := *hm
|
||||
hm2.offset = uint32(start)
|
||||
hm2.body = body
|
||||
buf = hm2.Marshal()
|
||||
hm = &hm2
|
||||
} else {
|
||||
buf = body
|
||||
}
|
||||
|
||||
if h.datagram {
|
||||
// Remember that we sent this.
|
||||
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
|
||||
hm.seq,
|
||||
start,
|
||||
len(body),
|
||||
h.conn.cipher.combineSeq(true),
|
||||
false,
|
||||
})
|
||||
}
|
||||
return true, start + bodylen, h.conn.writeRecordWithPadding(
|
||||
&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buf,
|
||||
},
|
||||
hm.cipher, 0)
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
|
||||
start := int(0)
|
||||
|
||||
if len(hm.body) > maxHandshakeMessageLen {
|
||||
return 0, fmt.Errorf("Tried to write a handshake message that's too long")
|
||||
}
|
||||
|
||||
written := 0
|
||||
wrote := false
|
||||
|
||||
// Always make one pass through to allow EOED (which is empty).
|
||||
for {
|
||||
var err error
|
||||
wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if wrote {
|
||||
written++
|
||||
}
|
||||
if start >= len(hm.body) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
|
||||
written := 0
|
||||
for _, hm := range hms {
|
||||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||
|
||||
wrote, err := h.WriteMessage(hm)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
written += wrote
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func encodeUint(v uint64, size int, out []byte) []byte {
|
||||
for i := size - 1; i >= 0; i-- {
|
||||
out[i] = byte(v & 0xff)
|
||||
v >>= 8
|
||||
}
|
||||
return out[size:]
|
||||
}
|
||||
|
||||
func decodeUint(in []byte, size int) (uint64, []byte) {
|
||||
val := uint64(0)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
val <<= 8
|
||||
val += uint64(in[i])
|
||||
}
|
||||
return val, in[size:]
|
||||
}
|
||||
|
||||
type marshalledPDU interface {
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
func safeUnmarshal(pdu marshalledPDU, data []byte) error {
|
||||
read, err := pdu.Unmarshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) != read {
|
||||
return fmt.Errorf("Invalid encoding: Extra data not consumed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user