Launch
This commit is contained in:
575
internal/dnsttcore/dns/dns.go
Normal file
575
internal/dnsttcore/dns/dns.go
Normal file
@@ -0,0 +1,575 @@
|
||||
// Package dns deals with encoding and decoding DNS wire format.
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// The maximum number of DNS name compression pointers we are willing to follow.
|
||||
// Without something like this, infinite loops are possible.
|
||||
const compressionPointerLimit = 10
|
||||
|
||||
var (
|
||||
// ErrZeroLengthLabel is the error returned for names that contain a
|
||||
// zero-length label, like "example..com".
|
||||
ErrZeroLengthLabel = errors.New("name contains a zero-length label")
|
||||
|
||||
// ErrLabelTooLong is the error returned for labels that are longer than
|
||||
// 63 octets.
|
||||
ErrLabelTooLong = errors.New("name contains a label longer than 63 octets")
|
||||
|
||||
// ErrNameTooLong is the error returned for names whose encoded
|
||||
// representation is longer than 255 octets.
|
||||
ErrNameTooLong = errors.New("name is longer than 255 octets")
|
||||
|
||||
// ErrReservedLabelType is the error returned when reading a label type
|
||||
// prefix whose two most significant bits are not 00 or 11.
|
||||
ErrReservedLabelType = errors.New("reserved label type")
|
||||
|
||||
// ErrTooManyPointers is the error returned when reading a compressed
|
||||
// name that has too many compression pointers.
|
||||
ErrTooManyPointers = errors.New("too many compression pointers")
|
||||
|
||||
// ErrTrailingBytes is the error returned when bytes remain in the parse
|
||||
// buffer after parsing a message.
|
||||
ErrTrailingBytes = errors.New("trailing bytes after message")
|
||||
|
||||
// ErrIntegerOverflow is the error returned when trying to encode an
|
||||
// integer greater than 65535 into a 16-bit field.
|
||||
ErrIntegerOverflow = errors.New("integer overflow")
|
||||
)
|
||||
|
||||
const (
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.2.2
|
||||
RRTypeTXT = 16
|
||||
// https://tools.ietf.org/html/rfc6891#section-6.1.1
|
||||
RRTypeOPT = 41
|
||||
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.2.4
|
||||
ClassIN = 1
|
||||
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.1
|
||||
RcodeNoError = 0 // a.k.a. NOERROR
|
||||
RcodeFormatError = 1 // a.k.a. FORMERR
|
||||
RcodeNameError = 3 // a.k.a. NXDOMAIN
|
||||
RcodeNotImplemented = 4 // a.k.a. NOTIMPL
|
||||
// https://tools.ietf.org/html/rfc6891#section-9
|
||||
ExtendedRcodeBadVers = 16 // a.k.a. BADVERS
|
||||
)
|
||||
|
||||
// Name represents a domain name, a sequence of labels each of which is 63
|
||||
// octets or less in length.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.1
|
||||
type Name [][]byte
|
||||
|
||||
// NewName returns a Name from a slice of labels, after checking the labels for
|
||||
// validity. Does not include a zero-length label at the end of the slice.
|
||||
func NewName(labels [][]byte) (Name, error) {
|
||||
name := Name(labels)
|
||||
// https://tools.ietf.org/html/rfc1035#section-2.3.4
|
||||
// Various objects and parameters in the DNS have size limits.
|
||||
// labels 63 octets or less
|
||||
// names 255 octets or less
|
||||
for _, label := range labels {
|
||||
if len(label) == 0 {
|
||||
return nil, ErrZeroLengthLabel
|
||||
}
|
||||
if len(label) > 63 {
|
||||
return nil, ErrLabelTooLong
|
||||
}
|
||||
}
|
||||
// Check the total length.
|
||||
builder := newMessageBuilder()
|
||||
builder.WriteName(name)
|
||||
if len(builder.Bytes()) > 255 {
|
||||
return nil, ErrNameTooLong
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// ParseName returns a new Name from a string of labels separated by dots, after
|
||||
// checking the name for validity. A single dot at the end of the string is
|
||||
// ignored.
|
||||
func ParseName(s string) (Name, error) {
|
||||
b := bytes.TrimSuffix([]byte(s), []byte("."))
|
||||
if len(b) == 0 {
|
||||
// bytes.Split(b, ".") would return [""] in this case
|
||||
return NewName([][]byte{})
|
||||
} else {
|
||||
return NewName(bytes.Split(b, []byte(".")))
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a reversible string representation of name. Labels are
|
||||
// separated by dots, and any bytes in a label that are outside the set
|
||||
// [0-9A-Za-z-] are replaced with a \xXX hex escape sequence.
|
||||
func (name Name) String() string {
|
||||
if len(name) == 0 {
|
||||
return "."
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
for i, label := range name {
|
||||
if i > 0 {
|
||||
buf.WriteByte('.')
|
||||
}
|
||||
for _, b := range label {
|
||||
if b == '-' ||
|
||||
('0' <= b && b <= '9') ||
|
||||
('A' <= b && b <= 'Z') ||
|
||||
('a' <= b && b <= 'z') {
|
||||
buf.WriteByte(b)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "\\x%02x", b)
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// TrimSuffix returns a Name with the given suffix removed, if it was present.
|
||||
// The second return value indicates whether the suffix was present. If the
|
||||
// suffix was not present, the first return value is nil.
|
||||
func (name Name) TrimSuffix(suffix Name) (Name, bool) {
|
||||
if len(name) < len(suffix) {
|
||||
return nil, false
|
||||
}
|
||||
split := len(name) - len(suffix)
|
||||
fore, aft := name[:split], name[split:]
|
||||
for i := 0; i < len(aft); i++ {
|
||||
if !bytes.Equal(bytes.ToLower(aft[i]), bytes.ToLower(suffix[i])) {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return fore, true
|
||||
}
|
||||
|
||||
// Message represents a DNS message.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1
|
||||
type Message struct {
|
||||
ID uint16
|
||||
Flags uint16
|
||||
|
||||
Question []Question
|
||||
Answer []RR
|
||||
Authority []RR
|
||||
Additional []RR
|
||||
}
|
||||
|
||||
// Opcode extracts the OPCODE part of the Flags field.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.1
|
||||
func (message *Message) Opcode() uint16 {
|
||||
return (message.Flags >> 11) & 0xf
|
||||
}
|
||||
|
||||
// Rcode extracts the RCODE part of the Flags field.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.1
|
||||
func (message *Message) Rcode() uint16 {
|
||||
return message.Flags & 0x000f
|
||||
}
|
||||
|
||||
// Question represents an entry in the question section of a message.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.2
|
||||
type Question struct {
|
||||
Name Name
|
||||
Type uint16
|
||||
Class uint16
|
||||
}
|
||||
|
||||
// RR represents a resource record.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.3
|
||||
type RR struct {
|
||||
Name Name
|
||||
Type uint16
|
||||
Class uint16
|
||||
TTL uint32
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// readName parses a DNS name from r. It leaves r positioned just after the
|
||||
// parsed name.
|
||||
func readName(r io.ReadSeeker) (Name, error) {
|
||||
var labels [][]byte
|
||||
// We limit the number of compression pointers we are willing to follow.
|
||||
numPointers := 0
|
||||
// If we followed any compression pointers, we must finally seek to just
|
||||
// past the first pointer.
|
||||
var seekTo int64
|
||||
loop:
|
||||
for {
|
||||
var labelType byte
|
||||
err := binary.Read(r, binary.BigEndian, &labelType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch labelType & 0xc0 {
|
||||
case 0x00:
|
||||
// This is an ordinary label.
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.1
|
||||
length := int(labelType & 0x3f)
|
||||
if length == 0 {
|
||||
break loop
|
||||
}
|
||||
label := make([]byte, length)
|
||||
_, err := io.ReadFull(r, label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels = append(labels, label)
|
||||
case 0xc0:
|
||||
// This is a compression pointer.
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.4
|
||||
upper := labelType & 0x3f
|
||||
var lower byte
|
||||
err := binary.Read(r, binary.BigEndian, &lower)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset := (uint16(upper) << 8) | uint16(lower)
|
||||
|
||||
if numPointers == 0 {
|
||||
// The first time we encounter a pointer,
|
||||
// remember our position so we can seek back to
|
||||
// it when done.
|
||||
seekTo, err = r.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
numPointers++
|
||||
if numPointers > compressionPointerLimit {
|
||||
return nil, ErrTooManyPointers
|
||||
}
|
||||
|
||||
// Follow the pointer and continue.
|
||||
_, err = r.Seek(int64(offset), io.SeekStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
// "The 10 and 01 combinations are reserved for future
|
||||
// use."
|
||||
return nil, ErrReservedLabelType
|
||||
}
|
||||
}
|
||||
// If we followed any pointers, then seek back to just after the first
|
||||
// one.
|
||||
if numPointers > 0 {
|
||||
_, err := r.Seek(seekTo, io.SeekStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return NewName(labels)
|
||||
}
|
||||
|
||||
// readQuestion parses one entry from the Question section. It leaves r
|
||||
// positioned just after the parsed entry.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.2
|
||||
func readQuestion(r io.ReadSeeker) (Question, error) {
|
||||
var question Question
|
||||
var err error
|
||||
question.Name, err = readName(r)
|
||||
if err != nil {
|
||||
return question, err
|
||||
}
|
||||
for _, ptr := range []*uint16{&question.Type, &question.Class} {
|
||||
err := binary.Read(r, binary.BigEndian, ptr)
|
||||
if err != nil {
|
||||
return question, err
|
||||
}
|
||||
}
|
||||
|
||||
return question, nil
|
||||
}
|
||||
|
||||
// readRR parses one resource record. It leaves r positioned just after the
|
||||
// parsed resource record.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.3
|
||||
func readRR(r io.ReadSeeker) (RR, error) {
|
||||
var rr RR
|
||||
var err error
|
||||
rr.Name, err = readName(r)
|
||||
if err != nil {
|
||||
return rr, err
|
||||
}
|
||||
for _, ptr := range []*uint16{&rr.Type, &rr.Class} {
|
||||
err := binary.Read(r, binary.BigEndian, ptr)
|
||||
if err != nil {
|
||||
return rr, err
|
||||
}
|
||||
}
|
||||
err = binary.Read(r, binary.BigEndian, &rr.TTL)
|
||||
if err != nil {
|
||||
return rr, err
|
||||
}
|
||||
var rdLength uint16
|
||||
err = binary.Read(r, binary.BigEndian, &rdLength)
|
||||
if err != nil {
|
||||
return rr, err
|
||||
}
|
||||
rr.Data = make([]byte, rdLength)
|
||||
_, err = io.ReadFull(r, rr.Data)
|
||||
if err != nil {
|
||||
return rr, err
|
||||
}
|
||||
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
// readMessage parses a complete DNS message. It leaves r positioned just after
|
||||
// the parsed message.
|
||||
func readMessage(r io.ReadSeeker) (Message, error) {
|
||||
var message Message
|
||||
|
||||
// Header section
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.1
|
||||
var qdCount, anCount, nsCount, arCount uint16
|
||||
for _, ptr := range []*uint16{
|
||||
&message.ID, &message.Flags,
|
||||
&qdCount, &anCount, &nsCount, &arCount,
|
||||
} {
|
||||
err := binary.Read(r, binary.BigEndian, ptr)
|
||||
if err != nil {
|
||||
return message, err
|
||||
}
|
||||
}
|
||||
|
||||
// Question section
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.2
|
||||
for i := 0; i < int(qdCount); i++ {
|
||||
question, err := readQuestion(r)
|
||||
if err != nil {
|
||||
return message, err
|
||||
}
|
||||
message.Question = append(message.Question, question)
|
||||
}
|
||||
|
||||
// Answer, Authority, and Additional sections
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.3
|
||||
for _, rec := range []struct {
|
||||
ptr *[]RR
|
||||
count uint16
|
||||
}{
|
||||
{&message.Answer, anCount},
|
||||
{&message.Authority, nsCount},
|
||||
{&message.Additional, arCount},
|
||||
} {
|
||||
for i := 0; i < int(rec.count); i++ {
|
||||
rr, err := readRR(r)
|
||||
if err != nil {
|
||||
return message, err
|
||||
}
|
||||
*rec.ptr = append(*rec.ptr, rr)
|
||||
}
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// MessageFromWireFormat parses a message from buf and returns a Message object.
|
||||
// It returns ErrTrailingBytes if there are bytes remaining in buf after parsing
|
||||
// is done.
|
||||
func MessageFromWireFormat(buf []byte) (Message, error) {
|
||||
r := bytes.NewReader(buf)
|
||||
message, err := readMessage(r)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
} else if err == nil {
|
||||
// Check for trailing bytes.
|
||||
_, err = r.ReadByte()
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
} else if err == nil {
|
||||
err = ErrTrailingBytes
|
||||
}
|
||||
}
|
||||
return message, err
|
||||
}
|
||||
|
||||
// messageBuilder manages the state of serializing a DNS message. Its main
|
||||
// function is to keep track of names already written for the purpose of name
|
||||
// compression.
|
||||
type messageBuilder struct {
|
||||
w bytes.Buffer
|
||||
nameCache map[string]int
|
||||
}
|
||||
|
||||
// newMessageBuilder creates a new messageBuilder with an empty name cache.
|
||||
func newMessageBuilder() *messageBuilder {
|
||||
return &messageBuilder{
|
||||
nameCache: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// Bytes returns the serialized DNS message as a slice of bytes.
|
||||
func (builder *messageBuilder) Bytes() []byte {
|
||||
return builder.w.Bytes()
|
||||
}
|
||||
|
||||
// WriteName appends name to the in-progress messageBuilder, employing
|
||||
// compression pointers to previously written names if possible.
|
||||
func (builder *messageBuilder) WriteName(name Name) {
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.1
|
||||
for i := range name {
|
||||
// Has this suffix already been encoded in the message?
|
||||
if ptr, ok := builder.nameCache[name[i:].String()]; ok && ptr&0x3fff == ptr {
|
||||
// If so, we can write a compression pointer.
|
||||
binary.Write(&builder.w, binary.BigEndian, uint16(0xc000|ptr))
|
||||
return
|
||||
}
|
||||
// Not cached; we must encode this label verbatim. Store a cache
|
||||
// entry pointing to the beginning of it.
|
||||
builder.nameCache[name[i:].String()] = builder.w.Len()
|
||||
length := len(name[i])
|
||||
if length == 0 || length > 63 {
|
||||
panic(length)
|
||||
}
|
||||
builder.w.WriteByte(byte(length))
|
||||
builder.w.Write(name[i])
|
||||
}
|
||||
builder.w.WriteByte(0)
|
||||
}
|
||||
|
||||
// WriteQuestion appends a Question section entry to the in-progress
|
||||
// messageBuilder.
|
||||
func (builder *messageBuilder) WriteQuestion(question *Question) {
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.2
|
||||
builder.WriteName(question.Name)
|
||||
binary.Write(&builder.w, binary.BigEndian, question.Type)
|
||||
binary.Write(&builder.w, binary.BigEndian, question.Class)
|
||||
}
|
||||
|
||||
// WriteRR appends a resource record to the in-progress messageBuilder. It
|
||||
// returns ErrIntegerOverflow if the length of rr.Data does not fit in 16 bits.
|
||||
func (builder *messageBuilder) WriteRR(rr *RR) error {
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.3
|
||||
builder.WriteName(rr.Name)
|
||||
binary.Write(&builder.w, binary.BigEndian, rr.Type)
|
||||
binary.Write(&builder.w, binary.BigEndian, rr.Class)
|
||||
binary.Write(&builder.w, binary.BigEndian, rr.TTL)
|
||||
rdLength := uint16(len(rr.Data))
|
||||
if int(rdLength) != len(rr.Data) {
|
||||
return ErrIntegerOverflow
|
||||
}
|
||||
binary.Write(&builder.w, binary.BigEndian, rdLength)
|
||||
builder.w.Write(rr.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteMessage appends a complete DNS message to the in-progress
|
||||
// messageBuilder. It returns ErrIntegerOverflow if the number of entries in any
|
||||
// section, or the length of the data in any resource record, does not fit in 16
|
||||
// bits.
|
||||
func (builder *messageBuilder) WriteMessage(message *Message) error {
|
||||
// Header section
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.1
|
||||
binary.Write(&builder.w, binary.BigEndian, message.ID)
|
||||
binary.Write(&builder.w, binary.BigEndian, message.Flags)
|
||||
for _, count := range []int{
|
||||
len(message.Question),
|
||||
len(message.Answer),
|
||||
len(message.Authority),
|
||||
len(message.Additional),
|
||||
} {
|
||||
count16 := uint16(count)
|
||||
if int(count16) != count {
|
||||
return ErrIntegerOverflow
|
||||
}
|
||||
binary.Write(&builder.w, binary.BigEndian, count16)
|
||||
}
|
||||
|
||||
// Question section
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.2
|
||||
for _, question := range message.Question {
|
||||
builder.WriteQuestion(&question)
|
||||
}
|
||||
|
||||
// Answer, Authority, and Additional sections
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.3
|
||||
for _, rrs := range [][]RR{message.Answer, message.Authority, message.Additional} {
|
||||
for _, rr := range rrs {
|
||||
err := builder.WriteRR(&rr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WireFormat encodes a Message as a slice of bytes in DNS wire format. It
|
||||
// returns ErrIntegerOverflow if the number of entries in any section, or the
|
||||
// length of the data in any resource record, does not fit in 16 bits.
|
||||
func (message *Message) WireFormat() ([]byte, error) {
|
||||
builder := newMessageBuilder()
|
||||
err := builder.WriteMessage(message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return builder.Bytes(), nil
|
||||
}
|
||||
|
||||
// DecodeRDataTXT decodes TXT-DATA (as found in the RDATA for a resource record
|
||||
// with TYPE=TXT) as a raw byte slice, by concatenating all the
|
||||
// <character-string>s it contains.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.3.14
|
||||
func DecodeRDataTXT(p []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for {
|
||||
if len(p) == 0 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
n := int(p[0])
|
||||
p = p[1:]
|
||||
if len(p) < n {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
buf.Write(p[:n])
|
||||
p = p[n:]
|
||||
if len(p) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// EncodeRDataTXT encodes a slice of bytes as TXT-DATA, as appropriate for the
|
||||
// RDATA of a resource record with TYPE=TXT. No length restriction is enforced
|
||||
// here; that must be checked at a higher level.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.3.14
|
||||
func EncodeRDataTXT(p []byte) []byte {
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.3
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.3.14
|
||||
// TXT data is a sequence of one or more <character-string>s, where
|
||||
// <character-string> is a length octet followed by that number of
|
||||
// octets.
|
||||
var buf bytes.Buffer
|
||||
for len(p) > 255 {
|
||||
buf.WriteByte(255)
|
||||
buf.Write(p[:255])
|
||||
p = p[255:]
|
||||
}
|
||||
// Must write here, even if len(p) == 0, because it's "*one or more*
|
||||
// <character-string>s".
|
||||
buf.WriteByte(byte(len(p)))
|
||||
buf.Write(p)
|
||||
return buf.Bytes()
|
||||
}
|
||||
592
internal/dnsttcore/dns/dns_test.go
Normal file
592
internal/dnsttcore/dns/dns_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func namesEqual(a, b Name) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(a); i++ {
|
||||
if !bytes.Equal(a[i], b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestName(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
labels [][]byte
|
||||
err error
|
||||
s string
|
||||
}{
|
||||
{[][]byte{}, nil, "."},
|
||||
{[][]byte{[]byte("test")}, nil, "test"},
|
||||
{[][]byte{[]byte("a"), []byte("b"), []byte("c")}, nil, "a.b.c"},
|
||||
|
||||
{[][]byte{{}}, ErrZeroLengthLabel, ""},
|
||||
{[][]byte{[]byte("a"), {}, []byte("c")}, ErrZeroLengthLabel, ""},
|
||||
|
||||
// 63 octets.
|
||||
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE")}, nil,
|
||||
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"},
|
||||
// 64 octets.
|
||||
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDEF")}, ErrLabelTooLong, ""},
|
||||
|
||||
// 64+64+64+62 octets.
|
||||
{[][]byte{
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"),
|
||||
}, nil,
|
||||
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"},
|
||||
// 64+64+64+63 octets.
|
||||
{[][]byte{
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
|
||||
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCD"),
|
||||
}, ErrNameTooLong, ""},
|
||||
// 127 one-octet labels.
|
||||
{[][]byte{
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'},
|
||||
}, nil,
|
||||
"0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E"},
|
||||
// 128 one-octet labels.
|
||||
{[][]byte{
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
|
||||
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
|
||||
}, ErrNameTooLong, ""},
|
||||
} {
|
||||
// Test that NewName returns proper error codes, and otherwise
|
||||
// returns an equal slice of labels.
|
||||
name, err := NewName(test.labels)
|
||||
if err != test.err || (err == nil && !namesEqual(name, test.labels)) {
|
||||
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
|
||||
test.labels, name, err, test.labels, test.err)
|
||||
continue
|
||||
}
|
||||
if test.err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Test that the string version of the name comes out as
|
||||
// expected.
|
||||
s := name.String()
|
||||
if s != test.s {
|
||||
t.Errorf("%+q became string %+q, expected %+q", test.labels, s, test.s)
|
||||
continue
|
||||
}
|
||||
|
||||
// Test that parsing from a string back to a Name results in the
|
||||
// original slice of labels.
|
||||
name, err = ParseName(s)
|
||||
if err != nil || !namesEqual(name, test.labels) {
|
||||
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
|
||||
test.labels, s, name, err, test.labels, nil)
|
||||
continue
|
||||
}
|
||||
// A trailing dot should be ignored.
|
||||
if !strings.HasSuffix(s, ".") {
|
||||
dotName, dotErr := ParseName(s + ".")
|
||||
if dotErr != err || !namesEqual(dotName, name) {
|
||||
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
|
||||
test.labels, s+".", dotName, dotErr, name, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
s string
|
||||
name Name
|
||||
err error
|
||||
}{
|
||||
// This case can't be tested by TestName above because String
|
||||
// will never produce "" (it produces "." instead).
|
||||
{"", [][]byte{}, nil},
|
||||
} {
|
||||
name, err := ParseName(test.s)
|
||||
if err != test.err || (err == nil && !namesEqual(name, test.name)) {
|
||||
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
|
||||
test.s, name, err, test.name, test.err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func unescapeString(s string) ([][]byte, error) {
|
||||
if s == "." {
|
||||
return [][]byte{}, nil
|
||||
}
|
||||
|
||||
var result [][]byte
|
||||
for _, label := range strings.Split(s, ".") {
|
||||
var buf bytes.Buffer
|
||||
i := 0
|
||||
for i < len(label) {
|
||||
switch label[i] {
|
||||
case '\\':
|
||||
if i+3 >= len(label) {
|
||||
return nil, fmt.Errorf("truncated escape sequence at index %v", i)
|
||||
}
|
||||
if label[i+1] != 'x' {
|
||||
return nil, fmt.Errorf("malformed escape sequence at index %v", i)
|
||||
}
|
||||
b, err := strconv.ParseUint(string(label[i+2:i+4]), 16, 8)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("malformed hex sequence at index %v", i+2)
|
||||
}
|
||||
buf.WriteByte(byte(b))
|
||||
i += 4
|
||||
default:
|
||||
buf.WriteByte(label[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
result = append(result, buf.Bytes())
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func TestNameString(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name Name
|
||||
s string
|
||||
}{
|
||||
{[][]byte{}, "."},
|
||||
{[][]byte{[]byte("\x00"), []byte("a.b"), []byte("c\nd\\")}, "\\x00.a\\x2eb.c\\x0ad\\x5c"},
|
||||
{[][]byte{
|
||||
[]byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>"),
|
||||
[]byte("?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}"),
|
||||
[]byte("~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc"),
|
||||
[]byte("\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb"),
|
||||
[]byte("\xfc\xfd\xfe\xff"),
|
||||
}, "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29\\x2a\\x2b\\x2c-\\x2e\\x2f0123456789\\x3a\\x3b\\x3c\\x3d\\x3e.\\x3f\\x40ABCDEFGHIJKLMNOPQRSTUVWXYZ\\x5b\\x5c\\x5d\\x5e\\x5f\\x60abcdefghijklmnopqrstuvwxyz\\x7b\\x7c\\x7d.\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc.\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb.\\xfc\\xfd\\xfe\\xff"},
|
||||
} {
|
||||
s := test.name.String()
|
||||
if s != test.s {
|
||||
t.Errorf("%+q escaped to %+q, expected %+q", test.name, s, test.s)
|
||||
continue
|
||||
}
|
||||
unescaped, err := unescapeString(s)
|
||||
if err != nil {
|
||||
t.Errorf("%+q unescaping %+q resulted in error %v", test.name, s, err)
|
||||
continue
|
||||
}
|
||||
if !namesEqual(Name(unescaped), test.name) {
|
||||
t.Errorf("%+q roundtripped through %+q to %+q", test.name, s, unescaped)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameTrimSuffix(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name, suffix string
|
||||
trimmed string
|
||||
ok bool
|
||||
}{
|
||||
{"", "", ".", true},
|
||||
{".", ".", ".", true},
|
||||
{"abc", "", "abc", true},
|
||||
{"abc", ".", "abc", true},
|
||||
{"", "abc", ".", false},
|
||||
{".", "abc", ".", false},
|
||||
{"example.com", "com", "example", true},
|
||||
{"example.com", "net", ".", false},
|
||||
{"example.com", "example.com", ".", true},
|
||||
{"example.com", "test.com", ".", false},
|
||||
{"example.com", "xample.com", ".", false},
|
||||
{"example.com", "example", ".", false},
|
||||
{"example.com", "COM", "example", true},
|
||||
{"EXAMPLE.COM", "com", "EXAMPLE", true},
|
||||
} {
|
||||
tmp, ok := mustParseName(test.name).TrimSuffix(mustParseName(test.suffix))
|
||||
trimmed := tmp.String()
|
||||
if ok != test.ok || trimmed != test.trimmed {
|
||||
t.Errorf("TrimSuffix %+q %+q returned (%+q, %v), expected (%+q, %v)",
|
||||
test.name, test.suffix, trimmed, ok, test.trimmed, test.ok)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadName(t *testing.T) {
|
||||
// Good tests.
|
||||
for _, test := range []struct {
|
||||
start int64
|
||||
end int64
|
||||
input string
|
||||
s string
|
||||
}{
|
||||
// Empty name.
|
||||
{0, 1, "\x00abcd", "."},
|
||||
// No pointers.
|
||||
{12, 25, "AAAABBBBCCCC\x07example\x03com\x00", "example.com"},
|
||||
// Backward pointer.
|
||||
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c", "sub.example.com"},
|
||||
// Forward pointer.
|
||||
{0, 4, "\x01a\xc0\x04\x03bcd\x00", "a.bcd"},
|
||||
// Two backwards pointers.
|
||||
{31, 38, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c\x04sub2\xc0\x19", "sub2.sub.example.com"},
|
||||
// Forward then backward pointer.
|
||||
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x1f\x04sub2\xc0\x0c", "sub.sub2.example.com"},
|
||||
// Overlapping codons.
|
||||
{0, 4, "\x01a\xc0\x03bcd\x00", "a.bcd"},
|
||||
// Pointer to empty label.
|
||||
{0, 10, "\x07example\xc0\x0a\x00", "example"},
|
||||
{1, 11, "\x00\x07example\xc0\x00", "example"},
|
||||
// Pointer to pointer to empty label.
|
||||
{0, 10, "\x07example\xc0\x0a\xc0\x0c\x00", "example"},
|
||||
{1, 11, "\x00\x07example\xc0\x0c\xc0\x00", "example"},
|
||||
} {
|
||||
r := bytes.NewReader([]byte(test.input))
|
||||
_, err := r.Seek(test.start, io.SeekStart)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
name, err := readName(r)
|
||||
if err != nil {
|
||||
t.Errorf("%+q returned error %s", test.input, err)
|
||||
continue
|
||||
}
|
||||
s := name.String()
|
||||
if s != test.s {
|
||||
t.Errorf("%+q returned %+q, expected %+q", test.input, s, test.s)
|
||||
continue
|
||||
}
|
||||
cur, _ := r.Seek(0, io.SeekCurrent)
|
||||
if cur != test.end {
|
||||
t.Errorf("%+q left offset %d, expected %d", test.input, cur, test.end)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Bad tests.
|
||||
for _, test := range []struct {
|
||||
start int64
|
||||
input string
|
||||
err error
|
||||
}{
|
||||
{0, "", io.ErrUnexpectedEOF},
|
||||
// Reserved label type.
|
||||
{0, "\x80example", ErrReservedLabelType},
|
||||
// Reserved label type.
|
||||
{0, "\x40example", ErrReservedLabelType},
|
||||
// No Terminating empty label.
|
||||
{0, "\x07example\x03com", io.ErrUnexpectedEOF},
|
||||
// Pointer past end of buffer.
|
||||
{0, "\x07example\xc0\xff", io.ErrUnexpectedEOF},
|
||||
// Pointer to self.
|
||||
{0, "\x07example\x03com\xc0\x0c", ErrTooManyPointers},
|
||||
// Pointer to self with intermediate label.
|
||||
{0, "\x07example\x03com\xc0\x08", ErrTooManyPointers},
|
||||
// Two pointers that point to each other.
|
||||
{0, "\xc0\x02\xc0\x00", ErrTooManyPointers},
|
||||
// Two pointers that point to each other, with intermediate labels.
|
||||
{0, "\x01a\xc0\x04\x01b\xc0\x00", ErrTooManyPointers},
|
||||
// EOF while reading label.
|
||||
{0, "\x0aexample", io.ErrUnexpectedEOF},
|
||||
// EOF before second byte of pointer.
|
||||
{0, "\xc0", io.ErrUnexpectedEOF},
|
||||
{0, "\x07example\xc0", io.ErrUnexpectedEOF},
|
||||
} {
|
||||
r := bytes.NewReader([]byte(test.input))
|
||||
_, err := r.Seek(test.start, io.SeekStart)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
name, err := readName(r)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != test.err {
|
||||
t.Errorf("%+q returned (%+q, %v), expected %v", test.input, name, err, test.err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseName(s string) Name {
|
||||
name, err := ParseName(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func questionsEqual(a, b *Question) bool {
|
||||
if !namesEqual(a.Name, b.Name) {
|
||||
return false
|
||||
}
|
||||
if a.Type != b.Type || a.Class != b.Class {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func rrsEqual(a, b *RR) bool {
|
||||
if !namesEqual(a.Name, b.Name) {
|
||||
return false
|
||||
}
|
||||
if a.Type != b.Type || a.Class != b.Class || a.TTL != b.TTL {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(a.Data, b.Data) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func messagesEqual(a, b *Message) bool {
|
||||
if a.ID != b.ID || a.Flags != b.Flags {
|
||||
return false
|
||||
}
|
||||
if len(a.Question) != len(b.Question) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(a.Question); i++ {
|
||||
if !questionsEqual(&a.Question[i], &b.Question[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, rec := range []struct{ rrA, rrB []RR }{
|
||||
{a.Answer, b.Answer},
|
||||
{a.Authority, b.Authority},
|
||||
{a.Additional, b.Additional},
|
||||
} {
|
||||
if len(rec.rrA) != len(rec.rrB) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(rec.rrA); i++ {
|
||||
if !rrsEqual(&rec.rrA[i], &rec.rrB[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMessageFromWireFormat(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
buf string
|
||||
expected Message
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"\x12\x34",
|
||||
Message{},
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01",
|
||||
Message{
|
||||
ID: 0x1234,
|
||||
Flags: 0x0100,
|
||||
Question: []Question{
|
||||
{
|
||||
Name: mustParseName("www.example.com"),
|
||||
Type: 1,
|
||||
Class: 1,
|
||||
},
|
||||
},
|
||||
Answer: []RR{},
|
||||
Authority: []RR{},
|
||||
Additional: []RR{},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01X",
|
||||
Message{},
|
||||
ErrTrailingBytes,
|
||||
},
|
||||
{
|
||||
"\x12\x34\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01\x03www\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x00\x80\x00\x04\xc0\x00\x02\x01",
|
||||
Message{
|
||||
ID: 0x1234,
|
||||
Flags: 0x8180,
|
||||
Question: []Question{
|
||||
{
|
||||
Name: mustParseName("www.example.com"),
|
||||
Type: 1,
|
||||
Class: 1,
|
||||
},
|
||||
},
|
||||
Answer: []RR{
|
||||
{
|
||||
Name: mustParseName("www.example.com"),
|
||||
Type: 1,
|
||||
Class: 1,
|
||||
TTL: 128,
|
||||
Data: []byte{192, 0, 2, 1},
|
||||
},
|
||||
},
|
||||
Authority: []RR{},
|
||||
Additional: []RR{},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
} {
|
||||
message, err := MessageFromWireFormat([]byte(test.buf))
|
||||
if err != test.err || (err == nil && !messagesEqual(&message, &test.expected)) {
|
||||
t.Errorf("%+q\nreturned (%+v, %v)\nexpected (%+v, %v)",
|
||||
test.buf, message, err, test.expected, test.err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageWireFormatRoundTrip(t *testing.T) {
|
||||
for _, message := range []Message{
|
||||
{
|
||||
ID: 0x1234,
|
||||
Flags: 0x0100,
|
||||
Question: []Question{
|
||||
{
|
||||
Name: mustParseName("www.example.com"),
|
||||
Type: 1,
|
||||
Class: 1,
|
||||
},
|
||||
{
|
||||
Name: mustParseName("www2.example.com"),
|
||||
Type: 2,
|
||||
Class: 2,
|
||||
},
|
||||
},
|
||||
Answer: []RR{
|
||||
{
|
||||
Name: mustParseName("abc"),
|
||||
Type: 2,
|
||||
Class: 3,
|
||||
TTL: 0xffffffff,
|
||||
Data: []byte{1},
|
||||
},
|
||||
{
|
||||
Name: mustParseName("xyz"),
|
||||
Type: 2,
|
||||
Class: 3,
|
||||
TTL: 255,
|
||||
Data: []byte{},
|
||||
},
|
||||
},
|
||||
Authority: []RR{
|
||||
{
|
||||
Name: mustParseName("."),
|
||||
Type: 65535,
|
||||
Class: 65535,
|
||||
TTL: 0,
|
||||
Data: []byte("XXXXXXXXXXXXXXXXXXX"),
|
||||
},
|
||||
},
|
||||
Additional: []RR{},
|
||||
},
|
||||
} {
|
||||
buf, err := message.WireFormat()
|
||||
if err != nil {
|
||||
t.Errorf("%+v cannot make wire format: %v", message, err)
|
||||
continue
|
||||
}
|
||||
message2, err := MessageFromWireFormat(buf)
|
||||
if err != nil {
|
||||
t.Errorf("%+q cannot parse wire format: %v", buf, err)
|
||||
continue
|
||||
}
|
||||
if !messagesEqual(&message, &message2) {
|
||||
t.Errorf("messages unequal\nbefore: %+v\n after: %+v", message, message2)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeRDataTXT(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
p []byte
|
||||
decoded []byte
|
||||
err error
|
||||
}{
|
||||
{[]byte{}, nil, io.ErrUnexpectedEOF},
|
||||
{[]byte("\x00"), []byte{}, nil},
|
||||
{[]byte("\x01"), nil, io.ErrUnexpectedEOF},
|
||||
} {
|
||||
decoded, err := DecodeRDataTXT(test.p)
|
||||
if err != test.err || (err == nil && !bytes.Equal(decoded, test.decoded)) {
|
||||
t.Errorf("%+q\nreturned (%+q, %v)\nexpected (%+q, %v)",
|
||||
test.p, decoded, err, test.decoded, test.err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeRDataTXT(t *testing.T) {
|
||||
// Encoding 0 bytes needs to return at least a single length octet of
|
||||
// zero, not an empty slice.
|
||||
p := make([]byte, 0)
|
||||
encoded := EncodeRDataTXT(p)
|
||||
if len(encoded) < 0 {
|
||||
t.Errorf("EncodeRDataTXT(%v) returned %v", p, encoded)
|
||||
}
|
||||
|
||||
// 255 bytes should be able to be encoded into 256 bytes.
|
||||
p = make([]byte, 255)
|
||||
encoded = EncodeRDataTXT(p)
|
||||
if len(encoded) > 256 {
|
||||
t.Errorf("EncodeRDataTXT(%d bytes) returned %d bytes", len(p), len(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRDataTXTRoundTrip(t *testing.T) {
|
||||
for _, p := range [][]byte{
|
||||
{},
|
||||
[]byte("\x00"),
|
||||
{
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
|
||||
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
|
||||
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
|
||||
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
|
||||
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
|
||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
|
||||
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
|
||||
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
|
||||
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
|
||||
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
|
||||
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
|
||||
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
|
||||
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
|
||||
},
|
||||
} {
|
||||
rdata := EncodeRDataTXT(p)
|
||||
decoded, err := DecodeRDataTXT(rdata)
|
||||
if err != nil || !bytes.Equal(decoded, p) {
|
||||
t.Errorf("%+q returned (%+q, %v)", p, decoded, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
23
internal/dnsttcore/dns/fuzz.go
Normal file
23
internal/dnsttcore/dns/fuzz.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build gofuzz
|
||||
// +build gofuzz
|
||||
|
||||
// Fuzzing driver for https://github.com/dvyukov/go-fuzz.
|
||||
// go get -u github.com/dvyukov/go-fuzz/go-fuzz github.com/dvyukov/go-fuzz/go-fuzz-build
|
||||
// $GOPATH/bin/go-fuzz-build
|
||||
// $GOPATH/bin/go-fuzz
|
||||
//
|
||||
// Related link: https://blog.cloudflare.com/dns-parser-meet-go-fuzzer/
|
||||
|
||||
package dns
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
msg, err := MessageFromWireFormat(data)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
_, err = msg.WireFormat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return 1 // prioritize this input
|
||||
}
|
||||
276
internal/dnsttcore/noise/noise.go
Normal file
276
internal/dnsttcore/noise/noise.go
Normal file
@@ -0,0 +1,276 @@
|
||||
// Package noise provides a net.Conn-like interface for a
|
||||
// Noise_NK_25519_ChaChaPoly_BLAKE2s. It encodes Noise messages onto a reliable
|
||||
// stream using 16-bit length prefixes.
|
||||
//
|
||||
// https://noiseprotocol.org/noise.html
|
||||
package noise
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
// The length of public and private keys as returned by GeneratePrivkey.
|
||||
const KeyLen = 32
|
||||
|
||||
// cipherSuite represents 25519_ChaChaPoly_BLAKE2s.
|
||||
var cipherSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s)
|
||||
|
||||
// readMessage reads a length-prefixed message from r. It returns a nil error
|
||||
// only when a complete message was read. It returns io.EOF only when there were
|
||||
// 0 bytes remaining to read from r. It returns io.ErrUnexpectedEOF when EOF
|
||||
// occurs in the middle of an encoded message.
|
||||
func readMessage(r io.Reader) ([]byte, error) {
|
||||
var length uint16
|
||||
err := binary.Read(r, binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
// We may return a real io.EOF only here.
|
||||
return nil, err
|
||||
}
|
||||
msg := make([]byte, int(length))
|
||||
_, err = io.ReadFull(r, msg)
|
||||
// Here we must change io.EOF to io.ErrUnexpectedEOF.
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// writeMessage writes msg as a length-prefixed message to w. It panics if the
|
||||
// length of msg cannot be represented in 16 bits.
|
||||
func writeMessage(w io.Writer, msg []byte) error {
|
||||
length := uint16(len(msg))
|
||||
if int(length) != len(msg) {
|
||||
panic(len(msg))
|
||||
}
|
||||
err := binary.Write(w, binary.BigEndian, length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(msg)
|
||||
return err
|
||||
}
|
||||
|
||||
// socket is the internal type that represents a Noise-wrapped
|
||||
// io.ReadWriteCloser.
|
||||
type socket struct {
|
||||
recvPipe *io.PipeReader
|
||||
sendCipher *noise.CipherState
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newSocket(rwc io.ReadWriteCloser, recvCipher, sendCipher *noise.CipherState) *socket {
|
||||
pr, pw := io.Pipe()
|
||||
// This loop calls readMessage, decrypts the messages, and feeds them
|
||||
// into recvPipe where they will be returned from Read.
|
||||
go func() (err error) {
|
||||
defer func() {
|
||||
pw.CloseWithError(err)
|
||||
}()
|
||||
for {
|
||||
msg, err := readMessage(rwc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p, err := recvCipher.Decrypt(nil, nil, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = pw.Write(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &socket{
|
||||
sendCipher: sendCipher,
|
||||
recvPipe: pr,
|
||||
ReadWriteCloser: rwc,
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads decrypted data from the wrapped io.Reader.
|
||||
func (s *socket) Read(p []byte) (int, error) {
|
||||
return s.recvPipe.Read(p)
|
||||
}
|
||||
|
||||
// Write writes encrypted data from the wrapped io.Writer.
|
||||
func (s *socket) Write(p []byte) (int, error) {
|
||||
total := 0
|
||||
for len(p) > 0 {
|
||||
n := len(p)
|
||||
if n > 4096 {
|
||||
n = 4096
|
||||
}
|
||||
msg, err := s.sendCipher.Encrypt(nil, nil, p[:n])
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
err = writeMessage(s.ReadWriteCloser, msg)
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += n
|
||||
p = p[n:]
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// newConfig instantiates configuration settings that are common to clients and
|
||||
// servers.
|
||||
func newConfig() noise.Config {
|
||||
return noise.Config{
|
||||
CipherSuite: cipherSuite,
|
||||
Pattern: noise.HandshakeNK,
|
||||
Prologue: []byte("dnstt 2020-04-13"),
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient wraps an io.ReadWriteCloser in a Noise protocol as a client, and
|
||||
// returns after completing the handshake. It returns a non-nil error if there
|
||||
// is an error during the handshake.
|
||||
func NewClient(rwc io.ReadWriteCloser, serverPubkey []byte) (io.ReadWriteCloser, error) {
|
||||
config := newConfig()
|
||||
config.Initiator = true
|
||||
config.PeerStatic = serverPubkey
|
||||
handshakeState, err := noise.NewHandshakeState(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// -> e, es
|
||||
msg, _, _, err := handshakeState.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = writeMessage(rwc, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// <- e, es
|
||||
msg, err = readMessage(rwc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload, sendCipher, recvCipher, err := handshakeState.ReadMessage(nil, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(payload) != 0 {
|
||||
return nil, errors.New("unexpected server payload")
|
||||
}
|
||||
|
||||
return newSocket(rwc, recvCipher, sendCipher), nil
|
||||
}
|
||||
|
||||
// NewClient wraps an io.ReadWriteCloser in a Noise protocol as a server, and
|
||||
// returns after completing the handshake. It returns a non-nil error if there
|
||||
// is an error during the handshake.
|
||||
func NewServer(rwc io.ReadWriteCloser, serverPrivkey []byte) (io.ReadWriteCloser, error) {
|
||||
config := newConfig()
|
||||
config.Initiator = false
|
||||
config.StaticKeypair = noise.DHKey{
|
||||
Private: serverPrivkey,
|
||||
Public: PubkeyFromPrivkey(serverPrivkey),
|
||||
}
|
||||
handshakeState, err := noise.NewHandshakeState(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// -> e, es
|
||||
msg, err := readMessage(rwc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload, _, _, err := handshakeState.ReadMessage(nil, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(payload) != 0 {
|
||||
return nil, errors.New("unexpected client payload")
|
||||
}
|
||||
|
||||
// <- e, es
|
||||
msg, recvCipher, sendCipher, err := handshakeState.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = writeMessage(rwc, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newSocket(rwc, recvCipher, sendCipher), nil
|
||||
}
|
||||
|
||||
// GeneratePrivkey generates a private key. The corresponding public key can be
|
||||
// derived using PubkeyFromPrivkey.
|
||||
func GeneratePrivkey() ([]byte, error) {
|
||||
pair, err := noise.DH25519.GenerateKeypair(rand.Reader)
|
||||
return pair.Private, err
|
||||
}
|
||||
|
||||
// PubkeyFromPrivkey returns the public key that corresponds to privkey.
|
||||
func PubkeyFromPrivkey(privkey []byte) []byte {
|
||||
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pubkey
|
||||
}
|
||||
|
||||
// ReadKey reads a hex-encoded key from r. r must consist of a single line, with
|
||||
// or without a '\n' line terminator. The line must consist of KeyLen
|
||||
// hex-encoded bytes.
|
||||
func ReadKey(r io.Reader) ([]byte, error) {
|
||||
br := bufio.NewReader(io.LimitReader(r, 100))
|
||||
line, err := br.ReadString('\n')
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
if err == nil {
|
||||
// Check that we're at EOF.
|
||||
_, err = br.ReadByte()
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
} else if err == nil {
|
||||
err = fmt.Errorf("file contains more than one line")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
return DecodeKey(line)
|
||||
}
|
||||
|
||||
// WriteKey writes the hex-encoded key in a single line to w.
|
||||
func WriteKey(w io.Writer, key []byte) error {
|
||||
_, err := fmt.Fprintf(w, "%x\n", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// DecodeKey decodes a hex-encoded private or public key.
|
||||
func DecodeKey(s string) ([]byte, error) {
|
||||
key, err := hex.DecodeString(s)
|
||||
if err == nil && len(key) != KeyLen {
|
||||
err = fmt.Errorf("length is %d, expected %d", len(key), KeyLen)
|
||||
}
|
||||
return key, err
|
||||
}
|
||||
|
||||
// EncodeKey encodes a hex-encoded private or public key.
|
||||
func EncodeKey(key []byte) string {
|
||||
return hex.EncodeToString(key)
|
||||
}
|
||||
218
internal/dnsttcore/noise/noise_test.go
Normal file
218
internal/dnsttcore/noise/noise_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package noise
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
)
|
||||
|
||||
func allMessages(buf []byte) ([][]byte, error) {
|
||||
var messages [][]byte
|
||||
r := bytes.NewReader(buf)
|
||||
for {
|
||||
msg, err := readMessage(r)
|
||||
if err != nil {
|
||||
return messages, err
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func messagesEqual(a, b [][]byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if !bytes.Equal(a[i], b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestReadMessage(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
input string
|
||||
messages [][]byte
|
||||
err error
|
||||
}{
|
||||
{"", [][]byte{}, io.EOF},
|
||||
{"\x00", [][]byte{}, io.ErrUnexpectedEOF},
|
||||
{"\x00\x00", [][]byte{{}}, io.EOF},
|
||||
{"\x00\x00\x00", [][]byte{{}}, io.ErrUnexpectedEOF},
|
||||
{"\x00\x01", [][]byte{}, io.ErrUnexpectedEOF},
|
||||
{"\x00\x05hello\x00\x05world", [][]byte{[]byte("hello"), []byte("world")}, io.EOF},
|
||||
} {
|
||||
packets, err := allMessages([]byte(test.input))
|
||||
if !messagesEqual(packets, test.messages) || err != test.err {
|
||||
t.Errorf("%x\nreturned %x %v\nexpected %x %v",
|
||||
test.input, packets, err, test.messages, test.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageRoundTrip(t *testing.T) {
|
||||
for _, messages := range [][][]byte{
|
||||
{},
|
||||
} {
|
||||
var buf bytes.Buffer
|
||||
for _, msg := range messages {
|
||||
err := writeMessage(&buf, msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
output, err := allMessages(buf.Bytes())
|
||||
if !messagesEqual(output, messages) || err != io.EOF {
|
||||
t.Errorf("%x roundtripped to %x %v",
|
||||
messages, output, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadKey(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
input string
|
||||
output []byte
|
||||
}{
|
||||
{"", nil},
|
||||
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", nil},
|
||||
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
|
||||
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
|
||||
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", nil},
|
||||
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\nX", nil},
|
||||
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n", nil},
|
||||
{"\n0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
|
||||
{"X123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
|
||||
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", nil},
|
||||
} {
|
||||
output, err := ReadKey(bytes.NewReader([]byte(test.input)))
|
||||
if test.output == nil {
|
||||
if err == nil {
|
||||
t.Errorf("%+q expected error", test.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("%+q returned error %v", test.input, err)
|
||||
} else if !bytes.Equal(output, test.output) {
|
||||
t.Errorf("%+q got %x, expected %x", test.input, output, test.output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnexpectedPayload(t *testing.T) {
|
||||
privkey, err := GeneratePrivkey()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pubkey := PubkeyFromPrivkey(privkey)
|
||||
|
||||
// Test the client sending an unexpected payload.
|
||||
clientWithPayload := func(rwc io.ReadWriteCloser) error {
|
||||
config := newConfig()
|
||||
config.Initiator = true
|
||||
config.PeerStatic = pubkey
|
||||
handshakeState, err := noise.NewHandshakeState(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// -> e, es
|
||||
msg, _, _, err := handshakeState.WriteMessage(nil, []byte("payload"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writeMessage(rwc, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// <- e, es
|
||||
// Return nil for all errors after this point, because we expect
|
||||
// the server to have failed, but we want to keep up the game
|
||||
// just in case the server did not fail.
|
||||
msg, err = readMessage(rwc)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, _, _, err = handshakeState.ReadMessage(nil, msg)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func() {
|
||||
c, s := net.Pipe()
|
||||
defer s.Close()
|
||||
|
||||
// Fake a client side that sends a payload.
|
||||
go func() {
|
||||
defer c.Close()
|
||||
err := clientWithPayload(c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
server, err := NewServer(s, privkey)
|
||||
if err == nil || err.Error() != "unexpected client payload" || server != nil {
|
||||
t.Errorf("NewServer got (%T, %v)", server, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test the server sending an unexpected payload.
|
||||
serverWithPayload := func(rwc io.ReadWriteCloser) error {
|
||||
config := newConfig()
|
||||
config.Initiator = false
|
||||
config.StaticKeypair = noise.DHKey{Private: privkey, Public: pubkey}
|
||||
handshakeState, err := noise.NewHandshakeState(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// -> e, es
|
||||
msg, err := readMessage(rwc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _, _, err = handshakeState.ReadMessage(nil, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// <- e, es
|
||||
msg, _, _, err = handshakeState.WriteMessage(nil, []byte("payload"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = writeMessage(rwc, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func() {
|
||||
c, s := net.Pipe()
|
||||
defer c.Close()
|
||||
|
||||
// Fake a server side that sends a payload.
|
||||
go func() {
|
||||
defer s.Close()
|
||||
err := serverWithPayload(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := NewClient(c, pubkey)
|
||||
if err == nil || err.Error() != "unexpected server payload" || client != nil {
|
||||
t.Errorf("NewClient got (%T, %v)", client, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
28
internal/dnsttcore/turbotunnel/clientid.go
Normal file
28
internal/dnsttcore/turbotunnel/clientid.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package turbotunnel
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// ClientID is an abstract identifier that binds together all the communications
|
||||
// belonging to a single client session, even though those communications may
|
||||
// arrive from multiple IP addresses or over multiple lower-level connections.
|
||||
// It plays the same role that an (IP address, port number) tuple plays in a
|
||||
// net.UDPConn: it's the return address pertaining to a long-lived abstract
|
||||
// client session. The client attaches its ClientID to each of its
|
||||
// communications, enabling the server to disambiguate requests among its many
|
||||
// clients. ClientID implements the net.Addr interface.
|
||||
type ClientID [8]byte
|
||||
|
||||
func NewClientID() ClientID {
|
||||
var id ClientID
|
||||
_, err := rand.Read(id[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ClientID) Network() string { return "clientid" }
|
||||
func (id ClientID) String() string { return hex.EncodeToString(id[:]) }
|
||||
22
internal/dnsttcore/turbotunnel/consts.go
Normal file
22
internal/dnsttcore/turbotunnel/consts.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Package turbotunnel is facilities for embedding packet-based reliability
|
||||
// protocols inside other protocols.
|
||||
//
|
||||
// https://github.com/net4people/bbs/issues/9
|
||||
package turbotunnel
|
||||
|
||||
import "errors"
|
||||
|
||||
// QueueSize is the size of send and receive queues in QueuePacketConn and
|
||||
// RemoteMap.
|
||||
const QueueSize = 128
|
||||
|
||||
var errClosedPacketConn = errors.New("operation on closed connection")
|
||||
var errNotImplemented = errors.New("not implemented")
|
||||
|
||||
// DummyAddr is a placeholder net.Addr, for when a programming interface
|
||||
// requires a net.Addr but there is none relevant. All DummyAddrs compare equal
|
||||
// to each other.
|
||||
type DummyAddr struct{}
|
||||
|
||||
func (addr DummyAddr) Network() string { return "dummy" }
|
||||
func (addr DummyAddr) String() string { return "dummy" }
|
||||
162
internal/dnsttcore/turbotunnel/queuepacketconn.go
Normal file
162
internal/dnsttcore/turbotunnel/queuepacketconn.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package turbotunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// taggedPacket is a combination of a []byte and a net.Addr, encapsulating the
|
||||
// return type of PacketConn.ReadFrom.
|
||||
type taggedPacket struct {
|
||||
P []byte
|
||||
Addr net.Addr
|
||||
}
|
||||
|
||||
// QueuePacketConn implements net.PacketConn by storing queues of packets. There
|
||||
// is one incoming queue (where packets are additionally tagged by the source
|
||||
// address of the peer that sent them). There are many outgoing queues, one for
|
||||
// each remote peer address that has been recently seen. The QueueIncoming
|
||||
// method inserts a packet into the incoming queue, to eventually be returned by
|
||||
// ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue,
|
||||
// which can later by accessed through the OutgoingQueue method.
|
||||
//
|
||||
// Besides the outgoing queues, there is also a one-element "stash" for each
|
||||
// remote peer address. You can stash a packet using the Stash method, and get
|
||||
// it back later by receiving from the channel returned by Unstash. The stash is
|
||||
// meant as a convenient place to temporarily store a single packet, such as
|
||||
// when you've read one too many packets from the send queue and need to store
|
||||
// the extra packet to be processed first in the next pass. It's the caller's
|
||||
// responsibility to Unstash what they have Stashed. Calling Stash does not put
|
||||
// the packet at the head of the send queue; if there is the possibility that a
|
||||
// packet has been stashed, it must be checked for by calling Unstash in
|
||||
// addition to OutgoingQueue.
|
||||
type QueuePacketConn struct {
|
||||
remotes *RemoteMap
|
||||
localAddr net.Addr
|
||||
recvQueue chan taggedPacket
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
// What error to return when the QueuePacketConn is closed.
|
||||
err atomic.Value
|
||||
}
|
||||
|
||||
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent peers
|
||||
// for at least a duration of timeout.
|
||||
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
|
||||
return &QueuePacketConn{
|
||||
remotes: NewRemoteMap(timeout),
|
||||
localAddr: localAddr,
|
||||
recvQueue: make(chan taggedPacket, QueueSize),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// QueueIncoming queues and incoming packet and its source address, to be
|
||||
// returned in a future call to ReadFrom.
|
||||
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
// If we're closed, silently drop it.
|
||||
return
|
||||
default:
|
||||
}
|
||||
// Copy the slice so that the caller may reuse it.
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.recvQueue <- taggedPacket{buf, addr}:
|
||||
default:
|
||||
// Drop the incoming packet if the receive queue is full.
|
||||
}
|
||||
}
|
||||
|
||||
// OutgoingQueue returns the queue of outgoing packets corresponding to addr,
|
||||
// creating it if necessary. The contents of the queue will be packets that are
|
||||
// written to the address in question using WriteTo.
|
||||
func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
|
||||
return c.remotes.SendQueue(addr)
|
||||
}
|
||||
|
||||
// Stash places p in the stash for addr, if the stash is not already occupied.
|
||||
// Returns true if the packet was placed in the stash, or false if the stash was
|
||||
// already occupied. This method is similar to WriteTo, except that it puts the
|
||||
// packet in the stash queue (accessible via Unstash), rather than the outgoing
|
||||
// queue (accessible via OutgoingQueue).
|
||||
func (c *QueuePacketConn) Stash(p []byte, addr net.Addr) bool {
|
||||
return c.remotes.Stash(addr, p)
|
||||
}
|
||||
|
||||
// Unstash returns the channel that represents the stash for addr.
|
||||
func (c *QueuePacketConn) Unstash(addr net.Addr) <-chan []byte {
|
||||
return c.remotes.Unstash(addr)
|
||||
}
|
||||
|
||||
// ReadFrom returns a packet and address previously stored by QueueIncoming.
|
||||
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
case packet := <-c.recvQueue:
|
||||
return copy(p, packet.P), packet.Addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo queues an outgoing packet for the given address. The queue can later
|
||||
// be retrieved using the OutgoingQueue method.
|
||||
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
default:
|
||||
}
|
||||
// Copy the slice so that the caller may reuse it.
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.remotes.SendQueue(addr) <- buf:
|
||||
return len(buf), nil
|
||||
default:
|
||||
// Drop the outgoing packet if the send queue is full.
|
||||
return len(buf), nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeWithError unblocks pending operations and makes future operations fail
|
||||
// with the given error. If err is nil, it becomes errClosedPacketConn.
|
||||
func (c *QueuePacketConn) closeWithError(err error) error {
|
||||
var newlyClosed bool
|
||||
c.closeOnce.Do(func() {
|
||||
newlyClosed = true
|
||||
// Store the error to be returned by future PacketConn
|
||||
// operations.
|
||||
if err == nil {
|
||||
err = errClosedPacketConn
|
||||
}
|
||||
c.err.Store(err)
|
||||
close(c.closed)
|
||||
})
|
||||
if !newlyClosed {
|
||||
return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close unblocks pending operations and makes future operations fail with a
|
||||
// "closed connection" error.
|
||||
func (c *QueuePacketConn) Close() error {
|
||||
return c.closeWithError(nil)
|
||||
}
|
||||
|
||||
// LocalAddr returns the localAddr value that was passed to NewQueuePacketConn.
|
||||
func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr }
|
||||
|
||||
func (c *QueuePacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (c *QueuePacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|
||||
177
internal/dnsttcore/turbotunnel/remotemap.go
Normal file
177
internal/dnsttcore/turbotunnel/remotemap.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package turbotunnel
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// remoteRecord is a record of a recently seen remote peer, with the time it was
|
||||
// last seen and queues of outgoing packets.
|
||||
type remoteRecord struct {
|
||||
Addr net.Addr
|
||||
LastSeen time.Time
|
||||
SendQueue chan []byte
|
||||
Stash chan []byte
|
||||
}
|
||||
|
||||
// RemoteMap manages a mapping of live remote peers, keyed by address, to their
|
||||
// respective send queues. Each peer has two queues: a primary send queue, and a
|
||||
// "stash". The primary send queue is returned by the SendQueue method. The
|
||||
// stash is an auxiliary one-element queue accessed using the Stash and Unstash
|
||||
// methods. The stash is meant for use by callers that need to "unread" a packet
|
||||
// that's already been removed from the primary send queue.
|
||||
//
|
||||
// RemoteMap's functions are safe to call from multiple goroutines.
|
||||
type RemoteMap struct {
|
||||
// We use an inner structure to avoid exposing public heap.Interface
|
||||
// functions to users of remoteMap.
|
||||
inner remoteMapInner
|
||||
// Synchronizes access to inner.
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// NewRemoteMap creates a RemoteMap that expires peers after a timeout.
|
||||
//
|
||||
// If the timeout is 0, peers never expire.
|
||||
//
|
||||
// The timeout does not have to be kept in sync with smux's idle timeout. If a
|
||||
// peer is removed from the map while the smux session is still live, the worst
|
||||
// that can happen is a loss of whatever packets were in the send queue at the
|
||||
// time. If smux later decides to send more packets to the same peer, we'll
|
||||
// instantiate a new send queue, and if the peer is ever seen again with a
|
||||
// matching address, we'll deliver them.
|
||||
func NewRemoteMap(timeout time.Duration) *RemoteMap {
|
||||
m := &RemoteMap{
|
||||
inner: remoteMapInner{
|
||||
byAge: make([]*remoteRecord, 0),
|
||||
byAddr: make(map[net.Addr]int),
|
||||
},
|
||||
}
|
||||
if timeout > 0 {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(timeout / 2)
|
||||
now := time.Now()
|
||||
m.lock.Lock()
|
||||
m.inner.removeExpired(now, timeout)
|
||||
m.lock.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// SendQueue returns the send queue corresponding to addr, creating it if
|
||||
// necessary.
|
||||
func (m *RemoteMap) SendQueue(addr net.Addr) chan []byte {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
return m.inner.Lookup(addr, time.Now()).SendQueue
|
||||
}
|
||||
|
||||
// Stash places p in the stash corresponding to addr, if the stash is not
|
||||
// already occupied. Returns true if the p was placed in the stash, false
|
||||
// otherwise.
|
||||
func (m *RemoteMap) Stash(addr net.Addr, p []byte) bool {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
select {
|
||||
case m.inner.Lookup(addr, time.Now()).Stash <- p:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Unstash returns the channel that reads from the stash for addr.
|
||||
func (m *RemoteMap) Unstash(addr net.Addr) <-chan []byte {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
return m.inner.Lookup(addr, time.Now()).Stash
|
||||
}
|
||||
|
||||
// remoteMapInner is the inner type of RemoteMap, implementing heap.Interface.
|
||||
// byAge is the backing store, a heap ordered by LastSeen time, to facilitate
|
||||
// expiring old records. byAddr is a map from addresses to heap indices, to
|
||||
// allow looking up by address. Unlike RemoteMap, remoteMapInner requires
|
||||
// external synchonization.
|
||||
type remoteMapInner struct {
|
||||
byAge []*remoteRecord
|
||||
byAddr map[net.Addr]int
|
||||
}
|
||||
|
||||
// removeExpired removes all records whose LastSeen timestamp is more than
|
||||
// timeout in the past.
|
||||
func (inner *remoteMapInner) removeExpired(now time.Time, timeout time.Duration) {
|
||||
for len(inner.byAge) > 0 && now.Sub(inner.byAge[0].LastSeen) >= timeout {
|
||||
record := heap.Pop(inner).(*remoteRecord)
|
||||
close(record.SendQueue)
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup finds the existing record corresponding to addr, or creates a new
|
||||
// one if none exists yet. It updates the record's LastSeen time and returns the
|
||||
// record.
|
||||
func (inner *remoteMapInner) Lookup(addr net.Addr, now time.Time) *remoteRecord {
|
||||
var record *remoteRecord
|
||||
i, ok := inner.byAddr[addr]
|
||||
if ok {
|
||||
// Found one, update its LastSeen.
|
||||
record = inner.byAge[i]
|
||||
record.LastSeen = now
|
||||
heap.Fix(inner, i)
|
||||
} else {
|
||||
// Not found, create a new one.
|
||||
record = &remoteRecord{
|
||||
Addr: addr,
|
||||
LastSeen: now,
|
||||
SendQueue: make(chan []byte, QueueSize),
|
||||
Stash: make(chan []byte, 1),
|
||||
}
|
||||
heap.Push(inner, record)
|
||||
}
|
||||
return record
|
||||
}
|
||||
|
||||
// heap.Interface for remoteMapInner.
|
||||
|
||||
func (inner *remoteMapInner) Len() int {
|
||||
if len(inner.byAge) != len(inner.byAddr) {
|
||||
panic("inconsistent remoteMap")
|
||||
}
|
||||
return len(inner.byAge)
|
||||
}
|
||||
|
||||
func (inner *remoteMapInner) Less(i, j int) bool {
|
||||
return inner.byAge[i].LastSeen.Before(inner.byAge[j].LastSeen)
|
||||
}
|
||||
|
||||
func (inner *remoteMapInner) Swap(i, j int) {
|
||||
inner.byAge[i], inner.byAge[j] = inner.byAge[j], inner.byAge[i]
|
||||
inner.byAddr[inner.byAge[i].Addr] = i
|
||||
inner.byAddr[inner.byAge[j].Addr] = j
|
||||
}
|
||||
|
||||
func (inner *remoteMapInner) Push(x interface{}) {
|
||||
record := x.(*remoteRecord)
|
||||
if _, ok := inner.byAddr[record.Addr]; ok {
|
||||
panic("duplicate address in remoteMap")
|
||||
}
|
||||
// Insert into byAddr map.
|
||||
inner.byAddr[record.Addr] = len(inner.byAge)
|
||||
// Insert into byAge slice.
|
||||
inner.byAge = append(inner.byAge, record)
|
||||
}
|
||||
|
||||
func (inner *remoteMapInner) Pop() interface{} {
|
||||
n := len(inner.byAddr)
|
||||
// Remove from byAge slice.
|
||||
record := inner.byAge[n-1]
|
||||
inner.byAge[n-1] = nil
|
||||
inner.byAge = inner.byAge[:n-1]
|
||||
// Remove from byAddr map.
|
||||
delete(inner.byAddr, record.Addr)
|
||||
return record
|
||||
}
|
||||
Reference in New Issue
Block a user