This commit is contained in:
2026-05-16 00:18:06 -03:00
commit 92941e68a2
66 changed files with 10352 additions and 0 deletions

45
internal/engine/logger.go Normal file
View File

@@ -0,0 +1,45 @@
package engine
import (
"fmt"
"sync"
"time"
)
type LogEntry struct {
ID int64 `json:"id"`
Time string `json:"time"`
Level string `json:"level"`
Message string `json:"message"`
}
type Logger struct {
mu sync.Mutex
nextID int64
entries []LogEntry
}
func NewLogger() *Logger { return &Logger{} }
func (l *Logger) Add(level, format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.nextID++
entry := LogEntry{ID: l.nextID, Time: time.Now().Format("15:04:05"), Level: level, Message: fmt.Sprintf(format, args...)}
l.entries = append(l.entries, entry)
if len(l.entries) > 600 {
l.entries = l.entries[len(l.entries)-600:]
}
}
func (l *Logger) Since(id int64) []LogEntry {
l.mu.Lock()
defer l.mu.Unlock()
out := make([]LogEntry, 0)
for _, e := range l.entries {
if e.ID > id {
out = append(out, e)
}
}
return out
}

525
internal/engine/manager.go Normal file
View File

@@ -0,0 +1,525 @@
package engine
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"socksrevivepc/internal/config"
"socksrevivepc/internal/dnsttclient"
"socksrevivepc/internal/routes"
"socksrevivepc/internal/tun"
)
type Status struct {
Running bool `json:"running"`
Connecting bool `json:"connecting"`
ProfileID string `json:"profile_id"`
Mode string `json:"mode"`
SocksAddr string `json:"socks_addr"`
Tun bool `json:"tun"`
StartedAt string `json:"started_at"`
}
type Manager struct {
root string
mu sync.Mutex
logger *Logger
status Status
cancel context.CancelFunc
ssh *sshBundle
socks *SocksServer
dnstt *ManagedProcess
embeddedDNSTT *dnsttclient.Client
xray *ManagedProcess
tun *tun.Runner
routeCleanup *routes.Cleanup
manualStop bool
reconnecting bool
}
func NewManager(root string) *Manager {
lg := NewLogger()
return &Manager{root: root, logger: lg, tun: tun.NewRunner(lg)}
}
func (m *Manager) LogsSince(id int64) []LogEntry { return m.logger.Since(id) }
func (m *Manager) Status() Status {
m.mu.Lock()
defer m.mu.Unlock()
return m.status
}
func (m *Manager) Start(p config.Profile) error {
config.ApplyDefaults(&p)
if err := config.Validate(p); err != nil {
return err
}
ctx, cancel := context.WithCancel(context.Background())
m.mu.Lock()
if m.status.Running || m.status.Connecting {
m.mu.Unlock()
cancel()
return fmt.Errorf("another profile is already running or connecting")
}
m.cancel = cancel
m.manualStop = false
m.status = Status{Connecting: true, ProfileID: p.ID, Mode: string(p.Mode), Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
m.mu.Unlock()
m.logger.Add("info", "starting profile: %s", p.Name)
fail := func(format string, args ...any) error {
err := fmt.Errorf(format, args...)
m.logger.Add("error", "%v", err)
m.stop(false)
return err
}
ensureCurrent := func() error {
if err := ctx.Err(); err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil || !m.status.Connecting || m.status.ProfileID != p.ID {
return context.Canceled
}
return nil
}
var socksAddr string
var bypass []string
if p.Mode == config.ModeXray {
proc, err := StartProcess(ctx, m.root, "xray", p.Xray.Executable, p.Xray.Args, m.logger)
if err != nil {
return fail("start xray failed: %w", err)
}
m.mu.Lock()
m.xray = proc
m.mu.Unlock()
time.Sleep(time.Duration(p.Xray.StartupTimeoutMs) * time.Millisecond)
if err := ensureCurrent(); err != nil {
proc.Stop()
return fail("connection cancelled: %w", err)
}
socksAddr = net.JoinHostPort(p.Xray.LocalSocksHost, fmt.Sprint(p.Xray.LocalSocksPort))
if exited, procErr := proc.Exited(); exited {
m.stop(false)
if procErr != nil {
return fmt.Errorf("xray exited before opening local SOCKS on %s: %w", socksAddr, procErr)
}
return fmt.Errorf("xray exited before opening local SOCKS on %s", socksAddr)
}
if err := waitForTCP(ctx, socksAddr, time.Duration(p.Xray.StartupTimeoutMs)*time.Millisecond); err != nil {
return fail("xray local SOCKS is not listening on %s: %w", socksAddr, err)
}
bypass = nil
} else {
if p.Mode == config.ModeDNSTT {
if p.DNSTT.UseEmbedded {
client, err := dnsttclient.Start(ctx, dnsttclient.Options{
ResolverType: p.DNSTT.ResolverType,
ResolverAddress: p.DNSTT.ResolverAddress,
PublicKeyHex: p.DNSTT.PublicKey,
Domain: p.DNSTT.Domain,
LocalAddress: net.JoinHostPort(p.DNSTT.LocalSSHHost, fmt.Sprint(p.DNSTT.LocalSSHPort)),
UTLSDistribution: p.DNSTT.UTLSDistribution,
StartupTimeout: time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond,
LogWriter: dnsttLogWriter{logger: m.logger},
})
if err != nil {
return fail("start embedded dnstt failed: %w", err)
}
m.mu.Lock()
m.embeddedDNSTT = client
m.mu.Unlock()
} else {
proc, err := StartProcess(ctx, m.root, "dnstt", p.DNSTT.Executable, p.DNSTT.Args, m.logger)
if err != nil {
return fail("start dnstt failed: %w", err)
}
m.mu.Lock()
m.dnstt = proc
m.mu.Unlock()
time.Sleep(time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond)
}
if err := ensureCurrent(); err != nil {
return fail("connection cancelled: %w", err)
}
}
sshc, err := connectSSH(ctx, p, m.logger)
if err != nil {
return fail("%w", err)
}
if err := ensureCurrent(); err != nil {
_ = sshc.Client.Close()
_ = sshc.Conn.Close()
return fail("connection cancelled: %w", err)
}
m.mu.Lock()
m.ssh = sshc
m.mu.Unlock()
socksAddr = net.JoinHostPort(p.Local.SocksHost, fmt.Sprint(p.Local.SocksPort))
ss := &SocksServer{Addr: socksAddr, SSH: sshc.Client, Logger: m.logger, DNS: profileDNSServers(p), UDPGW: p.UDPGW}
if err := ss.Start(); err != nil {
return fail("%w", err)
}
m.mu.Lock()
m.socks = ss
m.status.SocksAddr = socksAddr
m.mu.Unlock()
if err := waitForTCP(ctx, socksAddr, 2500*time.Millisecond); err != nil {
return fail("local SOCKS is not listening on %s: %w", socksAddr, err)
}
bypass = effectiveBypassHosts(p, sshc.ControlHosts)
}
if err := ensureCurrent(); err != nil {
return fail("connection cancelled: %w", err)
}
if p.Tun.Enabled {
if err := m.tun.Start(&p, socksAddr); err != nil {
return fail("%w", err)
}
cleanup, err := routes.Apply(p, bypass, m.logger)
if err != nil {
m.logger.Add("warn", "route setup error: %v", err)
}
m.mu.Lock()
m.routeCleanup = cleanup
m.mu.Unlock()
}
m.mu.Lock()
current := m.cancel != nil && m.status.Connecting && m.status.ProfileID == p.ID
err := ctx.Err()
if err == nil && current {
m.status = Status{Running: true, ProfileID: p.ID, Mode: string(p.Mode), SocksAddr: socksAddr, Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
}
m.mu.Unlock()
if err != nil || !current {
m.stop(false)
if err != nil {
return err
}
return context.Canceled
}
m.logger.Add("info", "profile is connected; local socks=%s tun=%v", socksAddr, p.Tun.Enabled)
m.startMonitor(ctx, p)
return nil
}
func profileDNSServers(p config.Profile) []string {
out := append([]string{}, p.Tun.DNS...)
if p.Tun.IPv6Enabled {
out = append(out, p.Tun.IPv6DNS...)
}
return out
}
func waitForTCP(ctx context.Context, addr string, timeout time.Duration) error {
if timeout <= 0 {
timeout = 5 * time.Second
}
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
d := net.Dialer{Timeout: 350 * time.Millisecond}
c, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
_ = c.Close()
return nil
}
lastErr = err
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(150 * time.Millisecond):
}
}
if lastErr != nil {
return lastErr
}
return fmt.Errorf("timeout waiting for %s", addr)
}
func effectiveBypassHosts(p config.Profile, activeControlHosts []string) []string {
seen := map[string]bool{}
out := make([]string, 0, len(activeControlHosts)+1)
add := func(host string) {
host = strings.TrimSpace(host)
if host == "" {
return
}
if h, _, err := net.SplitHostPort(host); err == nil {
host = strings.Trim(h, "[]")
}
if isLocalBypassHost(host) || seen[host] {
return
}
seen[host] = true
out = append(out, host)
}
for _, host := range activeControlHosts {
add(host)
}
if len(out) == 0 && p.Mode != config.ModeDNSTT {
// Fallback for old profile formats or unexpected transports. This still only
// adds the direct SSH host, not every rotated proxy from the profile.
add(p.SSH.Host)
}
if p.Mode == config.ModeDNSTT && p.DNSTT.UseEmbedded {
add(dnsttResolverHost(p.DNSTT.ResolverAddress))
}
return out
}
func isLocalBypassHost(host string) bool {
host = strings.Trim(strings.ToLower(strings.TrimSpace(host)), "[]")
if host == "" || host == "localhost" {
return true
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback() || ip.IsUnspecified()
}
func dnsttResolverHost(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if strings.Contains(s, "://") {
u, err := url.Parse(s)
if err == nil {
return u.Hostname()
}
}
host, _, err := net.SplitHostPort(s)
if err == nil {
return host
}
return s
}
func (m *Manager) Stop() {
m.stop(true)
}
func (m *Manager) stop(manual bool) {
m.mu.Lock()
defer m.mu.Unlock()
if manual {
m.manualStop = true
}
m.stopLocked()
}
func (m *Manager) startMonitor(ctx context.Context, p config.Profile) {
if !p.Reconnect.Enabled {
return
}
interval := time.Duration(p.Reconnect.CheckIntervalSeconds) * time.Second
if interval <= 0 {
interval = 10 * time.Second
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if m.isManualStopped() {
return
}
if err := m.probeConnection(p); err != nil {
m.logger.Add("warn", "connection monitor detected tunnel loss: %v", err)
m.reconnectLoop(p, err)
return
}
}
}
}()
}
func (m *Manager) isManualStopped() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.manualStop
}
func (m *Manager) markReconnecting() bool {
m.mu.Lock()
defer m.mu.Unlock()
if m.manualStop || m.reconnecting {
return false
}
m.reconnecting = true
return true
}
func (m *Manager) clearReconnecting() {
m.mu.Lock()
m.reconnecting = false
m.mu.Unlock()
}
func (m *Manager) probeConnection(p config.Profile) error {
m.mu.Lock()
sshc := m.ssh
xray := m.xray
socksAddr := m.status.SocksAddr
running := m.status.Running
m.mu.Unlock()
if !running {
return nil
}
if xray != nil {
if exited, err := xray.Exited(); exited {
if err != nil {
return fmt.Errorf("xray exited: %w", err)
}
return fmt.Errorf("xray exited")
}
if socksAddr != "" {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
return waitForTCP(ctx, socksAddr, 2*time.Second)
}
return nil
}
if sshc != nil && sshc.Client != nil {
if sshc.Conn != nil {
_ = sshc.Conn.SetDeadline(time.Now().Add(5 * time.Second))
defer sshc.Conn.SetDeadline(time.Time{})
}
_, _, err := sshc.Client.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
return fmt.Errorf("ssh keepalive failed: %w", err)
}
}
return nil
}
func (m *Manager) reconnectLoop(p config.Profile, cause error) {
if !m.markReconnecting() {
return
}
defer m.clearReconnecting()
delay := time.Duration(p.Reconnect.DelaySeconds) * time.Second
if delay <= 0 {
delay = 3 * time.Second
}
maxRetries := p.Reconnect.MaxRetries
m.logger.Add("warn", "connection lost (%v); auto reconnect is enabled", cause)
if p.Tun.Enabled {
m.logger.Add("info", "destroying TUN before reconnect")
}
m.stop(false)
if p.Tun.Enabled {
time.Sleep(1200 * time.Millisecond)
}
for attempt := 1; maxRetries <= 0 || attempt <= maxRetries; attempt++ {
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
m.logger.Add("info", "reconnect attempt %d%s in %s", attempt, reconnectLimitSuffix(maxRetries), delay)
select {
case <-time.After(delay):
}
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
if err := m.Start(p); err != nil {
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
m.logger.Add("warn", "reconnect attempt %d failed: %v", attempt, err)
if p.Tun.Enabled {
m.logger.Add("info", "destroying TUN after failed reconnect attempt")
m.stop(false)
time.Sleep(1200 * time.Millisecond)
}
continue
}
m.logger.Add("info", "reconnected successfully")
return
}
m.logger.Add("error", "auto reconnect stopped after %d failed attempt(s)", maxRetries)
}
func reconnectLimitSuffix(maxRetries int) string {
if maxRetries <= 0 {
return " (unlimited)"
}
return fmt.Sprintf("/%d", maxRetries)
}
func (m *Manager) stopLocked() {
wasActive := m.status.Running || m.status.Connecting
if m.cancel != nil {
m.cancel()
m.cancel = nil
}
if m.routeCleanup != nil {
m.routeCleanup.Run()
m.routeCleanup = nil
}
if m.tun != nil {
m.tun.Stop()
}
if m.socks != nil {
m.socks.Stop()
m.socks = nil
}
if m.ssh != nil {
_ = m.ssh.Client.Close()
_ = m.ssh.Conn.Close()
m.ssh = nil
}
if m.xray != nil {
m.xray.Stop()
m.xray = nil
}
if m.embeddedDNSTT != nil {
m.embeddedDNSTT.Stop()
m.embeddedDNSTT = nil
}
if m.dnstt != nil {
m.dnstt.Stop()
m.dnstt = nil
}
if wasActive {
m.logger.Add("info", "disconnected")
}
m.status = Status{}
}
type dnsttLogWriter struct {
logger *Logger
}
func (w dnsttLogWriter) Write(p []byte) (int, error) {
line := strings.TrimSpace(string(p))
if line != "" && w.logger != nil {
w.logger.Add("dnstt", "%s", line)
}
return len(p), nil
}

399
internal/engine/payload.go Normal file
View File

@@ -0,0 +1,399 @@
package engine
import (
"bytes"
"fmt"
"io"
"math/rand"
"net"
"strconv"
"strings"
"time"
"socksrevivepc/internal/config"
)
const (
payloadHTTPStatusPeekTimeoutMs = 1500
payloadHTTPStatusLineLimit = 4096
maxPayloadHTTPResponsesToSkip = 12
maxPayloadHTTPHeaderLines = 80
maxPayloadHTTPBodyDiscardBytes = 1024 * 1024
)
type PayloadResult struct {
StatusLine string
StatusCode int
}
type httpResponseInfo struct {
contentLength int64
chunked bool
}
type preloadedConn struct {
net.Conn
preloaded *bytes.Reader
}
func (c *preloadedConn) Read(p []byte) (int, error) {
if c.preloaded != nil && c.preloaded.Len() > 0 {
return c.preloaded.Read(p)
}
return c.Conn.Read(p)
}
func wrapConnWithPreloadedBytes(conn net.Conn, b []byte) net.Conn {
if len(b) == 0 {
return conn
}
return &preloadedConn{Conn: conn, preloaded: bytes.NewReader(b)}
}
func WritePayload(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (PayloadResult, net.Conn, error) {
payload := buildPayload(p.Payload.Text, targetHost, targetPort)
parts, instant := splitPayload(payload)
for i, part := range parts {
if part == "" {
continue
}
if _, err := io.WriteString(conn, part); err != nil {
return PayloadResult{}, conn, err
}
if i < len(parts)-1 && !instant {
time.Sleep(time.Duration(p.Payload.SplitDelayMs) * time.Millisecond)
}
}
logger.Add("debug", "payload sent (%d bytes)", len(payload))
if !p.Payload.WaitForResponse {
return PayloadResult{}, conn, nil
}
return consumePayloadHTTPNegotiation(conn, p, payloadSourceLabel(p.Mode), logger)
}
func consumePayloadHTTPNegotiation(conn net.Conn, p config.Profile, source string, logger *Logger) (PayloadResult, net.Conn, error) {
defer conn.SetReadDeadline(time.Time{})
var last PayloadResult
var captured *bytes.Buffer
var sawSuccess bool
for attempt := 0; attempt < maxPayloadHTTPResponsesToSkip; attempt++ {
setPayloadReadDeadline(conn, p, attempt)
captured = &bytes.Buffer{}
line, err := readPayloadLinePreserveBytes(conn, captured, payloadHTTPStatusLineLimit)
if err != nil {
if isTimeoutErr(err) {
if last.StatusCode >= 400 && !p.Payload.AcceptAnyStatus && !sawSuccess {
return last, conn, fmt.Errorf("payload rejected with final status %d", last.StatusCode)
}
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
if err == io.EOF && captured.Len() > 0 {
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
if last.StatusCode > 0 {
return last, conn, nil
}
return PayloadResult{}, conn, fmt.Errorf("payload response read failed: %w", err)
}
cleanLine := strings.TrimSpace(line)
if cleanLine == "" {
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
if strings.HasPrefix(cleanLine, "SSH-") || !isHTTPStatusLine(cleanLine) {
return last, wrapConnWithPreloadedBytes(conn, captured.Bytes()), nil
}
code := parseStatusCode(cleanLine)
last = PayloadResult{StatusLine: cleanLine, StatusCode: code}
logProxyStatus(logger, source, code, cleanLine)
logHTTPCompatibilityStatus(logger, code, cleanLine)
if code == 101 || (code >= 200 && code < 400) {
sawSuccess = true
}
// The current bytes are confirmed HTTP/proxy negotiation bytes. Do not replay
// them to the SSH transport. Only replay bytes when we detect SSH/non-HTTP
// data or a partial line after timeout.
captured = nil
if err := consumePayloadHTTPHeadersAndBody(conn); err != nil {
if isTimeoutErr(err) {
return last, conn, nil
}
return last, conn, fmt.Errorf("payload response consume failed: %w", err)
}
// Keep peeking for another immediate HTTP status block. Some payload/proxy
// chains return several statuses (for example 403 -> 403 -> 101). Returning
// after the first status can make SSH read HTTP text instead of SSH-2.0.
}
if last.StatusCode >= 400 && !p.Payload.AcceptAnyStatus && !sawSuccess {
return last, conn, fmt.Errorf("payload rejected with final status %d", last.StatusCode)
}
return last, conn, nil
}
func setPayloadReadDeadline(conn net.Conn, p config.Profile, attempt int) {
timeoutMs := payloadHTTPStatusPeekTimeoutMs
if attempt == 0 && p.Payload.ResponseTimeoutMs > 0 {
timeoutMs = p.Payload.ResponseTimeoutMs
if timeoutMs < payloadHTTPStatusPeekTimeoutMs {
timeoutMs = payloadHTTPStatusPeekTimeoutMs
}
}
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(timeoutMs) * time.Millisecond))
}
func readPayloadLinePreserveBytes(conn net.Conn, captured *bytes.Buffer, limit int) (string, error) {
var line bytes.Buffer
buf := make([]byte, 1)
for line.Len() < limit {
n, err := conn.Read(buf)
if n > 0 {
b := buf[0]
_ = captured.WriteByte(b)
_ = line.WriteByte(b)
if b == '\n' {
break
}
}
if err != nil {
if line.Len() > 0 && err == io.EOF {
return line.String(), nil
}
return "", err
}
if n == 0 {
continue
}
}
if line.Len() == 0 {
return "", io.EOF
}
return line.String(), nil
}
func consumePayloadHTTPHeadersAndBody(conn net.Conn) error {
info := httpResponseInfo{contentLength: -1}
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
ignored := &bytes.Buffer{}
line, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
if err != nil {
return err
}
clean := strings.TrimSpace(line)
if clean == "" {
break
}
lower := strings.ToLower(clean)
if strings.HasPrefix(lower, "content-length:") {
if n, err := strconv.ParseInt(strings.TrimSpace(clean[strings.Index(clean, ":")+1:]), 10, 64); err == nil {
info.contentLength = n
}
} else if strings.HasPrefix(lower, "transfer-encoding:") && strings.Contains(lower, "chunked") {
info.chunked = true
}
}
if info.chunked {
return discardPayloadChunkedBody(conn)
}
if info.contentLength > 0 {
return discardPayloadFixedLengthBody(conn, info.contentLength)
}
return nil
}
func discardPayloadFixedLengthBody(conn net.Conn, contentLength int64) error {
remaining := contentLength
if remaining > maxPayloadHTTPBodyDiscardBytes {
remaining = maxPayloadHTTPBodyDiscardBytes
}
buf := make([]byte, 4096)
for remaining > 0 {
toRead := int64(len(buf))
if remaining < toRead {
toRead = remaining
}
n, err := conn.Read(buf[:int(toRead)])
if n > 0 {
remaining -= int64(n)
}
if err != nil {
return err
}
}
return nil
}
func discardPayloadChunkedBody(conn net.Conn) error {
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
ignored := &bytes.Buffer{}
sizeLine, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
if err != nil {
return err
}
cleanSize := strings.TrimSpace(sizeLine)
if semi := strings.Index(cleanSize, ";"); semi >= 0 {
cleanSize = strings.TrimSpace(cleanSize[:semi])
}
chunkSize, err := strconv.ParseInt(cleanSize, 16, 64)
if err != nil {
return nil
}
if chunkSize == 0 {
return consumePayloadTrailingHeaders(conn)
}
if err := discardPayloadFixedLengthBody(conn, chunkSize); err != nil {
return err
}
crlf := &bytes.Buffer{}
_, _ = readPayloadLinePreserveBytes(conn, crlf, payloadHTTPStatusLineLimit)
}
return nil
}
func consumePayloadTrailingHeaders(conn net.Conn) error {
for i := 0; i < maxPayloadHTTPHeaderLines; i++ {
ignored := &bytes.Buffer{}
line, err := readPayloadLinePreserveBytes(conn, ignored, payloadHTTPStatusLineLimit)
if err != nil {
return err
}
if strings.TrimSpace(line) == "" {
return nil
}
}
return nil
}
func isTimeoutErr(err error) bool {
if err == nil {
return false
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return true
}
return false
}
func isHTTPStatusLine(statusLine string) bool {
clean := strings.ToUpper(strings.TrimSpace(statusLine))
return strings.HasPrefix(clean, "HTTP/1.") || strings.HasPrefix(clean, "HTTP/2") || strings.HasPrefix(clean, "HTTP/3")
}
func logProxyStatus(logger *Logger, source string, responseCode int, statusLine string) {
cleanLine := strings.TrimSpace(statusLine)
if cleanLine == "" {
return
}
if source == "" {
source = "PROXY"
}
logger.Add("info", "Proxy Status [%s]: %s", source, cleanLine)
}
func logHTTPCompatibilityStatus(logger *Logger, responseCode int, firstLine string) {
switch responseCode {
case 200:
logger.Add("info", "Status: 200 (Connection established) Successful")
case 101:
logger.Add("info", "replace 200 OK")
logger.Add("info", "HTTP/1.1 101 Websocket")
case 100:
logger.Add("info", "HTTP/1.1 100 Continue")
case 301, 302, 400, 401, 403, 404, 407, 429, 500, 502, 503, 504:
if strings.TrimSpace(firstLine) != "" {
logger.Add("info", "%s", strings.TrimSpace(firstLine))
} else {
logger.Add("info", "HTTP/1.1 %d", responseCode)
}
logger.Add("info", "replace 200 OK")
logger.Add("info", "Dragon Try!")
}
}
func payloadSourceLabel(mode config.Mode) string {
switch mode {
case config.ModePayload:
return "HTTP_PROXY"
case config.ModePayloadSSL:
return "SSL_PAYLOAD"
default:
return "PROXY"
}
}
func buildPayload(tpl, host string, port int) string {
portStr := strconv.Itoa(port)
repl := map[string]string{
"[host]": host,
"[port]": portStr,
"[host_port]": net.JoinHostPort(host, portStr),
"[crlf]": "\r\n",
"[cr]": "\r",
"[lf]": "\n",
"[protocol]": "HTTP/1.1",
"[method]": "CONNECT",
}
out := tpl
for k, v := range repl {
out = strings.ReplaceAll(out, k, v)
}
out = replaceRotate(out)
return out
}
func replaceRotate(s string) string {
for {
start := strings.Index(s, "[rotate=")
if start < 0 {
return s
}
end := strings.Index(s[start:], "]")
if end < 0 {
return s
}
end += start
body := strings.TrimPrefix(s[start:end+1], "[rotate=")
body = strings.TrimSuffix(body, "]")
choices := splitRotateChoices(body)
choice := ""
if len(choices) > 0 {
choice = strings.TrimSpace(choices[rand.Intn(len(choices))])
}
s = s[:start] + choice + s[end+1:]
}
}
func splitRotateChoices(body string) []string {
return strings.FieldsFunc(body, func(r rune) bool {
switch r {
case ';', '#', ',', '\n', '\r', '\t':
return true
default:
return false
}
})
}
func splitPayload(s string) ([]string, bool) {
instant := strings.Contains(s, "[instant_split]")
s = strings.ReplaceAll(s, "[instant_split]", "[split]")
parts := strings.Split(s, "[split]")
return parts, instant
}
func parseStatusCode(status string) int {
fields := strings.Fields(status)
if len(fields) < 2 {
return 0
}
code, _ := strconv.Atoi(fields[1])
return code
}

View File

@@ -0,0 +1,91 @@
package engine
import (
"bufio"
"context"
"errors"
"io"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"socksrevivepc/internal/oscmd"
)
type ManagedProcess struct {
cmd *exec.Cmd
name string
logger *Logger
mu sync.Mutex
done chan error
}
func StartProcess(ctx context.Context, root, name, exe string, args []string, logger *Logger) (*ManagedProcess, error) {
if strings.TrimSpace(exe) == "" {
return nil, errors.New(name + " executable path is empty")
}
if !filepath.IsAbs(exe) {
exe = filepath.Join(root, exe)
}
cmd := oscmd.CommandContext(ctx, exe, args...)
cmd.Dir = root
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
p := &ManagedProcess{cmd: cmd, name: name, logger: logger, done: make(chan error, 1)}
if err := cmd.Start(); err != nil {
return nil, err
}
logger.Add("info", "%s started: %s %s", name, exe, strings.Join(args, " "))
go p.pipe(stdout, "info")
go p.pipe(stderr, "warn")
go func() {
err := cmd.Wait()
p.done <- err
if err != nil {
logger.Add("warn", "%s stopped: %v", name, err)
} else {
logger.Add("info", "%s stopped", name)
}
}()
return p, nil
}
func (p *ManagedProcess) Exited() (bool, error) {
if p == nil || p.done == nil {
return true, nil
}
select {
case err := <-p.done:
return true, err
default:
return false, nil
}
}
func (p *ManagedProcess) pipe(r io.Reader, level string) {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if line != "" {
p.logger.Add(level, "%s: %s", p.name, line)
}
}
}
func (p *ManagedProcess) Stop() {
p.mu.Lock()
defer p.mu.Unlock()
if p == nil || p.cmd == nil || p.cmd.Process == nil {
return
}
if runtime.GOOS == "windows" {
_ = p.cmd.Process.Kill()
} else {
_ = p.cmd.Process.Signal(ioSignalInterrupt())
time.Sleep(500 * time.Millisecond)
_ = p.cmd.Process.Kill()
}
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package engine
import "os"
func ioSignalInterrupt() os.Signal { return os.Interrupt }

View File

@@ -0,0 +1,7 @@
//go:build windows
package engine
import "os"
func ioSignalInterrupt() os.Signal { return os.Interrupt }

462
internal/engine/socks5.go Normal file
View File

@@ -0,0 +1,462 @@
package engine
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
"socksrevivepc/internal/config"
)
const (
socksVersion = 0x05
socksCmdConnect = 0x01
socksCmdUDPAssoc = 0x03
socksAtypIPv4 = 0x01
socksAtypDomain = 0x03
socksAtypIPv6 = 0x04
socksRepOK = 0x00
socksRepFail = 0x01
socksRepUnsupported = 0x07
)
type SocksServer struct {
Addr string
SSH *ssh.Client
Logger *Logger
DNS []string
UDPGW config.UDPGWConfig
listener net.Listener
stopOnce sync.Once
}
type socksRequest struct {
Command byte
Host string
Port int
}
func (r socksRequest) Addr() string {
return net.JoinHostPort(r.Host, strconv.Itoa(r.Port))
}
func (s *SocksServer) Start() error {
ln, err := net.Listen("tcp", s.Addr)
if err != nil {
return err
}
s.listener = ln
s.Logger.Add("info", "local SOCKS5 listening on %s", s.Addr)
go s.acceptLoop()
return nil
}
func (s *SocksServer) Stop() {
s.stopOnce.Do(func() {
if s.listener != nil {
_ = s.listener.Close()
}
})
}
func (s *SocksServer) acceptLoop() {
for {
c, err := s.listener.Accept()
if err != nil {
return
}
go s.handle(c)
}
}
func (s *SocksServer) handle(c net.Conn) {
defer c.Close()
_ = c.SetDeadline(time.Now().Add(30 * time.Second))
if err := s.handshake(c); err != nil {
if !isExpectedProbeError(err) {
s.Logger.Add("debug", "socks handshake failed: %v", err)
}
return
}
req, err := readSocksRequest(c)
if err != nil {
s.Logger.Add("debug", "socks request failed: %v", err)
return
}
switch req.Command {
case socksCmdConnect:
s.handleConnect(c, req)
case socksCmdUDPAssoc:
if s.UDPGW.Enabled {
s.handleUDPGWAssociate(c)
} else {
s.handleDNSUDPAssociate(c)
}
default:
_ = writeReply(c, socksRepUnsupported)
s.Logger.Add("debug", "socks unsupported command %d", req.Command)
}
}
func (s *SocksServer) handleConnect(c net.Conn, req socksRequest) {
dest := req.Addr()
remote, err := dialSSHDirectTCP(s.SSH, req, c.RemoteAddr(), s.Logger)
if err != nil {
_ = writeReply(c, 0x05)
s.Logger.Add("warn", "ssh dial %s failed: %v", dest, err)
return
}
defer remote.Close()
_ = writeReply(c, socksRepOK)
_ = c.SetDeadline(time.Time{})
s.Logger.Add("debug", "socks connected %s", dest)
proxyCopy(c, remote)
}
func (s *SocksServer) handleDNSUDPAssociate(c net.Conn) {
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
_ = writeReply(c, socksRepFail)
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
return
}
defer udp.Close()
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
return
}
_ = c.SetDeadline(time.Time{})
s.Logger.Add("info", "local SOCKS5 UDP associate listening on %s for DNS over SSH", udp.LocalAddr().String())
done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, c)
close(done)
}()
buf := make([]byte, 64*1024)
for {
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
n, clientAddr, err := udp.ReadFromUDP(buf)
if err != nil {
select {
case <-done:
return
default:
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
s.Logger.Add("debug", "socks UDP read failed: %v", err)
return
}
resp, err := s.handleSocksUDPDatagram(buf[:n])
if err != nil {
if shouldLogUDPError(err) {
s.Logger.Add("debug", "%v", err)
}
continue
}
_, _ = udp.WriteToUDP(resp, clientAddr)
}
}
func (s *SocksServer) handleSocksUDPDatagram(packet []byte) ([]byte, error) {
addr, payload, err := parseSocksUDP(packet)
if err != nil {
return nil, err
}
if addr.Port != 53 {
return nil, fmt.Errorf("dropping unsupported UDP target %s; only DNS/UDP port 53 is proxied for SSH modes", addr.Addr())
}
answer, err := s.resolveDNSOverSSHTCP(addr, payload)
if err != nil {
return nil, fmt.Errorf("DNS over SSH failed for %s: %w", addr.Addr(), err)
}
return buildSocksUDP(addr, answer)
}
func (s *SocksServer) resolveDNSOverSSHTCP(addr socksRequest, query []byte) ([]byte, error) {
servers := s.dnsServers(addr)
var lastErr error
for _, server := range servers {
remote, err := dialSSHDirectTCPAddr(s.SSH, server, nil, s.Logger)
if err != nil {
lastErr = err
continue
}
_ = remote.SetDeadline(time.Now().Add(8 * time.Second))
resp, err := exchangeDNSTCP(remote, query)
_ = remote.Close()
if err == nil {
return resp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("no DNS server configured")
}
func (s *SocksServer) dnsServers(requested socksRequest) []string {
seen := map[string]bool{}
var out []string
add := func(v string) {
v = normalizeHostPort(v, "53")
if v == "" {
return
}
if !seen[v] {
seen[v] = true
out = append(out, v)
}
}
if requested.Host != "" {
add(net.JoinHostPort(requested.Host, strconv.Itoa(requested.Port)))
}
for _, dns := range s.DNS {
add(dns)
}
add("1.1.1.1:53")
add("8.8.8.8:53")
return out
}
func normalizeHostPort(v, defaultPort string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
if host, port, err := net.SplitHostPort(v); err == nil {
return net.JoinHostPort(host, port)
}
if ip := net.ParseIP(v); ip != nil {
return net.JoinHostPort(ip.String(), defaultPort)
}
if strings.Count(v, ":") == 1 {
host, port, ok := strings.Cut(v, ":")
if ok && host != "" && port != "" {
return net.JoinHostPort(host, port)
}
}
return net.JoinHostPort(v, defaultPort)
}
func exchangeDNSTCP(conn net.Conn, query []byte) ([]byte, error) {
if len(query) == 0 || len(query) > 65535 {
return nil, fmt.Errorf("invalid DNS query length %d", len(query))
}
prefix := make([]byte, 2)
binary.BigEndian.PutUint16(prefix, uint16(len(query)))
if _, err := conn.Write(append(prefix, query...)); err != nil {
return nil, err
}
if _, err := io.ReadFull(conn, prefix); err != nil {
return nil, err
}
ln := int(binary.BigEndian.Uint16(prefix))
if ln <= 0 || ln > 65535 {
return nil, fmt.Errorf("invalid DNS response length %d", ln)
}
resp := make([]byte, ln)
if _, err := io.ReadFull(conn, resp); err != nil {
return nil, err
}
return resp, nil
}
func (s *SocksServer) handshake(c net.Conn) error {
header := make([]byte, 2)
if _, err := io.ReadFull(c, header); err != nil {
return err
}
if header[0] != socksVersion {
return errors.New("not socks5")
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(c, methods); err != nil {
return err
}
_, err := c.Write([]byte{socksVersion, 0x00})
return err
}
func readSocksRequest(c net.Conn) (socksRequest, error) {
var req socksRequest
h := make([]byte, 4)
if _, err := io.ReadFull(c, h); err != nil {
return req, err
}
if h[0] != socksVersion {
return req, fmt.Errorf("invalid socks version %d", h[0])
}
req.Command = h[1]
host, err := readSocksHost(c, h[3])
if err != nil {
return req, err
}
pb := make([]byte, 2)
if _, err := io.ReadFull(c, pb); err != nil {
return req, err
}
req.Host = host
req.Port = int(binary.BigEndian.Uint16(pb))
return req, nil
}
func readSocksHost(r io.Reader, atyp byte) (string, error) {
switch atyp {
case socksAtypIPv4:
b := make([]byte, 4)
if _, err := io.ReadFull(r, b); err != nil {
return "", err
}
return net.IP(b).String(), nil
case socksAtypDomain:
l := []byte{0}
if _, err := io.ReadFull(r, l); err != nil {
return "", err
}
b := make([]byte, int(l[0]))
if _, err := io.ReadFull(r, b); err != nil {
return "", err
}
return string(b), nil
case socksAtypIPv6:
b := make([]byte, 16)
if _, err := io.ReadFull(r, b); err != nil {
return "", err
}
return net.IP(b).String(), nil
default:
return "", fmt.Errorf("unsupported address type %d", atyp)
}
}
func parseSocksUDP(packet []byte) (socksRequest, []byte, error) {
var req socksRequest
if len(packet) < 4 {
return req, nil, errors.New("short socks UDP packet")
}
if packet[0] != 0 || packet[1] != 0 {
return req, nil, errors.New("invalid socks UDP reserved field")
}
if packet[2] != 0 {
return req, nil, errors.New("fragmented socks UDP packets are not supported")
}
atyp := packet[3]
off := 4
switch atyp {
case socksAtypIPv4:
if len(packet) < off+4+2 {
return req, nil, errors.New("short socks UDP ipv4 packet")
}
req.Host = net.IP(packet[off : off+4]).String()
off += 4
case socksAtypDomain:
if len(packet) < off+1 {
return req, nil, errors.New("short socks UDP domain packet")
}
ln := int(packet[off])
off++
if len(packet) < off+ln+2 {
return req, nil, errors.New("short socks UDP domain payload")
}
req.Host = string(packet[off : off+ln])
off += ln
case socksAtypIPv6:
if len(packet) < off+16+2 {
return req, nil, errors.New("short socks UDP ipv6 packet")
}
req.Host = net.IP(packet[off : off+16]).String()
off += 16
default:
return req, nil, fmt.Errorf("unsupported socks UDP address type %d", atyp)
}
req.Port = int(binary.BigEndian.Uint16(packet[off : off+2]))
off += 2
return req, packet[off:], nil
}
func buildSocksUDP(addr socksRequest, payload []byte) ([]byte, error) {
var out []byte
out = append(out, 0, 0, 0)
ip := net.ParseIP(addr.Host)
if v4 := ip.To4(); v4 != nil {
out = append(out, socksAtypIPv4)
out = append(out, v4...)
} else if v6 := ip.To16(); v6 != nil {
out = append(out, socksAtypIPv6)
out = append(out, v6...)
} else {
if len(addr.Host) > 255 {
return nil, errors.New("socks UDP domain is too long")
}
out = append(out, socksAtypDomain, byte(len(addr.Host)))
out = append(out, []byte(addr.Host)...)
}
pb := make([]byte, 2)
binary.BigEndian.PutUint16(pb, uint16(addr.Port))
out = append(out, pb...)
out = append(out, payload...)
return out, nil
}
func writeReply(c net.Conn, code byte) error {
return writeReplyWithAddr(c, code, &net.UDPAddr{IP: net.IPv4zero, Port: 0})
}
func writeReplyWithAddr(c net.Conn, code byte, addr *net.UDPAddr) error {
if addr == nil {
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
}
ip := addr.IP
if ip == nil || ip.IsUnspecified() {
ip = net.IPv4(127, 0, 0, 1)
}
var resp []byte
resp = append(resp, socksVersion, code, 0x00)
if v4 := ip.To4(); v4 != nil {
resp = append(resp, socksAtypIPv4)
resp = append(resp, v4...)
} else if v6 := ip.To16(); v6 != nil {
resp = append(resp, socksAtypIPv6)
resp = append(resp, v6...)
} else {
resp = append(resp, socksAtypIPv4, 127, 0, 0, 1)
}
pb := make([]byte, 2)
binary.BigEndian.PutUint16(pb, uint16(addr.Port))
resp = append(resp, pb...)
_, err := c.Write(resp)
return err
}
func proxyCopy(a net.Conn, b net.Conn) {
done := make(chan struct{}, 2)
go func() { _, _ = io.Copy(a, b); done <- struct{}{} }()
go func() { _, _ = io.Copy(b, a); done <- struct{}{} }()
<-done
}
func isExpectedProbeError(err error) bool {
return errors.Is(err, io.EOF) || strings.Contains(err.Error(), "connection reset")
}
func shouldLogUDPError(err error) bool {
msg := err.Error()
return !strings.Contains(msg, "unsupported UDP target")
}

View File

@@ -0,0 +1,151 @@
package engine
import (
"fmt"
"net"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
type directTCPIPPayload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
type sshChannelConn struct {
ssh.Channel
local net.Addr
remote net.Addr
}
func (c *sshChannelConn) LocalAddr() net.Addr {
return c.local
}
func (c *sshChannelConn) RemoteAddr() net.Addr {
return c.remote
}
func (c *sshChannelConn) SetDeadline(time.Time) error {
return nil
}
func (c *sshChannelConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *sshChannelConn) SetWriteDeadline(time.Time) error {
return nil
}
func dialSSHDirectTCP(client *ssh.Client, dest socksRequest, origin net.Addr, logger *Logger) (net.Conn, error) {
addr := dest.Addr()
remote, err := client.Dial("tcp", addr)
if err == nil {
return remote, nil
}
if !shouldRetrySSHIPv6Bracket(dest.Host, err) {
return nil, err
}
if logger != nil {
logger.Add("debug", "retrying SSH direct-tcpip IPv6 target with bracketed host [%s]:%d", strings.Trim(dest.Host, "[]"), dest.Port)
}
remote, retryErr := openSSHDirectTCPIP(client, bracketIPv6Host(dest.Host), dest.Port, origin)
if retryErr == nil {
return remote, nil
}
return nil, fmt.Errorf("%w; bracketed IPv6 retry failed: %v", err, retryErr)
}
func dialSSHDirectTCPAddr(client *ssh.Client, addr string, origin net.Addr, logger *Logger) (net.Conn, error) {
host, portText, err := net.SplitHostPort(addr)
if err != nil {
return client.Dial("tcp", addr)
}
port, err := strconv.Atoi(portText)
if err != nil {
return nil, err
}
return dialSSHDirectTCP(client, socksRequest{Host: host, Port: port}, origin, logger)
}
func openSSHDirectTCPIP(client *ssh.Client, destHost string, destPort int, origin net.Addr) (net.Conn, error) {
originHost, originPort := splitOriginAddr(origin)
payload := ssh.Marshal(directTCPIPPayload{
DestAddr: destHost,
DestPort: uint32(destPort),
OriginAddr: originHost,
OriginPort: uint32(originPort),
})
ch, reqs, err := client.OpenChannel("direct-tcpip", payload)
if err != nil {
return nil, err
}
go ssh.DiscardRequests(reqs)
return &sshChannelConn{
Channel: ch,
local: tcpAddrOrDummy(originHost, originPort),
remote: tcpAddrOrDummy(destHost, destPort),
}, nil
}
func shouldRetrySSHIPv6Bracket(host string, err error) bool {
if err == nil || !isIPv6Literal(host) {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "too many colons in address") || strings.Contains(msg, "missing port in address")
}
func isIPv6Literal(host string) bool {
host = strings.Trim(host, "[]")
ip := net.ParseIP(host)
return ip != nil && ip.To4() == nil && ip.To16() != nil
}
func bracketIPv6Host(host string) string {
host = strings.TrimSpace(host)
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host
}
return "[" + strings.Trim(host, "[]") + "]"
}
func splitOriginAddr(addr net.Addr) (string, int) {
if addr == nil {
return "127.0.0.1", 0
}
host, portText, err := net.SplitHostPort(addr.String())
if err != nil {
return "127.0.0.1", 0
}
port, err := strconv.Atoi(portText)
if err != nil {
port = 0
}
if host == "" || host == "::" || host == "0.0.0.0" {
host = "127.0.0.1"
}
return host, port
}
func tcpAddrOrDummy(host string, port int) net.Addr {
h := strings.Trim(host, "[]")
ip := net.ParseIP(h)
if ip != nil {
return &net.TCPAddr{IP: ip, Port: port}
}
return dummyAddr(net.JoinHostPort(h, strconv.Itoa(port)))
}
type dummyAddr string
func (d dummyAddr) Network() string { return "tcp" }
func (d dummyAddr) String() string { return string(d) }

View File

@@ -0,0 +1,345 @@
package engine
import (
"context"
"crypto/tls"
"fmt"
"net"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"socksrevivepc/internal/config"
)
type sshBundle struct {
Client *ssh.Client
Conn net.Conn
ControlHosts []string
}
type transportAttempt struct {
Label string
ProxyHost string
ProxyPort int
TLSHost string
TLSPort int
ControlHost string
}
func connectSSH(ctx context.Context, p config.Profile, logger *Logger) (*sshBundle, error) {
targetHost := p.SSH.Host
targetPort := p.SSH.Port
if p.Mode == config.ModeDNSTT {
targetHost = p.DNSTT.LocalSSHHost
targetPort = p.DNSTT.LocalSSHPort
}
attempts := buildTransportAttempts(p, targetHost, targetPort)
if len(attempts) == 0 {
attempts = []transportAttempt{{Label: "default"}}
}
logger.Add("info", "connecting SSH %s:%d using mode %s", targetHost, targetPort, p.Mode)
var lastErr error
for i, attempt := range attempts {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if len(attempts) > 1 {
logger.Add("info", "connection attempt %d/%d via %s", i+1, len(attempts), attempt.Label)
}
conn, err := dialTransportAttempt(ctx, p, targetHost, targetPort, attempt, logger)
if err != nil {
lastErr = err
if len(attempts) > 1 {
logger.Add("warn", "connection attempt %d/%d failed before SSH handshake: %v", i+1, len(attempts), err)
}
continue
}
bundle, err := finishSSHHandshake(conn, p, targetHost, targetPort, logger)
if err == nil {
bundle.ControlHosts = controlHostsForAttempt(p, targetHost, attempt)
if len(attempts) > 1 {
logger.Add("info", "connection attempt %d/%d succeeded via %s", i+1, len(attempts), attempt.Label)
}
return bundle, nil
}
_ = conn.Close()
lastErr = err
if len(attempts) > 1 {
logger.Add("warn", "connection attempt %d/%d failed during SSH handshake: %v", i+1, len(attempts), err)
}
}
if lastErr == nil {
lastErr = fmt.Errorf("no transport attempts were available")
}
if len(attempts) > 1 {
return nil, fmt.Errorf("all %d connection attempts failed; last error: %w", len(attempts), lastErr)
}
return nil, lastErr
}
func finishSSHHandshake(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (*sshBundle, error) {
sshCfg := &ssh.ClientConfig{
User: p.SSH.Username,
Auth: []ssh.AuthMethod{ssh.Password(p.SSH.Password)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond,
ClientVersion: "SSH-2.0-SocksRevivePC",
}
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
cc, chans, reqs, err := ssh.NewClientConn(conn, addr, sshCfg)
if err != nil {
return nil, fmt.Errorf("ssh handshake failed: %w", err)
}
client := ssh.NewClient(cc, chans, reqs)
logger.Add("info", "ssh authenticated as %s", p.SSH.Username)
return &sshBundle{Client: client, Conn: conn}, nil
}
func buildTransportAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
switch p.Mode {
case config.ModePayload:
return payloadProxyAttempts(p, targetHost, targetPort)
case config.ModeSSL, config.ModePayloadSSL:
return tlsHostAttempts(p, targetHost, targetPort)
default:
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
}
}
func controlHostsForAttempt(p config.Profile, targetHost string, attempt transportAttempt) []string {
switch p.Mode {
case config.ModePayload:
if attempt.ProxyHost != "" {
return []string{attempt.ProxyHost}
}
return []string{targetHost}
case config.ModeSSL, config.ModePayloadSSL:
if attempt.TLSHost != "" {
return []string{attempt.TLSHost}
}
if p.TLS.Host != "" {
host, _ := hostPortWithDefault(p.TLS.Host, p.TLS.Port)
if host != "" {
return []string{host}
}
}
return []string{targetHost}
default:
return []string{targetHost}
}
}
func payloadProxyAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
proxyText := strings.TrimSpace(p.Proxy.Host)
if proxyText == "" || p.Proxy.Port <= 0 {
return []transportAttempt{{Label: "direct payload transport"}}
}
rawHosts := splitHostList(proxyText)
if len(rawHosts) == 0 {
return []transportAttempt{{Label: "direct payload transport"}}
}
out := make([]transportAttempt, 0, len(rawHosts))
for _, raw := range rawHosts {
host, port := hostPortWithDefault(raw, p.Proxy.Port)
if host == "" || port <= 0 {
continue
}
out = append(out, transportAttempt{ProxyHost: host, ProxyPort: port, Label: "proxy " + net.JoinHostPort(host, fmt.Sprint(port))})
}
if len(out) == 0 {
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
}
return out
}
func tlsHostAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
hostText := strings.TrimSpace(p.TLS.Host)
defaultPort := p.TLS.Port
if defaultPort <= 0 {
defaultPort = targetPort
}
if hostText == "" {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
rawHosts := splitHostList(hostText)
if len(rawHosts) == 0 {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
out := make([]transportAttempt, 0, len(rawHosts))
for _, raw := range rawHosts {
host, port := hostPortWithDefault(raw, defaultPort)
if host == "" || port <= 0 {
continue
}
out = append(out, transportAttempt{TLSHost: host, TLSPort: port, Label: "TLS " + net.JoinHostPort(host, fmt.Sprint(port))})
}
if len(out) == 0 {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
return out
}
func dialTransport(ctx context.Context, p config.Profile, targetHost string, targetPort int, logger *Logger) (net.Conn, error) {
attempts := buildTransportAttempts(p, targetHost, targetPort)
if len(attempts) == 0 {
attempts = []transportAttempt{{Label: "default"}}
}
return dialTransportAttempt(ctx, p, targetHost, targetPort, attempts[0], logger)
}
func dialTransportAttempt(ctx context.Context, p config.Profile, targetHost string, targetPort int, attempt transportAttempt, logger *Logger) (net.Conn, error) {
d := &net.Dialer{Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond}
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
switch p.Mode {
case config.ModeDirect, config.ModeDNSTT:
return d.DialContext(ctx, "tcp", addr)
case config.ModeSSL:
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
case config.ModePayload:
connectHost, connectPort := targetHost, targetPort
if attempt.ProxyHost != "" && attempt.ProxyPort > 0 {
connectHost, connectPort = attempt.ProxyHost, attempt.ProxyPort
} else if p.Proxy.Host != "" && p.Proxy.Port > 0 {
connectHost, connectPort = p.Proxy.Host, p.Proxy.Port
}
logger.Add("debug", "payload transport dialing %s", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
conn, err := d.DialContext(ctx, "tcp", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
if err != nil {
return nil, err
}
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
_ = conn.Close()
return nil, err
} else {
conn = wrappedConn
}
return conn, nil
case config.ModePayloadSSL:
conn, err := dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
if err != nil {
return nil, err
}
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
_ = conn.Close()
return nil, err
} else {
conn = wrappedConn
}
return conn, nil
default:
return nil, fmt.Errorf("unsupported mode %s", p.Mode)
}
}
func dialTLS(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int) (net.Conn, error) {
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, transportAttempt{})
}
func dialTLSAttempt(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int, attempt transportAttempt) (net.Conn, error) {
host := strings.TrimSpace(attempt.TLSHost)
port := attempt.TLSPort
if host == "" {
host = p.TLS.Host
}
if port <= 0 {
port = p.TLS.Port
}
if host == "" {
host = targetHost
}
if port == 0 {
port = targetPort
}
serverName := p.TLS.ServerName
if serverName == "" {
serverName = host
}
raw, err := d.DialContext(ctx, "tcp", net.JoinHostPort(host, fmt.Sprint(port)))
if err != nil {
return nil, err
}
cfg := &tls.Config{ServerName: serverName, InsecureSkipVerify: p.TLS.InsecureSkipVerify, MinVersion: tls.VersionTLS12}
conn := tls.Client(raw, cfg)
if err := conn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func splitHostList(s string) []string {
s = strings.TrimSpace(s)
if s == "" {
return nil
}
parts := strings.FieldsFunc(s, func(r rune) bool {
switch r {
case '#', ',', ';', '\n', '\r', '\t':
return true
default:
return false
}
})
out := make([]string, 0, len(parts))
seen := map[string]struct{}{}
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if _, ok := seen[part]; ok {
continue
}
seen[part] = struct{}{}
out = append(out, part)
}
return out
}
func hostPortWithDefault(raw string, defaultPort int) (string, int) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", defaultPort
}
if strings.Contains(raw, "://") {
// These fields are host fields, not URLs. Strip the scheme if a user pastes one.
if i := strings.Index(raw, "://"); i >= 0 {
raw = raw[i+3:]
}
}
if h, p, err := net.SplitHostPort(raw); err == nil {
pi, _ := strconv.Atoi(p)
if pi > 0 {
return strings.Trim(h, "[]"), pi
}
}
if strings.HasPrefix(raw, "[") && strings.Contains(raw, "]:") {
if h, p, err := net.SplitHostPort(raw); err == nil {
pi, _ := strconv.Atoi(p)
if pi > 0 {
return strings.Trim(h, "[]"), pi
}
}
}
// host:port for IPv4/domain. IPv6 without brackets has multiple colons and
// must keep the profile/default port.
if strings.Count(raw, ":") == 1 {
host, portText, ok := strings.Cut(raw, ":")
if ok {
if pi, err := strconv.Atoi(strings.TrimSpace(portText)); err == nil && pi > 0 {
return strings.TrimSpace(host), pi
}
}
}
return strings.Trim(raw, "[]"), defaultPort
}

View File

@@ -0,0 +1,353 @@
package engine
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
)
const (
udpgwMaxFrame = 64 * 1024
udpgwProtocolBadVPN = "badvpn"
udpgwProtocolLegacy = "legacy"
udpgwFlagKeepAlive = 1 << 0
udpgwFlagRebind = 1 << 1
udpgwFlagDNS = 1 << 2
udpgwFlagIPv6 = 1 << 3
)
type udpgwSession struct {
mu sync.Mutex
nextID uint16
pairID map[string]uint16
idClient map[uint16]*net.UDPAddr
}
func (s *SocksServer) handleUDPGWAssociate(c net.Conn) {
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
_ = writeReply(c, socksRepFail)
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
return
}
defer udp.Close()
gwAddr := net.JoinHostPort(s.UDPGW.Host, strconv.Itoa(s.UDPGW.Port))
gwConn, err := dialSSHDirectTCPAddr(s.SSH, gwAddr, c.RemoteAddr(), s.Logger)
if err != nil {
_ = writeReply(c, socksRepFail)
s.Logger.Add("warn", "udpgw dial through SSH failed %s: %v", gwAddr, err)
return
}
defer gwConn.Close()
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
return
}
_ = c.SetDeadline(time.Time{})
proto := normalizeUDPGWProtocol(s.UDPGW.Protocol)
s.Logger.Add("info", "SOCKS5 UDP associate using UDPGW %s over SSH; protocol=%s local UDP=%s", gwAddr, proto, udp.LocalAddr().String())
done := make(chan struct{})
closeOnce := sync.Once{}
stop := func() {
closeOnce.Do(func() {
_ = c.Close()
_ = gwConn.Close()
_ = udp.Close()
close(done)
})
}
go func() {
_, _ = io.Copy(io.Discard, c)
stop()
}()
sess := &udpgwSession{
nextID: 1,
pairID: make(map[string]uint16),
idClient: make(map[uint16]*net.UDPAddr),
}
go s.readUDPGWReplies(done, proto, gwConn, udp, sess)
s.forwardUDPToUDPGW(done, proto, gwConn, udp, sess)
stop()
}
func normalizeUDPGWProtocol(v string) string {
switch strings.ToLower(strings.TrimSpace(v)) {
case udpgwProtocolLegacy:
return udpgwProtocolLegacy
default:
return udpgwProtocolBadVPN
}
}
func (s *SocksServer) forwardUDPToUDPGW(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
buf := make([]byte, 64*1024)
for {
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
n, clientAddr, err := udp.ReadFromUDP(buf)
if err != nil {
select {
case <-done:
return
default:
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
s.Logger.Add("debug", "udpgw local UDP read failed: %v", err)
return
}
addr, payload, err := parseSocksUDP(buf[:n])
if err != nil {
if shouldLogUDPError(err) {
s.Logger.Add("debug", "udpgw parse SOCKS UDP failed: %v", err)
}
continue
}
connID, isNew, err := sess.idFor(clientAddr, addr)
if err != nil {
s.Logger.Add("debug", "udpgw session id failed for %s: %v", addr.Addr(), err)
continue
}
frame, err := udpgwBuildClientFrame(proto, connID, isNew, addr, payload)
if err != nil {
if shouldLogUDPError(err) {
s.Logger.Add("debug", "udpgw frame build failed for %s: %v", addr.Addr(), err)
}
continue
}
_ = gwConn.SetWriteDeadline(time.Now().Add(15 * time.Second))
if _, err := gwConn.Write(frame); err != nil {
s.Logger.Add("warn", "udpgw TCP write failed: %v", err)
return
}
}
}
func (s *SocksServer) readUDPGWReplies(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
br := bufio.NewReaderSize(gwConn, 256*1024)
for {
select {
case <-done:
return
default:
}
payload, err := udpgwReadFrame(br, udpgwMaxFrame)
if err != nil {
select {
case <-done:
return
default:
}
if !errors.Is(err, io.EOF) {
s.Logger.Add("debug", "udpgw TCP read failed: %v", err)
}
return
}
connID, src, data, err := udpgwParseReplyPayload(proto, payload)
if err != nil {
s.Logger.Add("debug", "udpgw bad reply frame: %v", err)
continue
}
clientAddr := sess.clientFor(connID)
if clientAddr == nil {
continue
}
packet, err := buildSocksUDP(src, data)
if err != nil {
s.Logger.Add("debug", "udpgw SOCKS UDP reply build failed: %v", err)
continue
}
_, _ = udp.WriteToUDP(packet, clientAddr)
}
}
func (s *udpgwSession) idFor(client *net.UDPAddr, dest socksRequest) (uint16, bool, error) {
if client == nil {
return 0, false, errors.New("missing local UDP client address")
}
key := client.String() + "|" + dest.Addr()
s.mu.Lock()
defer s.mu.Unlock()
if id, ok := s.pairID[key]; ok {
s.idClient[id] = cloneUDPAddr(client)
return id, false, nil
}
id := s.nextID
if id == 0 {
id = 1
}
s.nextID = id + 1
if s.nextID == 0 {
s.nextID = 1
}
s.pairID[key] = id
s.idClient[id] = cloneUDPAddr(client)
return id, true, nil
}
func (s *udpgwSession) clientFor(id uint16) *net.UDPAddr {
s.mu.Lock()
defer s.mu.Unlock()
return cloneUDPAddr(s.idClient[id])
}
func cloneUDPAddr(a *net.UDPAddr) *net.UDPAddr {
if a == nil {
return nil
}
out := *a
if a.IP != nil {
out.IP = append(net.IP(nil), a.IP...)
}
return &out
}
func udpgwReadFrame(r *bufio.Reader, max int) ([]byte, error) {
var lenBuf [2]byte
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
return nil, err
}
n := int(binary.LittleEndian.Uint16(lenBuf[:]))
if n <= 0 || n > max {
return nil, fmt.Errorf("invalid udpgw frame length %d", n)
}
b := make([]byte, n)
if _, err := io.ReadFull(r, b); err != nil {
return nil, err
}
return b, nil
}
func udpgwBuildClientFrame(proto string, connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
return udpgwBuildLegacyFrame(connID, 0, dest, data)
}
return udpgwBuildBadVPNFrame(connID, isNew, dest, data)
}
func udpgwParseReplyPayload(proto string, payload []byte) (uint16, socksRequest, []byte, error) {
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
return udpgwParseLegacyReplyPayload(payload)
}
return udpgwParseBadVPNReplyPayload(payload)
}
// udpgwBuildBadVPNFrame implements the normal badvpn-udpgw PacketProto frame:
// length(little-endian uint16) + flags(1) + conid(little-endian uint16) +
// IPv4/IPv6 destination + destination port(network byte order) + UDP payload.
// This is the same framing used by the Android badvpn UDPGW client and supports
// IPv6 targets with UDPGW_CLIENT_FLAG_IPV6.
func udpgwBuildBadVPNFrame(connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
ip := net.ParseIP(dest.Host)
if ip == nil {
return nil, fmt.Errorf("badvpn UDPGW requires an IP target, got %q", dest.Host)
}
flags := byte(0)
if isNew {
flags |= udpgwFlagRebind
}
var addr []byte
if v4 := ip.To4(); v4 != nil {
addr = append(addr, v4...)
} else if v6 := ip.To16(); v6 != nil {
flags |= udpgwFlagIPv6
addr = append(addr, v6...)
} else {
return nil, fmt.Errorf("invalid UDPGW target IP %q", dest.Host)
}
payloadLen := 1 + 2 + len(addr) + 2 + len(data)
if payloadLen <= 0 || payloadLen > 65535 {
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
}
out := make([]byte, 2+payloadLen)
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
out[2] = flags
binary.LittleEndian.PutUint16(out[3:5], connID)
copy(out[5:5+len(addr)], addr)
portOff := 5 + len(addr)
binary.BigEndian.PutUint16(out[portOff:portOff+2], uint16(dest.Port))
copy(out[portOff+2:], data)
return out, nil
}
func udpgwParseBadVPNReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
var src socksRequest
if len(payload) < 1+2+4+2 {
return 0, src, nil, fmt.Errorf("short badvpn udpgw payload %d", len(payload))
}
flags := payload[0]
if flags&udpgwFlagKeepAlive != 0 {
return 0, src, nil, errors.New("unexpected udpgw keepalive reply")
}
connID := binary.LittleEndian.Uint16(payload[1:3])
off := 3
if flags&udpgwFlagIPv6 != 0 {
if len(payload) < off+16+2 {
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv6 payload %d", len(payload))
}
src.Host = net.IP(payload[off : off+16]).String()
off += 16
} else {
if len(payload) < off+4+2 {
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv4 payload %d", len(payload))
}
src.Host = net.IP(payload[off : off+4]).String()
off += 4
}
src.Port = int(binary.BigEndian.Uint16(payload[off : off+2]))
off += 2
return connID, src, payload[off:], nil
}
// Legacy frame kept only for older experimental PC builds. It is IPv4-only.
func udpgwBuildLegacyFrame(connID uint16, x byte, dest socksRequest, data []byte) ([]byte, error) {
ip := net.ParseIP(dest.Host)
v4 := ip.To4()
if v4 == nil {
return nil, fmt.Errorf("legacy UDPGW only supports IPv4 UDP targets; got %s", dest.Addr())
}
payloadLen := 2 + 1 + 4 + 2 + len(data)
if payloadLen <= 0 || payloadLen > 65535 {
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
}
out := make([]byte, 2+payloadLen)
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
binary.BigEndian.PutUint16(out[2:4], connID)
out[4] = x
copy(out[5:9], v4)
binary.BigEndian.PutUint16(out[9:11], uint16(dest.Port))
copy(out[11:], data)
return out, nil
}
func udpgwParseLegacyReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
var src socksRequest
if len(payload) < 2+1+4+2 {
return 0, src, nil, fmt.Errorf("short legacy udpgw payload %d", len(payload))
}
connID := binary.BigEndian.Uint16(payload[0:2])
src.Host = net.IP(payload[3:7]).String()
src.Port = int(binary.BigEndian.Uint16(payload[7:9]))
return connID, src, payload[9:], nil
}