Launch
This commit is contained in:
345
internal/engine/sshclient.go
Normal file
345
internal/engine/sshclient.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"socksrevivepc/internal/config"
|
||||
)
|
||||
|
||||
type sshBundle struct {
|
||||
Client *ssh.Client
|
||||
Conn net.Conn
|
||||
ControlHosts []string
|
||||
}
|
||||
|
||||
type transportAttempt struct {
|
||||
Label string
|
||||
ProxyHost string
|
||||
ProxyPort int
|
||||
TLSHost string
|
||||
TLSPort int
|
||||
ControlHost string
|
||||
}
|
||||
|
||||
func connectSSH(ctx context.Context, p config.Profile, logger *Logger) (*sshBundle, error) {
|
||||
targetHost := p.SSH.Host
|
||||
targetPort := p.SSH.Port
|
||||
if p.Mode == config.ModeDNSTT {
|
||||
targetHost = p.DNSTT.LocalSSHHost
|
||||
targetPort = p.DNSTT.LocalSSHPort
|
||||
}
|
||||
|
||||
attempts := buildTransportAttempts(p, targetHost, targetPort)
|
||||
if len(attempts) == 0 {
|
||||
attempts = []transportAttempt{{Label: "default"}}
|
||||
}
|
||||
|
||||
logger.Add("info", "connecting SSH %s:%d using mode %s", targetHost, targetPort, p.Mode)
|
||||
var lastErr error
|
||||
for i, attempt := range attempts {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("info", "connection attempt %d/%d via %s", i+1, len(attempts), attempt.Label)
|
||||
}
|
||||
|
||||
conn, err := dialTransportAttempt(ctx, p, targetHost, targetPort, attempt, logger)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("warn", "connection attempt %d/%d failed before SSH handshake: %v", i+1, len(attempts), err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
bundle, err := finishSSHHandshake(conn, p, targetHost, targetPort, logger)
|
||||
if err == nil {
|
||||
bundle.ControlHosts = controlHostsForAttempt(p, targetHost, attempt)
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("info", "connection attempt %d/%d succeeded via %s", i+1, len(attempts), attempt.Label)
|
||||
}
|
||||
return bundle, nil
|
||||
}
|
||||
_ = conn.Close()
|
||||
lastErr = err
|
||||
if len(attempts) > 1 {
|
||||
logger.Add("warn", "connection attempt %d/%d failed during SSH handshake: %v", i+1, len(attempts), err)
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("no transport attempts were available")
|
||||
}
|
||||
if len(attempts) > 1 {
|
||||
return nil, fmt.Errorf("all %d connection attempts failed; last error: %w", len(attempts), lastErr)
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func finishSSHHandshake(conn net.Conn, p config.Profile, targetHost string, targetPort int, logger *Logger) (*sshBundle, error) {
|
||||
sshCfg := &ssh.ClientConfig{
|
||||
User: p.SSH.Username,
|
||||
Auth: []ssh.AuthMethod{ssh.Password(p.SSH.Password)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond,
|
||||
ClientVersion: "SSH-2.0-SocksRevivePC",
|
||||
}
|
||||
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
|
||||
cc, chans, reqs, err := ssh.NewClientConn(conn, addr, sshCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssh handshake failed: %w", err)
|
||||
}
|
||||
client := ssh.NewClient(cc, chans, reqs)
|
||||
logger.Add("info", "ssh authenticated as %s", p.SSH.Username)
|
||||
return &sshBundle{Client: client, Conn: conn}, nil
|
||||
}
|
||||
|
||||
func buildTransportAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
|
||||
switch p.Mode {
|
||||
case config.ModePayload:
|
||||
return payloadProxyAttempts(p, targetHost, targetPort)
|
||||
case config.ModeSSL, config.ModePayloadSSL:
|
||||
return tlsHostAttempts(p, targetHost, targetPort)
|
||||
default:
|
||||
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
|
||||
}
|
||||
}
|
||||
|
||||
func controlHostsForAttempt(p config.Profile, targetHost string, attempt transportAttempt) []string {
|
||||
switch p.Mode {
|
||||
case config.ModePayload:
|
||||
if attempt.ProxyHost != "" {
|
||||
return []string{attempt.ProxyHost}
|
||||
}
|
||||
return []string{targetHost}
|
||||
case config.ModeSSL, config.ModePayloadSSL:
|
||||
if attempt.TLSHost != "" {
|
||||
return []string{attempt.TLSHost}
|
||||
}
|
||||
if p.TLS.Host != "" {
|
||||
host, _ := hostPortWithDefault(p.TLS.Host, p.TLS.Port)
|
||||
if host != "" {
|
||||
return []string{host}
|
||||
}
|
||||
}
|
||||
return []string{targetHost}
|
||||
default:
|
||||
return []string{targetHost}
|
||||
}
|
||||
}
|
||||
|
||||
func payloadProxyAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
|
||||
proxyText := strings.TrimSpace(p.Proxy.Host)
|
||||
if proxyText == "" || p.Proxy.Port <= 0 {
|
||||
return []transportAttempt{{Label: "direct payload transport"}}
|
||||
}
|
||||
|
||||
rawHosts := splitHostList(proxyText)
|
||||
if len(rawHosts) == 0 {
|
||||
return []transportAttempt{{Label: "direct payload transport"}}
|
||||
}
|
||||
out := make([]transportAttempt, 0, len(rawHosts))
|
||||
for _, raw := range rawHosts {
|
||||
host, port := hostPortWithDefault(raw, p.Proxy.Port)
|
||||
if host == "" || port <= 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, transportAttempt{ProxyHost: host, ProxyPort: port, Label: "proxy " + net.JoinHostPort(host, fmt.Sprint(port))})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []transportAttempt{{Label: net.JoinHostPort(targetHost, fmt.Sprint(targetPort))}}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func tlsHostAttempts(p config.Profile, targetHost string, targetPort int) []transportAttempt {
|
||||
hostText := strings.TrimSpace(p.TLS.Host)
|
||||
defaultPort := p.TLS.Port
|
||||
if defaultPort <= 0 {
|
||||
defaultPort = targetPort
|
||||
}
|
||||
if hostText == "" {
|
||||
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
|
||||
}
|
||||
|
||||
rawHosts := splitHostList(hostText)
|
||||
if len(rawHosts) == 0 {
|
||||
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
|
||||
}
|
||||
out := make([]transportAttempt, 0, len(rawHosts))
|
||||
for _, raw := range rawHosts {
|
||||
host, port := hostPortWithDefault(raw, defaultPort)
|
||||
if host == "" || port <= 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, transportAttempt{TLSHost: host, TLSPort: port, Label: "TLS " + net.JoinHostPort(host, fmt.Sprint(port))})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []transportAttempt{{TLSHost: targetHost, TLSPort: defaultPort, Label: "TLS " + net.JoinHostPort(targetHost, fmt.Sprint(defaultPort))}}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func dialTransport(ctx context.Context, p config.Profile, targetHost string, targetPort int, logger *Logger) (net.Conn, error) {
|
||||
attempts := buildTransportAttempts(p, targetHost, targetPort)
|
||||
if len(attempts) == 0 {
|
||||
attempts = []transportAttempt{{Label: "default"}}
|
||||
}
|
||||
return dialTransportAttempt(ctx, p, targetHost, targetPort, attempts[0], logger)
|
||||
}
|
||||
|
||||
func dialTransportAttempt(ctx context.Context, p config.Profile, targetHost string, targetPort int, attempt transportAttempt, logger *Logger) (net.Conn, error) {
|
||||
d := &net.Dialer{Timeout: time.Duration(p.SSH.HandshakeTimeoutMs) * time.Millisecond}
|
||||
addr := net.JoinHostPort(targetHost, fmt.Sprint(targetPort))
|
||||
|
||||
switch p.Mode {
|
||||
case config.ModeDirect, config.ModeDNSTT:
|
||||
return d.DialContext(ctx, "tcp", addr)
|
||||
case config.ModeSSL:
|
||||
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
|
||||
case config.ModePayload:
|
||||
connectHost, connectPort := targetHost, targetPort
|
||||
if attempt.ProxyHost != "" && attempt.ProxyPort > 0 {
|
||||
connectHost, connectPort = attempt.ProxyHost, attempt.ProxyPort
|
||||
} else if p.Proxy.Host != "" && p.Proxy.Port > 0 {
|
||||
connectHost, connectPort = p.Proxy.Host, p.Proxy.Port
|
||||
}
|
||||
logger.Add("debug", "payload transport dialing %s", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
|
||||
conn, err := d.DialContext(ctx, "tcp", net.JoinHostPort(connectHost, fmt.Sprint(connectPort)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
} else {
|
||||
conn = wrappedConn
|
||||
}
|
||||
return conn, nil
|
||||
case config.ModePayloadSSL:
|
||||
conn, err := dialTLSAttempt(ctx, d, p, targetHost, targetPort, attempt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, wrappedConn, err := WritePayload(conn, p, targetHost, targetPort, logger); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
} else {
|
||||
conn = wrappedConn
|
||||
}
|
||||
return conn, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported mode %s", p.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func dialTLS(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int) (net.Conn, error) {
|
||||
return dialTLSAttempt(ctx, d, p, targetHost, targetPort, transportAttempt{})
|
||||
}
|
||||
|
||||
func dialTLSAttempt(ctx context.Context, d *net.Dialer, p config.Profile, targetHost string, targetPort int, attempt transportAttempt) (net.Conn, error) {
|
||||
host := strings.TrimSpace(attempt.TLSHost)
|
||||
port := attempt.TLSPort
|
||||
if host == "" {
|
||||
host = p.TLS.Host
|
||||
}
|
||||
if port <= 0 {
|
||||
port = p.TLS.Port
|
||||
}
|
||||
if host == "" {
|
||||
host = targetHost
|
||||
}
|
||||
if port == 0 {
|
||||
port = targetPort
|
||||
}
|
||||
serverName := p.TLS.ServerName
|
||||
if serverName == "" {
|
||||
serverName = host
|
||||
}
|
||||
raw, err := d.DialContext(ctx, "tcp", net.JoinHostPort(host, fmt.Sprint(port)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := &tls.Config{ServerName: serverName, InsecureSkipVerify: p.TLS.InsecureSkipVerify, MinVersion: tls.VersionTLS12}
|
||||
conn := tls.Client(raw, cfg)
|
||||
if err := conn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func splitHostList(s string) []string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.FieldsFunc(s, func(r rune) bool {
|
||||
switch r {
|
||||
case '#', ',', ';', '\n', '\r', '\t':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := map[string]struct{}{}
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[part]; ok {
|
||||
continue
|
||||
}
|
||||
seen[part] = struct{}{}
|
||||
out = append(out, part)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hostPortWithDefault(raw string, defaultPort int) (string, int) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", defaultPort
|
||||
}
|
||||
if strings.Contains(raw, "://") {
|
||||
// These fields are host fields, not URLs. Strip the scheme if a user pastes one.
|
||||
if i := strings.Index(raw, "://"); i >= 0 {
|
||||
raw = raw[i+3:]
|
||||
}
|
||||
}
|
||||
if h, p, err := net.SplitHostPort(raw); err == nil {
|
||||
pi, _ := strconv.Atoi(p)
|
||||
if pi > 0 {
|
||||
return strings.Trim(h, "[]"), pi
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(raw, "[") && strings.Contains(raw, "]:") {
|
||||
if h, p, err := net.SplitHostPort(raw); err == nil {
|
||||
pi, _ := strconv.Atoi(p)
|
||||
if pi > 0 {
|
||||
return strings.Trim(h, "[]"), pi
|
||||
}
|
||||
}
|
||||
}
|
||||
// host:port for IPv4/domain. IPv6 without brackets has multiple colons and
|
||||
// must keep the profile/default port.
|
||||
if strings.Count(raw, ":") == 1 {
|
||||
host, portText, ok := strings.Cut(raw, ":")
|
||||
if ok {
|
||||
if pi, err := strconv.Atoi(strings.TrimSpace(portText)); err == nil && pi > 0 {
|
||||
return strings.TrimSpace(host), pi
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Trim(raw, "[]"), defaultPort
|
||||
}
|
||||
Reference in New Issue
Block a user