Launch
This commit is contained in:
45
internal/engine/logger.go
Normal file
45
internal/engine/logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LogEntry struct {
|
||||
ID int64 `json:"id"`
|
||||
Time string `json:"time"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
entries []LogEntry
|
||||
}
|
||||
|
||||
func NewLogger() *Logger { return &Logger{} }
|
||||
|
||||
func (l *Logger) Add(level, format string, args ...any) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.nextID++
|
||||
entry := LogEntry{ID: l.nextID, Time: time.Now().Format("15:04:05"), Level: level, Message: fmt.Sprintf(format, args...)}
|
||||
l.entries = append(l.entries, entry)
|
||||
if len(l.entries) > 600 {
|
||||
l.entries = l.entries[len(l.entries)-600:]
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Since(id int64) []LogEntry {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
out := make([]LogEntry, 0)
|
||||
for _, e := range l.entries {
|
||||
if e.ID > id {
|
||||
out = append(out, e)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
525
internal/engine/manager.go
Normal file
525
internal/engine/manager.go
Normal file
@@ -0,0 +1,525 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"socksrevivepc/internal/config"
|
||||
"socksrevivepc/internal/dnsttclient"
|
||||
"socksrevivepc/internal/routes"
|
||||
"socksrevivepc/internal/tun"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
Running bool `json:"running"`
|
||||
Connecting bool `json:"connecting"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
Mode string `json:"mode"`
|
||||
SocksAddr string `json:"socks_addr"`
|
||||
Tun bool `json:"tun"`
|
||||
StartedAt string `json:"started_at"`
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
root string
|
||||
mu sync.Mutex
|
||||
logger *Logger
|
||||
status Status
|
||||
cancel context.CancelFunc
|
||||
ssh *sshBundle
|
||||
socks *SocksServer
|
||||
dnstt *ManagedProcess
|
||||
embeddedDNSTT *dnsttclient.Client
|
||||
xray *ManagedProcess
|
||||
tun *tun.Runner
|
||||
routeCleanup *routes.Cleanup
|
||||
manualStop bool
|
||||
reconnecting bool
|
||||
}
|
||||
|
||||
func NewManager(root string) *Manager {
|
||||
lg := NewLogger()
|
||||
return &Manager{root: root, logger: lg, tun: tun.NewRunner(lg)}
|
||||
}
|
||||
|
||||
func (m *Manager) LogsSince(id int64) []LogEntry { return m.logger.Since(id) }
|
||||
|
||||
func (m *Manager) Status() Status {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.status
|
||||
}
|
||||
|
||||
func (m *Manager) Start(p config.Profile) error {
|
||||
config.ApplyDefaults(&p)
|
||||
if err := config.Validate(p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.mu.Lock()
|
||||
if m.status.Running || m.status.Connecting {
|
||||
m.mu.Unlock()
|
||||
cancel()
|
||||
return fmt.Errorf("another profile is already running or connecting")
|
||||
}
|
||||
m.cancel = cancel
|
||||
m.manualStop = false
|
||||
m.status = Status{Connecting: true, ProfileID: p.ID, Mode: string(p.Mode), Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Add("info", "starting profile: %s", p.Name)
|
||||
|
||||
fail := func(format string, args ...any) error {
|
||||
err := fmt.Errorf(format, args...)
|
||||
m.logger.Add("error", "%v", err)
|
||||
m.stop(false)
|
||||
return err
|
||||
}
|
||||
|
||||
ensureCurrent := func() error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.cancel == nil || !m.status.Connecting || m.status.ProfileID != p.ID {
|
||||
return context.Canceled
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var socksAddr string
|
||||
var bypass []string
|
||||
if p.Mode == config.ModeXray {
|
||||
proc, err := StartProcess(ctx, m.root, "xray", p.Xray.Executable, p.Xray.Args, m.logger)
|
||||
if err != nil {
|
||||
return fail("start xray failed: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.xray = proc
|
||||
m.mu.Unlock()
|
||||
time.Sleep(time.Duration(p.Xray.StartupTimeoutMs) * time.Millisecond)
|
||||
if err := ensureCurrent(); err != nil {
|
||||
proc.Stop()
|
||||
return fail("connection cancelled: %w", err)
|
||||
}
|
||||
socksAddr = net.JoinHostPort(p.Xray.LocalSocksHost, fmt.Sprint(p.Xray.LocalSocksPort))
|
||||
if exited, procErr := proc.Exited(); exited {
|
||||
m.stop(false)
|
||||
if procErr != nil {
|
||||
return fmt.Errorf("xray exited before opening local SOCKS on %s: %w", socksAddr, procErr)
|
||||
}
|
||||
return fmt.Errorf("xray exited before opening local SOCKS on %s", socksAddr)
|
||||
}
|
||||
if err := waitForTCP(ctx, socksAddr, time.Duration(p.Xray.StartupTimeoutMs)*time.Millisecond); err != nil {
|
||||
return fail("xray local SOCKS is not listening on %s: %w", socksAddr, err)
|
||||
}
|
||||
bypass = nil
|
||||
} else {
|
||||
if p.Mode == config.ModeDNSTT {
|
||||
if p.DNSTT.UseEmbedded {
|
||||
client, err := dnsttclient.Start(ctx, dnsttclient.Options{
|
||||
ResolverType: p.DNSTT.ResolverType,
|
||||
ResolverAddress: p.DNSTT.ResolverAddress,
|
||||
PublicKeyHex: p.DNSTT.PublicKey,
|
||||
Domain: p.DNSTT.Domain,
|
||||
LocalAddress: net.JoinHostPort(p.DNSTT.LocalSSHHost, fmt.Sprint(p.DNSTT.LocalSSHPort)),
|
||||
UTLSDistribution: p.DNSTT.UTLSDistribution,
|
||||
StartupTimeout: time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond,
|
||||
LogWriter: dnsttLogWriter{logger: m.logger},
|
||||
})
|
||||
if err != nil {
|
||||
return fail("start embedded dnstt failed: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.embeddedDNSTT = client
|
||||
m.mu.Unlock()
|
||||
} else {
|
||||
proc, err := StartProcess(ctx, m.root, "dnstt", p.DNSTT.Executable, p.DNSTT.Args, m.logger)
|
||||
if err != nil {
|
||||
return fail("start dnstt failed: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.dnstt = proc
|
||||
m.mu.Unlock()
|
||||
time.Sleep(time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond)
|
||||
}
|
||||
if err := ensureCurrent(); err != nil {
|
||||
return fail("connection cancelled: %w", err)
|
||||
}
|
||||
}
|
||||
sshc, err := connectSSH(ctx, p, m.logger)
|
||||
if err != nil {
|
||||
return fail("%w", err)
|
||||
}
|
||||
if err := ensureCurrent(); err != nil {
|
||||
_ = sshc.Client.Close()
|
||||
_ = sshc.Conn.Close()
|
||||
return fail("connection cancelled: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.ssh = sshc
|
||||
m.mu.Unlock()
|
||||
socksAddr = net.JoinHostPort(p.Local.SocksHost, fmt.Sprint(p.Local.SocksPort))
|
||||
ss := &SocksServer{Addr: socksAddr, SSH: sshc.Client, Logger: m.logger, DNS: profileDNSServers(p), UDPGW: p.UDPGW}
|
||||
if err := ss.Start(); err != nil {
|
||||
return fail("%w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.socks = ss
|
||||
m.status.SocksAddr = socksAddr
|
||||
m.mu.Unlock()
|
||||
if err := waitForTCP(ctx, socksAddr, 2500*time.Millisecond); err != nil {
|
||||
return fail("local SOCKS is not listening on %s: %w", socksAddr, err)
|
||||
}
|
||||
bypass = effectiveBypassHosts(p, sshc.ControlHosts)
|
||||
}
|
||||
|
||||
if err := ensureCurrent(); err != nil {
|
||||
return fail("connection cancelled: %w", err)
|
||||
}
|
||||
if p.Tun.Enabled {
|
||||
if err := m.tun.Start(&p, socksAddr); err != nil {
|
||||
return fail("%w", err)
|
||||
}
|
||||
cleanup, err := routes.Apply(p, bypass, m.logger)
|
||||
if err != nil {
|
||||
m.logger.Add("warn", "route setup error: %v", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.routeCleanup = cleanup
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
current := m.cancel != nil && m.status.Connecting && m.status.ProfileID == p.ID
|
||||
err := ctx.Err()
|
||||
if err == nil && current {
|
||||
m.status = Status{Running: true, ProfileID: p.ID, Mode: string(p.Mode), SocksAddr: socksAddr, Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if err != nil || !current {
|
||||
m.stop(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return context.Canceled
|
||||
}
|
||||
m.logger.Add("info", "profile is connected; local socks=%s tun=%v", socksAddr, p.Tun.Enabled)
|
||||
m.startMonitor(ctx, p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func profileDNSServers(p config.Profile) []string {
|
||||
out := append([]string{}, p.Tun.DNS...)
|
||||
if p.Tun.IPv6Enabled {
|
||||
out = append(out, p.Tun.IPv6DNS...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func waitForTCP(ctx context.Context, addr string, timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
for time.Now().Before(deadline) {
|
||||
d := net.Dialer{Timeout: 350 * time.Millisecond}
|
||||
c, err := d.DialContext(ctx, "tcp", addr)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for %s", addr)
|
||||
}
|
||||
|
||||
func effectiveBypassHosts(p config.Profile, activeControlHosts []string) []string {
|
||||
seen := map[string]bool{}
|
||||
out := make([]string, 0, len(activeControlHosts)+1)
|
||||
add := func(host string) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = strings.Trim(h, "[]")
|
||||
}
|
||||
if isLocalBypassHost(host) || seen[host] {
|
||||
return
|
||||
}
|
||||
seen[host] = true
|
||||
out = append(out, host)
|
||||
}
|
||||
for _, host := range activeControlHosts {
|
||||
add(host)
|
||||
}
|
||||
if len(out) == 0 && p.Mode != config.ModeDNSTT {
|
||||
// Fallback for old profile formats or unexpected transports. This still only
|
||||
// adds the direct SSH host, not every rotated proxy from the profile.
|
||||
add(p.SSH.Host)
|
||||
}
|
||||
if p.Mode == config.ModeDNSTT && p.DNSTT.UseEmbedded {
|
||||
add(dnsttResolverHost(p.DNSTT.ResolverAddress))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isLocalBypassHost(host string) bool {
|
||||
host = strings.Trim(strings.ToLower(strings.TrimSpace(host)), "[]")
|
||||
if host == "" || host == "localhost" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.IsLoopback() || ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func dnsttResolverHost(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(s, "://") {
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
return u.Hostname()
|
||||
}
|
||||
}
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *Manager) Stop() {
|
||||
m.stop(true)
|
||||
}
|
||||
|
||||
func (m *Manager) stop(manual bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if manual {
|
||||
m.manualStop = true
|
||||
}
|
||||
m.stopLocked()
|
||||
}
|
||||
|
||||
func (m *Manager) startMonitor(ctx context.Context, p config.Profile) {
|
||||
if !p.Reconnect.Enabled {
|
||||
return
|
||||
}
|
||||
interval := time.Duration(p.Reconnect.CheckIntervalSeconds) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = 10 * time.Second
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if m.isManualStopped() {
|
||||
return
|
||||
}
|
||||
if err := m.probeConnection(p); err != nil {
|
||||
m.logger.Add("warn", "connection monitor detected tunnel loss: %v", err)
|
||||
m.reconnectLoop(p, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) isManualStopped() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.manualStop
|
||||
}
|
||||
|
||||
func (m *Manager) markReconnecting() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.manualStop || m.reconnecting {
|
||||
return false
|
||||
}
|
||||
m.reconnecting = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) clearReconnecting() {
|
||||
m.mu.Lock()
|
||||
m.reconnecting = false
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) probeConnection(p config.Profile) error {
|
||||
m.mu.Lock()
|
||||
sshc := m.ssh
|
||||
xray := m.xray
|
||||
socksAddr := m.status.SocksAddr
|
||||
running := m.status.Running
|
||||
m.mu.Unlock()
|
||||
|
||||
if !running {
|
||||
return nil
|
||||
}
|
||||
if xray != nil {
|
||||
if exited, err := xray.Exited(); exited {
|
||||
if err != nil {
|
||||
return fmt.Errorf("xray exited: %w", err)
|
||||
}
|
||||
return fmt.Errorf("xray exited")
|
||||
}
|
||||
if socksAddr != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
return waitForTCP(ctx, socksAddr, 2*time.Second)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if sshc != nil && sshc.Client != nil {
|
||||
if sshc.Conn != nil {
|
||||
_ = sshc.Conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
defer sshc.Conn.SetDeadline(time.Time{})
|
||||
}
|
||||
_, _, err := sshc.Client.SendRequest("keepalive@openssh.com", true, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssh keepalive failed: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) reconnectLoop(p config.Profile, cause error) {
|
||||
if !m.markReconnecting() {
|
||||
return
|
||||
}
|
||||
defer m.clearReconnecting()
|
||||
|
||||
delay := time.Duration(p.Reconnect.DelaySeconds) * time.Second
|
||||
if delay <= 0 {
|
||||
delay = 3 * time.Second
|
||||
}
|
||||
maxRetries := p.Reconnect.MaxRetries
|
||||
m.logger.Add("warn", "connection lost (%v); auto reconnect is enabled", cause)
|
||||
if p.Tun.Enabled {
|
||||
m.logger.Add("info", "destroying TUN before reconnect")
|
||||
}
|
||||
m.stop(false)
|
||||
if p.Tun.Enabled {
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
}
|
||||
|
||||
for attempt := 1; maxRetries <= 0 || attempt <= maxRetries; attempt++ {
|
||||
if m.isManualStopped() {
|
||||
m.logger.Add("info", "auto reconnect cancelled by user")
|
||||
return
|
||||
}
|
||||
m.logger.Add("info", "reconnect attempt %d%s in %s", attempt, reconnectLimitSuffix(maxRetries), delay)
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
}
|
||||
if m.isManualStopped() {
|
||||
m.logger.Add("info", "auto reconnect cancelled by user")
|
||||
return
|
||||
}
|
||||
if err := m.Start(p); err != nil {
|
||||
if m.isManualStopped() {
|
||||
m.logger.Add("info", "auto reconnect cancelled by user")
|
||||
return
|
||||
}
|
||||
m.logger.Add("warn", "reconnect attempt %d failed: %v", attempt, err)
|
||||
if p.Tun.Enabled {
|
||||
m.logger.Add("info", "destroying TUN after failed reconnect attempt")
|
||||
m.stop(false)
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
}
|
||||
continue
|
||||
}
|
||||
m.logger.Add("info", "reconnected successfully")
|
||||
return
|
||||
}
|
||||
m.logger.Add("error", "auto reconnect stopped after %d failed attempt(s)", maxRetries)
|
||||
}
|
||||
|
||||
func reconnectLimitSuffix(maxRetries int) string {
|
||||
if maxRetries <= 0 {
|
||||
return " (unlimited)"
|
||||
}
|
||||
return fmt.Sprintf("/%d", maxRetries)
|
||||
}
|
||||
|
||||
func (m *Manager) stopLocked() {
|
||||
wasActive := m.status.Running || m.status.Connecting
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
m.cancel = nil
|
||||
}
|
||||
if m.routeCleanup != nil {
|
||||
m.routeCleanup.Run()
|
||||
m.routeCleanup = nil
|
||||
}
|
||||
if m.tun != nil {
|
||||
m.tun.Stop()
|
||||
}
|
||||
if m.socks != nil {
|
||||
m.socks.Stop()
|
||||
m.socks = nil
|
||||
}
|
||||
if m.ssh != nil {
|
||||
_ = m.ssh.Client.Close()
|
||||
_ = m.ssh.Conn.Close()
|
||||
m.ssh = nil
|
||||
}
|
||||
if m.xray != nil {
|
||||
m.xray.Stop()
|
||||
m.xray = nil
|
||||
}
|
||||
if m.embeddedDNSTT != nil {
|
||||
m.embeddedDNSTT.Stop()
|
||||
m.embeddedDNSTT = nil
|
||||
}
|
||||
if m.dnstt != nil {
|
||||
m.dnstt.Stop()
|
||||
m.dnstt = nil
|
||||
}
|
||||
if wasActive {
|
||||
m.logger.Add("info", "disconnected")
|
||||
}
|
||||
m.status = Status{}
|
||||
}
|
||||
|
||||
type dnsttLogWriter struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
func (w dnsttLogWriter) Write(p []byte) (int, error) {
|
||||
line := strings.TrimSpace(string(p))
|
||||
if line != "" && w.logger != nil {
|
||||
w.logger.Add("dnstt", "%s", line)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
399
internal/engine/payload.go
Normal file
399
internal/engine/payload.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"socksrevivepc/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
payloadHTTPStatusPeekTimeoutMs = 1500
|
||||
payloadHTTPStatusLineLimit = 4096
|
||||
maxPayloadHTTPResponsesToSkip = 12
|
||||
maxPayloadHTTPHeaderLines = 80
|
||||
maxPayloadHTTPBodyDiscardBytes = 1024 * 1024
|
||||
)
|
||||
|
||||
type PayloadResult struct {
|
||||
StatusLine string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
type httpResponseInfo struct {
|
||||
contentLength int64
|
||||
chunked bool
|
||||
}
|
||||
|
||||
type preloadedConn struct {
|
||||
net.Conn
|
||||
preloaded *bytes.Reader
|
||||
}
|
||||
|
||||
func (c *preloadedConn) Read(p []byte) (int, error) {
|
||||
if c.preloaded != nil && c.preloaded.Len() > 0 {
|
||||
return c.preloaded.Read(p)
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
func wrapConnWithPreloadedBytes(conn net.Conn, b []byte) net.Conn {
|
||||
if len(b) == 0 {
|
||||
return conn
|
||||
}
|
||||
return &preloadedConn{Conn: conn, preloaded: bytes.NewReader(b)}
|
||||
}
|
||||
|
||||
func WritePayload(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (PayloadResult, net.Conn, error) {
|
||||
payload := buildPayload(p.Payload.Text, targetHost, targetPort)
|
||||
parts, instant := splitPayload(payload)
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := io.WriteString(conn, part); err != nil {
|
||||
return PayloadResult{}, conn, err
|
||||
}
|
||||
if i < len(parts)-1 && !instant {
|
||||
time.Sleep(time.Duration(p.Payload.SplitDelayMs) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
logger.Add("debug", "payload sent (%d bytes)", len(payload))
|
||||
|
||||
if !p.Payload.WaitForResponse {
|
||||
return PayloadResult{}, conn, nil
|
||||
}
|
||||
|
||||
return consumePayloadHTTPNegotiation(conn, p, payloadSourceLabel(p.Mode), logger)
|
||||
}
|
||||
|
||||
func consumePayloadHTTPNegotiation(conn net.Conn, p config.Profile, source string, logger *Logger) (PayloadResult, net.Conn, error) {
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
|
||||
var last PayloadResult
|
||||
var captured *bytes.Buffer
|
||||
var sawSuccess bool
|
||||
|
||||
for attempt := 0; attempt < maxPayloadHTTPResponsesToSkip; attempt++ {
|
||||
setPayloadReadDeadline(conn, p, attempt)
|
||||
captured = &bytes.Buffer{}
|
||||
line, err := readPayloadLinePreserveBytes(conn, captured, payloadHTTPStatusLineLimit)
|
||||
if err != nil {
|
||||
if isTimeoutErr(err) {
|
||||
if last.StatusCode >= 400 && !p.Payload.AcceptAnyStatus && !sawSuccess {
|
||||
return last, conn, fmt.Errorf("payload rejected with final status %d", last.StatusCode)
|
||||
}
|
||||
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
|
||||
}
|
||||
if err == io.EOF && captured.Len() > 0 {
|
||||
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
|
||||
}
|
||||
if last.StatusCode > 0 {
|
||||
return last, conn, nil
|
||||
}
|
||||
return PayloadResult{}, conn, fmt.Errorf("payload response read failed: %w", err)
|
||||
}
|
||||
|
||||
cleanLine := strings.TrimSpace(line)
|
||||
if cleanLine == "" {
|
||||
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(cleanLine, "SSH-") || !isHTTPStatusLine(cleanLine) {
|
||||
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
|
||||
}
|
||||
|
||||
code := parseStatusCode(cleanLine)
|
||||
last = PayloadResult{StatusLine: cleanLine, StatusCode: code}
|
||||
logProxyStatus(logger, source, code, cleanLine)
|
||||
logHTTPCompatibilityStatus(logger, code, cleanLine)
|
||||
if code == 101 || (code >= 200 && code < 400) {
|
||||
sawSuccess = true
|
||||
}
|
||||
|
||||
// The current bytes are confirmed HTTP/proxy negotiation bytes. Do not replay
|
||||
// them to the SSH transport. Only replay bytes when we detect SSH/non-HTTP
|
||||
// data or a partial line after timeout.
|
||||
captured = nil
|
||||
if err := consumePayloadHTTPHeadersAndBody(conn); err != nil {
|
||||
if isTimeoutErr(err) {
|
||||
return last, conn, nil
|
||||
}
|
||||
return last, conn, fmt.Errorf("payload response consume failed: %w", err)
|
||||
}
|
||||
|
||||
// Keep peeking for another immediate HTTP status block. Some payload/proxy
|
||||
// chains return several statuses (for example 403 -> 403 -> 101). Returning
|
||||
// after the first status can make SSH read HTTP text instead of SSH-2.0.
|
||||
}
|
||||
|
||||
if last.StatusCode >= 400 && !p.Payload.AcceptAnyStatus && !sawSuccess {
|
||||
return last, conn, fmt.Errorf("payload rejected with final status %d", last.StatusCode)
|
||||
}
|
||||
return last, conn, nil
|
||||
}
|
||||
|
||||
func setPayloadReadDeadline(conn net.Conn, p config.Profile, attempt int) {
|
||||
timeoutMs := payloadHTTPStatusPeekTimeoutMs
|
||||
if attempt == 0 && p.Payload.ResponseTimeoutMs > 0 {
|
||||
timeoutMs = p.Payload.ResponseTimeoutMs
|
||||
if timeoutMs < payloadHTTPStatusPeekTimeoutMs {
|
||||
timeoutMs = payloadHTTPStatusPeekTimeoutMs
|
||||
}
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond))
|
||||
}
|
||||
|
||||
func readPayloadLinePreserveBytes(conn net.Conn, captured *bytes.Buffer, limit int) (string, error) {
|
||||
var line bytes.Buffer
|
||||
buf := make([]byte, 1)
|
||||
for line.Len() < limit {
|
||||
n, err := conn.Read(buf)
|
||||
if n > 0 {
|
||||
b := buf[0]
|
||||
_ = captured.WriteByte(b)
|
||||
_ = line.WriteByte(b)
|
||||
if b == '\n' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if line.Len() > 0 && err == io.EOF {
|
||||
return line.String(), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if line.Len() == 0 {
|
||||
return "", io.EOF
|
||||
}
|
||||
return line.String(), nil
|
||||
}
|
||||
|
||||
func consumePayloadHTTPHeadersAndBody(conn net.Conn) error {
|
||||
info := httpResponseInfo{contentLength: -1}
|
||||
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
|
||||
ignored := &bytes.Buffer{}
|
||||
line, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clean := strings.TrimSpace(line)
|
||||
if clean == "" {
|
||||
break
|
||||
}
|
||||
lower := strings.ToLower(clean)
|
||||
if strings.HasPrefix(lower, "content-length:") {
|
||||
if n, err := strconv.ParseInt(strings.TrimSpace(clean[strings.Index(clean, ":")+1:]), 10, 64); err == nil {
|
||||
info.contentLength = n
|
||||
}
|
||||
} else if strings.HasPrefix(lower, "transfer-encoding:") && strings.Contains(lower, "chunked") {
|
||||
info.chunked = true
|
||||
}
|
||||
}
|
||||
if info.chunked {
|
||||
return discardPayloadChunkedBody(conn)
|
||||
}
|
||||
if info.contentLength > 0 {
|
||||
return discardPayloadFixedLengthBody(conn, info.contentLength)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func discardPayloadFixedLengthBody(conn net.Conn, contentLength int64) error {
|
||||
remaining := contentLength
|
||||
if remaining > maxPayloadHTTPBodyDiscardBytes {
|
||||
remaining = maxPayloadHTTPBodyDiscardBytes
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
for remaining > 0 {
|
||||
toRead := int64(len(buf))
|
||||
if remaining < toRead {
|
||||
toRead = remaining
|
||||
}
|
||||
n, err := conn.Read(buf[:int(toRead)])
|
||||
if n > 0 {
|
||||
remaining -= int64(n)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func discardPayloadChunkedBody(conn net.Conn) error {
|
||||
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
|
||||
ignored := &bytes.Buffer{}
|
||||
sizeLine, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cleanSize := strings.TrimSpace(sizeLine)
|
||||
if semi := strings.Index(cleanSize, ";"); semi >= 0 {
|
||||
cleanSize = strings.TrimSpace(cleanSize[:semi])
|
||||
}
|
||||
chunkSize, err := strconv.ParseInt(cleanSize, 16, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if chunkSize == 0 {
|
||||
return consumePayloadTrailingHeaders(conn)
|
||||
}
|
||||
if err := discardPayloadFixedLengthBody(conn, chunkSize); err != nil {
|
||||
return err
|
||||
}
|
||||
crlf := &bytes.Buffer{}
|
||||
_, _ = readPayloadLinePreserveBytes(conn, crlf, payloadHTTPStatusLineLimit)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func consumePayloadTrailingHeaders(conn net.Conn) error {
|
||||
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
|
||||
ignored := &bytes.Buffer{}
|
||||
line, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTimeoutErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isHTTPStatusLine(statusLine string) bool {
|
||||
clean := strings.ToUpper(strings.TrimSpace(statusLine))
|
||||
return strings.HasPrefix(clean, "HTTP/1.") || strings.HasPrefix(clean, "HTTP/2") || strings.HasPrefix(clean, "HTTP/3")
|
||||
}
|
||||
|
||||
func logProxyStatus(logger *Logger, source string, responseCode int, statusLine string) {
|
||||
cleanLine := strings.TrimSpace(statusLine)
|
||||
if cleanLine == "" {
|
||||
return
|
||||
}
|
||||
if source == "" {
|
||||
source = "PROXY"
|
||||
}
|
||||
logger.Add("info", "Proxy Status [%s]: %s", source, cleanLine)
|
||||
}
|
||||
|
||||
func logHTTPCompatibilityStatus(logger *Logger, responseCode int, firstLine string) {
|
||||
switch responseCode {
|
||||
case 200:
|
||||
logger.Add("info", "Status: 200 (Connection established) Successful")
|
||||
case 101:
|
||||
logger.Add("info", "replace 200 OK")
|
||||
logger.Add("info", "HTTP/1.1 101 Websocket")
|
||||
case 100:
|
||||
logger.Add("info", "HTTP/1.1 100 Continue")
|
||||
case 301, 302, 400, 401, 403, 404, 407, 429, 500, 502, 503, 504:
|
||||
if strings.TrimSpace(firstLine) != "" {
|
||||
logger.Add("info", "%s", strings.TrimSpace(firstLine))
|
||||
} else {
|
||||
logger.Add("info", "HTTP/1.1 %d", responseCode)
|
||||
}
|
||||
logger.Add("info", "replace 200 OK")
|
||||
logger.Add("info", "Dragon Try!")
|
||||
}
|
||||
}
|
||||
|
||||
func payloadSourceLabel(mode config.Mode) string {
|
||||
switch mode {
|
||||
case config.ModePayload:
|
||||
return "HTTP_PROXY"
|
||||
case config.ModePayloadSSL:
|
||||
return "SSL_PAYLOAD"
|
||||
default:
|
||||
return "PROXY"
|
||||
}
|
||||
}
|
||||
|
||||
func buildPayload(tpl, host string, port int) string {
|
||||
portStr := strconv.Itoa(port)
|
||||
repl := map[string]string{
|
||||
"[host]": host,
|
||||
"[port]": portStr,
|
||||
"[host_port]": net.JoinHostPort(host, portStr),
|
||||
"[crlf]": "\r\n",
|
||||
"[cr]": "\r",
|
||||
"[lf]": "\n",
|
||||
"[protocol]": "HTTP/1.1",
|
||||
"[method]": "CONNECT",
|
||||
}
|
||||
out := tpl
|
||||
for k, v := range repl {
|
||||
out = strings.ReplaceAll(out, k, v)
|
||||
}
|
||||
out = replaceRotate(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func replaceRotate(s string) string {
|
||||
for {
|
||||
start := strings.Index(s, "[rotate=")
|
||||
if start < 0 {
|
||||
return s
|
||||
}
|
||||
end := strings.Index(s[start:], "]")
|
||||
if end < 0 {
|
||||
return s
|
||||
}
|
||||
end += start
|
||||
body := strings.TrimPrefix(s[start:end+1], "[rotate=")
|
||||
body = strings.TrimSuffix(body, "]")
|
||||
choices := splitRotateChoices(body)
|
||||
choice := ""
|
||||
if len(choices) > 0 {
|
||||
choice = strings.TrimSpace(choices[rand.Intn(len(choices))])
|
||||
}
|
||||
s = s[:start] + choice + s[end+1:]
|
||||
}
|
||||
}
|
||||
|
||||
func splitRotateChoices(body string) []string {
|
||||
return strings.FieldsFunc(body, func(r rune) bool {
|
||||
switch r {
|
||||
case ';', '#', ',', '\n', '\r', '\t':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func splitPayload(s string) ([]string, bool) {
|
||||
instant := strings.Contains(s, "[instant_split]")
|
||||
s = strings.ReplaceAll(s, "[instant_split]", "[split]")
|
||||
parts := strings.Split(s, "[split]")
|
||||
return parts, instant
|
||||
}
|
||||
|
||||
func parseStatusCode(status string) int {
|
||||
fields := strings.Fields(status)
|
||||
if len(fields) < 2 {
|
||||
return 0
|
||||
}
|
||||
code, _ := strconv.Atoi(fields[1])
|
||||
return code
|
||||
}
|
||||
91
internal/engine/process.go
Normal file
91
internal/engine/process.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"socksrevivepc/internal/oscmd"
|
||||
)
|
||||
|
||||
type ManagedProcess struct {
|
||||
cmd *exec.Cmd
|
||||
name string
|
||||
logger *Logger
|
||||
mu sync.Mutex
|
||||
done chan error
|
||||
}
|
||||
|
||||
func StartProcess(ctx context.Context, root, name, exe string, args []string, logger *Logger) (*ManagedProcess, error) {
|
||||
if strings.TrimSpace(exe) == "" {
|
||||
return nil, errors.New(name + " executable path is empty")
|
||||
}
|
||||
if !filepath.IsAbs(exe) {
|
||||
exe = filepath.Join(root, exe)
|
||||
}
|
||||
cmd := oscmd.CommandContext(ctx, exe, args...)
|
||||
cmd.Dir = root
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
p := &ManagedProcess{cmd: cmd, name: name, logger: logger, done: make(chan error, 1)}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Add("info", "%s started: %s %s", name, exe, strings.Join(args, " "))
|
||||
go p.pipe(stdout, "info")
|
||||
go p.pipe(stderr, "warn")
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
p.done <- err
|
||||
if err != nil {
|
||||
logger.Add("warn", "%s stopped: %v", name, err)
|
||||
} else {
|
||||
logger.Add("info", "%s stopped", name)
|
||||
}
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *ManagedProcess) Exited() (bool, error) {
|
||||
if p == nil || p.done == nil {
|
||||
return true, nil
|
||||
}
|
||||
select {
|
||||
case err := <-p.done:
|
||||
return true, err
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ManagedProcess) pipe(r io.Reader, level string) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
if line != "" {
|
||||
p.logger.Add(level, "%s: %s", p.name, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ManagedProcess) Stop() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p == nil || p.cmd == nil || p.cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
_ = p.cmd.Process.Kill()
|
||||
} else {
|
||||
_ = p.cmd.Process.Signal(ioSignalInterrupt())
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
_ = p.cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
7
internal/engine/signal_unix.go
Normal file
7
internal/engine/signal_unix.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package engine
|
||||
|
||||
import "os"
|
||||
|
||||
func ioSignalInterrupt() os.Signal { return os.Interrupt }
|
||||
7
internal/engine/signal_windows.go
Normal file
7
internal/engine/signal_windows.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build windows
|
||||
|
||||
package engine
|
||||
|
||||
import "os"
|
||||
|
||||
func ioSignalInterrupt() os.Signal { return os.Interrupt }
|
||||
462
internal/engine/socks5.go
Normal file
462
internal/engine/socks5.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"socksrevivepc/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
socksVersion = 0x05
|
||||
socksCmdConnect = 0x01
|
||||
socksCmdUDPAssoc = 0x03
|
||||
socksAtypIPv4 = 0x01
|
||||
socksAtypDomain = 0x03
|
||||
socksAtypIPv6 = 0x04
|
||||
socksRepOK = 0x00
|
||||
socksRepFail = 0x01
|
||||
socksRepUnsupported = 0x07
|
||||
)
|
||||
|
||||
type SocksServer struct {
|
||||
Addr string
|
||||
SSH *ssh.Client
|
||||
Logger *Logger
|
||||
DNS []string
|
||||
UDPGW config.UDPGWConfig
|
||||
listener net.Listener
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
type socksRequest struct {
|
||||
Command byte
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
func (r socksRequest) Addr() string {
|
||||
return net.JoinHostPort(r.Host, strconv.Itoa(r.Port))
|
||||
}
|
||||
|
||||
func (s *SocksServer) Start() error {
|
||||
ln, err := net.Listen("tcp", s.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = ln
|
||||
s.Logger.Add("info", "local SOCKS5 listening on %s", s.Addr)
|
||||
go s.acceptLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SocksServer) Stop() {
|
||||
s.stopOnce.Do(func() {
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SocksServer) acceptLoop() {
|
||||
for {
|
||||
c, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go s.handle(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) handle(c net.Conn) {
|
||||
defer c.Close()
|
||||
_ = c.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
if err := s.handshake(c); err != nil {
|
||||
if !isExpectedProbeError(err) {
|
||||
s.Logger.Add("debug", "socks handshake failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
req, err := readSocksRequest(c)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "socks request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Command {
|
||||
case socksCmdConnect:
|
||||
s.handleConnect(c, req)
|
||||
case socksCmdUDPAssoc:
|
||||
if s.UDPGW.Enabled {
|
||||
s.handleUDPGWAssociate(c)
|
||||
} else {
|
||||
s.handleDNSUDPAssociate(c)
|
||||
}
|
||||
default:
|
||||
_ = writeReply(c, socksRepUnsupported)
|
||||
s.Logger.Add("debug", "socks unsupported command %d", req.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) handleConnect(c net.Conn, req socksRequest) {
|
||||
dest := req.Addr()
|
||||
remote, err := dialSSHDirectTCP(s.SSH, req, c.RemoteAddr(), s.Logger)
|
||||
if err != nil {
|
||||
_ = writeReply(c, 0x05)
|
||||
s.Logger.Add("warn", "ssh dial %s failed: %v", dest, err)
|
||||
return
|
||||
}
|
||||
defer remote.Close()
|
||||
_ = writeReply(c, socksRepOK)
|
||||
_ = c.SetDeadline(time.Time{})
|
||||
s.Logger.Add("debug", "socks connected %s", dest)
|
||||
proxyCopy(c, remote)
|
||||
}
|
||||
|
||||
func (s *SocksServer) handleDNSUDPAssociate(c net.Conn) {
|
||||
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
if err != nil {
|
||||
_ = writeReply(c, socksRepFail)
|
||||
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer udp.Close()
|
||||
|
||||
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
|
||||
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = c.SetDeadline(time.Time{})
|
||||
s.Logger.Add("info", "local SOCKS5 UDP associate listening on %s for DNS over SSH", udp.LocalAddr().String())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, _ = io.Copy(io.Discard, c)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
n, clientAddr, err := udp.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
s.Logger.Add("debug", "socks UDP read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.handleSocksUDPDatagram(buf[:n])
|
||||
if err != nil {
|
||||
if shouldLogUDPError(err) {
|
||||
s.Logger.Add("debug", "%v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
_, _ = udp.WriteToUDP(resp, clientAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) handleSocksUDPDatagram(packet []byte) ([]byte, error) {
|
||||
addr, payload, err := parseSocksUDP(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if addr.Port != 53 {
|
||||
return nil, fmt.Errorf("dropping unsupported UDP target %s; only DNS/UDP port 53 is proxied for SSH modes", addr.Addr())
|
||||
}
|
||||
answer, err := s.resolveDNSOverSSHTCP(addr, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS over SSH failed for %s: %w", addr.Addr(), err)
|
||||
}
|
||||
return buildSocksUDP(addr, answer)
|
||||
}
|
||||
|
||||
func (s *SocksServer) resolveDNSOverSSHTCP(addr socksRequest, query []byte) ([]byte, error) {
|
||||
servers := s.dnsServers(addr)
|
||||
var lastErr error
|
||||
for _, server := range servers {
|
||||
remote, err := dialSSHDirectTCPAddr(s.SSH, server, nil, s.Logger)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
_ = remote.SetDeadline(time.Now().Add(8 * time.Second))
|
||||
resp, err := exchangeDNSTCP(remote, query)
|
||||
_ = remote.Close()
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errors.New("no DNS server configured")
|
||||
}
|
||||
|
||||
func (s *SocksServer) dnsServers(requested socksRequest) []string {
|
||||
seen := map[string]bool{}
|
||||
var out []string
|
||||
add := func(v string) {
|
||||
v = normalizeHostPort(v, "53")
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if !seen[v] {
|
||||
seen[v] = true
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
if requested.Host != "" {
|
||||
add(net.JoinHostPort(requested.Host, strconv.Itoa(requested.Port)))
|
||||
}
|
||||
for _, dns := range s.DNS {
|
||||
add(dns)
|
||||
}
|
||||
add("1.1.1.1:53")
|
||||
add("8.8.8.8:53")
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeHostPort(v, defaultPort string) string {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if host, port, err := net.SplitHostPort(v); err == nil {
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
if ip := net.ParseIP(v); ip != nil {
|
||||
return net.JoinHostPort(ip.String(), defaultPort)
|
||||
}
|
||||
if strings.Count(v, ":") == 1 {
|
||||
host, port, ok := strings.Cut(v, ":")
|
||||
if ok && host != "" && port != "" {
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
}
|
||||
return net.JoinHostPort(v, defaultPort)
|
||||
}
|
||||
|
||||
func exchangeDNSTCP(conn net.Conn, query []byte) ([]byte, error) {
|
||||
if len(query) == 0 || len(query) > 65535 {
|
||||
return nil, fmt.Errorf("invalid DNS query length %d", len(query))
|
||||
}
|
||||
prefix := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(prefix, uint16(len(query)))
|
||||
if _, err := conn.Write(append(prefix, query...)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.ReadFull(conn, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ln := int(binary.BigEndian.Uint16(prefix))
|
||||
if ln <= 0 || ln > 65535 {
|
||||
return nil, fmt.Errorf("invalid DNS response length %d", ln)
|
||||
}
|
||||
resp := make([]byte, ln)
|
||||
if _, err := io.ReadFull(conn, resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *SocksServer) handshake(c net.Conn) error {
|
||||
header := make([]byte, 2)
|
||||
if _, err := io.ReadFull(c, header); err != nil {
|
||||
return err
|
||||
}
|
||||
if header[0] != socksVersion {
|
||||
return errors.New("not socks5")
|
||||
}
|
||||
methods := make([]byte, int(header[1]))
|
||||
if _, err := io.ReadFull(c, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := c.Write([]byte{socksVersion, 0x00})
|
||||
return err
|
||||
}
|
||||
|
||||
func readSocksRequest(c net.Conn) (socksRequest, error) {
|
||||
var req socksRequest
|
||||
h := make([]byte, 4)
|
||||
if _, err := io.ReadFull(c, h); err != nil {
|
||||
return req, err
|
||||
}
|
||||
if h[0] != socksVersion {
|
||||
return req, fmt.Errorf("invalid socks version %d", h[0])
|
||||
}
|
||||
req.Command = h[1]
|
||||
host, err := readSocksHost(c, h[3])
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
pb := make([]byte, 2)
|
||||
if _, err := io.ReadFull(c, pb); err != nil {
|
||||
return req, err
|
||||
}
|
||||
req.Host = host
|
||||
req.Port = int(binary.BigEndian.Uint16(pb))
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func readSocksHost(r io.Reader, atyp byte) (string, error) {
|
||||
switch atyp {
|
||||
case socksAtypIPv4:
|
||||
b := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return net.IP(b).String(), nil
|
||||
case socksAtypDomain:
|
||||
l := []byte{0}
|
||||
if _, err := io.ReadFull(r, l); err != nil {
|
||||
return "", err
|
||||
}
|
||||
b := make([]byte, int(l[0]))
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
case socksAtypIPv6:
|
||||
b := make([]byte, 16)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return net.IP(b).String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported address type %d", atyp)
|
||||
}
|
||||
}
|
||||
|
||||
func parseSocksUDP(packet []byte) (socksRequest, []byte, error) {
|
||||
var req socksRequest
|
||||
if len(packet) < 4 {
|
||||
return req, nil, errors.New("short socks UDP packet")
|
||||
}
|
||||
if packet[0] != 0 || packet[1] != 0 {
|
||||
return req, nil, errors.New("invalid socks UDP reserved field")
|
||||
}
|
||||
if packet[2] != 0 {
|
||||
return req, nil, errors.New("fragmented socks UDP packets are not supported")
|
||||
}
|
||||
atyp := packet[3]
|
||||
off := 4
|
||||
switch atyp {
|
||||
case socksAtypIPv4:
|
||||
if len(packet) < off+4+2 {
|
||||
return req, nil, errors.New("short socks UDP ipv4 packet")
|
||||
}
|
||||
req.Host = net.IP(packet[off : off+4]).String()
|
||||
off += 4
|
||||
case socksAtypDomain:
|
||||
if len(packet) < off+1 {
|
||||
return req, nil, errors.New("short socks UDP domain packet")
|
||||
}
|
||||
ln := int(packet[off])
|
||||
off++
|
||||
if len(packet) < off+ln+2 {
|
||||
return req, nil, errors.New("short socks UDP domain payload")
|
||||
}
|
||||
req.Host = string(packet[off : off+ln])
|
||||
off += ln
|
||||
case socksAtypIPv6:
|
||||
if len(packet) < off+16+2 {
|
||||
return req, nil, errors.New("short socks UDP ipv6 packet")
|
||||
}
|
||||
req.Host = net.IP(packet[off : off+16]).String()
|
||||
off += 16
|
||||
default:
|
||||
return req, nil, fmt.Errorf("unsupported socks UDP address type %d", atyp)
|
||||
}
|
||||
req.Port = int(binary.BigEndian.Uint16(packet[off : off+2]))
|
||||
off += 2
|
||||
return req, packet[off:], nil
|
||||
}
|
||||
|
||||
func buildSocksUDP(addr socksRequest, payload []byte) ([]byte, error) {
|
||||
var out []byte
|
||||
out = append(out, 0, 0, 0)
|
||||
ip := net.ParseIP(addr.Host)
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
out = append(out, socksAtypIPv4)
|
||||
out = append(out, v4...)
|
||||
} else if v6 := ip.To16(); v6 != nil {
|
||||
out = append(out, socksAtypIPv6)
|
||||
out = append(out, v6...)
|
||||
} else {
|
||||
if len(addr.Host) > 255 {
|
||||
return nil, errors.New("socks UDP domain is too long")
|
||||
}
|
||||
out = append(out, socksAtypDomain, byte(len(addr.Host)))
|
||||
out = append(out, []byte(addr.Host)...)
|
||||
}
|
||||
pb := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(pb, uint16(addr.Port))
|
||||
out = append(out, pb...)
|
||||
out = append(out, payload...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func writeReply(c net.Conn, code byte) error {
|
||||
return writeReplyWithAddr(c, code, &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
}
|
||||
|
||||
func writeReplyWithAddr(c net.Conn, code byte, addr *net.UDPAddr) error {
|
||||
if addr == nil {
|
||||
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
|
||||
}
|
||||
ip := addr.IP
|
||||
if ip == nil || ip.IsUnspecified() {
|
||||
ip = net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
var resp []byte
|
||||
resp = append(resp, socksVersion, code, 0x00)
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
resp = append(resp, socksAtypIPv4)
|
||||
resp = append(resp, v4...)
|
||||
} else if v6 := ip.To16(); v6 != nil {
|
||||
resp = append(resp, socksAtypIPv6)
|
||||
resp = append(resp, v6...)
|
||||
} else {
|
||||
resp = append(resp, socksAtypIPv4, 127, 0, 0, 1)
|
||||
}
|
||||
pb := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(pb, uint16(addr.Port))
|
||||
resp = append(resp, pb...)
|
||||
_, err := c.Write(resp)
|
||||
return err
|
||||
}
|
||||
|
||||
func proxyCopy(a net.Conn, b net.Conn) {
|
||||
done := make(chan struct{}, 2)
|
||||
go func() { _, _ = io.Copy(a, b); done <- struct{}{} }()
|
||||
go func() { _, _ = io.Copy(b, a); done <- struct{}{} }()
|
||||
<-done
|
||||
}
|
||||
|
||||
func isExpectedProbeError(err error) bool {
|
||||
return errors.Is(err, io.EOF) || strings.Contains(err.Error(), "connection reset")
|
||||
}
|
||||
|
||||
func shouldLogUDPError(err error) bool {
|
||||
msg := err.Error()
|
||||
return !strings.Contains(msg, "unsupported UDP target")
|
||||
}
|
||||
151
internal/engine/ssh_direct_tcp.go
Normal file
151
internal/engine/ssh_direct_tcp.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type directTCPIPPayload struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
type sshChannelConn struct {
|
||||
ssh.Channel
|
||||
local net.Addr
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (c *sshChannelConn) LocalAddr() net.Addr {
|
||||
return c.local
|
||||
}
|
||||
|
||||
func (c *sshChannelConn) RemoteAddr() net.Addr {
|
||||
return c.remote
|
||||
}
|
||||
|
||||
func (c *sshChannelConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sshChannelConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sshChannelConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func dialSSHDirectTCP(client *ssh.Client, dest socksRequest, origin net.Addr, logger *Logger) (net.Conn, error) {
|
||||
addr := dest.Addr()
|
||||
remote, err := client.Dial("tcp", addr)
|
||||
if err == nil {
|
||||
return remote, nil
|
||||
}
|
||||
|
||||
if !shouldRetrySSHIPv6Bracket(dest.Host, err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
logger.Add("debug", "retrying SSH direct-tcpip IPv6 target with bracketed host [%s]:%d", strings.Trim(dest.Host, "[]"), dest.Port)
|
||||
}
|
||||
remote, retryErr := openSSHDirectTCPIP(client, bracketIPv6Host(dest.Host), dest.Port, origin)
|
||||
if retryErr == nil {
|
||||
return remote, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w; bracketed IPv6 retry failed: %v", err, retryErr)
|
||||
}
|
||||
|
||||
func dialSSHDirectTCPAddr(client *ssh.Client, addr string, origin net.Addr, logger *Logger) (net.Conn, error) {
|
||||
host, portText, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return client.Dial("tcp", addr)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialSSHDirectTCP(client, socksRequest{Host: host, Port: port}, origin, logger)
|
||||
}
|
||||
|
||||
func openSSHDirectTCPIP(client *ssh.Client, destHost string, destPort int, origin net.Addr) (net.Conn, error) {
|
||||
originHost, originPort := splitOriginAddr(origin)
|
||||
payload := ssh.Marshal(directTCPIPPayload{
|
||||
DestAddr: destHost,
|
||||
DestPort: uint32(destPort),
|
||||
OriginAddr: originHost,
|
||||
OriginPort: uint32(originPort),
|
||||
})
|
||||
ch, reqs, err := client.OpenChannel("direct-tcpip", payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
return &sshChannelConn{
|
||||
Channel: ch,
|
||||
local: tcpAddrOrDummy(originHost, originPort),
|
||||
remote: tcpAddrOrDummy(destHost, destPort),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func shouldRetrySSHIPv6Bracket(host string, err error) bool {
|
||||
if err == nil || !isIPv6Literal(host) {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "too many colons in address") || strings.Contains(msg, "missing port in address")
|
||||
}
|
||||
|
||||
func isIPv6Literal(host string) bool {
|
||||
host = strings.Trim(host, "[]")
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.To4() == nil && ip.To16() != nil
|
||||
}
|
||||
|
||||
func bracketIPv6Host(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host
|
||||
}
|
||||
return "[" + strings.Trim(host, "[]") + "]"
|
||||
}
|
||||
|
||||
func splitOriginAddr(addr net.Addr) (string, int) {
|
||||
if addr == nil {
|
||||
return "127.0.0.1", 0
|
||||
}
|
||||
host, portText, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return "127.0.0.1", 0
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
port = 0
|
||||
}
|
||||
if host == "" || host == "::" || host == "0.0.0.0" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func tcpAddrOrDummy(host string, port int) net.Addr {
|
||||
h := strings.Trim(host, "[]")
|
||||
ip := net.ParseIP(h)
|
||||
if ip != nil {
|
||||
return &net.TCPAddr{IP: ip, Port: port}
|
||||
}
|
||||
return dummyAddr(net.JoinHostPort(h, strconv.Itoa(port)))
|
||||
}
|
||||
|
||||
type dummyAddr string
|
||||
|
||||
func (d dummyAddr) Network() string { return "tcp" }
|
||||
func (d dummyAddr) String() string { return string(d) }
|
||||
345
internal/engine/sshclient.go
Normal file
345
internal/engine/sshclient.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"socksrevivepc/internal/config"
|
||||
)
|
||||
|
||||
type sshBundle struct {
|
||||
Client *ssh.Client
|
||||
Conn net.Conn
|
||||
ControlHosts []string
|
||||
}
|
||||
|
||||
type transportAttempt struct {
|
||||
Label string
|
||||
ProxyHost string
|
||||
ProxyPort int
|
||||
TLSHost string
|
||||
TLSPort int
|
||||
ControlHost string
|
||||
}
|
||||
|
||||
func connectSSH(ctx context.Context, p config.Profile, logger *Logger) (*sshBundle, error) {
|
||||
targetHost := p.SSH.Host
|
||||
targetPort := p.SSH.Port
|
||||
if p.Mode == config.ModeDNSTT {
|
||||
targetHost = p.DNSTT.LocalSSHHost
|
||||
targetPort = p.DNSTT.LocalSSHPort
|
||||
}
|
||||
|
||||
attempts := buildTransportAttempts(p, targetHost, targetPort)
|
||||
if len(attempts) == 0 {
|
||||
attempts = []transportAttempt{{Label: "default"}}
|
||||
}
|
||||
|
||||
logger.Add("info", "connecting SSH %s:%d using mode %s", targetHost, targetPort, p.Mode)
|
||||
var lastErr error
|
||||
for i, attempt := range attempts {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("info", "connection attempt %d/%d via %s", i+1, len(attempts), attempt.Label)
|
||||
}
|
||||
|
||||
conn, err := dialTransportAttempt(ctx, p, targetHost, targetPort, attempt, logger)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("warn", "connection attempt %d/%d failed before SSH handshake: %v", i+1, len(attempts), err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
bundle, err := finishSSHHandshake(conn, p, targetHost, targetPort, logger)
|
||||
if err == nil {
|
||||
bundle.ControlHosts = controlHostsForAttempt(p, targetHost, attempt)
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("info", "connection attempt %d/%d succeeded via %s", i+1, len(attempts), attempt.Label)
|
||||
}
|
||||
return bundle, nil
|
||||
}
|
||||
_ = conn.Close()
|
||||
lastErr = err
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("warn", "connection attempt %d/%d failed during SSH handshake: %v", i+1, len(attempts), err)
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("no transport attempts were available")
|
||||
}
|
||||
if len(attempts) > 1 {
|
||||
return nil, fmt.Errorf("all %d connection attempts failed; last error: %w", len(attempts), lastErr)
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func finishSSHHandshake(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (*sshBundle, error) {
|
||||
sshCfg := &ssh.ClientConfig{
|
||||
User: p.SSH.Username,
|
||||
Auth: []ssh.AuthMethod{ssh.Password(p.SSH.Password)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond,
|
||||
ClientVersion: "SSH-2.0-SocksRevivePC",
|
||||
}
|
||||
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
|
||||
cc, chans, reqs, err := ssh.NewClientConn(conn, addr, sshCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssh handshake failed: %w", err)
|
||||
}
|
||||
client := ssh.NewClient(cc, chans, reqs)
|
||||
logger.Add("info", "ssh authenticated as %s", p.SSH.Username)
|
||||
return &sshBundle{Client: client, Conn: conn}, nil
|
||||
}
|
||||
|
||||
func buildTransportAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
|
||||
switch p.Mode {
|
||||
case config.ModePayload:
|
||||
return payloadProxyAttempts(p, targetHost, targetPort)
|
||||
case config.ModeSSL, config.ModePayloadSSL:
|
||||
return tlsHostAttempts(p, targetHost, targetPort)
|
||||
default:
|
||||
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
|
||||
}
|
||||
}
|
||||
|
||||
func controlHostsForAttempt(p config.Profile, targetHost string, attempt transportAttempt) []string {
|
||||
switch p.Mode {
|
||||
case config.ModePayload:
|
||||
if attempt.ProxyHost != "" {
|
||||
return []string{attempt.ProxyHost}
|
||||
}
|
||||
return []string{targetHost}
|
||||
case config.ModeSSL, config.ModePayloadSSL:
|
||||
if attempt.TLSHost != "" {
|
||||
return []string{attempt.TLSHost}
|
||||
}
|
||||
if p.TLS.Host != "" {
|
||||
host, _ := hostPortWithDefault(p.TLS.Host, p.TLS.Port)
|
||||
if host != "" {
|
||||
return []string{host}
|
||||
}
|
||||
}
|
||||
return []string{targetHost}
|
||||
default:
|
||||
return []string{targetHost}
|
||||
}
|
||||
}
|
||||
|
||||
func payloadProxyAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
|
||||
proxyText := strings.TrimSpace(p.Proxy.Host)
|
||||
if proxyText == "" || p.Proxy.Port <= 0 {
|
||||
return []transportAttempt{{Label: "direct payload transport"}}
|
||||
}
|
||||
|
||||
rawHosts := splitHostList(proxyText)
|
||||
if len(rawHosts) == 0 {
|
||||
return []transportAttempt{{Label: "direct payload transport"}}
|
||||
}
|
||||
out := make([]transportAttempt, 0, len(rawHosts))
|
||||
for _, raw := range rawHosts {
|
||||
host, port := hostPortWithDefault(raw, p.Proxy.Port)
|
||||
if host == "" || port <= 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, transportAttempt{ProxyHost: host, ProxyPort: port, Label: "proxy " + net.JoinHostPort(host, fmt.Sprint(port))})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func tlsHostAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
|
||||
hostText := strings.TrimSpace(p.TLS.Host)
|
||||
defaultPort := p.TLS.Port
|
||||
if defaultPort <= 0 {
|
||||
defaultPort = targetPort
|
||||
}
|
||||
if hostText == "" {
|
||||
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
|
||||
}
|
||||
|
||||
rawHosts := splitHostList(hostText)
|
||||
if len(rawHosts) == 0 {
|
||||
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
|
||||
}
|
||||
out := make([]transportAttempt, 0, len(rawHosts))
|
||||
for _, raw := range rawHosts {
|
||||
host, port := hostPortWithDefault(raw, defaultPort)
|
||||
if host == "" || port <= 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, transportAttempt{TLSHost: host, TLSPort: port, Label: "TLS " + net.JoinHostPort(host, fmt.Sprint(port))})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func dialTransport(ctx context.Context, p config.Profile, targetHost string, targetPort int, logger *Logger) (net.Conn, error) {
|
||||
attempts := buildTransportAttempts(p, targetHost, targetPort)
|
||||
if len(attempts) == 0 {
|
||||
attempts = []transportAttempt{{Label: "default"}}
|
||||
}
|
||||
return dialTransportAttempt(ctx, p, targetHost, targetPort, attempts[0], logger)
|
||||
}
|
||||
|
||||
func dialTransportAttempt(ctx context.Context, p config.Profile, targetHost string, targetPort int, attempt transportAttempt, logger *Logger) (net.Conn, error) {
|
||||
d := &net.Dialer{Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond}
|
||||
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
|
||||
|
||||
switch p.Mode {
|
||||
case config.ModeDirect, config.ModeDNSTT:
|
||||
return d.DialContext(ctx, "tcp", addr)
|
||||
case config.ModeSSL:
|
||||
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
|
||||
case config.ModePayload:
|
||||
connectHost, connectPort := targetHost, targetPort
|
||||
if attempt.ProxyHost != "" && attempt.ProxyPort > 0 {
|
||||
connectHost, connectPort = attempt.ProxyHost, attempt.ProxyPort
|
||||
} else if p.Proxy.Host != "" && p.Proxy.Port > 0 {
|
||||
connectHost, connectPort = p.Proxy.Host, p.Proxy.Port
|
||||
}
|
||||
logger.Add("debug", "payload transport dialing %s", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
|
||||
conn, err := d.DialContext(ctx, "tcp", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
} else {
|
||||
conn = wrappedConn
|
||||
}
|
||||
return conn, nil
|
||||
case config.ModePayloadSSL:
|
||||
conn, err := dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
} else {
|
||||
conn = wrappedConn
|
||||
}
|
||||
return conn, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported mode %s", p.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func dialTLS(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int) (net.Conn, error) {
|
||||
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, transportAttempt{})
|
||||
}
|
||||
|
||||
func dialTLSAttempt(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int, attempt transportAttempt) (net.Conn, error) {
|
||||
host := strings.TrimSpace(attempt.TLSHost)
|
||||
port := attempt.TLSPort
|
||||
if host == "" {
|
||||
host = p.TLS.Host
|
||||
}
|
||||
if port <= 0 {
|
||||
port = p.TLS.Port
|
||||
}
|
||||
if host == "" {
|
||||
host = targetHost
|
||||
}
|
||||
if port == 0 {
|
||||
port = targetPort
|
||||
}
|
||||
serverName := p.TLS.ServerName
|
||||
if serverName == "" {
|
||||
serverName = host
|
||||
}
|
||||
raw, err := d.DialContext(ctx, "tcp", net.JoinHostPort(host, fmt.Sprint(port)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := &tls.Config{ServerName: serverName, InsecureSkipVerify: p.TLS.InsecureSkipVerify, MinVersion: tls.VersionTLS12}
|
||||
conn := tls.Client(raw, cfg)
|
||||
if err := conn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func splitHostList(s string) []string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.FieldsFunc(s, func(r rune) bool {
|
||||
switch r {
|
||||
case '#', ',', ';', '\n', '\r', '\t':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := map[string]struct{}{}
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[part]; ok {
|
||||
continue
|
||||
}
|
||||
seen[part] = struct{}{}
|
||||
out = append(out, part)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hostPortWithDefault(raw string, defaultPort int) (string, int) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", defaultPort
|
||||
}
|
||||
if strings.Contains(raw, "://") {
|
||||
// These fields are host fields, not URLs. Strip the scheme if a user pastes one.
|
||||
if i := strings.Index(raw, "://"); i >= 0 {
|
||||
raw = raw[i+3:]
|
||||
}
|
||||
}
|
||||
if h, p, err := net.SplitHostPort(raw); err == nil {
|
||||
pi, _ := strconv.Atoi(p)
|
||||
if pi > 0 {
|
||||
return strings.Trim(h, "[]"), pi
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(raw, "[") && strings.Contains(raw, "]:") {
|
||||
if h, p, err := net.SplitHostPort(raw); err == nil {
|
||||
pi, _ := strconv.Atoi(p)
|
||||
if pi > 0 {
|
||||
return strings.Trim(h, "[]"), pi
|
||||
}
|
||||
}
|
||||
}
|
||||
// host:port for IPv4/domain. IPv6 without brackets has multiple colons and
|
||||
// must keep the profile/default port.
|
||||
if strings.Count(raw, ":") == 1 {
|
||||
host, portText, ok := strings.Cut(raw, ":")
|
||||
if ok {
|
||||
if pi, err := strconv.Atoi(strings.TrimSpace(portText)); err == nil && pi > 0 {
|
||||
return strings.TrimSpace(host), pi
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Trim(raw, "[]"), defaultPort
|
||||
}
|
||||
353
internal/engine/udpgw_client.go
Normal file
353
internal/engine/udpgw_client.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
udpgwMaxFrame = 64 * 1024
|
||||
|
||||
udpgwProtocolBadVPN = "badvpn"
|
||||
udpgwProtocolLegacy = "legacy"
|
||||
|
||||
udpgwFlagKeepAlive = 1 << 0
|
||||
udpgwFlagRebind = 1 << 1
|
||||
udpgwFlagDNS = 1 << 2
|
||||
udpgwFlagIPv6 = 1 << 3
|
||||
)
|
||||
|
||||
type udpgwSession struct {
|
||||
mu sync.Mutex
|
||||
nextID uint16
|
||||
pairID map[string]uint16
|
||||
idClient map[uint16]*net.UDPAddr
|
||||
}
|
||||
|
||||
func (s *SocksServer) handleUDPGWAssociate(c net.Conn) {
|
||||
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
if err != nil {
|
||||
_ = writeReply(c, socksRepFail)
|
||||
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer udp.Close()
|
||||
|
||||
gwAddr := net.JoinHostPort(s.UDPGW.Host, strconv.Itoa(s.UDPGW.Port))
|
||||
gwConn, err := dialSSHDirectTCPAddr(s.SSH, gwAddr, c.RemoteAddr(), s.Logger)
|
||||
if err != nil {
|
||||
_ = writeReply(c, socksRepFail)
|
||||
s.Logger.Add("warn", "udpgw dial through SSH failed %s: %v", gwAddr, err)
|
||||
return
|
||||
}
|
||||
defer gwConn.Close()
|
||||
|
||||
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
|
||||
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = c.SetDeadline(time.Time{})
|
||||
proto := normalizeUDPGWProtocol(s.UDPGW.Protocol)
|
||||
s.Logger.Add("info", "SOCKS5 UDP associate using UDPGW %s over SSH; protocol=%s local UDP=%s", gwAddr, proto, udp.LocalAddr().String())
|
||||
|
||||
done := make(chan struct{})
|
||||
closeOnce := sync.Once{}
|
||||
stop := func() {
|
||||
closeOnce.Do(func() {
|
||||
_ = c.Close()
|
||||
_ = gwConn.Close()
|
||||
_ = udp.Close()
|
||||
close(done)
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(io.Discard, c)
|
||||
stop()
|
||||
}()
|
||||
|
||||
sess := &udpgwSession{
|
||||
nextID: 1,
|
||||
pairID: make(map[string]uint16),
|
||||
idClient: make(map[uint16]*net.UDPAddr),
|
||||
}
|
||||
|
||||
go s.readUDPGWReplies(done, proto, gwConn, udp, sess)
|
||||
s.forwardUDPToUDPGW(done, proto, gwConn, udp, sess)
|
||||
stop()
|
||||
}
|
||||
|
||||
func normalizeUDPGWProtocol(v string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case udpgwProtocolLegacy:
|
||||
return udpgwProtocolLegacy
|
||||
default:
|
||||
return udpgwProtocolBadVPN
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) forwardUDPToUDPGW(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
n, clientAddr, err := udp.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
s.Logger.Add("debug", "udpgw local UDP read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
addr, payload, err := parseSocksUDP(buf[:n])
|
||||
if err != nil {
|
||||
if shouldLogUDPError(err) {
|
||||
s.Logger.Add("debug", "udpgw parse SOCKS UDP failed: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
connID, isNew, err := sess.idFor(clientAddr, addr)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "udpgw session id failed for %s: %v", addr.Addr(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
frame, err := udpgwBuildClientFrame(proto, connID, isNew, addr, payload)
|
||||
if err != nil {
|
||||
if shouldLogUDPError(err) {
|
||||
s.Logger.Add("debug", "udpgw frame build failed for %s: %v", addr.Addr(), err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = gwConn.SetWriteDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := gwConn.Write(frame); err != nil {
|
||||
s.Logger.Add("warn", "udpgw TCP write failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) readUDPGWReplies(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
|
||||
br := bufio.NewReaderSize(gwConn, 256*1024)
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
payload, err := udpgwReadFrame(br, udpgwMaxFrame)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.Logger.Add("debug", "udpgw TCP read failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
connID, src, data, err := udpgwParseReplyPayload(proto, payload)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "udpgw bad reply frame: %v", err)
|
||||
continue
|
||||
}
|
||||
clientAddr := sess.clientFor(connID)
|
||||
if clientAddr == nil {
|
||||
continue
|
||||
}
|
||||
packet, err := buildSocksUDP(src, data)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "udpgw SOCKS UDP reply build failed: %v", err)
|
||||
continue
|
||||
}
|
||||
_, _ = udp.WriteToUDP(packet, clientAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpgwSession) idFor(client *net.UDPAddr, dest socksRequest) (uint16, bool, error) {
|
||||
if client == nil {
|
||||
return 0, false, errors.New("missing local UDP client address")
|
||||
}
|
||||
key := client.String() + "|" + dest.Addr()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if id, ok := s.pairID[key]; ok {
|
||||
s.idClient[id] = cloneUDPAddr(client)
|
||||
return id, false, nil
|
||||
}
|
||||
id := s.nextID
|
||||
if id == 0 {
|
||||
id = 1
|
||||
}
|
||||
s.nextID = id + 1
|
||||
if s.nextID == 0 {
|
||||
s.nextID = 1
|
||||
}
|
||||
s.pairID[key] = id
|
||||
s.idClient[id] = cloneUDPAddr(client)
|
||||
return id, true, nil
|
||||
}
|
||||
|
||||
func (s *udpgwSession) clientFor(id uint16) *net.UDPAddr {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return cloneUDPAddr(s.idClient[id])
|
||||
}
|
||||
|
||||
func cloneUDPAddr(a *net.UDPAddr) *net.UDPAddr {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
out := *a
|
||||
if a.IP != nil {
|
||||
out.IP = append(net.IP(nil), a.IP...)
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func udpgwReadFrame(r *bufio.Reader, max int) ([]byte, error) {
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := int(binary.LittleEndian.Uint16(lenBuf[:]))
|
||||
if n <= 0 || n > max {
|
||||
return nil, fmt.Errorf("invalid udpgw frame length %d", n)
|
||||
}
|
||||
b := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func udpgwBuildClientFrame(proto string, connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
|
||||
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
|
||||
return udpgwBuildLegacyFrame(connID, 0, dest, data)
|
||||
}
|
||||
return udpgwBuildBadVPNFrame(connID, isNew, dest, data)
|
||||
}
|
||||
|
||||
func udpgwParseReplyPayload(proto string, payload []byte) (uint16, socksRequest, []byte, error) {
|
||||
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
|
||||
return udpgwParseLegacyReplyPayload(payload)
|
||||
}
|
||||
return udpgwParseBadVPNReplyPayload(payload)
|
||||
}
|
||||
|
||||
// udpgwBuildBadVPNFrame implements the normal badvpn-udpgw PacketProto frame:
|
||||
// length(little-endian uint16) + flags(1) + conid(little-endian uint16) +
|
||||
// IPv4/IPv6 destination + destination port(network byte order) + UDP payload.
|
||||
// This is the same framing used by the Android badvpn UDPGW client and supports
|
||||
// IPv6 targets with UDPGW_CLIENT_FLAG_IPV6.
|
||||
func udpgwBuildBadVPNFrame(connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
|
||||
ip := net.ParseIP(dest.Host)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("badvpn UDPGW requires an IP target, got %q", dest.Host)
|
||||
}
|
||||
flags := byte(0)
|
||||
if isNew {
|
||||
flags |= udpgwFlagRebind
|
||||
}
|
||||
|
||||
var addr []byte
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
addr = append(addr, v4...)
|
||||
} else if v6 := ip.To16(); v6 != nil {
|
||||
flags |= udpgwFlagIPv6
|
||||
addr = append(addr, v6...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid UDPGW target IP %q", dest.Host)
|
||||
}
|
||||
|
||||
payloadLen := 1 + 2 + len(addr) + 2 + len(data)
|
||||
if payloadLen <= 0 || payloadLen > 65535 {
|
||||
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
|
||||
}
|
||||
out := make([]byte, 2+payloadLen)
|
||||
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
|
||||
out[2] = flags
|
||||
binary.LittleEndian.PutUint16(out[3:5], connID)
|
||||
copy(out[5:5+len(addr)], addr)
|
||||
portOff := 5 + len(addr)
|
||||
binary.BigEndian.PutUint16(out[portOff:portOff+2], uint16(dest.Port))
|
||||
copy(out[portOff+2:], data)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func udpgwParseBadVPNReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
|
||||
var src socksRequest
|
||||
if len(payload) < 1+2+4+2 {
|
||||
return 0, src, nil, fmt.Errorf("short badvpn udpgw payload %d", len(payload))
|
||||
}
|
||||
flags := payload[0]
|
||||
if flags&udpgwFlagKeepAlive != 0 {
|
||||
return 0, src, nil, errors.New("unexpected udpgw keepalive reply")
|
||||
}
|
||||
connID := binary.LittleEndian.Uint16(payload[1:3])
|
||||
off := 3
|
||||
if flags&udpgwFlagIPv6 != 0 {
|
||||
if len(payload) < off+16+2 {
|
||||
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv6 payload %d", len(payload))
|
||||
}
|
||||
src.Host = net.IP(payload[off : off+16]).String()
|
||||
off += 16
|
||||
} else {
|
||||
if len(payload) < off+4+2 {
|
||||
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv4 payload %d", len(payload))
|
||||
}
|
||||
src.Host = net.IP(payload[off : off+4]).String()
|
||||
off += 4
|
||||
}
|
||||
src.Port = int(binary.BigEndian.Uint16(payload[off : off+2]))
|
||||
off += 2
|
||||
return connID, src, payload[off:], nil
|
||||
}
|
||||
|
||||
// Legacy frame kept only for older experimental PC builds. It is IPv4-only.
|
||||
func udpgwBuildLegacyFrame(connID uint16, x byte, dest socksRequest, data []byte) ([]byte, error) {
|
||||
ip := net.ParseIP(dest.Host)
|
||||
v4 := ip.To4()
|
||||
if v4 == nil {
|
||||
return nil, fmt.Errorf("legacy UDPGW only supports IPv4 UDP targets; got %s", dest.Addr())
|
||||
}
|
||||
payloadLen := 2 + 1 + 4 + 2 + len(data)
|
||||
if payloadLen <= 0 || payloadLen > 65535 {
|
||||
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
|
||||
}
|
||||
out := make([]byte, 2+payloadLen)
|
||||
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
|
||||
binary.BigEndian.PutUint16(out[2:4], connID)
|
||||
out[4] = x
|
||||
copy(out[5:9], v4)
|
||||
binary.BigEndian.PutUint16(out[9:11], uint16(dest.Port))
|
||||
copy(out[11:], data)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func udpgwParseLegacyReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
|
||||
var src socksRequest
|
||||
if len(payload) < 2+1+4+2 {
|
||||
return 0, src, nil, fmt.Errorf("short legacy udpgw payload %d", len(payload))
|
||||
}
|
||||
connID := binary.BigEndian.Uint16(payload[0:2])
|
||||
src.Host = net.IP(payload[3:7]).String()
|
||||
src.Port = int(binary.BigEndian.Uint16(payload[7:9]))
|
||||
return connID, src, payload[9:], nil
|
||||
}
|
||||
Reference in New Issue
Block a user