219 lines
5.4 KiB
Go
219 lines
5.4 KiB
Go
package noise
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/flynn/noise"
|
|
)
|
|
|
|
func allMessages(buf []byte) ([][]byte, error) {
|
|
var messages [][]byte
|
|
r := bytes.NewReader(buf)
|
|
for {
|
|
msg, err := readMessage(r)
|
|
if err != nil {
|
|
return messages, err
|
|
}
|
|
messages = append(messages, msg)
|
|
}
|
|
}
|
|
|
|
func messagesEqual(a, b [][]byte) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if !bytes.Equal(a[i], b[i]) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func TestReadMessage(t *testing.T) {
|
|
for _, test := range []struct {
|
|
input string
|
|
messages [][]byte
|
|
err error
|
|
}{
|
|
{"", [][]byte{}, io.EOF},
|
|
{"\x00", [][]byte{}, io.ErrUnexpectedEOF},
|
|
{"\x00\x00", [][]byte{{}}, io.EOF},
|
|
{"\x00\x00\x00", [][]byte{{}}, io.ErrUnexpectedEOF},
|
|
{"\x00\x01", [][]byte{}, io.ErrUnexpectedEOF},
|
|
{"\x00\x05hello\x00\x05world", [][]byte{[]byte("hello"), []byte("world")}, io.EOF},
|
|
} {
|
|
packets, err := allMessages([]byte(test.input))
|
|
if !messagesEqual(packets, test.messages) || err != test.err {
|
|
t.Errorf("%x\nreturned %x %v\nexpected %x %v",
|
|
test.input, packets, err, test.messages, test.err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestMessageRoundTrip(t *testing.T) {
|
|
for _, messages := range [][][]byte{
|
|
{},
|
|
} {
|
|
var buf bytes.Buffer
|
|
for _, msg := range messages {
|
|
err := writeMessage(&buf, msg)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
output, err := allMessages(buf.Bytes())
|
|
if !messagesEqual(output, messages) || err != io.EOF {
|
|
t.Errorf("%x roundtripped to %x %v",
|
|
messages, output, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestReadKey(t *testing.T) {
|
|
for _, test := range []struct {
|
|
input string
|
|
output []byte
|
|
}{
|
|
{"", nil},
|
|
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", nil},
|
|
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
|
|
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
|
|
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", nil},
|
|
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\nX", nil},
|
|
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n", nil},
|
|
{"\n0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
|
|
{"X123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
|
|
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", nil},
|
|
} {
|
|
output, err := ReadKey(bytes.NewReader([]byte(test.input)))
|
|
if test.output == nil {
|
|
if err == nil {
|
|
t.Errorf("%+q expected error", test.input)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("%+q returned error %v", test.input, err)
|
|
} else if !bytes.Equal(output, test.output) {
|
|
t.Errorf("%+q got %x, expected %x", test.input, output, test.output)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestUnexpectedPayload(t *testing.T) {
|
|
privkey, err := GeneratePrivkey()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
pubkey := PubkeyFromPrivkey(privkey)
|
|
|
|
// Test the client sending an unexpected payload.
|
|
clientWithPayload := func(rwc io.ReadWriteCloser) error {
|
|
config := newConfig()
|
|
config.Initiator = true
|
|
config.PeerStatic = pubkey
|
|
handshakeState, err := noise.NewHandshakeState(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// -> e, es
|
|
msg, _, _, err := handshakeState.WriteMessage(nil, []byte("payload"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = writeMessage(rwc, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// <- e, es
|
|
// Return nil for all errors after this point, because we expect
|
|
// the server to have failed, but we want to keep up the game
|
|
// just in case the server did not fail.
|
|
msg, err = readMessage(rwc)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
_, _, _, err = handshakeState.ReadMessage(nil, msg)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return nil
|
|
}
|
|
func() {
|
|
c, s := net.Pipe()
|
|
defer s.Close()
|
|
|
|
// Fake a client side that sends a payload.
|
|
go func() {
|
|
defer c.Close()
|
|
err := clientWithPayload(c)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
server, err := NewServer(s, privkey)
|
|
if err == nil || err.Error() != "unexpected client payload" || server != nil {
|
|
t.Errorf("NewServer got (%T, %v)", server, err)
|
|
}
|
|
}()
|
|
|
|
// Test the server sending an unexpected payload.
|
|
serverWithPayload := func(rwc io.ReadWriteCloser) error {
|
|
config := newConfig()
|
|
config.Initiator = false
|
|
config.StaticKeypair = noise.DHKey{Private: privkey, Public: pubkey}
|
|
handshakeState, err := noise.NewHandshakeState(config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// -> e, es
|
|
msg, err := readMessage(rwc)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, _, _, err = handshakeState.ReadMessage(nil, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// <- e, es
|
|
msg, _, _, err = handshakeState.WriteMessage(nil, []byte("payload"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = writeMessage(rwc, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
func() {
|
|
c, s := net.Pipe()
|
|
defer c.Close()
|
|
|
|
// Fake a server side that sends a payload.
|
|
go func() {
|
|
defer s.Close()
|
|
err := serverWithPayload(s)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
client, err := NewClient(c, pubkey)
|
|
if err == nil || err.Error() != "unexpected server payload" || client != nil {
|
|
t.Errorf("NewClient got (%T, %v)", client, err)
|
|
}
|
|
}()
|
|
}
|