From 43482c88fa942fe7b928c9159e460f6f23c30571 Mon Sep 17 00:00:00 2001 From: penguinehis Date: Sun, 3 May 2026 10:15:28 -0300 Subject: [PATCH] Fix stuck users --- check_api.go | 3 +- hotreload.go | 4 +- main.go | 182 +++++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 166 insertions(+), 23 deletions(-) diff --git a/check_api.go b/check_api.go index 3240c04..cd5b868 100644 --- a/check_api.go +++ b/check_api.go @@ -59,7 +59,8 @@ func checkSSHUser(w http.ResponseWriter, username string) { } u.mu.Lock() - activeConns := u.ActiveConns + activeConns := len(u.conns) + u.ActiveConns = activeConns maxConns := u.Cfg.MaxConnections expiresAt := u.ExpiresAt u.mu.Unlock() diff --git a/hotreload.go b/hotreload.go index 956a554..97d730c 100644 --- a/hotreload.go +++ b/hotreload.go @@ -292,8 +292,10 @@ func applyFullConfigReload(newCfg *Config) ConfigReloadReport { } setBannerText(bt) - // Default per-connection bandwidth limits (picked up by new connections) + // Default per-connection bandwidth limits and SSH inactivity cleanup + // (picked up by new connections). setDefaultLimits(newCfg.DefaultLimitMbpsUp, newCfg.DefaultLimitMbpsDown) + setSSHIdleTimeoutFromConfig(newCfg.SSHIdleTimeout) // Quiet logging / user count display if newCfg.Quiet { diff --git a/main.go b/main.go index ac49c01..419a91c 100644 --- a/main.go +++ b/main.go @@ -35,6 +35,9 @@ const ( tlsHandshakeTimeout = 15 * time.Second // Dial timeout for direct-tcpip backend connections. directTCPIPDialTimeout = 10 * time.Second + // Default post-auth SSH inactivity timeout. This is based on real bytes + // moving in either direction, so live upload/download tunnels are not closed. + defaultSSHIdleTimeout = 5 * time.Minute ) // ---------- Config types ---------- @@ -75,6 +78,11 @@ type Config struct { UserCount bool `json:"user_count"` + // SSHIdleTimeout controls how long an authenticated SSH connection may + // remain with no bytes moving in either direction before it is closed and + // released from the active user count. Empty = default 5m. Use "0s" to disable. + SSHIdleTimeout string `json:"ssh_idle_timeout,omitempty"` + // NEW: Directory to serve the admin panel from AdminDir string `json:"admin_dir"` @@ -278,9 +286,12 @@ func (m *UserManager) ReplaceAll(newUsers map[string]*UserState) { m.users = newUsers } -// ReplaceAllPreserveRuntime replaces the user map while keeping runtime connection -// state (ActiveConns + conns) for users that already exist. -// This prevents the admin panel from showing everyone as "offline" after a DB reload. +// ReplaceAllPreserveRuntime replaces the user map while keeping the same +// runtime UserState object for users that already exist. This is important: +// active handleConn goroutines hold a pointer to the old UserState and run the +// decrement in a defer. If we copy ActiveConns into a new UserState during a DB +// reload, that later decrement happens on the old object and the visible count +// in the new map stays stuck. func (m *UserManager) ReplaceAllPreserveRuntime(newUsers map[string]*UserState) { m.mu.Lock() old := m.users @@ -288,10 +299,17 @@ func (m *UserManager) ReplaceAllPreserveRuntime(newUsers map[string]*UserState) for username, nu := range newUsers { if ou, ok := old[username]; ok && ou != nil && nu != nil { ou.mu.Lock() - nu.ActiveConns = ou.ActiveConns - // Preserve the live connection set so we can still disconnect correctly. - nu.conns = ou.conns + ou.Cfg = nu.Cfg + ou.ExpiresAt = nu.ExpiresAt + ou.PubKey = nu.PubKey + if ou.conns == nil { + ou.conns = make(map[*ssh.ServerConn]struct{}) + } + // Self-heal any previously stale counter by trusting the live connection map. + ou.ActiveConns = len(ou.conns) ou.mu.Unlock() + + newUsers[username] = ou } } @@ -326,6 +344,9 @@ var ( userMgr = &UserManager{users: make(map[string]*UserState)} userCountEnabled bool + sshIdleTimeoutMu sync.RWMutex + currentSSHIdleTimeout = defaultSSHIdleTimeout + displayMu sync.Mutex lastDisplayLen int @@ -395,6 +416,111 @@ func copyWithRateLimit(dst io.Writer, src io.Reader, lim *rate.Limiter) (written return written, err } +func parseSSHIdleTimeout(raw string) time.Duration { + raw = strings.TrimSpace(raw) + if raw == "" { + return defaultSSHIdleTimeout + } + d, err := time.ParseDuration(raw) + if err != nil { + log.Printf("invalid ssh_idle_timeout %q: %v; using default %s", raw, err, defaultSSHIdleTimeout) + return defaultSSHIdleTimeout + } + if d < 0 { + log.Printf("invalid negative ssh_idle_timeout %q; using default %s", raw, defaultSSHIdleTimeout) + return defaultSSHIdleTimeout + } + return d +} + +func setSSHIdleTimeoutFromConfig(raw string) { + d := parseSSHIdleTimeout(raw) + sshIdleTimeoutMu.Lock() + currentSSHIdleTimeout = d + sshIdleTimeoutMu.Unlock() +} + +func getSSHIdleTimeout() time.Duration { + sshIdleTimeoutMu.RLock() + d := currentSSHIdleTimeout + sshIdleTimeoutMu.RUnlock() + return d +} + +// activityConn tracks real SSH transport activity in both directions. The idle +// monitor uses this instead of a read deadline so download-only or upload-only +// tunnels are considered live and are not disconnected. +type activityConn struct { + net.Conn + mu sync.Mutex + last time.Time +} + +func newActivityConn(c net.Conn) *activityConn { + return &activityConn{Conn: c, last: time.Now()} +} + +func (c *activityConn) touch() { + c.mu.Lock() + c.last = time.Now() + c.mu.Unlock() +} + +func (c *activityConn) LastActivity() time.Time { + c.mu.Lock() + last := c.last + c.mu.Unlock() + return last +} + +func (c *activityConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + if n > 0 { + c.touch() + } + return n, err +} + +func (c *activityConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + if n > 0 { + c.touch() + } + return n, err +} + +func monitorSSHIdle(c *activityConn, sshConn *ssh.ServerConn, username string, idleTimeout time.Duration, done <-chan struct{}) { + if idleTimeout <= 0 { + return + } + checkEvery := idleTimeout / 4 + if checkEvery < 5*time.Second { + checkEvery = 5 * time.Second + } + if checkEvery > 30*time.Second { + checkEvery = 30 * time.Second + } + + ticker := time.NewTicker(checkEvery) + defer ticker.Stop() + + for { + select { + case <-done: + return + case <-ticker.C: + idleFor := time.Since(c.LastActivity()) + if idleFor >= idleTimeout { + log.Printf("ssh idle timeout: user=%s remote=%s idle=%s limit=%s; closing stale connection", + username, sshConn.RemoteAddr(), idleFor.Round(time.Second), idleTimeout) + _ = sshConn.Close() + _ = c.Close() + return + } + } + } +} + // ---------- Server stats (CPU + network interfaces) ---------- // per-interface stats returned by /api/stats @@ -1165,7 +1291,8 @@ func handleListUsers(w http.ResponseWriter, r *http.Request) { out := make([]UserDTO, 0, len(states)) for _, u := range states { u.mu.Lock() - c := u.ActiveConns + c := len(u.conns) + u.ActiveConns = c cfg := u.Cfg expires := u.ExpiresAt u.mu.Unlock() @@ -1536,18 +1663,20 @@ func publicKeyCallback(meta ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissio // ---------- Connection handling ---------- func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) { - defer tcpConn.Close() + trackedConn := newActivityConn(tcpConn) + defer trackedConn.Close() // Prevent goroutine leaks from clients that connect but never complete the SSH handshake. - _ = tcpConn.SetReadDeadline(time.Now().Add(sshHandshakeTimeout)) + _ = trackedConn.SetReadDeadline(time.Now().Add(sshHandshakeTimeout)) - sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) + sshConn, chans, reqs, err := ssh.NewServerConn(trackedConn, config) if err != nil { log.Println("ssh handshake failed:", err) return } - // Clear deadlines after a successful handshake. - _ = tcpConn.SetDeadline(time.Time{}) + // Clear deadlines after a successful handshake. Runtime cleanup is handled + // by monitorSSHIdle, which checks traffic in both directions. + _ = trackedConn.SetDeadline(time.Time{}) username := sshConn.User() log.Printf("new SSH connection from %s as %s", sshConn.RemoteAddr(), username) @@ -1558,19 +1687,22 @@ func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) { return } - // Track active connection and enforce max_connections + // Track active connection and enforce max_connections. The connection map is + // treated as the source of truth so stale counters can self-heal. u.mu.Lock() - if u.Cfg.MaxConnections > 0 && u.ActiveConns >= u.Cfg.MaxConnections { + if u.conns == nil { + u.conns = make(map[*ssh.ServerConn]struct{}) + } + activeConns := len(u.conns) + u.ActiveConns = activeConns + if u.Cfg.MaxConnections > 0 && activeConns >= u.Cfg.MaxConnections { u.mu.Unlock() log.Printf("user %s exceeded max_connections (%d)", username, u.Cfg.MaxConnections) sshConn.Close() return } - u.ActiveConns++ - if u.conns == nil { - u.conns = make(map[*ssh.ServerConn]struct{}) - } u.conns[sshConn] = struct{}{} + u.ActiveConns = len(u.conns) u.mu.Unlock() // Use per-user limit if set; otherwise fall back to the global config default. @@ -1596,10 +1728,16 @@ func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) { updateUserDisplay() + idleDone := make(chan struct{}) + if idleTimeout := getSSHIdleTimeout(); idleTimeout > 0 { + go monitorSSHIdle(trackedConn, sshConn, username, idleTimeout, idleDone) + } + defer func() { + close(idleDone) u.mu.Lock() - u.ActiveConns-- delete(u.conns, sshConn) + u.ActiveConns = len(u.conns) u.mu.Unlock() updateUserDisplay() @@ -1727,7 +1865,8 @@ func updateUserDisplay() { parts := make([]string, 0, len(userStates)) for _, u := range userStates { u.mu.Lock() - c := u.ActiveConns + c := len(u.conns) + u.ActiveConns = c name := u.Cfg.Username u.mu.Unlock() parts = append(parts, fmt.Sprintf("%s: %d", name, c)) @@ -2494,8 +2633,9 @@ func main() { log.SetOutput(io.Discard) } - // Initialise default per-connection bandwidth limits. + // Initialise default per-connection bandwidth limits and SSH inactivity cleanup. setDefaultLimits(cfg.DefaultLimitMbpsUp, cfg.DefaultLimitMbpsDown) + setSSHIdleTimeoutFromConfig(cfg.SSHIdleTimeout) // Initialise listener pools (used for initial startup and hot-reload alike). publicPool = newListenerPool(serveHTTP80)