526 lines
13 KiB
Go
526 lines
13 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"socksrevivepc/internal/config"
|
|
"socksrevivepc/internal/dnsttclient"
|
|
"socksrevivepc/internal/routes"
|
|
"socksrevivepc/internal/tun"
|
|
)
|
|
|
|
type Status struct {
|
|
Running bool `json:"running"`
|
|
Connecting bool `json:"connecting"`
|
|
ProfileID string `json:"profile_id"`
|
|
Mode string `json:"mode"`
|
|
SocksAddr string `json:"socks_addr"`
|
|
Tun bool `json:"tun"`
|
|
StartedAt string `json:"started_at"`
|
|
}
|
|
|
|
type Manager struct {
|
|
root string
|
|
mu sync.Mutex
|
|
logger *Logger
|
|
status Status
|
|
cancel context.CancelFunc
|
|
ssh *sshBundle
|
|
socks *SocksServer
|
|
dnstt *ManagedProcess
|
|
embeddedDNSTT *dnsttclient.Client
|
|
xray *ManagedProcess
|
|
tun *tun.Runner
|
|
routeCleanup *routes.Cleanup
|
|
manualStop bool
|
|
reconnecting bool
|
|
}
|
|
|
|
func NewManager(root string) *Manager {
|
|
lg := NewLogger()
|
|
return &Manager{root: root, logger: lg, tun: tun.NewRunner(lg)}
|
|
}
|
|
|
|
func (m *Manager) LogsSince(id int64) []LogEntry { return m.logger.Since(id) }
|
|
|
|
func (m *Manager) Status() Status {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.status
|
|
}
|
|
|
|
func (m *Manager) Start(p config.Profile) error {
|
|
config.ApplyDefaults(&p)
|
|
if err := config.Validate(p); err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
m.mu.Lock()
|
|
if m.status.Running || m.status.Connecting {
|
|
m.mu.Unlock()
|
|
cancel()
|
|
return fmt.Errorf("another profile is already running or connecting")
|
|
}
|
|
m.cancel = cancel
|
|
m.manualStop = false
|
|
m.status = Status{Connecting: true, ProfileID: p.ID, Mode: string(p.Mode), Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
|
|
m.mu.Unlock()
|
|
|
|
m.logger.Add("info", "starting profile: %s", p.Name)
|
|
|
|
fail := func(format string, args ...any) error {
|
|
err := fmt.Errorf(format, args...)
|
|
m.logger.Add("error", "%v", err)
|
|
m.stop(false)
|
|
return err
|
|
}
|
|
|
|
ensureCurrent := func() error {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if m.cancel == nil || !m.status.Connecting || m.status.ProfileID != p.ID {
|
|
return context.Canceled
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var socksAddr string
|
|
var bypass []string
|
|
if p.Mode == config.ModeXray {
|
|
proc, err := StartProcess(ctx, m.root, "xray", p.Xray.Executable, p.Xray.Args, m.logger)
|
|
if err != nil {
|
|
return fail("start xray failed: %w", err)
|
|
}
|
|
m.mu.Lock()
|
|
m.xray = proc
|
|
m.mu.Unlock()
|
|
time.Sleep(time.Duration(p.Xray.StartupTimeoutMs) * time.Millisecond)
|
|
if err := ensureCurrent(); err != nil {
|
|
proc.Stop()
|
|
return fail("connection cancelled: %w", err)
|
|
}
|
|
socksAddr = net.JoinHostPort(p.Xray.LocalSocksHost, fmt.Sprint(p.Xray.LocalSocksPort))
|
|
if exited, procErr := proc.Exited(); exited {
|
|
m.stop(false)
|
|
if procErr != nil {
|
|
return fmt.Errorf("xray exited before opening local SOCKS on %s: %w", socksAddr, procErr)
|
|
}
|
|
return fmt.Errorf("xray exited before opening local SOCKS on %s", socksAddr)
|
|
}
|
|
if err := waitForTCP(ctx, socksAddr, time.Duration(p.Xray.StartupTimeoutMs)*time.Millisecond); err != nil {
|
|
return fail("xray local SOCKS is not listening on %s: %w", socksAddr, err)
|
|
}
|
|
bypass = nil
|
|
} else {
|
|
if p.Mode == config.ModeDNSTT {
|
|
if p.DNSTT.UseEmbedded {
|
|
client, err := dnsttclient.Start(ctx, dnsttclient.Options{
|
|
ResolverType: p.DNSTT.ResolverType,
|
|
ResolverAddress: p.DNSTT.ResolverAddress,
|
|
PublicKeyHex: p.DNSTT.PublicKey,
|
|
Domain: p.DNSTT.Domain,
|
|
LocalAddress: net.JoinHostPort(p.DNSTT.LocalSSHHost, fmt.Sprint(p.DNSTT.LocalSSHPort)),
|
|
UTLSDistribution: p.DNSTT.UTLSDistribution,
|
|
StartupTimeout: time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond,
|
|
LogWriter: dnsttLogWriter{logger: m.logger},
|
|
})
|
|
if err != nil {
|
|
return fail("start embedded dnstt failed: %w", err)
|
|
}
|
|
m.mu.Lock()
|
|
m.embeddedDNSTT = client
|
|
m.mu.Unlock()
|
|
} else {
|
|
proc, err := StartProcess(ctx, m.root, "dnstt", p.DNSTT.Executable, p.DNSTT.Args, m.logger)
|
|
if err != nil {
|
|
return fail("start dnstt failed: %w", err)
|
|
}
|
|
m.mu.Lock()
|
|
m.dnstt = proc
|
|
m.mu.Unlock()
|
|
time.Sleep(time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond)
|
|
}
|
|
if err := ensureCurrent(); err != nil {
|
|
return fail("connection cancelled: %w", err)
|
|
}
|
|
}
|
|
sshc, err := connectSSH(ctx, p, m.logger)
|
|
if err != nil {
|
|
return fail("%w", err)
|
|
}
|
|
if err := ensureCurrent(); err != nil {
|
|
_ = sshc.Client.Close()
|
|
_ = sshc.Conn.Close()
|
|
return fail("connection cancelled: %w", err)
|
|
}
|
|
m.mu.Lock()
|
|
m.ssh = sshc
|
|
m.mu.Unlock()
|
|
socksAddr = net.JoinHostPort(p.Local.SocksHost, fmt.Sprint(p.Local.SocksPort))
|
|
ss := &SocksServer{Addr: socksAddr, SSH: sshc.Client, Logger: m.logger, DNS: profileDNSServers(p), UDPGW: p.UDPGW}
|
|
if err := ss.Start(); err != nil {
|
|
return fail("%w", err)
|
|
}
|
|
m.mu.Lock()
|
|
m.socks = ss
|
|
m.status.SocksAddr = socksAddr
|
|
m.mu.Unlock()
|
|
if err := waitForTCP(ctx, socksAddr, 2500*time.Millisecond); err != nil {
|
|
return fail("local SOCKS is not listening on %s: %w", socksAddr, err)
|
|
}
|
|
bypass = effectiveBypassHosts(p, sshc.ControlHosts)
|
|
}
|
|
|
|
if err := ensureCurrent(); err != nil {
|
|
return fail("connection cancelled: %w", err)
|
|
}
|
|
if p.Tun.Enabled {
|
|
if err := m.tun.Start(&p, socksAddr); err != nil {
|
|
return fail("%w", err)
|
|
}
|
|
cleanup, err := routes.Apply(p, bypass, m.logger)
|
|
if err != nil {
|
|
m.logger.Add("warn", "route setup error: %v", err)
|
|
}
|
|
m.mu.Lock()
|
|
m.routeCleanup = cleanup
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
m.mu.Lock()
|
|
current := m.cancel != nil && m.status.Connecting && m.status.ProfileID == p.ID
|
|
err := ctx.Err()
|
|
if err == nil && current {
|
|
m.status = Status{Running: true, ProfileID: p.ID, Mode: string(p.Mode), SocksAddr: socksAddr, Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
|
|
}
|
|
m.mu.Unlock()
|
|
if err != nil || !current {
|
|
m.stop(false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return context.Canceled
|
|
}
|
|
m.logger.Add("info", "profile is connected; local socks=%s tun=%v", socksAddr, p.Tun.Enabled)
|
|
m.startMonitor(ctx, p)
|
|
return nil
|
|
}
|
|
|
|
func profileDNSServers(p config.Profile) []string {
|
|
out := append([]string{}, p.Tun.DNS...)
|
|
if p.Tun.IPv6Enabled {
|
|
out = append(out, p.Tun.IPv6DNS...)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func waitForTCP(ctx context.Context, addr string, timeout time.Duration) error {
|
|
if timeout <= 0 {
|
|
timeout = 5 * time.Second
|
|
}
|
|
deadline := time.Now().Add(timeout)
|
|
var lastErr error
|
|
for time.Now().Before(deadline) {
|
|
d := net.Dialer{Timeout: 350 * time.Millisecond}
|
|
c, err := d.DialContext(ctx, "tcp", addr)
|
|
if err == nil {
|
|
_ = c.Close()
|
|
return nil
|
|
}
|
|
lastErr = err
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(150 * time.Millisecond):
|
|
}
|
|
}
|
|
if lastErr != nil {
|
|
return lastErr
|
|
}
|
|
return fmt.Errorf("timeout waiting for %s", addr)
|
|
}
|
|
|
|
func effectiveBypassHosts(p config.Profile, activeControlHosts []string) []string {
|
|
seen := map[string]bool{}
|
|
out := make([]string, 0, len(activeControlHosts)+1)
|
|
add := func(host string) {
|
|
host = strings.TrimSpace(host)
|
|
if host == "" {
|
|
return
|
|
}
|
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
|
host = strings.Trim(h, "[]")
|
|
}
|
|
if isLocalBypassHost(host) || seen[host] {
|
|
return
|
|
}
|
|
seen[host] = true
|
|
out = append(out, host)
|
|
}
|
|
for _, host := range activeControlHosts {
|
|
add(host)
|
|
}
|
|
if len(out) == 0 && p.Mode != config.ModeDNSTT {
|
|
// Fallback for old profile formats or unexpected transports. This still only
|
|
// adds the direct SSH host, not every rotated proxy from the profile.
|
|
add(p.SSH.Host)
|
|
}
|
|
if p.Mode == config.ModeDNSTT && p.DNSTT.UseEmbedded {
|
|
add(dnsttResolverHost(p.DNSTT.ResolverAddress))
|
|
}
|
|
return out
|
|
}
|
|
|
|
func isLocalBypassHost(host string) bool {
|
|
host = strings.Trim(strings.ToLower(strings.TrimSpace(host)), "[]")
|
|
if host == "" || host == "localhost" {
|
|
return true
|
|
}
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
return ip.IsLoopback() || ip.IsUnspecified()
|
|
}
|
|
|
|
func dnsttResolverHost(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
if s == "" {
|
|
return ""
|
|
}
|
|
if strings.Contains(s, "://") {
|
|
u, err := url.Parse(s)
|
|
if err == nil {
|
|
return u.Hostname()
|
|
}
|
|
}
|
|
host, _, err := net.SplitHostPort(s)
|
|
if err == nil {
|
|
return host
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (m *Manager) Stop() {
|
|
m.stop(true)
|
|
}
|
|
|
|
func (m *Manager) stop(manual bool) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if manual {
|
|
m.manualStop = true
|
|
}
|
|
m.stopLocked()
|
|
}
|
|
|
|
func (m *Manager) startMonitor(ctx context.Context, p config.Profile) {
|
|
if !p.Reconnect.Enabled {
|
|
return
|
|
}
|
|
interval := time.Duration(p.Reconnect.CheckIntervalSeconds) * time.Second
|
|
if interval <= 0 {
|
|
interval = 10 * time.Second
|
|
}
|
|
go func() {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if m.isManualStopped() {
|
|
return
|
|
}
|
|
if err := m.probeConnection(p); err != nil {
|
|
m.logger.Add("warn", "connection monitor detected tunnel loss: %v", err)
|
|
m.reconnectLoop(p, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (m *Manager) isManualStopped() bool {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.manualStop
|
|
}
|
|
|
|
func (m *Manager) markReconnecting() bool {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if m.manualStop || m.reconnecting {
|
|
return false
|
|
}
|
|
m.reconnecting = true
|
|
return true
|
|
}
|
|
|
|
func (m *Manager) clearReconnecting() {
|
|
m.mu.Lock()
|
|
m.reconnecting = false
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
func (m *Manager) probeConnection(p config.Profile) error {
|
|
m.mu.Lock()
|
|
sshc := m.ssh
|
|
xray := m.xray
|
|
socksAddr := m.status.SocksAddr
|
|
running := m.status.Running
|
|
m.mu.Unlock()
|
|
|
|
if !running {
|
|
return nil
|
|
}
|
|
if xray != nil {
|
|
if exited, err := xray.Exited(); exited {
|
|
if err != nil {
|
|
return fmt.Errorf("xray exited: %w", err)
|
|
}
|
|
return fmt.Errorf("xray exited")
|
|
}
|
|
if socksAddr != "" {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
return waitForTCP(ctx, socksAddr, 2*time.Second)
|
|
}
|
|
return nil
|
|
}
|
|
if sshc != nil && sshc.Client != nil {
|
|
if sshc.Conn != nil {
|
|
_ = sshc.Conn.SetDeadline(time.Now().Add(5 * time.Second))
|
|
defer sshc.Conn.SetDeadline(time.Time{})
|
|
}
|
|
_, _, err := sshc.Client.SendRequest("keepalive@openssh.com", true, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("ssh keepalive failed: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) reconnectLoop(p config.Profile, cause error) {
|
|
if !m.markReconnecting() {
|
|
return
|
|
}
|
|
defer m.clearReconnecting()
|
|
|
|
delay := time.Duration(p.Reconnect.DelaySeconds) * time.Second
|
|
if delay <= 0 {
|
|
delay = 3 * time.Second
|
|
}
|
|
maxRetries := p.Reconnect.MaxRetries
|
|
m.logger.Add("warn", "connection lost (%v); auto reconnect is enabled", cause)
|
|
if p.Tun.Enabled {
|
|
m.logger.Add("info", "destroying TUN before reconnect")
|
|
}
|
|
m.stop(false)
|
|
if p.Tun.Enabled {
|
|
time.Sleep(1200 * time.Millisecond)
|
|
}
|
|
|
|
for attempt := 1; maxRetries <= 0 || attempt <= maxRetries; attempt++ {
|
|
if m.isManualStopped() {
|
|
m.logger.Add("info", "auto reconnect cancelled by user")
|
|
return
|
|
}
|
|
m.logger.Add("info", "reconnect attempt %d%s in %s", attempt, reconnectLimitSuffix(maxRetries), delay)
|
|
select {
|
|
case <-time.After(delay):
|
|
}
|
|
if m.isManualStopped() {
|
|
m.logger.Add("info", "auto reconnect cancelled by user")
|
|
return
|
|
}
|
|
if err := m.Start(p); err != nil {
|
|
if m.isManualStopped() {
|
|
m.logger.Add("info", "auto reconnect cancelled by user")
|
|
return
|
|
}
|
|
m.logger.Add("warn", "reconnect attempt %d failed: %v", attempt, err)
|
|
if p.Tun.Enabled {
|
|
m.logger.Add("info", "destroying TUN after failed reconnect attempt")
|
|
m.stop(false)
|
|
time.Sleep(1200 * time.Millisecond)
|
|
}
|
|
continue
|
|
}
|
|
m.logger.Add("info", "reconnected successfully")
|
|
return
|
|
}
|
|
m.logger.Add("error", "auto reconnect stopped after %d failed attempt(s)", maxRetries)
|
|
}
|
|
|
|
func reconnectLimitSuffix(maxRetries int) string {
|
|
if maxRetries <= 0 {
|
|
return " (unlimited)"
|
|
}
|
|
return fmt.Sprintf("/%d", maxRetries)
|
|
}
|
|
|
|
func (m *Manager) stopLocked() {
|
|
wasActive := m.status.Running || m.status.Connecting
|
|
if m.cancel != nil {
|
|
m.cancel()
|
|
m.cancel = nil
|
|
}
|
|
if m.routeCleanup != nil {
|
|
m.routeCleanup.Run()
|
|
m.routeCleanup = nil
|
|
}
|
|
if m.tun != nil {
|
|
m.tun.Stop()
|
|
}
|
|
if m.socks != nil {
|
|
m.socks.Stop()
|
|
m.socks = nil
|
|
}
|
|
if m.ssh != nil {
|
|
_ = m.ssh.Client.Close()
|
|
_ = m.ssh.Conn.Close()
|
|
m.ssh = nil
|
|
}
|
|
if m.xray != nil {
|
|
m.xray.Stop()
|
|
m.xray = nil
|
|
}
|
|
if m.embeddedDNSTT != nil {
|
|
m.embeddedDNSTT.Stop()
|
|
m.embeddedDNSTT = nil
|
|
}
|
|
if m.dnstt != nil {
|
|
m.dnstt.Stop()
|
|
m.dnstt = nil
|
|
}
|
|
if wasActive {
|
|
m.logger.Add("info", "disconnected")
|
|
}
|
|
m.status = Status{}
|
|
}
|
|
|
|
type dnsttLogWriter struct {
|
|
logger *Logger
|
|
}
|
|
|
|
func (w dnsttLogWriter) Write(p []byte) (int, error) {
|
|
line := strings.TrimSpace(string(p))
|
|
if line != "" && w.logger != nil {
|
|
w.logger.Add("dnstt", "%s", line)
|
|
}
|
|
return len(p), nil
|
|
}
|