This commit is contained in:
2026-05-16 00:18:06 -03:00
commit 92941e68a2
66 changed files with 10352 additions and 0 deletions

View File

@@ -0,0 +1,575 @@
// Package dns deals with encoding and decoding DNS wire format.
package dns
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
)
// The maximum number of DNS name compression pointers we are willing to follow.
// Without something like this, infinite loops are possible.
const compressionPointerLimit = 10
var (
// ErrZeroLengthLabel is the error returned for names that contain a
// zero-length label, like "example..com".
ErrZeroLengthLabel = errors.New("name contains a zero-length label")
// ErrLabelTooLong is the error returned for labels that are longer than
// 63 octets.
ErrLabelTooLong = errors.New("name contains a label longer than 63 octets")
// ErrNameTooLong is the error returned for names whose encoded
// representation is longer than 255 octets.
ErrNameTooLong = errors.New("name is longer than 255 octets")
// ErrReservedLabelType is the error returned when reading a label type
// prefix whose two most significant bits are not 00 or 11.
ErrReservedLabelType = errors.New("reserved label type")
// ErrTooManyPointers is the error returned when reading a compressed
// name that has too many compression pointers.
ErrTooManyPointers = errors.New("too many compression pointers")
// ErrTrailingBytes is the error returned when bytes remain in the parse
// buffer after parsing a message.
ErrTrailingBytes = errors.New("trailing bytes after message")
// ErrIntegerOverflow is the error returned when trying to encode an
// integer greater than 65535 into a 16-bit field.
ErrIntegerOverflow = errors.New("integer overflow")
)
const (
// https://tools.ietf.org/html/rfc1035#section-3.2.2
RRTypeTXT = 16
// https://tools.ietf.org/html/rfc6891#section-6.1.1
RRTypeOPT = 41
// https://tools.ietf.org/html/rfc1035#section-3.2.4
ClassIN = 1
// https://tools.ietf.org/html/rfc1035#section-4.1.1
RcodeNoError = 0 // a.k.a. NOERROR
RcodeFormatError = 1 // a.k.a. FORMERR
RcodeNameError = 3 // a.k.a. NXDOMAIN
RcodeNotImplemented = 4 // a.k.a. NOTIMPL
// https://tools.ietf.org/html/rfc6891#section-9
ExtendedRcodeBadVers = 16 // a.k.a. BADVERS
)
// Name represents a domain name, a sequence of labels each of which is 63
// octets or less in length.
//
// https://tools.ietf.org/html/rfc1035#section-3.1
type Name [][]byte
// NewName returns a Name from a slice of labels, after checking the labels for
// validity. Does not include a zero-length label at the end of the slice.
func NewName(labels [][]byte) (Name, error) {
name := Name(labels)
// https://tools.ietf.org/html/rfc1035#section-2.3.4
// Various objects and parameters in the DNS have size limits.
// labels 63 octets or less
// names 255 octets or less
for _, label := range labels {
if len(label) == 0 {
return nil, ErrZeroLengthLabel
}
if len(label) > 63 {
return nil, ErrLabelTooLong
}
}
// Check the total length.
builder := newMessageBuilder()
builder.WriteName(name)
if len(builder.Bytes()) > 255 {
return nil, ErrNameTooLong
}
return name, nil
}
// ParseName returns a new Name from a string of labels separated by dots, after
// checking the name for validity. A single dot at the end of the string is
// ignored.
func ParseName(s string) (Name, error) {
b := bytes.TrimSuffix([]byte(s), []byte("."))
if len(b) == 0 {
// bytes.Split(b, ".") would return [""] in this case
return NewName([][]byte{})
} else {
return NewName(bytes.Split(b, []byte(".")))
}
}
// String returns a reversible string representation of name. Labels are
// separated by dots, and any bytes in a label that are outside the set
// [0-9A-Za-z-] are replaced with a \xXX hex escape sequence.
func (name Name) String() string {
if len(name) == 0 {
return "."
}
var buf strings.Builder
for i, label := range name {
if i > 0 {
buf.WriteByte('.')
}
for _, b := range label {
if b == '-' ||
('0' <= b && b <= '9') ||
('A' <= b && b <= 'Z') ||
('a' <= b && b <= 'z') {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "\\x%02x", b)
}
}
}
return buf.String()
}
// TrimSuffix returns a Name with the given suffix removed, if it was present.
// The second return value indicates whether the suffix was present. If the
// suffix was not present, the first return value is nil.
func (name Name) TrimSuffix(suffix Name) (Name, bool) {
if len(name) < len(suffix) {
return nil, false
}
split := len(name) - len(suffix)
fore, aft := name[:split], name[split:]
for i := 0; i < len(aft); i++ {
if !bytes.Equal(bytes.ToLower(aft[i]), bytes.ToLower(suffix[i])) {
return nil, false
}
}
return fore, true
}
// Message represents a DNS message.
//
// https://tools.ietf.org/html/rfc1035#section-4.1
type Message struct {
ID uint16
Flags uint16
Question []Question
Answer []RR
Authority []RR
Additional []RR
}
// Opcode extracts the OPCODE part of the Flags field.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.1
func (message *Message) Opcode() uint16 {
return (message.Flags >> 11) & 0xf
}
// Rcode extracts the RCODE part of the Flags field.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.1
func (message *Message) Rcode() uint16 {
return message.Flags & 0x000f
}
// Question represents an entry in the question section of a message.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.2
type Question struct {
Name Name
Type uint16
Class uint16
}
// RR represents a resource record.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.3
type RR struct {
Name Name
Type uint16
Class uint16
TTL uint32
Data []byte
}
// readName parses a DNS name from r. It leaves r positioned just after the
// parsed name.
func readName(r io.ReadSeeker) (Name, error) {
var labels [][]byte
// We limit the number of compression pointers we are willing to follow.
numPointers := 0
// If we followed any compression pointers, we must finally seek to just
// past the first pointer.
var seekTo int64
loop:
for {
var labelType byte
err := binary.Read(r, binary.BigEndian, &labelType)
if err != nil {
return nil, err
}
switch labelType & 0xc0 {
case 0x00:
// This is an ordinary label.
// https://tools.ietf.org/html/rfc1035#section-3.1
length := int(labelType & 0x3f)
if length == 0 {
break loop
}
label := make([]byte, length)
_, err := io.ReadFull(r, label)
if err != nil {
return nil, err
}
labels = append(labels, label)
case 0xc0:
// This is a compression pointer.
// https://tools.ietf.org/html/rfc1035#section-4.1.4
upper := labelType & 0x3f
var lower byte
err := binary.Read(r, binary.BigEndian, &lower)
if err != nil {
return nil, err
}
offset := (uint16(upper) << 8) | uint16(lower)
if numPointers == 0 {
// The first time we encounter a pointer,
// remember our position so we can seek back to
// it when done.
seekTo, err = r.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
}
numPointers++
if numPointers > compressionPointerLimit {
return nil, ErrTooManyPointers
}
// Follow the pointer and continue.
_, err = r.Seek(int64(offset), io.SeekStart)
if err != nil {
return nil, err
}
default:
// "The 10 and 01 combinations are reserved for future
// use."
return nil, ErrReservedLabelType
}
}
// If we followed any pointers, then seek back to just after the first
// one.
if numPointers > 0 {
_, err := r.Seek(seekTo, io.SeekStart)
if err != nil {
return nil, err
}
}
return NewName(labels)
}
// readQuestion parses one entry from the Question section. It leaves r
// positioned just after the parsed entry.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.2
func readQuestion(r io.ReadSeeker) (Question, error) {
var question Question
var err error
question.Name, err = readName(r)
if err != nil {
return question, err
}
for _, ptr := range []*uint16{&question.Type, &question.Class} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return question, err
}
}
return question, nil
}
// readRR parses one resource record. It leaves r positioned just after the
// parsed resource record.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.3
func readRR(r io.ReadSeeker) (RR, error) {
var rr RR
var err error
rr.Name, err = readName(r)
if err != nil {
return rr, err
}
for _, ptr := range []*uint16{&rr.Type, &rr.Class} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return rr, err
}
}
err = binary.Read(r, binary.BigEndian, &rr.TTL)
if err != nil {
return rr, err
}
var rdLength uint16
err = binary.Read(r, binary.BigEndian, &rdLength)
if err != nil {
return rr, err
}
rr.Data = make([]byte, rdLength)
_, err = io.ReadFull(r, rr.Data)
if err != nil {
return rr, err
}
return rr, nil
}
// readMessage parses a complete DNS message. It leaves r positioned just after
// the parsed message.
func readMessage(r io.ReadSeeker) (Message, error) {
var message Message
// Header section
// https://tools.ietf.org/html/rfc1035#section-4.1.1
var qdCount, anCount, nsCount, arCount uint16
for _, ptr := range []*uint16{
&message.ID, &message.Flags,
&qdCount, &anCount, &nsCount, &arCount,
} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return message, err
}
}
// Question section
// https://tools.ietf.org/html/rfc1035#section-4.1.2
for i := 0; i < int(qdCount); i++ {
question, err := readQuestion(r)
if err != nil {
return message, err
}
message.Question = append(message.Question, question)
}
// Answer, Authority, and Additional sections
// https://tools.ietf.org/html/rfc1035#section-4.1.3
for _, rec := range []struct {
ptr *[]RR
count uint16
}{
{&message.Answer, anCount},
{&message.Authority, nsCount},
{&message.Additional, arCount},
} {
for i := 0; i < int(rec.count); i++ {
rr, err := readRR(r)
if err != nil {
return message, err
}
*rec.ptr = append(*rec.ptr, rr)
}
}
return message, nil
}
// MessageFromWireFormat parses a message from buf and returns a Message object.
// It returns ErrTrailingBytes if there are bytes remaining in buf after parsing
// is done.
func MessageFromWireFormat(buf []byte) (Message, error) {
r := bytes.NewReader(buf)
message, err := readMessage(r)
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == nil {
// Check for trailing bytes.
_, err = r.ReadByte()
if err == io.EOF {
err = nil
} else if err == nil {
err = ErrTrailingBytes
}
}
return message, err
}
// messageBuilder manages the state of serializing a DNS message. Its main
// function is to keep track of names already written for the purpose of name
// compression.
type messageBuilder struct {
w bytes.Buffer
nameCache map[string]int
}
// newMessageBuilder creates a new messageBuilder with an empty name cache.
func newMessageBuilder() *messageBuilder {
return &messageBuilder{
nameCache: make(map[string]int),
}
}
// Bytes returns the serialized DNS message as a slice of bytes.
func (builder *messageBuilder) Bytes() []byte {
return builder.w.Bytes()
}
// WriteName appends name to the in-progress messageBuilder, employing
// compression pointers to previously written names if possible.
func (builder *messageBuilder) WriteName(name Name) {
// https://tools.ietf.org/html/rfc1035#section-3.1
for i := range name {
// Has this suffix already been encoded in the message?
if ptr, ok := builder.nameCache[name[i:].String()]; ok && ptr&0x3fff == ptr {
// If so, we can write a compression pointer.
binary.Write(&builder.w, binary.BigEndian, uint16(0xc000|ptr))
return
}
// Not cached; we must encode this label verbatim. Store a cache
// entry pointing to the beginning of it.
builder.nameCache[name[i:].String()] = builder.w.Len()
length := len(name[i])
if length == 0 || length > 63 {
panic(length)
}
builder.w.WriteByte(byte(length))
builder.w.Write(name[i])
}
builder.w.WriteByte(0)
}
// WriteQuestion appends a Question section entry to the in-progress
// messageBuilder.
func (builder *messageBuilder) WriteQuestion(question *Question) {
// https://tools.ietf.org/html/rfc1035#section-4.1.2
builder.WriteName(question.Name)
binary.Write(&builder.w, binary.BigEndian, question.Type)
binary.Write(&builder.w, binary.BigEndian, question.Class)
}
// WriteRR appends a resource record to the in-progress messageBuilder. It
// returns ErrIntegerOverflow if the length of rr.Data does not fit in 16 bits.
func (builder *messageBuilder) WriteRR(rr *RR) error {
// https://tools.ietf.org/html/rfc1035#section-4.1.3
builder.WriteName(rr.Name)
binary.Write(&builder.w, binary.BigEndian, rr.Type)
binary.Write(&builder.w, binary.BigEndian, rr.Class)
binary.Write(&builder.w, binary.BigEndian, rr.TTL)
rdLength := uint16(len(rr.Data))
if int(rdLength) != len(rr.Data) {
return ErrIntegerOverflow
}
binary.Write(&builder.w, binary.BigEndian, rdLength)
builder.w.Write(rr.Data)
return nil
}
// WriteMessage appends a complete DNS message to the in-progress
// messageBuilder. It returns ErrIntegerOverflow if the number of entries in any
// section, or the length of the data in any resource record, does not fit in 16
// bits.
func (builder *messageBuilder) WriteMessage(message *Message) error {
// Header section
// https://tools.ietf.org/html/rfc1035#section-4.1.1
binary.Write(&builder.w, binary.BigEndian, message.ID)
binary.Write(&builder.w, binary.BigEndian, message.Flags)
for _, count := range []int{
len(message.Question),
len(message.Answer),
len(message.Authority),
len(message.Additional),
} {
count16 := uint16(count)
if int(count16) != count {
return ErrIntegerOverflow
}
binary.Write(&builder.w, binary.BigEndian, count16)
}
// Question section
// https://tools.ietf.org/html/rfc1035#section-4.1.2
for _, question := range message.Question {
builder.WriteQuestion(&question)
}
// Answer, Authority, and Additional sections
// https://tools.ietf.org/html/rfc1035#section-4.1.3
for _, rrs := range [][]RR{message.Answer, message.Authority, message.Additional} {
for _, rr := range rrs {
err := builder.WriteRR(&rr)
if err != nil {
return err
}
}
}
return nil
}
// WireFormat encodes a Message as a slice of bytes in DNS wire format. It
// returns ErrIntegerOverflow if the number of entries in any section, or the
// length of the data in any resource record, does not fit in 16 bits.
func (message *Message) WireFormat() ([]byte, error) {
builder := newMessageBuilder()
err := builder.WriteMessage(message)
if err != nil {
return nil, err
}
return builder.Bytes(), nil
}
// DecodeRDataTXT decodes TXT-DATA (as found in the RDATA for a resource record
// with TYPE=TXT) as a raw byte slice, by concatenating all the
// <character-string>s it contains.
//
// https://tools.ietf.org/html/rfc1035#section-3.3.14
func DecodeRDataTXT(p []byte) ([]byte, error) {
var buf bytes.Buffer
for {
if len(p) == 0 {
return nil, io.ErrUnexpectedEOF
}
n := int(p[0])
p = p[1:]
if len(p) < n {
return nil, io.ErrUnexpectedEOF
}
buf.Write(p[:n])
p = p[n:]
if len(p) == 0 {
break
}
}
return buf.Bytes(), nil
}
// EncodeRDataTXT encodes a slice of bytes as TXT-DATA, as appropriate for the
// RDATA of a resource record with TYPE=TXT. No length restriction is enforced
// here; that must be checked at a higher level.
//
// https://tools.ietf.org/html/rfc1035#section-3.3.14
func EncodeRDataTXT(p []byte) []byte {
// https://tools.ietf.org/html/rfc1035#section-3.3
// https://tools.ietf.org/html/rfc1035#section-3.3.14
// TXT data is a sequence of one or more <character-string>s, where
// <character-string> is a length octet followed by that number of
// octets.
var buf bytes.Buffer
for len(p) > 255 {
buf.WriteByte(255)
buf.Write(p[:255])
p = p[255:]
}
// Must write here, even if len(p) == 0, because it's "*one or more*
// <character-string>s".
buf.WriteByte(byte(len(p)))
buf.Write(p)
return buf.Bytes()
}

View File

@@ -0,0 +1,592 @@
package dns
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
"testing"
)
func namesEqual(a, b Name) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if !bytes.Equal(a[i], b[i]) {
return false
}
}
return true
}
func TestName(t *testing.T) {
for _, test := range []struct {
labels [][]byte
err error
s string
}{
{[][]byte{}, nil, "."},
{[][]byte{[]byte("test")}, nil, "test"},
{[][]byte{[]byte("a"), []byte("b"), []byte("c")}, nil, "a.b.c"},
{[][]byte{{}}, ErrZeroLengthLabel, ""},
{[][]byte{[]byte("a"), {}, []byte("c")}, ErrZeroLengthLabel, ""},
// 63 octets.
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE")}, nil,
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"},
// 64 octets.
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDEF")}, ErrLabelTooLong, ""},
// 64+64+64+62 octets.
{[][]byte{
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"),
}, nil,
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"},
// 64+64+64+63 octets.
{[][]byte{
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCD"),
}, ErrNameTooLong, ""},
// 127 one-octet labels.
{[][]byte{
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'},
}, nil,
"0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E"},
// 128 one-octet labels.
{[][]byte{
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
}, ErrNameTooLong, ""},
} {
// Test that NewName returns proper error codes, and otherwise
// returns an equal slice of labels.
name, err := NewName(test.labels)
if err != test.err || (err == nil && !namesEqual(name, test.labels)) {
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
test.labels, name, err, test.labels, test.err)
continue
}
if test.err != nil {
continue
}
// Test that the string version of the name comes out as
// expected.
s := name.String()
if s != test.s {
t.Errorf("%+q became string %+q, expected %+q", test.labels, s, test.s)
continue
}
// Test that parsing from a string back to a Name results in the
// original slice of labels.
name, err = ParseName(s)
if err != nil || !namesEqual(name, test.labels) {
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
test.labels, s, name, err, test.labels, nil)
continue
}
// A trailing dot should be ignored.
if !strings.HasSuffix(s, ".") {
dotName, dotErr := ParseName(s + ".")
if dotErr != err || !namesEqual(dotName, name) {
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
test.labels, s+".", dotName, dotErr, name, err)
continue
}
}
}
}
func TestParseName(t *testing.T) {
for _, test := range []struct {
s string
name Name
err error
}{
// This case can't be tested by TestName above because String
// will never produce "" (it produces "." instead).
{"", [][]byte{}, nil},
} {
name, err := ParseName(test.s)
if err != test.err || (err == nil && !namesEqual(name, test.name)) {
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
test.s, name, err, test.name, test.err)
continue
}
}
}
func unescapeString(s string) ([][]byte, error) {
if s == "." {
return [][]byte{}, nil
}
var result [][]byte
for _, label := range strings.Split(s, ".") {
var buf bytes.Buffer
i := 0
for i < len(label) {
switch label[i] {
case '\\':
if i+3 >= len(label) {
return nil, fmt.Errorf("truncated escape sequence at index %v", i)
}
if label[i+1] != 'x' {
return nil, fmt.Errorf("malformed escape sequence at index %v", i)
}
b, err := strconv.ParseUint(string(label[i+2:i+4]), 16, 8)
if err != nil {
return nil, fmt.Errorf("malformed hex sequence at index %v", i+2)
}
buf.WriteByte(byte(b))
i += 4
default:
buf.WriteByte(label[i])
i++
}
}
result = append(result, buf.Bytes())
}
return result, nil
}
func TestNameString(t *testing.T) {
for _, test := range []struct {
name Name
s string
}{
{[][]byte{}, "."},
{[][]byte{[]byte("\x00"), []byte("a.b"), []byte("c\nd\\")}, "\\x00.a\\x2eb.c\\x0ad\\x5c"},
{[][]byte{
[]byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>"),
[]byte("?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}"),
[]byte("~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc"),
[]byte("\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb"),
[]byte("\xfc\xfd\xfe\xff"),
}, "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29\\x2a\\x2b\\x2c-\\x2e\\x2f0123456789\\x3a\\x3b\\x3c\\x3d\\x3e.\\x3f\\x40ABCDEFGHIJKLMNOPQRSTUVWXYZ\\x5b\\x5c\\x5d\\x5e\\x5f\\x60abcdefghijklmnopqrstuvwxyz\\x7b\\x7c\\x7d.\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc.\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb.\\xfc\\xfd\\xfe\\xff"},
} {
s := test.name.String()
if s != test.s {
t.Errorf("%+q escaped to %+q, expected %+q", test.name, s, test.s)
continue
}
unescaped, err := unescapeString(s)
if err != nil {
t.Errorf("%+q unescaping %+q resulted in error %v", test.name, s, err)
continue
}
if !namesEqual(Name(unescaped), test.name) {
t.Errorf("%+q roundtripped through %+q to %+q", test.name, s, unescaped)
continue
}
}
}
func TestNameTrimSuffix(t *testing.T) {
for _, test := range []struct {
name, suffix string
trimmed string
ok bool
}{
{"", "", ".", true},
{".", ".", ".", true},
{"abc", "", "abc", true},
{"abc", ".", "abc", true},
{"", "abc", ".", false},
{".", "abc", ".", false},
{"example.com", "com", "example", true},
{"example.com", "net", ".", false},
{"example.com", "example.com", ".", true},
{"example.com", "test.com", ".", false},
{"example.com", "xample.com", ".", false},
{"example.com", "example", ".", false},
{"example.com", "COM", "example", true},
{"EXAMPLE.COM", "com", "EXAMPLE", true},
} {
tmp, ok := mustParseName(test.name).TrimSuffix(mustParseName(test.suffix))
trimmed := tmp.String()
if ok != test.ok || trimmed != test.trimmed {
t.Errorf("TrimSuffix %+q %+q returned (%+q, %v), expected (%+q, %v)",
test.name, test.suffix, trimmed, ok, test.trimmed, test.ok)
continue
}
}
}
func TestReadName(t *testing.T) {
// Good tests.
for _, test := range []struct {
start int64
end int64
input string
s string
}{
// Empty name.
{0, 1, "\x00abcd", "."},
// No pointers.
{12, 25, "AAAABBBBCCCC\x07example\x03com\x00", "example.com"},
// Backward pointer.
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c", "sub.example.com"},
// Forward pointer.
{0, 4, "\x01a\xc0\x04\x03bcd\x00", "a.bcd"},
// Two backwards pointers.
{31, 38, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c\x04sub2\xc0\x19", "sub2.sub.example.com"},
// Forward then backward pointer.
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x1f\x04sub2\xc0\x0c", "sub.sub2.example.com"},
// Overlapping codons.
{0, 4, "\x01a\xc0\x03bcd\x00", "a.bcd"},
// Pointer to empty label.
{0, 10, "\x07example\xc0\x0a\x00", "example"},
{1, 11, "\x00\x07example\xc0\x00", "example"},
// Pointer to pointer to empty label.
{0, 10, "\x07example\xc0\x0a\xc0\x0c\x00", "example"},
{1, 11, "\x00\x07example\xc0\x0c\xc0\x00", "example"},
} {
r := bytes.NewReader([]byte(test.input))
_, err := r.Seek(test.start, io.SeekStart)
if err != nil {
panic(err)
}
name, err := readName(r)
if err != nil {
t.Errorf("%+q returned error %s", test.input, err)
continue
}
s := name.String()
if s != test.s {
t.Errorf("%+q returned %+q, expected %+q", test.input, s, test.s)
continue
}
cur, _ := r.Seek(0, io.SeekCurrent)
if cur != test.end {
t.Errorf("%+q left offset %d, expected %d", test.input, cur, test.end)
continue
}
}
// Bad tests.
for _, test := range []struct {
start int64
input string
err error
}{
{0, "", io.ErrUnexpectedEOF},
// Reserved label type.
{0, "\x80example", ErrReservedLabelType},
// Reserved label type.
{0, "\x40example", ErrReservedLabelType},
// No Terminating empty label.
{0, "\x07example\x03com", io.ErrUnexpectedEOF},
// Pointer past end of buffer.
{0, "\x07example\xc0\xff", io.ErrUnexpectedEOF},
// Pointer to self.
{0, "\x07example\x03com\xc0\x0c", ErrTooManyPointers},
// Pointer to self with intermediate label.
{0, "\x07example\x03com\xc0\x08", ErrTooManyPointers},
// Two pointers that point to each other.
{0, "\xc0\x02\xc0\x00", ErrTooManyPointers},
// Two pointers that point to each other, with intermediate labels.
{0, "\x01a\xc0\x04\x01b\xc0\x00", ErrTooManyPointers},
// EOF while reading label.
{0, "\x0aexample", io.ErrUnexpectedEOF},
// EOF before second byte of pointer.
{0, "\xc0", io.ErrUnexpectedEOF},
{0, "\x07example\xc0", io.ErrUnexpectedEOF},
} {
r := bytes.NewReader([]byte(test.input))
_, err := r.Seek(test.start, io.SeekStart)
if err != nil {
panic(err)
}
name, err := readName(r)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != test.err {
t.Errorf("%+q returned (%+q, %v), expected %v", test.input, name, err, test.err)
continue
}
}
}
func mustParseName(s string) Name {
name, err := ParseName(s)
if err != nil {
panic(err)
}
return name
}
func questionsEqual(a, b *Question) bool {
if !namesEqual(a.Name, b.Name) {
return false
}
if a.Type != b.Type || a.Class != b.Class {
return false
}
return true
}
func rrsEqual(a, b *RR) bool {
if !namesEqual(a.Name, b.Name) {
return false
}
if a.Type != b.Type || a.Class != b.Class || a.TTL != b.TTL {
return false
}
if !bytes.Equal(a.Data, b.Data) {
return false
}
return true
}
func messagesEqual(a, b *Message) bool {
if a.ID != b.ID || a.Flags != b.Flags {
return false
}
if len(a.Question) != len(b.Question) {
return false
}
for i := 0; i < len(a.Question); i++ {
if !questionsEqual(&a.Question[i], &b.Question[i]) {
return false
}
}
for _, rec := range []struct{ rrA, rrB []RR }{
{a.Answer, b.Answer},
{a.Authority, b.Authority},
{a.Additional, b.Additional},
} {
if len(rec.rrA) != len(rec.rrB) {
return false
}
for i := 0; i < len(rec.rrA); i++ {
if !rrsEqual(&rec.rrA[i], &rec.rrB[i]) {
return false
}
}
}
return true
}
func TestMessageFromWireFormat(t *testing.T) {
for _, test := range []struct {
buf string
expected Message
err error
}{
{
"\x12\x34",
Message{},
io.ErrUnexpectedEOF,
},
{
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01",
Message{
ID: 0x1234,
Flags: 0x0100,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
},
Answer: []RR{},
Authority: []RR{},
Additional: []RR{},
},
nil,
},
{
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01X",
Message{},
ErrTrailingBytes,
},
{
"\x12\x34\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01\x03www\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x00\x80\x00\x04\xc0\x00\x02\x01",
Message{
ID: 0x1234,
Flags: 0x8180,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
},
Answer: []RR{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
TTL: 128,
Data: []byte{192, 0, 2, 1},
},
},
Authority: []RR{},
Additional: []RR{},
},
nil,
},
} {
message, err := MessageFromWireFormat([]byte(test.buf))
if err != test.err || (err == nil && !messagesEqual(&message, &test.expected)) {
t.Errorf("%+q\nreturned (%+v, %v)\nexpected (%+v, %v)",
test.buf, message, err, test.expected, test.err)
continue
}
}
}
func TestMessageWireFormatRoundTrip(t *testing.T) {
for _, message := range []Message{
{
ID: 0x1234,
Flags: 0x0100,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
{
Name: mustParseName("www2.example.com"),
Type: 2,
Class: 2,
},
},
Answer: []RR{
{
Name: mustParseName("abc"),
Type: 2,
Class: 3,
TTL: 0xffffffff,
Data: []byte{1},
},
{
Name: mustParseName("xyz"),
Type: 2,
Class: 3,
TTL: 255,
Data: []byte{},
},
},
Authority: []RR{
{
Name: mustParseName("."),
Type: 65535,
Class: 65535,
TTL: 0,
Data: []byte("XXXXXXXXXXXXXXXXXXX"),
},
},
Additional: []RR{},
},
} {
buf, err := message.WireFormat()
if err != nil {
t.Errorf("%+v cannot make wire format: %v", message, err)
continue
}
message2, err := MessageFromWireFormat(buf)
if err != nil {
t.Errorf("%+q cannot parse wire format: %v", buf, err)
continue
}
if !messagesEqual(&message, &message2) {
t.Errorf("messages unequal\nbefore: %+v\n after: %+v", message, message2)
continue
}
}
}
func TestDecodeRDataTXT(t *testing.T) {
for _, test := range []struct {
p []byte
decoded []byte
err error
}{
{[]byte{}, nil, io.ErrUnexpectedEOF},
{[]byte("\x00"), []byte{}, nil},
{[]byte("\x01"), nil, io.ErrUnexpectedEOF},
} {
decoded, err := DecodeRDataTXT(test.p)
if err != test.err || (err == nil && !bytes.Equal(decoded, test.decoded)) {
t.Errorf("%+q\nreturned (%+q, %v)\nexpected (%+q, %v)",
test.p, decoded, err, test.decoded, test.err)
continue
}
}
}
func TestEncodeRDataTXT(t *testing.T) {
// Encoding 0 bytes needs to return at least a single length octet of
// zero, not an empty slice.
p := make([]byte, 0)
encoded := EncodeRDataTXT(p)
if len(encoded) < 0 {
t.Errorf("EncodeRDataTXT(%v) returned %v", p, encoded)
}
// 255 bytes should be able to be encoded into 256 bytes.
p = make([]byte, 255)
encoded = EncodeRDataTXT(p)
if len(encoded) > 256 {
t.Errorf("EncodeRDataTXT(%d bytes) returned %d bytes", len(p), len(encoded))
}
}
func TestRDataTXTRoundTrip(t *testing.T) {
for _, p := range [][]byte{
{},
[]byte("\x00"),
{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
},
} {
rdata := EncodeRDataTXT(p)
decoded, err := DecodeRDataTXT(rdata)
if err != nil || !bytes.Equal(decoded, p) {
t.Errorf("%+q returned (%+q, %v)", p, decoded, err)
continue
}
}
}

View File

@@ -0,0 +1,23 @@
//go:build gofuzz
// +build gofuzz
// Fuzzing driver for https://github.com/dvyukov/go-fuzz.
// go get -u github.com/dvyukov/go-fuzz/go-fuzz github.com/dvyukov/go-fuzz/go-fuzz-build
// $GOPATH/bin/go-fuzz-build
// $GOPATH/bin/go-fuzz
//
// Related link: https://blog.cloudflare.com/dns-parser-meet-go-fuzzer/
package dns
func Fuzz(data []byte) int {
msg, err := MessageFromWireFormat(data)
if err != nil {
return 0
}
_, err = msg.WireFormat()
if err != nil {
panic(err)
}
return 1 // prioritize this input
}

View File

@@ -0,0 +1,276 @@
// Package noise provides a net.Conn-like interface for a
// Noise_NK_25519_ChaChaPoly_BLAKE2s. It encodes Noise messages onto a reliable
// stream using 16-bit length prefixes.
//
// https://noiseprotocol.org/noise.html
package noise
import (
"bufio"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"github.com/flynn/noise"
"golang.org/x/crypto/curve25519"
)
// The length of public and private keys as returned by GeneratePrivkey.
const KeyLen = 32
// cipherSuite represents 25519_ChaChaPoly_BLAKE2s.
var cipherSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s)
// readMessage reads a length-prefixed message from r. It returns a nil error
// only when a complete message was read. It returns io.EOF only when there were
// 0 bytes remaining to read from r. It returns io.ErrUnexpectedEOF when EOF
// occurs in the middle of an encoded message.
func readMessage(r io.Reader) ([]byte, error) {
var length uint16
err := binary.Read(r, binary.BigEndian, &length)
if err != nil {
// We may return a real io.EOF only here.
return nil, err
}
msg := make([]byte, int(length))
_, err = io.ReadFull(r, msg)
// Here we must change io.EOF to io.ErrUnexpectedEOF.
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return msg, err
}
// writeMessage writes msg as a length-prefixed message to w. It panics if the
// length of msg cannot be represented in 16 bits.
func writeMessage(w io.Writer, msg []byte) error {
length := uint16(len(msg))
if int(length) != len(msg) {
panic(len(msg))
}
err := binary.Write(w, binary.BigEndian, length)
if err != nil {
return err
}
_, err = w.Write(msg)
return err
}
// socket is the internal type that represents a Noise-wrapped
// io.ReadWriteCloser.
type socket struct {
recvPipe *io.PipeReader
sendCipher *noise.CipherState
io.ReadWriteCloser
}
func newSocket(rwc io.ReadWriteCloser, recvCipher, sendCipher *noise.CipherState) *socket {
pr, pw := io.Pipe()
// This loop calls readMessage, decrypts the messages, and feeds them
// into recvPipe where they will be returned from Read.
go func() (err error) {
defer func() {
pw.CloseWithError(err)
}()
for {
msg, err := readMessage(rwc)
if err != nil {
return err
}
p, err := recvCipher.Decrypt(nil, nil, msg)
if err != nil {
return err
}
_, err = pw.Write(p)
if err != nil {
return err
}
}
}()
return &socket{
sendCipher: sendCipher,
recvPipe: pr,
ReadWriteCloser: rwc,
}
}
// Read reads decrypted data from the wrapped io.Reader.
func (s *socket) Read(p []byte) (int, error) {
return s.recvPipe.Read(p)
}
// Write writes encrypted data from the wrapped io.Writer.
func (s *socket) Write(p []byte) (int, error) {
total := 0
for len(p) > 0 {
n := len(p)
if n > 4096 {
n = 4096
}
msg, err := s.sendCipher.Encrypt(nil, nil, p[:n])
if err != nil {
return total, err
}
err = writeMessage(s.ReadWriteCloser, msg)
if err != nil {
return total, err
}
total += n
p = p[n:]
}
return total, nil
}
// newConfig instantiates configuration settings that are common to clients and
// servers.
func newConfig() noise.Config {
return noise.Config{
CipherSuite: cipherSuite,
Pattern: noise.HandshakeNK,
Prologue: []byte("dnstt 2020-04-13"),
}
}
// NewClient wraps an io.ReadWriteCloser in a Noise protocol as a client, and
// returns after completing the handshake. It returns a non-nil error if there
// is an error during the handshake.
func NewClient(rwc io.ReadWriteCloser, serverPubkey []byte) (io.ReadWriteCloser, error) {
config := newConfig()
config.Initiator = true
config.PeerStatic = serverPubkey
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return nil, err
}
// -> e, es
msg, _, _, err := handshakeState.WriteMessage(nil, nil)
if err != nil {
return nil, err
}
err = writeMessage(rwc, msg)
if err != nil {
return nil, err
}
// <- e, es
msg, err = readMessage(rwc)
if err != nil {
return nil, err
}
payload, sendCipher, recvCipher, err := handshakeState.ReadMessage(nil, msg)
if err != nil {
return nil, err
}
if len(payload) != 0 {
return nil, errors.New("unexpected server payload")
}
return newSocket(rwc, recvCipher, sendCipher), nil
}
// NewClient wraps an io.ReadWriteCloser in a Noise protocol as a server, and
// returns after completing the handshake. It returns a non-nil error if there
// is an error during the handshake.
func NewServer(rwc io.ReadWriteCloser, serverPrivkey []byte) (io.ReadWriteCloser, error) {
config := newConfig()
config.Initiator = false
config.StaticKeypair = noise.DHKey{
Private: serverPrivkey,
Public: PubkeyFromPrivkey(serverPrivkey),
}
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return nil, err
}
// -> e, es
msg, err := readMessage(rwc)
if err != nil {
return nil, err
}
payload, _, _, err := handshakeState.ReadMessage(nil, msg)
if err != nil {
return nil, err
}
if len(payload) != 0 {
return nil, errors.New("unexpected client payload")
}
// <- e, es
msg, recvCipher, sendCipher, err := handshakeState.WriteMessage(nil, nil)
if err != nil {
return nil, err
}
err = writeMessage(rwc, msg)
if err != nil {
return nil, err
}
return newSocket(rwc, recvCipher, sendCipher), nil
}
// GeneratePrivkey generates a private key. The corresponding public key can be
// derived using PubkeyFromPrivkey.
func GeneratePrivkey() ([]byte, error) {
pair, err := noise.DH25519.GenerateKeypair(rand.Reader)
return pair.Private, err
}
// PubkeyFromPrivkey returns the public key that corresponds to privkey.
func PubkeyFromPrivkey(privkey []byte) []byte {
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey
}
// ReadKey reads a hex-encoded key from r. r must consist of a single line, with
// or without a '\n' line terminator. The line must consist of KeyLen
// hex-encoded bytes.
func ReadKey(r io.Reader) ([]byte, error) {
br := bufio.NewReader(io.LimitReader(r, 100))
line, err := br.ReadString('\n')
if err == io.EOF {
err = nil
}
if err == nil {
// Check that we're at EOF.
_, err = br.ReadByte()
if err == io.EOF {
err = nil
} else if err == nil {
err = fmt.Errorf("file contains more than one line")
}
}
if err != nil {
return nil, err
}
line = strings.TrimSuffix(line, "\n")
return DecodeKey(line)
}
// WriteKey writes the hex-encoded key in a single line to w.
func WriteKey(w io.Writer, key []byte) error {
_, err := fmt.Fprintf(w, "%x\n", key)
return err
}
// DecodeKey decodes a hex-encoded private or public key.
func DecodeKey(s string) ([]byte, error) {
key, err := hex.DecodeString(s)
if err == nil && len(key) != KeyLen {
err = fmt.Errorf("length is %d, expected %d", len(key), KeyLen)
}
return key, err
}
// EncodeKey encodes a hex-encoded private or public key.
func EncodeKey(key []byte) string {
return hex.EncodeToString(key)
}

View File

@@ -0,0 +1,218 @@
package noise
import (
"bytes"
"io"
"net"
"testing"
"github.com/flynn/noise"
)
func allMessages(buf []byte) ([][]byte, error) {
var messages [][]byte
r := bytes.NewReader(buf)
for {
msg, err := readMessage(r)
if err != nil {
return messages, err
}
messages = append(messages, msg)
}
}
func messagesEqual(a, b [][]byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i], b[i]) {
return false
}
}
return true
}
func TestReadMessage(t *testing.T) {
for _, test := range []struct {
input string
messages [][]byte
err error
}{
{"", [][]byte{}, io.EOF},
{"\x00", [][]byte{}, io.ErrUnexpectedEOF},
{"\x00\x00", [][]byte{{}}, io.EOF},
{"\x00\x00\x00", [][]byte{{}}, io.ErrUnexpectedEOF},
{"\x00\x01", [][]byte{}, io.ErrUnexpectedEOF},
{"\x00\x05hello\x00\x05world", [][]byte{[]byte("hello"), []byte("world")}, io.EOF},
} {
packets, err := allMessages([]byte(test.input))
if !messagesEqual(packets, test.messages) || err != test.err {
t.Errorf("%x\nreturned %x %v\nexpected %x %v",
test.input, packets, err, test.messages, test.err)
}
}
}
func TestMessageRoundTrip(t *testing.T) {
for _, messages := range [][][]byte{
{},
} {
var buf bytes.Buffer
for _, msg := range messages {
err := writeMessage(&buf, msg)
if err != nil {
panic(err)
}
}
output, err := allMessages(buf.Bytes())
if !messagesEqual(output, messages) || err != io.EOF {
t.Errorf("%x roundtripped to %x %v",
messages, output, err)
}
}
}
func TestReadKey(t *testing.T) {
for _, test := range []struct {
input string
output []byte
}{
{"", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\nX", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n", nil},
{"\n0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
{"X123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", nil},
} {
output, err := ReadKey(bytes.NewReader([]byte(test.input)))
if test.output == nil {
if err == nil {
t.Errorf("%+q expected error", test.input)
}
} else {
if err != nil {
t.Errorf("%+q returned error %v", test.input, err)
} else if !bytes.Equal(output, test.output) {
t.Errorf("%+q got %x, expected %x", test.input, output, test.output)
}
}
}
}
func TestUnexpectedPayload(t *testing.T) {
privkey, err := GeneratePrivkey()
if err != nil {
panic(err)
}
pubkey := PubkeyFromPrivkey(privkey)
// Test the client sending an unexpected payload.
clientWithPayload := func(rwc io.ReadWriteCloser) error {
config := newConfig()
config.Initiator = true
config.PeerStatic = pubkey
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return err
}
// -> e, es
msg, _, _, err := handshakeState.WriteMessage(nil, []byte("payload"))
if err != nil {
return err
}
err = writeMessage(rwc, msg)
if err != nil {
return err
}
// <- e, es
// Return nil for all errors after this point, because we expect
// the server to have failed, but we want to keep up the game
// just in case the server did not fail.
msg, err = readMessage(rwc)
if err != nil {
return nil
}
_, _, _, err = handshakeState.ReadMessage(nil, msg)
if err != nil {
return nil
}
return nil
}
func() {
c, s := net.Pipe()
defer s.Close()
// Fake a client side that sends a payload.
go func() {
defer c.Close()
err := clientWithPayload(c)
if err != nil {
t.Fatal(err)
}
}()
server, err := NewServer(s, privkey)
if err == nil || err.Error() != "unexpected client payload" || server != nil {
t.Errorf("NewServer got (%T, %v)", server, err)
}
}()
// Test the server sending an unexpected payload.
serverWithPayload := func(rwc io.ReadWriteCloser) error {
config := newConfig()
config.Initiator = false
config.StaticKeypair = noise.DHKey{Private: privkey, Public: pubkey}
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return err
}
// -> e, es
msg, err := readMessage(rwc)
if err != nil {
return err
}
_, _, _, err = handshakeState.ReadMessage(nil, msg)
if err != nil {
return err
}
// <- e, es
msg, _, _, err = handshakeState.WriteMessage(nil, []byte("payload"))
if err != nil {
return err
}
err = writeMessage(rwc, msg)
if err != nil {
return err
}
return nil
}
func() {
c, s := net.Pipe()
defer c.Close()
// Fake a server side that sends a payload.
go func() {
defer s.Close()
err := serverWithPayload(s)
if err != nil {
t.Fatal(err)
}
}()
client, err := NewClient(c, pubkey)
if err == nil || err.Error() != "unexpected server payload" || client != nil {
t.Errorf("NewClient got (%T, %v)", client, err)
}
}()
}

View File

@@ -0,0 +1,28 @@
package turbotunnel
import (
"crypto/rand"
"encoding/hex"
)
// ClientID is an abstract identifier that binds together all the communications
// belonging to a single client session, even though those communications may
// arrive from multiple IP addresses or over multiple lower-level connections.
// It plays the same role that an (IP address, port number) tuple plays in a
// net.UDPConn: it's the return address pertaining to a long-lived abstract
// client session. The client attaches its ClientID to each of its
// communications, enabling the server to disambiguate requests among its many
// clients. ClientID implements the net.Addr interface.
type ClientID [8]byte
func NewClientID() ClientID {
var id ClientID
_, err := rand.Read(id[:])
if err != nil {
panic(err)
}
return id
}
func (id ClientID) Network() string { return "clientid" }
func (id ClientID) String() string { return hex.EncodeToString(id[:]) }

View File

@@ -0,0 +1,22 @@
// Package turbotunnel is facilities for embedding packet-based reliability
// protocols inside other protocols.
//
// https://github.com/net4people/bbs/issues/9
package turbotunnel
import "errors"
// QueueSize is the size of send and receive queues in QueuePacketConn and
// RemoteMap.
const QueueSize = 128
var errClosedPacketConn = errors.New("operation on closed connection")
var errNotImplemented = errors.New("not implemented")
// DummyAddr is a placeholder net.Addr, for when a programming interface
// requires a net.Addr but there is none relevant. All DummyAddrs compare equal
// to each other.
type DummyAddr struct{}
func (addr DummyAddr) Network() string { return "dummy" }
func (addr DummyAddr) String() string { return "dummy" }

View File

@@ -0,0 +1,162 @@
package turbotunnel
import (
"net"
"sync"
"sync/atomic"
"time"
)
// taggedPacket is a combination of a []byte and a net.Addr, encapsulating the
// return type of PacketConn.ReadFrom.
type taggedPacket struct {
P []byte
Addr net.Addr
}
// QueuePacketConn implements net.PacketConn by storing queues of packets. There
// is one incoming queue (where packets are additionally tagged by the source
// address of the peer that sent them). There are many outgoing queues, one for
// each remote peer address that has been recently seen. The QueueIncoming
// method inserts a packet into the incoming queue, to eventually be returned by
// ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue,
// which can later by accessed through the OutgoingQueue method.
//
// Besides the outgoing queues, there is also a one-element "stash" for each
// remote peer address. You can stash a packet using the Stash method, and get
// it back later by receiving from the channel returned by Unstash. The stash is
// meant as a convenient place to temporarily store a single packet, such as
// when you've read one too many packets from the send queue and need to store
// the extra packet to be processed first in the next pass. It's the caller's
// responsibility to Unstash what they have Stashed. Calling Stash does not put
// the packet at the head of the send queue; if there is the possibility that a
// packet has been stashed, it must be checked for by calling Unstash in
// addition to OutgoingQueue.
type QueuePacketConn struct {
remotes *RemoteMap
localAddr net.Addr
recvQueue chan taggedPacket
closeOnce sync.Once
closed chan struct{}
// What error to return when the QueuePacketConn is closed.
err atomic.Value
}
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent peers
// for at least a duration of timeout.
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
return &QueuePacketConn{
remotes: NewRemoteMap(timeout),
localAddr: localAddr,
recvQueue: make(chan taggedPacket, QueueSize),
closed: make(chan struct{}),
}
}
// QueueIncoming queues and incoming packet and its source address, to be
// returned in a future call to ReadFrom.
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
select {
case <-c.closed:
// If we're closed, silently drop it.
return
default:
}
// Copy the slice so that the caller may reuse it.
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.recvQueue <- taggedPacket{buf, addr}:
default:
// Drop the incoming packet if the receive queue is full.
}
}
// OutgoingQueue returns the queue of outgoing packets corresponding to addr,
// creating it if necessary. The contents of the queue will be packets that are
// written to the address in question using WriteTo.
func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
return c.remotes.SendQueue(addr)
}
// Stash places p in the stash for addr, if the stash is not already occupied.
// Returns true if the packet was placed in the stash, or false if the stash was
// already occupied. This method is similar to WriteTo, except that it puts the
// packet in the stash queue (accessible via Unstash), rather than the outgoing
// queue (accessible via OutgoingQueue).
func (c *QueuePacketConn) Stash(p []byte, addr net.Addr) bool {
return c.remotes.Stash(addr, p)
}
// Unstash returns the channel that represents the stash for addr.
func (c *QueuePacketConn) Unstash(addr net.Addr) <-chan []byte {
return c.remotes.Unstash(addr)
}
// ReadFrom returns a packet and address previously stored by QueueIncoming.
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
select {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
default:
}
select {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
case packet := <-c.recvQueue:
return copy(p, packet.P), packet.Addr, nil
}
}
// WriteTo queues an outgoing packet for the given address. The queue can later
// be retrieved using the OutgoingQueue method.
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
select {
case <-c.closed:
return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
default:
}
// Copy the slice so that the caller may reuse it.
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.remotes.SendQueue(addr) <- buf:
return len(buf), nil
default:
// Drop the outgoing packet if the send queue is full.
return len(buf), nil
}
}
// closeWithError unblocks pending operations and makes future operations fail
// with the given error. If err is nil, it becomes errClosedPacketConn.
func (c *QueuePacketConn) closeWithError(err error) error {
var newlyClosed bool
c.closeOnce.Do(func() {
newlyClosed = true
// Store the error to be returned by future PacketConn
// operations.
if err == nil {
err = errClosedPacketConn
}
c.err.Store(err)
close(c.closed)
})
if !newlyClosed {
return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
}
return nil
}
// Close unblocks pending operations and makes future operations fail with a
// "closed connection" error.
func (c *QueuePacketConn) Close() error {
return c.closeWithError(nil)
}
// LocalAddr returns the localAddr value that was passed to NewQueuePacketConn.
func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr }
func (c *QueuePacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
func (c *QueuePacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }

View File

@@ -0,0 +1,177 @@
package turbotunnel
import (
"container/heap"
"net"
"sync"
"time"
)
// remoteRecord is a record of a recently seen remote peer, with the time it was
// last seen and queues of outgoing packets.
type remoteRecord struct {
Addr net.Addr
LastSeen time.Time
SendQueue chan []byte
Stash chan []byte
}
// RemoteMap manages a mapping of live remote peers, keyed by address, to their
// respective send queues. Each peer has two queues: a primary send queue, and a
// "stash". The primary send queue is returned by the SendQueue method. The
// stash is an auxiliary one-element queue accessed using the Stash and Unstash
// methods. The stash is meant for use by callers that need to "unread" a packet
// that's already been removed from the primary send queue.
//
// RemoteMap's functions are safe to call from multiple goroutines.
type RemoteMap struct {
// We use an inner structure to avoid exposing public heap.Interface
// functions to users of remoteMap.
inner remoteMapInner
// Synchronizes access to inner.
lock sync.Mutex
}
// NewRemoteMap creates a RemoteMap that expires peers after a timeout.
//
// If the timeout is 0, peers never expire.
//
// The timeout does not have to be kept in sync with smux's idle timeout. If a
// peer is removed from the map while the smux session is still live, the worst
// that can happen is a loss of whatever packets were in the send queue at the
// time. If smux later decides to send more packets to the same peer, we'll
// instantiate a new send queue, and if the peer is ever seen again with a
// matching address, we'll deliver them.
func NewRemoteMap(timeout time.Duration) *RemoteMap {
m := &RemoteMap{
inner: remoteMapInner{
byAge: make([]*remoteRecord, 0),
byAddr: make(map[net.Addr]int),
},
}
if timeout > 0 {
go func() {
for {
time.Sleep(timeout / 2)
now := time.Now()
m.lock.Lock()
m.inner.removeExpired(now, timeout)
m.lock.Unlock()
}
}()
}
return m
}
// SendQueue returns the send queue corresponding to addr, creating it if
// necessary.
func (m *RemoteMap) SendQueue(addr net.Addr) chan []byte {
m.lock.Lock()
defer m.lock.Unlock()
return m.inner.Lookup(addr, time.Now()).SendQueue
}
// Stash places p in the stash corresponding to addr, if the stash is not
// already occupied. Returns true if the p was placed in the stash, false
// otherwise.
func (m *RemoteMap) Stash(addr net.Addr, p []byte) bool {
m.lock.Lock()
defer m.lock.Unlock()
select {
case m.inner.Lookup(addr, time.Now()).Stash <- p:
return true
default:
return false
}
}
// Unstash returns the channel that reads from the stash for addr.
func (m *RemoteMap) Unstash(addr net.Addr) <-chan []byte {
m.lock.Lock()
defer m.lock.Unlock()
return m.inner.Lookup(addr, time.Now()).Stash
}
// remoteMapInner is the inner type of RemoteMap, implementing heap.Interface.
// byAge is the backing store, a heap ordered by LastSeen time, to facilitate
// expiring old records. byAddr is a map from addresses to heap indices, to
// allow looking up by address. Unlike RemoteMap, remoteMapInner requires
// external synchonization.
type remoteMapInner struct {
byAge []*remoteRecord
byAddr map[net.Addr]int
}
// removeExpired removes all records whose LastSeen timestamp is more than
// timeout in the past.
func (inner *remoteMapInner) removeExpired(now time.Time, timeout time.Duration) {
for len(inner.byAge) > 0 && now.Sub(inner.byAge[0].LastSeen) >= timeout {
record := heap.Pop(inner).(*remoteRecord)
close(record.SendQueue)
}
}
// Lookup finds the existing record corresponding to addr, or creates a new
// one if none exists yet. It updates the record's LastSeen time and returns the
// record.
func (inner *remoteMapInner) Lookup(addr net.Addr, now time.Time) *remoteRecord {
var record *remoteRecord
i, ok := inner.byAddr[addr]
if ok {
// Found one, update its LastSeen.
record = inner.byAge[i]
record.LastSeen = now
heap.Fix(inner, i)
} else {
// Not found, create a new one.
record = &remoteRecord{
Addr: addr,
LastSeen: now,
SendQueue: make(chan []byte, QueueSize),
Stash: make(chan []byte, 1),
}
heap.Push(inner, record)
}
return record
}
// heap.Interface for remoteMapInner.
func (inner *remoteMapInner) Len() int {
if len(inner.byAge) != len(inner.byAddr) {
panic("inconsistent remoteMap")
}
return len(inner.byAge)
}
func (inner *remoteMapInner) Less(i, j int) bool {
return inner.byAge[i].LastSeen.Before(inner.byAge[j].LastSeen)
}
func (inner *remoteMapInner) Swap(i, j int) {
inner.byAge[i], inner.byAge[j] = inner.byAge[j], inner.byAge[i]
inner.byAddr[inner.byAge[i].Addr] = i
inner.byAddr[inner.byAge[j].Addr] = j
}
func (inner *remoteMapInner) Push(x interface{}) {
record := x.(*remoteRecord)
if _, ok := inner.byAddr[record.Addr]; ok {
panic("duplicate address in remoteMap")
}
// Insert into byAddr map.
inner.byAddr[record.Addr] = len(inner.byAge)
// Insert into byAge slice.
inner.byAge = append(inner.byAge, record)
}
func (inner *remoteMapInner) Pop() interface{} {
n := len(inner.byAddr)
// Remove from byAge slice.
record := inner.byAge[n-1]
inner.byAge[n-1] = nil
inner.byAge = inner.byAge[:n-1]
// Remove from byAddr map.
delete(inner.byAddr, record.Addr)
return record
}