Fix stuck users

This commit is contained in:
2026-05-03 10:15:28 -03:00
parent 09f3959aa2
commit 43482c88fa
3 changed files with 166 additions and 23 deletions

View File

@@ -59,7 +59,8 @@ func checkSSHUser(w http.ResponseWriter, username string) {
}
u.mu.Lock()
activeConns := u.ActiveConns
activeConns := len(u.conns)
u.ActiveConns = activeConns
maxConns := u.Cfg.MaxConnections
expiresAt := u.ExpiresAt
u.mu.Unlock()

View File

@@ -292,8 +292,10 @@ func applyFullConfigReload(newCfg *Config) ConfigReloadReport {
}
setBannerText(bt)
// Default per-connection bandwidth limits (picked up by new connections)
// Default per-connection bandwidth limits and SSH inactivity cleanup
// (picked up by new connections).
setDefaultLimits(newCfg.DefaultLimitMbpsUp, newCfg.DefaultLimitMbpsDown)
setSSHIdleTimeoutFromConfig(newCfg.SSHIdleTimeout)
// Quiet logging / user count display
if newCfg.Quiet {

182
main.go
View File

@@ -35,6 +35,9 @@ const (
tlsHandshakeTimeout = 15 * time.Second
// Dial timeout for direct-tcpip backend connections.
directTCPIPDialTimeout = 10 * time.Second
// Default post-auth SSH inactivity timeout. This is based on real bytes
// moving in either direction, so live upload/download tunnels are not closed.
defaultSSHIdleTimeout = 5 * time.Minute
)
// ---------- Config types ----------
@@ -75,6 +78,11 @@ type Config struct {
UserCount bool `json:"user_count"`
// SSHIdleTimeout controls how long an authenticated SSH connection may
// remain with no bytes moving in either direction before it is closed and
// released from the active user count. Empty = default 5m. Use "0s" to disable.
SSHIdleTimeout string `json:"ssh_idle_timeout,omitempty"`
// NEW: Directory to serve the admin panel from
AdminDir string `json:"admin_dir"`
@@ -278,9 +286,12 @@ func (m *UserManager) ReplaceAll(newUsers map[string]*UserState) {
m.users = newUsers
}
// ReplaceAllPreserveRuntime replaces the user map while keeping runtime connection
// state (ActiveConns + conns) for users that already exist.
// This prevents the admin panel from showing everyone as "offline" after a DB reload.
// ReplaceAllPreserveRuntime replaces the user map while keeping the same
// runtime UserState object for users that already exist. This is important:
// active handleConn goroutines hold a pointer to the old UserState and run the
// decrement in a defer. If we copy ActiveConns into a new UserState during a DB
// reload, that later decrement happens on the old object and the visible count
// in the new map stays stuck.
func (m *UserManager) ReplaceAllPreserveRuntime(newUsers map[string]*UserState) {
m.mu.Lock()
old := m.users
@@ -288,10 +299,17 @@ func (m *UserManager) ReplaceAllPreserveRuntime(newUsers map[string]*UserState)
for username, nu := range newUsers {
if ou, ok := old[username]; ok && ou != nil && nu != nil {
ou.mu.Lock()
nu.ActiveConns = ou.ActiveConns
// Preserve the live connection set so we can still disconnect correctly.
nu.conns = ou.conns
ou.Cfg = nu.Cfg
ou.ExpiresAt = nu.ExpiresAt
ou.PubKey = nu.PubKey
if ou.conns == nil {
ou.conns = make(map[*ssh.ServerConn]struct{})
}
// Self-heal any previously stale counter by trusting the live connection map.
ou.ActiveConns = len(ou.conns)
ou.mu.Unlock()
newUsers[username] = ou
}
}
@@ -326,6 +344,9 @@ var (
userMgr = &UserManager{users: make(map[string]*UserState)}
userCountEnabled bool
sshIdleTimeoutMu sync.RWMutex
currentSSHIdleTimeout = defaultSSHIdleTimeout
displayMu sync.Mutex
lastDisplayLen int
@@ -395,6 +416,111 @@ func copyWithRateLimit(dst io.Writer, src io.Reader, lim *rate.Limiter) (written
return written, err
}
func parseSSHIdleTimeout(raw string) time.Duration {
raw = strings.TrimSpace(raw)
if raw == "" {
return defaultSSHIdleTimeout
}
d, err := time.ParseDuration(raw)
if err != nil {
log.Printf("invalid ssh_idle_timeout %q: %v; using default %s", raw, err, defaultSSHIdleTimeout)
return defaultSSHIdleTimeout
}
if d < 0 {
log.Printf("invalid negative ssh_idle_timeout %q; using default %s", raw, defaultSSHIdleTimeout)
return defaultSSHIdleTimeout
}
return d
}
func setSSHIdleTimeoutFromConfig(raw string) {
d := parseSSHIdleTimeout(raw)
sshIdleTimeoutMu.Lock()
currentSSHIdleTimeout = d
sshIdleTimeoutMu.Unlock()
}
func getSSHIdleTimeout() time.Duration {
sshIdleTimeoutMu.RLock()
d := currentSSHIdleTimeout
sshIdleTimeoutMu.RUnlock()
return d
}
// activityConn tracks real SSH transport activity in both directions. The idle
// monitor uses this instead of a read deadline so download-only or upload-only
// tunnels are considered live and are not disconnected.
type activityConn struct {
net.Conn
mu sync.Mutex
last time.Time
}
func newActivityConn(c net.Conn) *activityConn {
return &activityConn{Conn: c, last: time.Now()}
}
func (c *activityConn) touch() {
c.mu.Lock()
c.last = time.Now()
c.mu.Unlock()
}
func (c *activityConn) LastActivity() time.Time {
c.mu.Lock()
last := c.last
c.mu.Unlock()
return last
}
func (c *activityConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
if n > 0 {
c.touch()
}
return n, err
}
func (c *activityConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
if n > 0 {
c.touch()
}
return n, err
}
func monitorSSHIdle(c *activityConn, sshConn *ssh.ServerConn, username string, idleTimeout time.Duration, done <-chan struct{}) {
if idleTimeout <= 0 {
return
}
checkEvery := idleTimeout / 4
if checkEvery < 5*time.Second {
checkEvery = 5 * time.Second
}
if checkEvery > 30*time.Second {
checkEvery = 30 * time.Second
}
ticker := time.NewTicker(checkEvery)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
idleFor := time.Since(c.LastActivity())
if idleFor >= idleTimeout {
log.Printf("ssh idle timeout: user=%s remote=%s idle=%s limit=%s; closing stale connection",
username, sshConn.RemoteAddr(), idleFor.Round(time.Second), idleTimeout)
_ = sshConn.Close()
_ = c.Close()
return
}
}
}
}
// ---------- Server stats (CPU + network interfaces) ----------
// per-interface stats returned by /api/stats
@@ -1165,7 +1291,8 @@ func handleListUsers(w http.ResponseWriter, r *http.Request) {
out := make([]UserDTO, 0, len(states))
for _, u := range states {
u.mu.Lock()
c := u.ActiveConns
c := len(u.conns)
u.ActiveConns = c
cfg := u.Cfg
expires := u.ExpiresAt
u.mu.Unlock()
@@ -1536,18 +1663,20 @@ func publicKeyCallback(meta ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissio
// ---------- Connection handling ----------
func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) {
defer tcpConn.Close()
trackedConn := newActivityConn(tcpConn)
defer trackedConn.Close()
// Prevent goroutine leaks from clients that connect but never complete the SSH handshake.
_ = tcpConn.SetReadDeadline(time.Now().Add(sshHandshakeTimeout))
_ = trackedConn.SetReadDeadline(time.Now().Add(sshHandshakeTimeout))
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
sshConn, chans, reqs, err := ssh.NewServerConn(trackedConn, config)
if err != nil {
log.Println("ssh handshake failed:", err)
return
}
// Clear deadlines after a successful handshake.
_ = tcpConn.SetDeadline(time.Time{})
// Clear deadlines after a successful handshake. Runtime cleanup is handled
// by monitorSSHIdle, which checks traffic in both directions.
_ = trackedConn.SetDeadline(time.Time{})
username := sshConn.User()
log.Printf("new SSH connection from %s as %s", sshConn.RemoteAddr(), username)
@@ -1558,19 +1687,22 @@ func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) {
return
}
// Track active connection and enforce max_connections
// Track active connection and enforce max_connections. The connection map is
// treated as the source of truth so stale counters can self-heal.
u.mu.Lock()
if u.Cfg.MaxConnections > 0 && u.ActiveConns >= u.Cfg.MaxConnections {
if u.conns == nil {
u.conns = make(map[*ssh.ServerConn]struct{})
}
activeConns := len(u.conns)
u.ActiveConns = activeConns
if u.Cfg.MaxConnections > 0 && activeConns >= u.Cfg.MaxConnections {
u.mu.Unlock()
log.Printf("user %s exceeded max_connections (%d)", username, u.Cfg.MaxConnections)
sshConn.Close()
return
}
u.ActiveConns++
if u.conns == nil {
u.conns = make(map[*ssh.ServerConn]struct{})
}
u.conns[sshConn] = struct{}{}
u.ActiveConns = len(u.conns)
u.mu.Unlock()
// Use per-user limit if set; otherwise fall back to the global config default.
@@ -1596,10 +1728,16 @@ func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) {
updateUserDisplay()
idleDone := make(chan struct{})
if idleTimeout := getSSHIdleTimeout(); idleTimeout > 0 {
go monitorSSHIdle(trackedConn, sshConn, username, idleTimeout, idleDone)
}
defer func() {
close(idleDone)
u.mu.Lock()
u.ActiveConns--
delete(u.conns, sshConn)
u.ActiveConns = len(u.conns)
u.mu.Unlock()
updateUserDisplay()
@@ -1727,7 +1865,8 @@ func updateUserDisplay() {
parts := make([]string, 0, len(userStates))
for _, u := range userStates {
u.mu.Lock()
c := u.ActiveConns
c := len(u.conns)
u.ActiveConns = c
name := u.Cfg.Username
u.mu.Unlock()
parts = append(parts, fmt.Sprintf("%s: %d", name, c))
@@ -2494,8 +2633,9 @@ func main() {
log.SetOutput(io.Discard)
}
// Initialise default per-connection bandwidth limits.
// Initialise default per-connection bandwidth limits and SSH inactivity cleanup.
setDefaultLimits(cfg.DefaultLimitMbpsUp, cfg.DefaultLimitMbpsDown)
setSSHIdleTimeoutFromConfig(cfg.SSHIdleTimeout)
// Initialise listener pools (used for initial startup and hot-reload alike).
publicPool = newListenerPool(serveHTTP80)