This commit is contained in:
2026-05-16 00:18:06 -03:00
commit 92941e68a2
66 changed files with 10352 additions and 0 deletions

57
internal/app/app.go Normal file
View File

@@ -0,0 +1,57 @@
package app
import (
"os"
"path/filepath"
"socksrevivepc/internal/config"
"socksrevivepc/internal/engine"
)
type App struct {
Root string
Store *config.Store
Engine *engine.Manager
}
func New() (*App, error) {
root, err := os.Getwd()
if err != nil {
return nil, err
}
for _, d := range []string{"configs", "logs", filepath.Join("tools", "dnstt"), filepath.Join("tools", "xray"), filepath.Join("tools", "wintun")} {
if err := os.MkdirAll(filepath.Join(root, d), 0o755); err != nil {
return nil, err
}
}
store, err := config.NewStore(filepath.Join(root, "profiles"))
if err != nil {
return nil, err
}
mgr := engine.NewManager(root)
if err := removeBundledExamples(store); err != nil {
return nil, err
}
app := &App{Root: root, Store: store, Engine: mgr}
return app, nil
}
func removeBundledExamples(store *config.Store) error {
list, err := store.List()
if err != nil {
return err
}
exampleNames := map[string]bool{
"Example - Direct SSH": true,
"Example - Payload + SSL": true,
"Example - Xray Core": true,
}
for _, p := range list {
if exampleNames[p.Name] {
if err := store.Delete(p.ID); err != nil {
return err
}
}
}
return nil
}

477
internal/config/config.go Normal file
View File

@@ -0,0 +1,477 @@
package config
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/gob"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"time"
)
type Mode string
const (
ModeDirect Mode = "direct"
ModePayload Mode = "payload"
ModeSSL Mode = "ssl"
ModePayloadSSL Mode = "payload_ssl"
ModeDNSTT Mode = "dnstt"
ModeXray Mode = "xray"
)
const (
ProfileExtension = ".srpc"
profileMagic = "SRPC\x01"
)
type Profile struct {
ID string
Name string
Mode Mode
CreatedAt time.Time
UpdatedAt time.Time
SSH SSHConfig
Proxy ProxyConfig
Payload PayloadConfig
TLS TLSConfig
DNSTT DNSTTConfig
Xray XrayConfig
UDPGW UDPGWConfig
Reconnect ReconnectConfig
Local LocalConfig
Tun TunConfig
}
type SSHConfig struct {
Host string
Port int
Username string
Password string
KeepAliveSeconds int
HandshakeTimeoutMs int
}
type ProxyConfig struct {
Host string
Port int
}
type PayloadConfig struct {
Text string
WaitForResponse bool
AcceptAnyStatus bool
ResponseTimeoutMs int
SplitDelayMs int
}
type TLSConfig struct {
Enabled bool
Host string
Port int
ServerName string
InsecureSkipVerify bool
}
type DNSTTConfig struct {
Enabled bool
UseEmbedded bool
ResolverType string
ResolverAddress string
Domain string
PublicKey string
UTLSDistribution string
Executable string
Args []string
LocalSSHHost string
LocalSSHPort int
StartupTimeoutMs int
}
type XrayConfig struct {
Executable string
Args []string
ConfigPath string
LocalSocksHost string
LocalSocksPort int
StartupTimeoutMs int
}
type UDPGWConfig struct {
Enabled bool
Host string
Port int
Protocol string
}
type ReconnectConfig struct {
Enabled bool
DelaySeconds int
MaxRetries int
CheckIntervalSeconds int
}
type LocalConfig struct {
SocksHost string
SocksPort int
}
type TunConfig struct {
Enabled bool
Device string
InterfaceName string
MTU int
Gateway string
CIDR string
DNS []string
RouteAll bool
// IPv6 is optional because many SSH/DNSTT/UDPGW servers only have IPv4
// egress. When IPv6Enabled is false, AllowIPv6Leak controls whether the
// app should leave the normal Windows/Linux IPv6 route untouched.
IPv6Enabled bool
IPv6CIDR string
IPv6DNS []string
AllowIPv6Leak bool
}
type Store struct {
Dir string
}
func NewStore(dir string) (*Store, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
return &Store{Dir: dir}, nil
}
func (s *Store) List() ([]Profile, error) {
entries, err := os.ReadDir(s.Dir)
if err != nil {
return nil, err
}
var out []Profile
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(strings.ToLower(e.Name()), ProfileExtension) {
continue
}
p, err := s.Load(strings.TrimSuffix(e.Name(), ProfileExtension))
if err == nil {
out = append(out, p)
}
}
sort.Slice(out, func(i, j int) bool {
return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
})
return out, nil
}
func (s *Store) Load(id string) (Profile, error) {
var p Profile
b, err := os.ReadFile(filepath.Join(s.Dir, safeID(id)+ProfileExtension))
if err != nil {
return p, err
}
p, err = DecodeProfileFile(b)
if err != nil {
return p, err
}
ApplyDefaults(&p)
return p, nil
}
func (s *Store) Save(p Profile) (Profile, error) {
if p.ID == "" {
p.ID = randomID()
}
now := time.Now().UTC()
if p.CreatedAt.IsZero() {
p.CreatedAt = now
}
p.UpdatedAt = now
ApplyDefaults(&p)
if err := Validate(p); err != nil {
return p, err
}
b, err := EncodeProfileFile(p)
if err != nil {
return p, err
}
return p, os.WriteFile(filepath.Join(s.Dir, safeID(p.ID)+ProfileExtension), b, 0o600)
}
func (s *Store) Delete(id string) error {
return os.Remove(filepath.Join(s.Dir, safeID(id)+ProfileExtension))
}
func EncodeProfileFile(p Profile) ([]byte, error) {
ApplyDefaults(&p)
var buf bytes.Buffer
buf.WriteString(profileMagic)
gz := gzip.NewWriter(&buf)
if err := gob.NewEncoder(gz).Encode(p); err != nil {
_ = gz.Close()
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func DecodeProfileFile(b []byte) (Profile, error) {
var p Profile
if !bytes.HasPrefix(b, []byte(profileMagic)) {
return p, errors.New("invalid SocksRevive PC profile file")
}
gz, err := gzip.NewReader(bytes.NewReader(b[len(profileMagic):]))
if err != nil {
return p, err
}
defer gz.Close()
payload, err := io.ReadAll(gz)
if err != nil {
return p, err
}
if err := gob.NewDecoder(bytes.NewReader(payload)).Decode(&p); err != nil {
return p, err
}
ApplyDefaults(&p)
return p, Validate(p)
}
func Validate(p Profile) error {
if strings.TrimSpace(p.Name) == "" {
return errors.New("profile name is required")
}
switch p.Mode {
case ModeDirect, ModePayload, ModeSSL, ModePayloadSSL, ModeDNSTT, ModeXray:
default:
return fmt.Errorf("unknown mode %q", p.Mode)
}
if p.Mode != ModeXray {
if p.SSH.Username == "" || p.SSH.Password == "" {
return errors.New("ssh username/password are required")
}
if p.Mode != ModeDNSTT && (p.SSH.Host == "" || p.SSH.Port <= 0) {
return errors.New("ssh host/port are required")
}
}
if (p.Mode == ModePayload || p.Mode == ModePayloadSSL) && strings.TrimSpace(p.Payload.Text) == "" {
return errors.New("payload text is required for payload modes")
}
if p.Mode == ModeXray && strings.TrimSpace(p.Xray.Executable) == "" {
return errors.New("xray executable path is required")
}
if p.Mode == ModeDNSTT {
if p.DNSTT.UseEmbedded {
if strings.TrimSpace(p.DNSTT.ResolverType) == "" || strings.TrimSpace(p.DNSTT.ResolverAddress) == "" {
return errors.New("dnstt resolver type/address are required")
}
if strings.TrimSpace(p.DNSTT.Domain) == "" {
return errors.New("dnstt domain is required")
}
if strings.TrimSpace(p.DNSTT.PublicKey) == "" {
return errors.New("dnstt public key is required")
}
} else if strings.TrimSpace(p.DNSTT.Executable) == "" {
return errors.New("dnstt executable path is required when embedded dnstt is disabled")
}
}
if p.UDPGW.Enabled {
if strings.TrimSpace(p.UDPGW.Host) == "" {
return errors.New("udpgw host is required when udpgw is enabled")
}
if p.UDPGW.Port <= 0 || p.UDPGW.Port > 65535 {
return errors.New("udpgw port must be between 1 and 65535")
}
switch strings.ToLower(strings.TrimSpace(p.UDPGW.Protocol)) {
case "", "badvpn", "legacy":
default:
return errors.New("udpgw protocol must be badvpn or legacy")
}
}
if p.Local.SocksPort <= 0 || p.Local.SocksPort > 65535 {
return errors.New("local socks port must be between 1 and 65535")
}
if p.Tun.Enabled && p.Tun.IPv6Enabled {
ip, _, err := net.ParseCIDR(strings.TrimSpace(p.Tun.IPv6CIDR))
if err != nil || ip == nil || ip.To4() != nil {
return errors.New("IPv6 CIDR must be a valid IPv6 CIDR, for example fd00:534f:434b::1/64")
}
}
return nil
}
func ApplyDefaults(p *Profile) {
if p.ID == "" {
p.ID = randomID()
}
if p.Mode == "" {
p.Mode = ModeDirect
}
if p.SSH.Port == 0 {
p.SSH.Port = 22
}
if p.SSH.KeepAliveSeconds == 0 {
p.SSH.KeepAliveSeconds = 20
}
if p.SSH.HandshakeTimeoutMs == 0 {
p.SSH.HandshakeTimeoutMs = 15000
}
if p.Payload.Text == "" {
p.Payload.Text = "CONNECT [host]:[port] HTTP/1.1[crlf]Host: [host][crlf]User-Agent: SocksRevivePC[crlf][crlf]"
}
if p.Payload.ResponseTimeoutMs == 0 {
p.Payload.ResponseTimeoutMs = 8000
}
if p.Payload.SplitDelayMs == 0 {
p.Payload.SplitDelayMs = 120
}
if p.TLS.Port == 0 {
p.TLS.Port = 443
}
if p.Reconnect.DelaySeconds == 0 {
p.Reconnect.DelaySeconds = 3
}
if p.Reconnect.CheckIntervalSeconds == 0 {
p.Reconnect.CheckIntervalSeconds = 10
}
if p.Local.SocksHost == "" {
p.Local.SocksHost = "127.0.0.1"
}
if p.Local.SocksPort == 0 {
p.Local.SocksPort = 10809
}
if p.DNSTT.LocalSSHHost == "" {
p.DNSTT.LocalSSHHost = "127.0.0.1"
}
if p.DNSTT.LocalSSHPort == 0 {
p.DNSTT.LocalSSHPort = 2222
}
if p.DNSTT.StartupTimeoutMs == 0 {
p.DNSTT.StartupTimeoutMs = 5000
}
if p.DNSTT.ResolverType == "" {
p.DNSTT.ResolverType = "doh"
}
if p.DNSTT.UTLSDistribution == "" {
p.DNSTT.UTLSDistribution = "4*random,3*Firefox_120,1*Firefox_105,3*Chrome_120,1*Chrome_102,1*iOS_14,1*iOS_13"
}
if !p.DNSTT.UseEmbedded && p.DNSTT.Executable == "" && len(p.DNSTT.Args) == 0 {
// New profiles use the embedded DNSTT client. Old imported profiles can
// still disable UseEmbedded and point to an external executable.
p.DNSTT.UseEmbedded = true
}
if p.DNSTT.Executable == "" {
p.DNSTT.Executable = defaultToolPath("dnstt", "dnstt-client")
}
if p.Xray.LocalSocksHost == "" {
p.Xray.LocalSocksHost = "127.0.0.1"
}
if p.Xray.LocalSocksPort == 0 {
p.Xray.LocalSocksPort = 10808
}
if p.Xray.ConfigPath == "" {
p.Xray.ConfigPath = "configs/xray.json"
}
if len(p.Xray.Args) == 0 {
p.Xray.Args = []string{"run", "-config", p.Xray.ConfigPath}
}
if p.Xray.StartupTimeoutMs == 0 {
p.Xray.StartupTimeoutMs = 2500
}
if p.Xray.Executable == "" {
p.Xray.Executable = defaultToolPath("xray", "xray")
}
if p.UDPGW.Host == "" {
p.UDPGW.Host = "127.0.0.1"
}
if p.UDPGW.Port == 0 {
p.UDPGW.Port = 7400
}
if strings.TrimSpace(p.UDPGW.Protocol) == "" {
// BadVPN is the Android-compatible UDPGW framing. The old PC-only
// experimental frame is still available as "legacy" for older tests.
p.UDPGW.Protocol = "badvpn"
}
if p.Tun.Device == "" {
if runtime.GOOS == "windows" {
p.Tun.Device = "wintun"
p.Tun.InterfaceName = "wintun"
} else {
p.Tun.Device = "tun://socksrevive0"
p.Tun.InterfaceName = "socksrevive0"
}
}
if runtime.GOOS == "windows" {
// tun2socks v2 expects the Windows device model to be just "wintun".
// The old value "wintun://SocksRevive" made the engine treat
// "SocksRevive" as a network interface and crash with
// "route ip+net: no such network interface" on many PCs.
p.Tun.Device = "wintun"
p.Tun.InterfaceName = "wintun"
}
if p.Tun.InterfaceName == "" {
if strings.HasPrefix(p.Tun.Device, "tun://") {
p.Tun.InterfaceName = strings.TrimPrefix(p.Tun.Device, "tun://")
} else if strings.HasPrefix(p.Tun.Device, "wintun://") {
p.Tun.InterfaceName = strings.TrimPrefix(p.Tun.Device, "wintun://")
} else {
p.Tun.InterfaceName = p.Tun.Device
}
}
if p.Tun.MTU == 0 {
p.Tun.MTU = 1500
}
if p.Tun.Gateway == "" {
p.Tun.Gateway = "198.18.0.1"
}
if p.Tun.CIDR == "" {
p.Tun.CIDR = "198.18.0.1/15"
}
if len(p.Tun.DNS) == 0 {
p.Tun.DNS = []string{"1.1.1.1", "8.8.8.8"}
}
if p.Tun.IPv6CIDR == "" {
p.Tun.IPv6CIDR = "fd00:534f:434b::1/64"
}
if len(p.Tun.IPv6DNS) == 0 {
p.Tun.IPv6DNS = []string{"2606:4700:4700::1111", "2001:4860:4860::8888"}
}
}
func defaultToolPath(folder, base string) string {
exe := base
if runtime.GOOS == "windows" {
exe += ".exe"
}
return filepath.Join("tools", folder, exe)
}
func randomID() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
return fmt.Sprintf("p%d", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
func safeID(id string) string {
id = strings.TrimSpace(id)
id = strings.ReplaceAll(id, "/", "_")
id = strings.ReplaceAll(id, "\\", "_")
id = strings.ReplaceAll(id, "..", "_")
return id
}

73
internal/crash/crash.go Normal file
View File

@@ -0,0 +1,73 @@
package crash
import (
"fmt"
"log"
"os"
"path/filepath"
"runtime/debug"
"sync"
"time"
)
var fileMu sync.Mutex
// AttachLog mirrors the standard logger into logs/runtime.log. It is useful for
// GUI builds on Windows where there is no console attached.
func AttachLog(root string) func() {
if root == "" {
root = "."
}
logDir := filepath.Join(root, "logs")
_ = os.MkdirAll(logDir, 0o755)
path := filepath.Join(logDir, "runtime.log")
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return func() {}
}
log.SetOutput(f)
log.SetFlags(log.LstdFlags | log.Lmicroseconds | log.Lshortfile)
log.Printf("SocksRevivePC started")
return func() {
log.Printf("SocksRevivePC stopped")
_ = f.Close()
}
}
// Recover writes any Go panic to logs/crash.log. This does not hide the crash;
// it only makes GUI builds debuggable when Windows closes the app silently.
func Recover(root string) {
if v := recover(); v != nil {
Write(root, "panic", v)
}
}
// Go starts a goroutine with panic logging. Use this for goroutines owned by the
// app so background failures do not vanish without a stack trace.
func Go(root string, fn func()) {
go func() {
defer Recover(root)
fn()
}()
}
func Write(root, kind string, value any) {
if root == "" {
root = "."
}
fileMu.Lock()
defer fileMu.Unlock()
debug.SetTraceback("all")
logDir := filepath.Join(root, "logs")
_ = os.MkdirAll(logDir, 0o755)
path := filepath.Join(logDir, "crash.log")
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return
}
defer f.Close()
_, _ = fmt.Fprintf(f, "\n===== %s =====\n", time.Now().Format(time.RFC3339Nano))
_, _ = fmt.Fprintf(f, "%s: %v\n\n", kind, value)
_, _ = f.Write(debug.Stack())
_, _ = f.WriteString("\n")
}

276
internal/dnsttclient/api.go Normal file
View File

@@ -0,0 +1,276 @@
package dnsttclient
import (
"context"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
utls "github.com/refraction-networking/utls"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"socksrevivepc/internal/dnsttcore/dns"
"socksrevivepc/internal/dnsttcore/noise"
"socksrevivepc/internal/dnsttcore/turbotunnel"
)
// Options configures the embedded DNSTT client.
type Options struct {
ResolverType string
ResolverAddress string
PublicKeyHex string
Domain string
LocalAddress string
UTLSDistribution string
StartupTimeout time.Duration
LogWriter io.Writer
}
// Client is a running embedded DNSTT client instance.
type Client struct {
cancel context.CancelFunc
done chan struct{}
}
// Stop shuts down the listener and DNS transport.
func (c *Client) Stop() {
if c == nil {
return
}
c.cancel()
select {
case <-c.done:
case <-time.After(3 * time.Second):
}
}
// Start starts an embedded DNSTT client and waits until the local TCP listener
// is open. It does not require dnstt-client.exe.
func Start(parent context.Context, opts Options) (*Client, error) {
if opts.LogWriter != nil {
log.SetOutput(opts.LogWriter)
log.SetFlags(log.LstdFlags | log.LUTC)
}
if opts.StartupTimeout <= 0 {
opts.StartupTimeout = 5 * time.Second
}
ctx, cancel := context.WithCancel(parent)
client := &Client{cancel: cancel, done: make(chan struct{})}
ready := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(client.done)
errCh <- runOptions(ctx, opts, ready)
}()
select {
case <-ready:
return client, nil
case err := <-errCh:
cancel()
if err == nil {
err = fmt.Errorf("embedded dnstt stopped during startup")
}
return nil, err
case <-time.After(opts.StartupTimeout):
// Keep the client alive. The local listener usually opens first, but on
// slow networks the Noise/smux session can need a little more time.
return client, nil
case <-parent.Done():
cancel()
return nil, parent.Err()
}
}
func runOptions(ctx context.Context, opts Options, ready chan<- struct{}) error {
resolverType := strings.ToLower(strings.TrimSpace(opts.ResolverType))
resolverAddress := strings.TrimSpace(opts.ResolverAddress)
if resolverType == "" {
resolverType = "doh"
}
if resolverAddress == "" {
return fmt.Errorf("dnstt resolver is required")
}
if opts.PublicKeyHex == "" {
return fmt.Errorf("dnstt public key is required")
}
if opts.Domain == "" {
return fmt.Errorf("dnstt domain is required")
}
if opts.LocalAddress == "" {
opts.LocalAddress = "127.0.0.1:2222"
}
if opts.UTLSDistribution == "" {
opts.UTLSDistribution = "4*random,3*Firefox_120,1*Firefox_105,3*Chrome_120,1*Chrome_102,1*iOS_14,1*iOS_13"
}
domain, err := dns.ParseName(opts.Domain)
if err != nil {
return fmt.Errorf("invalid dnstt domain: %w", err)
}
localAddr, err := net.ResolveTCPAddr("tcp", opts.LocalAddress)
if err != nil {
return fmt.Errorf("invalid dnstt local address: %w", err)
}
pubkey, err := noise.DecodeKey(opts.PublicKeyHex)
if err != nil {
return fmt.Errorf("dnstt public key format error: %w", err)
}
tlsConfig, err := loadTLSConfig()
if err != nil {
return fmt.Errorf("dnstt TLS config error: %w", err)
}
utlsClientHelloID, err := sampleUTLSDistribution(opts.UTLSDistribution)
if err != nil {
return fmt.Errorf("dnstt uTLS profile error: %w", err)
}
if utlsClientHelloID != nil {
log.Printf("dnstt uTLS fingerprint %s %s", utlsClientHelloID.Client, utlsClientHelloID.Version)
}
remoteAddr, pconn, err := makePacketConn(resolverType, resolverAddress, tlsConfig, utlsClientHelloID)
if err != nil {
return err
}
pconn = NewDNSPacketConn(pconn, remoteAddr, domain)
return runContext(ctx, pubkey, domain, localAddr, remoteAddr, pconn, ready)
}
func makePacketConn(resolverType, resolverAddress string, tlsConfig *tls.Config, utlsClientHelloID *utls.ClientHelloID) (net.Addr, net.PacketConn, error) {
switch resolverType {
case "doh", "https":
addr := turbotunnel.DummyAddr{}
var rt http.RoundTripper
if utlsClientHelloID == nil {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Proxy = nil
transport.TLSClientConfig = tlsConfig.Clone()
baseDialContext := transport.DialContext
if baseDialContext == nil {
baseDialContext = (&net.Dialer{}).DialContext
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if network == "tcp" {
resolvedAddr, _, err := resolveAddrIPv4(ctx, addr)
if err != nil {
return nil, err
}
addr = resolvedAddr
network = "tcp4"
}
return baseDialContext(ctx, network, addr)
}
rt = transport
} else {
utlsConfig := &utls.Config{RootCAs: tlsConfig.RootCAs, MinVersion: tlsConfig.MinVersion}
rt = NewUTLSRoundTripper(utlsConfig, utlsClientHelloID, true)
}
pconn, err := NewHTTPPacketConn(rt, resolverAddress, 32)
return addr, pconn, err
case "dot", "tls":
addr := turbotunnel.DummyAddr{}
var dialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
if utlsClientHelloID == nil {
dialTLSContext = (&tls.Dialer{Config: tlsConfig.Clone()}).DialContext
} else {
utlsConfig := &utls.Config{RootCAs: tlsConfig.RootCAs, MinVersion: tlsConfig.MinVersion}
dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return utlsDialContext(ctx, network, addr, utlsConfig, utlsClientHelloID)
}
}
pconn, err := NewTLSPacketConn(resolverAddress, dialTLSContext)
return addr, pconn, err
case "udp", "dns":
addr, err := net.ResolveUDPAddr("udp", resolverAddress)
if err != nil {
return nil, nil, err
}
pconn, err := net.ListenUDP("udp", nil)
return addr, pconn, err
default:
return nil, nil, fmt.Errorf("unknown dnstt resolver type %q", resolverType)
}
}
func runContext(ctx context.Context, pubkey []byte, domain dns.Name, localAddr *net.TCPAddr, remoteAddr net.Addr, pconn net.PacketConn, ready chan<- struct{}) error {
defer pconn.Close()
ln, err := net.ListenTCP("tcp", localAddr)
if err != nil {
return fmt.Errorf("opening dnstt local listener: %w", err)
}
defer ln.Close()
go func() {
<-ctx.Done()
_ = ln.Close()
_ = pconn.Close()
}()
close(ready)
log.Printf("dnstt local listener ready at %s", ln.Addr())
mtu := dnsNameCapacity(domain) - 8 - 1 - numPadding - 1
if mtu < 80 {
return fmt.Errorf("domain %s leaves only %d bytes for payload", domain, mtu)
}
log.Printf("dnstt effective MTU %d", mtu)
conn, err := kcp.NewConn2(remoteAddr, nil, 0, 0, pconn)
if err != nil {
return fmt.Errorf("opening dnstt KCP connection: %w", err)
}
defer func() {
log.Printf("end dnstt session %08x", conn.GetConv())
conn.Close()
}()
log.Printf("begin dnstt session %08x", conn.GetConv())
conn.SetStreamMode(true)
conn.SetNoDelay(0, 0, 0, 1)
conn.SetWindowSize(turbotunnel.QueueSize/2, turbotunnel.QueueSize/2)
if rc := conn.SetMtu(mtu); !rc {
return fmt.Errorf("setting dnstt MTU failed")
}
rw, err := noise.NewClient(conn, pubkey)
if err != nil {
return err
}
smuxConfig := smux.DefaultConfig()
smuxConfig.Version = 2
smuxConfig.KeepAliveTimeout = idleTimeout
smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024
sess, err := smux.Client(rw, smuxConfig)
if err != nil {
return fmt.Errorf("opening dnstt smux session: %w", err)
}
defer sess.Close()
for {
local, err := ln.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
}
if err, ok := err.(net.Error); ok && err.Temporary() {
continue
}
return err
}
go func() {
defer local.Close()
err := handle(local.(*net.TCPConn), sess, conn.GetConv())
if err != nil {
log.Printf("dnstt handle: %v", err)
}
}()
}
}

View File

@@ -0,0 +1,22 @@
package dnsttclient
import (
"crypto/tls"
"crypto/x509"
)
func loadTLSConfig() (*tls.Config, error) {
pool, err := x509.SystemCertPool()
if err != nil {
pool = nil
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if pool != nil {
config.RootCAs = pool
}
return config, nil
}

407
internal/dnsttclient/dns.go Normal file
View File

@@ -0,0 +1,407 @@
package dnsttclient
import (
"bytes"
"crypto/rand"
"encoding/base32"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"time"
"socksrevivepc/internal/dnsttcore/dns"
"socksrevivepc/internal/dnsttcore/turbotunnel"
)
const (
// How many bytes of random padding to insert into queries.
numPadding = 3
// In an otherwise empty polling query, insert even more random padding,
// to reduce the chance of a cache hit. Cannot be greater than 31,
// because the prefix codes indicating padding start at 224.
numPaddingForPoll = 8
// sendLoop has a poll timer that automatically sends an empty polling
// query when a certain amount of time has elapsed without a send. The
// poll timer is initially set to initPollDelay. It increases by a
// factor of pollDelayMultiplier every time the poll timer expires, up
// to a maximum of maxPollDelay. The poll timer is reset to
// initPollDelay whenever an a send occurs that is not the result of the
// poll timer expiring.
initPollDelay = 500 * time.Millisecond
maxPollDelay = 10 * time.Second
pollDelayMultiplier = 2.0
// A limit on the number of empty poll requests we may send in a burst
// as a result of receiving data.
pollLimit = 16
)
// base32Encoding is a base32 encoding without padding.
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
// DNSPacketConn provides a packet-sending and -receiving interface over various
// forms of DNS. It handles the details of how packets and padding are encoded
// as a DNS name in the Question section of an upstream query, and as a TXT RR
// in downstream responses.
//
// DNSPacketConn does not handle the mechanics of actually sending and receiving
// encoded DNS messages. That is rather the responsibility of some other
// net.PacketConn such as net.UDPConn, HTTPPacketConn, or TLSPacketConn, one of
// which must be provided to NewDNSPacketConn.
//
// We don't have a need to match up a query and a response by ID. Queries and
// responses are vehicles for carrying data and for our purposes don't need to
// be correlated. When sending a query, we generate a random ID, and when
// receiving a response, we ignore the ID.
type DNSPacketConn struct {
clientID turbotunnel.ClientID
domain dns.Name
// Sending on pollChan permits sendLoop to send an empty polling query.
// sendLoop also does its own polling according to a time schedule.
pollChan chan struct{}
// QueuePacketConn is the direct receiver of ReadFrom and WriteTo calls.
// recvLoop and sendLoop take the messages out of the receive and send
// queues and actually put them on the network.
*turbotunnel.QueuePacketConn
}
// NewDNSPacketConn creates a new DNSPacketConn. transport, through its WriteTo
// and ReadFrom methods, handles the actual sending and receiving the DNS
// messages encoded by DNSPacketConn. addr is the address to be passed to
// transport.WriteTo whenever a message needs to be sent.
func NewDNSPacketConn(transport net.PacketConn, addr net.Addr, domain dns.Name) *DNSPacketConn {
// Generate a new random ClientID.
clientID := turbotunnel.NewClientID()
c := &DNSPacketConn{
clientID: clientID,
domain: domain,
pollChan: make(chan struct{}, pollLimit),
QueuePacketConn: turbotunnel.NewQueuePacketConn(clientID, 0),
}
go func() {
err := c.recvLoop(transport)
if err != nil {
log.Printf("recvLoop: %v", err)
}
}()
go func() {
err := c.sendLoop(transport, addr)
if err != nil {
log.Printf("sendLoop: %v", err)
}
}()
return c
}
// dnsResponsePayload extracts the downstream payload of a DNS response, encoded
// into the RDATA of a TXT RR. It returns nil if the message doesn't pass format
// checks, or if the name in its Question entry is not a subdomain of domain.
func dnsResponsePayload(resp *dns.Message, domain dns.Name) []byte {
if resp.Flags&0x8000 != 0x8000 {
// QR != 1, this is not a response.
return nil
}
if resp.Flags&0x000f != dns.RcodeNoError {
return nil
}
if len(resp.Answer) != 1 {
return nil
}
answer := resp.Answer[0]
_, ok := answer.Name.TrimSuffix(domain)
if !ok {
// Not the name we are expecting.
return nil
}
if answer.Type != dns.RRTypeTXT {
// We only support TYPE == TXT.
return nil
}
payload, err := dns.DecodeRDataTXT(answer.Data)
if err != nil {
return nil
}
return payload
}
// nextPacket reads the next length-prefixed packet from r. It returns a nil
// error only when a complete packet was read. It returns io.EOF only when there
// were 0 bytes remaining to read from r. It returns io.ErrUnexpectedEOF when
// EOF occurs in the middle of an encoded packet.
func nextPacket(r *bytes.Reader) ([]byte, error) {
for {
var n uint16
err := binary.Read(r, binary.BigEndian, &n)
if err != nil {
// We may return a real io.EOF only here.
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(r, p)
// Here we must change io.EOF to io.ErrUnexpectedEOF.
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return p, err
}
}
// recvLoop repeatedly calls transport.ReadFrom to receive a DNS message,
// extracts its payload and breaks it into packets, and stores the packets in a
// queue to be returned from a future call to c.ReadFrom.
//
// Whenever we receive a DNS response containing at least one data packet, we
// send on c.pollChan to permit sendLoop to send an immediate polling queries.
// KCP itself will also send an ACK packet for incoming data, which is
// effectively a second poll. Therefore, each time we receive data, we send up
// to 2 polling queries (or 1 + f polling queries, if KCP only ACKs an f
// fraction of incoming data). We say "up to" because sendLoop will discard an
// empty polling query if it has an organic non-empty packet to send (this goes
// also for KCP's organic ACK packets).
//
// The intuition behind polling immediately after receiving is that if server
// has just had something to send, it may have more to send, and in order for
// the server to send anything, we must give it a query to respond to. The
// intuition behind polling *2 times* (or 1 + f times) is similar to TCP slow
// start: we want to maintain some number of queries "in flight", and the faster
// the server is sending, the higher that number should be. If we polled only
// once for each received packet, we would tend to have only one query in flight
// at a time, ping-pong style. The first polling query replaces the in-flight
// query that has just finished its duty in returning data to us; the second
// grows the effective in-flight window proportional to the rate at which
// data-carrying responses are being received. Compare to Eq. (2) of
// https://tools.ietf.org/html/rfc5681#section-3.1. The differences are that we
// count messages, not bytes, and we don't maintain an explicit window. If a
// response comes back without data, or if a query or response is dropped by the
// network, then we don't poll again, which decreases the effective in-flight
// window.
func (c *DNSPacketConn) recvLoop(transport net.PacketConn) error {
for {
var buf [4096]byte
n, addr, err := transport.ReadFrom(buf[:])
if err != nil {
if err, ok := err.(net.Error); ok && err.Temporary() {
log.Printf("ReadFrom temporary error: %v", err)
continue
}
return err
}
// Got a response. Try to parse it as a DNS message.
resp, err := dns.MessageFromWireFormat(buf[:n])
if err != nil {
log.Printf("MessageFromWireFormat: %v", err)
continue
}
payload := dnsResponsePayload(&resp, c.domain)
// Pull out the packets contained in the payload.
r := bytes.NewReader(payload)
any := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
any = true
c.QueuePacketConn.QueueIncoming(p, addr)
}
// If the payload contained one or more packets, permit sendLoop
// to poll immediately. ACKs on received data will effectively
// serve as another stream of polls whose rate is proportional
// to the rate of incoming packets.
if any {
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
}
// chunks breaks p into non-empty subslices of at most n bytes, greedily so that
// only final subslice has length < n.
func chunks(p []byte, n int) [][]byte {
var result [][]byte
for len(p) > 0 {
sz := len(p)
if sz > n {
sz = n
}
result = append(result, p[:sz])
p = p[sz:]
}
return result
}
// send sends p as a single packet encoded into a DNS query, using
// transport.WriteTo(query, addr). The length of p must be less than 224 bytes.
//
// Here is an example of how a packet is encoded into a DNS name, using
//
// p = "supercalifragilisticexpialidocious"
// c.clientID = "CLIENTID"
// domain = "t.example.com"
//
// as the input.
//
// 0. Start with the raw packet contents.
//
// supercalifragilisticexpialidocious
//
// 1. Length-prefix the packet and add random padding. A length prefix L < 0xe0
// means a data packet of L bytes. A length prefix L ≥ 0xe0 means padding
// of L 0xe0 bytes (not counting the length of the length prefix itself).
//
// \xe3\xd9\xa3\x15\x22supercalifragilisticexpialidocious
//
// 2. Prefix the ClientID.
//
// CLIENTID\xe3\xd9\xa3\x15\x22supercalifragilisticexpialidocious
//
// 3. Base32-encode, without padding and in lower case.
//
// ingesrkokreujy6zumkse43vobsxey3bnruwm4tbm5uwy2ltoruwgzlyobuwc3djmrxwg2lpovzq
//
// 4. Break into labels of at most 63 octets.
//
// ingesrkokreujy6zumkse43vobsxey3bnruwm4tbm5uwy2ltoruwgzlyobuwc3d.jmrxwg2lpovzq
//
// 5. Append the domain.
//
// ingesrkokreujy6zumkse43vobsxey3bnruwm4tbm5uwy2ltoruwgzlyobuwc3d.jmrxwg2lpovzq.t.example.com
func (c *DNSPacketConn) send(transport net.PacketConn, p []byte, addr net.Addr) error {
var decoded []byte
{
if len(p) >= 224 {
return fmt.Errorf("too long")
}
var buf bytes.Buffer
// ClientID
buf.Write(c.clientID[:])
n := numPadding
if len(p) == 0 {
n = numPaddingForPoll
}
// Padding / cache inhibition
buf.WriteByte(byte(224 + n))
io.CopyN(&buf, rand.Reader, int64(n))
// Packet contents
if len(p) > 0 {
buf.WriteByte(byte(len(p)))
buf.Write(p)
}
decoded = buf.Bytes()
}
encoded := make([]byte, base32Encoding.EncodedLen(len(decoded)))
base32Encoding.Encode(encoded, decoded)
encoded = bytes.ToLower(encoded)
labels := chunks(encoded, 63)
labels = append(labels, c.domain...)
name, err := dns.NewName(labels)
if err != nil {
return err
}
var id uint16
binary.Read(rand.Reader, binary.BigEndian, &id)
query := &dns.Message{
ID: id,
Flags: 0x0100, // QR = 0, RD = 1
Question: []dns.Question{
{
Name: name,
Type: dns.RRTypeTXT,
Class: dns.ClassIN,
},
},
// EDNS(0)
Additional: []dns.RR{
{
Name: dns.Name{},
Type: dns.RRTypeOPT,
Class: 4096, // requester's UDP payload size
TTL: 0, // extended RCODE and flags
Data: []byte{},
},
},
}
buf, err := query.WireFormat()
if err != nil {
return err
}
_, err = transport.WriteTo(buf, addr)
return err
}
// sendLoop takes packets that have been written using c.WriteTo, and sends them
// on the network using send. It also does polling with empty packets when
// requested by pollChan or after a timeout.
func (c *DNSPacketConn) sendLoop(transport net.PacketConn, addr net.Addr) error {
pollDelay := initPollDelay
pollTimer := time.NewTimer(pollDelay)
for {
var p []byte
outgoing := c.QueuePacketConn.OutgoingQueue(addr)
pollTimerExpired := false
// Prioritize sending an actual data packet from outgoing. Only
// consider a poll when outgoing is empty.
select {
case p = <-outgoing:
default:
select {
case p = <-outgoing:
case <-c.pollChan:
case <-pollTimer.C:
pollTimerExpired = true
}
}
if len(p) > 0 {
// A data-carrying packet displaces one pending poll
// opportunity, if any.
select {
case <-c.pollChan:
default:
}
}
if pollTimerExpired {
// We're polling because it's been a while since we last
// polled. Increase the poll delay.
pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier)
if pollDelay > maxPollDelay {
pollDelay = maxPollDelay
}
} else {
// We're sending an actual data packet, or we're polling
// in response to a received packet. Reset the poll
// delay to initial.
if !pollTimer.Stop() {
<-pollTimer.C
}
pollDelay = initPollDelay
}
pollTimer.Reset(pollDelay)
// Unlike in the server, in the client we assume that because
// the data capacity of queries is so limited, it's not worth
// trying to send more than one packet per query.
err := c.send(transport, p, addr)
if err != nil {
log.Printf("send: %v", err)
continue
}
}
}

View File

@@ -0,0 +1,176 @@
package dnsttclient
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"strconv"
"sync"
"time"
"socksrevivepc/internal/dnsttcore/turbotunnel"
)
// A default Retry-After delay to use when there is no explicit Retry-After
// header in an HTTP response.
const defaultRetryAfter = 10 * time.Second
// HTTPPacketConn is an HTTP-based transport for DNS messages, used for DNS over
// HTTPS (DoH). Its WriteTo and ReadFrom methods exchange DNS messages over HTTP
// requests and responses.
//
// HTTPPacketConn deals only with already formatted DNS messages. It does not
// handle encoding information into the messages. That is rather the
// responsibility of DNSPacketConn.
//
// https://tools.ietf.org/html/rfc8484
type HTTPPacketConn struct {
// client is the http.Client used to make requests. We use this instead
// of http.DefaultClient in order to support setting a timeout and a
// uTLS fingerprint.
client *http.Client
// urlString is the URL to which HTTP requests will be sent, for example
// "https://doh.example/dns-query".
urlString string
// notBefore, if not zero, is a time before which we may not send any
// queries; queries are buffered or dropped until that time. notBefore
// is set when we get a 429 Too Many Requests HTTP response or other
// unexpected status code that causes us to need to slow down. It is set
// according to the Retry-After header if available, otherwise it is set
// to defaultRetryAfter in the future. notBeforeLock controls access to
// notBefore.
notBefore time.Time
notBeforeLock sync.RWMutex
// QueuePacketConn is the direct receiver of ReadFrom and WriteTo calls.
// sendLoop, via send, removes messages from the outgoing queue that
// were placed there by WriteTo, and inserts messages into the incoming
// queue to be returned from ReadFrom.
*turbotunnel.QueuePacketConn
}
// NewHTTPPacketConn creates a new HTTPPacketConn configured to use the HTTP
// server at urlString as a DNS over HTTP resolver. client is the http.Client
// that will be used to make requests. urlString should include any necessary
// path components; e.g., "/dns-query". numSenders is the number of concurrent
// sender-receiver goroutines to run.
func NewHTTPPacketConn(rt http.RoundTripper, urlString string, numSenders int) (*HTTPPacketConn, error) {
c := &HTTPPacketConn{
client: &http.Client{
Transport: rt,
Timeout: 1 * time.Minute,
},
urlString: urlString,
QueuePacketConn: turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, 0),
}
for i := 0; i < numSenders; i++ {
go c.sendLoop()
}
return c, nil
}
// send sends a message in an HTTP request, and queues the body HTTP response to
// be returned from a future call to ReadFrom.
func (c *HTTPPacketConn) send(p []byte) error {
req, err := http.NewRequest("POST", c.urlString, bytes.NewReader(p))
if err != nil {
return err
}
req.Header.Set("Accept", "application/dns-message")
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("User-Agent", "") // Disable default "Go-http-client/1.1".
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
if ct := resp.Header.Get("Content-Type"); ct != "application/dns-message" {
return fmt.Errorf("unknown HTTP response Content-Type %+q", ct)
}
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 64000))
if err == nil {
c.QueuePacketConn.QueueIncoming(body, turbotunnel.DummyAddr{})
}
// Ignore err != nil; don't report an error if we at least
// managed to send.
default:
// We primarily are thinking of 429 Too Many Requests here, but
// any other unexpected response codes will also cause us to
// rate-limit ourselves and emit a log message.
// https://developers.google.com/speed/public-dns/docs/doh/#errors
now := time.Now()
var retryAfter time.Time
if value := resp.Header.Get("Retry-After"); value != "" {
var err error
retryAfter, err = parseRetryAfter(value, now)
if err != nil {
log.Printf("cannot parse Retry-After value %+q", value)
}
}
if retryAfter.IsZero() {
// Supply a default.
retryAfter = now.Add(defaultRetryAfter)
}
if retryAfter.Before(now) {
log.Printf("got %+q, but Retry-After is %v in the past",
resp.Status, now.Sub(retryAfter))
} else {
c.notBeforeLock.Lock()
if retryAfter.Before(c.notBefore) {
log.Printf("got %+q, but Retry-After is %v earlier than already received Retry-After",
resp.Status, c.notBefore.Sub(retryAfter))
} else {
log.Printf("got %+q; ceasing sending for %v",
resp.Status, retryAfter.Sub(now))
c.notBefore = retryAfter
}
c.notBeforeLock.Unlock()
}
}
return nil
}
// sendLoop loops over the contents of the outgoing queue and passes them to
// send. It drops packets while c.notBefore is in the future.
func (c *HTTPPacketConn) sendLoop() {
for p := range c.QueuePacketConn.OutgoingQueue(turbotunnel.DummyAddr{}) {
// Stop sending while we are rate-limiting ourselves (as a
// result of a Retry-After response header, for example).
c.notBeforeLock.RLock()
notBefore := c.notBefore
c.notBeforeLock.RUnlock()
if wait := notBefore.Sub(time.Now()); wait > 0 {
// Drop it.
continue
}
err := c.send(p)
if err != nil {
log.Printf("sendLoop: %v", err)
}
}
}
// parseRetryAfter parses the value of a Retry-After header as an absolute
// time.Time.
func parseRetryAfter(value string, now time.Time) (time.Time, error) {
// May be a date string or an integer number of seconds.
// https://tools.ietf.org/html/rfc7231#section-7.1.3
if t, err := http.ParseTime(value); err == nil {
return t, nil
}
i, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return time.Time{}, err
}
return now.Add(time.Duration(i) * time.Second), nil
}

View File

@@ -0,0 +1,445 @@
// dnstt-client is the client end of a DNS tunnel.
//
// Usage:
//
// dnstt-client [-doh URL|-dot ADDR|-udp ADDR] -pubkey-file PUBKEYFILE DOMAIN LOCALADDR
//
// Examples:
//
// dnstt-client -doh https://resolver.example/dns-query -pubkey-file server.pub t.example.com 127.0.0.1:7000
// dnstt-client -dot resolver.example:853 -pubkey-file server.pub t.example.com 127.0.0.1:7000
//
// The program supports DNS over HTTPS (DoH), DNS over TLS (DoT), and UDP DNS.
// Use one of these options:
//
// -doh https://resolver.example/dns-query
// -dot resolver.example:853
// -udp resolver.example:53
//
// You can give the server's public key as a file or as a hex string. Use
// "dnstt-server -gen-key" to get the public key.
//
// -pubkey-file server.pub
// -pubkey 0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff
//
// DOMAIN is the root of the DNS zone reserved for the tunnel. See README for
// instructions on setting it up.
//
// LOCALADDR is the TCP address that will listen for connections and forward
// them over the tunnel.
//
// In -doh and -dot modes, the program's TLS fingerprint is camouflaged with
// uTLS by default. The specific TLS fingerprint is selected randomly from a
// weighted distribution. You can set your own distribution (or specific single
// fingerprint) using the -utls option. The special value "none" disables uTLS.
//
// -utls '3*Firefox,2*Chrome,1*iOS'
// -utls Firefox
// -utls none
package dnsttclient
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
utls "github.com/refraction-networking/utls"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"socksrevivepc/internal/dnsttcore/dns"
"socksrevivepc/internal/dnsttcore/noise"
"socksrevivepc/internal/dnsttcore/turbotunnel"
)
// smux streams will be closed after this much time without receiving data.
const idleTimeout = 2 * time.Minute
// dnsNameCapacity returns the number of bytes remaining for encoded data after
// including domain in a DNS name.
func dnsNameCapacity(domain dns.Name) int {
// Names must be 255 octets or shorter in total length.
// https://tools.ietf.org/html/rfc1035#section-2.3.4
capacity := 255
// Subtract the length of the null terminator.
capacity -= 1
for _, label := range domain {
// Subtract the length of the label and the length octet.
capacity -= len(label) + 1
}
// Each label may be up to 63 bytes long and requires 64 bytes to
// encode.
capacity = capacity * 63 / 64
// Base32 expands every 5 bytes to 8.
capacity = capacity * 5 / 8
return capacity
}
// readKeyFromFile reads a key from a named file.
func readKeyFromFile(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
return noise.ReadKey(f)
}
// sampleUTLSDistribution parses a weighted uTLS Client Hello ID distribution
// string of the form "3*Firefox,2*Chrome,1*iOS", matches each label to a
// utls.ClientHelloID from utlsClientHelloIDMap, and randomly samples one
// utls.ClientHelloID from the distribution.
func sampleUTLSDistribution(spec string) (*utls.ClientHelloID, error) {
weights, labels, err := parseWeightedList(spec)
if err != nil {
return nil, err
}
ids := make([]*utls.ClientHelloID, 0, len(labels))
for _, label := range labels {
var id *utls.ClientHelloID
if label == "none" {
id = nil
} else {
id = utlsLookup(label)
if id == nil {
return nil, fmt.Errorf("unknown TLS fingerprint %q", label)
}
}
ids = append(ids, id)
}
return ids[sampleWeighted(weights)], nil
}
func handle(local *net.TCPConn, sess *smux.Session, conv uint32) error {
stream, err := sess.OpenStream()
if err != nil {
return fmt.Errorf("session %08x opening stream: %v", conv, err)
}
defer func() {
log.Printf("end stream %08x:%d", conv, stream.ID())
stream.Close()
}()
log.Printf("begin stream %08x:%d", conv, stream.ID())
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := io.Copy(stream, local)
if err == io.EOF {
// smux Stream.Write may return io.EOF.
err = nil
}
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
log.Printf("stream %08x:%d copy stream←local: %v", conv, stream.ID(), err)
}
local.CloseRead()
stream.Close()
}()
go func() {
defer wg.Done()
_, err := io.Copy(local, stream)
if err == io.EOF {
// smux Stream.WriteTo may return io.EOF.
err = nil
}
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
log.Printf("stream %08x:%d copy local←stream: %v", conv, stream.ID(), err)
}
local.CloseWrite()
}()
wg.Wait()
return err
}
func run(pubkey []byte, domain dns.Name, localAddr *net.TCPAddr, remoteAddr net.Addr, pconn net.PacketConn) error {
defer pconn.Close()
ln, err := net.ListenTCP("tcp", localAddr)
if err != nil {
return fmt.Errorf("opening local listener: %v", err)
}
defer ln.Close()
mtu := dnsNameCapacity(domain) - 8 - 1 - numPadding - 1 // clientid + padding length prefix + padding + data length prefix
if mtu < 80 {
return fmt.Errorf("domain %s leaves only %d bytes for payload", domain, mtu)
}
log.Printf("effective MTU %d", mtu)
// Open a KCP conn on the PacketConn.
conn, err := kcp.NewConn2(remoteAddr, nil, 0, 0, pconn)
if err != nil {
return fmt.Errorf("opening KCP conn: %v", err)
}
defer func() {
log.Printf("end session %08x", conn.GetConv())
conn.Close()
}()
log.Printf("begin session %08x", conn.GetConv())
// Permit coalescing the payloads of consecutive sends.
conn.SetStreamMode(true)
// Disable the dynamic congestion window (limit only by the maximum of
// local and remote static windows).
conn.SetNoDelay(
0, // default nodelay
0, // default interval
0, // default resend
1, // nc=1 => congestion window off
)
conn.SetWindowSize(turbotunnel.QueueSize/2, turbotunnel.QueueSize/2)
if rc := conn.SetMtu(mtu); !rc {
panic(rc)
}
// Put a Noise channel on top of the KCP conn.
rw, err := noise.NewClient(conn, pubkey)
if err != nil {
return err
}
// Start a smux session on the Noise channel.
smuxConfig := smux.DefaultConfig()
smuxConfig.Version = 2
smuxConfig.KeepAliveTimeout = idleTimeout
smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024 // default is 65536
sess, err := smux.Client(rw, smuxConfig)
if err != nil {
return fmt.Errorf("opening smux session: %v", err)
}
defer sess.Close()
for {
local, err := ln.Accept()
if err != nil {
if err, ok := err.(net.Error); ok && err.Temporary() {
continue
}
return err
}
go func() {
defer local.Close()
err := handle(local.(*net.TCPConn), sess, conn.GetConv())
if err != nil {
log.Printf("handle: %v", err)
}
}()
}
}
func main() {
var dohURL string
var dotAddr string
var pubkeyFilename string
var pubkeyString string
var udpAddr string
var utlsDistribution string
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), `Usage:
%[1]s [-doh URL|-dot ADDR|-udp ADDR] -pubkey-file PUBKEYFILE DOMAIN LOCALADDR
Examples:
%[1]s -doh https://resolver.example/dns-query -pubkey-file server.pub t.example.com 127.0.0.1:7000
%[1]s -dot resolver.example:853 -pubkey-file server.pub t.example.com 127.0.0.1:7000
`, os.Args[0])
flag.PrintDefaults()
labels := make([]string, 0, len(utlsClientHelloIDMap))
labels = append(labels, "none")
for _, entry := range utlsClientHelloIDMap {
labels = append(labels, entry.Label)
}
fmt.Fprintf(flag.CommandLine.Output(), `
Known TLS fingerprints for -utls are:
`)
i := 0
for i < len(labels) {
var line strings.Builder
fmt.Fprintf(&line, " %s", labels[i])
w := 2 + len(labels[i])
i++
for i < len(labels) && w+1+len(labels[i]) <= 72 {
fmt.Fprintf(&line, " %s", labels[i])
w += 1 + len(labels[i])
i++
}
fmt.Fprintln(flag.CommandLine.Output(), line.String())
}
}
flag.StringVar(&dohURL, "doh", "", "URL of DoH resolver")
flag.StringVar(&dotAddr, "dot", "", "address of DoT resolver")
flag.StringVar(&pubkeyString, "pubkey", "", fmt.Sprintf("server public key (%d hex digits)", noise.KeyLen*2))
flag.StringVar(&pubkeyFilename, "pubkey-file", "", "read server public key from file")
flag.StringVar(&udpAddr, "udp", "", "address of UDP DNS resolver")
flag.StringVar(&utlsDistribution, "utls",
"4*random,3*Firefox_120,1*Firefox_105,3*Chrome_120,1*Chrome_102,1*iOS_14,1*iOS_13",
"choose TLS fingerprint from weighted distribution")
flag.Parse()
log.SetFlags(log.LstdFlags | log.LUTC)
if flag.NArg() != 2 {
flag.Usage()
os.Exit(1)
}
domain, err := dns.ParseName(flag.Arg(0))
if err != nil {
fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
os.Exit(1)
}
localAddr, err := net.ResolveTCPAddr("tcp", flag.Arg(1))
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
var pubkey []byte
if pubkeyFilename != "" && pubkeyString != "" {
fmt.Fprintf(os.Stderr, "only one of -pubkey and -pubkey-file may be used\n")
os.Exit(1)
} else if pubkeyFilename != "" {
var err error
pubkey, err = readKeyFromFile(pubkeyFilename)
if err != nil {
fmt.Fprintf(os.Stderr, "cannot read pubkey from file: %v\n", err)
os.Exit(1)
}
} else if pubkeyString != "" {
var err error
pubkey, err = noise.DecodeKey(pubkeyString)
if err != nil {
fmt.Fprintf(os.Stderr, "pubkey format error: %v\n", err)
os.Exit(1)
}
}
if len(pubkey) == 0 {
fmt.Fprintf(os.Stderr, "the -pubkey or -pubkey-file option is required\n")
os.Exit(1)
}
tlsConfig, err := loadTLSConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "TLS config error: %v\n", err)
os.Exit(1)
}
utlsClientHelloID, err := sampleUTLSDistribution(utlsDistribution)
if err != nil {
fmt.Fprintf(os.Stderr, "parsing -utls: %v\n", err)
os.Exit(1)
}
if utlsClientHelloID != nil {
log.Printf("uTLS fingerprint %s %s", utlsClientHelloID.Client, utlsClientHelloID.Version)
}
// Iterate over the remote resolver address options and select one and
// only one.
var remoteAddr net.Addr
var pconn net.PacketConn
for _, opt := range []struct {
s string
f func(string) (net.Addr, net.PacketConn, error)
}{
// -doh
{dohURL, func(s string) (net.Addr, net.PacketConn, error) {
addr := turbotunnel.DummyAddr{}
var rt http.RoundTripper
if utlsClientHelloID == nil {
transport := http.DefaultTransport.(*http.Transport).Clone()
// Disable DefaultTransport's default Proxy =
// ProxyFromEnvironment setting, for conformity
// with utlsRoundTripper and with DoT mode,
// which do not take a proxy from the
// environment.
transport.Proxy = nil
transport.TLSClientConfig = tlsConfig.Clone()
baseDialContext := transport.DialContext
if baseDialContext == nil {
baseDialContext = (&net.Dialer{}).DialContext
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if network == "tcp" {
resolvedAddr, _, err := resolveAddrIPv4(ctx, addr)
if err != nil {
return nil, err
}
addr = resolvedAddr
network = "tcp4"
}
return baseDialContext(ctx, network, addr)
}
rt = transport
} else {
utlsConfig := &utls.Config{
RootCAs: tlsConfig.RootCAs,
MinVersion: tlsConfig.MinVersion,
}
rt = NewUTLSRoundTripper(utlsConfig, utlsClientHelloID, true)
}
pconn, err := NewHTTPPacketConn(rt, dohURL, 32)
return addr, pconn, err
}},
// -dot
{dotAddr, func(s string) (net.Addr, net.PacketConn, error) {
addr := turbotunnel.DummyAddr{}
var dialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
if utlsClientHelloID == nil {
dialTLSContext = (&tls.Dialer{Config: tlsConfig.Clone()}).DialContext
} else {
utlsConfig := &utls.Config{
RootCAs: tlsConfig.RootCAs,
MinVersion: tlsConfig.MinVersion,
}
dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return utlsDialContext(ctx, network, addr, utlsConfig, utlsClientHelloID)
}
}
pconn, err := NewTLSPacketConn(dotAddr, dialTLSContext)
return addr, pconn, err
}},
// -udp
{udpAddr, func(s string) (net.Addr, net.PacketConn, error) {
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, nil, err
}
pconn, err := net.ListenUDP("udp", nil)
return addr, pconn, err
}},
} {
if opt.s == "" {
continue
}
if pconn != nil {
fmt.Fprintf(os.Stderr, "only one of -doh, -dot, and -udp may be given\n")
os.Exit(1)
}
var err error
remoteAddr, pconn, err = opt.f(opt.s)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
if pconn == nil {
fmt.Fprintf(os.Stderr, "one of -doh, -dot, or -udp is required\n")
os.Exit(1)
}
pconn = NewDNSPacketConn(pconn, remoteAddr, domain)
err = run(pubkey, domain, localAddr, remoteAddr, pconn)
if err != nil {
log.Fatal(err)
}
}

134
internal/dnsttclient/tls.go Normal file
View File

@@ -0,0 +1,134 @@
package dnsttclient
import (
"bufio"
"context"
"encoding/binary"
"io"
"log"
"net"
"sync"
"time"
"socksrevivepc/internal/dnsttcore/turbotunnel"
)
const dialTimeout = 30 * time.Second
// TLSPacketConn is a TLS- and TCP-based transport for DNS messages, used for
// DNS over TLS (DoT). Its WriteTo and ReadFrom methods exchange DNS messages
// over a TLS channel, prefixing each message with a two-octet length field as
// in DNS over TCP.
//
// TLSPacketConn deals only with already formatted DNS messages. It does not
// handle encoding information into the messages. That is rather the
// responsibility of DNSPacketConn.
//
// https://tools.ietf.org/html/rfc7858
type TLSPacketConn struct {
// QueuePacketConn is the direct receiver of ReadFrom and WriteTo calls.
// recvLoop and sendLoop take the messages out of the receive and send
// queues and actually put them on the network.
*turbotunnel.QueuePacketConn
}
// NewTLSPacketConn creates a new TLSPacketConn configured to use the TLS
// server at addr as a DNS over TLS resolver. It maintains a TLS connection to
// the resolver, reconnecting as necessary. It closes the connection if any
// reconnection attempt fails.
func NewTLSPacketConn(addr string, dialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)) (*TLSPacketConn, error) {
dial := func() (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
defer cancel()
return dialTLSContext(ctx, "tcp", addr)
}
// We maintain one TLS connection at a time, redialing it whenever it
// becomes disconnected. We do the first dial here, outside the
// goroutine, so that any immediate and permanent connection errors are
// reported directly to the caller of NewTLSPacketConn.
conn, err := dial()
if err != nil {
return nil, err
}
c := &TLSPacketConn{
QueuePacketConn: turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, 0),
}
go func() {
defer c.Close()
for {
var wg sync.WaitGroup
wg.Add(2)
go func() {
err := c.recvLoop(conn)
if err != nil {
log.Printf("recvLoop: %v", err)
}
wg.Done()
}()
go func() {
err := c.sendLoop(conn)
if err != nil {
log.Printf("sendLoop: %v", err)
}
wg.Done()
}()
wg.Wait()
conn.Close()
// Whenever the TLS connection dies, redial a new one.
conn, err = dial()
if err != nil {
log.Printf("dial tls: %v", err)
break
}
}
}()
return c, nil
}
// recvLoop reads length-prefixed messages from conn and passes them to the
// incoming queue.
func (c *TLSPacketConn) recvLoop(conn net.Conn) error {
br := bufio.NewReader(conn)
for {
var length uint16
err := binary.Read(br, binary.BigEndian, &length)
if err != nil {
if err == io.EOF {
err = nil
}
return err
}
p := make([]byte, int(length))
_, err = io.ReadFull(br, p)
if err != nil {
return err
}
c.QueuePacketConn.QueueIncoming(p, turbotunnel.DummyAddr{})
}
}
// sendLoop reads messages from the outgoing queue and writes them,
// length-prefixed, to conn.
func (c *TLSPacketConn) sendLoop(conn net.Conn) error {
bw := bufio.NewWriter(conn)
for p := range c.QueuePacketConn.OutgoingQueue(turbotunnel.DummyAddr{}) {
length := uint16(len(p))
if int(length) != len(p) {
panic(len(p))
}
err := binary.Write(bw, binary.BigEndian, &length)
if err != nil {
return err
}
_, err = bw.Write(p)
if err != nil {
return err
}
err = bw.Flush()
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,337 @@
package dnsttclient
// Support code for TLS camouflage using uTLS.
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
utls "github.com/refraction-networking/utls"
"golang.org/x/net/http2"
)
// utlsClientHelloIDMap is a correspondence between human-readable labels and
// supported utls.ClientHelloIDs.
var utlsClientHelloIDMap = []struct {
Label string
ID *utls.ClientHelloID
}{
{"random", &utls.HelloRandomizedALPN},
{"Firefox", &utls.HelloFirefox_Auto},
{"Firefox_55", &utls.HelloFirefox_55},
{"Firefox_56", &utls.HelloFirefox_56},
{"Firefox_63", &utls.HelloFirefox_63},
{"Firefox_65", &utls.HelloFirefox_65},
{"Firefox_99", &utls.HelloFirefox_99},
{"Firefox_102", &utls.HelloFirefox_102},
{"Firefox_105", &utls.HelloFirefox_105},
{"Firefox_120", &utls.HelloFirefox_120},
{"Chrome", &utls.HelloChrome_Auto},
{"Chrome_58", &utls.HelloChrome_58},
{"Chrome_62", &utls.HelloChrome_62},
{"Chrome_70", &utls.HelloChrome_70},
{"Chrome_72", &utls.HelloChrome_72},
{"Chrome_83", &utls.HelloChrome_83},
{"Chrome_87", &utls.HelloChrome_87},
{"Chrome_96", &utls.HelloChrome_96},
{"Chrome_100", &utls.HelloChrome_100},
{"Chrome_102", &utls.HelloChrome_102},
{"Chrome_120", &utls.HelloChrome_120},
{"iOS", &utls.HelloIOS_Auto},
{"iOS_11_1", &utls.HelloIOS_11_1},
{"iOS_12_1", &utls.HelloIOS_12_1},
{"iOS_13", &utls.HelloIOS_13},
{"iOS_14", &utls.HelloIOS_14},
}
// utlsLookup returns a *utls.ClientHelloID from utlsClientHelloIDMap by a
// case-insensitive label match, or nil if there is no match.
func utlsLookup(label string) *utls.ClientHelloID {
for _, entry := range utlsClientHelloIDMap {
if strings.ToLower(label) == strings.ToLower(entry.Label) {
return entry.ID
}
}
return nil
}
var bootstrapResolverAddrs = []string{
"8.8.8.8:53",
"1.1.1.1:53",
"8.8.4.4:53",
"1.0.0.1:53",
}
func lookupHostIPv4(ctx context.Context, host string) ([]net.IP, error) {
var lastErr error
for _, resolverAddr := range bootstrapResolverAddrs {
resolverAddr := resolverAddr
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := &net.Dialer{Timeout: 5 * time.Second}
return d.DialContext(ctx, "udp4", resolverAddr)
},
}
ips, err := r.LookupIP(ctx, "ip4", host)
if err == nil && len(ips) > 0 {
return ips, nil
}
if err != nil {
lastErr = err
}
}
if lastErr == nil {
lastErr = fmt.Errorf("no A records returned")
}
return nil, fmt.Errorf("bootstrap IPv4 lookup failed for %s: %w", host, lastErr)
}
func resolveAddrIPv4(ctx context.Context, addr string) (string, string, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return "", "", err
}
if ip := net.ParseIP(host); ip != nil {
ip4 := ip.To4()
if ip4 == nil {
return "", "", fmt.Errorf("IPv6 address %s cannot be used for forced IPv4 dial", host)
}
return net.JoinHostPort(ip4.String(), port), host, nil
}
ips, err := lookupHostIPv4(ctx, host)
if err != nil {
return "", host, err
}
return net.JoinHostPort(ips[0].String(), port), host, nil
}
// utlsDialContext connects to the given network address and initiates a TLS
// handshake with the provided ClientHelloID, and returns the resulting TLS
// connection.
func utlsDialContext(ctx context.Context, network, addr string, config *utls.Config, id *utls.ClientHelloID) (*utls.UConn, error) {
return utlsDialContextWithOptions(ctx, network, addr, config, id, false)
}
func utlsDialContextWithOptions(ctx context.Context, network, addr string, config *utls.Config, id *utls.ClientHelloID, forceIPv4 bool) (*utls.UConn, error) {
originalHost, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if forceIPv4 {
addr, _, err = resolveAddrIPv4(ctx, addr)
if err != nil {
return nil, err
}
}
// Set the SNI from the original addr host, if not already set.
if config == nil {
config = &utls.Config{}
}
if config.ServerName == "" {
config = config.Clone()
config.ServerName = originalHost
}
dialer := &net.Dialer{}
if forceIPv4 && network == "tcp" {
network = "tcp4"
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
uconn := utls.UClient(conn, config, *id)
// We must call Handshake before returning, or else the UConn may not
// actually use the selected ClientHelloID. It depends on whether a Read
// or a Write happens first. If a Read happens first, the connection
// will use the normal crypto/tls fingerprint. If a Write happens first,
// it will use the selected fingerprint as expected.
// https://github.com/refraction-networking/utls/issues/75
err = uconn.Handshake()
if err != nil {
uconn.Close()
return nil, err
}
return uconn, nil
}
// The goal of utlsRoundTripper is: provide an http.RoundTripper abstraction
// that retains the features of http.Transport (e.g., persistent connections and
// HTTP/2 support), while making TLS connections using uTLS in place of
// crypto/tls. The challenge is: while http.Transport provides a DialTLSContext
// hook, setting it to non-nil disables automatic HTTP/2 support in the client.
// Most of the uTLS fingerprints contain an ALPN extension containing "h2";
// i.e., they declare support for HTTP/2. If the server also supports HTTP/2,
// then uTLS may negotiate an HTTP/2 connection without the http.Transport
// knowing it, which leads to an HTTP/1.1 client speaking to an HTTP/2 server, a
// protocol error.
//
// The code here uses an idea adapted from meek_lite in obfs4proxy:
// https://gitlab.com/yawning/obfs4/commit/4d453dab2120082b00bf6e63ab4aaeeda6b8d8a3
// Instead of setting DialTLSContext on an http.Transport and exposing it
// directly, we expose a wrapper type, utlsRoundTripper, which contains within
// it either an http.Transport or an http2.Transport. The first time a caller
// calls RoundTrip on the wrapper, we initiate a uTLS connection
// (bootstrapConn), then peek at the ALPN-negotiated protocol: if "h2", create
// an internal http2.Transport; otherwise, create an internal http.Transport. In
// either case, set DialTLSContext (or DialTLS for http2.Transport) on the
// created Transport to a function that dials using uTLS. As a special case, the
// first time the DialTLS callback is called, it reuses bootstrapConn (the one
// made to peek at the ALPN), rather than make a new connection.
//
// Subsequent calls to RoundTripper on the wrapper just pass the requests though
// the previously created http.Transport or http2.Transport. We assume that in
// future RoundTrips, the ALPN-negotiated protocol will remain the same as it
// was in the initial RoundTrip. At this point it is the http.Transport or
// http2.Transport calling DialTLSContext, not us, so we cannot dynamically swap
// the underlying transport based on the ALPN.
//
// https://bugs.torproject.org/tpo/anti-censorship/pluggable-transports/meek/29077
// https://github.com/refraction-networking/utls/issues/16
// utlsRoundTripper is an http.RoundTripper that uses uTLS (with a specified
// ClientHelloID) to make TLS connections.
//
// Can only be reused among servers which negotiate the same ALPN.
type utlsRoundTripper struct {
clientHelloID *utls.ClientHelloID
config *utls.Config
forceIPv4 bool
innerLock sync.Mutex
inner http.RoundTripper
}
// NewUTLSRoundTripper creates a utlsRoundTripper with the given TLS
// configuration and ClientHelloID.
func NewUTLSRoundTripper(config *utls.Config, id *utls.ClientHelloID, forceIPv4 bool) *utlsRoundTripper {
return &utlsRoundTripper{
clientHelloID: id,
config: config,
forceIPv4: forceIPv4,
// inner will be set in the first call to RoundTrip.
}
}
func (rt *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
switch req.URL.Scheme {
case "http":
// If http, don't invoke uTLS; just pass it to an ordinary http.Transport.
return http.DefaultTransport.RoundTrip(req)
case "https":
default:
return nil, fmt.Errorf("unsupported URL scheme %q", req.URL.Scheme)
}
var err error
rt.innerLock.Lock()
if rt.inner == nil {
// On the first call, make an http.Transport or http2.Transport
// as appropriate.
rt.inner, err = makeRoundTripper(req, rt.config, rt.clientHelloID, rt.forceIPv4)
}
rt.innerLock.Unlock()
if err != nil {
return nil, err
}
// Forward the request to the inner http.Transport or http2.Transport.
return rt.inner.RoundTrip(req)
}
// makeRoundTripper makes a bootstrap TLS configuration using the given TLS
// configuration and ClientHelloID, and creates an http.Transport or
// http2.Transport, depending on the negotated ALPN. The Transport is set up to
// make future TLS connections using the same TLS configuration and
// ClientHelloID.
func makeRoundTripper(req *http.Request, config *utls.Config, id *utls.ClientHelloID, forceIPv4 bool) (http.RoundTripper, error) {
addr, err := addrForDial(req.URL)
if err != nil {
return nil, err
}
bootstrapConn, err := utlsDialContextWithOptions(req.Context(), "tcp", addr, config, id, forceIPv4)
if err != nil {
return nil, err
}
// Peek at the ALPN-negotiated protocol.
protocol := bootstrapConn.ConnectionState().NegotiatedProtocol
// Protects bootstrapConn.
var lock sync.Mutex
// This is the callback for future dials done by the inner
// http.Transport or http2.Transport.
dialTLSContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
lock.Lock()
defer lock.Unlock()
// On the first dial, reuse bootstrapConn.
if bootstrapConn != nil {
uconn := bootstrapConn
bootstrapConn = nil
return uconn, nil
}
// Later dials make a new connection.
uconn, err := utlsDialContextWithOptions(ctx, "tcp", addr, config, id, forceIPv4)
if err != nil {
return nil, err
}
if uconn.ConnectionState().NegotiatedProtocol != protocol {
return nil, fmt.Errorf("unexpected switch from ALPN %q to %q",
protocol, uconn.ConnectionState().NegotiatedProtocol)
}
return uconn, nil
}
// Construct an http.Transport or http2.Transport depending on ALPN.
switch protocol {
case http2.NextProtoTLS:
// Unfortunately http2.Transport does not expose the same
// configuration options as http.Transport with regard to
// timeouts, etc., so we are at the mercy of the defaults.
// https://github.com/golang/go/issues/16581
return &http2.Transport{
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
// Ignore the *tls.Config parameter; use our
// static config instead.
return dialTLSContext(context.Background(), network, addr)
},
}, nil
default:
// With http.Transport, copy important default fields from
// http.DefaultTransport, such as TLSHandshakeTimeout and
// IdleConnTimeout, before overriding DialTLSContext.
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialTLSContext = dialTLSContext
return tr, nil
}
}
// addrForDial extracts a host:port address from a URL, suitable for dialing.
func addrForDial(url *url.URL) (string, error) {
host := url.Hostname()
// net/http would use golang.org/x/net/idna here, to convert a possible
// internationalized domain name to ASCII.
port := url.Port()
if port == "" {
// No port? Use the default for the scheme.
switch url.Scheme {
case "http":
port = "80"
case "https":
port = "443"
default:
return "", fmt.Errorf("unsupported URL scheme %q", url.Scheme)
}
}
return net.JoinHostPort(host, port), nil
}

View File

@@ -0,0 +1,201 @@
package dnsttclient
// Random selection from weighted distributions, and strings for specifying such
// distributions.
import (
cryptorand "crypto/rand"
"encoding/binary"
"fmt"
mathrand "math/rand"
"strconv"
"strings"
)
// parseWeightedList parses a list of text labels with optional numeric weights,
// and returns parallel slices of weights and labels. If a weight is omitted for
// a label, the weight is 1.
//
// An example weighted list string is "2*apple,orange,10*cookie". This example
// results in the slices [2, 1, 10] and ["apple", "orange", "cookie"].
// Bytes may be escaped by backslashes.
//
// list ::= entry ("," entry)*
// entry ::= (weight "*")? label
func parseWeightedList(s string) ([]uint32, []string, error) {
const (
kindEOF = iota
kindComma
kindAsterisk
kindText
kindError
)
type token struct {
Kind int
Text string
}
var i int
// nextToken incrementally consumes s and returns tokens.
nextToken := func() token {
if !(i < len(s)) {
return token{Kind: kindEOF}
}
if s[i] == ',' {
i++
return token{Kind: kindComma}
}
if s[i] == '*' {
i++
return token{Kind: kindAsterisk}
}
var text strings.Builder
for i < len(s) && s[i] != ',' && s[i] != '*' {
if s[i] == '\\' {
i++
if !(i < len(s)) {
return token{Kind: kindError, Text: fmt.Sprintf("%q at end of string", s[i])}
}
}
text.WriteByte(s[i])
i++
}
return token{Kind: kindText, Text: text.String()}
}
peekToken := func() token {
saved := i
t := nextToken()
i = saved
return t
}
const (
stateBeginEntry = iota
stateLabel
stateEndEntry
stateDone
stateUnexpected
)
var weights []uint32
var labels []string
var weightString, label string
var t token
for state := stateBeginEntry; state != stateDone; {
switch state {
// Beginning of a new entry (at the beginning of the input or
// after a comma).
case stateBeginEntry:
t = nextToken()
switch t.Kind {
case kindText:
// If the next token is an asterisk, this text
// represents a weight; otherwise it represents
// a label (with a weight of "1").
switch peekToken().Kind {
case kindAsterisk:
nextToken() // Consume the asterisk token.
weightString = t.Text
state = stateLabel
default:
weightString = "1"
label = t.Text
state = stateEndEntry
}
default:
state = stateUnexpected
}
// weightString is assigned and we have seen an asterisk, now
// expect a text label.
case stateLabel:
t = nextToken()
switch t.Kind {
case kindText:
label = t.Text
state = stateEndEntry
default:
state = stateUnexpected
}
// weightString and label are assigned, now emit the entry and
// expect a comma or EOF.
case stateEndEntry:
w, err := strconv.ParseUint(weightString, 10, 32)
if err != nil {
return nil, nil, err
}
weights = append(weights, uint32(w))
labels = append(labels, label)
t = nextToken()
switch t.Kind {
case kindEOF:
state = stateDone
case kindComma:
state = stateBeginEntry
default:
state = stateUnexpected
}
case stateUnexpected:
if t.Kind == kindError {
return nil, nil, fmt.Errorf("%s", t.Text)
} else {
var ttext string
switch t.Kind {
case kindEOF:
ttext = "end of string"
case kindComma:
ttext = "\",\""
case kindAsterisk:
ttext = "\"*\""
case kindText:
ttext = fmt.Sprintf("%+q", t.Text)
}
return nil, nil, fmt.Errorf("unexpected %s", ttext)
}
default:
panic(state)
}
}
return weights, labels, nil
}
// cryptoSource is a math/rand Source that reads from the crypto/rand Reader.
// The Seed method does not affect the sequence of numbers returned from the
// Int63 method.
type cryptoSource struct{}
func (s cryptoSource) Seed(_ int64) {}
func (s cryptoSource) Int63() int64 {
var n int64
err := binary.Read(cryptorand.Reader, binary.BigEndian, &n)
if err != nil {
panic(err)
}
n &= (1 << 63) - 1
return n
}
// sampleWeighted returns the index of a randomly selected element of the
// weights slice, weighted by the values stored in the slice. Panics if
// the sum of the weights is zero or does not fit in an int64.
func sampleWeighted(weights []uint32) int {
var sum int64 = 0
for _, w := range weights {
sum += int64(w)
if sum < int64(w) {
panic("weights overflow")
}
}
if sum == 0 {
panic("total weight is zero")
}
r := uint64(mathrand.New(&cryptoSource{}).Int63n(sum))
for i, w := range weights {
if r < uint64(w) {
return i
}
r -= uint64(w)
}
panic("impossible")
}

View File

@@ -0,0 +1,575 @@
// Package dns deals with encoding and decoding DNS wire format.
package dns
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
)
// The maximum number of DNS name compression pointers we are willing to follow.
// Without something like this, infinite loops are possible.
const compressionPointerLimit = 10
var (
// ErrZeroLengthLabel is the error returned for names that contain a
// zero-length label, like "example..com".
ErrZeroLengthLabel = errors.New("name contains a zero-length label")
// ErrLabelTooLong is the error returned for labels that are longer than
// 63 octets.
ErrLabelTooLong = errors.New("name contains a label longer than 63 octets")
// ErrNameTooLong is the error returned for names whose encoded
// representation is longer than 255 octets.
ErrNameTooLong = errors.New("name is longer than 255 octets")
// ErrReservedLabelType is the error returned when reading a label type
// prefix whose two most significant bits are not 00 or 11.
ErrReservedLabelType = errors.New("reserved label type")
// ErrTooManyPointers is the error returned when reading a compressed
// name that has too many compression pointers.
ErrTooManyPointers = errors.New("too many compression pointers")
// ErrTrailingBytes is the error returned when bytes remain in the parse
// buffer after parsing a message.
ErrTrailingBytes = errors.New("trailing bytes after message")
// ErrIntegerOverflow is the error returned when trying to encode an
// integer greater than 65535 into a 16-bit field.
ErrIntegerOverflow = errors.New("integer overflow")
)
const (
// https://tools.ietf.org/html/rfc1035#section-3.2.2
RRTypeTXT = 16
// https://tools.ietf.org/html/rfc6891#section-6.1.1
RRTypeOPT = 41
// https://tools.ietf.org/html/rfc1035#section-3.2.4
ClassIN = 1
// https://tools.ietf.org/html/rfc1035#section-4.1.1
RcodeNoError = 0 // a.k.a. NOERROR
RcodeFormatError = 1 // a.k.a. FORMERR
RcodeNameError = 3 // a.k.a. NXDOMAIN
RcodeNotImplemented = 4 // a.k.a. NOTIMPL
// https://tools.ietf.org/html/rfc6891#section-9
ExtendedRcodeBadVers = 16 // a.k.a. BADVERS
)
// Name represents a domain name, a sequence of labels each of which is 63
// octets or less in length.
//
// https://tools.ietf.org/html/rfc1035#section-3.1
type Name [][]byte
// NewName returns a Name from a slice of labels, after checking the labels for
// validity. Does not include a zero-length label at the end of the slice.
func NewName(labels [][]byte) (Name, error) {
name := Name(labels)
// https://tools.ietf.org/html/rfc1035#section-2.3.4
// Various objects and parameters in the DNS have size limits.
// labels 63 octets or less
// names 255 octets or less
for _, label := range labels {
if len(label) == 0 {
return nil, ErrZeroLengthLabel
}
if len(label) > 63 {
return nil, ErrLabelTooLong
}
}
// Check the total length.
builder := newMessageBuilder()
builder.WriteName(name)
if len(builder.Bytes()) > 255 {
return nil, ErrNameTooLong
}
return name, nil
}
// ParseName returns a new Name from a string of labels separated by dots, after
// checking the name for validity. A single dot at the end of the string is
// ignored.
func ParseName(s string) (Name, error) {
b := bytes.TrimSuffix([]byte(s), []byte("."))
if len(b) == 0 {
// bytes.Split(b, ".") would return [""] in this case
return NewName([][]byte{})
} else {
return NewName(bytes.Split(b, []byte(".")))
}
}
// String returns a reversible string representation of name. Labels are
// separated by dots, and any bytes in a label that are outside the set
// [0-9A-Za-z-] are replaced with a \xXX hex escape sequence.
func (name Name) String() string {
if len(name) == 0 {
return "."
}
var buf strings.Builder
for i, label := range name {
if i > 0 {
buf.WriteByte('.')
}
for _, b := range label {
if b == '-' ||
('0' <= b && b <= '9') ||
('A' <= b && b <= 'Z') ||
('a' <= b && b <= 'z') {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "\\x%02x", b)
}
}
}
return buf.String()
}
// TrimSuffix returns a Name with the given suffix removed, if it was present.
// The second return value indicates whether the suffix was present. If the
// suffix was not present, the first return value is nil.
func (name Name) TrimSuffix(suffix Name) (Name, bool) {
if len(name) < len(suffix) {
return nil, false
}
split := len(name) - len(suffix)
fore, aft := name[:split], name[split:]
for i := 0; i < len(aft); i++ {
if !bytes.Equal(bytes.ToLower(aft[i]), bytes.ToLower(suffix[i])) {
return nil, false
}
}
return fore, true
}
// Message represents a DNS message.
//
// https://tools.ietf.org/html/rfc1035#section-4.1
type Message struct {
ID uint16
Flags uint16
Question []Question
Answer []RR
Authority []RR
Additional []RR
}
// Opcode extracts the OPCODE part of the Flags field.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.1
func (message *Message) Opcode() uint16 {
return (message.Flags >> 11) & 0xf
}
// Rcode extracts the RCODE part of the Flags field.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.1
func (message *Message) Rcode() uint16 {
return message.Flags & 0x000f
}
// Question represents an entry in the question section of a message.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.2
type Question struct {
Name Name
Type uint16
Class uint16
}
// RR represents a resource record.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.3
type RR struct {
Name Name
Type uint16
Class uint16
TTL uint32
Data []byte
}
// readName parses a DNS name from r. It leaves r positioned just after the
// parsed name.
func readName(r io.ReadSeeker) (Name, error) {
var labels [][]byte
// We limit the number of compression pointers we are willing to follow.
numPointers := 0
// If we followed any compression pointers, we must finally seek to just
// past the first pointer.
var seekTo int64
loop:
for {
var labelType byte
err := binary.Read(r, binary.BigEndian, &labelType)
if err != nil {
return nil, err
}
switch labelType & 0xc0 {
case 0x00:
// This is an ordinary label.
// https://tools.ietf.org/html/rfc1035#section-3.1
length := int(labelType & 0x3f)
if length == 0 {
break loop
}
label := make([]byte, length)
_, err := io.ReadFull(r, label)
if err != nil {
return nil, err
}
labels = append(labels, label)
case 0xc0:
// This is a compression pointer.
// https://tools.ietf.org/html/rfc1035#section-4.1.4
upper := labelType & 0x3f
var lower byte
err := binary.Read(r, binary.BigEndian, &lower)
if err != nil {
return nil, err
}
offset := (uint16(upper) << 8) | uint16(lower)
if numPointers == 0 {
// The first time we encounter a pointer,
// remember our position so we can seek back to
// it when done.
seekTo, err = r.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
}
numPointers++
if numPointers > compressionPointerLimit {
return nil, ErrTooManyPointers
}
// Follow the pointer and continue.
_, err = r.Seek(int64(offset), io.SeekStart)
if err != nil {
return nil, err
}
default:
// "The 10 and 01 combinations are reserved for future
// use."
return nil, ErrReservedLabelType
}
}
// If we followed any pointers, then seek back to just after the first
// one.
if numPointers > 0 {
_, err := r.Seek(seekTo, io.SeekStart)
if err != nil {
return nil, err
}
}
return NewName(labels)
}
// readQuestion parses one entry from the Question section. It leaves r
// positioned just after the parsed entry.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.2
func readQuestion(r io.ReadSeeker) (Question, error) {
var question Question
var err error
question.Name, err = readName(r)
if err != nil {
return question, err
}
for _, ptr := range []*uint16{&question.Type, &question.Class} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return question, err
}
}
return question, nil
}
// readRR parses one resource record. It leaves r positioned just after the
// parsed resource record.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.3
func readRR(r io.ReadSeeker) (RR, error) {
var rr RR
var err error
rr.Name, err = readName(r)
if err != nil {
return rr, err
}
for _, ptr := range []*uint16{&rr.Type, &rr.Class} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return rr, err
}
}
err = binary.Read(r, binary.BigEndian, &rr.TTL)
if err != nil {
return rr, err
}
var rdLength uint16
err = binary.Read(r, binary.BigEndian, &rdLength)
if err != nil {
return rr, err
}
rr.Data = make([]byte, rdLength)
_, err = io.ReadFull(r, rr.Data)
if err != nil {
return rr, err
}
return rr, nil
}
// readMessage parses a complete DNS message. It leaves r positioned just after
// the parsed message.
func readMessage(r io.ReadSeeker) (Message, error) {
var message Message
// Header section
// https://tools.ietf.org/html/rfc1035#section-4.1.1
var qdCount, anCount, nsCount, arCount uint16
for _, ptr := range []*uint16{
&message.ID, &message.Flags,
&qdCount, &anCount, &nsCount, &arCount,
} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return message, err
}
}
// Question section
// https://tools.ietf.org/html/rfc1035#section-4.1.2
for i := 0; i < int(qdCount); i++ {
question, err := readQuestion(r)
if err != nil {
return message, err
}
message.Question = append(message.Question, question)
}
// Answer, Authority, and Additional sections
// https://tools.ietf.org/html/rfc1035#section-4.1.3
for _, rec := range []struct {
ptr *[]RR
count uint16
}{
{&message.Answer, anCount},
{&message.Authority, nsCount},
{&message.Additional, arCount},
} {
for i := 0; i < int(rec.count); i++ {
rr, err := readRR(r)
if err != nil {
return message, err
}
*rec.ptr = append(*rec.ptr, rr)
}
}
return message, nil
}
// MessageFromWireFormat parses a message from buf and returns a Message object.
// It returns ErrTrailingBytes if there are bytes remaining in buf after parsing
// is done.
func MessageFromWireFormat(buf []byte) (Message, error) {
r := bytes.NewReader(buf)
message, err := readMessage(r)
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == nil {
// Check for trailing bytes.
_, err = r.ReadByte()
if err == io.EOF {
err = nil
} else if err == nil {
err = ErrTrailingBytes
}
}
return message, err
}
// messageBuilder manages the state of serializing a DNS message. Its main
// function is to keep track of names already written for the purpose of name
// compression.
type messageBuilder struct {
w bytes.Buffer
nameCache map[string]int
}
// newMessageBuilder creates a new messageBuilder with an empty name cache.
func newMessageBuilder() *messageBuilder {
return &messageBuilder{
nameCache: make(map[string]int),
}
}
// Bytes returns the serialized DNS message as a slice of bytes.
func (builder *messageBuilder) Bytes() []byte {
return builder.w.Bytes()
}
// WriteName appends name to the in-progress messageBuilder, employing
// compression pointers to previously written names if possible.
func (builder *messageBuilder) WriteName(name Name) {
// https://tools.ietf.org/html/rfc1035#section-3.1
for i := range name {
// Has this suffix already been encoded in the message?
if ptr, ok := builder.nameCache[name[i:].String()]; ok && ptr&0x3fff == ptr {
// If so, we can write a compression pointer.
binary.Write(&builder.w, binary.BigEndian, uint16(0xc000|ptr))
return
}
// Not cached; we must encode this label verbatim. Store a cache
// entry pointing to the beginning of it.
builder.nameCache[name[i:].String()] = builder.w.Len()
length := len(name[i])
if length == 0 || length > 63 {
panic(length)
}
builder.w.WriteByte(byte(length))
builder.w.Write(name[i])
}
builder.w.WriteByte(0)
}
// WriteQuestion appends a Question section entry to the in-progress
// messageBuilder.
func (builder *messageBuilder) WriteQuestion(question *Question) {
// https://tools.ietf.org/html/rfc1035#section-4.1.2
builder.WriteName(question.Name)
binary.Write(&builder.w, binary.BigEndian, question.Type)
binary.Write(&builder.w, binary.BigEndian, question.Class)
}
// WriteRR appends a resource record to the in-progress messageBuilder. It
// returns ErrIntegerOverflow if the length of rr.Data does not fit in 16 bits.
func (builder *messageBuilder) WriteRR(rr *RR) error {
// https://tools.ietf.org/html/rfc1035#section-4.1.3
builder.WriteName(rr.Name)
binary.Write(&builder.w, binary.BigEndian, rr.Type)
binary.Write(&builder.w, binary.BigEndian, rr.Class)
binary.Write(&builder.w, binary.BigEndian, rr.TTL)
rdLength := uint16(len(rr.Data))
if int(rdLength) != len(rr.Data) {
return ErrIntegerOverflow
}
binary.Write(&builder.w, binary.BigEndian, rdLength)
builder.w.Write(rr.Data)
return nil
}
// WriteMessage appends a complete DNS message to the in-progress
// messageBuilder. It returns ErrIntegerOverflow if the number of entries in any
// section, or the length of the data in any resource record, does not fit in 16
// bits.
func (builder *messageBuilder) WriteMessage(message *Message) error {
// Header section
// https://tools.ietf.org/html/rfc1035#section-4.1.1
binary.Write(&builder.w, binary.BigEndian, message.ID)
binary.Write(&builder.w, binary.BigEndian, message.Flags)
for _, count := range []int{
len(message.Question),
len(message.Answer),
len(message.Authority),
len(message.Additional),
} {
count16 := uint16(count)
if int(count16) != count {
return ErrIntegerOverflow
}
binary.Write(&builder.w, binary.BigEndian, count16)
}
// Question section
// https://tools.ietf.org/html/rfc1035#section-4.1.2
for _, question := range message.Question {
builder.WriteQuestion(&question)
}
// Answer, Authority, and Additional sections
// https://tools.ietf.org/html/rfc1035#section-4.1.3
for _, rrs := range [][]RR{message.Answer, message.Authority, message.Additional} {
for _, rr := range rrs {
err := builder.WriteRR(&rr)
if err != nil {
return err
}
}
}
return nil
}
// WireFormat encodes a Message as a slice of bytes in DNS wire format. It
// returns ErrIntegerOverflow if the number of entries in any section, or the
// length of the data in any resource record, does not fit in 16 bits.
func (message *Message) WireFormat() ([]byte, error) {
builder := newMessageBuilder()
err := builder.WriteMessage(message)
if err != nil {
return nil, err
}
return builder.Bytes(), nil
}
// DecodeRDataTXT decodes TXT-DATA (as found in the RDATA for a resource record
// with TYPE=TXT) as a raw byte slice, by concatenating all the
// <character-string>s it contains.
//
// https://tools.ietf.org/html/rfc1035#section-3.3.14
func DecodeRDataTXT(p []byte) ([]byte, error) {
var buf bytes.Buffer
for {
if len(p) == 0 {
return nil, io.ErrUnexpectedEOF
}
n := int(p[0])
p = p[1:]
if len(p) < n {
return nil, io.ErrUnexpectedEOF
}
buf.Write(p[:n])
p = p[n:]
if len(p) == 0 {
break
}
}
return buf.Bytes(), nil
}
// EncodeRDataTXT encodes a slice of bytes as TXT-DATA, as appropriate for the
// RDATA of a resource record with TYPE=TXT. No length restriction is enforced
// here; that must be checked at a higher level.
//
// https://tools.ietf.org/html/rfc1035#section-3.3.14
func EncodeRDataTXT(p []byte) []byte {
// https://tools.ietf.org/html/rfc1035#section-3.3
// https://tools.ietf.org/html/rfc1035#section-3.3.14
// TXT data is a sequence of one or more <character-string>s, where
// <character-string> is a length octet followed by that number of
// octets.
var buf bytes.Buffer
for len(p) > 255 {
buf.WriteByte(255)
buf.Write(p[:255])
p = p[255:]
}
// Must write here, even if len(p) == 0, because it's "*one or more*
// <character-string>s".
buf.WriteByte(byte(len(p)))
buf.Write(p)
return buf.Bytes()
}

View File

@@ -0,0 +1,592 @@
package dns
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
"testing"
)
func namesEqual(a, b Name) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if !bytes.Equal(a[i], b[i]) {
return false
}
}
return true
}
func TestName(t *testing.T) {
for _, test := range []struct {
labels [][]byte
err error
s string
}{
{[][]byte{}, nil, "."},
{[][]byte{[]byte("test")}, nil, "test"},
{[][]byte{[]byte("a"), []byte("b"), []byte("c")}, nil, "a.b.c"},
{[][]byte{{}}, ErrZeroLengthLabel, ""},
{[][]byte{[]byte("a"), {}, []byte("c")}, ErrZeroLengthLabel, ""},
// 63 octets.
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE")}, nil,
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"},
// 64 octets.
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDEF")}, ErrLabelTooLong, ""},
// 64+64+64+62 octets.
{[][]byte{
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"),
}, nil,
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"},
// 64+64+64+63 octets.
{[][]byte{
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCD"),
}, ErrNameTooLong, ""},
// 127 one-octet labels.
{[][]byte{
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'},
}, nil,
"0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E"},
// 128 one-octet labels.
{[][]byte{
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
}, ErrNameTooLong, ""},
} {
// Test that NewName returns proper error codes, and otherwise
// returns an equal slice of labels.
name, err := NewName(test.labels)
if err != test.err || (err == nil && !namesEqual(name, test.labels)) {
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
test.labels, name, err, test.labels, test.err)
continue
}
if test.err != nil {
continue
}
// Test that the string version of the name comes out as
// expected.
s := name.String()
if s != test.s {
t.Errorf("%+q became string %+q, expected %+q", test.labels, s, test.s)
continue
}
// Test that parsing from a string back to a Name results in the
// original slice of labels.
name, err = ParseName(s)
if err != nil || !namesEqual(name, test.labels) {
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
test.labels, s, name, err, test.labels, nil)
continue
}
// A trailing dot should be ignored.
if !strings.HasSuffix(s, ".") {
dotName, dotErr := ParseName(s + ".")
if dotErr != err || !namesEqual(dotName, name) {
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
test.labels, s+".", dotName, dotErr, name, err)
continue
}
}
}
}
func TestParseName(t *testing.T) {
for _, test := range []struct {
s string
name Name
err error
}{
// This case can't be tested by TestName above because String
// will never produce "" (it produces "." instead).
{"", [][]byte{}, nil},
} {
name, err := ParseName(test.s)
if err != test.err || (err == nil && !namesEqual(name, test.name)) {
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
test.s, name, err, test.name, test.err)
continue
}
}
}
func unescapeString(s string) ([][]byte, error) {
if s == "." {
return [][]byte{}, nil
}
var result [][]byte
for _, label := range strings.Split(s, ".") {
var buf bytes.Buffer
i := 0
for i < len(label) {
switch label[i] {
case '\\':
if i+3 >= len(label) {
return nil, fmt.Errorf("truncated escape sequence at index %v", i)
}
if label[i+1] != 'x' {
return nil, fmt.Errorf("malformed escape sequence at index %v", i)
}
b, err := strconv.ParseUint(string(label[i+2:i+4]), 16, 8)
if err != nil {
return nil, fmt.Errorf("malformed hex sequence at index %v", i+2)
}
buf.WriteByte(byte(b))
i += 4
default:
buf.WriteByte(label[i])
i++
}
}
result = append(result, buf.Bytes())
}
return result, nil
}
func TestNameString(t *testing.T) {
for _, test := range []struct {
name Name
s string
}{
{[][]byte{}, "."},
{[][]byte{[]byte("\x00"), []byte("a.b"), []byte("c\nd\\")}, "\\x00.a\\x2eb.c\\x0ad\\x5c"},
{[][]byte{
[]byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>"),
[]byte("?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}"),
[]byte("~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc"),
[]byte("\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb"),
[]byte("\xfc\xfd\xfe\xff"),
}, "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29\\x2a\\x2b\\x2c-\\x2e\\x2f0123456789\\x3a\\x3b\\x3c\\x3d\\x3e.\\x3f\\x40ABCDEFGHIJKLMNOPQRSTUVWXYZ\\x5b\\x5c\\x5d\\x5e\\x5f\\x60abcdefghijklmnopqrstuvwxyz\\x7b\\x7c\\x7d.\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc.\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb.\\xfc\\xfd\\xfe\\xff"},
} {
s := test.name.String()
if s != test.s {
t.Errorf("%+q escaped to %+q, expected %+q", test.name, s, test.s)
continue
}
unescaped, err := unescapeString(s)
if err != nil {
t.Errorf("%+q unescaping %+q resulted in error %v", test.name, s, err)
continue
}
if !namesEqual(Name(unescaped), test.name) {
t.Errorf("%+q roundtripped through %+q to %+q", test.name, s, unescaped)
continue
}
}
}
func TestNameTrimSuffix(t *testing.T) {
for _, test := range []struct {
name, suffix string
trimmed string
ok bool
}{
{"", "", ".", true},
{".", ".", ".", true},
{"abc", "", "abc", true},
{"abc", ".", "abc", true},
{"", "abc", ".", false},
{".", "abc", ".", false},
{"example.com", "com", "example", true},
{"example.com", "net", ".", false},
{"example.com", "example.com", ".", true},
{"example.com", "test.com", ".", false},
{"example.com", "xample.com", ".", false},
{"example.com", "example", ".", false},
{"example.com", "COM", "example", true},
{"EXAMPLE.COM", "com", "EXAMPLE", true},
} {
tmp, ok := mustParseName(test.name).TrimSuffix(mustParseName(test.suffix))
trimmed := tmp.String()
if ok != test.ok || trimmed != test.trimmed {
t.Errorf("TrimSuffix %+q %+q returned (%+q, %v), expected (%+q, %v)",
test.name, test.suffix, trimmed, ok, test.trimmed, test.ok)
continue
}
}
}
func TestReadName(t *testing.T) {
// Good tests.
for _, test := range []struct {
start int64
end int64
input string
s string
}{
// Empty name.
{0, 1, "\x00abcd", "."},
// No pointers.
{12, 25, "AAAABBBBCCCC\x07example\x03com\x00", "example.com"},
// Backward pointer.
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c", "sub.example.com"},
// Forward pointer.
{0, 4, "\x01a\xc0\x04\x03bcd\x00", "a.bcd"},
// Two backwards pointers.
{31, 38, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c\x04sub2\xc0\x19", "sub2.sub.example.com"},
// Forward then backward pointer.
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x1f\x04sub2\xc0\x0c", "sub.sub2.example.com"},
// Overlapping codons.
{0, 4, "\x01a\xc0\x03bcd\x00", "a.bcd"},
// Pointer to empty label.
{0, 10, "\x07example\xc0\x0a\x00", "example"},
{1, 11, "\x00\x07example\xc0\x00", "example"},
// Pointer to pointer to empty label.
{0, 10, "\x07example\xc0\x0a\xc0\x0c\x00", "example"},
{1, 11, "\x00\x07example\xc0\x0c\xc0\x00", "example"},
} {
r := bytes.NewReader([]byte(test.input))
_, err := r.Seek(test.start, io.SeekStart)
if err != nil {
panic(err)
}
name, err := readName(r)
if err != nil {
t.Errorf("%+q returned error %s", test.input, err)
continue
}
s := name.String()
if s != test.s {
t.Errorf("%+q returned %+q, expected %+q", test.input, s, test.s)
continue
}
cur, _ := r.Seek(0, io.SeekCurrent)
if cur != test.end {
t.Errorf("%+q left offset %d, expected %d", test.input, cur, test.end)
continue
}
}
// Bad tests.
for _, test := range []struct {
start int64
input string
err error
}{
{0, "", io.ErrUnexpectedEOF},
// Reserved label type.
{0, "\x80example", ErrReservedLabelType},
// Reserved label type.
{0, "\x40example", ErrReservedLabelType},
// No Terminating empty label.
{0, "\x07example\x03com", io.ErrUnexpectedEOF},
// Pointer past end of buffer.
{0, "\x07example\xc0\xff", io.ErrUnexpectedEOF},
// Pointer to self.
{0, "\x07example\x03com\xc0\x0c", ErrTooManyPointers},
// Pointer to self with intermediate label.
{0, "\x07example\x03com\xc0\x08", ErrTooManyPointers},
// Two pointers that point to each other.
{0, "\xc0\x02\xc0\x00", ErrTooManyPointers},
// Two pointers that point to each other, with intermediate labels.
{0, "\x01a\xc0\x04\x01b\xc0\x00", ErrTooManyPointers},
// EOF while reading label.
{0, "\x0aexample", io.ErrUnexpectedEOF},
// EOF before second byte of pointer.
{0, "\xc0", io.ErrUnexpectedEOF},
{0, "\x07example\xc0", io.ErrUnexpectedEOF},
} {
r := bytes.NewReader([]byte(test.input))
_, err := r.Seek(test.start, io.SeekStart)
if err != nil {
panic(err)
}
name, err := readName(r)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != test.err {
t.Errorf("%+q returned (%+q, %v), expected %v", test.input, name, err, test.err)
continue
}
}
}
func mustParseName(s string) Name {
name, err := ParseName(s)
if err != nil {
panic(err)
}
return name
}
func questionsEqual(a, b *Question) bool {
if !namesEqual(a.Name, b.Name) {
return false
}
if a.Type != b.Type || a.Class != b.Class {
return false
}
return true
}
func rrsEqual(a, b *RR) bool {
if !namesEqual(a.Name, b.Name) {
return false
}
if a.Type != b.Type || a.Class != b.Class || a.TTL != b.TTL {
return false
}
if !bytes.Equal(a.Data, b.Data) {
return false
}
return true
}
func messagesEqual(a, b *Message) bool {
if a.ID != b.ID || a.Flags != b.Flags {
return false
}
if len(a.Question) != len(b.Question) {
return false
}
for i := 0; i < len(a.Question); i++ {
if !questionsEqual(&a.Question[i], &b.Question[i]) {
return false
}
}
for _, rec := range []struct{ rrA, rrB []RR }{
{a.Answer, b.Answer},
{a.Authority, b.Authority},
{a.Additional, b.Additional},
} {
if len(rec.rrA) != len(rec.rrB) {
return false
}
for i := 0; i < len(rec.rrA); i++ {
if !rrsEqual(&rec.rrA[i], &rec.rrB[i]) {
return false
}
}
}
return true
}
func TestMessageFromWireFormat(t *testing.T) {
for _, test := range []struct {
buf string
expected Message
err error
}{
{
"\x12\x34",
Message{},
io.ErrUnexpectedEOF,
},
{
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01",
Message{
ID: 0x1234,
Flags: 0x0100,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
},
Answer: []RR{},
Authority: []RR{},
Additional: []RR{},
},
nil,
},
{
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01X",
Message{},
ErrTrailingBytes,
},
{
"\x12\x34\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01\x03www\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x00\x80\x00\x04\xc0\x00\x02\x01",
Message{
ID: 0x1234,
Flags: 0x8180,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
},
Answer: []RR{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
TTL: 128,
Data: []byte{192, 0, 2, 1},
},
},
Authority: []RR{},
Additional: []RR{},
},
nil,
},
} {
message, err := MessageFromWireFormat([]byte(test.buf))
if err != test.err || (err == nil && !messagesEqual(&message, &test.expected)) {
t.Errorf("%+q\nreturned (%+v, %v)\nexpected (%+v, %v)",
test.buf, message, err, test.expected, test.err)
continue
}
}
}
func TestMessageWireFormatRoundTrip(t *testing.T) {
for _, message := range []Message{
{
ID: 0x1234,
Flags: 0x0100,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
{
Name: mustParseName("www2.example.com"),
Type: 2,
Class: 2,
},
},
Answer: []RR{
{
Name: mustParseName("abc"),
Type: 2,
Class: 3,
TTL: 0xffffffff,
Data: []byte{1},
},
{
Name: mustParseName("xyz"),
Type: 2,
Class: 3,
TTL: 255,
Data: []byte{},
},
},
Authority: []RR{
{
Name: mustParseName("."),
Type: 65535,
Class: 65535,
TTL: 0,
Data: []byte("XXXXXXXXXXXXXXXXXXX"),
},
},
Additional: []RR{},
},
} {
buf, err := message.WireFormat()
if err != nil {
t.Errorf("%+v cannot make wire format: %v", message, err)
continue
}
message2, err := MessageFromWireFormat(buf)
if err != nil {
t.Errorf("%+q cannot parse wire format: %v", buf, err)
continue
}
if !messagesEqual(&message, &message2) {
t.Errorf("messages unequal\nbefore: %+v\n after: %+v", message, message2)
continue
}
}
}
func TestDecodeRDataTXT(t *testing.T) {
for _, test := range []struct {
p []byte
decoded []byte
err error
}{
{[]byte{}, nil, io.ErrUnexpectedEOF},
{[]byte("\x00"), []byte{}, nil},
{[]byte("\x01"), nil, io.ErrUnexpectedEOF},
} {
decoded, err := DecodeRDataTXT(test.p)
if err != test.err || (err == nil && !bytes.Equal(decoded, test.decoded)) {
t.Errorf("%+q\nreturned (%+q, %v)\nexpected (%+q, %v)",
test.p, decoded, err, test.decoded, test.err)
continue
}
}
}
func TestEncodeRDataTXT(t *testing.T) {
// Encoding 0 bytes needs to return at least a single length octet of
// zero, not an empty slice.
p := make([]byte, 0)
encoded := EncodeRDataTXT(p)
if len(encoded) < 0 {
t.Errorf("EncodeRDataTXT(%v) returned %v", p, encoded)
}
// 255 bytes should be able to be encoded into 256 bytes.
p = make([]byte, 255)
encoded = EncodeRDataTXT(p)
if len(encoded) > 256 {
t.Errorf("EncodeRDataTXT(%d bytes) returned %d bytes", len(p), len(encoded))
}
}
func TestRDataTXTRoundTrip(t *testing.T) {
for _, p := range [][]byte{
{},
[]byte("\x00"),
{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
},
} {
rdata := EncodeRDataTXT(p)
decoded, err := DecodeRDataTXT(rdata)
if err != nil || !bytes.Equal(decoded, p) {
t.Errorf("%+q returned (%+q, %v)", p, decoded, err)
continue
}
}
}

View File

@@ -0,0 +1,23 @@
//go:build gofuzz
// +build gofuzz
// Fuzzing driver for https://github.com/dvyukov/go-fuzz.
// go get -u github.com/dvyukov/go-fuzz/go-fuzz github.com/dvyukov/go-fuzz/go-fuzz-build
// $GOPATH/bin/go-fuzz-build
// $GOPATH/bin/go-fuzz
//
// Related link: https://blog.cloudflare.com/dns-parser-meet-go-fuzzer/
package dns
func Fuzz(data []byte) int {
msg, err := MessageFromWireFormat(data)
if err != nil {
return 0
}
_, err = msg.WireFormat()
if err != nil {
panic(err)
}
return 1 // prioritize this input
}

View File

@@ -0,0 +1,276 @@
// Package noise provides a net.Conn-like interface for a
// Noise_NK_25519_ChaChaPoly_BLAKE2s. It encodes Noise messages onto a reliable
// stream using 16-bit length prefixes.
//
// https://noiseprotocol.org/noise.html
package noise
import (
"bufio"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"github.com/flynn/noise"
"golang.org/x/crypto/curve25519"
)
// The length of public and private keys as returned by GeneratePrivkey.
const KeyLen = 32
// cipherSuite represents 25519_ChaChaPoly_BLAKE2s.
var cipherSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s)
// readMessage reads a length-prefixed message from r. It returns a nil error
// only when a complete message was read. It returns io.EOF only when there were
// 0 bytes remaining to read from r. It returns io.ErrUnexpectedEOF when EOF
// occurs in the middle of an encoded message.
func readMessage(r io.Reader) ([]byte, error) {
var length uint16
err := binary.Read(r, binary.BigEndian, &length)
if err != nil {
// We may return a real io.EOF only here.
return nil, err
}
msg := make([]byte, int(length))
_, err = io.ReadFull(r, msg)
// Here we must change io.EOF to io.ErrUnexpectedEOF.
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return msg, err
}
// writeMessage writes msg as a length-prefixed message to w. It panics if the
// length of msg cannot be represented in 16 bits.
func writeMessage(w io.Writer, msg []byte) error {
length := uint16(len(msg))
if int(length) != len(msg) {
panic(len(msg))
}
err := binary.Write(w, binary.BigEndian, length)
if err != nil {
return err
}
_, err = w.Write(msg)
return err
}
// socket is the internal type that represents a Noise-wrapped
// io.ReadWriteCloser.
type socket struct {
recvPipe *io.PipeReader
sendCipher *noise.CipherState
io.ReadWriteCloser
}
func newSocket(rwc io.ReadWriteCloser, recvCipher, sendCipher *noise.CipherState) *socket {
pr, pw := io.Pipe()
// This loop calls readMessage, decrypts the messages, and feeds them
// into recvPipe where they will be returned from Read.
go func() (err error) {
defer func() {
pw.CloseWithError(err)
}()
for {
msg, err := readMessage(rwc)
if err != nil {
return err
}
p, err := recvCipher.Decrypt(nil, nil, msg)
if err != nil {
return err
}
_, err = pw.Write(p)
if err != nil {
return err
}
}
}()
return &socket{
sendCipher: sendCipher,
recvPipe: pr,
ReadWriteCloser: rwc,
}
}
// Read reads decrypted data from the wrapped io.Reader.
func (s *socket) Read(p []byte) (int, error) {
return s.recvPipe.Read(p)
}
// Write writes encrypted data from the wrapped io.Writer.
func (s *socket) Write(p []byte) (int, error) {
total := 0
for len(p) > 0 {
n := len(p)
if n > 4096 {
n = 4096
}
msg, err := s.sendCipher.Encrypt(nil, nil, p[:n])
if err != nil {
return total, err
}
err = writeMessage(s.ReadWriteCloser, msg)
if err != nil {
return total, err
}
total += n
p = p[n:]
}
return total, nil
}
// newConfig instantiates configuration settings that are common to clients and
// servers.
func newConfig() noise.Config {
return noise.Config{
CipherSuite: cipherSuite,
Pattern: noise.HandshakeNK,
Prologue: []byte("dnstt 2020-04-13"),
}
}
// NewClient wraps an io.ReadWriteCloser in a Noise protocol as a client, and
// returns after completing the handshake. It returns a non-nil error if there
// is an error during the handshake.
func NewClient(rwc io.ReadWriteCloser, serverPubkey []byte) (io.ReadWriteCloser, error) {
config := newConfig()
config.Initiator = true
config.PeerStatic = serverPubkey
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return nil, err
}
// -> e, es
msg, _, _, err := handshakeState.WriteMessage(nil, nil)
if err != nil {
return nil, err
}
err = writeMessage(rwc, msg)
if err != nil {
return nil, err
}
// <- e, es
msg, err = readMessage(rwc)
if err != nil {
return nil, err
}
payload, sendCipher, recvCipher, err := handshakeState.ReadMessage(nil, msg)
if err != nil {
return nil, err
}
if len(payload) != 0 {
return nil, errors.New("unexpected server payload")
}
return newSocket(rwc, recvCipher, sendCipher), nil
}
// NewClient wraps an io.ReadWriteCloser in a Noise protocol as a server, and
// returns after completing the handshake. It returns a non-nil error if there
// is an error during the handshake.
func NewServer(rwc io.ReadWriteCloser, serverPrivkey []byte) (io.ReadWriteCloser, error) {
config := newConfig()
config.Initiator = false
config.StaticKeypair = noise.DHKey{
Private: serverPrivkey,
Public: PubkeyFromPrivkey(serverPrivkey),
}
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return nil, err
}
// -> e, es
msg, err := readMessage(rwc)
if err != nil {
return nil, err
}
payload, _, _, err := handshakeState.ReadMessage(nil, msg)
if err != nil {
return nil, err
}
if len(payload) != 0 {
return nil, errors.New("unexpected client payload")
}
// <- e, es
msg, recvCipher, sendCipher, err := handshakeState.WriteMessage(nil, nil)
if err != nil {
return nil, err
}
err = writeMessage(rwc, msg)
if err != nil {
return nil, err
}
return newSocket(rwc, recvCipher, sendCipher), nil
}
// GeneratePrivkey generates a private key. The corresponding public key can be
// derived using PubkeyFromPrivkey.
func GeneratePrivkey() ([]byte, error) {
pair, err := noise.DH25519.GenerateKeypair(rand.Reader)
return pair.Private, err
}
// PubkeyFromPrivkey returns the public key that corresponds to privkey.
func PubkeyFromPrivkey(privkey []byte) []byte {
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey
}
// ReadKey reads a hex-encoded key from r. r must consist of a single line, with
// or without a '\n' line terminator. The line must consist of KeyLen
// hex-encoded bytes.
func ReadKey(r io.Reader) ([]byte, error) {
br := bufio.NewReader(io.LimitReader(r, 100))
line, err := br.ReadString('\n')
if err == io.EOF {
err = nil
}
if err == nil {
// Check that we're at EOF.
_, err = br.ReadByte()
if err == io.EOF {
err = nil
} else if err == nil {
err = fmt.Errorf("file contains more than one line")
}
}
if err != nil {
return nil, err
}
line = strings.TrimSuffix(line, "\n")
return DecodeKey(line)
}
// WriteKey writes the hex-encoded key in a single line to w.
func WriteKey(w io.Writer, key []byte) error {
_, err := fmt.Fprintf(w, "%x\n", key)
return err
}
// DecodeKey decodes a hex-encoded private or public key.
func DecodeKey(s string) ([]byte, error) {
key, err := hex.DecodeString(s)
if err == nil && len(key) != KeyLen {
err = fmt.Errorf("length is %d, expected %d", len(key), KeyLen)
}
return key, err
}
// EncodeKey encodes a hex-encoded private or public key.
func EncodeKey(key []byte) string {
return hex.EncodeToString(key)
}

View File

@@ -0,0 +1,218 @@
package noise
import (
"bytes"
"io"
"net"
"testing"
"github.com/flynn/noise"
)
func allMessages(buf []byte) ([][]byte, error) {
var messages [][]byte
r := bytes.NewReader(buf)
for {
msg, err := readMessage(r)
if err != nil {
return messages, err
}
messages = append(messages, msg)
}
}
func messagesEqual(a, b [][]byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i], b[i]) {
return false
}
}
return true
}
func TestReadMessage(t *testing.T) {
for _, test := range []struct {
input string
messages [][]byte
err error
}{
{"", [][]byte{}, io.EOF},
{"\x00", [][]byte{}, io.ErrUnexpectedEOF},
{"\x00\x00", [][]byte{{}}, io.EOF},
{"\x00\x00\x00", [][]byte{{}}, io.ErrUnexpectedEOF},
{"\x00\x01", [][]byte{}, io.ErrUnexpectedEOF},
{"\x00\x05hello\x00\x05world", [][]byte{[]byte("hello"), []byte("world")}, io.EOF},
} {
packets, err := allMessages([]byte(test.input))
if !messagesEqual(packets, test.messages) || err != test.err {
t.Errorf("%x\nreturned %x %v\nexpected %x %v",
test.input, packets, err, test.messages, test.err)
}
}
}
func TestMessageRoundTrip(t *testing.T) {
for _, messages := range [][][]byte{
{},
} {
var buf bytes.Buffer
for _, msg := range messages {
err := writeMessage(&buf, msg)
if err != nil {
panic(err)
}
}
output, err := allMessages(buf.Bytes())
if !messagesEqual(output, messages) || err != io.EOF {
t.Errorf("%x roundtripped to %x %v",
messages, output, err)
}
}
}
func TestReadKey(t *testing.T) {
for _, test := range []struct {
input string
output []byte
}{
{"", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n", []byte("\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef\x01\x23\x45\x67\x89\xab\xcd\xef")},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\nX", nil},
{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n", nil},
{"\n0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
{"X123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", nil},
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", nil},
} {
output, err := ReadKey(bytes.NewReader([]byte(test.input)))
if test.output == nil {
if err == nil {
t.Errorf("%+q expected error", test.input)
}
} else {
if err != nil {
t.Errorf("%+q returned error %v", test.input, err)
} else if !bytes.Equal(output, test.output) {
t.Errorf("%+q got %x, expected %x", test.input, output, test.output)
}
}
}
}
func TestUnexpectedPayload(t *testing.T) {
privkey, err := GeneratePrivkey()
if err != nil {
panic(err)
}
pubkey := PubkeyFromPrivkey(privkey)
// Test the client sending an unexpected payload.
clientWithPayload := func(rwc io.ReadWriteCloser) error {
config := newConfig()
config.Initiator = true
config.PeerStatic = pubkey
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return err
}
// -> e, es
msg, _, _, err := handshakeState.WriteMessage(nil, []byte("payload"))
if err != nil {
return err
}
err = writeMessage(rwc, msg)
if err != nil {
return err
}
// <- e, es
// Return nil for all errors after this point, because we expect
// the server to have failed, but we want to keep up the game
// just in case the server did not fail.
msg, err = readMessage(rwc)
if err != nil {
return nil
}
_, _, _, err = handshakeState.ReadMessage(nil, msg)
if err != nil {
return nil
}
return nil
}
func() {
c, s := net.Pipe()
defer s.Close()
// Fake a client side that sends a payload.
go func() {
defer c.Close()
err := clientWithPayload(c)
if err != nil {
t.Fatal(err)
}
}()
server, err := NewServer(s, privkey)
if err == nil || err.Error() != "unexpected client payload" || server != nil {
t.Errorf("NewServer got (%T, %v)", server, err)
}
}()
// Test the server sending an unexpected payload.
serverWithPayload := func(rwc io.ReadWriteCloser) error {
config := newConfig()
config.Initiator = false
config.StaticKeypair = noise.DHKey{Private: privkey, Public: pubkey}
handshakeState, err := noise.NewHandshakeState(config)
if err != nil {
return err
}
// -> e, es
msg, err := readMessage(rwc)
if err != nil {
return err
}
_, _, _, err = handshakeState.ReadMessage(nil, msg)
if err != nil {
return err
}
// <- e, es
msg, _, _, err = handshakeState.WriteMessage(nil, []byte("payload"))
if err != nil {
return err
}
err = writeMessage(rwc, msg)
if err != nil {
return err
}
return nil
}
func() {
c, s := net.Pipe()
defer c.Close()
// Fake a server side that sends a payload.
go func() {
defer s.Close()
err := serverWithPayload(s)
if err != nil {
t.Fatal(err)
}
}()
client, err := NewClient(c, pubkey)
if err == nil || err.Error() != "unexpected server payload" || client != nil {
t.Errorf("NewClient got (%T, %v)", client, err)
}
}()
}

View File

@@ -0,0 +1,28 @@
package turbotunnel
import (
"crypto/rand"
"encoding/hex"
)
// ClientID is an abstract identifier that binds together all the communications
// belonging to a single client session, even though those communications may
// arrive from multiple IP addresses or over multiple lower-level connections.
// It plays the same role that an (IP address, port number) tuple plays in a
// net.UDPConn: it's the return address pertaining to a long-lived abstract
// client session. The client attaches its ClientID to each of its
// communications, enabling the server to disambiguate requests among its many
// clients. ClientID implements the net.Addr interface.
type ClientID [8]byte
func NewClientID() ClientID {
var id ClientID
_, err := rand.Read(id[:])
if err != nil {
panic(err)
}
return id
}
func (id ClientID) Network() string { return "clientid" }
func (id ClientID) String() string { return hex.EncodeToString(id[:]) }

View File

@@ -0,0 +1,22 @@
// Package turbotunnel is facilities for embedding packet-based reliability
// protocols inside other protocols.
//
// https://github.com/net4people/bbs/issues/9
package turbotunnel
import "errors"
// QueueSize is the size of send and receive queues in QueuePacketConn and
// RemoteMap.
const QueueSize = 128
var errClosedPacketConn = errors.New("operation on closed connection")
var errNotImplemented = errors.New("not implemented")
// DummyAddr is a placeholder net.Addr, for when a programming interface
// requires a net.Addr but there is none relevant. All DummyAddrs compare equal
// to each other.
type DummyAddr struct{}
func (addr DummyAddr) Network() string { return "dummy" }
func (addr DummyAddr) String() string { return "dummy" }

View File

@@ -0,0 +1,162 @@
package turbotunnel
import (
"net"
"sync"
"sync/atomic"
"time"
)
// taggedPacket is a combination of a []byte and a net.Addr, encapsulating the
// return type of PacketConn.ReadFrom.
type taggedPacket struct {
P []byte
Addr net.Addr
}
// QueuePacketConn implements net.PacketConn by storing queues of packets. There
// is one incoming queue (where packets are additionally tagged by the source
// address of the peer that sent them). There are many outgoing queues, one for
// each remote peer address that has been recently seen. The QueueIncoming
// method inserts a packet into the incoming queue, to eventually be returned by
// ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue,
// which can later by accessed through the OutgoingQueue method.
//
// Besides the outgoing queues, there is also a one-element "stash" for each
// remote peer address. You can stash a packet using the Stash method, and get
// it back later by receiving from the channel returned by Unstash. The stash is
// meant as a convenient place to temporarily store a single packet, such as
// when you've read one too many packets from the send queue and need to store
// the extra packet to be processed first in the next pass. It's the caller's
// responsibility to Unstash what they have Stashed. Calling Stash does not put
// the packet at the head of the send queue; if there is the possibility that a
// packet has been stashed, it must be checked for by calling Unstash in
// addition to OutgoingQueue.
type QueuePacketConn struct {
remotes *RemoteMap
localAddr net.Addr
recvQueue chan taggedPacket
closeOnce sync.Once
closed chan struct{}
// What error to return when the QueuePacketConn is closed.
err atomic.Value
}
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent peers
// for at least a duration of timeout.
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
return &QueuePacketConn{
remotes: NewRemoteMap(timeout),
localAddr: localAddr,
recvQueue: make(chan taggedPacket, QueueSize),
closed: make(chan struct{}),
}
}
// QueueIncoming queues and incoming packet and its source address, to be
// returned in a future call to ReadFrom.
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
select {
case <-c.closed:
// If we're closed, silently drop it.
return
default:
}
// Copy the slice so that the caller may reuse it.
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.recvQueue <- taggedPacket{buf, addr}:
default:
// Drop the incoming packet if the receive queue is full.
}
}
// OutgoingQueue returns the queue of outgoing packets corresponding to addr,
// creating it if necessary. The contents of the queue will be packets that are
// written to the address in question using WriteTo.
func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
return c.remotes.SendQueue(addr)
}
// Stash places p in the stash for addr, if the stash is not already occupied.
// Returns true if the packet was placed in the stash, or false if the stash was
// already occupied. This method is similar to WriteTo, except that it puts the
// packet in the stash queue (accessible via Unstash), rather than the outgoing
// queue (accessible via OutgoingQueue).
func (c *QueuePacketConn) Stash(p []byte, addr net.Addr) bool {
return c.remotes.Stash(addr, p)
}
// Unstash returns the channel that represents the stash for addr.
func (c *QueuePacketConn) Unstash(addr net.Addr) <-chan []byte {
return c.remotes.Unstash(addr)
}
// ReadFrom returns a packet and address previously stored by QueueIncoming.
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
select {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
default:
}
select {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
case packet := <-c.recvQueue:
return copy(p, packet.P), packet.Addr, nil
}
}
// WriteTo queues an outgoing packet for the given address. The queue can later
// be retrieved using the OutgoingQueue method.
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
select {
case <-c.closed:
return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
default:
}
// Copy the slice so that the caller may reuse it.
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.remotes.SendQueue(addr) <- buf:
return len(buf), nil
default:
// Drop the outgoing packet if the send queue is full.
return len(buf), nil
}
}
// closeWithError unblocks pending operations and makes future operations fail
// with the given error. If err is nil, it becomes errClosedPacketConn.
func (c *QueuePacketConn) closeWithError(err error) error {
var newlyClosed bool
c.closeOnce.Do(func() {
newlyClosed = true
// Store the error to be returned by future PacketConn
// operations.
if err == nil {
err = errClosedPacketConn
}
c.err.Store(err)
close(c.closed)
})
if !newlyClosed {
return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
}
return nil
}
// Close unblocks pending operations and makes future operations fail with a
// "closed connection" error.
func (c *QueuePacketConn) Close() error {
return c.closeWithError(nil)
}
// LocalAddr returns the localAddr value that was passed to NewQueuePacketConn.
func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr }
func (c *QueuePacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
func (c *QueuePacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }

View File

@@ -0,0 +1,177 @@
package turbotunnel
import (
"container/heap"
"net"
"sync"
"time"
)
// remoteRecord is a record of a recently seen remote peer, with the time it was
// last seen and queues of outgoing packets.
type remoteRecord struct {
Addr net.Addr
LastSeen time.Time
SendQueue chan []byte
Stash chan []byte
}
// RemoteMap manages a mapping of live remote peers, keyed by address, to their
// respective send queues. Each peer has two queues: a primary send queue, and a
// "stash". The primary send queue is returned by the SendQueue method. The
// stash is an auxiliary one-element queue accessed using the Stash and Unstash
// methods. The stash is meant for use by callers that need to "unread" a packet
// that's already been removed from the primary send queue.
//
// RemoteMap's functions are safe to call from multiple goroutines.
type RemoteMap struct {
// We use an inner structure to avoid exposing public heap.Interface
// functions to users of remoteMap.
inner remoteMapInner
// Synchronizes access to inner.
lock sync.Mutex
}
// NewRemoteMap creates a RemoteMap that expires peers after a timeout.
//
// If the timeout is 0, peers never expire.
//
// The timeout does not have to be kept in sync with smux's idle timeout. If a
// peer is removed from the map while the smux session is still live, the worst
// that can happen is a loss of whatever packets were in the send queue at the
// time. If smux later decides to send more packets to the same peer, we'll
// instantiate a new send queue, and if the peer is ever seen again with a
// matching address, we'll deliver them.
func NewRemoteMap(timeout time.Duration) *RemoteMap {
m := &RemoteMap{
inner: remoteMapInner{
byAge: make([]*remoteRecord, 0),
byAddr: make(map[net.Addr]int),
},
}
if timeout > 0 {
go func() {
for {
time.Sleep(timeout / 2)
now := time.Now()
m.lock.Lock()
m.inner.removeExpired(now, timeout)
m.lock.Unlock()
}
}()
}
return m
}
// SendQueue returns the send queue corresponding to addr, creating it if
// necessary.
func (m *RemoteMap) SendQueue(addr net.Addr) chan []byte {
m.lock.Lock()
defer m.lock.Unlock()
return m.inner.Lookup(addr, time.Now()).SendQueue
}
// Stash places p in the stash corresponding to addr, if the stash is not
// already occupied. Returns true if the p was placed in the stash, false
// otherwise.
func (m *RemoteMap) Stash(addr net.Addr, p []byte) bool {
m.lock.Lock()
defer m.lock.Unlock()
select {
case m.inner.Lookup(addr, time.Now()).Stash <- p:
return true
default:
return false
}
}
// Unstash returns the channel that reads from the stash for addr.
func (m *RemoteMap) Unstash(addr net.Addr) <-chan []byte {
m.lock.Lock()
defer m.lock.Unlock()
return m.inner.Lookup(addr, time.Now()).Stash
}
// remoteMapInner is the inner type of RemoteMap, implementing heap.Interface.
// byAge is the backing store, a heap ordered by LastSeen time, to facilitate
// expiring old records. byAddr is a map from addresses to heap indices, to
// allow looking up by address. Unlike RemoteMap, remoteMapInner requires
// external synchonization.
type remoteMapInner struct {
byAge []*remoteRecord
byAddr map[net.Addr]int
}
// removeExpired removes all records whose LastSeen timestamp is more than
// timeout in the past.
func (inner *remoteMapInner) removeExpired(now time.Time, timeout time.Duration) {
for len(inner.byAge) > 0 && now.Sub(inner.byAge[0].LastSeen) >= timeout {
record := heap.Pop(inner).(*remoteRecord)
close(record.SendQueue)
}
}
// Lookup finds the existing record corresponding to addr, or creates a new
// one if none exists yet. It updates the record's LastSeen time and returns the
// record.
func (inner *remoteMapInner) Lookup(addr net.Addr, now time.Time) *remoteRecord {
var record *remoteRecord
i, ok := inner.byAddr[addr]
if ok {
// Found one, update its LastSeen.
record = inner.byAge[i]
record.LastSeen = now
heap.Fix(inner, i)
} else {
// Not found, create a new one.
record = &remoteRecord{
Addr: addr,
LastSeen: now,
SendQueue: make(chan []byte, QueueSize),
Stash: make(chan []byte, 1),
}
heap.Push(inner, record)
}
return record
}
// heap.Interface for remoteMapInner.
func (inner *remoteMapInner) Len() int {
if len(inner.byAge) != len(inner.byAddr) {
panic("inconsistent remoteMap")
}
return len(inner.byAge)
}
func (inner *remoteMapInner) Less(i, j int) bool {
return inner.byAge[i].LastSeen.Before(inner.byAge[j].LastSeen)
}
func (inner *remoteMapInner) Swap(i, j int) {
inner.byAge[i], inner.byAge[j] = inner.byAge[j], inner.byAge[i]
inner.byAddr[inner.byAge[i].Addr] = i
inner.byAddr[inner.byAge[j].Addr] = j
}
func (inner *remoteMapInner) Push(x interface{}) {
record := x.(*remoteRecord)
if _, ok := inner.byAddr[record.Addr]; ok {
panic("duplicate address in remoteMap")
}
// Insert into byAddr map.
inner.byAddr[record.Addr] = len(inner.byAge)
// Insert into byAge slice.
inner.byAge = append(inner.byAge, record)
}
func (inner *remoteMapInner) Pop() interface{} {
n := len(inner.byAddr)
// Remove from byAge slice.
record := inner.byAge[n-1]
inner.byAge[n-1] = nil
inner.byAge = inner.byAge[:n-1]
// Remove from byAddr map.
delete(inner.byAddr, record.Addr)
return record
}

45
internal/engine/logger.go Normal file
View File

@@ -0,0 +1,45 @@
package engine
import (
"fmt"
"sync"
"time"
)
type LogEntry struct {
ID int64 `json:"id"`
Time string `json:"time"`
Level string `json:"level"`
Message string `json:"message"`
}
type Logger struct {
mu sync.Mutex
nextID int64
entries []LogEntry
}
func NewLogger() *Logger { return &Logger{} }
func (l *Logger) Add(level, format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.nextID++
entry := LogEntry{ID: l.nextID, Time: time.Now().Format("15:04:05"), Level: level, Message: fmt.Sprintf(format, args...)}
l.entries = append(l.entries, entry)
if len(l.entries) > 600 {
l.entries = l.entries[len(l.entries)-600:]
}
}
func (l *Logger) Since(id int64) []LogEntry {
l.mu.Lock()
defer l.mu.Unlock()
out := make([]LogEntry, 0)
for _, e := range l.entries {
if e.ID > id {
out = append(out, e)
}
}
return out
}

525
internal/engine/manager.go Normal file
View File

@@ -0,0 +1,525 @@
package engine
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"socksrevivepc/internal/config"
"socksrevivepc/internal/dnsttclient"
"socksrevivepc/internal/routes"
"socksrevivepc/internal/tun"
)
type Status struct {
Running bool `json:"running"`
Connecting bool `json:"connecting"`
ProfileID string `json:"profile_id"`
Mode string `json:"mode"`
SocksAddr string `json:"socks_addr"`
Tun bool `json:"tun"`
StartedAt string `json:"started_at"`
}
type Manager struct {
root string
mu sync.Mutex
logger *Logger
status Status
cancel context.CancelFunc
ssh *sshBundle
socks *SocksServer
dnstt *ManagedProcess
embeddedDNSTT *dnsttclient.Client
xray *ManagedProcess
tun *tun.Runner
routeCleanup *routes.Cleanup
manualStop bool
reconnecting bool
}
func NewManager(root string) *Manager {
lg := NewLogger()
return &Manager{root: root, logger: lg, tun: tun.NewRunner(lg)}
}
func (m *Manager) LogsSince(id int64) []LogEntry { return m.logger.Since(id) }
func (m *Manager) Status() Status {
m.mu.Lock()
defer m.mu.Unlock()
return m.status
}
func (m *Manager) Start(p config.Profile) error {
config.ApplyDefaults(&p)
if err := config.Validate(p); err != nil {
return err
}
ctx, cancel := context.WithCancel(context.Background())
m.mu.Lock()
if m.status.Running || m.status.Connecting {
m.mu.Unlock()
cancel()
return fmt.Errorf("another profile is already running or connecting")
}
m.cancel = cancel
m.manualStop = false
m.status = Status{Connecting: true, ProfileID: p.ID, Mode: string(p.Mode), Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
m.mu.Unlock()
m.logger.Add("info", "starting profile: %s", p.Name)
fail := func(format string, args ...any) error {
err := fmt.Errorf(format, args...)
m.logger.Add("error", "%v", err)
m.stop(false)
return err
}
ensureCurrent := func() error {
if err := ctx.Err(); err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil || !m.status.Connecting || m.status.ProfileID != p.ID {
return context.Canceled
}
return nil
}
var socksAddr string
var bypass []string
if p.Mode == config.ModeXray {
proc, err := StartProcess(ctx, m.root, "xray", p.Xray.Executable, p.Xray.Args, m.logger)
if err != nil {
return fail("start xray failed: %w", err)
}
m.mu.Lock()
m.xray = proc
m.mu.Unlock()
time.Sleep(time.Duration(p.Xray.StartupTimeoutMs) * time.Millisecond)
if err := ensureCurrent(); err != nil {
proc.Stop()
return fail("connection cancelled: %w", err)
}
socksAddr = net.JoinHostPort(p.Xray.LocalSocksHost, fmt.Sprint(p.Xray.LocalSocksPort))
if exited, procErr := proc.Exited(); exited {
m.stop(false)
if procErr != nil {
return fmt.Errorf("xray exited before opening local SOCKS on %s: %w", socksAddr, procErr)
}
return fmt.Errorf("xray exited before opening local SOCKS on %s", socksAddr)
}
if err := waitForTCP(ctx, socksAddr, time.Duration(p.Xray.StartupTimeoutMs)*time.Millisecond); err != nil {
return fail("xray local SOCKS is not listening on %s: %w", socksAddr, err)
}
bypass = nil
} else {
if p.Mode == config.ModeDNSTT {
if p.DNSTT.UseEmbedded {
client, err := dnsttclient.Start(ctx, dnsttclient.Options{
ResolverType: p.DNSTT.ResolverType,
ResolverAddress: p.DNSTT.ResolverAddress,
PublicKeyHex: p.DNSTT.PublicKey,
Domain: p.DNSTT.Domain,
LocalAddress: net.JoinHostPort(p.DNSTT.LocalSSHHost, fmt.Sprint(p.DNSTT.LocalSSHPort)),
UTLSDistribution: p.DNSTT.UTLSDistribution,
StartupTimeout: time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond,
LogWriter: dnsttLogWriter{logger: m.logger},
})
if err != nil {
return fail("start embedded dnstt failed: %w", err)
}
m.mu.Lock()
m.embeddedDNSTT = client
m.mu.Unlock()
} else {
proc, err := StartProcess(ctx, m.root, "dnstt", p.DNSTT.Executable, p.DNSTT.Args, m.logger)
if err != nil {
return fail("start dnstt failed: %w", err)
}
m.mu.Lock()
m.dnstt = proc
m.mu.Unlock()
time.Sleep(time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond)
}
if err := ensureCurrent(); err != nil {
return fail("connection cancelled: %w", err)
}
}
sshc, err := connectSSH(ctx, p, m.logger)
if err != nil {
return fail("%w", err)
}
if err := ensureCurrent(); err != nil {
_ = sshc.Client.Close()
_ = sshc.Conn.Close()
return fail("connection cancelled: %w", err)
}
m.mu.Lock()
m.ssh = sshc
m.mu.Unlock()
socksAddr = net.JoinHostPort(p.Local.SocksHost, fmt.Sprint(p.Local.SocksPort))
ss := &SocksServer{Addr: socksAddr, SSH: sshc.Client, Logger: m.logger, DNS: profileDNSServers(p), UDPGW: p.UDPGW}
if err := ss.Start(); err != nil {
return fail("%w", err)
}
m.mu.Lock()
m.socks = ss
m.status.SocksAddr = socksAddr
m.mu.Unlock()
if err := waitForTCP(ctx, socksAddr, 2500*time.Millisecond); err != nil {
return fail("local SOCKS is not listening on %s: %w", socksAddr, err)
}
bypass = effectiveBypassHosts(p, sshc.ControlHosts)
}
if err := ensureCurrent(); err != nil {
return fail("connection cancelled: %w", err)
}
if p.Tun.Enabled {
if err := m.tun.Start(&p, socksAddr); err != nil {
return fail("%w", err)
}
cleanup, err := routes.Apply(p, bypass, m.logger)
if err != nil {
m.logger.Add("warn", "route setup error: %v", err)
}
m.mu.Lock()
m.routeCleanup = cleanup
m.mu.Unlock()
}
m.mu.Lock()
current := m.cancel != nil && m.status.Connecting && m.status.ProfileID == p.ID
err := ctx.Err()
if err == nil && current {
m.status = Status{Running: true, ProfileID: p.ID, Mode: string(p.Mode), SocksAddr: socksAddr, Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
}
m.mu.Unlock()
if err != nil || !current {
m.stop(false)
if err != nil {
return err
}
return context.Canceled
}
m.logger.Add("info", "profile is connected; local socks=%s tun=%v", socksAddr, p.Tun.Enabled)
m.startMonitor(ctx, p)
return nil
}
func profileDNSServers(p config.Profile) []string {
out := append([]string{}, p.Tun.DNS...)
if p.Tun.IPv6Enabled {
out = append(out, p.Tun.IPv6DNS...)
}
return out
}
func waitForTCP(ctx context.Context, addr string, timeout time.Duration) error {
if timeout <= 0 {
timeout = 5 * time.Second
}
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
d := net.Dialer{Timeout: 350 * time.Millisecond}
c, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
_ = c.Close()
return nil
}
lastErr = err
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(150 * time.Millisecond):
}
}
if lastErr != nil {
return lastErr
}
return fmt.Errorf("timeout waiting for %s", addr)
}
func effectiveBypassHosts(p config.Profile, activeControlHosts []string) []string {
seen := map[string]bool{}
out := make([]string, 0, len(activeControlHosts)+1)
add := func(host string) {
host = strings.TrimSpace(host)
if host == "" {
return
}
if h, _, err := net.SplitHostPort(host); err == nil {
host = strings.Trim(h, "[]")
}
if isLocalBypassHost(host) || seen[host] {
return
}
seen[host] = true
out = append(out, host)
}
for _, host := range activeControlHosts {
add(host)
}
if len(out) == 0 && p.Mode != config.ModeDNSTT {
// Fallback for old profile formats or unexpected transports. This still only
// adds the direct SSH host, not every rotated proxy from the profile.
add(p.SSH.Host)
}
if p.Mode == config.ModeDNSTT && p.DNSTT.UseEmbedded {
add(dnsttResolverHost(p.DNSTT.ResolverAddress))
}
return out
}
func isLocalBypassHost(host string) bool {
host = strings.Trim(strings.ToLower(strings.TrimSpace(host)), "[]")
if host == "" || host == "localhost" {
return true
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback() || ip.IsUnspecified()
}
func dnsttResolverHost(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if strings.Contains(s, "://") {
u, err := url.Parse(s)
if err == nil {
return u.Hostname()
}
}
host, _, err := net.SplitHostPort(s)
if err == nil {
return host
}
return s
}
func (m *Manager) Stop() {
m.stop(true)
}
func (m *Manager) stop(manual bool) {
m.mu.Lock()
defer m.mu.Unlock()
if manual {
m.manualStop = true
}
m.stopLocked()
}
func (m *Manager) startMonitor(ctx context.Context, p config.Profile) {
if !p.Reconnect.Enabled {
return
}
interval := time.Duration(p.Reconnect.CheckIntervalSeconds) * time.Second
if interval <= 0 {
interval = 10 * time.Second
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if m.isManualStopped() {
return
}
if err := m.probeConnection(p); err != nil {
m.logger.Add("warn", "connection monitor detected tunnel loss: %v", err)
m.reconnectLoop(p, err)
return
}
}
}
}()
}
func (m *Manager) isManualStopped() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.manualStop
}
func (m *Manager) markReconnecting() bool {
m.mu.Lock()
defer m.mu.Unlock()
if m.manualStop || m.reconnecting {
return false
}
m.reconnecting = true
return true
}
func (m *Manager) clearReconnecting() {
m.mu.Lock()
m.reconnecting = false
m.mu.Unlock()
}
func (m *Manager) probeConnection(p config.Profile) error {
m.mu.Lock()
sshc := m.ssh
xray := m.xray
socksAddr := m.status.SocksAddr
running := m.status.Running
m.mu.Unlock()
if !running {
return nil
}
if xray != nil {
if exited, err := xray.Exited(); exited {
if err != nil {
return fmt.Errorf("xray exited: %w", err)
}
return fmt.Errorf("xray exited")
}
if socksAddr != "" {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
return waitForTCP(ctx, socksAddr, 2*time.Second)
}
return nil
}
if sshc != nil && sshc.Client != nil {
if sshc.Conn != nil {
_ = sshc.Conn.SetDeadline(time.Now().Add(5 * time.Second))
defer sshc.Conn.SetDeadline(time.Time{})
}
_, _, err := sshc.Client.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
return fmt.Errorf("ssh keepalive failed: %w", err)
}
}
return nil
}
func (m *Manager) reconnectLoop(p config.Profile, cause error) {
if !m.markReconnecting() {
return
}
defer m.clearReconnecting()
delay := time.Duration(p.Reconnect.DelaySeconds) * time.Second
if delay <= 0 {
delay = 3 * time.Second
}
maxRetries := p.Reconnect.MaxRetries
m.logger.Add("warn", "connection lost (%v); auto reconnect is enabled", cause)
if p.Tun.Enabled {
m.logger.Add("info", "destroying TUN before reconnect")
}
m.stop(false)
if p.Tun.Enabled {
time.Sleep(1200 * time.Millisecond)
}
for attempt := 1; maxRetries <= 0 || attempt <= maxRetries; attempt++ {
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
m.logger.Add("info", "reconnect attempt %d%s in %s", attempt, reconnectLimitSuffix(maxRetries), delay)
select {
case <-time.After(delay):
}
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
if err := m.Start(p); err != nil {
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
m.logger.Add("warn", "reconnect attempt %d failed: %v", attempt, err)
if p.Tun.Enabled {
m.logger.Add("info", "destroying TUN after failed reconnect attempt")
m.stop(false)
time.Sleep(1200 * time.Millisecond)
}
continue
}
m.logger.Add("info", "reconnected successfully")
return
}
m.logger.Add("error", "auto reconnect stopped after %d failed attempt(s)", maxRetries)
}
func reconnectLimitSuffix(maxRetries int) string {
if maxRetries <= 0 {
return " (unlimited)"
}
return fmt.Sprintf("/%d", maxRetries)
}
func (m *Manager) stopLocked() {
wasActive := m.status.Running || m.status.Connecting
if m.cancel != nil {
m.cancel()
m.cancel = nil
}
if m.routeCleanup != nil {
m.routeCleanup.Run()
m.routeCleanup = nil
}
if m.tun != nil {
m.tun.Stop()
}
if m.socks != nil {
m.socks.Stop()
m.socks = nil
}
if m.ssh != nil {
_ = m.ssh.Client.Close()
_ = m.ssh.Conn.Close()
m.ssh = nil
}
if m.xray != nil {
m.xray.Stop()
m.xray = nil
}
if m.embeddedDNSTT != nil {
m.embeddedDNSTT.Stop()
m.embeddedDNSTT = nil
}
if m.dnstt != nil {
m.dnstt.Stop()
m.dnstt = nil
}
if wasActive {
m.logger.Add("info", "disconnected")
}
m.status = Status{}
}
type dnsttLogWriter struct {
logger *Logger
}
func (w dnsttLogWriter) Write(p []byte) (int, error) {
line := strings.TrimSpace(string(p))
if line != "" && w.logger != nil {
w.logger.Add("dnstt", "%s", line)
}
return len(p), nil
}

399
internal/engine/payload.go Normal file
View File

@@ -0,0 +1,399 @@
package engine
import (
"bytes"
"fmt"
"io"
"math/rand"
"net"
"strconv"
"strings"
"time"
"socksrevivepc/internal/config"
)
const (
payloadHTTPStatusPeekTimeoutMs = 1500
payloadHTTPStatusLineLimit = 4096
maxPayloadHTTPResponsesToSkip = 12
maxPayloadHTTPHeaderLines = 80
maxPayloadHTTPBodyDiscardBytes = 1024 * 1024
)
type PayloadResult struct {
StatusLine string
StatusCode int
}
type httpResponseInfo struct {
contentLength int64
chunked bool
}
type preloadedConn struct {
net.Conn
preloaded *bytes.Reader
}
func (c *preloadedConn) Read(p []byte) (int, error) {
if c.preloaded != nil && c.preloaded.Len() > 0 {
return c.preloaded.Read(p)
}
return c.Conn.Read(p)
}
func wrapConnWithPreloadedBytes(conn net.Conn, b []byte) net.Conn {
if len(b) == 0 {
return conn
}
return &preloadedConn{Conn: conn, preloaded: bytes.NewReader(b)}
}
func WritePayload(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (PayloadResult, net.Conn, error) {
payload := buildPayload(p.Payload.Text, targetHost, targetPort)
parts, instant := splitPayload(payload)
for i, part := range parts {
if part == "" {
continue
}
if _, err := io.WriteString(conn, part); err != nil {
return PayloadResult{}, conn, err
}
if i < len(parts)-1 && !instant {
time.Sleep(time.Duration(p.Payload.SplitDelayMs) * time.Millisecond)
}
}
logger.Add("debug", "payload sent (%d bytes)", len(payload))
if !p.Payload.WaitForResponse {
return PayloadResult{}, conn, nil
}
return consumePayloadHTTPNegotiation(conn, p, payloadSourceLabel(p.Mode), logger)
}
func consumePayloadHTTPNegotiation(conn net.Conn, p config.Profile, source string, logger *Logger) (PayloadResult, net.Conn, error) {
defer conn.SetReadDeadline(time.Time{})
var last PayloadResult
var captured *bytes.Buffer
var sawSuccess bool
for attempt := 0; attempt < maxPayloadHTTPResponsesToSkip; attempt++ {
setPayloadReadDeadline(conn, p, attempt)
captured = &bytes.Buffer{}
line, err := readPayloadLinePreserveBytes(conn, captured, payloadHTTPStatusLineLimit)
if err != nil {
if isTimeoutErr(err) {
if last.StatusCode >= 400 && !p.Payload.AcceptAnyStatus && !sawSuccess {
return last, conn, fmt.Errorf("payload rejected with final status %d", last.StatusCode)
}
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
if err == io.EOF && captured.Len() > 0 {
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
if last.StatusCode > 0 {
return last, conn, nil
}
return PayloadResult{}, conn, fmt.Errorf("payload response read failed: %w", err)
}
cleanLine := strings.TrimSpace(line)
if cleanLine == "" {
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
if strings.HasPrefix(cleanLine, "SSH-") || !isHTTPStatusLine(cleanLine) {
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
code := parseStatusCode(cleanLine)
last = PayloadResult{StatusLine: cleanLine, StatusCode: code}
logProxyStatus(logger, source, code, cleanLine)
logHTTPCompatibilityStatus(logger, code, cleanLine)
if code == 101 || (code >= 200 && code < 400) {
sawSuccess = true
}
// The current bytes are confirmed HTTP/proxy negotiation bytes. Do not replay
// them to the SSH transport. Only replay bytes when we detect SSH/non-HTTP
// data or a partial line after timeout.
captured = nil
if err := consumePayloadHTTPHeadersAndBody(conn); err != nil {
if isTimeoutErr(err) {
return last, conn, nil
}
return last, conn, fmt.Errorf("payload response consume failed: %w", err)
}
// Keep peeking for another immediate HTTP status block. Some payload/proxy
// chains return several statuses (for example 403 -> 403 -> 101). Returning
// after the first status can make SSH read HTTP text instead of SSH-2.0.
}
if last.StatusCode >= 400 && !p.Payload.AcceptAnyStatus && !sawSuccess {
return last, conn, fmt.Errorf("payload rejected with final status %d", last.StatusCode)
}
return last, conn, nil
}
func setPayloadReadDeadline(conn net.Conn, p config.Profile, attempt int) {
timeoutMs := payloadHTTPStatusPeekTimeoutMs
if attempt == 0 && p.Payload.ResponseTimeoutMs > 0 {
timeoutMs = p.Payload.ResponseTimeoutMs
if timeoutMs < payloadHTTPStatusPeekTimeoutMs {
timeoutMs = payloadHTTPStatusPeekTimeoutMs
}
}
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond))
}
func readPayloadLinePreserveBytes(conn net.Conn, captured *bytes.Buffer, limit int) (string, error) {
var line bytes.Buffer
buf := make([]byte, 1)
for line.Len() < limit {
n, err := conn.Read(buf)
if n > 0 {
b := buf[0]
_ = captured.WriteByte(b)
_ = line.WriteByte(b)
if b == '\n' {
break
}
}
if err != nil {
if line.Len() > 0 && err == io.EOF {
return line.String(), nil
}
return "", err
}
if n == 0 {
continue
}
}
if line.Len() == 0 {
return "", io.EOF
}
return line.String(), nil
}
func consumePayloadHTTPHeadersAndBody(conn net.Conn) error {
info := httpResponseInfo{contentLength: -1}
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
ignored := &bytes.Buffer{}
line, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
if err != nil {
return err
}
clean := strings.TrimSpace(line)
if clean == "" {
break
}
lower := strings.ToLower(clean)
if strings.HasPrefix(lower, "content-length:") {
if n, err := strconv.ParseInt(strings.TrimSpace(clean[strings.Index(clean, ":")+1:]), 10, 64); err == nil {
info.contentLength = n
}
} else if strings.HasPrefix(lower, "transfer-encoding:") && strings.Contains(lower, "chunked") {
info.chunked = true
}
}
if info.chunked {
return discardPayloadChunkedBody(conn)
}
if info.contentLength > 0 {
return discardPayloadFixedLengthBody(conn, info.contentLength)
}
return nil
}
func discardPayloadFixedLengthBody(conn net.Conn, contentLength int64) error {
remaining := contentLength
if remaining > maxPayloadHTTPBodyDiscardBytes {
remaining = maxPayloadHTTPBodyDiscardBytes
}
buf := make([]byte, 4096)
for remaining > 0 {
toRead := int64(len(buf))
if remaining < toRead {
toRead = remaining
}
n, err := conn.Read(buf[:int(toRead)])
if n > 0 {
remaining -= int64(n)
}
if err != nil {
return err
}
}
return nil
}
func discardPayloadChunkedBody(conn net.Conn) error {
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
ignored := &bytes.Buffer{}
sizeLine, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
if err != nil {
return err
}
cleanSize := strings.TrimSpace(sizeLine)
if semi := strings.Index(cleanSize, ";"); semi >= 0 {
cleanSize = strings.TrimSpace(cleanSize[:semi])
}
chunkSize, err := strconv.ParseInt(cleanSize, 16, 64)
if err != nil {
return nil
}
if chunkSize == 0 {
return consumePayloadTrailingHeaders(conn)
}
if err := discardPayloadFixedLengthBody(conn, chunkSize); err != nil {
return err
}
crlf := &bytes.Buffer{}
_, _ = readPayloadLinePreserveBytes(conn, crlf, payloadHTTPStatusLineLimit)
}
return nil
}
func consumePayloadTrailingHeaders(conn net.Conn) error {
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
ignored := &bytes.Buffer{}
line, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
if err != nil {
return err
}
if strings.TrimSpace(line) == "" {
return nil
}
}
return nil
}
func isTimeoutErr(err error) bool {
if err == nil {
return false
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return true
}
return false
}
func isHTTPStatusLine(statusLine string) bool {
clean := strings.ToUpper(strings.TrimSpace(statusLine))
return strings.HasPrefix(clean, "HTTP/1.") || strings.HasPrefix(clean, "HTTP/2") || strings.HasPrefix(clean, "HTTP/3")
}
func logProxyStatus(logger *Logger, source string, responseCode int, statusLine string) {
cleanLine := strings.TrimSpace(statusLine)
if cleanLine == "" {
return
}
if source == "" {
source = "PROXY"
}
logger.Add("info", "Proxy Status [%s]: %s", source, cleanLine)
}
func logHTTPCompatibilityStatus(logger *Logger, responseCode int, firstLine string) {
switch responseCode {
case 200:
logger.Add("info", "Status: 200 (Connection established) Successful")
case 101:
logger.Add("info", "replace 200 OK")
logger.Add("info", "HTTP/1.1 101 Websocket")
case 100:
logger.Add("info", "HTTP/1.1 100 Continue")
case 301, 302, 400, 401, 403, 404, 407, 429, 500, 502, 503, 504:
if strings.TrimSpace(firstLine) != "" {
logger.Add("info", "%s", strings.TrimSpace(firstLine))
} else {
logger.Add("info", "HTTP/1.1 %d", responseCode)
}
logger.Add("info", "replace 200 OK")
logger.Add("info", "Dragon Try!")
}
}
func payloadSourceLabel(mode config.Mode) string {
switch mode {
case config.ModePayload:
return "HTTP_PROXY"
case config.ModePayloadSSL:
return "SSL_PAYLOAD"
default:
return "PROXY"
}
}
func buildPayload(tpl, host string, port int) string {
portStr := strconv.Itoa(port)
repl := map[string]string{
"[host]": host,
"[port]": portStr,
"[host_port]": net.JoinHostPort(host, portStr),
"[crlf]": "\r\n",
"[cr]": "\r",
"[lf]": "\n",
"[protocol]": "HTTP/1.1",
"[method]": "CONNECT",
}
out := tpl
for k, v := range repl {
out = strings.ReplaceAll(out, k, v)
}
out = replaceRotate(out)
return out
}
func replaceRotate(s string) string {
for {
start := strings.Index(s, "[rotate=")
if start < 0 {
return s
}
end := strings.Index(s[start:], "]")
if end < 0 {
return s
}
end += start
body := strings.TrimPrefix(s[start:end+1], "[rotate=")
body = strings.TrimSuffix(body, "]")
choices := splitRotateChoices(body)
choice := ""
if len(choices) > 0 {
choice = strings.TrimSpace(choices[rand.Intn(len(choices))])
}
s = s[:start] + choice + s[end+1:]
}
}
func splitRotateChoices(body string) []string {
return strings.FieldsFunc(body, func(r rune) bool {
switch r {
case ';', '#', ',', '\n', '\r', '\t':
return true
default:
return false
}
})
}
func splitPayload(s string) ([]string, bool) {
instant := strings.Contains(s, "[instant_split]")
s = strings.ReplaceAll(s, "[instant_split]", "[split]")
parts := strings.Split(s, "[split]")
return parts, instant
}
func parseStatusCode(status string) int {
fields := strings.Fields(status)
if len(fields) < 2 {
return 0
}
code, _ := strconv.Atoi(fields[1])
return code
}

View File

@@ -0,0 +1,91 @@
package engine
import (
"bufio"
"context"
"errors"
"io"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"socksrevivepc/internal/oscmd"
)
type ManagedProcess struct {
cmd *exec.Cmd
name string
logger *Logger
mu sync.Mutex
done chan error
}
func StartProcess(ctx context.Context, root, name, exe string, args []string, logger *Logger) (*ManagedProcess, error) {
if strings.TrimSpace(exe) == "" {
return nil, errors.New(name + " executable path is empty")
}
if !filepath.IsAbs(exe) {
exe = filepath.Join(root, exe)
}
cmd := oscmd.CommandContext(ctx, exe, args...)
cmd.Dir = root
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
p := &ManagedProcess{cmd: cmd, name: name, logger: logger, done: make(chan error, 1)}
if err := cmd.Start(); err != nil {
return nil, err
}
logger.Add("info", "%s started: %s %s", name, exe, strings.Join(args, " "))
go p.pipe(stdout, "info")
go p.pipe(stderr, "warn")
go func() {
err := cmd.Wait()
p.done <- err
if err != nil {
logger.Add("warn", "%s stopped: %v", name, err)
} else {
logger.Add("info", "%s stopped", name)
}
}()
return p, nil
}
func (p *ManagedProcess) Exited() (bool, error) {
if p == nil || p.done == nil {
return true, nil
}
select {
case err := <-p.done:
return true, err
default:
return false, nil
}
}
func (p *ManagedProcess) pipe(r io.Reader, level string) {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if line != "" {
p.logger.Add(level, "%s: %s", p.name, line)
}
}
}
func (p *ManagedProcess) Stop() {
p.mu.Lock()
defer p.mu.Unlock()
if p == nil || p.cmd == nil || p.cmd.Process == nil {
return
}
if runtime.GOOS == "windows" {
_ = p.cmd.Process.Kill()
} else {
_ = p.cmd.Process.Signal(ioSignalInterrupt())
time.Sleep(500 * time.Millisecond)
_ = p.cmd.Process.Kill()
}
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package engine
import "os"
func ioSignalInterrupt() os.Signal { return os.Interrupt }

View File

@@ -0,0 +1,7 @@
//go:build windows
package engine
import "os"
func ioSignalInterrupt() os.Signal { return os.Interrupt }

462
internal/engine/socks5.go Normal file
View File

@@ -0,0 +1,462 @@
package engine
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
"socksrevivepc/internal/config"
)
const (
socksVersion = 0x05
socksCmdConnect = 0x01
socksCmdUDPAssoc = 0x03
socksAtypIPv4 = 0x01
socksAtypDomain = 0x03
socksAtypIPv6 = 0x04
socksRepOK = 0x00
socksRepFail = 0x01
socksRepUnsupported = 0x07
)
type SocksServer struct {
Addr string
SSH *ssh.Client
Logger *Logger
DNS []string
UDPGW config.UDPGWConfig
listener net.Listener
stopOnce sync.Once
}
type socksRequest struct {
Command byte
Host string
Port int
}
func (r socksRequest) Addr() string {
return net.JoinHostPort(r.Host, strconv.Itoa(r.Port))
}
func (s *SocksServer) Start() error {
ln, err := net.Listen("tcp", s.Addr)
if err != nil {
return err
}
s.listener = ln
s.Logger.Add("info", "local SOCKS5 listening on %s", s.Addr)
go s.acceptLoop()
return nil
}
func (s *SocksServer) Stop() {
s.stopOnce.Do(func() {
if s.listener != nil {
_ = s.listener.Close()
}
})
}
func (s *SocksServer) acceptLoop() {
for {
c, err := s.listener.Accept()
if err != nil {
return
}
go s.handle(c)
}
}
func (s *SocksServer) handle(c net.Conn) {
defer c.Close()
_ = c.SetDeadline(time.Now().Add(30 * time.Second))
if err := s.handshake(c); err != nil {
if !isExpectedProbeError(err) {
s.Logger.Add("debug", "socks handshake failed: %v", err)
}
return
}
req, err := readSocksRequest(c)
if err != nil {
s.Logger.Add("debug", "socks request failed: %v", err)
return
}
switch req.Command {
case socksCmdConnect:
s.handleConnect(c, req)
case socksCmdUDPAssoc:
if s.UDPGW.Enabled {
s.handleUDPGWAssociate(c)
} else {
s.handleDNSUDPAssociate(c)
}
default:
_ = writeReply(c, socksRepUnsupported)
s.Logger.Add("debug", "socks unsupported command %d", req.Command)
}
}
func (s *SocksServer) handleConnect(c net.Conn, req socksRequest) {
dest := req.Addr()
remote, err := dialSSHDirectTCP(s.SSH, req, c.RemoteAddr(), s.Logger)
if err != nil {
_ = writeReply(c, 0x05)
s.Logger.Add("warn", "ssh dial %s failed: %v", dest, err)
return
}
defer remote.Close()
_ = writeReply(c, socksRepOK)
_ = c.SetDeadline(time.Time{})
s.Logger.Add("debug", "socks connected %s", dest)
proxyCopy(c, remote)
}
func (s *SocksServer) handleDNSUDPAssociate(c net.Conn) {
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
_ = writeReply(c, socksRepFail)
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
return
}
defer udp.Close()
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
return
}
_ = c.SetDeadline(time.Time{})
s.Logger.Add("info", "local SOCKS5 UDP associate listening on %s for DNS over SSH", udp.LocalAddr().String())
done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, c)
close(done)
}()
buf := make([]byte, 64*1024)
for {
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
n, clientAddr, err := udp.ReadFromUDP(buf)
if err != nil {
select {
case <-done:
return
default:
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
s.Logger.Add("debug", "socks UDP read failed: %v", err)
return
}
resp, err := s.handleSocksUDPDatagram(buf[:n])
if err != nil {
if shouldLogUDPError(err) {
s.Logger.Add("debug", "%v", err)
}
continue
}
_, _ = udp.WriteToUDP(resp, clientAddr)
}
}
func (s *SocksServer) handleSocksUDPDatagram(packet []byte) ([]byte, error) {
addr, payload, err := parseSocksUDP(packet)
if err != nil {
return nil, err
}
if addr.Port != 53 {
return nil, fmt.Errorf("dropping unsupported UDP target %s; only DNS/UDP port 53 is proxied for SSH modes", addr.Addr())
}
answer, err := s.resolveDNSOverSSHTCP(addr, payload)
if err != nil {
return nil, fmt.Errorf("DNS over SSH failed for %s: %w", addr.Addr(), err)
}
return buildSocksUDP(addr, answer)
}
func (s *SocksServer) resolveDNSOverSSHTCP(addr socksRequest, query []byte) ([]byte, error) {
servers := s.dnsServers(addr)
var lastErr error
for _, server := range servers {
remote, err := dialSSHDirectTCPAddr(s.SSH, server, nil, s.Logger)
if err != nil {
lastErr = err
continue
}
_ = remote.SetDeadline(time.Now().Add(8 * time.Second))
resp, err := exchangeDNSTCP(remote, query)
_ = remote.Close()
if err == nil {
return resp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("no DNS server configured")
}
func (s *SocksServer) dnsServers(requested socksRequest) []string {
seen := map[string]bool{}
var out []string
add := func(v string) {
v = normalizeHostPort(v, "53")
if v == "" {
return
}
if !seen[v] {
seen[v] = true
out = append(out, v)
}
}
if requested.Host != "" {
add(net.JoinHostPort(requested.Host, strconv.Itoa(requested.Port)))
}
for _, dns := range s.DNS {
add(dns)
}
add("1.1.1.1:53")
add("8.8.8.8:53")
return out
}
func normalizeHostPort(v, defaultPort string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
if host, port, err := net.SplitHostPort(v); err == nil {
return net.JoinHostPort(host, port)
}
if ip := net.ParseIP(v); ip != nil {
return net.JoinHostPort(ip.String(), defaultPort)
}
if strings.Count(v, ":") == 1 {
host, port, ok := strings.Cut(v, ":")
if ok && host != "" && port != "" {
return net.JoinHostPort(host, port)
}
}
return net.JoinHostPort(v, defaultPort)
}
func exchangeDNSTCP(conn net.Conn, query []byte) ([]byte, error) {
if len(query) == 0 || len(query) > 65535 {
return nil, fmt.Errorf("invalid DNS query length %d", len(query))
}
prefix := make([]byte, 2)
binary.BigEndian.PutUint16(prefix, uint16(len(query)))
if _, err := conn.Write(append(prefix, query...)); err != nil {
return nil, err
}
if _, err := io.ReadFull(conn, prefix); err != nil {
return nil, err
}
ln := int(binary.BigEndian.Uint16(prefix))
if ln <= 0 || ln > 65535 {
return nil, fmt.Errorf("invalid DNS response length %d", ln)
}
resp := make([]byte, ln)
if _, err := io.ReadFull(conn, resp); err != nil {
return nil, err
}
return resp, nil
}
func (s *SocksServer) handshake(c net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(c, header); err != nil {
return err
}
if header[0] != socksVersion {
return errors.New("not socks5")
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(c, methods); err != nil {
return err
}
_, err := c.Write([]byte{socksVersion, 0x00})
return err
}
func readSocksRequest(c net.Conn) (socksRequest, error) {
var req socksRequest
h := make([]byte, 4)
if _, err := io.ReadFull(c, h); err != nil {
return req, err
}
if h[0] != socksVersion {
return req, fmt.Errorf("invalid socks version %d", h[0])
}
req.Command = h[1]
host, err := readSocksHost(c, h[3])
if err != nil {
return req, err
}
pb := make([]byte, 2)
if _, err := io.ReadFull(c, pb); err != nil {
return req, err
}
req.Host = host
req.Port = int(binary.BigEndian.Uint16(pb))
return req, nil
}
func readSocksHost(r io.Reader, atyp byte) (string, error) {
switch atyp {
case socksAtypIPv4:
b := make([]byte, 4)
if _, err := io.ReadFull(r, b); err != nil {
return "", err
}
return net.IP(b).String(), nil
case socksAtypDomain:
l := []byte{0}
if _, err := io.ReadFull(r, l); err != nil {
return "", err
}
b := make([]byte, int(l[0]))
if _, err := io.ReadFull(r, b); err != nil {
return "", err
}
return string(b), nil
case socksAtypIPv6:
b := make([]byte, 16)
if _, err := io.ReadFull(r, b); err != nil {
return "", err
}
return net.IP(b).String(), nil
default:
return "", fmt.Errorf("unsupported address type %d", atyp)
}
}
func parseSocksUDP(packet []byte) (socksRequest, []byte, error) {
var req socksRequest
if len(packet) < 4 {
return req, nil, errors.New("short socks UDP packet")
}
if packet[0] != 0 || packet[1] != 0 {
return req, nil, errors.New("invalid socks UDP reserved field")
}
if packet[2] != 0 {
return req, nil, errors.New("fragmented socks UDP packets are not supported")
}
atyp := packet[3]
off := 4
switch atyp {
case socksAtypIPv4:
if len(packet) < off+4+2 {
return req, nil, errors.New("short socks UDP ipv4 packet")
}
req.Host = net.IP(packet[off : off+4]).String()
off += 4
case socksAtypDomain:
if len(packet) < off+1 {
return req, nil, errors.New("short socks UDP domain packet")
}
ln := int(packet[off])
off++
if len(packet) < off+ln+2 {
return req, nil, errors.New("short socks UDP domain payload")
}
req.Host = string(packet[off : off+ln])
off += ln
case socksAtypIPv6:
if len(packet) < off+16+2 {
return req, nil, errors.New("short socks UDP ipv6 packet")
}
req.Host = net.IP(packet[off : off+16]).String()
off += 16
default:
return req, nil, fmt.Errorf("unsupported socks UDP address type %d", atyp)
}
req.Port = int(binary.BigEndian.Uint16(packet[off : off+2]))
off += 2
return req, packet[off:], nil
}
func buildSocksUDP(addr socksRequest, payload []byte) ([]byte, error) {
var out []byte
out = append(out, 0, 0, 0)
ip := net.ParseIP(addr.Host)
if v4 := ip.To4(); v4 != nil {
out = append(out, socksAtypIPv4)
out = append(out, v4...)
} else if v6 := ip.To16(); v6 != nil {
out = append(out, socksAtypIPv6)
out = append(out, v6...)
} else {
if len(addr.Host) > 255 {
return nil, errors.New("socks UDP domain is too long")
}
out = append(out, socksAtypDomain, byte(len(addr.Host)))
out = append(out, []byte(addr.Host)...)
}
pb := make([]byte, 2)
binary.BigEndian.PutUint16(pb, uint16(addr.Port))
out = append(out, pb...)
out = append(out, payload...)
return out, nil
}
func writeReply(c net.Conn, code byte) error {
return writeReplyWithAddr(c, code, &net.UDPAddr{IP: net.IPv4zero, Port: 0})
}
func writeReplyWithAddr(c net.Conn, code byte, addr *net.UDPAddr) error {
if addr == nil {
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
}
ip := addr.IP
if ip == nil || ip.IsUnspecified() {
ip = net.IPv4(127, 0, 0, 1)
}
var resp []byte
resp = append(resp, socksVersion, code, 0x00)
if v4 := ip.To4(); v4 != nil {
resp = append(resp, socksAtypIPv4)
resp = append(resp, v4...)
} else if v6 := ip.To16(); v6 != nil {
resp = append(resp, socksAtypIPv6)
resp = append(resp, v6...)
} else {
resp = append(resp, socksAtypIPv4, 127, 0, 0, 1)
}
pb := make([]byte, 2)
binary.BigEndian.PutUint16(pb, uint16(addr.Port))
resp = append(resp, pb...)
_, err := c.Write(resp)
return err
}
func proxyCopy(a net.Conn, b net.Conn) {
done := make(chan struct{}, 2)
go func() { _, _ = io.Copy(a, b); done <- struct{}{} }()
go func() { _, _ = io.Copy(b, a); done <- struct{}{} }()
<-done
}
func isExpectedProbeError(err error) bool {
return errors.Is(err, io.EOF) || strings.Contains(err.Error(), "connection reset")
}
func shouldLogUDPError(err error) bool {
msg := err.Error()
return !strings.Contains(msg, "unsupported UDP target")
}

View File

@@ -0,0 +1,151 @@
package engine
import (
"fmt"
"net"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
type directTCPIPPayload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
type sshChannelConn struct {
ssh.Channel
local net.Addr
remote net.Addr
}
func (c *sshChannelConn) LocalAddr() net.Addr {
return c.local
}
func (c *sshChannelConn) RemoteAddr() net.Addr {
return c.remote
}
func (c *sshChannelConn) SetDeadline(time.Time) error {
return nil
}
func (c *sshChannelConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *sshChannelConn) SetWriteDeadline(time.Time) error {
return nil
}
func dialSSHDirectTCP(client *ssh.Client, dest socksRequest, origin net.Addr, logger *Logger) (net.Conn, error) {
addr := dest.Addr()
remote, err := client.Dial("tcp", addr)
if err == nil {
return remote, nil
}
if !shouldRetrySSHIPv6Bracket(dest.Host, err) {
return nil, err
}
if logger != nil {
logger.Add("debug", "retrying SSH direct-tcpip IPv6 target with bracketed host [%s]:%d", strings.Trim(dest.Host, "[]"), dest.Port)
}
remote, retryErr := openSSHDirectTCPIP(client, bracketIPv6Host(dest.Host), dest.Port, origin)
if retryErr == nil {
return remote, nil
}
return nil, fmt.Errorf("%w; bracketed IPv6 retry failed: %v", err, retryErr)
}
func dialSSHDirectTCPAddr(client *ssh.Client, addr string, origin net.Addr, logger *Logger) (net.Conn, error) {
host, portText, err := net.SplitHostPort(addr)
if err != nil {
return client.Dial("tcp", addr)
}
port, err := strconv.Atoi(portText)
if err != nil {
return nil, err
}
return dialSSHDirectTCP(client, socksRequest{Host: host, Port: port}, origin, logger)
}
func openSSHDirectTCPIP(client *ssh.Client, destHost string, destPort int, origin net.Addr) (net.Conn, error) {
originHost, originPort := splitOriginAddr(origin)
payload := ssh.Marshal(directTCPIPPayload{
DestAddr: destHost,
DestPort: uint32(destPort),
OriginAddr: originHost,
OriginPort: uint32(originPort),
})
ch, reqs, err := client.OpenChannel("direct-tcpip", payload)
if err != nil {
return nil, err
}
go ssh.DiscardRequests(reqs)
return &sshChannelConn{
Channel: ch,
local: tcpAddrOrDummy(originHost, originPort),
remote: tcpAddrOrDummy(destHost, destPort),
}, nil
}
func shouldRetrySSHIPv6Bracket(host string, err error) bool {
if err == nil || !isIPv6Literal(host) {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "too many colons in address") || strings.Contains(msg, "missing port in address")
}
func isIPv6Literal(host string) bool {
host = strings.Trim(host, "[]")
ip := net.ParseIP(host)
return ip != nil && ip.To4() == nil && ip.To16() != nil
}
func bracketIPv6Host(host string) string {
host = strings.TrimSpace(host)
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host
}
return "[" + strings.Trim(host, "[]") + "]"
}
func splitOriginAddr(addr net.Addr) (string, int) {
if addr == nil {
return "127.0.0.1", 0
}
host, portText, err := net.SplitHostPort(addr.String())
if err != nil {
return "127.0.0.1", 0
}
port, err := strconv.Atoi(portText)
if err != nil {
port = 0
}
if host == "" || host == "::" || host == "0.0.0.0" {
host = "127.0.0.1"
}
return host, port
}
func tcpAddrOrDummy(host string, port int) net.Addr {
h := strings.Trim(host, "[]")
ip := net.ParseIP(h)
if ip != nil {
return &net.TCPAddr{IP: ip, Port: port}
}
return dummyAddr(net.JoinHostPort(h, strconv.Itoa(port)))
}
type dummyAddr string
func (d dummyAddr) Network() string { return "tcp" }
func (d dummyAddr) String() string { return string(d) }

View File

@@ -0,0 +1,345 @@
package engine
import (
"context"
"crypto/tls"
"fmt"
"net"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"socksrevivepc/internal/config"
)
type sshBundle struct {
Client *ssh.Client
Conn net.Conn
ControlHosts []string
}
type transportAttempt struct {
Label string
ProxyHost string
ProxyPort int
TLSHost string
TLSPort int
ControlHost string
}
func connectSSH(ctx context.Context, p config.Profile, logger *Logger) (*sshBundle, error) {
targetHost := p.SSH.Host
targetPort := p.SSH.Port
if p.Mode == config.ModeDNSTT {
targetHost = p.DNSTT.LocalSSHHost
targetPort = p.DNSTT.LocalSSHPort
}
attempts := buildTransportAttempts(p, targetHost, targetPort)
if len(attempts) == 0 {
attempts = []transportAttempt{{Label: "default"}}
}
logger.Add("info", "connecting SSH %s:%d using mode %s", targetHost, targetPort, p.Mode)
var lastErr error
for i, attempt := range attempts {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if len(attempts) > 1 {
logger.Add("info", "connection attempt %d/%d via %s", i+1, len(attempts), attempt.Label)
}
conn, err := dialTransportAttempt(ctx, p, targetHost, targetPort, attempt, logger)
if err != nil {
lastErr = err
if len(attempts) > 1 {
logger.Add("warn", "connection attempt %d/%d failed before SSH handshake: %v", i+1, len(attempts), err)
}
continue
}
bundle, err := finishSSHHandshake(conn, p, targetHost, targetPort, logger)
if err == nil {
bundle.ControlHosts = controlHostsForAttempt(p, targetHost, attempt)
if len(attempts) > 1 {
logger.Add("info", "connection attempt %d/%d succeeded via %s", i+1, len(attempts), attempt.Label)
}
return bundle, nil
}
_ = conn.Close()
lastErr = err
if len(attempts) > 1 {
logger.Add("warn", "connection attempt %d/%d failed during SSH handshake: %v", i+1, len(attempts), err)
}
}
if lastErr == nil {
lastErr = fmt.Errorf("no transport attempts were available")
}
if len(attempts) > 1 {
return nil, fmt.Errorf("all %d connection attempts failed; last error: %w", len(attempts), lastErr)
}
return nil, lastErr
}
func finishSSHHandshake(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (*sshBundle, error) {
sshCfg := &ssh.ClientConfig{
User: p.SSH.Username,
Auth: []ssh.AuthMethod{ssh.Password(p.SSH.Password)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond,
ClientVersion: "SSH-2.0-SocksRevivePC",
}
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
cc, chans, reqs, err := ssh.NewClientConn(conn, addr, sshCfg)
if err != nil {
return nil, fmt.Errorf("ssh handshake failed: %w", err)
}
client := ssh.NewClient(cc, chans, reqs)
logger.Add("info", "ssh authenticated as %s", p.SSH.Username)
return &sshBundle{Client: client, Conn: conn}, nil
}
func buildTransportAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
switch p.Mode {
case config.ModePayload:
return payloadProxyAttempts(p, targetHost, targetPort)
case config.ModeSSL, config.ModePayloadSSL:
return tlsHostAttempts(p, targetHost, targetPort)
default:
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
}
}
func controlHostsForAttempt(p config.Profile, targetHost string, attempt transportAttempt) []string {
switch p.Mode {
case config.ModePayload:
if attempt.ProxyHost != "" {
return []string{attempt.ProxyHost}
}
return []string{targetHost}
case config.ModeSSL, config.ModePayloadSSL:
if attempt.TLSHost != "" {
return []string{attempt.TLSHost}
}
if p.TLS.Host != "" {
host, _ := hostPortWithDefault(p.TLS.Host, p.TLS.Port)
if host != "" {
return []string{host}
}
}
return []string{targetHost}
default:
return []string{targetHost}
}
}
func payloadProxyAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
proxyText := strings.TrimSpace(p.Proxy.Host)
if proxyText == "" || p.Proxy.Port <= 0 {
return []transportAttempt{{Label: "direct payload transport"}}
}
rawHosts := splitHostList(proxyText)
if len(rawHosts) == 0 {
return []transportAttempt{{Label: "direct payload transport"}}
}
out := make([]transportAttempt, 0, len(rawHosts))
for _, raw := range rawHosts {
host, port := hostPortWithDefault(raw, p.Proxy.Port)
if host == "" || port <= 0 {
continue
}
out = append(out, transportAttempt{ProxyHost: host, ProxyPort: port, Label: "proxy " + net.JoinHostPort(host, fmt.Sprint(port))})
}
if len(out) == 0 {
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
}
return out
}
func tlsHostAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
hostText := strings.TrimSpace(p.TLS.Host)
defaultPort := p.TLS.Port
if defaultPort <= 0 {
defaultPort = targetPort
}
if hostText == "" {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
rawHosts := splitHostList(hostText)
if len(rawHosts) == 0 {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
out := make([]transportAttempt, 0, len(rawHosts))
for _, raw := range rawHosts {
host, port := hostPortWithDefault(raw, defaultPort)
if host == "" || port <= 0 {
continue
}
out = append(out, transportAttempt{TLSHost: host, TLSPort: port, Label: "TLS " + net.JoinHostPort(host, fmt.Sprint(port))})
}
if len(out) == 0 {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
return out
}
func dialTransport(ctx context.Context, p config.Profile, targetHost string, targetPort int, logger *Logger) (net.Conn, error) {
attempts := buildTransportAttempts(p, targetHost, targetPort)
if len(attempts) == 0 {
attempts = []transportAttempt{{Label: "default"}}
}
return dialTransportAttempt(ctx, p, targetHost, targetPort, attempts[0], logger)
}
func dialTransportAttempt(ctx context.Context, p config.Profile, targetHost string, targetPort int, attempt transportAttempt, logger *Logger) (net.Conn, error) {
d := &net.Dialer{Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond}
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
switch p.Mode {
case config.ModeDirect, config.ModeDNSTT:
return d.DialContext(ctx, "tcp", addr)
case config.ModeSSL:
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
case config.ModePayload:
connectHost, connectPort := targetHost, targetPort
if attempt.ProxyHost != "" && attempt.ProxyPort > 0 {
connectHost, connectPort = attempt.ProxyHost, attempt.ProxyPort
} else if p.Proxy.Host != "" && p.Proxy.Port > 0 {
connectHost, connectPort = p.Proxy.Host, p.Proxy.Port
}
logger.Add("debug", "payload transport dialing %s", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
conn, err := d.DialContext(ctx, "tcp", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
if err != nil {
return nil, err
}
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
_ = conn.Close()
return nil, err
} else {
conn = wrappedConn
}
return conn, nil
case config.ModePayloadSSL:
conn, err := dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
if err != nil {
return nil, err
}
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
_ = conn.Close()
return nil, err
} else {
conn = wrappedConn
}
return conn, nil
default:
return nil, fmt.Errorf("unsupported mode %s", p.Mode)
}
}
func dialTLS(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int) (net.Conn, error) {
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, transportAttempt{})
}
func dialTLSAttempt(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int, attempt transportAttempt) (net.Conn, error) {
host := strings.TrimSpace(attempt.TLSHost)
port := attempt.TLSPort
if host == "" {
host = p.TLS.Host
}
if port <= 0 {
port = p.TLS.Port
}
if host == "" {
host = targetHost
}
if port == 0 {
port = targetPort
}
serverName := p.TLS.ServerName
if serverName == "" {
serverName = host
}
raw, err := d.DialContext(ctx, "tcp", net.JoinHostPort(host, fmt.Sprint(port)))
if err != nil {
return nil, err
}
cfg := &tls.Config{ServerName: serverName, InsecureSkipVerify: p.TLS.InsecureSkipVerify, MinVersion: tls.VersionTLS12}
conn := tls.Client(raw, cfg)
if err := conn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func splitHostList(s string) []string {
s = strings.TrimSpace(s)
if s == "" {
return nil
}
parts := strings.FieldsFunc(s, func(r rune) bool {
switch r {
case '#', ',', ';', '\n', '\r', '\t':
return true
default:
return false
}
})
out := make([]string, 0, len(parts))
seen := map[string]struct{}{}
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if _, ok := seen[part]; ok {
continue
}
seen[part] = struct{}{}
out = append(out, part)
}
return out
}
func hostPortWithDefault(raw string, defaultPort int) (string, int) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", defaultPort
}
if strings.Contains(raw, "://") {
// These fields are host fields, not URLs. Strip the scheme if a user pastes one.
if i := strings.Index(raw, "://"); i >= 0 {
raw = raw[i+3:]
}
}
if h, p, err := net.SplitHostPort(raw); err == nil {
pi, _ := strconv.Atoi(p)
if pi > 0 {
return strings.Trim(h, "[]"), pi
}
}
if strings.HasPrefix(raw, "[") && strings.Contains(raw, "]:") {
if h, p, err := net.SplitHostPort(raw); err == nil {
pi, _ := strconv.Atoi(p)
if pi > 0 {
return strings.Trim(h, "[]"), pi
}
}
}
// host:port for IPv4/domain. IPv6 without brackets has multiple colons and
// must keep the profile/default port.
if strings.Count(raw, ":") == 1 {
host, portText, ok := strings.Cut(raw, ":")
if ok {
if pi, err := strconv.Atoi(strings.TrimSpace(portText)); err == nil && pi > 0 {
return strings.TrimSpace(host), pi
}
}
}
return strings.Trim(raw, "[]"), defaultPort
}

View File

@@ -0,0 +1,353 @@
package engine
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
)
const (
udpgwMaxFrame = 64 * 1024
udpgwProtocolBadVPN = "badvpn"
udpgwProtocolLegacy = "legacy"
udpgwFlagKeepAlive = 1 << 0
udpgwFlagRebind = 1 << 1
udpgwFlagDNS = 1 << 2
udpgwFlagIPv6 = 1 << 3
)
type udpgwSession struct {
mu sync.Mutex
nextID uint16
pairID map[string]uint16
idClient map[uint16]*net.UDPAddr
}
func (s *SocksServer) handleUDPGWAssociate(c net.Conn) {
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
_ = writeReply(c, socksRepFail)
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
return
}
defer udp.Close()
gwAddr := net.JoinHostPort(s.UDPGW.Host, strconv.Itoa(s.UDPGW.Port))
gwConn, err := dialSSHDirectTCPAddr(s.SSH, gwAddr, c.RemoteAddr(), s.Logger)
if err != nil {
_ = writeReply(c, socksRepFail)
s.Logger.Add("warn", "udpgw dial through SSH failed %s: %v", gwAddr, err)
return
}
defer gwConn.Close()
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
return
}
_ = c.SetDeadline(time.Time{})
proto := normalizeUDPGWProtocol(s.UDPGW.Protocol)
s.Logger.Add("info", "SOCKS5 UDP associate using UDPGW %s over SSH; protocol=%s local UDP=%s", gwAddr, proto, udp.LocalAddr().String())
done := make(chan struct{})
closeOnce := sync.Once{}
stop := func() {
closeOnce.Do(func() {
_ = c.Close()
_ = gwConn.Close()
_ = udp.Close()
close(done)
})
}
go func() {
_, _ = io.Copy(io.Discard, c)
stop()
}()
sess := &udpgwSession{
nextID: 1,
pairID: make(map[string]uint16),
idClient: make(map[uint16]*net.UDPAddr),
}
go s.readUDPGWReplies(done, proto, gwConn, udp, sess)
s.forwardUDPToUDPGW(done, proto, gwConn, udp, sess)
stop()
}
func normalizeUDPGWProtocol(v string) string {
switch strings.ToLower(strings.TrimSpace(v)) {
case udpgwProtocolLegacy:
return udpgwProtocolLegacy
default:
return udpgwProtocolBadVPN
}
}
func (s *SocksServer) forwardUDPToUDPGW(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
buf := make([]byte, 64*1024)
for {
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
n, clientAddr, err := udp.ReadFromUDP(buf)
if err != nil {
select {
case <-done:
return
default:
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
s.Logger.Add("debug", "udpgw local UDP read failed: %v", err)
return
}
addr, payload, err := parseSocksUDP(buf[:n])
if err != nil {
if shouldLogUDPError(err) {
s.Logger.Add("debug", "udpgw parse SOCKS UDP failed: %v", err)
}
continue
}
connID, isNew, err := sess.idFor(clientAddr, addr)
if err != nil {
s.Logger.Add("debug", "udpgw session id failed for %s: %v", addr.Addr(), err)
continue
}
frame, err := udpgwBuildClientFrame(proto, connID, isNew, addr, payload)
if err != nil {
if shouldLogUDPError(err) {
s.Logger.Add("debug", "udpgw frame build failed for %s: %v", addr.Addr(), err)
}
continue
}
_ = gwConn.SetWriteDeadline(time.Now().Add(15 * time.Second))
if _, err := gwConn.Write(frame); err != nil {
s.Logger.Add("warn", "udpgw TCP write failed: %v", err)
return
}
}
}
func (s *SocksServer) readUDPGWReplies(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
br := bufio.NewReaderSize(gwConn, 256*1024)
for {
select {
case <-done:
return
default:
}
payload, err := udpgwReadFrame(br, udpgwMaxFrame)
if err != nil {
select {
case <-done:
return
default:
}
if !errors.Is(err, io.EOF) {
s.Logger.Add("debug", "udpgw TCP read failed: %v", err)
}
return
}
connID, src, data, err := udpgwParseReplyPayload(proto, payload)
if err != nil {
s.Logger.Add("debug", "udpgw bad reply frame: %v", err)
continue
}
clientAddr := sess.clientFor(connID)
if clientAddr == nil {
continue
}
packet, err := buildSocksUDP(src, data)
if err != nil {
s.Logger.Add("debug", "udpgw SOCKS UDP reply build failed: %v", err)
continue
}
_, _ = udp.WriteToUDP(packet, clientAddr)
}
}
func (s *udpgwSession) idFor(client *net.UDPAddr, dest socksRequest) (uint16, bool, error) {
if client == nil {
return 0, false, errors.New("missing local UDP client address")
}
key := client.String() + "|" + dest.Addr()
s.mu.Lock()
defer s.mu.Unlock()
if id, ok := s.pairID[key]; ok {
s.idClient[id] = cloneUDPAddr(client)
return id, false, nil
}
id := s.nextID
if id == 0 {
id = 1
}
s.nextID = id + 1
if s.nextID == 0 {
s.nextID = 1
}
s.pairID[key] = id
s.idClient[id] = cloneUDPAddr(client)
return id, true, nil
}
func (s *udpgwSession) clientFor(id uint16) *net.UDPAddr {
s.mu.Lock()
defer s.mu.Unlock()
return cloneUDPAddr(s.idClient[id])
}
func cloneUDPAddr(a *net.UDPAddr) *net.UDPAddr {
if a == nil {
return nil
}
out := *a
if a.IP != nil {
out.IP = append(net.IP(nil), a.IP...)
}
return &out
}
func udpgwReadFrame(r *bufio.Reader, max int) ([]byte, error) {
var lenBuf [2]byte
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
return nil, err
}
n := int(binary.LittleEndian.Uint16(lenBuf[:]))
if n <= 0 || n > max {
return nil, fmt.Errorf("invalid udpgw frame length %d", n)
}
b := make([]byte, n)
if _, err := io.ReadFull(r, b); err != nil {
return nil, err
}
return b, nil
}
func udpgwBuildClientFrame(proto string, connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
return udpgwBuildLegacyFrame(connID, 0, dest, data)
}
return udpgwBuildBadVPNFrame(connID, isNew, dest, data)
}
func udpgwParseReplyPayload(proto string, payload []byte) (uint16, socksRequest, []byte, error) {
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
return udpgwParseLegacyReplyPayload(payload)
}
return udpgwParseBadVPNReplyPayload(payload)
}
// udpgwBuildBadVPNFrame implements the normal badvpn-udpgw PacketProto frame:
// length(little-endian uint16) + flags(1) + conid(little-endian uint16) +
// IPv4/IPv6 destination + destination port(network byte order) + UDP payload.
// This is the same framing used by the Android badvpn UDPGW client and supports
// IPv6 targets with UDPGW_CLIENT_FLAG_IPV6.
func udpgwBuildBadVPNFrame(connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
ip := net.ParseIP(dest.Host)
if ip == nil {
return nil, fmt.Errorf("badvpn UDPGW requires an IP target, got %q", dest.Host)
}
flags := byte(0)
if isNew {
flags |= udpgwFlagRebind
}
var addr []byte
if v4 := ip.To4(); v4 != nil {
addr = append(addr, v4...)
} else if v6 := ip.To16(); v6 != nil {
flags |= udpgwFlagIPv6
addr = append(addr, v6...)
} else {
return nil, fmt.Errorf("invalid UDPGW target IP %q", dest.Host)
}
payloadLen := 1 + 2 + len(addr) + 2 + len(data)
if payloadLen <= 0 || payloadLen > 65535 {
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
}
out := make([]byte, 2+payloadLen)
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
out[2] = flags
binary.LittleEndian.PutUint16(out[3:5], connID)
copy(out[5:5+len(addr)], addr)
portOff := 5 + len(addr)
binary.BigEndian.PutUint16(out[portOff:portOff+2], uint16(dest.Port))
copy(out[portOff+2:], data)
return out, nil
}
func udpgwParseBadVPNReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
var src socksRequest
if len(payload) < 1+2+4+2 {
return 0, src, nil, fmt.Errorf("short badvpn udpgw payload %d", len(payload))
}
flags := payload[0]
if flags&udpgwFlagKeepAlive != 0 {
return 0, src, nil, errors.New("unexpected udpgw keepalive reply")
}
connID := binary.LittleEndian.Uint16(payload[1:3])
off := 3
if flags&udpgwFlagIPv6 != 0 {
if len(payload) < off+16+2 {
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv6 payload %d", len(payload))
}
src.Host = net.IP(payload[off : off+16]).String()
off += 16
} else {
if len(payload) < off+4+2 {
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv4 payload %d", len(payload))
}
src.Host = net.IP(payload[off : off+4]).String()
off += 4
}
src.Port = int(binary.BigEndian.Uint16(payload[off : off+2]))
off += 2
return connID, src, payload[off:], nil
}
// Legacy frame kept only for older experimental PC builds. It is IPv4-only.
func udpgwBuildLegacyFrame(connID uint16, x byte, dest socksRequest, data []byte) ([]byte, error) {
ip := net.ParseIP(dest.Host)
v4 := ip.To4()
if v4 == nil {
return nil, fmt.Errorf("legacy UDPGW only supports IPv4 UDP targets; got %s", dest.Addr())
}
payloadLen := 2 + 1 + 4 + 2 + len(data)
if payloadLen <= 0 || payloadLen > 65535 {
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
}
out := make([]byte, 2+payloadLen)
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
binary.BigEndian.PutUint16(out[2:4], connID)
out[4] = x
copy(out[5:9], v4)
binary.BigEndian.PutUint16(out[9:11], uint16(dest.Port))
copy(out[11:], data)
return out, nil
}
func udpgwParseLegacyReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
var src socksRequest
if len(payload) < 2+1+4+2 {
return 0, src, nil, fmt.Errorf("short legacy udpgw payload %d", len(payload))
}
connID := binary.BigEndian.Uint16(payload[0:2])
src.Host = net.IP(payload[3:7]).String()
src.Port = int(binary.BigEndian.Uint16(payload[7:9]))
return connID, src, payload[9:], nil
}

1043
internal/nativeui/ui.go Normal file

File diff suppressed because it is too large Load Diff

23
internal/oscmd/command.go Normal file
View File

@@ -0,0 +1,23 @@
package oscmd
import (
"context"
"os/exec"
)
// Command creates an exec.Cmd with platform-specific defaults that are safe for
// the GUI application. On Windows, the implementation hides child console
// windows so helpers like powershell.exe, netsh.exe, route.exe, xray.exe, and
// dnstt-client.exe do not flash a terminal window.
func Command(name string, args ...string) *exec.Cmd {
cmd := exec.Command(name, args...)
applyPlatformOptions(cmd)
return cmd
}
// CommandContext is the context-aware version of Command.
func CommandContext(ctx context.Context, name string, args ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, args...)
applyPlatformOptions(cmd)
return cmd
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package oscmd
import "os/exec"
func applyPlatformOptions(cmd *exec.Cmd) {}

View File

@@ -0,0 +1,17 @@
//go:build windows
package oscmd
import (
"os/exec"
"syscall"
)
const createNoWindow = 0x08000000
func applyPlatformOptions(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: createNoWindow,
}
}

View File

@@ -0,0 +1,15 @@
//go:build !windows
package platformtun
import "strings"
type Logger interface {
Add(level, format string, args ...any)
}
func Prepare(device, interfaceName string, mtu int, logger Logger) (string, string, func(), error) {
device = strings.TrimSpace(device)
interfaceName = strings.TrimSpace(interfaceName)
return device, interfaceName, nil, nil
}

View File

@@ -0,0 +1,57 @@
//go:build windows
package platformtun
import (
"strings"
"socksrevivepc/internal/wintunloader"
)
type Logger interface {
Add(level, format string, args ...any)
}
// Prepare validates and loads Wintun for the current process, then normalizes
// the device string used by tun2socks. It intentionally does not pre-open a
// WireGuard TUN adapter here: tun2socks must own the live adapter/session.
// Pre-opening and closing the adapter before tun2socks can make some Windows
// builds close/crash with no useful GUI error.
func Prepare(device, interfaceName string, mtu int, logger Logger) (string, string, func(), error) {
if mtu <= 0 {
mtu = 1500
}
if err := wintunloader.Prepare(logger); err != nil {
return "", "", nil, err
}
device = normalizeWindowsDevice(device)
interfaceName = normalizeWindowsInterface(interfaceName)
if logger != nil {
logger.Add("info", "Windows Wintun is ready: adapter=%s device=%s mtu=%d", interfaceName, device, mtu)
}
return device, interfaceName, nil, nil
}
func normalizeWindowsDevice(device string) string {
device = strings.TrimSpace(device)
// tun2socks v2 uses the Windows device model name "wintun". Values
// like "wintun://SocksRevive" are Linux-style URL device strings and
// can make the engine search for a non-existent network interface.
if device == "" || strings.EqualFold(device, "tun") || strings.EqualFold(device, "wintun") {
return "wintun"
}
if strings.HasPrefix(strings.ToLower(device), "wintun://") || strings.EqualFold(device, "tun://wintun") {
return "wintun"
}
return "wintun"
}
func normalizeWindowsInterface(interfaceName string) string {
interfaceName = strings.TrimSpace(interfaceName)
if interfaceName == "" || strings.EqualFold(interfaceName, "SocksRevive") {
return "wintun"
}
return interfaceName
}

574
internal/routes/routes.go Normal file
View File

@@ -0,0 +1,574 @@
package routes
import (
"encoding/json"
"fmt"
"net"
"runtime"
"strconv"
"strings"
"time"
"socksrevivepc/internal/config"
"socksrevivepc/internal/oscmd"
)
type Logger interface {
Add(level, format string, args ...any)
}
type Cleanup struct {
commands [][]string
logger Logger
}
func Apply(p config.Profile, proxyHosts []string, logger Logger) (*Cleanup, error) {
if !p.Tun.Enabled || !p.Tun.RouteAll {
return &Cleanup{logger: logger}, nil
}
logger.Add("info", "applying TUN routes; admin/root permission may be required")
gw, iface, err := defaultGateway()
if err != nil {
logger.Add("warn", "cannot detect default gateway: %v", err)
}
cleanup := &Cleanup{logger: logger}
switch runtime.GOOS {
case "windows":
applyWindowsRoutes(p, proxyHosts, gw, cleanup, logger)
case "linux":
dev := p.Tun.InterfaceName
if dev == "" {
dev = "socksrevive0"
}
_ = run(logger, "ip", "addr", "add", p.Tun.CIDR, "dev", dev)
_ = run(logger, "ip", "link", "set", dev, "up")
addBypassLinux(proxyHosts, gw, iface, cleanup, logger)
if err := run(logger, "ip", "route", "replace", "0.0.0.0/1", "via", p.Tun.Gateway, "dev", dev); err == nil {
cleanup.commands = append(cleanup.commands, []string{"ip", "route", "del", "0.0.0.0/1"})
}
if err := run(logger, "ip", "route", "replace", "128.0.0.0/1", "via", p.Tun.Gateway, "dev", dev); err == nil {
cleanup.commands = append(cleanup.commands, []string{"ip", "route", "del", "128.0.0.0/1"})
}
applyLinuxIPv6Routes(p, proxyHosts, dev, cleanup, logger)
case "darwin":
dev := p.Tun.InterfaceName
_ = run(logger, "ifconfig", dev, p.Tun.Gateway, p.Tun.Gateway, "up")
addBypassDarwin(proxyHosts, gw, cleanup, logger)
if err := run(logger, "route", "add", "0.0.0.0/1", p.Tun.Gateway); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", "0.0.0.0/1"})
}
if err := run(logger, "route", "add", "128.0.0.0/1", p.Tun.Gateway); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", "128.0.0.0/1"})
}
applyDarwinIPv6Routes(p, proxyHosts, dev, cleanup, logger)
default:
logger.Add("warn", "route automation not implemented for %s", runtime.GOOS)
}
return cleanup, nil
}
func (c *Cleanup) Run() {
if c == nil {
return
}
for i := len(c.commands) - 1; i >= 0; i-- {
cmd := c.commands[i]
if len(cmd) == 0 {
continue
}
_ = run(c.logger, cmd[0], cmd[1:]...)
}
}
func applyWindowsRoutes(p config.Profile, proxyHosts []string, gw string, cleanup *Cleanup, logger Logger) {
alias, ifIndex, err := waitWindowsTunInterface(p.Tun.InterfaceName, 6*time.Second, logger)
if err != nil {
logger.Add("warn", "cannot find Windows Wintun adapter for routes: %v", err)
return
}
ip, mask := ipv4CIDRToAddressAndMask(p.Tun.CIDR, p.Tun.Gateway)
logger.Add("info", "Windows TUN route adapter: alias=%s ifIndex=%d ip=%s mask=%s", alias, ifIndex, ip, mask)
_ = run(logger, "netsh", "interface", "ipv4", "set", "interface", fmt.Sprint(ifIndex), "metric=1")
if err := run(logger, "netsh", "interface", "ipv4", "set", "address", "name="+alias, "static", ip, mask); err != nil {
logger.Add("warn", "failed to set Wintun IPv4 address; routes may not work until Windows finishes creating the adapter")
}
if len(p.Tun.DNS) > 0 {
_ = run(logger, "netsh", "interface", "ipv4", "set", "dnsservers", "name="+alias, "static", p.Tun.DNS[0], "validate=no")
for i, dns := range p.Tun.DNS[1:] {
dns = strings.TrimSpace(dns)
if dns == "" {
continue
}
_ = run(logger, "netsh", "interface", "ipv4", "add", "dnsservers", "name="+alias, dns, fmt.Sprint(i+2), "validate=no")
}
}
// Bypass routes must stay on the original physical gateway so the SSH/Xray/DNSTT
// control connection never loops back into the TUN default route.
addBypassWindows(proxyHosts, gw, cleanup, logger)
addWindowsIPv4SplitRoute(alias, ifIndex, "0.0.0.0/1", "0.0.0.0", "128.0.0.0", cleanup, logger)
addWindowsIPv4SplitRoute(alias, ifIndex, "128.0.0.0/1", "128.0.0.0", "128.0.0.0", cleanup, logger)
applyWindowsIPv6Routes(p, proxyHosts, alias, ifIndex, cleanup, logger)
}
func applyWindowsIPv6Routes(p config.Profile, proxyHosts []string, alias string, ifIndex int, cleanup *Cleanup, logger Logger) {
if !p.Tun.IPv6Enabled && p.Tun.AllowIPv6Leak {
logger.Add("info", "IPv6 TUN is disabled and IPv6 leak blocking is off; leaving normal IPv6 routes untouched")
return
}
_ = run(logger, "netsh", "interface", "ipv6", "set", "interface", fmt.Sprint(ifIndex), "metric=1")
ipv6Addr, prefixLen := ipv6CIDRToAddressAndPrefix(p.Tun.IPv6CIDR)
if p.Tun.IPv6Enabled {
logger.Add("info", "IPv6 TUN routing enabled: address=%s/%s", ipv6Addr, prefixLen)
_ = run(logger, "netsh", "interface", "ipv6", "add", "address", "interface="+alias, "address="+ipv6Addr, "store=active")
setWindowsIPv6DNS(alias, p.Tun.IPv6DNS, logger)
if gw, idxStr, err := defaultGatewayIPv6(); err == nil {
if idx, convErr := strconv.Atoi(idxStr); convErr == nil && idx > 0 {
addBypassWindowsIPv6(proxyHosts, gw, idx, cleanup, logger)
}
}
} else {
logger.Add("info", "IPv6 TUN is disabled; adding IPv6 split routes as leak protection")
}
addWindowsIPv6SplitRoute(alias, "::/1", cleanup, logger)
addWindowsIPv6SplitRoute(alias, "8000::/1", cleanup, logger)
}
func setWindowsIPv6DNS(alias string, dns []string, logger Logger) {
if len(dns) == 0 {
return
}
first := strings.TrimSpace(dns[0])
if first == "" {
return
}
_ = run(logger, "netsh", "interface", "ipv6", "set", "dnsservers", "name="+alias, "static", first, "validate=no")
for i, server := range dns[1:] {
server = strings.TrimSpace(server)
if server == "" {
continue
}
_ = run(logger, "netsh", "interface", "ipv6", "add", "dnsservers", "name="+alias, server, fmt.Sprint(i+2), "validate=no")
}
}
func applyLinuxIPv6Routes(p config.Profile, proxyHosts []string, dev string, cleanup *Cleanup, logger Logger) {
if !p.Tun.IPv6Enabled && p.Tun.AllowIPv6Leak {
logger.Add("info", "IPv6 TUN is disabled and IPv6 leak blocking is off; leaving normal IPv6 routes untouched")
return
}
if p.Tun.IPv6Enabled {
logger.Add("info", "IPv6 TUN routing enabled: cidr=%s", p.Tun.IPv6CIDR)
_ = run(logger, "ip", "-6", "addr", "add", p.Tun.IPv6CIDR, "dev", dev)
if gw, iface, err := defaultGatewayIPv6(); err == nil {
addBypassLinuxIPv6(proxyHosts, gw, iface, cleanup, logger)
}
} else {
logger.Add("info", "IPv6 TUN is disabled; adding IPv6 split routes as leak protection")
}
if err := run(logger, "ip", "-6", "route", "replace", "::/1", "dev", dev, "metric", "1"); err == nil {
cleanup.commands = append(cleanup.commands, []string{"ip", "-6", "route", "del", "::/1"})
}
if err := run(logger, "ip", "-6", "route", "replace", "8000::/1", "dev", dev, "metric", "1"); err == nil {
cleanup.commands = append(cleanup.commands, []string{"ip", "-6", "route", "del", "8000::/1"})
}
}
func applyDarwinIPv6Routes(p config.Profile, proxyHosts []string, dev string, cleanup *Cleanup, logger Logger) {
if !p.Tun.IPv6Enabled && p.Tun.AllowIPv6Leak {
logger.Add("info", "IPv6 TUN is disabled and IPv6 leak blocking is off; leaving normal IPv6 routes untouched")
return
}
ipv6Addr, prefixLen := ipv6CIDRToAddressAndPrefix(p.Tun.IPv6CIDR)
if p.Tun.IPv6Enabled {
logger.Add("info", "IPv6 TUN routing enabled: address=%s/%s", ipv6Addr, prefixLen)
_ = run(logger, "ifconfig", dev, "inet6", ipv6Addr, "prefixlen", prefixLen, "up")
if gw, _, err := defaultGatewayIPv6(); err == nil {
addBypassDarwinIPv6(proxyHosts, gw, cleanup, logger)
}
} else {
logger.Add("info", "IPv6 TUN is disabled; adding IPv6 split routes as leak protection")
}
if err := run(logger, "route", "add", "-inet6", "::/1", "-interface", dev); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", "-inet6", "::/1"})
}
if err := run(logger, "route", "add", "-inet6", "8000::/1", "-interface", dev); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", "-inet6", "8000::/1"})
}
}
func addWindowsIPv4SplitRoute(alias string, ifIndex int, prefix, legacyDest, legacyMask string, cleanup *Cleanup, logger Logger) {
args := []string{"interface", "ipv4", "add", "route", "prefix=" + prefix, "interface=" + alias, "nexthop=0.0.0.0", "metric=1", "store=active"}
if err := run(logger, "netsh", args...); err == nil {
cleanup.commands = append(cleanup.commands, []string{"netsh", "interface", "ipv4", "delete", "route", "prefix=" + prefix, "interface=" + alias, "nexthop=0.0.0.0", "store=active"})
return
}
// Fallback for older Windows/netsh variants.
if err := run(logger, "route", "add", legacyDest, "mask", legacyMask, "0.0.0.0", "metric", "1", "if", fmt.Sprint(ifIndex)); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", legacyDest, "mask", legacyMask})
}
}
func addWindowsIPv6SplitRoute(alias, prefix string, cleanup *Cleanup, logger Logger) {
args := []string{"interface", "ipv6", "add", "route", "prefix=" + prefix, "interface=" + alias, "nexthop=::", "metric=1", "store=active"}
if err := run(logger, "netsh", args...); err == nil {
cleanup.commands = append(cleanup.commands, []string{"netsh", "interface", "ipv6", "delete", "route", "prefix=" + prefix, "interface=" + alias, "nexthop=::", "store=active"})
}
}
func ipv4CIDRToAddressAndMask(cidr, fallbackIP string) (string, string) {
ip, ipnet, err := net.ParseCIDR(strings.TrimSpace(cidr))
if err == nil && ip.To4() != nil {
return ip.String(), net.IP(ipnet.Mask).String()
}
if parsed := net.ParseIP(strings.TrimSpace(fallbackIP)); parsed != nil && parsed.To4() != nil {
return parsed.String(), "255.254.0.0"
}
return "198.18.0.1", "255.254.0.0"
}
func ipv6CIDRToAddressAndPrefix(cidr string) (string, string) {
ip, ipnet, err := net.ParseCIDR(strings.TrimSpace(cidr))
if err == nil && ip.To4() == nil && ip.To16() != nil {
ones, _ := ipnet.Mask.Size()
return ip.String(), fmt.Sprint(ones)
}
return "fd00:534f:434b::1", "64"
}
type windowsAdapter struct {
Name string `json:"Name"`
InterfaceAlias string `json:"InterfaceAlias"`
InterfaceIndex int `json:"InterfaceIndex"`
IfIndex int `json:"ifIndex"`
Status string `json:"Status"`
InterfaceDesc string `json:"InterfaceDescription"`
}
func waitWindowsTunInterface(preferred string, timeout time.Duration, logger Logger) (string, int, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
alias, idx, err := windowsTunInterface(preferred)
if err == nil && alias != "" && idx > 0 {
return alias, idx, nil
}
lastErr = err
time.Sleep(300 * time.Millisecond)
}
if lastErr != nil {
return "", 0, lastErr
}
return "", 0, fmt.Errorf("not found")
}
func windowsTunInterface(preferred string) (string, int, error) {
preferred = strings.TrimSpace(preferred)
if preferred == "" {
preferred = "wintun"
}
ps := `$preferred = ` + strconv.Quote(preferred) + `
$adapters = @(Get-NetAdapter -ErrorAction SilentlyContinue | Where-Object {
$_.Status -ne 'Disabled' -and (
$_.Name -ieq $preferred -or
$_.InterfaceDescription -like '*Wintun*' -or
$_.InterfaceDescription -like '*WireGuard*' -or
$_.Name -like '*wintun*'
)
} | Sort-Object @{Expression={if ($_.Name -ieq $preferred) {0} else {1}}}, InterfaceIndex | Select-Object -First 1 Name,InterfaceDescription,InterfaceIndex,ifIndex,Status)
if ($adapters.Count -eq 0) { exit 2 }
$adapters[0] | ConvertTo-Json -Compress`
out, err := oscmd.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", ps).Output()
if err != nil {
return "", 0, err
}
var a windowsAdapter
if err := json.Unmarshal(out, &a); err != nil {
return "", 0, err
}
idx := a.InterfaceIndex
if idx == 0 {
idx = a.IfIndex
}
name := a.Name
if name == "" {
name = a.InterfaceAlias
}
if name == "" || idx == 0 {
return "", 0, fmt.Errorf("invalid adapter json %q", strings.TrimSpace(string(out)))
}
return name, idx, nil
}
func addBypassWindows(hosts []string, gw string, cleanup *Cleanup, logger Logger) {
if gw == "" {
return
}
for _, ip := range resolveIPv4(hosts) {
if err := run(logger, "route", "add", ip, "mask", "255.255.255.255", gw, "metric", "1"); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", ip, "mask", "255.255.255.255"})
}
}
}
func addBypassLinux(hosts []string, gw, iface string, cleanup *Cleanup, logger Logger) {
if gw == "" {
return
}
for _, ip := range resolveIPv4(hosts) {
args := []string{"route", "add", ip + "/32", "via", gw}
if iface != "" {
args = append(args, "dev", iface)
}
if err := run(logger, "ip", args...); err == nil {
cleanup.commands = append(cleanup.commands, []string{"ip", "route", "del", ip + "/32"})
}
}
}
func addBypassDarwin(hosts []string, gw string, cleanup *Cleanup, logger Logger) {
if gw == "" {
return
}
for _, ip := range resolveIPv4(hosts) {
if err := run(logger, "route", "add", "-host", ip, gw); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", "-host", ip})
}
}
}
func addBypassWindowsIPv6(hosts []string, gw string, ifIndex int, cleanup *Cleanup, logger Logger) {
if gw == "" || ifIndex == 0 {
return
}
for _, ip := range resolveIPv6(hosts) {
prefix := ip + "/128"
if err := run(logger, "netsh", "interface", "ipv6", "add", "route", "prefix="+prefix, "interface="+fmt.Sprint(ifIndex), "nexthop="+gw, "metric=1", "store=active"); err == nil {
cleanup.commands = append(cleanup.commands, []string{"netsh", "interface", "ipv6", "delete", "route", "prefix=" + prefix, "interface=" + fmt.Sprint(ifIndex), "nexthop=" + gw, "store=active"})
}
}
}
func addBypassLinuxIPv6(hosts []string, gw, iface string, cleanup *Cleanup, logger Logger) {
for _, ip := range resolveIPv6(hosts) {
args := []string{"-6", "route", "add", ip + "/128"}
if gw != "" {
args = append(args, "via", gw)
}
if iface != "" {
args = append(args, "dev", iface)
}
if err := run(logger, "ip", args...); err == nil {
cleanup.commands = append(cleanup.commands, []string{"ip", "-6", "route", "del", ip + "/128"})
}
}
}
func addBypassDarwinIPv6(hosts []string, gw string, cleanup *Cleanup, logger Logger) {
if gw == "" {
return
}
for _, ip := range resolveIPv6(hosts) {
if err := run(logger, "route", "add", "-inet6", "-host", ip, gw); err == nil {
cleanup.commands = append(cleanup.commands, []string{"route", "delete", "-inet6", "-host", ip})
}
}
}
func resolveIPv6(hosts []string) []string {
seen := map[string]bool{}
var out []string
for _, host := range hosts {
host = strings.TrimSpace(host)
if host == "" {
continue
}
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if ip := net.ParseIP(host); ip != nil {
if ip.To4() == nil && ip.To16() != nil && !seen[ip.String()] {
seen[ip.String()] = true
out = append(out, ip.String())
}
continue
}
ips, err := net.LookupIP(host)
if err != nil {
continue
}
for _, ip := range ips {
if ip.To4() == nil && ip.To16() != nil && !seen[ip.String()] {
seen[ip.String()] = true
out = append(out, ip.String())
}
}
}
return out
}
func resolveIPv4(hosts []string) []string {
seen := map[string]bool{}
var out []string
for _, host := range hosts {
host = strings.TrimSpace(host)
if host == "" || net.ParseIP(host) != nil && net.ParseIP(host).To4() == nil {
continue
}
ips, err := net.LookupIP(host)
if err != nil {
if ip := net.ParseIP(host); ip != nil && ip.To4() != nil && !seen[ip.String()] {
seen[ip.String()] = true
out = append(out, ip.String())
}
continue
}
for _, ip := range ips {
v4 := ip.To4()
if v4 != nil && !seen[v4.String()] {
seen[v4.String()] = true
out = append(out, v4.String())
}
}
}
return out
}
func defaultGatewayIPv6() (gateway string, iface string, err error) {
switch runtime.GOOS {
case "windows":
ps := `Get-NetRoute -AddressFamily IPv6 -DestinationPrefix "::/0" -ErrorAction SilentlyContinue |
Where-Object { $_.NextHop -and $_.NextHop -ne "::" } |
Sort-Object RouteMetric, InterfaceMetric |
Select-Object -First 1 NextHop, InterfaceIndex | ConvertTo-Json -Compress`
out, err := oscmd.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", ps).Output()
if err != nil {
return "", "", err
}
var row struct {
NextHop string `json:"NextHop"`
InterfaceIndex int `json:"InterfaceIndex"`
}
if err := json.Unmarshal(out, &row); err != nil {
return "", "", err
}
if row.NextHop == "" || row.InterfaceIndex == 0 {
return "", "", fmt.Errorf("IPv6 default gateway not found")
}
return row.NextHop, fmt.Sprint(row.InterfaceIndex), nil
case "linux":
out, err := oscmd.Command("ip", "-6", "route", "show", "default").Output()
if err != nil {
return "", "", err
}
fields := strings.Fields(string(out))
for i, f := range fields {
if f == "via" && i+1 < len(fields) {
gateway = fields[i+1]
}
if f == "dev" && i+1 < len(fields) {
iface = fields[i+1]
}
}
if gateway == "" && iface == "" {
return "", "", fmt.Errorf("IPv6 default gateway not found")
}
return gateway, iface, nil
case "darwin":
out, err := oscmd.Command("route", "-n", "get", "-inet6", "default").Output()
if err != nil {
return "", "", err
}
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
if strings.HasPrefix(line, "interface:") {
iface = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
}
}
if gateway == "" && iface == "" {
return "", "", fmt.Errorf("IPv6 default gateway not found")
}
return gateway, iface, nil
default:
return "", "", fmt.Errorf("IPv6 default gateway detection not implemented for %s", runtime.GOOS)
}
}
func run(logger Logger, name string, args ...string) error {
cmd := oscmd.Command(name, args...)
out, err := cmd.CombinedOutput()
line := strings.TrimSpace(string(out))
if err != nil {
logger.Add("warn", "%s %s failed: %v %s", name, strings.Join(args, " "), err, line)
return err
}
if line != "" {
logger.Add("debug", "%s: %s", name, line)
} else if strings.EqualFold(name, "netsh") || strings.EqualFold(name, "route") || strings.EqualFold(name, "ip") {
logger.Add("debug", "%s %s: OK", name, strings.Join(args, " "))
}
return nil
}
func defaultGateway() (gateway string, iface string, err error) {
switch runtime.GOOS {
case "linux":
out, err := oscmd.Command("ip", "route", "show", "default").Output()
if err != nil {
return "", "", err
}
fields := strings.Fields(string(out))
for i, f := range fields {
if f == "via" && i+1 < len(fields) {
gateway = fields[i+1]
}
if f == "dev" && i+1 < len(fields) {
iface = fields[i+1]
}
}
return gateway, iface, nil
case "darwin":
out, err := oscmd.Command("route", "-n", "get", "default").Output()
if err != nil {
return "", "", err
}
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
if strings.HasPrefix(line, "interface:") {
iface = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
}
}
return gateway, iface, nil
case "windows":
ps := `Get-NetRoute -DestinationPrefix '0.0.0.0/0' -ErrorAction SilentlyContinue |
Where-Object { $_.NextHop -and $_.NextHop -ne '0.0.0.0' } |
Sort-Object RouteMetric, InterfaceMetric |
Select-Object -First 1 NextHop,InterfaceAlias |
ConvertTo-Json -Compress`
out, err := oscmd.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", ps).Output()
if err != nil {
return "", "", err
}
var r struct {
NextHop string
InterfaceAlias string
}
if err := json.Unmarshal(out, &r); err != nil {
return "", "", err
}
return r.NextHop, r.InterfaceAlias, nil
default:
return "", "", fmt.Errorf("unsupported OS")
}
}

96
internal/tun/tun2socks.go Normal file
View File

@@ -0,0 +1,96 @@
package tun
import (
"fmt"
"net"
"runtime"
"time"
"github.com/xjasonlyu/tun2socks/v2/engine"
"socksrevivepc/internal/config"
"socksrevivepc/internal/platformtun"
)
type Logger interface {
Add(level, format string, args ...any)
}
type Runner struct {
active bool
logger Logger
cleanup func()
}
func NewRunner(logger Logger) *Runner { return &Runner{logger: logger} }
func (r *Runner) Start(p *config.Profile, socksAddr string) error {
if p == nil || !p.Tun.Enabled {
return nil
}
if r.active {
return fmt.Errorf("tun is already active")
}
if err := waitForProxyPort(socksAddr, 2500*time.Millisecond); err != nil {
return fmt.Errorf("cannot start TUN because upstream SOCKS is not reachable at %s: %w", socksAddr, err)
}
device, iface, cleanup, err := platformtun.Prepare(p.Tun.Device, p.Tun.InterfaceName, p.Tun.MTU, r.logger)
if err != nil {
return err
}
p.Tun.Device = device
p.Tun.InterfaceName = iface
r.cleanup = cleanup
key := &engine.Key{
MTU: p.Tun.MTU,
Device: p.Tun.Device,
Proxy: "socks5://" + socksAddr,
LogLevel: "error",
}
if p.Tun.InterfaceName != "" && runtime.GOOS != "windows" {
key.Interface = p.Tun.InterfaceName
}
if runtime.GOOS == "windows" && r.logger != nil {
r.logger.Add("info", "Windows TUN: leaving tun2socks interface binding empty; the TUN adapter name is %s", p.Tun.InterfaceName)
}
engine.Insert(key)
// tun2socks/v2 engine.Start() initializes the default netstack and then
// returns; it does not block for the lifetime of the tunnel. The stack stays
// alive globally until engine.Stop() is called. Older builds treated this
// normal return as "engine exited immediately", which made TUN mode fail even
// after the stack was correctly created.
engine.Start()
r.active = true
r.logger.Add("info", "tun2socks started: device=%s interface=%s proxy=socks5://%s", p.Tun.Device, p.Tun.InterfaceName, socksAddr)
return nil
}
func (r *Runner) Stop() {
if !r.active {
return
}
engine.Stop()
if r.cleanup != nil {
r.cleanup()
r.cleanup = nil
}
r.active = false
r.logger.Add("info", "tun2socks stopped")
}
func waitForProxyPort(addr string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 350*time.Millisecond)
if err == nil {
_ = c.Close()
return nil
}
lastErr = err
time.Sleep(150 * time.Millisecond)
}
if lastErr != nil {
return lastErr
}
return fmt.Errorf("timeout waiting for %s", addr)
}

View File

@@ -0,0 +1,6 @@
Put the official amd64 wintun.dll here and rebuild if you want SocksRevivePC.exe to embed Wintun internally.
Expected filename:
wintun.dll
Use the signed DLL from the official Wintun package.

View File

@@ -0,0 +1,6 @@
Put the official arm64 wintun.dll here and rebuild if you build GOARCH=arm64.
Expected filename:
wintun.dll
Use the signed DLL from the official Wintun package.

View File

@@ -0,0 +1,318 @@
//go:build windows
package wintunloader
import (
"embed"
"encoding/binary"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"strings"
"syscall"
"unsafe"
)
//go:embed assets/wintun/windows/amd64/* assets/wintun/windows/arm64/*
var embeddedWintun embed.FS
type Logger interface {
Add(level, format string, args ...any)
}
// Prepare makes Wintun available before tun2socks/WireGuard tries to load it.
// Windows cannot create a real TUN adapter without the signed wintun.dll layer,
// so this function validates the DLL architecture, extracts embedded assets when
// present, and updates the current process DLL search path.
func Prepare(logger Logger) error {
arch := runtime.GOARCH
if arch != "amd64" && arch != "arm64" {
return fmt.Errorf("Wintun embedded loader only supports amd64/arm64, current GOARCH=%s", arch)
}
exeDir := executableDir()
workDir := workingDir()
candidateDirs := uniqueNonEmpty([]string{
exeDir,
workDir,
filepath.Join(exeDir, "tools", "wintun", arch),
filepath.Join(workDir, "tools", "wintun", arch),
filepath.Join(exeDir, "tools", "wintun"),
filepath.Join(workDir, "tools", "wintun"),
})
for _, dir := range candidateDirs {
path := filepath.Join(dir, "wintun.dll")
if ok, reason := validDLLForArch(path, arch); ok {
installDir, installPath, err := ensureProcessDLL(path, exeDir, arch)
if err != nil {
return err
}
activateDLLDirectory(installDir, append(candidateDirs, installDir)...)
if logger != nil {
logger.Add("info", "Wintun DLL ready: %s", installPath)
}
return nil
} else if fileExists(path) && logger != nil {
logger.Add("warn", "Ignoring invalid Wintun DLL at %s: %s", path, reason)
}
}
embeddedPath := filepath.ToSlash(filepath.Join("assets", "wintun", "windows", arch, "wintun.dll"))
data, err := embeddedWintun.ReadFile(embeddedPath)
if err == nil {
if err := validateDLLBytes(data, arch); err != nil {
return fmt.Errorf("embedded wintun.dll is invalid for %s: %w", arch, err)
}
installDir, installPath, err := writeEmbeddedDLL(data, exeDir, arch)
if err != nil {
return err
}
activateDLLDirectory(installDir, append(candidateDirs, installDir)...)
if logger != nil {
logger.Add("info", "Embedded Wintun extracted: %s", installPath)
}
return nil
}
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("read embedded wintun.dll failed: %w", err)
}
activateDLLDirectory("", candidateDirs...)
return fmt.Errorf("wintun.dll was not found. The TUN engine is compiled into the app, but Windows still requires the signed Wintun DLL. Put the official %s wintun.dll at tools/wintun/%s/wintun.dll, run scripts/embed_wintun_from_tools.ps1, and rebuild", arch, arch)
}
func ensureProcessDLL(source, exeDir, arch string) (string, string, error) {
if ok, reason := validDLLForArch(source, arch); !ok {
return "", "", fmt.Errorf("source Wintun DLL is invalid: %s: %s", source, reason)
}
if source == filepath.Join(exeDir, "wintun.dll") {
return exeDir, source, nil
}
installDir := preferredInstallDir(exeDir, arch)
installPath := filepath.Join(installDir, "wintun.dll")
if samePath(source, installPath) {
return installDir, installPath, nil
}
if ok, _ := validDLLForArch(installPath, arch); ok {
return installDir, installPath, nil
}
if err := copyFile(source, installPath); err != nil {
// Fall back to the source directory if we cannot copy beside the exe.
return filepath.Dir(source), source, nil
}
return installDir, installPath, nil
}
func writeEmbeddedDLL(data []byte, exeDir, arch string) (string, string, error) {
if err := validateDLLBytes(data, arch); err != nil {
return "", "", err
}
installDir := preferredInstallDir(exeDir, arch)
installPath := filepath.Join(installDir, "wintun.dll")
if ok, _ := validDLLForArch(installPath, arch); ok {
return installDir, installPath, nil
}
if err := os.MkdirAll(installDir, 0o755); err != nil {
return "", "", fmt.Errorf("create Wintun install directory failed: %w", err)
}
if err := os.WriteFile(installPath, data, 0o644); err != nil {
cacheDir := filepath.Join(userCacheDir(), "SocksRevivePC", "wintun", arch)
cachePath := filepath.Join(cacheDir, "wintun.dll")
if err2 := os.MkdirAll(cacheDir, 0o755); err2 != nil {
return "", "", fmt.Errorf("write embedded Wintun failed: %w; cache fallback mkdir failed: %v", err, err2)
}
if err2 := os.WriteFile(cachePath, data, 0o644); err2 != nil {
return "", "", fmt.Errorf("write embedded Wintun failed: %w; cache fallback write failed: %v", err, err2)
}
return cacheDir, cachePath, nil
}
return installDir, installPath, nil
}
func preferredInstallDir(exeDir, arch string) string {
if exeDir != "" {
return exeDir
}
return filepath.Join(userCacheDir(), "SocksRevivePC", "wintun", arch)
}
func executableDir() string {
exe, err := os.Executable()
if err != nil || exe == "" {
return ""
}
return filepath.Dir(exe)
}
func workingDir() string {
wd, err := os.Getwd()
if err != nil {
return ""
}
return wd
}
func userCacheDir() string {
dir, err := os.UserCacheDir()
if err != nil || dir == "" {
return os.TempDir()
}
return dir
}
func activateDLLDirectory(primary string, dirs ...string) {
if primary != "" {
_ = setDLLDirectory(primary)
}
prependPath(dirs...)
}
func setDLLDirectory(dir string) error {
if strings.TrimSpace(dir) == "" {
return nil
}
kernel32 := syscall.NewLazyDLL("kernel32.dll")
proc := kernel32.NewProc("SetDllDirectoryW")
ptr, err := syscall.UTF16PtrFromString(dir)
if err != nil {
return err
}
r1, _, callErr := proc.Call(uintptr(unsafe.Pointer(ptr)))
if r1 == 0 {
return callErr
}
return nil
}
func prependPath(dirs ...string) {
dirs = uniqueNonEmpty(dirs)
if len(dirs) == 0 {
return
}
current := os.Getenv("PATH")
parts := strings.Split(current, string(os.PathListSeparator))
var cleaned []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" && !containsPath(dirs, p) {
cleaned = append(cleaned, p)
}
}
os.Setenv("PATH", strings.Join(append(dirs, cleaned...), string(os.PathListSeparator)))
}
func uniqueNonEmpty(in []string) []string {
out := make([]string, 0, len(in))
for _, s := range in {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if !containsPath(out, s) {
out = append(out, s)
}
}
return out
}
func containsPath(paths []string, p string) bool {
for _, existing := range paths {
if samePath(existing, p) {
return true
}
}
return false
}
func samePath(a, b string) bool {
if a == "" || b == "" {
return false
}
return strings.EqualFold(filepath.Clean(a), filepath.Clean(b))
}
func fileExists(path string) bool {
st, err := os.Stat(path)
return err == nil && !st.IsDir()
}
func validDLLForArch(path, arch string) (bool, string) {
st, err := os.Stat(path)
if err != nil {
return false, err.Error()
}
if st.IsDir() {
return false, "path is a directory"
}
if st.Size() <= 1024 {
return false, "file is too small"
}
data, err := os.ReadFile(path)
if err != nil {
return false, err.Error()
}
if err := validateDLLBytes(data, arch); err != nil {
return false, err.Error()
}
return true, ""
}
func validateDLLBytes(data []byte, arch string) error {
if len(data) <= 1024 {
return fmt.Errorf("file is too small")
}
if len(data) < 0x40 || data[0] != 'M' || data[1] != 'Z' {
return fmt.Errorf("not a Windows PE DLL")
}
peOffset := int(binary.LittleEndian.Uint32(data[0x3c:0x40]))
if peOffset <= 0 || len(data) < peOffset+6 {
return fmt.Errorf("invalid PE header")
}
if string(data[peOffset:peOffset+4]) != "PE\x00\x00" {
return fmt.Errorf("missing PE signature")
}
machine := binary.LittleEndian.Uint16(data[peOffset+4 : peOffset+6])
want := uint16(0x8664) // IMAGE_FILE_MACHINE_AMD64
if arch == "arm64" {
want = 0xaa64 // IMAGE_FILE_MACHINE_ARM64
}
if machine != want {
return fmt.Errorf("wrong architecture machine=0x%04x expected=0x%04x for %s", machine, want, arch)
}
return nil
}
func copyFile(source, target string) error {
if ok, reason := validDLLForArch(source, runtime.GOARCH); !ok {
return fmt.Errorf("source DLL is missing or invalid: %s: %s", source, reason)
}
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return err
}
in, err := os.Open(source)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(target)
if err != nil {
return err
}
_, copyErr := io.Copy(out, in)
closeErr := out.Close()
if copyErr != nil {
_ = os.Remove(target)
return copyErr
}
if closeErr != nil {
_ = os.Remove(target)
return closeErr
}
return nil
}