This commit is contained in:
2026-05-16 00:18:06 -03:00
commit 92941e68a2
66 changed files with 10352 additions and 0 deletions

View File

@@ -0,0 +1,345 @@
package engine
import (
"context"
"crypto/tls"
"fmt"
"net"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"socksrevivepc/internal/config"
)
type sshBundle struct {
Client *ssh.Client
Conn net.Conn
ControlHosts []string
}
type transportAttempt struct {
Label string
ProxyHost string
ProxyPort int
TLSHost string
TLSPort int
ControlHost string
}
func connectSSH(ctx context.Context, p config.Profile, logger *Logger) (*sshBundle, error) {
targetHost := p.SSH.Host
targetPort := p.SSH.Port
if p.Mode == config.ModeDNSTT {
targetHost = p.DNSTT.LocalSSHHost
targetPort = p.DNSTT.LocalSSHPort
}
attempts := buildTransportAttempts(p, targetHost, targetPort)
if len(attempts) == 0 {
attempts = []transportAttempt{{Label: "default"}}
}
logger.Add("info", "connecting SSH %s:%d using mode %s", targetHost, targetPort, p.Mode)
var lastErr error
for i, attempt := range attempts {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if len(attempts) > 1 {
logger.Add("info", "connection attempt %d/%d via %s", i+1, len(attempts), attempt.Label)
}
conn, err := dialTransportAttempt(ctx, p, targetHost, targetPort, attempt, logger)
if err != nil {
lastErr = err
if len(attempts) > 1 {
logger.Add("warn", "connection attempt %d/%d failed before SSH handshake: %v", i+1, len(attempts), err)
}
continue
}
bundle, err := finishSSHHandshake(conn, p, targetHost, targetPort, logger)
if err == nil {
bundle.ControlHosts = controlHostsForAttempt(p, targetHost, attempt)
if len(attempts) > 1 {
logger.Add("info", "connection attempt %d/%d succeeded via %s", i+1, len(attempts), attempt.Label)
}
return bundle, nil
}
_ = conn.Close()
lastErr = err
if len(attempts) > 1 {
logger.Add("warn", "connection attempt %d/%d failed during SSH handshake: %v", i+1, len(attempts), err)
}
}
if lastErr == nil {
lastErr = fmt.Errorf("no transport attempts were available")
}
if len(attempts) > 1 {
return nil, fmt.Errorf("all %d connection attempts failed; last error: %w", len(attempts), lastErr)
}
return nil, lastErr
}
func finishSSHHandshake(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (*sshBundle, error) {
sshCfg := &ssh.ClientConfig{
User: p.SSH.Username,
Auth: []ssh.AuthMethod{ssh.Password(p.SSH.Password)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond,
ClientVersion: "SSH-2.0-SocksRevivePC",
}
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
cc, chans, reqs, err := ssh.NewClientConn(conn, addr, sshCfg)
if err != nil {
return nil, fmt.Errorf("ssh handshake failed: %w", err)
}
client := ssh.NewClient(cc, chans, reqs)
logger.Add("info", "ssh authenticated as %s", p.SSH.Username)
return &sshBundle{Client: client, Conn: conn}, nil
}
func buildTransportAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
switch p.Mode {
case config.ModePayload:
return payloadProxyAttempts(p, targetHost, targetPort)
case config.ModeSSL, config.ModePayloadSSL:
return tlsHostAttempts(p, targetHost, targetPort)
default:
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
}
}
func controlHostsForAttempt(p config.Profile, targetHost string, attempt transportAttempt) []string {
switch p.Mode {
case config.ModePayload:
if attempt.ProxyHost != "" {
return []string{attempt.ProxyHost}
}
return []string{targetHost}
case config.ModeSSL, config.ModePayloadSSL:
if attempt.TLSHost != "" {
return []string{attempt.TLSHost}
}
if p.TLS.Host != "" {
host, _ := hostPortWithDefault(p.TLS.Host, p.TLS.Port)
if host != "" {
return []string{host}
}
}
return []string{targetHost}
default:
return []string{targetHost}
}
}
func payloadProxyAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
proxyText := strings.TrimSpace(p.Proxy.Host)
if proxyText == "" || p.Proxy.Port <= 0 {
return []transportAttempt{{Label: "direct payload transport"}}
}
rawHosts := splitHostList(proxyText)
if len(rawHosts) == 0 {
return []transportAttempt{{Label: "direct payload transport"}}
}
out := make([]transportAttempt, 0, len(rawHosts))
for _, raw := range rawHosts {
host, port := hostPortWithDefault(raw, p.Proxy.Port)
if host == "" || port <= 0 {
continue
}
out = append(out, transportAttempt{ProxyHost: host, ProxyPort: port, Label: "proxy " + net.JoinHostPort(host, fmt.Sprint(port))})
}
if len(out) == 0 {
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
}
return out
}
func tlsHostAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
hostText := strings.TrimSpace(p.TLS.Host)
defaultPort := p.TLS.Port
if defaultPort <= 0 {
defaultPort = targetPort
}
if hostText == "" {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
rawHosts := splitHostList(hostText)
if len(rawHosts) == 0 {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
out := make([]transportAttempt, 0, len(rawHosts))
for _, raw := range rawHosts {
host, port := hostPortWithDefault(raw, defaultPort)
if host == "" || port <= 0 {
continue
}
out = append(out, transportAttempt{TLSHost: host, TLSPort: port, Label: "TLS " + net.JoinHostPort(host, fmt.Sprint(port))})
}
if len(out) == 0 {
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
}
return out
}
func dialTransport(ctx context.Context, p config.Profile, targetHost string, targetPort int, logger *Logger) (net.Conn, error) {
attempts := buildTransportAttempts(p, targetHost, targetPort)
if len(attempts) == 0 {
attempts = []transportAttempt{{Label: "default"}}
}
return dialTransportAttempt(ctx, p, targetHost, targetPort, attempts[0], logger)
}
func dialTransportAttempt(ctx context.Context, p config.Profile, targetHost string, targetPort int, attempt transportAttempt, logger *Logger) (net.Conn, error) {
d := &net.Dialer{Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond}
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
switch p.Mode {
case config.ModeDirect, config.ModeDNSTT:
return d.DialContext(ctx, "tcp", addr)
case config.ModeSSL:
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
case config.ModePayload:
connectHost, connectPort := targetHost, targetPort
if attempt.ProxyHost != "" && attempt.ProxyPort > 0 {
connectHost, connectPort = attempt.ProxyHost, attempt.ProxyPort
} else if p.Proxy.Host != "" && p.Proxy.Port > 0 {
connectHost, connectPort = p.Proxy.Host, p.Proxy.Port
}
logger.Add("debug", "payload transport dialing %s", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
conn, err := d.DialContext(ctx, "tcp", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
if err != nil {
return nil, err
}
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
_ = conn.Close()
return nil, err
} else {
conn = wrappedConn
}
return conn, nil
case config.ModePayloadSSL:
conn, err := dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
if err != nil {
return nil, err
}
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
_ = conn.Close()
return nil, err
} else {
conn = wrappedConn
}
return conn, nil
default:
return nil, fmt.Errorf("unsupported mode %s", p.Mode)
}
}
func dialTLS(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int) (net.Conn, error) {
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, transportAttempt{})
}
func dialTLSAttempt(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int, attempt transportAttempt) (net.Conn, error) {
host := strings.TrimSpace(attempt.TLSHost)
port := attempt.TLSPort
if host == "" {
host = p.TLS.Host
}
if port <= 0 {
port = p.TLS.Port
}
if host == "" {
host = targetHost
}
if port == 0 {
port = targetPort
}
serverName := p.TLS.ServerName
if serverName == "" {
serverName = host
}
raw, err := d.DialContext(ctx, "tcp", net.JoinHostPort(host, fmt.Sprint(port)))
if err != nil {
return nil, err
}
cfg := &tls.Config{ServerName: serverName, InsecureSkipVerify: p.TLS.InsecureSkipVerify, MinVersion: tls.VersionTLS12}
conn := tls.Client(raw, cfg)
if err := conn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func splitHostList(s string) []string {
s = strings.TrimSpace(s)
if s == "" {
return nil
}
parts := strings.FieldsFunc(s, func(r rune) bool {
switch r {
case '#', ',', ';', '\n', '\r', '\t':
return true
default:
return false
}
})
out := make([]string, 0, len(parts))
seen := map[string]struct{}{}
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if _, ok := seen[part]; ok {
continue
}
seen[part] = struct{}{}
out = append(out, part)
}
return out
}
func hostPortWithDefault(raw string, defaultPort int) (string, int) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", defaultPort
}
if strings.Contains(raw, "://") {
// These fields are host fields, not URLs. Strip the scheme if a user pastes one.
if i := strings.Index(raw, "://"); i >= 0 {
raw = raw[i+3:]
}
}
if h, p, err := net.SplitHostPort(raw); err == nil {
pi, _ := strconv.Atoi(p)
if pi > 0 {
return strings.Trim(h, "[]"), pi
}
}
if strings.HasPrefix(raw, "[") && strings.Contains(raw, "]:") {
if h, p, err := net.SplitHostPort(raw); err == nil {
pi, _ := strconv.Atoi(p)
if pi > 0 {
return strings.Trim(h, "[]"), pi
}
}
}
// host:port for IPv4/domain. IPv6 without brackets has multiple colons and
// must keep the profile/default port.
if strings.Count(raw, ":") == 1 {
host, portText, ok := strings.Cut(raw, ":")
if ok {
if pi, err := strconv.Atoi(strings.TrimSpace(portText)); err == nil && pi > 0 {
return strings.TrimSpace(host), pi
}
}
}
return strings.Trim(raw, "[]"), defaultPort
}