Files
SocksRevive-PC/internal/engine/manager.go
2026-05-16 00:18:06 -03:00

526 lines
13 KiB
Go

package engine
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"socksrevivepc/internal/config"
"socksrevivepc/internal/dnsttclient"
"socksrevivepc/internal/routes"
"socksrevivepc/internal/tun"
)
type Status struct {
Running bool `json:"running"`
Connecting bool `json:"connecting"`
ProfileID string `json:"profile_id"`
Mode string `json:"mode"`
SocksAddr string `json:"socks_addr"`
Tun bool `json:"tun"`
StartedAt string `json:"started_at"`
}
type Manager struct {
root string
mu sync.Mutex
logger *Logger
status Status
cancel context.CancelFunc
ssh *sshBundle
socks *SocksServer
dnstt *ManagedProcess
embeddedDNSTT *dnsttclient.Client
xray *ManagedProcess
tun *tun.Runner
routeCleanup *routes.Cleanup
manualStop bool
reconnecting bool
}
func NewManager(root string) *Manager {
lg := NewLogger()
return &Manager{root: root, logger: lg, tun: tun.NewRunner(lg)}
}
func (m *Manager) LogsSince(id int64) []LogEntry { return m.logger.Since(id) }
func (m *Manager) Status() Status {
m.mu.Lock()
defer m.mu.Unlock()
return m.status
}
func (m *Manager) Start(p config.Profile) error {
config.ApplyDefaults(&p)
if err := config.Validate(p); err != nil {
return err
}
ctx, cancel := context.WithCancel(context.Background())
m.mu.Lock()
if m.status.Running || m.status.Connecting {
m.mu.Unlock()
cancel()
return fmt.Errorf("another profile is already running or connecting")
}
m.cancel = cancel
m.manualStop = false
m.status = Status{Connecting: true, ProfileID: p.ID, Mode: string(p.Mode), Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
m.mu.Unlock()
m.logger.Add("info", "starting profile: %s", p.Name)
fail := func(format string, args ...any) error {
err := fmt.Errorf(format, args...)
m.logger.Add("error", "%v", err)
m.stop(false)
return err
}
ensureCurrent := func() error {
if err := ctx.Err(); err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil || !m.status.Connecting || m.status.ProfileID != p.ID {
return context.Canceled
}
return nil
}
var socksAddr string
var bypass []string
if p.Mode == config.ModeXray {
proc, err := StartProcess(ctx, m.root, "xray", p.Xray.Executable, p.Xray.Args, m.logger)
if err != nil {
return fail("start xray failed: %w", err)
}
m.mu.Lock()
m.xray = proc
m.mu.Unlock()
time.Sleep(time.Duration(p.Xray.StartupTimeoutMs) * time.Millisecond)
if err := ensureCurrent(); err != nil {
proc.Stop()
return fail("connection cancelled: %w", err)
}
socksAddr = net.JoinHostPort(p.Xray.LocalSocksHost, fmt.Sprint(p.Xray.LocalSocksPort))
if exited, procErr := proc.Exited(); exited {
m.stop(false)
if procErr != nil {
return fmt.Errorf("xray exited before opening local SOCKS on %s: %w", socksAddr, procErr)
}
return fmt.Errorf("xray exited before opening local SOCKS on %s", socksAddr)
}
if err := waitForTCP(ctx, socksAddr, time.Duration(p.Xray.StartupTimeoutMs)*time.Millisecond); err != nil {
return fail("xray local SOCKS is not listening on %s: %w", socksAddr, err)
}
bypass = nil
} else {
if p.Mode == config.ModeDNSTT {
if p.DNSTT.UseEmbedded {
client, err := dnsttclient.Start(ctx, dnsttclient.Options{
ResolverType: p.DNSTT.ResolverType,
ResolverAddress: p.DNSTT.ResolverAddress,
PublicKeyHex: p.DNSTT.PublicKey,
Domain: p.DNSTT.Domain,
LocalAddress: net.JoinHostPort(p.DNSTT.LocalSSHHost, fmt.Sprint(p.DNSTT.LocalSSHPort)),
UTLSDistribution: p.DNSTT.UTLSDistribution,
StartupTimeout: time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond,
LogWriter: dnsttLogWriter{logger: m.logger},
})
if err != nil {
return fail("start embedded dnstt failed: %w", err)
}
m.mu.Lock()
m.embeddedDNSTT = client
m.mu.Unlock()
} else {
proc, err := StartProcess(ctx, m.root, "dnstt", p.DNSTT.Executable, p.DNSTT.Args, m.logger)
if err != nil {
return fail("start dnstt failed: %w", err)
}
m.mu.Lock()
m.dnstt = proc
m.mu.Unlock()
time.Sleep(time.Duration(p.DNSTT.StartupTimeoutMs) * time.Millisecond)
}
if err := ensureCurrent(); err != nil {
return fail("connection cancelled: %w", err)
}
}
sshc, err := connectSSH(ctx, p, m.logger)
if err != nil {
return fail("%w", err)
}
if err := ensureCurrent(); err != nil {
_ = sshc.Client.Close()
_ = sshc.Conn.Close()
return fail("connection cancelled: %w", err)
}
m.mu.Lock()
m.ssh = sshc
m.mu.Unlock()
socksAddr = net.JoinHostPort(p.Local.SocksHost, fmt.Sprint(p.Local.SocksPort))
ss := &SocksServer{Addr: socksAddr, SSH: sshc.Client, Logger: m.logger, DNS: profileDNSServers(p), UDPGW: p.UDPGW}
if err := ss.Start(); err != nil {
return fail("%w", err)
}
m.mu.Lock()
m.socks = ss
m.status.SocksAddr = socksAddr
m.mu.Unlock()
if err := waitForTCP(ctx, socksAddr, 2500*time.Millisecond); err != nil {
return fail("local SOCKS is not listening on %s: %w", socksAddr, err)
}
bypass = effectiveBypassHosts(p, sshc.ControlHosts)
}
if err := ensureCurrent(); err != nil {
return fail("connection cancelled: %w", err)
}
if p.Tun.Enabled {
if err := m.tun.Start(&p, socksAddr); err != nil {
return fail("%w", err)
}
cleanup, err := routes.Apply(p, bypass, m.logger)
if err != nil {
m.logger.Add("warn", "route setup error: %v", err)
}
m.mu.Lock()
m.routeCleanup = cleanup
m.mu.Unlock()
}
m.mu.Lock()
current := m.cancel != nil && m.status.Connecting && m.status.ProfileID == p.ID
err := ctx.Err()
if err == nil && current {
m.status = Status{Running: true, ProfileID: p.ID, Mode: string(p.Mode), SocksAddr: socksAddr, Tun: p.Tun.Enabled, StartedAt: time.Now().Format(time.RFC3339)}
}
m.mu.Unlock()
if err != nil || !current {
m.stop(false)
if err != nil {
return err
}
return context.Canceled
}
m.logger.Add("info", "profile is connected; local socks=%s tun=%v", socksAddr, p.Tun.Enabled)
m.startMonitor(ctx, p)
return nil
}
func profileDNSServers(p config.Profile) []string {
out := append([]string{}, p.Tun.DNS...)
if p.Tun.IPv6Enabled {
out = append(out, p.Tun.IPv6DNS...)
}
return out
}
func waitForTCP(ctx context.Context, addr string, timeout time.Duration) error {
if timeout <= 0 {
timeout = 5 * time.Second
}
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
d := net.Dialer{Timeout: 350 * time.Millisecond}
c, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
_ = c.Close()
return nil
}
lastErr = err
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(150 * time.Millisecond):
}
}
if lastErr != nil {
return lastErr
}
return fmt.Errorf("timeout waiting for %s", addr)
}
func effectiveBypassHosts(p config.Profile, activeControlHosts []string) []string {
seen := map[string]bool{}
out := make([]string, 0, len(activeControlHosts)+1)
add := func(host string) {
host = strings.TrimSpace(host)
if host == "" {
return
}
if h, _, err := net.SplitHostPort(host); err == nil {
host = strings.Trim(h, "[]")
}
if isLocalBypassHost(host) || seen[host] {
return
}
seen[host] = true
out = append(out, host)
}
for _, host := range activeControlHosts {
add(host)
}
if len(out) == 0 && p.Mode != config.ModeDNSTT {
// Fallback for old profile formats or unexpected transports. This still only
// adds the direct SSH host, not every rotated proxy from the profile.
add(p.SSH.Host)
}
if p.Mode == config.ModeDNSTT && p.DNSTT.UseEmbedded {
add(dnsttResolverHost(p.DNSTT.ResolverAddress))
}
return out
}
func isLocalBypassHost(host string) bool {
host = strings.Trim(strings.ToLower(strings.TrimSpace(host)), "[]")
if host == "" || host == "localhost" {
return true
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback() || ip.IsUnspecified()
}
func dnsttResolverHost(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if strings.Contains(s, "://") {
u, err := url.Parse(s)
if err == nil {
return u.Hostname()
}
}
host, _, err := net.SplitHostPort(s)
if err == nil {
return host
}
return s
}
func (m *Manager) Stop() {
m.stop(true)
}
func (m *Manager) stop(manual bool) {
m.mu.Lock()
defer m.mu.Unlock()
if manual {
m.manualStop = true
}
m.stopLocked()
}
func (m *Manager) startMonitor(ctx context.Context, p config.Profile) {
if !p.Reconnect.Enabled {
return
}
interval := time.Duration(p.Reconnect.CheckIntervalSeconds) * time.Second
if interval <= 0 {
interval = 10 * time.Second
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if m.isManualStopped() {
return
}
if err := m.probeConnection(p); err != nil {
m.logger.Add("warn", "connection monitor detected tunnel loss: %v", err)
m.reconnectLoop(p, err)
return
}
}
}
}()
}
func (m *Manager) isManualStopped() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.manualStop
}
func (m *Manager) markReconnecting() bool {
m.mu.Lock()
defer m.mu.Unlock()
if m.manualStop || m.reconnecting {
return false
}
m.reconnecting = true
return true
}
func (m *Manager) clearReconnecting() {
m.mu.Lock()
m.reconnecting = false
m.mu.Unlock()
}
func (m *Manager) probeConnection(p config.Profile) error {
m.mu.Lock()
sshc := m.ssh
xray := m.xray
socksAddr := m.status.SocksAddr
running := m.status.Running
m.mu.Unlock()
if !running {
return nil
}
if xray != nil {
if exited, err := xray.Exited(); exited {
if err != nil {
return fmt.Errorf("xray exited: %w", err)
}
return fmt.Errorf("xray exited")
}
if socksAddr != "" {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
return waitForTCP(ctx, socksAddr, 2*time.Second)
}
return nil
}
if sshc != nil && sshc.Client != nil {
if sshc.Conn != nil {
_ = sshc.Conn.SetDeadline(time.Now().Add(5 * time.Second))
defer sshc.Conn.SetDeadline(time.Time{})
}
_, _, err := sshc.Client.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
return fmt.Errorf("ssh keepalive failed: %w", err)
}
}
return nil
}
func (m *Manager) reconnectLoop(p config.Profile, cause error) {
if !m.markReconnecting() {
return
}
defer m.clearReconnecting()
delay := time.Duration(p.Reconnect.DelaySeconds) * time.Second
if delay <= 0 {
delay = 3 * time.Second
}
maxRetries := p.Reconnect.MaxRetries
m.logger.Add("warn", "connection lost (%v); auto reconnect is enabled", cause)
if p.Tun.Enabled {
m.logger.Add("info", "destroying TUN before reconnect")
}
m.stop(false)
if p.Tun.Enabled {
time.Sleep(1200 * time.Millisecond)
}
for attempt := 1; maxRetries <= 0 || attempt <= maxRetries; attempt++ {
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
m.logger.Add("info", "reconnect attempt %d%s in %s", attempt, reconnectLimitSuffix(maxRetries), delay)
select {
case <-time.After(delay):
}
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
if err := m.Start(p); err != nil {
if m.isManualStopped() {
m.logger.Add("info", "auto reconnect cancelled by user")
return
}
m.logger.Add("warn", "reconnect attempt %d failed: %v", attempt, err)
if p.Tun.Enabled {
m.logger.Add("info", "destroying TUN after failed reconnect attempt")
m.stop(false)
time.Sleep(1200 * time.Millisecond)
}
continue
}
m.logger.Add("info", "reconnected successfully")
return
}
m.logger.Add("error", "auto reconnect stopped after %d failed attempt(s)", maxRetries)
}
func reconnectLimitSuffix(maxRetries int) string {
if maxRetries <= 0 {
return " (unlimited)"
}
return fmt.Sprintf("/%d", maxRetries)
}
func (m *Manager) stopLocked() {
wasActive := m.status.Running || m.status.Connecting
if m.cancel != nil {
m.cancel()
m.cancel = nil
}
if m.routeCleanup != nil {
m.routeCleanup.Run()
m.routeCleanup = nil
}
if m.tun != nil {
m.tun.Stop()
}
if m.socks != nil {
m.socks.Stop()
m.socks = nil
}
if m.ssh != nil {
_ = m.ssh.Client.Close()
_ = m.ssh.Conn.Close()
m.ssh = nil
}
if m.xray != nil {
m.xray.Stop()
m.xray = nil
}
if m.embeddedDNSTT != nil {
m.embeddedDNSTT.Stop()
m.embeddedDNSTT = nil
}
if m.dnstt != nil {
m.dnstt.Stop()
m.dnstt = nil
}
if wasActive {
m.logger.Add("info", "disconnected")
}
m.status = Status{}
}
type dnsttLogWriter struct {
logger *Logger
}
func (w dnsttLogWriter) Write(p []byte) (int, error) {
line := strings.TrimSpace(string(p))
if line != "" && w.logger != nil {
w.logger.Add("dnstt", "%s", line)
}
return len(p), nil
}