From 6f677d272a3dce716f78ed9850eae6e650bfb43b Mon Sep 17 00:00:00 2001 From: penguinehis Date: Sat, 2 May 2026 18:42:58 -0300 Subject: [PATCH] Launch --- .gitignore | 27 + README.md | 379 +++++++ admin/index.html | 2150 +++++++++++++++++++++++++++++++++++ auth.go | 682 ++++++++++++ check_api.go | 121 ++ dnstt_integration.go | 1158 +++++++++++++++++++ go.mod | 20 + go.sum | 104 ++ hotreload.go | 269 +++++ install.sh | 386 +++++++ main.go | 2532 ++++++++++++++++++++++++++++++++++++++++++ server_config_api.go | 117 ++ tls_api.go | 168 +++ udpgw_integration.go | 490 ++++++++ update.sh | 199 ++++ xray_clients.go | 164 +++ xray_integration.go | 692 ++++++++++++ 17 files changed, 9658 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 admin/index.html create mode 100644 auth.go create mode 100644 check_api.go create mode 100644 dnstt_integration.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 hotreload.go create mode 100644 install.sh create mode 100644 main.go create mode 100644 server_config_api.go create mode 100644 tls_api.go create mode 100644 udpgw_integration.go create mode 100644 update.sh create mode 100644 xray_clients.go create mode 100644 xray_integration.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..67c4eb1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# Build output +sshpanel +sshpanel.bak +*.bak + +# Runtime/generated config +.env +config.json +xray_config.json +banner.txt + +# Secrets / keys / certificates +keys/ +certs/ +*.pem +*.key +ssh_host_*_key +ssh_host_*_key.pub + +# Logs / runtime data +logs/ +*.log + +# Local/editor +.DS_Store +.vscode/ +.idea/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..b29ac71 --- /dev/null +++ b/README.md @@ -0,0 +1,379 @@ +# DragonCoreSSH V40 + +## PT-BR + +DragonCoreSSH V40 é um painel/servidor em Go para SSH com HTTP Injection, painel web, PostgreSQL, integração com Xray-core e API pública para consultar status de usuário. + +### Requisitos + +- Servidor Linux com `systemd` +- Acesso `root` ou `sudo` +- Gerenciador de pacotes `apt`, `yum` ou `dnf` +- Portas liberadas no firewall/security group conforme a configuração usada + +Distribuições alvo: + +- Ubuntu / Debian / Linux Mint +- CentOS / RHEL / Rocky / AlmaLinux +- Fedora + +### Instalação + +Clone o projeto e execute o instalador: + +```bash +git clone +cd +sudo bash install.sh +``` + +Durante a instalação, o script instala/configura: + +- Go +- PostgreSQL +- Xray-core +- Binário do DragonCoreSSH V40 +- Serviço `systemd` chamado `sshpanel` +- Painel web +- Arquivos de runtime em `/opt/sshpanel` + +Ao finalizar, o instalador mostra os dados principais: + +```text +Server IP +SSH ports +VLESS port +VLESS UUID +Admin panel URL +Admin login +Admin password +Admin token +``` + +### Caminhos principais + +```text +/opt/sshpanel/sshpanel +/opt/sshpanel/.env +/opt/sshpanel/config.json +/opt/sshpanel/xray_config.json +/opt/sshpanel/admin/ +/opt/sshpanel/logs/panel.log +/etc/systemd/system/sshpanel.service +``` + +### Portas padrão + +```text +80 SSH com HTTP Injection +8080 SSH extra com HTTP Injection +9090 Painel web + API pública /check +10086 Xray VLESS +10088 SOCKS local em 127.0.0.1 +``` + +### Comandos úteis + +Ver status do serviço: + +```bash +systemctl status sshpanel --no-pager -l +``` + +Ver logs pelo `journalctl`: + +```bash +journalctl -u sshpanel -f +``` + +Ver log direto do painel: + +```bash +tail -f /opt/sshpanel/logs/panel.log +``` + +Reiniciar serviço: + +```bash +systemctl restart sshpanel +``` + +### Atualização + +Entre na nova pasta do código e execute: + +```bash +sudo bash update.sh +``` + +O update recompila o binário e atualiza os arquivos do painel web, mantendo as configurações e dados existentes. + +### API pública CheckUser + +Endpoint: + +```http +GET /check +``` + +URL padrão: + +```text +http://SERVER_IP:9090/check +``` + +Consultar usuário SSH: + +```bash +curl "http://SERVER_IP:9090/check?user=testuser" +``` + +Consultar UUID Xray: + +```bash +curl "http://SERVER_IP:9090/check?uuid=a499cb67-6c73-43cc-a84d-92cbb68d22d1" +``` + +Se `user` e `uuid` forem enviados juntos, `user` tem prioridade. + +Resposta de sucesso: + +```json +{ + "username": "testuser", + "count_connections": 1, + "expiration_date": "31/12/2026", + "expiration_days": 243, + "limit_connections": 2 +} +``` + +Conta ilimitada: + +```json +{ + "username": "testuser", + "count_connections": 0, + "expiration_date": "Unlimited", + "expiration_days": -1, + "limit_connections": 1 +} +``` + +Campos da resposta: + +| Campo | Tipo | Descrição | +| --- | --- | --- | +| `username` | string | Usuário SSH, nome do cliente Xray ou UUID. | +| `count_connections` | number | Conexões SSH ativas no momento. | +| `expiration_date` | string | Data de expiração em `DD/MM/YYYY` ou `Unlimited`. | +| `expiration_days` | number | Dias restantes. `-1` significa ilimitado. | +| `limit_connections` | number | Limite máximo de conexões. | + +Erros comuns: + +```json +{"error":"user or uuid parameter required"} +``` + +```json +{"error":"user not found"} +``` + +```json +{"error":"uuid not found"} +``` + +```json +{"error":"database not configured"} +``` + +--- + +## EN-US + +DragonCoreSSH V40 is a Go-based SSH HTTP Injection server with a web panel, PostgreSQL, Xray-core integration, and a public API for checking user status. + +### Requirements + +- Linux server with `systemd` +- `root` or `sudo` access +- `apt`, `yum`, or `dnf` package manager +- Required ports opened in the firewall/security group + +Target distributions: + +- Ubuntu / Debian / Linux Mint +- CentOS / RHEL / Rocky / AlmaLinux +- Fedora + +### Installation + +Clone the project and run the installer: + +```bash +git clone +cd +sudo bash install.sh +``` + +During installation, the script installs/configures: + +- Go +- PostgreSQL +- Xray-core +- DragonCoreSSH V40 binary +- `systemd` service named `sshpanel` +- Web panel +- Runtime files in `/opt/sshpanel` + +When finished, the installer prints the main access details: + +```text +Server IP +SSH ports +VLESS port +VLESS UUID +Admin panel URL +Admin login +Admin password +Admin token +``` + +### Main paths + +```text +/opt/sshpanel/sshpanel +/opt/sshpanel/.env +/opt/sshpanel/config.json +/opt/sshpanel/xray_config.json +/opt/sshpanel/admin/ +/opt/sshpanel/logs/panel.log +/etc/systemd/system/sshpanel.service +``` + +### Default ports + +```text +80 SSH with HTTP Injection +8080 Extra SSH with HTTP Injection +9090 Web panel + public /check API +10086 Xray VLESS +10088 Local SOCKS on 127.0.0.1 +``` + +### Useful commands + +Check service status: + +```bash +systemctl status sshpanel --no-pager -l +``` + +Follow logs with `journalctl`: + +```bash +journalctl -u sshpanel -f +``` + +Follow panel log file: + +```bash +tail -f /opt/sshpanel/logs/panel.log +``` + +Restart service: + +```bash +systemctl restart sshpanel +``` + +### Update + +Enter the new source-code folder and run: + +```bash +sudo bash update.sh +``` + +The update script rebuilds the binary and updates the web panel files while keeping existing configuration and user data. + +### Public CheckUser API + +Endpoint: + +```http +GET /check +``` + +Default URL: + +```text +http://SERVER_IP:9090/check +``` + +Check SSH username: + +```bash +curl "http://SERVER_IP:9090/check?user=testuser" +``` + +Check Xray UUID: + +```bash +curl "http://SERVER_IP:9090/check?uuid=a499cb67-6c73-43cc-a84d-92cbb68d22d1" +``` + +If both `user` and `uuid` are sent, `user` has priority. + +Success response: + +```json +{ + "username": "testuser", + "count_connections": 1, + "expiration_date": "31/12/2026", + "expiration_days": 243, + "limit_connections": 2 +} +``` + +Unlimited account: + +```json +{ + "username": "testuser", + "count_connections": 0, + "expiration_date": "Unlimited", + "expiration_days": -1, + "limit_connections": 1 +} +``` + +Response fields: + +| Field | Type | Description | +| --- | --- | --- | +| `username` | string | SSH username, Xray client name, or UUID. | +| `count_connections` | number | Current active SSH connections. | +| `expiration_date` | string | Expiration date in `DD/MM/YYYY` or `Unlimited`. | +| `expiration_days` | number | Remaining days. `-1` means unlimited. | +| `limit_connections` | number | Maximum connection limit. | + +Common errors: + +```json +{"error":"user or uuid parameter required"} +``` + +```json +{"error":"user not found"} +``` + +```json +{"error":"uuid not found"} +``` + +```json +{"error":"database not configured"} +``` diff --git a/admin/index.html b/admin/index.html new file mode 100644 index 0000000..6a3e3b3 --- /dev/null +++ b/admin/index.html @@ -0,0 +1,2150 @@ + + + + +SSH Panel + + + + +
+
+ + +
+
+
SSH Panel
+
Sign in with your admin or reseller credentials.
+ + + +
+
+
+ + + +
+
+ + + + diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..8369195 --- /dev/null +++ b/auth.go @@ -0,0 +1,682 @@ +package main + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "net/http" + "sync" + "time" +) + +const ( + RoleSuperAdmin = "superadmin" + RoleReseller = "reseller" + sessionTTL = 12 * time.Hour +) + +// ---------- AdminUser ---------- + +type AdminUser struct { + ID int + Username string + PasswordHash string + Role string + MaxUsers int + ExpiresAt *time.Time + IsActive bool + CreatedAt time.Time +} + +// ---------- Session store (in-memory) ---------- + +type AdminSession struct { + Token string + UserID int + Username string + Role string + ExpiresAt time.Time +} + +type sessionStoreT struct { + mu sync.RWMutex + m map[string]*AdminSession +} + +var sessions = &sessionStoreT{m: make(map[string]*AdminSession)} + +func (s *sessionStoreT) Create(userID int, username, role string) *AdminSession { + b := make([]byte, 32) + _, _ = rand.Read(b) + tok := hex.EncodeToString(b) + sess := &AdminSession{ + Token: tok, + UserID: userID, + Username: username, + Role: role, + ExpiresAt: time.Now().Add(sessionTTL), + } + s.mu.Lock() + s.m[tok] = sess + s.mu.Unlock() + return sess +} + +func (s *sessionStoreT) Get(token string) *AdminSession { + if token == "" { + return nil + } + s.mu.RLock() + sess := s.m[token] + s.mu.RUnlock() + if sess == nil || time.Now().After(sess.ExpiresAt) { + return nil + } + return sess +} + +func (s *sessionStoreT) Delete(token string) { + s.mu.Lock() + delete(s.m, token) + s.mu.Unlock() +} + +func (s *sessionStoreT) cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + now := time.Now() + for tok, sess := range s.m { + if now.After(sess.ExpiresAt) { + delete(s.m, tok) + } + } +} + +// ---------- In-memory AdminUser cache ---------- + +type adminUserMgrT struct { + mu sync.RWMutex + m map[string]*AdminUser +} + +var adminUsers = &adminUserMgrT{m: make(map[string]*AdminUser)} + +func (m *adminUserMgrT) set(u *AdminUser) { + m.mu.Lock() + m.m[u.Username] = u + m.mu.Unlock() +} + +func (m *adminUserMgrT) get(username string) (*AdminUser, bool) { + m.mu.RLock() + u, ok := m.m[username] + m.mu.RUnlock() + return u, ok +} + +func (m *adminUserMgrT) delete(username string) { + m.mu.Lock() + delete(m.m, username) + m.mu.Unlock() +} + +func (m *adminUserMgrT) list() []*AdminUser { + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]*AdminUser, 0, len(m.m)) + for _, u := range m.m { + cp := *u + out = append(out, &cp) + } + return out +} + +func (m *adminUserMgrT) replaceAll(users []*AdminUser) { + m.mu.Lock() + m.m = make(map[string]*AdminUser, len(users)) + for _, u := range users { + cp := *u + m.m[u.Username] = &cp + } + m.mu.Unlock() +} + +// ---------- Context helpers ---------- + +type ctxKeyAdmin struct{} + +func withSession(ctx context.Context, s *AdminSession) context.Context { + return context.WithValue(ctx, ctxKeyAdmin{}, s) +} + +func sessionFromCtx(ctx context.Context) *AdminSession { + s, _ := ctx.Value(ctxKeyAdmin{}).(*AdminSession) + return s +} + +// ---------- Middleware ---------- + +// sessionMiddleware requires a valid X-Session-Token header. +func sessionMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("X-Session-Token") + s := sessions.Get(token) + if s == nil { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r.WithContext(withSession(r.Context(), s))) + }) +} + +// superAdminOnly wraps a handler to require role == superadmin. +// Must be used AFTER sessionMiddleware (session must be in context). +func superAdminOnly(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := sessionFromCtx(r.Context()) + if s == nil || s.Role != RoleSuperAdmin { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) +} + +// saSession chains sessionMiddleware + superAdminOnly. +func saSession(next http.Handler) http.Handler { + return sessionMiddleware(superAdminOnly(next)) +} + +// ---------- Password hashing ---------- + +func hashAdminPassword(pw string) string { + h := sha256.Sum256([]byte(pw)) + return hex.EncodeToString(h[:]) +} + +// ---------- DB methods on Store ---------- + +func (s *Store) EnsureAdminUsersSchema(ctx context.Context) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS admin_users ( + id SERIAL PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'reseller', + max_users INT NOT NULL DEFAULT 30, + expires_at TIMESTAMPTZ, + is_active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + )`, + `ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS owner_username TEXT NOT NULL DEFAULT ''`, + } + for _, stmt := range stmts { + if _, err := s.db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("EnsureAdminUsersSchema: %w", err) + } + } + return nil +} + +func (s *Store) GetAdminUserByUsername(ctx context.Context, username string) (*AdminUser, error) { + u := &AdminUser{} + var expiresAt sql.NullTime + err := s.db.QueryRowContext(ctx, + `SELECT id, username, password_hash, role, max_users, expires_at, is_active, created_at + FROM admin_users WHERE username = $1`, username, + ).Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, &u.MaxUsers, + &expiresAt, &u.IsActive, &u.CreatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + if expiresAt.Valid { + u.ExpiresAt = &expiresAt.Time + } + return u, nil +} + +func (s *Store) ListAdminUsers(ctx context.Context) ([]*AdminUser, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, username, password_hash, role, max_users, expires_at, is_active, created_at + FROM admin_users ORDER BY role, username`) + if err != nil { + return nil, err + } + defer rows.Close() + var out []*AdminUser + for rows.Next() { + u := &AdminUser{} + var expiresAt sql.NullTime + if err := rows.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, + &u.MaxUsers, &expiresAt, &u.IsActive, &u.CreatedAt); err != nil { + return nil, err + } + if expiresAt.Valid { + u.ExpiresAt = &expiresAt.Time + } + out = append(out, u) + } + return out, rows.Err() +} + +func (s *Store) UpsertAdminUser(ctx context.Context, u *AdminUser) error { + var expiresAt interface{} + if u.ExpiresAt != nil { + expiresAt = *u.ExpiresAt + } + if u.ID == 0 { + return s.db.QueryRowContext(ctx, + `INSERT INTO admin_users (username, password_hash, role, max_users, expires_at, is_active) + VALUES ($1,$2,$3,$4,$5,$6) RETURNING id`, + u.Username, u.PasswordHash, u.Role, u.MaxUsers, expiresAt, u.IsActive, + ).Scan(&u.ID) + } + _, err := s.db.ExecContext(ctx, + `UPDATE admin_users SET password_hash=$2, role=$3, max_users=$4, + expires_at=$5, is_active=$6 WHERE id=$1`, + u.ID, u.PasswordHash, u.Role, u.MaxUsers, expiresAt, u.IsActive) + return err +} + +func (s *Store) DeleteAdminUser(ctx context.Context, username string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM admin_users WHERE username=$1`, username) + return err +} + +func (s *Store) SetAdminUserActive(ctx context.Context, username string, active bool) error { + _, err := s.db.ExecContext(ctx, `UPDATE admin_users SET is_active=$1 WHERE username=$2`, active, username) + return err +} + +func (s *Store) ListExpiredResellers(ctx context.Context) ([]*AdminUser, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, username, password_hash, role, max_users, expires_at, is_active, created_at + FROM admin_users + WHERE role=$1 AND is_active=TRUE AND expires_at IS NOT NULL AND expires_at < NOW()`, + RoleReseller) + if err != nil { + return nil, err + } + defer rows.Close() + return scanAdminUsers(rows) +} + +func (s *Store) ListInactiveButRenewedResellers(ctx context.Context) ([]*AdminUser, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT id, username, password_hash, role, max_users, expires_at, is_active, created_at + FROM admin_users + WHERE role=$1 AND is_active=FALSE AND (expires_at IS NULL OR expires_at > NOW())`, + RoleReseller) + if err != nil { + return nil, err + } + defer rows.Close() + return scanAdminUsers(rows) +} + +func scanAdminUsers(rows *sql.Rows) ([]*AdminUser, error) { + var out []*AdminUser + for rows.Next() { + u := &AdminUser{} + var expiresAt sql.NullTime + if err := rows.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.Role, + &u.MaxUsers, &expiresAt, &u.IsActive, &u.CreatedAt); err != nil { + return nil, err + } + if expiresAt.Valid { + u.ExpiresAt = &expiresAt.Time + } + out = append(out, u) + } + return out, rows.Err() +} + +// BootstrapSuperAdmin creates a default "admin" superadmin if none exists. +// Returns the generated password, or "" if a superadmin already existed. +func (s *Store) BootstrapSuperAdmin(ctx context.Context) (string, error) { + var count int + if err := s.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM admin_users WHERE role=$1`, RoleSuperAdmin, + ).Scan(&count); err != nil { + return "", err + } + if count > 0 { + return "", nil + } + b := make([]byte, 10) + _, _ = rand.Read(b) + pw := hex.EncodeToString(b) + u := &AdminUser{ + Username: "admin", + PasswordHash: hashAdminPassword(pw), + Role: RoleSuperAdmin, + MaxUsers: 0, + IsActive: true, + } + if err := s.UpsertAdminUser(ctx, u); err != nil { + return "", err + } + return pw, nil +} + +// loadAdminUsersIntoCache reloads all admin_users rows into the in-memory cache. +func loadAdminUsersIntoCache(ctx context.Context, store *Store) error { + users, err := store.ListAdminUsers(ctx) + if err != nil { + return err + } + adminUsers.replaceAll(users) + return nil +} + +// ---------- Owner check (called from SSH auth callbacks) ---------- + +// ownerIsActive returns nil if an SSH user's reseller owner is active, or an error if suspended/expired. +func ownerIsActive(ownerUsername string) error { + if ownerUsername == "" { + return nil + } + u, ok := adminUsers.get(ownerUsername) + if !ok { + return fmt.Errorf("reseller account not found") + } + if !u.IsActive { + return fmt.Errorf("reseller account suspended") + } + if u.ExpiresAt != nil && time.Now().After(*u.ExpiresAt) { + return fmt.Errorf("reseller account expired") + } + return nil +} + +// disconnectOwnerUsers forcibly closes all active SSH connections for users owned by owner. +func disconnectOwnerUsers(ownerUsername string) { + for _, u := range userMgr.List() { + if u.Cfg.OwnerUsername == ownerUsername { + userMgr.DisconnectUser(u.Cfg.Username) + } + } +} + +// countOwnedUsers counts SSH users in memory that belong to owner. +func countOwnedUsers(ownerUsername string) int { + n := 0 + for _, u := range userMgr.List() { + if u.Cfg.OwnerUsername == ownerUsername { + n++ + } + } + return n +} + +// ---------- Reseller expiry background checker ---------- + +func startResellerExpiryChecker(store *Store) { + if store == nil { + return + } + go func() { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for range ticker.C { + ctx := context.Background() + + // Expire active resellers past their deadline + expired, err := store.ListExpiredResellers(ctx) + if err != nil { + log.Printf("reseller expiry check: %v", err) + } + for _, u := range expired { + log.Printf("reseller %s expired — suspending", u.Username) + if err := store.SetAdminUserActive(ctx, u.Username, false); err != nil { + log.Printf("reseller expiry: %v", err) + continue + } + u.IsActive = false + adminUsers.set(u) + disconnectOwnerUsers(u.Username) + } + + // Reactivate resellers that have been renewed (inactive but expiry now in future/nil) + renewed, err := store.ListInactiveButRenewedResellers(ctx) + if err != nil { + log.Printf("reseller renewal check: %v", err) + } + for _, u := range renewed { + log.Printf("reseller %s renewed — reactivating", u.Username) + if err := store.SetAdminUserActive(ctx, u.Username, true); err != nil { + log.Printf("reseller renewal: %v", err) + continue + } + u.IsActive = true + adminUsers.set(u) + } + + sessions.cleanup() + } + }() +} + +// ---------- HTTP handlers ---------- + +func handleLogin(store *Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if req.Username == "" || req.Password == "" { + http.Error(w, "username and password required", http.StatusBadRequest) + return + } + + u, err := store.GetAdminUserByUsername(r.Context(), req.Username) + if err != nil { + log.Printf("login db: %v", err) + http.Error(w, "server error", http.StatusInternalServerError) + return + } + if u == nil || u.PasswordHash != hashAdminPassword(req.Password) { + http.Error(w, "invalid credentials", http.StatusUnauthorized) + return + } + if !u.IsActive { + http.Error(w, "account suspended", http.StatusForbidden) + return + } + if u.ExpiresAt != nil && time.Now().After(*u.ExpiresAt) { + http.Error(w, "account expired", http.StatusForbidden) + return + } + + sess := sessions.Create(u.ID, u.Username, u.Role) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "token": sess.Token, + "username": u.Username, + "role": u.Role, + }) + } +} + +func handleLogout(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + sessions.Delete(r.Header.Get("X-Session-Token")) + w.WriteHeader(http.StatusOK) +} + +func handleMe(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + s := sessionFromCtx(r.Context()) + resp := map[string]interface{}{ + "username": s.Username, + "role": s.Role, + } + if s.Role == RoleReseller { + if u, ok := adminUsers.get(s.Username); ok { + resp["max_users"] = u.MaxUsers + resp["used_users"] = countOwnedUsers(s.Username) + resp["expires_at"] = u.ExpiresAt + resp["is_active"] = u.IsActive + } + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + +// ---------- Reseller management (superadmin only) ---------- + +type ResellerDTO struct { + ID int `json:"id"` + Username string `json:"username"` + Role string `json:"role"` + MaxUsers int `json:"max_users"` + UsedUsers int `json:"used_users"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + IsActive bool `json:"is_active"` + CreatedAt time.Time `json:"created_at"` +} + +func handleListResellers(store *Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + users, err := store.ListAdminUsers(r.Context()) + if err != nil { + http.Error(w, "db error", http.StatusInternalServerError) + return + } + out := make([]ResellerDTO, 0, len(users)) + for _, u := range users { + out = append(out, ResellerDTO{ + ID: u.ID, + Username: u.Username, + Role: u.Role, + MaxUsers: u.MaxUsers, + UsedUsers: countOwnedUsers(u.Username), + ExpiresAt: u.ExpiresAt, + IsActive: u.IsActive, + CreatedAt: u.CreatedAt, + }) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(out) + } +} + +type ResellerPayload struct { + Username string `json:"username"` + Password string `json:"password,omitempty"` + MaxUsers int `json:"max_users"` + ExpiresAt string `json:"expires_at"` + IsActive bool `json:"is_active"` +} + +func handleCreateReseller(store *Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var p ResellerPayload + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if p.Username == "" { + http.Error(w, "username required", http.StatusBadRequest) + return + } + + ctx := r.Context() + existing, err := store.GetAdminUserByUsername(ctx, p.Username) + if err != nil { + http.Error(w, "db error", http.StatusInternalServerError) + return + } + + var u *AdminUser + if existing != nil { + u = existing + } else { + if p.Password == "" { + http.Error(w, "password required for new account", http.StatusBadRequest) + return + } + u = &AdminUser{Username: p.Username, Role: RoleReseller} + } + + if p.Password != "" { + u.PasswordHash = hashAdminPassword(p.Password) + } + u.MaxUsers = p.MaxUsers + u.IsActive = p.IsActive + u.ExpiresAt = nil + if p.ExpiresAt != "" { + t, err := time.Parse(time.RFC3339, p.ExpiresAt) + if err != nil { + http.Error(w, "invalid expires_at (RFC3339 required)", http.StatusBadRequest) + return + } + u.ExpiresAt = &t + } + + if err := store.UpsertAdminUser(ctx, u); err != nil { + log.Printf("upsert reseller: %v", err) + http.Error(w, "db error", http.StatusInternalServerError) + return + } + adminUsers.set(u) + + // If reseller was reactivated, users can reconnect automatically. + // Reconnect of existing SSH connections happens via the expiry checker. + + w.WriteHeader(http.StatusCreated) + } +} + +func handleDeleteReseller(store *Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + username := r.URL.Query().Get("username") + if username == "" { + http.Error(w, "username required", http.StatusBadRequest) + return + } + ctx := r.Context() + if err := store.DeleteAdminUser(ctx, username); err != nil { + http.Error(w, "db error", http.StatusInternalServerError) + return + } + disconnectOwnerUsers(username) + adminUsers.delete(username) + w.WriteHeader(http.StatusNoContent) + } +} diff --git a/check_api.go b/check_api.go new file mode 100644 index 0000000..3240c04 --- /dev/null +++ b/check_api.go @@ -0,0 +1,121 @@ +package main + +import ( + "database/sql" + "encoding/json" + "net/http" + "time" +) + +// CheckResponse is returned by the public /check endpoint. +type CheckResponse struct { + Username string `json:"username"` + CountConnections int `json:"count_connections"` + ExpirationDate string `json:"expiration_date"` + ExpirationDays int `json:"expiration_days"` + LimitConnections int `json:"limit_connections"` +} + +// handleCheck is the public user-check API. No authentication required. +// Accepts ?user= or ?uuid= (or both; user takes priority). +// Returns JSON matching CheckResponse. +func handleCheck(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Content-Type", "application/json") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + q := r.URL.Query() + username := q.Get("user") + uuid := q.Get("uuid") + + if username == "" && uuid == "" { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"user or uuid parameter required"}`)) + return + } + + if username != "" { + checkSSHUser(w, username) + return + } + checkXrayUUID(w, r, uuid) +} + +func checkSSHUser(w http.ResponseWriter, username string) { + u, ok := userMgr.Get(username) + if !ok { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"user not found"}`)) + return + } + + u.mu.Lock() + activeConns := u.ActiveConns + maxConns := u.Cfg.MaxConnections + expiresAt := u.ExpiresAt + u.mu.Unlock() + + resp := CheckResponse{ + Username: username, + CountConnections: activeConns, + LimitConnections: maxConns, + } + fillExpiry(&resp, expiresAt) + _ = json.NewEncoder(w).Encode(resp) +} + +func checkXrayUUID(w http.ResponseWriter, r *http.Request, uuid string) { + if statsStore == nil { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":"database not configured"}`)) + return + } + + meta, err := statsStore.GetXrayClientMeta(r.Context(), uuid) + if err == sql.ErrNoRows { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"uuid not found"}`)) + return + } + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"database error"}`)) + return + } + + displayName := meta.Name + if displayName == "" { + displayName = meta.UUID + } + + resp := CheckResponse{ + Username: displayName, + CountConnections: 0, + LimitConnections: meta.MaxConns, + } + fillExpiry(&resp, meta.ExpiresAt) + _ = json.NewEncoder(w).Encode(resp) +} + +func fillExpiry(resp *CheckResponse, expiresAt *time.Time) { + if expiresAt == nil { + resp.ExpirationDate = "Unlimited" + resp.ExpirationDays = -1 + return + } + resp.ExpirationDate = expiresAt.Local().Format("02/01/2006") + days := int(time.Until(*expiresAt).Hours() / 24) + if days < 0 { + days = 0 + } + resp.ExpirationDays = days +} diff --git a/dnstt_integration.go b/dnstt_integration.go new file mode 100644 index 0000000..79f3d70 --- /dev/null +++ b/dnstt_integration.go @@ -0,0 +1,1158 @@ +package main + +// This file integrates a minimal DNS‑tunnel server (dnstt) into the main +// application. It is adapted from the public‑domain dnstt-server project +// (see https://www.bamsoftware.com/software/dnstt/) but modified to +// terminate streams into the existing SSH handler (handleConn) rather than +// forwarding them to an upstream TCP service. Each stream accepted via +// dnstt is wrapped as a net.Conn and passed to handleConn. The DNS +// transport itself uses KCP over UDP, Noise encryption and smux +// multiplexing as in the original dnstt. + +import ( + "bytes" + "encoding/base32" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/xtaci/kcp-go/v5" + "github.com/xtaci/smux" + "golang.org/x/crypto/ssh" + "www.bamsoftware.com/git/dnstt.git/dns" + "www.bamsoftware.com/git/dnstt.git/noise" + "www.bamsoftware.com/git/dnstt.git/turbotunnel" +) + +// ---------- Hot-reload stop mechanism ---------- + +var ( + dnsttConnMu sync.Mutex + dnsttConn net.PacketConn // active UDP socket; closing it stops runDNSTT +) + +// stopDNSTT closes the active DNSTT UDP listener, causing runDNSTT to exit. +// It is a no-op if DNSTT is not running. +func stopDNSTT() { + dnsttConnMu.Lock() + defer dnsttConnMu.Unlock() + if dnsttConn != nil { + _ = dnsttConn.Close() + dnsttConn = nil + } +} + +// Constants mirrored from dnstt-server. See dnstt-server/main.go for +// commentary. +const ( + // smux streams will be closed after this much time without receiving data. + idleTimeout = 2 * time.Minute + // How to set the TTL field in Answer resource records. + responseTTL = 60 + // How long we may wait for downstream data before sending an empty + // response. This number should be less than 2 seconds (Quad9 DNS + // timeout as of 2019). + maxResponseDelay = 1 * time.Second +) + +// We don't send UDP payloads larger than this, in an attempt to avoid +// network-layer fragmentation. 1280 is the minimum IPv6 MTU, 40 bytes is +// the size of an IPv6 header (without extension headers), and 8 bytes is +// the size of a UDP header【561853413345496†L97-L109】. +// Control this value with the -mtu command-line option in the standalone +// dnstt-server. Here we use the default. +// maxUDPPayload defines the maximum UDP payload we ever send in a DNS +// response. It defaults to the IPv6 minimum MTU minus the IPv6 and UDP +// header sizes (1232 octets) but is clamped per‑query in responseFor and +// sendLoop based on the EDNS UDP payload advertised by the client. See +// startDNSTT for how this may be overridden via configuration. +var maxUDPPayload = 1280 - 40 - 8 + +// noEDNSFallbackPayload is the assumed UDP payload capability for resolver +// paths that do not include an EDNS(0) OPT RR (or that clamp it to 512) but +// still reliably carry larger UDP DNS messages. Many DNSTT deployments rely +// on this behaviour in the wild. This value is used as a *floor* for the +// inferred payload limit and is still clamped by maxUDPPayload. +// +// If you want strict RFC behaviour, set this to 512. +const noEDNSFallbackPayload = 932 + +// dnsttPrintStats controls whether periodic statistics are printed to stderr. +// It is set based on the DNSTT configuration provided by the main program. +// When false, the periodic stats will still be collected and made available +// via the admin API, but no log lines will be emitted. The default is +// true to preserve existing behaviour. +var dnsttPrintStats = true + +// DnsttStatsSnapshot holds a recent snapshot of DNSTT counters over the +// previous 5‑second window. It is updated every 5 seconds by the +// runDNSTT goroutine. The Timestamp field records when the snapshot was +// taken. These values are surfaced via the admin API so that the web +// panel can display tunnel health without reading stderr logs. +type DnsttStatsSnapshot struct { + Timestamp time.Time `json:"timestamp"` + DNSRx uint64 `json:"dns_rx"` + ParseErr uint64 `json:"parse_err"` + NoEDNS uint64 `json:"no_edns"` + Limit512 uint64 `json:"limit512"` + RecQueued uint64 `json:"rec_queued"` + RecDropped uint64 `json:"rec_dropped"` + RespSent uint64 `json:"resp_sent"` + RespBytes uint64 `json:"resp_bytes"` + RespEmpty uint64 `json:"resp_empty"` + RespData uint64 `json:"resp_data"` + RespOversize uint64 `json:"resp_oversize"` + KCPNew uint64 `json:"kcp_new"` + KCPEnd uint64 `json:"kcp_end"` + SmuxNew uint64 `json:"smux_new"` + SmuxEnd uint64 `json:"smux_end"` + ChLen int `json:"ch_len"` +} + +var ( + dnsttStatsMu sync.Mutex + lastDnsttStats DnsttStatsSnapshot +) + +// GetDNSTTStatsSnapshot returns the most recent DNSTT stats snapshot. +// It is safe for concurrent use by HTTP handlers and always returns a +// defensive copy of the snapshot. +func GetDNSTTStatsSnapshot() DnsttStatsSnapshot { + dnsttStatsMu.Lock() + defer dnsttStatsMu.Unlock() + return lastDnsttStats +} + +// dnsttCounters holds aggregated counters used to debug tunnel instability. +// All fields are updated via sync/atomic. +type dnsttCounters struct { + // StreamSeq is a monotonically increasing identifier used to tag smux + // streams as they are handed off to the SSH handler. + StreamSeq uint64 + DNSRx uint64 + DNSParseErr uint64 + // NoEDNS counts DNS queries without an EDNS(0) OPT RR. In many real-world + // deployments (especially mobile / carrier resolvers), EDNS may be stripped + // even though the path can carry UDP responses larger than 512 bytes. + NoEDNS uint64 + // SmallEDNS counts queries where the inferred/advertised UDP payload limit + // is at the classic DNS size (512 bytes). Kept for backward-compatible + // stats output. + SmallEDNS uint64 + RecQueued uint64 + RecDropped uint64 + RespSent uint64 + RespSentBytes uint64 + RespEmpty uint64 + RespWithData uint64 + RespOversize uint64 + KCPSessionsNew uint64 + KCPSessionsEnd uint64 + SmuxStreamsNew uint64 + SmuxStreamsEnd uint64 +} + +var dnsttStats dnsttCounters + +var maxEncodedPayloadCache sync.Map // map[int]int + +// dnsttClientPayloadCap tracks an inferred per-client UDP payload capability. +// Keyed by the turbotunnel "remote address" string (a hex ClientID). We use +// this to set a per-session KCP MTU so the tunnel can function on paths that +// do not send EDNS(0) but still support UDP payloads > 512. +// dnsttClientPayloadCap stores per-client capability with a last-seen timestamp. +// This needs periodic cleanup to avoid unbounded growth when many unique client +// IDs are observed over time. +type clientCapEntry struct { + Cap int + LastSeen int64 // unix nano +} + +var dnsttClientPayloadCap sync.Map // map[string]clientCapEntry + +var dnsttCapReaperOnce sync.Once + +func startDNSTTCapReaper() { + // Reap old client capability entries so the map can't grow forever. + // Defaults chosen to be conservative: keep "active" client IDs for 6 hours. + const ( + reapEvery = 10 * time.Minute + maxAge = 6 * time.Hour + ) + dnsttCapReaperOnce.Do(func() { + go func() { + t := time.NewTicker(reapEvery) + defer t.Stop() + for range t.C { + cutoff := time.Now().Add(-maxAge).UnixNano() + dnsttClientPayloadCap.Range(func(k, v any) bool { + e, ok := v.(clientCapEntry) + if ok && e.LastSeen > 0 && e.LastSeen < cutoff { + dnsttClientPayloadCap.Delete(k) + } + return true + }) + } + }() + }) +} + +func dnsttClientKey(clientID turbotunnel.ClientID) string { + return fmt.Sprintf("%x", clientID[:]) +} + +func updateClientPayloadCap(clientID turbotunnel.ClientID, cap int) { + if cap <= 0 { + return + } + key := dnsttClientKey(clientID) + now := time.Now().UnixNano() + if v, ok := dnsttClientPayloadCap.Load(key); ok { + e := v.(clientCapEntry) + if cap > e.Cap { + e.Cap = cap + } + e.LastSeen = now + dnsttClientPayloadCap.Store(key, e) + return + } + dnsttClientPayloadCap.Store(key, clientCapEntry{Cap: cap, LastSeen: now}) +} + +func cachedMaxEncodedPayload(limit int) int { + if limit <= 0 { + return 0 + } + if v, ok := maxEncodedPayloadCache.Load(limit); ok { + return v.(int) + } + m := computeMaxEncodedPayload(limit) + maxEncodedPayloadCache.Store(limit, m) + return m +} + +// base32Encoding is a base32 encoding without padding, as used by dnstt. +var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) + +// dnsttSSHConfig holds the SSH server configuration used by the DNS tunnel. +// It is set by startDNSTT before any dnstt sessions are accepted. +var dnsttSSHConfig *ssh.ServerConfig + +// dnsttLog is a dedicated logger for the integrated DNSTT server. It writes +// to stderr and uses a per-log prefix and microsecond precision. Unlike the +// global log.Logger used by the rest of the application, this logger is not +// affected by calls to log.SetOutput in main.go (e.g. when quiet mode is +// enabled). All log lines emitted by DNSTT should go through dnsttLog so +// that debugging output remains visible even when other logs are suppressed. +var dnsttLog = log.New(os.Stderr, "dnstt: ", log.LstdFlags|log.Lmicroseconds) + +// dnsttLogBuf stores recent DNSTT log lines for the web panel. It acts as a +// circular buffer retaining the last N log lines. The capacity is set when +// the DNSTT server is started, typically around 100 lines to stay well +// within the configured memory budget. Lines are stored without the +// trailing newline. Access to the buffer is synchronized by an internal +// mutex. +type dnsttLogBuffer struct { + mu sync.Mutex + lines []string + maxLines int +} + +// newDNSTTLogBuffer constructs a new ring buffer with the given capacity. +// The capacity must be positive. +func newDNSTTLogBuffer(maxLines int) *dnsttLogBuffer { + if maxLines <= 0 { + maxLines = 100 + } + return &dnsttLogBuffer{ + lines: make([]string, 0, maxLines), + maxLines: maxLines, + } +} + +// Write implements io.Writer and appends complete lines from p to the +// buffer. It splits on '\n' and discards empty segments. When the +// buffer reaches its maximum length, the oldest lines are removed to make +// room for new ones. +func (b *dnsttLogBuffer) Write(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + s := string(p) + // Split incoming data by newline. We intentionally discard empty + // strings to avoid blank lines from being stored. + parts := strings.Split(s, "\n") + for _, part := range parts { + if part == "" { + continue + } + if len(b.lines) < b.maxLines { + b.lines = append(b.lines, part) + } else { + // Shift left by one and append new line at end. + copy(b.lines, b.lines[1:]) + b.lines[len(b.lines)-1] = part + } + } + return len(p), nil +} + +// GetLines returns a copy of the current log lines. The lines are +// returned in chronological order from oldest to newest. +func (b *dnsttLogBuffer) GetLines() []string { + b.mu.Lock() + defer b.mu.Unlock() + out := make([]string, len(b.lines)) + copy(out, b.lines) + return out +} + +// global log buffer for DNSTT logs. It is initialised when the server +// starts. Access via getDNSTTLogLines for concurrency safety. +var dnsttLogBuf *dnsttLogBuffer + +// getDNSTTLogLines returns the current DNSTT log lines in order. If the +// buffer has not been initialised, it returns an empty slice. +func getDNSTTLogLines() []string { + if dnsttLogBuf == nil { + return nil + } + return dnsttLogBuf.GetLines() +} + +// startDNSTT starts the integrated dnstt server if cfg is non-nil. It reads +// the Noise private key from cfg.PrivKeyFile, parses cfg.Domain into a dns.Name, +// and then launches runDNSTT in a goroutine. Any errors during start are +// logged. The SSH server configuration is used when handling streams. +func startDNSTT(cfg *DNSTTConfig, sshConf *ssh.ServerConfig) { + if cfg == nil { + return + } + startDNSTTCapReaper() + dnsttSSHConfig = sshConf + // Configure whether periodic DNSTT statistics should be emitted to stderr. + // When DisableStatsLog is true, stats will be collected but log lines are suppressed. + if cfg != nil { + dnsttPrintStats = !cfg.DisableStatsLog + // Initialise the log buffer once. Use a capacity of 100 lines (~few KB). + if dnsttLogBuf == nil { + dnsttLogBuf = newDNSTTLogBuffer(100) + } + // Configure the DNSTT logger output. If DisableConsoleLog is set, + // write only to the buffer; otherwise tee to both the buffer and stderr. + if cfg.DisableConsoleLog { + dnsttLog.SetOutput(dnsttLogBuf) + } else { + dnsttLog.SetOutput(io.MultiWriter(dnsttLogBuf, os.Stderr)) + } + } + // Read the private key from file. + f, err := os.Open(cfg.PrivKeyFile) + if err != nil { + dnsttLog.Printf("cannot open privkey file %s: %v", cfg.PrivKeyFile, err) + return + } + privkey, err := noise.ReadKey(f) + f.Close() + if err != nil { + dnsttLog.Printf("cannot read privkey from file: %v", err) + return + } + // Parse the domain name. dns.ParseName accepts a domain with a trailing + // dot or without. Any error here will abort the dnstt server. + domain, err := dns.ParseName(cfg.Domain) + if err != nil { + dnsttLog.Printf("invalid domain %q: %v", cfg.Domain, err) + return + } + // Log initialisation parameters so DNSTT startup is visible even when + // quiet logging is enabled. This helps with debugging. + dnsttLog.Printf("starting: domain=%q udp_listen=%q privkey=%q", cfg.Domain, cfg.UDPListen, cfg.PrivKeyFile) + go func() { + if err := runDNSTT(privkey, domain, cfg.UDPListen); err != nil { + dnsttLog.Printf("server exited with error: %v", err) + } + }() +} + +// handleDNSTTStream accepts a smux.Stream from a client and hands it off to +// handleConn. The stream is wrapped as a net.Conn by streamConn so that +// handleConn sees a minimal net.Conn interface. This function blocks until +// handleConn returns, then returns nil. Any errors from handleConn are +// ignored; they are logged by handleConn itself. +func handleDNSTTStream(stream *smux.Stream, conv uint32) error { + // Assign a per-stream sequence number to help correlate open/close events. + sid := atomic.AddUint64(&dnsttStats.StreamSeq, 1) + start := time.Now() + dnsttLog.Printf("ssh stream begin: conv=%d sid=%d", conv, sid) + sc := &streamConn{Stream: stream} + // Delegate to the existing SSH connection handler. This call blocks + // until the SSH connection terminates. The smux stream will be closed + // by handleConn when it returns. + handleConn(sc, dnsttSSHConfig) + dnsttLog.Printf("ssh stream end: conv=%d sid=%d duration=%s", conv, sid, time.Since(start)) + return nil +} + +// streamConn adapts a smux.Stream to the net.Conn interface expected by +// handleConn. smux.Stream already implements Read and Write, but does not +// satisfy net.Conn because it lacks methods for deadlines and addresses. We +// implement those methods with no‑ops and placeholder addresses. +type streamConn struct { + *smux.Stream +} + +func (s *streamConn) LocalAddr() net.Addr { return dummyAddr{} } +func (s *streamConn) RemoteAddr() net.Addr { return dummyAddr{} } +func (s *streamConn) SetDeadline(t time.Time) error { return nil } +func (s *streamConn) SetReadDeadline(t time.Time) error { return nil } +func (s *streamConn) SetWriteDeadline(t time.Time) error { return nil } + +// dummyAddr is a stand‑in net.Addr implementation for dnstt streams. It +// reports a generic network and address; this satisfies the net.Conn +// interface. handleConn logs the remote address, but here we provide +// "dnstt" as both network and address to indicate a tunnelled connection. +type dummyAddr struct{} + +func (d dummyAddr) Network() string { return "dnstt" } +func (d dummyAddr) String() string { return "dnstt" } + +// acceptDNSTTStreams wraps a KCP session in a Noise channel and an smux +// Session, then waits for smux streams. Each stream is passed to +// handleDNSTTStream. Any errors from the Noise or smux layers are returned. +func acceptDNSTTStreams(conn *kcp.UDPSession, privkey []byte) error { + // Put a Noise channel on top of the KCP conn. + rw, err := noise.NewServer(conn, privkey) + if err != nil { + return err + } + + // Put an smux session on top of the encrypted Noise channel. + smuxConfig := smux.DefaultConfig() + smuxConfig.Version = 2 + smuxConfig.KeepAliveTimeout = idleTimeout + smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024 + sess, err := smux.Server(rw, smuxConfig) + if err != nil { + return err + } + defer sess.Close() + + for { + stream, err := sess.AcceptStream() + if err != nil { + if err, ok := err.(net.Error); ok && err.Temporary() { + continue + } + return err + } + // Log the creation of each new smux stream. Reporting the conv helps + // to correlate streams with their parent KCP session. + atomic.AddUint64(&dnsttStats.SmuxStreamsNew, 1) + dnsttLog.Printf("new smux stream: conv=%d", conn.GetConv()) + // For each new smux stream, hand it off to our SSH handler. + go func(s *smux.Stream, conv uint32) { + defer s.Close() + _ = handleDNSTTStream(s, conv) + atomic.AddUint64(&dnsttStats.SmuxStreamsEnd, 1) + dnsttLog.Printf("smux stream closed: conv=%d", conv) + }(stream, conn.GetConv()) + } +} + +// acceptDNSTTSessions listens for incoming KCP connections and passes them to +// acceptDNSTTStreams. It configures window sizes and MTU on each accepted +// session as in the original dnstt-server. +func acceptDNSTTSessions(ln *kcp.Listener, privkey []byte, mtu int) error { + for { + conn, err := ln.AcceptKCP() + if err != nil { + if err, ok := err.(net.Error); ok && err.Temporary() { + continue + } + return err + } + from := conn.RemoteAddr().String() + // Choose a per-session MTU derived from the inferred/advertised UDP payload + // capability for this client. This is essential on paths that clamp UDP + // payloads below our global cap (e.g. ~930), because KCP packets larger than + // what can fit in a single DNS response will cause the tunnel to stall. + effectiveLimit := maxUDPPayload + if v, ok := dnsttClientPayloadCap.Load(from); ok { + effectiveLimit = v.(clientCapEntry).Cap + if effectiveLimit < 512 { + effectiveLimit = 512 + } + if effectiveLimit > maxUDPPayload { + effectiveLimit = maxUDPPayload + } + } + maxEnc := cachedMaxEncodedPayload(effectiveLimit) + mtuSession := mtu + if maxEnc > 0 { + if m := maxEnc - 2; m >= 80 { + mtuSession = m + } + } + // Log each newly accepted KCP session. Include conversation ID, remote address + // (ClientID), and the chosen MTU. + atomic.AddUint64(&dnsttStats.KCPSessionsNew, 1) + dnsttLog.Printf("new KCP session: conv=%d from=%s mtu=%d limit=%d", conn.GetConv(), from, mtuSession, effectiveLimit) + // Permit coalescing the payloads of consecutive sends. + conn.SetStreamMode(true) + // Disable the dynamic congestion window (limit only by the maximum of + // local and remote static windows). + conn.SetNoDelay(0, 0, 0, 1) + conn.SetWindowSize(turbotunnel.QueueSize/2, turbotunnel.QueueSize/2) + if rc := conn.SetMtu(mtuSession); !rc { + panic(rc) + } + go func(c *kcp.UDPSession, conv uint32, from string) { + defer c.Close() + err := acceptDNSTTStreams(c, privkey) + atomic.AddUint64(&dnsttStats.KCPSessionsEnd, 1) + if err != nil && err != io.ErrClosedPipe { + dnsttLog.Printf("kcp session closed: conv=%d from=%s err=%v", conv, from, err) + } else { + dnsttLog.Printf("kcp session closed: conv=%d from=%s", conv, from) + } + }(conn, conn.GetConv(), conn.RemoteAddr().String()) + } +} + +// record represents a DNS message appropriate for a response to a previously +// received query, along with metadata necessary for sending the response. +// recvLoop sends instances of record to sendLoop via a channel. sendLoop +// receives instances of record and may fill in the message's Answer section +// before sending it. +type record struct { + Resp *dns.Message + Addr net.Addr + ClientID turbotunnel.ClientID + // PayloadLimit holds the maximum UDP payload size advertised by the + // client via EDNS(0). sendLoop uses this to clamp outgoing DNS + // responses so they never exceed what the client claims it will + // accept. A zero value means no per‑client limit and defaults to + // maxUDPPayload. + PayloadLimit int +} + +// nextPacket reads the next length‑prefixed packet from r, ignoring padding. +// It returns a nil error only when a packet was read successfully. It +// returns io.EOF only when there were 0 bytes remaining to read from r. It +// returns io.ErrUnexpectedEOF when EOF occurs in the middle of an encoded +// packet. See dnstt-server/main.go for details. +func nextPacket(r *bytes.Reader) ([]byte, error) { + eof := func(err error) error { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err + } + for { + prefix, err := r.ReadByte() + if err != nil { + // We may return a real io.EOF only here. + return nil, err + } + if prefix >= 224 { + paddingLen := prefix - 224 + _, err := io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return nil, eof(err) + } + } else { + p := make([]byte, int(prefix)) + _, err = io.ReadFull(r, p) + return p, eof(err) + } + } +} + +// responseFor constructs a response dns.Message that is appropriate for query. +// Along with the dns.Message, it returns the query's decoded data payload. If +// the returned dns.Message is nil, it means that there should be no response +// to this query. If the returned dns.Message has an Rcode() of +// dns.RcodeNoError, the message is a candidate for carrying downstream data +// in a TXT record. This function is adapted from dnstt-server/main.go. +func responseFor(query *dns.Message, domain dns.Name) (*dns.Message, []byte) { + resp := &dns.Message{ + ID: query.ID, + Flags: 0x8000, // QR = 1, RCODE = no error + Question: query.Question, + } + if query.Flags&0x8000 != 0 { + // QR != 0, this is not a query. Don't even send a response. + return nil, nil + } + // Check for EDNS(0) support. Include our own OPT RR only if we receive + // one from the requester. + payloadSize := 0 + for _, rr := range query.Additional { + if rr.Type != dns.RRTypeOPT { + continue + } + version := (rr.TTL >> 16) & 0xff + if version != 0 { + resp.Flags |= dns.ExtendedRcodeBadVers & 0xf + additional := dns.RR{ + Name: dns.Name{}, + Type: dns.RRTypeOPT, + Class: 0, + TTL: (dns.ExtendedRcodeBadVers >> 4) << 24, + Data: []byte{}, + } + resp.Additional = append(resp.Additional, additional) + return resp, nil + } + payloadSize = int(rr.Class) + } + if payloadSize < 512 { + payloadSize = 512 + } + // There must be exactly one question. + if len(query.Question) != 1 { + resp.Flags |= dns.RcodeFormatError + dnsttLog.Printf("FORMERR: too few or too many questions (%d)", len(query.Question)) + return resp, nil + } + question := query.Question[0] + // Check the name to see if it ends in our chosen domain, and extract + // all that comes before the domain if it does. If it does not, we + // return RcodeNameError. + prefix, ok := question.Name.TrimSuffix(domain) + if !ok { + resp.Flags |= dns.RcodeNameError + // NXDOMAIN: not authoritative for this name + return resp, nil + } + resp.Flags |= 0x0400 // AA = 1 + if query.Opcode() != 0 { + resp.Flags |= dns.RcodeNotImplemented + return resp, nil + } + if question.Type != dns.RRTypeTXT { + // We only support QTYPE == TXT. + resp.Flags |= dns.RcodeNameError + return resp, nil + } + encoded := bytes.ToUpper(bytes.Join(prefix, nil)) + payload := make([]byte, base32Encoding.DecodedLen(len(encoded))) + n, err := base32Encoding.Decode(payload, encoded) + if err != nil { + resp.Flags |= dns.RcodeNameError + return resp, nil + } + payload = payload[:n] + // Do not reject queries advertising a smaller EDNS UDP payload than + // maxUDPPayload. We clamp responses to the client‑advertised size + // later in sendLoop. + return resp, payload +} + +// recvLoop repeatedly calls dnsConn.ReadFrom, extracts the packets contained in +// the incoming DNS queries, and puts them on ttConn's incoming queue. +// Whenever a query calls for a response, constructs a partial response and +// passes it to sendLoop over ch. +func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- *record) error { + for { + var buf [4096]byte + n, addr, err := dnsConn.ReadFrom(buf[:]) + if err != nil { + if err, ok := err.(net.Error); ok && err.Temporary() { + dnsttLog.Printf("ReadFrom temporary error: %v", err) + continue + } + return err + } + atomic.AddUint64(&dnsttStats.DNSRx, 1) + // Parse DNS query. + query, err := dns.MessageFromWireFormat(buf[:n]) + if err != nil { + atomic.AddUint64(&dnsttStats.DNSParseErr, 1) + dnsttLog.Printf("cannot parse DNS query: %v", err) + continue + } + // Determine an effective UDP payload limit for this query. + // + // Prefer EDNS(0) if present. However, many resolver paths strip EDNS while + // still allowing UDP responses larger than 512. For those paths, infer a + // practical limit from the observed query size (if we received a ~900-byte + // query, the path clearly supports >512 UDP). Clamp to our global cap. + payloadLimit := 0 + hasEDNS := false + for _, rr := range query.Additional { + if rr.Type != dns.RRTypeOPT { + continue + } + hasEDNS = true + // The lower 16 bits of the Class field of the OPT RR specify the + // requestor's maximum UDP payload size (RFC 6891). + sz := int(rr.Class) + if sz < 512 { + sz = 512 + } + payloadLimit = sz + break + } + if !hasEDNS { + atomic.AddUint64(&dnsttStats.NoEDNS, 1) + payloadLimit = n + if payloadLimit < 512 { + payloadLimit = 512 + } + } else { + // If EDNS is present but appears to under-advertise the true path MTU, + // treat the observed query size as a lower bound. Some resolvers clamp or + // rewrite EDNS values even though they can carry larger UDP payloads. + if n > payloadLimit { + payloadLimit = n + } + } + // Many resolver paths do not include EDNS (or clamp it to 512) but still + // reliably carry larger UDP DNS messages (commonly ~900 bytes). For DNSTT + // this matters because downstream data rides in responses; treating such + // paths as strictly 512-byte causes the tunnel to stall. + // + // noEDNSFallbackPayload acts as a floor for these cases; if you want strict + // RFC 1035 behaviour, set it to 512. + if payloadLimit < noEDNSFallbackPayload { + payloadLimit = noEDNSFallbackPayload + } + if payloadLimit > maxUDPPayload { + payloadLimit = maxUDPPayload + } + if payloadLimit == 512 { + atomic.AddUint64(&dnsttStats.SmallEDNS, 1) + } + resp, payload := responseFor(&query, domain) + // Extract ClientID + var clientID turbotunnel.ClientID + n = copy(clientID[:], payload) + payload = payload[n:] + if n == len(clientID) { + // Update our per-client capability estimate so we can choose a suitable + // per-session KCP MTU even when EDNS is stripped. + updateClientPayloadCap(clientID, payloadLimit) + // Feed packets into KCP. + r := bytes.NewReader(payload) + for { + p, err := nextPacket(r) + if err != nil { + break + } + ttConn.QueueIncoming(p, clientID) + } + } else { + if resp != nil && resp.Rcode() == dns.RcodeNoError { + resp.Flags |= dns.RcodeNameError + } + } + if resp != nil { + // Push this record along with the per‑client payload limit to the sender. + rec := &record{ + Resp: resp, + Addr: addr, + ClientID: clientID, + PayloadLimit: payloadLimit, + } + select { + case ch <- rec: + atomic.AddUint64(&dnsttStats.RecQueued, 1) + default: + d := atomic.AddUint64(&dnsttStats.RecDropped, 1) + // Log occasionally to avoid flooding logs under sustained overload. + if d == 1 || d%1000 == 0 { + dnsttLog.Printf("dropping response record: ch_len=%d ch_cap=%d dropped=%d", len(ch), cap(ch), d) + } + } + } + } +} + +// sendLoop repeatedly receives records from ch. Those that represent an +// error response are sent immediately. Those that represent a response +// capable of carrying data are packed full of as many packets as will fit +// while keeping the total size under maxEncodedPayload, then sent. +func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan *record, maxEncodedPayload int) error { + var nextRec *record + for { + rec := nextRec + nextRec = nil + if rec == nil { + var ok bool + rec, ok = <-ch + if !ok { + break + } + } + // Determine the effective per-query UDP payload limit. + // + // rec.PayloadLimit comes from EDNS(0) and our heuristic floor/observations. + // Additionally, we maintain a per-client learned capability (dnsttClientPayloadCap) + // that can be promoted when we observe that larger downstream responses are + // needed and appear to be supported. Use the larger of the two, then clamp + // to our global cap. + effectivePayloadLimit := rec.PayloadLimit + if v, ok := dnsttClientPayloadCap.Load(dnsttClientKey(rec.ClientID)); ok { + if cap := v.(clientCapEntry).Cap; cap > effectivePayloadLimit { + effectivePayloadLimit = cap + } + } + if effectivePayloadLimit <= 0 { + effectivePayloadLimit = maxUDPPayload + } + if effectivePayloadLimit > maxUDPPayload { + effectivePayloadLimit = maxUDPPayload + } + // Compute a per-limit "max encoded payload" so we avoid generating a DNS + // response that would need truncation (TC). Truncation breaks DNSTT because + // resolvers retry over TCP, which we don't support here. + maxEnc := cachedMaxEncodedPayload(effectivePayloadLimit) + if maxEnc < 0 { + maxEnc = 0 + } + // Note: KCP MTU is set per session in acceptDNSTTSessions based on inferred + // per-client capability, so we do not warn here about MTU mismatches. + + if rec.Resp.Rcode() == dns.RcodeNoError && len(rec.Resp.Question) == 1 { + // It's a non-error response, so we can fill the Answer section. + rec.Resp.Answer = []dns.RR{{ + Name: rec.Resp.Question[0].Name, + Type: rec.Resp.Question[0].Type, + Class: rec.Resp.Question[0].Class, + TTL: responseTTL, + Data: nil, + }} + var payload bytes.Buffer + limit := maxEnc + timer := time.NewTimer(maxResponseDelay) + packets := 0 + for { + var p []byte + unstash := ttConn.Unstash(rec.ClientID) + outgoing := ttConn.OutgoingQueue(rec.ClientID) + select { + case p = <-unstash: + default: + select { + case p = <-unstash: + case p = <-outgoing: + default: + select { + case p = <-unstash: + case p = <-outgoing: + case <-timer.C: + case nextRec = <-ch: + } + } + } + timer.Reset(0) + if len(p) == 0 { + break + } + limit -= 2 + len(p) + // Never exceed the per-query maximum encoded payload. If the next packet + // would overflow, stash it for the next response (even if this would have + // been the first packet). + if limit < 0 { + ttConn.Stash(p, rec.ClientID) + break + } + binary.Write(&payload, binary.BigEndian, uint16(len(p))) + payload.Write(p) + packets++ + } + timer.Stop() + rec.Resp.Answer[0].Data = dns.EncodeRDataTXT(payload.Bytes()) + if packets == 0 { + atomic.AddUint64(&dnsttStats.RespEmpty, 1) + } else { + atomic.AddUint64(&dnsttStats.RespWithData, 1) + } + } + buf, err := rec.Resp.WireFormat() + if err != nil { + dnsttLog.Printf("resp WireFormat: %v", err) + continue + } + if len(buf) > effectivePayloadLimit { + // The resolver may under-advertise (or omit) its UDP payload limit while + // still accepting larger UDP responses. Rather than dropping downstream + // data (which can stall DNSTT), treat this as a hint and promote the + // per-client inferred payload capability. + atomic.AddUint64(&dnsttStats.RespOversize, 1) + promoteTo := len(buf) + if promoteTo > maxUDPPayload { + dnsttLog.Printf("oversize DNS response: size=%d limit=%d cap=%d (dropping)", len(buf), effectivePayloadLimit, maxUDPPayload) + continue + } + updateClientPayloadCap(rec.ClientID, promoteTo) + dnsttLog.Printf("oversize DNS response: size=%d limit=%d -> promote client to %d", len(buf), effectivePayloadLimit, promoteTo) + // After promotion, allow this response to be sent. + effectivePayloadLimit = promoteTo + } + _, err = dnsConn.WriteTo(buf, rec.Addr) + if err != nil { + if err, ok := err.(net.Error); ok && err.Temporary() { + dnsttLog.Printf("WriteTo temporary error: %v", err) + continue + } + return err + } + atomic.AddUint64(&dnsttStats.RespSent, 1) + atomic.AddUint64(&dnsttStats.RespSentBytes, uint64(len(buf))) + } + return nil +} + +// computeMaxEncodedPayload computes the maximum amount of downstream TXT RR +// data that keep the overall response size less than maxUDPPayload, in the +// worst case when the response answers a query that has a maximum-length name +// in its Question section. Returns 0 in the case that no amount of data +// makes the overall response size small enough. +func computeMaxEncodedPayload(limit int) int { + maxLengthName, err := dns.NewName([][]byte{ + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + }) + if err != nil { + panic(err) + } + { + n := 0 + for _, label := range maxLengthName { + n += len(label) + 1 + } + n += 1 + if n != 255 { + panic(fmt.Sprintf("dnstt: max-length name is %d octets, should be %d", n, 255)) + } + } + queryLimit := uint16(limit) + if int(queryLimit) != limit { + queryLimit = 0xffff + } + query := &dns.Message{ + Question: []dns.Question{{ + Name: maxLengthName, + Type: dns.RRTypeTXT, + Class: dns.RRTypeTXT, + }}, + Additional: []dns.RR{{ + Name: dns.Name{}, + Type: dns.RRTypeOPT, + Class: uint16(queryLimit), + TTL: 0, + Data: []byte{}, + }}, + } + resp, _ := responseFor(query, dns.Name([][]byte{})) + resp.Answer = []dns.RR{{ + Name: query.Question[0].Name, + Type: query.Question[0].Type, + Class: query.Question[0].Class, + TTL: responseTTL, + Data: nil, + }} + low := 0 + high := 32768 + for low+1 < high { + mid := (low + high) / 2 + resp.Answer[0].Data = dns.EncodeRDataTXT(make([]byte, mid)) + buf, err := resp.WireFormat() + if err != nil { + panic(err) + } + if len(buf) <= limit { + low = mid + } else { + high = mid + } + } + return low +} + +// runDNSTT starts a dnstt server on udpListen. It computes the effective +// MTU based on the configured maxUDPPayload, then accepts KCP sessions and +// handles DNS queries. Errors are returned only for fatal conditions. +func runDNSTT(privkey []byte, domain dns.Name, udpListen string) error { + dnsConn, err := net.ListenPacket("udp", udpListen) + if err != nil { + return fmt.Errorf("dnstt: opening UDP listener on %s: %v", udpListen, err) + } + if udp, ok := dnsConn.(*net.UDPConn); ok { + _ = udp.SetReadBuffer(4 * 1024 * 1024) + _ = udp.SetWriteBuffer(4 * 1024 * 1024) + } + + // Register so stopDNSTT() can close this socket and unblock the read loop. + dnsttConnMu.Lock() + if dnsttConn != nil { + _ = dnsttConn.Close() + } + dnsttConn = dnsConn + dnsttConnMu.Unlock() + // Log readiness of the UDP listener. + dnsttLog.Printf("udp listener ready on %s", udpListen) + // compute maximum encoded payload and resulting MTU + maxEncodedPayload := computeMaxEncodedPayload(maxUDPPayload) + mtu := maxEncodedPayload - 2 + if mtu < 80 { + if mtu < 0 { + mtu = 0 + } + return fmt.Errorf("dnstt: computed MTU %d too small", mtu) + } + // set up turbotunnel and KCP listener + ttConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2) + ln, err := kcp.ServeConn(nil, 0, 0, ttConn) + if err != nil { + return fmt.Errorf("dnstt: opening KCP listener: %v", err) + } + go func() { + if err := acceptDNSTTSessions(ln, privkey, mtu); err != nil { + dnsttLog.Printf("acceptSessions error: %v", err) + } + }() + // NOTE: This channel buffers pending DNS response records. Keeping this + // extremely large can look like a "memory leak" under bursty load because + // records (and their associated allocations) are retained until drained. + // A moderate size provides smoothing while still applying backpressure. + ch := make(chan *record, 20000) + // Periodically aggregate DNSTT counters. This goroutine runs every 5 seconds, + // resetting the atomic counters, storing them in lastDnsttStats and optionally + // emitting a log line. Even when dnsttPrintStats is false, statistics will + // still be collected and made available via the admin API. + go func() { + t := time.NewTicker(5 * time.Second) + defer t.Stop() + for range t.C { + dnsRx := atomic.SwapUint64(&dnsttStats.DNSRx, 0) + parseErr := atomic.SwapUint64(&dnsttStats.DNSParseErr, 0) + noEDNS := atomic.SwapUint64(&dnsttStats.NoEDNS, 0) + limit512 := atomic.SwapUint64(&dnsttStats.SmallEDNS, 0) + queued := atomic.SwapUint64(&dnsttStats.RecQueued, 0) + dropped := atomic.SwapUint64(&dnsttStats.RecDropped, 0) + respSent := atomic.SwapUint64(&dnsttStats.RespSent, 0) + respBytes := atomic.SwapUint64(&dnsttStats.RespSentBytes, 0) + respEmpty := atomic.SwapUint64(&dnsttStats.RespEmpty, 0) + respData := atomic.SwapUint64(&dnsttStats.RespWithData, 0) + over := atomic.SwapUint64(&dnsttStats.RespOversize, 0) + kcpNew := atomic.SwapUint64(&dnsttStats.KCPSessionsNew, 0) + kcpEnd := atomic.SwapUint64(&dnsttStats.KCPSessionsEnd, 0) + smuxNew := atomic.SwapUint64(&dnsttStats.SmuxStreamsNew, 0) + smuxEnd := atomic.SwapUint64(&dnsttStats.SmuxStreamsEnd, 0) + // Update the snapshot + dnsttStatsMu.Lock() + lastDnsttStats = DnsttStatsSnapshot{ + Timestamp: time.Now(), + DNSRx: dnsRx, + ParseErr: parseErr, + NoEDNS: noEDNS, + Limit512: limit512, + RecQueued: queued, + RecDropped: dropped, + RespSent: respSent, + RespBytes: respBytes, + RespEmpty: respEmpty, + RespData: respData, + RespOversize: over, + KCPNew: kcpNew, + KCPEnd: kcpEnd, + SmuxNew: smuxNew, + SmuxEnd: smuxEnd, + ChLen: len(ch), + } + dnsttStatsMu.Unlock() + // Optionally log the snapshot to stderr + if dnsttPrintStats { + dnsttLog.Printf( + "stats 5s: dns_rx=%d parse_err=%d no_edns=%d limit512=%d rec_queued=%d rec_dropped=%d resp_sent=%d resp_bytes=%d resp_empty=%d resp_data=%d resp_oversize=%d kcp_new=%d kcp_end=%d smux_new=%d smux_end=%d ch_len=%d", + dnsRx, parseErr, noEDNS, limit512, queued, dropped, respSent, respBytes, respEmpty, respData, over, kcpNew, kcpEnd, smuxNew, smuxEnd, len(ch), + ) + } + } + }() + go func() { + if err := sendLoop(dnsConn, ttConn, ch, maxEncodedPayload); err != nil { + dnsttLog.Printf("sendLoop error: %v", err) + } + }() + return recvLoop(domain, dnsConn, ttConn, ch) +} + +// ---- Key management API handlers ---- + +const dnsttKeyFile = "/opt/sshpanel/dnstt.key" + +// handleDnsttGenKey generates a new Noise keypair, saves the private key to +// dnsttKeyFile, and returns the hex-encoded public key. +func handleDnsttGenKey(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + privkey, err := noise.GeneratePrivkey() + if err != nil { + http.Error(w, "keygen: "+err.Error(), http.StatusInternalServerError) + return + } + f, err := os.OpenFile(dnsttKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + http.Error(w, "write key: "+err.Error(), http.StatusInternalServerError) + return + } + if err := noise.WriteKey(f, privkey); err != nil { + f.Close() + http.Error(w, "write key: "+err.Error(), http.StatusInternalServerError) + return + } + f.Close() + pubkey := noise.PubkeyFromPrivkey(privkey) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "privkey_file": dnsttKeyFile, + "pubkey": noise.EncodeKey(pubkey), + }) +} + +// handleDnsttGetPubKey reads the configured private key and returns the +// corresponding public key so the admin can share it with dnstt clients. +func handleDnsttGetPubKey(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + cfg := getGlobalCfg() + keyPath := dnsttKeyFile + if cfg != nil && cfg.DNSTT != nil && cfg.DNSTT.PrivKeyFile != "" { + keyPath = cfg.DNSTT.PrivKeyFile + } + f, err := os.Open(keyPath) + if err != nil { + http.Error(w, "open key: "+err.Error(), http.StatusInternalServerError) + return + } + defer f.Close() + privkey, err := noise.ReadKey(f) + if err != nil { + http.Error(w, "read key: "+err.Error(), http.StatusInternalServerError) + return + } + pubkey := noise.PubkeyFromPrivkey(privkey) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "pubkey": noise.EncodeKey(pubkey), + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..204406e --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module shell2 + +go 1.25.4 + +require golang.org/x/crypto v0.45.0 + +require ( + github.com/flynn/noise v1.0.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/klauspost/reedsolomon v1.12.0 // indirect + github.com/lib/pq v1.10.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/xtaci/kcp-go/v5 v5.6.61 // indirect + github.com/xtaci/smux v1.5.50 // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/time v0.14.0 // indirect + www.bamsoftware.com/git/dnstt.git v1.20241021.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..abce5e9 --- /dev/null +++ b/go.sum @@ -0,0 +1,104 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= +github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/reedsolomon v1.12.0 h1:I5FEp3xSwVCcEh3F5A7dofEfhXdF/bWhQWPH+XwBFno= +github.com/klauspost/reedsolomon v1.12.0/go.mod h1:EPLZJeh4l27pUGC3aXOjheaoh1I9yut7xTURiW3LQ9Y= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= +github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= +github.com/xtaci/kcp-go/v5 v5.6.61 h1:ajm12pGuWO+GWQNusPyPESC7Rq0yTC2rEXVYkM8ExOg= +github.com/xtaci/kcp-go/v5 v5.6.61/go.mod h1:9O3D8WR+cyyUjGiTILYfg17vn72otWuXK2AFfqIe6CM= +github.com/xtaci/smux v1.5.50 h1:y/1DlWQC9bnMeZzsyk4oL2hbLK6uVk4BKTz5BeQqUEA= +github.com/xtaci/smux v1.5.50/go.mod h1:IGQ9QYrBphmb/4aTnLEcJby0TNr3NV+OslIOMrX825Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +www.bamsoftware.com/git/dnstt.git v1.20241021.0 h1:Xi0lmT+5kcgzY7P+r726eBXKMZKgGoD8GTNKrlh8TuE= +www.bamsoftware.com/git/dnstt.git v1.20241021.0/go.mod h1:J4kVFxhn2bZqSqfE9l7keNTtsc+dRR6+uNH4kPu5VIs= diff --git a/hotreload.go b/hotreload.go new file mode 100644 index 0000000..b2aa09c --- /dev/null +++ b/hotreload.go @@ -0,0 +1,269 @@ +package main + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "sync" + + "golang.org/x/crypto/ssh" +) + +// ---------- Global SSH config ---------- + +var ( + sshCfgMu sync.RWMutex + currentSSHCfg *ssh.ServerConfig +) + +func setSSHConfig(c *ssh.ServerConfig) { + sshCfgMu.Lock() + currentSSHCfg = c + sshCfgMu.Unlock() +} + +func getSSHConfig() *ssh.ServerConfig { + sshCfgMu.RLock() + defer sshCfgMu.RUnlock() + return currentSSHCfg +} + +// ---------- Dynamic TCP listener pool ---------- + +// listenerPool manages a set of net.Listener instances by address. +// Calling Sync() opens new addresses and closes removed ones; unchanged +// addresses are left untouched so existing connections are not disrupted. +type listenerPool struct { + mu sync.Mutex + entries map[string]net.Listener + serve func(net.Listener) // goroutine launched for each new listener +} + +func newListenerPool(serve func(net.Listener)) *listenerPool { + return &listenerPool{entries: make(map[string]net.Listener), serve: serve} +} + +// Sync ensures exactly addrs are listening. Returns errors for addresses +// that could not be opened. +func (p *listenerPool) Sync(addrs []string) []error { + p.mu.Lock() + defer p.mu.Unlock() + + want := make(map[string]bool, len(addrs)) + for _, a := range addrs { + if a != "" { + want[a] = true + } + } + + for addr, ln := range p.entries { + if !want[addr] { + _ = ln.Close() // causes the serve goroutine to exit + delete(p.entries, addr) + log.Printf("hotreload: stopped %s", addr) + } + } + + var errs []error + for addr := range want { + if _, ok := p.entries[addr]; ok { + continue + } + ln, err := net.Listen("tcp", addr) + if err != nil { + errs = append(errs, fmt.Errorf("listen %s: %w", addr, err)) + continue + } + p.entries[addr] = ln + log.Printf("hotreload: listening on %s", addr) + go p.serve(ln) + } + return errs +} + +// ---------- Dynamic TLS listener pool ---------- + +type tlsListenerPool struct { + mu sync.Mutex + entries map[string]net.Listener +} + +func newTLSListenerPool() *tlsListenerPool { + return &tlsListenerPool{entries: make(map[string]net.Listener)} +} + +// Sync ensures exactly forwarders are listening (matched by listen address). +func (p *tlsListenerPool) Sync(forwarders []TLSForwarderConfig) []error { + p.mu.Lock() + defer p.mu.Unlock() + + want := make(map[string]TLSForwarderConfig, len(forwarders)) + for _, f := range forwarders { + if f.Listen != "" { + want[f.Listen] = f + } + } + + for addr, ln := range p.entries { + if _, ok := want[addr]; !ok { + _ = ln.Close() + delete(p.entries, addr) + log.Printf("hotreload: stopped TLS %s", addr) + } + } + + var errs []error + for addr, fwd := range want { + if _, ok := p.entries[addr]; ok { + continue + } + cert, err := tls.LoadX509KeyPair(fwd.CertFile, fwd.KeyFile) + if err != nil { + errs = append(errs, fmt.Errorf("TLS cert/key %s: %w", addr, err)) + continue + } + ln, err := tls.Listen("tcp", addr, &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }) + if err != nil { + errs = append(errs, fmt.Errorf("TLS listen %s: %w", addr, err)) + continue + } + p.entries[addr] = ln + log.Printf("hotreload: TLS+SSH listening on %s", addr) + go serveTLSSSH(ln) + } + return errs +} + +// ---------- Global pool instances (initialised in main) ---------- + +var ( + publicPool *listenerPool // HTTP+SSH: listen + extra_listen + localPool *listenerPool // raw SSH: local_ssh_listen + tlsPool *tlsListenerPool // TLS forwarders +) + +// isListenerClosed reports whether err came from using a closed listener. +func isListenerClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} + +// ---------- Runtime settings ---------- + +var ( + defaultLimitsMu sync.RWMutex + runtimeLimitMbpsUp int + runtimeLimitMbpsDown int + + adminHandlerMu sync.RWMutex + adminHandler http.Handler +) + +func setDefaultLimits(up, down int) { + defaultLimitsMu.Lock() + runtimeLimitMbpsUp = up + runtimeLimitMbpsDown = down + defaultLimitsMu.Unlock() +} + +func getDefaultLimits() (up, down int) { + defaultLimitsMu.RLock() + defer defaultLimitsMu.RUnlock() + return runtimeLimitMbpsUp, runtimeLimitMbpsDown +} + +func setAdminHandler(h http.Handler) { + adminHandlerMu.Lock() + adminHandler = h + adminHandlerMu.Unlock() +} + +func getAdminHandler() http.Handler { + adminHandlerMu.RLock() + defer adminHandlerMu.RUnlock() + return adminHandler +} + +// ---------- Full live reload ---------- + +// applyFullConfigReload applies every field in newCfg to the running server +// without a process restart. Port changes, DNSTT/UDPGW changes, Xray changes, +// and bandwidth defaults all take effect immediately. +// The only field that still requires a restart is host_key_file. +func applyFullConfigReload(newCfg *Config) { + // Banner + bt := newCfg.Banner + if bt == "" && newCfg.BannerFile != "" { + if data, err := os.ReadFile(newCfg.BannerFile); err == nil { + bt = string(data) + } + } + setBannerText(bt) + + // Default per-connection bandwidth limits (picked up by new connections) + setDefaultLimits(newCfg.DefaultLimitMbpsUp, newCfg.DefaultLimitMbpsDown) + + // Quiet logging / user count display + if newCfg.Quiet { + log.SetOutput(io.Discard) + } else if !newCfg.UserCount { + log.SetOutput(os.Stderr) + } + userCountEnabled = newCfg.UserCount + + // Admin panel directory (hot-swaps the file server on the next request) + if newCfg.AdminDir != "" { + setAdminHandler(http.FileServer(http.Dir(newCfg.AdminDir))) + } + + // Public SSH listeners (main listen + extra_listen) + publicAddrs := append([]string{newCfg.Listen}, newCfg.ExtraListen...) + for _, e := range publicPool.Sync(publicAddrs) { + log.Printf("hotreload: %v", e) + } + + // Local raw SSH listener + var localAddrs []string + if newCfg.LocalSSHListen != "" { + localAddrs = []string{newCfg.LocalSSHListen} + } + for _, e := range localPool.Sync(localAddrs) { + log.Printf("hotreload: %v", e) + } + + // TLS forwarders + for _, e := range tlsPool.Sync(newCfg.TLSForwarders) { + log.Printf("hotreload: %v", e) + } + + // DNSTT — stop current instance (no-op if not running) then start new one + stopDNSTT() + startDNSTT(newCfg.DNSTT, getSSHConfig()) + + // UDPGW — same pattern + stopUDPGW() + startUDPGW(newCfg.UDPGW) + + // Xray — update stored config then restart/stop as needed + if newCfg.Xray != nil { + xrayMgr.mu.Lock() + xrayMgr.cfg = newCfg.Xray + xrayMgr.mu.Unlock() + if newCfg.Xray.Enabled { + _ = xrayMgr.Restart() + } else { + _ = xrayMgr.Stop() + } + } else { + _ = xrayMgr.Stop() + } + + setGlobalCfg(newCfg) +} diff --git a/install.sh b/install.sh new file mode 100644 index 0000000..2662597 --- /dev/null +++ b/install.sh @@ -0,0 +1,386 @@ +#!/bin/bash +# Auto-install script for SSH Panel + Xray-core (Ubuntu/Debian/CentOS) +# Usage: sudo bash install.sh +set -euo pipefail + +RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m' +info() { echo -e "${GREEN}[+]${NC} $*"; } +warn() { echo -e "${YELLOW}[!]${NC} $*"; } +error() { echo -e "${RED}[x]${NC} $*"; exit 1; } + +# ── config ────────────────────────────────────────────────────────────────── +INSTALL_DIR="/opt/sshpanel" +SERVICE_NAME="sshpanel" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +GO_VERSION="${GO_VERSION:-$(awk '$1 == "go" {print $2; exit}' "$SCRIPT_DIR/go.mod" 2>/dev/null || echo "1.22.5")}" +# ──────────────────────────────────────────────────────────────────────────── + +[[ $EUID -ne 0 ]] && error "Run as root: sudo bash $0" + +echo -e "\n${GREEN}══════════════════════════════════════════${NC}" +echo -e "${GREEN} SSH Panel + Xray-core · Installer ${NC}" +echo -e "${GREEN}══════════════════════════════════════════${NC}\n" + +# ── 1. OS detection ────────────────────────────────────────────────────────── +info "[1/9] Detecting OS…" +if [[ -f /etc/os-release ]]; then + # shellcheck disable=SC1091 + . /etc/os-release + OS_ID="${ID:-unknown}" +else + OS_ID="unknown" +fi + +case "$OS_ID" in + ubuntu|debian|linuxmint) + PKG_UPDATE="apt-get update -qq" + PKG_INSTALL="DEBIAN_FRONTEND=noninteractive apt-get install -y" + PKG_DEPS="curl wget git build-essential postgresql postgresql-contrib ca-certificates unzip openssh-client openssl" + ;; + centos|rhel|rocky|almalinux) + PKG_UPDATE="yum makecache -q" + PKG_INSTALL="yum install -y" + PKG_DEPS="curl wget git gcc make postgresql-server postgresql-contrib ca-certificates unzip openssh-clients openssl" + ;; + fedora) + PKG_UPDATE="dnf makecache -q" + PKG_INSTALL="dnf install -y" + PKG_DEPS="curl wget git gcc make postgresql-server postgresql-contrib ca-certificates unzip openssh-clients openssl" + ;; + *) + warn "Unknown OS '$OS_ID' — attempting apt-get…" + PKG_UPDATE="apt-get update -qq" + PKG_INSTALL="DEBIAN_FRONTEND=noninteractive apt-get install -y" + PKG_DEPS="curl wget git build-essential postgresql postgresql-contrib ca-certificates unzip openssh-client openssl" + ;; +esac +info " OS: $OS_ID" + +# ── 2. System dependencies ─────────────────────────────────────────────────── +info "[2/9] Installing system packages…" +eval "$PKG_UPDATE" +eval "$PKG_INSTALL $PKG_DEPS" + +# ── 3. Go ──────────────────────────────────────────────────────────────────── +info "[3/9] Installing Go ${GO_VERSION}…" +NEED_GO=true +if command -v go &>/dev/null; then + CURRENT_GO=$(go version 2>/dev/null | awk '{print $3}' | sed 's/go//') + if [[ "$(printf '%s\n' "$GO_VERSION" "$CURRENT_GO" | sort -V | head -1)" == "$GO_VERSION" ]]; then + info " Go $CURRENT_GO already installed — skipping" + NEED_GO=false + fi +fi + +if $NEED_GO; then + MACHINE=$(uname -m) + case "$MACHINE" in + x86_64) GOARCH="amd64" ;; + aarch64) GOARCH="arm64" ;; + armv7l) GOARCH="armv6l" ;; + *) GOARCH="amd64" ;; + esac + GO_URL="https://go.dev/dl/go${GO_VERSION}.linux-${GOARCH}.tar.gz" + info " Downloading $GO_URL" + wget -q --show-progress -O /tmp/go.tar.gz "$GO_URL" + rm -rf /usr/local/go + tar -C /usr/local -xzf /tmp/go.tar.gz + rm -f /tmp/go.tar.gz + echo 'export PATH=$PATH:/usr/local/go/bin' > /etc/profile.d/go.sh + chmod +x /etc/profile.d/go.sh +fi + +export PATH=$PATH:/usr/local/go/bin +go version + +# ── 4. Directory layout ────────────────────────────────────────────────────── +info "[4/9] Setting up ${INSTALL_DIR}…" +mkdir -p "$INSTALL_DIR/admin" "$INSTALL_DIR/keys" "$INSTALL_DIR/logs" + +# ── 5. Build SSH panel binary ──────────────────────────────────────────────── +info "[5/9] Building SSH Panel binary…" +cd "$SCRIPT_DIR" +export GOPATH=/tmp/gopath_sshpanel +export GOCACHE=/tmp/gocache_sshpanel +go mod download +go build -ldflags="-s -w" -o "$INSTALL_DIR/sshpanel" . +info " Binary: $INSTALL_DIR/sshpanel" +cp -r "$SCRIPT_DIR/admin/"* "$INSTALL_DIR/admin/" +info " Admin panel copied" + +# ── 6. Xray binary ────────────────────────────────────────────────────────── +info "[6/9] Downloading Xray-core…" +XRAY_VER=$(curl -sf "https://api.github.com/repos/XTLS/Xray-core/releases/latest" \ + | grep '"tag_name"' | head -1 | cut -d'"' -f4 || echo "v24.11.30") +MACHINE=$(uname -m) +case "$MACHINE" in + x86_64) XRAY_ARCH="64" ;; + aarch64) XRAY_ARCH="arm64-v8a" ;; + armv7l) XRAY_ARCH="arm32-v7a" ;; + *) XRAY_ARCH="64" ;; +esac +XRAY_URL="https://github.com/XTLS/Xray-core/releases/download/${XRAY_VER}/Xray-linux-${XRAY_ARCH}.zip" +info " Xray ${XRAY_VER} (${XRAY_ARCH})" +wget -q --show-progress -O /tmp/xray.zip "$XRAY_URL" +unzip -o /tmp/xray.zip xray -d "$INSTALL_DIR" > /dev/null 2>&1 || { + mkdir -p /tmp/xray_extract + unzip -o /tmp/xray.zip -d /tmp/xray_extract > /dev/null 2>&1 + mv /tmp/xray_extract/xray "$INSTALL_DIR/xray" +} +chmod +x "$INSTALL_DIR/xray" +rm -f /tmp/xray.zip +"$INSTALL_DIR/xray" version + +# ── 7. PostgreSQL ──────────────────────────────────────────────────────────── +info "[7/9] Configuring PostgreSQL…" +case "$OS_ID" in + centos|rhel|rocky|almalinux|fedora) + postgresql-setup --initdb 2>/dev/null || true ;; +esac +systemctl start postgresql 2>/dev/null || service postgresql start 2>/dev/null || true +systemctl enable postgresql 2>/dev/null || true + +DB_NAME="sshpanel" +DB_USER="sshpanel" +DB_PASS=$(tr -dc 'A-Za-z0-9' < /dev/urandom | head -c 32 || true) +if [[ ${#DB_PASS} -lt 32 ]]; then + DB_PASS=$(openssl rand -hex 16 2>/dev/null || date +%s%N) +fi + +su -c "psql -tc \"SELECT 1 FROM pg_roles WHERE rolname='${DB_USER}'\" | grep -q 1 || \ + psql -c \"CREATE USER ${DB_USER} WITH PASSWORD '${DB_PASS}';\"" postgres +# Reinstall-safe: if the role already existed, make the new .env password valid. +su -c "psql -c \"ALTER USER ${DB_USER} WITH PASSWORD '${DB_PASS}';\"" postgres + +su -c "psql -tc \"SELECT 1 FROM pg_database WHERE datname='${DB_NAME}'\" | grep -q 1 || \ + psql -c \"CREATE DATABASE ${DB_NAME} OWNER ${DB_USER};\"" postgres +# Reinstall-safe: if the database already existed, make sshpanel its owner. +su -c "psql -c \"ALTER DATABASE ${DB_NAME} OWNER TO ${DB_USER};\"" postgres + +su -c "psql -d ${DB_NAME} -c \" +CREATE TABLE IF NOT EXISTS ssh_users ( + username TEXT PRIMARY KEY, + password TEXT NOT NULL DEFAULT '', + max_connections INT NOT NULL DEFAULT 0, + expires_at TEXT, + limit_mbps_up INT NOT NULL DEFAULT 0, + limit_mbps_down INT NOT NULL DEFAULT 0, + totp_secret TEXT NOT NULL DEFAULT '', + totp_period INT NOT NULL DEFAULT 60, + totp_window INT NOT NULL DEFAULT 1, + totp_digits INT NOT NULL DEFAULT 6, + allow_static_password BOOLEAN NOT NULL DEFAULT FALSE, + owner_username TEXT NOT NULL DEFAULT '' +); +ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_secret TEXT NOT NULL DEFAULT ''; +ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_period INT NOT NULL DEFAULT 60; +ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_window INT NOT NULL DEFAULT 1; +ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_digits INT NOT NULL DEFAULT 6; +ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS allow_static_password BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS owner_username TEXT NOT NULL DEFAULT ''; +ALTER TABLE ssh_users ALTER COLUMN password SET DEFAULT ''; + +CREATE TABLE IF NOT EXISTS ssh_iface_totals ( + iface TEXT PRIMARY KEY, + total_rx_bytes BIGINT NOT NULL DEFAULT 0, + total_tx_bytes BIGINT NOT NULL DEFAULT 0, + last_kernel_rx_bytes BIGINT NOT NULL DEFAULT 0, + last_kernel_tx_bytes BIGINT NOT NULL DEFAULT 0, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +ALTER TABLE ssh_iface_totals ADD COLUMN IF NOT EXISTS total_rx_bytes BIGINT NOT NULL DEFAULT 0; +ALTER TABLE ssh_iface_totals ADD COLUMN IF NOT EXISTS total_tx_bytes BIGINT NOT NULL DEFAULT 0; +ALTER TABLE ssh_iface_totals ADD COLUMN IF NOT EXISTS last_kernel_rx_bytes BIGINT NOT NULL DEFAULT 0; +ALTER TABLE ssh_iface_totals ADD COLUMN IF NOT EXISTS last_kernel_tx_bytes BIGINT NOT NULL DEFAULT 0; +ALTER TABLE ssh_iface_totals ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(); + +CREATE TABLE IF NOT EXISTS admin_users ( + id SERIAL PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'reseller', + max_users INT NOT NULL DEFAULT 30, + expires_at TIMESTAMPTZ, + is_active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS xray_clients ( + uuid TEXT PRIMARY KEY, + name TEXT NOT NULL DEFAULT '', + email TEXT NOT NULL DEFAULT '', + inbound_tag TEXT NOT NULL DEFAULT '', + expires_at TIMESTAMPTZ, + max_conns INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +ALTER SCHEMA public OWNER TO ${DB_USER}; +ALTER TABLE IF EXISTS ssh_users OWNER TO ${DB_USER}; +ALTER TABLE IF EXISTS ssh_iface_totals OWNER TO ${DB_USER}; +ALTER TABLE IF EXISTS admin_users OWNER TO ${DB_USER}; +ALTER TABLE IF EXISTS xray_clients OWNER TO ${DB_USER}; +ALTER SEQUENCE IF EXISTS admin_users_id_seq OWNER TO ${DB_USER}; +GRANT ALL PRIVILEGES ON DATABASE ${DB_NAME} TO ${DB_USER}; +GRANT ALL PRIVILEGES ON SCHEMA public TO ${DB_USER}; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO ${DB_USER}; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO ${DB_USER}; +\"" postgres + +info " PostgreSQL database '${DB_NAME}' ready" + +# ── 8. Config files ────────────────────────────────────────────────────────── +info "[8/9] Generating config files…" + +# Admin token +ADMIN_TOKEN=$(tr -dc 'A-Za-z0-9' < /dev/urandom | head -c 48 || true) +if [[ ${#ADMIN_TOKEN} -lt 48 ]]; then + ADMIN_TOKEN=$(openssl rand -hex 24 2>/dev/null || date +%s%N) +fi + +# Admin panel login password. The web panel login is username/password; +# ADMIN_TOKEN is only for bearer-token API access and is not the login password. +ADMIN_PASSWORD=$(tr -dc 'A-Za-z0-9' < /dev/urandom | head -c 20 || true) +if [[ ${#ADMIN_PASSWORD} -lt 20 ]]; then + ADMIN_PASSWORD=$(openssl rand -hex 10 2>/dev/null || date +%s%N) +fi +ADMIN_PASSWORD_HASH=$(printf '%s' "${ADMIN_PASSWORD}" | sha256sum | awk '{print $1}') +su -c "psql -d ${DB_NAME}" postgres < "$INSTALL_DIR/.env" </dev/null \ + || curl -sf --max-time 5 https://api.ipify.org 2>/dev/null \ + || hostname -I | awk '{print $1}') + +# config.json +cat > "$INSTALL_DIR/config.json" </dev/null \ + || python3 -c "import uuid; print(uuid.uuid4())" 2>/dev/null \ + || echo "11111111-2222-3333-4444-555555555555") + +# xray_config.json (default VLESS + SOCKS inbounds — no geoip routing needed) +cat > "$INSTALL_DIR/xray_config.json" < "/etc/systemd/system/${SERVICE_NAME}.service" <connID mapping remains + // valid after the last packet from that destination. Expressed as a + // duration string. Default is "90s". + MapTTL string `json:"map_ttl"` + // ReapEvery controls how often expired mappings are purged. Expressed + // as a duration string. Default is "10s". + ReapEvery string `json:"reap_every"` + // IdleTimeout controls how long a TCP client connection may remain + // idle (no frames received) before being closed. Expressed as a duration + // string. Default is "2m". + IdleTimeout string `json:"idle_timeout"` + // MaxClientConns limits how many distinct udpgw connIDs a single TCP + // client connection may keep active at once. This is the primary guard + // against a single client creating effectively infinite logical UDP + // sessions and growing RAM usage over time. Default is 10. + MaxClientConns int `json:"max_client_conns"` + // MaxMapEntries limits the maximum number of destination->connID + // mappings kept per client. This protects the server from unbounded + // growth if a client sprays packets to many unique destinations. + // Default is 32768. + MaxMapEntries int `json:"max_map_entries"` +} + +type UserConfig struct { + Username string `json:"username"` + Password string `json:"password"` + PublicKeyFile string `json:"public_key_file"` // optional .pub path + + // Optional rotating password mode for shared/public accounts. When + // totp_secret is set, the server accepts a TOTP code as the SSH + // password. This is useful when you want to make copied credentials + // expire quickly without relying on a separate API. + // Secret must be Base32 (with or without padding). Example generated + // secret: JBSWY3DPEHPK3PXP + TOTPSecret string `json:"totp_secret"` + // Period in seconds for each password step. Default: 60. + TOTPPeriod int `json:"totp_period"` + // Number of adjacent time windows accepted for clock drift. + // Default: 1 (previous/current/next). + TOTPWindow int `json:"totp_window"` + // Number of digits in the generated TOTP. Default: 6. + TOTPDigits int `json:"totp_digits"` + // When true, both the static Password and the TOTP code are accepted. + // When false and totp_secret is set, only the TOTP code is accepted. + AllowStaticPassword bool `json:"allow_static_password"` + + MaxConnections int `json:"max_connections"` + ExpiresAt string `json:"expires_at"` // RFC3339 or empty + + // Per-connection limits (one limiter per SSH connection) + LimitMbpsUp int `json:"limit_mbps_up"` // Mbps upstream + LimitMbpsDown int `json:"limit_mbps_down"` // Mbps downstream + + // OwnerUsername is the reseller who created this SSH user. Empty = superadmin-owned. + OwnerUsername string `json:"owner_username,omitempty"` +} + +type UserState struct { + Cfg UserConfig + ExpiresAt *time.Time + PubKey ssh.PublicKey // may be nil + + mu sync.Mutex + ActiveConns int + conns map[*ssh.ServerConn]struct{} // active SSH connections for this user +} + +type UserManager struct { + mu sync.RWMutex + users map[string]*UserState +} + +func (m *UserManager) Get(username string) (*UserState, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + u, ok := m.users[username] + return u, ok +} + +func (m *UserManager) List() []*UserState { + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]*UserState, 0, len(m.users)) + for _, u := range m.users { + out = append(out, u) + } + return out +} + +func (m *UserManager) ReplaceAll(newUsers map[string]*UserState) { + m.mu.Lock() + defer m.mu.Unlock() + m.users = newUsers +} + +// ReplaceAllPreserveRuntime replaces the user map while keeping runtime connection +// state (ActiveConns + conns) for users that already exist. +// This prevents the admin panel from showing everyone as "offline" after a DB reload. +func (m *UserManager) ReplaceAllPreserveRuntime(newUsers map[string]*UserState) { + m.mu.Lock() + old := m.users + + for username, nu := range newUsers { + if ou, ok := old[username]; ok && ou != nil && nu != nil { + ou.mu.Lock() + nu.ActiveConns = ou.ActiveConns + // Preserve the live connection set so we can still disconnect correctly. + nu.conns = ou.conns + ou.mu.Unlock() + } + } + + m.users = newUsers + m.mu.Unlock() +} + +// DisconnectUser closes all active SSH connections for the given user. +// Safe if the user does not exist or has no connections. +func (m *UserManager) DisconnectUser(username string) { + m.mu.RLock() + u, ok := m.users[username] + m.mu.RUnlock() + if !ok { + return + } + + u.mu.Lock() + conns := make([]*ssh.ServerConn, 0, len(u.conns)) + for c := range u.conns { + conns = append(conns, c) + } + u.mu.Unlock() + + for _, c := range conns { + _ = c.Close() + } +} + +// Global state +var ( + userMgr = &UserManager{users: make(map[string]*UserState)} + userCountEnabled bool + + displayMu sync.Mutex + lastDisplayLen int + + // Server stats (CPU + network interfaces) + statsMu sync.RWMutex + currentStats StatsDTO + + // Optional: persist interface byte counters across process restarts / reboots (requires PG_DSN). + ifaceTotalsMgr *IfaceTotalsManager + statsStore *Store +) + +// ---------- Helpers ---------- + +func publicKeysEqual(a, b ssh.PublicKey) bool { + if a == nil || b == nil { + return a == b + } + return a.Type() == b.Type() && bytes.Equal(a.Marshal(), b.Marshal()) +} + +func mbpsToBytesPerSec(mbps int) int64 { + if mbps <= 0 { + return 0 + } + return int64(mbps) * 1024 * 1024 / 8 +} + +func copyWithRateLimit(dst io.Writer, src io.Reader, lim *rate.Limiter) (written int64, err error) { + if lim == nil { + return io.Copy(dst, src) + } + + const bufSize = 32 * 1024 + buf := make([]byte, bufSize) + ctx := context.Background() + + for { + nr, er := src.Read(buf) + if nr > 0 { + if err := lim.WaitN(ctx, nr); err != nil { + return written, err + } + + nw, ew := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + + return written, err +} + +// ---------- Server stats (CPU + network interfaces) ---------- + +// per-interface stats returned by /api/stats +type InterfaceStats struct { + Name string `json:"name"` + RxBytes uint64 `json:"rx_bytes"` + TxBytes uint64 `json:"tx_bytes"` + RxMbps float64 `json:"rx_mbps"` + TxMbps float64 `json:"tx_mbps"` +} + +type StatsDTO struct { + CPUPercent float64 `json:"cpu_percent"` + MemTotal uint64 `json:"mem_total_bytes"` + MemUsed uint64 `json:"mem_used_bytes"` + MemAvail uint64 `json:"mem_avail_bytes"` + MemPercent float64 `json:"mem_percent"` + Interfaces []InterfaceStats `json:"interfaces"` +} +type ifaceCounters struct { + RxBytes uint64 + TxBytes uint64 +} + +func getCurrentStats() StatsDTO { + statsMu.RLock() + defer statsMu.RUnlock() + return currentStats +} + +func setCurrentStats(s StatsDTO) { + statsMu.Lock() + currentStats = s + statsMu.Unlock() +} + +type IfaceTotals struct { + Iface string + TotalRxBytes uint64 + TotalTxBytes uint64 + LastKernelRxBytes uint64 + LastKernelTxBytes uint64 + UpdatedAt time.Time +} + +type IfaceTotalsManager struct { + mu sync.Mutex + m map[string]*IfaceTotals +} + +func NewIfaceTotalsManager() *IfaceTotalsManager { + return &IfaceTotalsManager{m: make(map[string]*IfaceTotals)} +} + +// ApplyKernel updates cumulative totals using kernel counters (from /proc/net/dev). +// It is resilient to kernel counter resets (e.g. host reboot): if the kernel counter +// goes backwards, it treats the new value as "delta since reset". +func (tm *IfaceTotalsManager) ApplyKernel(iface string, kRx, kTx uint64) (totalRx, totalTx uint64) { + tm.mu.Lock() + defer tm.mu.Unlock() + + st, ok := tm.m[iface] + if !ok { + st = &IfaceTotals{Iface: iface} + tm.m[iface] = st + } + + // RX + if st.LastKernelRxBytes == 0 && st.TotalRxBytes == 0 { + st.TotalRxBytes = kRx + } else if kRx >= st.LastKernelRxBytes { + st.TotalRxBytes += kRx - st.LastKernelRxBytes + } else { + // kernel reset or wrap + st.TotalRxBytes += kRx + } + st.LastKernelRxBytes = kRx + + // TX + if st.LastKernelTxBytes == 0 && st.TotalTxBytes == 0 { + st.TotalTxBytes = kTx + } else if kTx >= st.LastKernelTxBytes { + st.TotalTxBytes += kTx - st.LastKernelTxBytes + } else { + st.TotalTxBytes += kTx + } + st.LastKernelTxBytes = kTx + + st.UpdatedAt = time.Now() + return st.TotalRxBytes, st.TotalTxBytes +} + +func (tm *IfaceTotalsManager) Load(rows []IfaceTotals) { + tm.mu.Lock() + defer tm.mu.Unlock() + for _, r := range rows { + cp := r // copy + tm.m[r.Iface] = &cp + } +} + +func (tm *IfaceTotalsManager) Snapshot() []IfaceTotals { + tm.mu.Lock() + defer tm.mu.Unlock() + out := make([]IfaceTotals, 0, len(tm.m)) + for _, v := range tm.m { + out = append(out, *v) + } + return out +} + +// startStatsCollector periodically reads /proc/stat and /proc/net/dev +// to compute CPU usage and per-interface traffic in Mbps. +func startStatsCollector() { + go func() { + // Recover from any panic so the goroutine doesn't silently die and + // leave stats frozen. Log the panic and restart after a short delay. + defer func() { + if r := recover(); r != nil { + log.Printf("startStatsCollector: panic recovered: %v; restarting in 5s", r) + time.Sleep(5 * time.Second) + go startStatsCollector() + } + }() + + var ( + prevIdle, prevTotal uint64 + prevNet map[string]ifaceCounters + prevTime time.Time + ) + + // Use a ticker for the poll interval so the goroutine can be + // cleanly stopped in tests and so the flush ticker is always + // paired with a matching Stop(). + pollTicker := time.NewTicker(2 * time.Second) + defer pollTicker.Stop() + + var flushTicker *time.Ticker + if statsStore != nil && ifaceTotalsMgr != nil { + flushTicker = time.NewTicker(30 * time.Second) + defer flushTicker.Stop() + } + + for range pollTicker.C { + now := time.Now() + + idle, total, err := readCPUStat() + if err != nil { + // keep previous CPU if error + } + + netMap, err := readNetDev() + if err != nil { + // keep previous net if error + } + + // CPU usage + var cpuPercent float64 + if prevTotal != 0 && total > prevTotal { + idleDelta := float64(idle - prevIdle) + totalDelta := float64(total - prevTotal) + if totalDelta > 0 { + cpuPercent = 100.0 * (1.0 - idleDelta/totalDelta) + } + } + + // Per-interface Mbps (Rx/Tx) + var interfaces []InterfaceStats + dt := now.Sub(prevTime).Seconds() + if netMap != nil { + for name, ctrs := range netMap { + st := InterfaceStats{ + Name: name, + } + // Bytes: if persistence enabled, show cumulative totals across restarts; else show kernel counters. + if ifaceTotalsMgr != nil { + rxTotal, txTotal := ifaceTotalsMgr.ApplyKernel(name, ctrs.RxBytes, ctrs.TxBytes) + st.RxBytes = rxTotal + st.TxBytes = txTotal + } else { + st.RxBytes = ctrs.RxBytes + st.TxBytes = ctrs.TxBytes + } + if prevNet != nil && dt > 0 { + if prev, ok := prevNet[name]; ok { + if ctrs.RxBytes >= prev.RxBytes { + rxDelta := ctrs.RxBytes - prev.RxBytes + st.RxMbps = float64(rxDelta*8) / dt / 1_000_000 + } + if ctrs.TxBytes >= prev.TxBytes { + txDelta := ctrs.TxBytes - prev.TxBytes + st.TxMbps = float64(txDelta*8) / dt / 1_000_000 + } + } + } + interfaces = append(interfaces, st) + } + } + + sort.Slice(interfaces, func(i, j int) bool { + return interfaces[i].Name < interfaces[j].Name + }) + + // RAM usage (/proc/meminfo) + memTotal, memAvail, _ := readMemInfo() + var memUsed uint64 + var memPercent float64 + if memTotal > 0 { + if memAvail <= memTotal { + memUsed = memTotal - memAvail + memPercent = 100.0 * float64(memUsed) / float64(memTotal) + } else { + memAvail = memTotal + } + } + + setCurrentStats(StatsDTO{ + CPUPercent: cpuPercent, + MemTotal: memTotal, + MemUsed: memUsed, + MemAvail: memAvail, + MemPercent: memPercent, + Interfaces: interfaces, + }) + + // Persist interface totals periodically (optional). + if flushTicker != nil && statsStore != nil && ifaceTotalsMgr != nil { + select { + case <-flushTicker.C: + ctx := context.Background() + _ = statsStore.UpsertIfaceTotals(ctx, ifaceTotalsMgr.Snapshot()) + default: + } + } + + prevIdle, prevTotal = idle, total + prevNet = netMap + prevTime = now + } + }() +} + +func readMemInfo() (totalBytes, availBytes uint64, err error) { + f, err := os.Open("/proc/meminfo") + if err != nil { + return 0, 0, err + } + defer f.Close() + + var ( + memTotalKB uint64 + memAvailKB uint64 + memFreeKB uint64 + buffersKB uint64 + cachedKB uint64 + ) + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + key := strings.TrimSuffix(fields[0], ":") + val, perr := strconv.ParseUint(fields[1], 10, 64) + if perr != nil { + continue + } + switch key { + case "MemTotal": + memTotalKB = val + case "MemAvailable": + memAvailKB = val + case "MemFree": + memFreeKB = val + case "Buffers": + buffersKB = val + case "Cached": + cachedKB = val + } + } + if err := scanner.Err(); err != nil { + return 0, 0, err + } + if memTotalKB == 0 { + return 0, 0, nil + } + if memAvailKB == 0 { + memAvailKB = memFreeKB + buffersKB + cachedKB + if memAvailKB > memTotalKB { + memAvailKB = memTotalKB + } + } + return memTotalKB * 1024, memAvailKB * 1024, nil +} + +func readCPUStat() (idle, total uint64, err error) { + f, err := os.Open("/proc/stat") + if err != nil { + return 0, 0, err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "cpu ") { + fields := strings.Fields(line) + if len(fields) < 5 { + break + } + for i := 1; i < len(fields); i++ { + v, err2 := strconv.ParseUint(fields[i], 10, 64) + if err2 != nil { + continue + } + total += v + // idle + iowait + if i == 4 || i == 5 { + idle += v + } + } + break + } + } + if err := scanner.Err(); err != nil { + return 0, 0, err + } + return idle, total, nil +} + +func readNetDev() (map[string]ifaceCounters, error) { + f, err := os.Open("/proc/net/dev") + if err != nil { + return nil, err + } + defer f.Close() + + stats := make(map[string]ifaceCounters) + scanner := bufio.NewScanner(f) + lineNum := 0 + for scanner.Scan() { + lineNum++ + if lineNum <= 2 { + // headers + continue + } + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + continue + } + iface := strings.TrimSpace(parts[0]) + fields := strings.Fields(parts[1]) + if len(fields) < 9 { + continue + } + rx, err1 := strconv.ParseUint(fields[0], 10, 64) + tx, err2 := strconv.ParseUint(fields[8], 10, 64) + if err1 != nil || err2 != nil { + continue + } + stats[iface] = ifaceCounters{ + RxBytes: rx, + TxBytes: tx, + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + return stats, nil +} + +// ---------- Config loading ---------- + +func loadConfig(path string) (*Config, map[string]*UserState, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, fmt.Errorf("read config: %w", err) + } + + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, nil, fmt.Errorf("parse config: %w", err) + } + + if cfg.Listen == "" { + cfg.Listen = ":2222" + } + if cfg.HostKeyFile == "" { + cfg.HostKeyFile = "ssh_host_rsa_key" + } + + userMap := make(map[string]*UserState) + for _, u := range cfg.Users { + if u.Username == "" { + return nil, nil, fmt.Errorf("config: user with empty username") + } + if _, exists := userMap[u.Username]; exists { + return nil, nil, fmt.Errorf("config: duplicate user %q", u.Username) + } + + // Apply default bandwidth limits if not set per-user. + if u.LimitMbpsUp == 0 { + u.LimitMbpsUp = cfg.DefaultLimitMbpsUp + } + if u.LimitMbpsDown == 0 { + u.LimitMbpsDown = cfg.DefaultLimitMbpsDown + } + + st := &UserState{Cfg: u} + + if u.ExpiresAt != "" { + t, err := time.Parse(time.RFC3339, u.ExpiresAt) + if err != nil { + return nil, nil, fmt.Errorf("config: invalid expires_at for user %q: %w", u.Username, err) + } + st.ExpiresAt = &t + } + + // Load per-user public key if configured + if u.PublicKeyFile != "" { + pkBytes, err := os.ReadFile(u.PublicKeyFile) + if err != nil { + return nil, nil, fmt.Errorf("config: could not read public_key_file for user %q: %w", u.Username, err) + } + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pkBytes) + if err != nil { + return nil, nil, fmt.Errorf("config: invalid public key in %s for user %q: %w", + u.PublicKeyFile, u.Username, err) + } + st.PubKey = pubKey + } + + userMap[u.Username] = st + } + + return &cfg, userMap, nil +} + +// ---------- Database store & admin API ---------- + +type Store struct { + db *sql.DB +} + +func NewStore(dsn string) (*Store, error) { + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(time.Hour) + + if err := db.Ping(); err != nil { + return nil, err + } + store := &Store{db: db} + ctx := context.Background() + if err := store.EnsureUsersSchema(ctx); err != nil { + return nil, err + } + if err := store.EnsureAdminUsersSchema(ctx); err != nil { + return nil, err + } + return store, nil +} + +func (s *Store) EnsureUsersSchema(ctx context.Context) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS ssh_users ( + username TEXT PRIMARY KEY, + password TEXT NOT NULL DEFAULT '', + max_connections INT NOT NULL DEFAULT 0, + expires_at TEXT, + limit_mbps_up INT NOT NULL DEFAULT 0, + limit_mbps_down INT NOT NULL DEFAULT 0, + totp_secret TEXT NOT NULL DEFAULT '', + totp_period INT NOT NULL DEFAULT 60, + totp_window INT NOT NULL DEFAULT 1, + totp_digits INT NOT NULL DEFAULT 6, + allow_static_password BOOLEAN NOT NULL DEFAULT FALSE + )`, + `ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_secret TEXT NOT NULL DEFAULT ''`, + `ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_period INT NOT NULL DEFAULT 60`, + `ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_window INT NOT NULL DEFAULT 1`, + `ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS totp_digits INT NOT NULL DEFAULT 6`, + `ALTER TABLE ssh_users ADD COLUMN IF NOT EXISTS allow_static_password BOOLEAN NOT NULL DEFAULT FALSE`, + `ALTER TABLE ssh_users ALTER COLUMN password SET DEFAULT ''`, + } + for _, stmt := range stmts { + if _, err := s.db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} + +// LoadUsers loads all users from the ssh_users table into in-memory UserState map. +func (s *Store) LoadUsers(ctx context.Context) (map[string]*UserState, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT username, password, max_connections, expires_at, limit_mbps_up, limit_mbps_down, + COALESCE(totp_secret, ''), COALESCE(totp_period, 60), COALESCE(totp_window, 1), + COALESCE(totp_digits, 6), COALESCE(allow_static_password, FALSE), + COALESCE(owner_username, '') + FROM ssh_users`) + if err != nil { + return nil, err + } + defer rows.Close() + + users := make(map[string]*UserState) + for rows.Next() { + var ( + username string + password string + maxConnections int + expiresAt sql.NullString + limitUp int + limitDown int + totpSecret string + totpPeriod int + totpWindow int + totpDigits int + allowStaticPassword bool + ownerUsername string + ) + if err := rows.Scan(&username, &password, &maxConnections, &expiresAt, &limitUp, &limitDown, + &totpSecret, &totpPeriod, &totpWindow, &totpDigits, &allowStaticPassword, &ownerUsername); err != nil { + return nil, err + } + + cfg := UserConfig{ + Username: username, + Password: password, + MaxConnections: maxConnections, + LimitMbpsUp: limitUp, + LimitMbpsDown: limitDown, + TOTPSecret: totpSecret, + TOTPPeriod: totpPeriod, + TOTPWindow: totpWindow, + TOTPDigits: totpDigits, + AllowStaticPassword: allowStaticPassword, + OwnerUsername: ownerUsername, + } + + st := &UserState{Cfg: cfg} + if expiresAt.Valid && expiresAt.String != "" { + t, err := time.Parse(time.RFC3339, expiresAt.String) + if err != nil { + log.Printf("invalid expires_at for user %s in db: %v", username, err) + } else { + st.ExpiresAt = &t + st.Cfg.ExpiresAt = expiresAt.String + } + } + users[username] = st + } + if err := rows.Err(); err != nil { + return nil, err + } + return users, nil +} + +// UpsertUser creates or updates a row in ssh_users. +func (s *Store) UpsertUser(ctx context.Context, u UserConfig) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO ssh_users ( + username, password, max_connections, expires_at, limit_mbps_up, limit_mbps_down, + totp_secret, totp_period, totp_window, totp_digits, allow_static_password, owner_username + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (username) DO UPDATE + SET password = EXCLUDED.password, + max_connections = EXCLUDED.max_connections, + expires_at = EXCLUDED.expires_at, + limit_mbps_up = EXCLUDED.limit_mbps_up, + limit_mbps_down = EXCLUDED.limit_mbps_down, + totp_secret = EXCLUDED.totp_secret, + totp_period = EXCLUDED.totp_period, + totp_window = EXCLUDED.totp_window, + totp_digits = EXCLUDED.totp_digits, + allow_static_password = EXCLUDED.allow_static_password`, + // owner_username is intentionally excluded from UPDATE — ownership is set at creation only. + u.Username, u.Password, u.MaxConnections, u.ExpiresAt, u.LimitMbpsUp, u.LimitMbpsDown, + u.TOTPSecret, u.TOTPPeriod, u.TOTPWindow, u.TOTPDigits, u.AllowStaticPassword, u.OwnerUsername) + return err +} + +func (s *Store) DeleteUser(ctx context.Context, username string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM ssh_users WHERE username = $1`, username) + return err +} + +// ---------- Optional persistence for interface totals ---------- + +func (s *Store) EnsureIfaceTotalsTable(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS ssh_iface_totals ( + iface TEXT PRIMARY KEY, + total_rx_bytes BIGINT NOT NULL DEFAULT 0, + total_tx_bytes BIGINT NOT NULL DEFAULT 0, + last_kernel_rx_bytes BIGINT NOT NULL DEFAULT 0, + last_kernel_tx_bytes BIGINT NOT NULL DEFAULT 0, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + )`) + return err +} + +func (s *Store) LoadIfaceTotals(ctx context.Context) ([]IfaceTotals, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT iface, total_rx_bytes, total_tx_bytes, last_kernel_rx_bytes, last_kernel_tx_bytes, updated_at + FROM ssh_iface_totals`) + if err != nil { + return nil, err + } + defer rows.Close() + + out := []IfaceTotals{} + for rows.Next() { + var r IfaceTotals + var updated time.Time + if err := rows.Scan(&r.Iface, &r.TotalRxBytes, &r.TotalTxBytes, &r.LastKernelRxBytes, &r.LastKernelTxBytes, &updated); err != nil { + return nil, err + } + r.UpdatedAt = updated + out = append(out, r) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +func (s *Store) UpsertIfaceTotals(ctx context.Context, rows []IfaceTotals) error { + if len(rows) == 0 { + return nil + } + // Simple loop (small N: number of interfaces). Keeps CPU/DB overhead minimal. + for _, r := range rows { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO ssh_iface_totals (iface, total_rx_bytes, total_tx_bytes, last_kernel_rx_bytes, last_kernel_tx_bytes, updated_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + ON CONFLICT (iface) DO UPDATE + SET total_rx_bytes = EXCLUDED.total_rx_bytes, + total_tx_bytes = EXCLUDED.total_tx_bytes, + last_kernel_rx_bytes = EXCLUDED.last_kernel_rx_bytes, + last_kernel_tx_bytes = EXCLUDED.last_kernel_tx_bytes, + updated_at = NOW()`, + r.Iface, r.TotalRxBytes, r.TotalTxBytes, r.LastKernelRxBytes, r.LastKernelTxBytes) + if err != nil { + return err + } + } + return nil +} + +func reloadUsersFromDB(ctx context.Context, store *Store) { + if store == nil { + return + } + users, err := store.LoadUsers(ctx) + if err != nil { + log.Printf("failed to reload users from db: %v", err) + return + } + userMgr.ReplaceAllPreserveRuntime(users) + updateUserDisplay() +} + +func startAdminAPI(store *Store, addr string, adminDir string) { + mux := http.NewServeMux() + + // Auth (no session required) + mux.Handle("/api/auth/login", http.HandlerFunc(handleLogin(store))) + + // Auth (session required) + mux.Handle("/api/auth/logout", sessionMiddleware(http.HandlerFunc(handleLogout))) + mux.Handle("/api/auth/me", sessionMiddleware(http.HandlerFunc(handleMe))) + + // SSH user management (session required; role-filtered inside handlers) + mux.Handle("/api/users", sessionMiddleware(http.HandlerFunc(handleListUsers))) + mux.Handle("/api/users/create", sessionMiddleware(http.HandlerFunc(handleCreateUser(store)))) + mux.Handle("/api/users/delete", sessionMiddleware(http.HandlerFunc(handleDeleteUser(store)))) + + // Superadmin-only: server stats + DNSTT + mux.Handle("/api/stats", saSession(http.HandlerFunc(handleStats))) + mux.Handle("/api/dnstt", saSession(http.HandlerFunc(handleDnsttStats))) + mux.Handle("/api/dnstt/logs", saSession(http.HandlerFunc(handleDnsttLogs))) + + // Superadmin-only: reseller management + mux.Handle("/api/resellers", saSession(http.HandlerFunc(handleListResellers(store)))) + mux.Handle("/api/resellers/create", saSession(http.HandlerFunc(handleCreateReseller(store)))) + mux.Handle("/api/resellers/delete", saSession(http.HandlerFunc(handleDeleteReseller(store)))) + + // Superadmin-only: Xray-core management + mux.Handle("/api/xray/status", saSession(http.HandlerFunc(handleXrayStatus))) + mux.Handle("/api/xray/start", saSession(http.HandlerFunc(handleXrayStart))) + mux.Handle("/api/xray/stop", saSession(http.HandlerFunc(handleXrayStop))) + mux.Handle("/api/xray/restart", saSession(http.HandlerFunc(handleXrayRestart))) + mux.Handle("/api/xray/config", saSession(http.HandlerFunc(handleXrayConfig))) + mux.Handle("/api/xray/logs", saSession(http.HandlerFunc(handleXrayLogs))) + mux.Handle("/api/xray/inbounds", saSession(http.HandlerFunc(handleXrayInbounds))) + mux.Handle("/api/xray/clients/add", saSession(http.HandlerFunc(handleXrayClientAdd))) + mux.Handle("/api/xray/clients/update", saSession(http.HandlerFunc(handleXrayClientUpdate))) + mux.Handle("/api/xray/clients/remove", saSession(http.HandlerFunc(handleXrayClientRemove))) + + // Superadmin-only: TLS certificate generation + mux.Handle("/api/tls/generate-selfsigned", saSession(http.HandlerFunc(handleTLSGenerateSelfSigned))) + mux.Handle("/api/tls/letsencrypt", saSession(http.HandlerFunc(handleTLSLetsEncrypt))) + mux.Handle("/api/tls/upload-pem", saSession(http.HandlerFunc(handleTLSUploadPEM))) + + // Superadmin-only: DNSTT key management + mux.Handle("/api/dnstt/genkey", saSession(http.HandlerFunc(handleDnsttGenKey))) + mux.Handle("/api/dnstt/pubkey", saSession(http.HandlerFunc(handleDnsttGetPubKey))) + + // Superadmin-only: server config (read/write config.json + live banner apply) + mux.Handle("/api/server/config", saSession(http.HandlerFunc(handleServerConfig))) + + // Public: user/UUID check — no auth, CORS *. + mux.Handle("/check", http.HandlerFunc(handleCheck)) + + // Static panel — delegates to a global handler so admin_dir can be hot-swapped. + setAdminHandler(http.FileServer(http.Dir(adminDir))) + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + getAdminHandler().ServeHTTP(w, r) + })) + + go func() { + log.Printf("Admin HTTP (panel + API) listening on %s", addr) + if err := http.ListenAndServe(addr, mux); err != nil { + log.Printf("admin http error: %v", err) + } + }() +} + +// UserDTO is returned by the admin API for listing. +type UserDTO struct { + Username string `json:"username"` + ActiveConns int `json:"active_conns"` + MaxConnections int `json:"max_connections"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + LimitUpMbps int `json:"limit_mbps_up"` + LimitDownMbps int `json:"limit_mbps_down"` + TOTPSecret string `json:"totp_secret,omitempty"` + TOTPPeriod int `json:"totp_period"` + TOTPWindow int `json:"totp_window"` + TOTPDigits int `json:"totp_digits"` + AllowStaticPassword bool `json:"allow_static_password"` + TOTPEnabled bool `json:"totp_enabled"` + OwnerUsername string `json:"owner_username,omitempty"` +} + +func handleListUsers(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + sess := sessionFromCtx(r.Context()) + states := userMgr.List() + out := make([]UserDTO, 0, len(states)) + for _, u := range states { + u.mu.Lock() + c := u.ActiveConns + cfg := u.Cfg + expires := u.ExpiresAt + u.mu.Unlock() + + // Resellers only see their own users + if sess != nil && sess.Role == RoleReseller && cfg.OwnerUsername != sess.Username { + continue + } + + out = append(out, UserDTO{ + Username: cfg.Username, + ActiveConns: c, + MaxConnections: cfg.MaxConnections, + ExpiresAt: expires, + LimitUpMbps: cfg.LimitMbpsUp, + LimitDownMbps: cfg.LimitMbpsDown, + TOTPSecret: cfg.TOTPSecret, + TOTPPeriod: cfg.TOTPPeriod, + TOTPWindow: cfg.TOTPWindow, + TOTPDigits: cfg.TOTPDigits, + AllowStaticPassword: cfg.AllowStaticPassword, + TOTPEnabled: strings.TrimSpace(cfg.TOTPSecret) != "", + OwnerUsername: cfg.OwnerUsername, + }) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(out) +} + +// NOTE: Password is pointer to distinguish "field missing" vs "field present". +type UserPayload struct { + Username string `json:"username"` + Password *string `json:"password,omitempty"` // nil or empty = keep existing if user already exists + MaxConnections int `json:"max_connections"` + ExpiresAt string `json:"expires_at"` + LimitUpMbps int `json:"limit_mbps_up"` + LimitDownMbps int `json:"limit_mbps_down"` + TOTPSecret string `json:"totp_secret"` + TOTPPeriod int `json:"totp_period"` + TOTPWindow int `json:"totp_window"` + TOTPDigits int `json:"totp_digits"` + AllowStaticPassword bool `json:"allow_static_password"` +} + +func handleCreateUser(store *Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if store == nil { + http.Error(w, "database not configured", http.StatusServiceUnavailable) + return + } + + var p UserPayload + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if p.Username == "" { + http.Error(w, "username required", http.StatusBadRequest) + return + } + + ctx := r.Context() + + // Decide what password to use: + // - if payload has non-empty password -> use it + // - else try to read existing password from DB + // - if user exists -> keep existing password + // - if not exists and TOTP is enabled -> allow blank password + // - otherwise reject (password required for new user) + var password string + + if p.Password != nil && *p.Password != "" { + password = *p.Password + } else { + var existing string + err := store.db.QueryRowContext(ctx, + `SELECT password FROM ssh_users WHERE username = $1`, + p.Username, + ).Scan(&existing) + + if err == sql.ErrNoRows { + if strings.TrimSpace(p.TOTPSecret) == "" { + http.Error(w, "password or totp_secret required for new user", http.StatusBadRequest) + return + } + password = "" + } else if err != nil { + log.Printf("failed to load existing password for %s: %v", p.Username, err) + http.Error(w, "db error", http.StatusInternalServerError) + return + } else { + password = existing + } + } + + // Determine owner and enforce reseller quota + sess := sessionFromCtx(ctx) + ownerUsername := "" + if sess != nil && sess.Role == RoleReseller { + ownerUsername = sess.Username + // Enforce user limit — only count on new user creation + var existsInDB bool + _ = store.db.QueryRowContext(ctx, + `SELECT TRUE FROM ssh_users WHERE username=$1`, p.Username, + ).Scan(&existsInDB) + if !existsInDB { + owner, ok := adminUsers.get(sess.Username) + if ok && owner.MaxUsers > 0 && countOwnedUsers(sess.Username) >= owner.MaxUsers { + http.Error(w, fmt.Sprintf("user limit reached (%d)", owner.MaxUsers), http.StatusForbidden) + return + } + } + } + + cfg := UserConfig{ + Username: p.Username, + Password: password, + MaxConnections: p.MaxConnections, + ExpiresAt: p.ExpiresAt, + LimitMbpsUp: p.LimitUpMbps, + LimitMbpsDown: p.LimitDownMbps, + TOTPSecret: strings.TrimSpace(p.TOTPSecret), + TOTPPeriod: p.TOTPPeriod, + TOTPWindow: p.TOTPWindow, + TOTPDigits: p.TOTPDigits, + AllowStaticPassword: p.AllowStaticPassword, + OwnerUsername: ownerUsername, + } + + if err := store.UpsertUser(ctx, cfg); err != nil { + log.Printf("failed to upsert user: %v", err) + http.Error(w, "db error", http.StatusInternalServerError) + return + } + + // Force-disconnect all active sessions for this user so new config applies. + userMgr.DisconnectUser(p.Username) + + reloadUsersFromDB(ctx, store) + w.WriteHeader(http.StatusCreated) + } +} + +func handleDeleteUser(store *Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if store == nil { + http.Error(w, "database not configured", http.StatusServiceUnavailable) + return + } + + username := r.URL.Query().Get("username") + if username == "" { + http.Error(w, "username required", http.StatusBadRequest) + return + } + + ctx := r.Context() + + // Resellers may only delete their own users + sess := sessionFromCtx(ctx) + if sess != nil && sess.Role == RoleReseller { + u, ok := userMgr.Get(username) + if !ok || u.Cfg.OwnerUsername != sess.Username { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + } + + if err := store.DeleteUser(ctx, username); err != nil { + log.Printf("failed to delete user: %v", err) + http.Error(w, "db error", http.StatusInternalServerError) + return + } + + // Kick any live sessions for that user + userMgr.DisconnectUser(username) + + reloadUsersFromDB(ctx, store) + w.WriteHeader(http.StatusNoContent) + } +} + +// ---------- Stats API handler ---------- + +func handleStats(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + stats := getCurrentStats() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(stats) +} + +// handleDnsttStats returns the most recent DNSTT statistics snapshot. These +// counters represent activity in the last 5‑second window. The handler +// requires a GET request and is protected by the same bearer token as +// other admin API endpoints. +func handleDnsttStats(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + // Obtain a copy of the current snapshot. + stats := GetDNSTTStatsSnapshot() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(stats) +} + +// handleDnsttLogs returns the recent DNSTT log lines as a JSON array of strings. +// If DNSTT logging has not been initialised, it returns an empty array. The +// handler requires a GET request and uses the same bearer token as other +// admin endpoints. +func handleDnsttLogs(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + lines := getDNSTTLogLines() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(lines) +} + +// ---------- Auth helpers ---------- + +func normalizeBase32Secret(secret string) string { + secret = strings.ToUpper(strings.TrimSpace(secret)) + secret = strings.ReplaceAll(secret, " ", "") + return secret +} + +func totpDigitsOrDefault(n int) int { + if n < 6 || n > 8 { + return 6 + } + return n +} + +func totpPeriodOrDefault(n int) int64 { + if n <= 0 { + return 60 + } + return int64(n) +} + +func totpWindowOrDefault(n int) int { + if n < 0 { + return 0 + } + if n == 0 { + return 1 + } + return n +} + +func pow10(n int) uint32 { + v := uint32(1) + for i := 0; i < n; i++ { + v *= 10 + } + return v +} + +func generateTOTPCode(secret string, ts time.Time, period int64, digits int) (string, error) { + secret = normalizeBase32Secret(secret) + key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) + if err != nil { + return "", fmt.Errorf("decode TOTP secret: %w", err) + } + counter := uint64(ts.Unix() / period) + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], counter) + h := hmac.New(sha1.New, key) + _, _ = h.Write(buf[:]) + sum := h.Sum(nil) + offset := sum[len(sum)-1] & 0x0f + code := (uint32(sum[offset])&0x7f)<<24 | + uint32(sum[offset+1])<<16 | + uint32(sum[offset+2])<<8 | + uint32(sum[offset+3]) + mod := pow10(digits) + return fmt.Sprintf("%0*d", digits, code%mod), nil +} + +func matchTOTPPassword(u *UserState, supplied string, now time.Time) bool { + secret := strings.TrimSpace(u.Cfg.TOTPSecret) + if secret == "" { + return false + } + period := totpPeriodOrDefault(u.Cfg.TOTPPeriod) + digits := totpDigitsOrDefault(u.Cfg.TOTPDigits) + window := totpWindowOrDefault(u.Cfg.TOTPWindow) + for i := -window; i <= window; i++ { + stepTime := now.Add(time.Duration(i) * time.Duration(period) * time.Second) + code, err := generateTOTPCode(secret, stepTime, period, digits) + if err != nil { + log.Printf("user %s has invalid TOTP configuration: %v", u.Cfg.Username, err) + return false + } + if supplied == code { + return true + } + } + return false +} + +// ---------- Auth callbacks ---------- + +func passwordCallback(meta ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + u, ok := userMgr.Get(meta.User()) + if !ok { + return nil, fmt.Errorf("authentication failed") + } + now := time.Now() + if u.ExpiresAt != nil && now.After(*u.ExpiresAt) { + log.Printf("user %s tried to connect but account is expired", meta.User()) + return nil, fmt.Errorf("account expired") + } + if err := ownerIsActive(u.Cfg.OwnerUsername); err != nil { + return nil, fmt.Errorf("authentication failed: %w", err) + } + supplied := string(pass) + if strings.TrimSpace(u.Cfg.TOTPSecret) != "" { + if matchTOTPPassword(u, supplied, now) { + return nil, nil + } + if u.Cfg.AllowStaticPassword && u.Cfg.Password == supplied { + return nil, nil + } + return nil, fmt.Errorf("authentication failed") + } + if u.Cfg.Password != supplied { + return nil, fmt.Errorf("authentication failed") + } + return nil, nil +} + +func publicKeyCallback(meta ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + u, ok := userMgr.Get(meta.User()) + if !ok { + return nil, fmt.Errorf("authentication failed") + } + if u.ExpiresAt != nil && time.Now().After(*u.ExpiresAt) { + log.Printf("user %s tried to connect but account is expired", meta.User()) + return nil, fmt.Errorf("account expired") + } + if err := ownerIsActive(u.Cfg.OwnerUsername); err != nil { + return nil, fmt.Errorf("authentication failed: %w", err) + } + if u.PubKey == nil { + return nil, fmt.Errorf("no public key configured") + } + if !publicKeysEqual(u.PubKey, key) { + return nil, fmt.Errorf("authentication failed") + } + return nil, nil +} + +// ---------- Connection handling ---------- + +func handleConn(tcpConn net.Conn, config *ssh.ServerConfig) { + defer tcpConn.Close() + + // Prevent goroutine leaks from clients that connect but never complete the SSH handshake. + _ = tcpConn.SetReadDeadline(time.Now().Add(sshHandshakeTimeout)) + + sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config) + if err != nil { + log.Println("ssh handshake failed:", err) + return + } + // Clear deadlines after a successful handshake. + _ = tcpConn.SetDeadline(time.Time{}) + username := sshConn.User() + log.Printf("new SSH connection from %s as %s", sshConn.RemoteAddr(), username) + + u, ok := userMgr.Get(username) + if !ok { + log.Printf("user %s not found in config (should not happen if auth is correct)", username) + sshConn.Close() + return + } + + // Track active connection and enforce max_connections + u.mu.Lock() + if u.Cfg.MaxConnections > 0 && u.ActiveConns >= u.Cfg.MaxConnections { + u.mu.Unlock() + log.Printf("user %s exceeded max_connections (%d)", username, u.Cfg.MaxConnections) + sshConn.Close() + return + } + u.ActiveConns++ + if u.conns == nil { + u.conns = make(map[*ssh.ServerConn]struct{}) + } + u.conns[sshConn] = struct{}{} + u.mu.Unlock() + + // Use per-user limit if set; otherwise fall back to the global config default. + limitUp, limitDown := u.Cfg.LimitMbpsUp, u.Cfg.LimitMbpsDown + if limitUp == 0 || limitDown == 0 { + defUp, defDown := getDefaultLimits() + if limitUp == 0 { + limitUp = defUp + } + if limitDown == 0 { + limitDown = defDown + } + } + var upLimiter, downLimiter *rate.Limiter + if limitUp > 0 { + bps := mbpsToBytesPerSec(limitUp) + upLimiter = rate.NewLimiter(rate.Limit(bps), int(bps)) + } + if limitDown > 0 { + bps := mbpsToBytesPerSec(limitDown) + downLimiter = rate.NewLimiter(rate.Limit(bps), int(bps)) + } + + updateUserDisplay() + + defer func() { + u.mu.Lock() + u.ActiveConns-- + delete(u.conns, sshConn) + u.mu.Unlock() + + updateUserDisplay() + sshConn.Close() + }() + + go ssh.DiscardRequests(reqs) + + for newChan := range chans { + switch newChan.ChannelType() { + case "direct-tcpip": + go handleDirectTCPIP(newChan, u, upLimiter, downLimiter) + case "session": + go handleDummySession(newChan) + default: + newChan.Reject(ssh.UnknownChannelType, "unsupported channel type") + } + } +} + +type directTCPIPReq struct { + Host string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +func handleDirectTCPIP(newChan ssh.NewChannel, u *UserState, upLimiter, downLimiter *rate.Limiter) { + var req directTCPIPReq + if err := ssh.Unmarshal(newChan.ExtraData(), &req); err != nil { + newChan.Reject(ssh.Prohibited, "bad direct-tcpip request") + return + } + + target := fmt.Sprintf("%s:%d", req.Host, req.Port) + log.Printf("direct-tcpip: user=%s connecting to %s from %s:%d", + u.Cfg.Username, target, req.OriginAddr, req.OriginPort) + + dialer := &net.Dialer{Timeout: directTCPIPDialTimeout, KeepAlive: 30 * time.Second} + backend, err := dialer.Dial("tcp", target) + if err != nil { + log.Printf("failed to connect to %s: %v", target, err) + newChan.Reject(ssh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = backend.Close() + return + } + + go func() { + defer ch.Close() + for req := range reqs { + if req.WantReply { + req.Reply(false, nil) + } + } + }() + + // Use a WaitGroup + shared closer so that when either copy direction + // finishes (including a half-close that never completes), both sides + // are force-closed and the other goroutine is unblocked. Without + // this, a backend that issues CloseWrite but never closes the read + // side would leave the downstream goroutine blocked indefinitely, + // leaking a goroutine, a rate-limiter, and the SSH channel. + var wg sync.WaitGroup + closeAll := func() { + _ = backend.Close() + _ = ch.Close() + } + + // upstream: SSH channel -> backend + wg.Add(1) + go func() { + defer wg.Done() + _, _ = copyWithRateLimit(backend, ch, upLimiter) + // Signal to the backend that we are done writing. + if cw, ok := backend.(interface{ CloseWrite() error }); ok { + _ = cw.CloseWrite() + } + closeAll() + }() + + // downstream: backend -> SSH channel + wg.Add(1) + go func() { + defer wg.Done() + _, _ = copyWithRateLimit(ch, backend, downLimiter) + closeAll() + }() + + // Wait for both goroutines to finish, then ensure everything is closed. + go func() { + wg.Wait() + closeAll() + }() +} + +func handleDummySession(newChan ssh.NewChannel) { + ch, reqs, err := newChan.Accept() + if err != nil { + return + } + + go func() { + defer ch.Close() + for req := range reqs { + if req.WantReply { + req.Reply(false, nil) + } + } + }() +} + +// ---------- UI / display ---------- + +func updateUserDisplay() { + if !userCountEnabled { + return + } + + userStates := userMgr.List() + parts := make([]string, 0, len(userStates)) + for _, u := range userStates { + u.mu.Lock() + c := u.ActiveConns + name := u.Cfg.Username + u.mu.Unlock() + parts = append(parts, fmt.Sprintf("%s: %d", name, c)) + } + + sort.Strings(parts) + line := strings.Join(parts, " ") + if line == "" { + line = "no users" + } + + displayMu.Lock() + defer displayMu.Unlock() + + if len(line) < lastDisplayLen { + line = line + strings.Repeat(" ", lastDisplayLen-len(line)) + } + lastDisplayLen = len(line) + + fmt.Printf("\r%s", line) +} + +// ---------- Main ---------- + +// ======================== +// Port-80 HTTP-clean + SSH +// ======================== + +// bufferedConn is a net.Conn wrapper whose Read() is served by a bufio.Reader. +// This preserves any bytes already buffered/peeked during HTTP cleanup. +type bufferedConn struct { + net.Conn + r *bufio.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) } + +// handleHTTP80Conn applies the same "HTTP injection cleanup" that proxy.go did, +// then hands the cleaned connection directly to the SSH server handler, avoiding +// the extra io.Copy hop to a separate backend SSH process. +func handleHTTP80Conn(raw net.Conn, sshConfig *ssh.ServerConfig) { + // Best-effort TCP tuning (no hard dependency on TCPConn) + if tc, ok := raw.(*net.TCPConn); ok { + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(30 * time.Second) + _ = tc.SetNoDelay(true) + } + + // 101 response up-front (keeps behavior identical to the old proxy) + status := "Switching Protocols" + _, _ = raw.Write([]byte(fmt.Sprintf("HTTP/1.1 101 %s\r\n\r\n", status))) + + skip200 := false + br := bufio.NewReaderSize(raw, 32<<10) + + // Drain chained HTTP header blocks with a short rolling deadline so Peek/ReadBytes never stalls. + cleanWindow := 30 * time.Second + _ = raw.SetReadDeadline(time.Now().Add(cleanWindow)) + + for { + // swallow stray CR/LF between chained blocks + if skipped, _ := eatLeadingEOL(br); skipped > 0 { + log.Printf("[CLEAN] skipped %d EOL bytes between blocks", skipped) + } + + if lit, _ := eatLiteralBackslashCRLF(br); lit > 0 { + log.Printf("[CLEAN] skipped %d bytes of literal \"\\r\\n\"", lit) + // keep rolling the deadline since we consumed data + _ = raw.SetReadDeadline(time.Now().Add(cleanWindow)) + // continue to try cleaning more + continue + } + if litN, _ := eatLiteralBackslashLF(br); litN > 0 { + log.Printf("[CLEAN] skipped %d bytes of literal \"\\n\"", litN) + _ = raw.SetReadDeadline(time.Now().Add(cleanWindow)) + continue + } + + // peek a bit more so we can classify reliably + size := 8 + if b := br.Buffered(); b > 0 && b < size { + size = b + } + if size < 1 { + size = 1 + } + + peek, err := br.Peek(size) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + log.Printf("[CLEAN] timeout; pass-through next=%s", peekPreview(br, 64)) + break + } + if err == io.EOF { + // Normal when payload stops. Proceed. + log.Printf("[CLEAN] EOF; nothing left") + break + } + log.Println("[CLEAN] peek error:", err) + break + } + + // Fast path: HTTP-ish at offset 0 (request-line OR status-line) + if maybeHTTPStartPrefix(peek) { + firstLine, headers, err := discardAndCaptureHTTP(br) + if err != nil { + log.Println("[CLEAN] discard error:", err) + break + } + if headerHasValue(headers, "npv-x", "ok") { + cleanWindow = 2 * time.Second + _ = raw.SetReadDeadline(time.Now().Add(cleanWindow)) + skip200 = true + status = "OK" + _, err = raw.Write([]byte(fmt.Sprintf("HTTP/1.1 200 %s\r\n\r\n", status))) + if err != nil { + log.Println("Failed to write 200 OK:", err) + return + } + log.Printf("[CLEAN] NPV-X: OK; promote clean window to %s", cleanWindow) + } + if bytes.Contains(bytes.ToLower(headers), []byte("x-skip: 200")) { + skip200 = true + log.Printf("[CLEAN] detected and stripped X-Skip: 200 header") + } + log.Printf("[CLEAN] discarded HTTP block: %q", firstLine) + _ = raw.SetReadDeadline(time.Now().Add(cleanWindow)) + continue + } + + // Slow path: scan the already-buffered bytes for a later HTTP token + if bufN := br.Buffered(); bufN > 0 { + pbuf, _ := br.Peek(bufN) + + // If we see SSH banner soon, stop cleaning now + if iSSH := bytes.Index(pbuf, []byte("SSH-")); iSSH >= 0 && iSSH < 64 { + log.Printf("[CLEAN] SSH banner detected at +%d; stop cleaning", iSSH) + log.Printf("[CLEAN] pass-through begins; next=%s", peekPreview(br, 64)) + break + } + + if pos := findHTTPStartIndex(pbuf); pos > 0 { + // 1) discard junk prefix before the HTTP token + noise := make([]byte, pos) + if _, err := io.ReadFull(br, noise); err != nil { + log.Printf("[CLEAN] error discarding noise prefix: %v", err) + break + } + log.Printf("[CLEAN] discarded NOISE prefix (%d bytes): %q (hex:%s)", + pos, asciiPreview(noise, 64), hexPreview(noise, 64)) + + // 2) discard the HTTP block starting at current position + firstLine, err := discardOneHTTP(br) + if err != nil { + log.Printf("[CLEAN] discard error after noise: %v", err) + break + } + log.Printf("[CLEAN] discarded HTTP block after noise: %q", firstLine) + _ = raw.SetReadDeadline(time.Now().Add(cleanWindow)) + continue + } + + // No HTTP tokens; pass-through whatever remains + log.Printf("[CLEAN] pass-through begins; next=%s", peekPreview(br, 64)) + } else { + log.Printf("[CLEAN] pass-through (no buffered data)") + } + break + } + + if skip200 { + log.Printf("[SKIP] X-Skip: 200 detected earlier; not sending 200 OK") + } else { + status = "OK" + _, err := raw.Write([]byte(fmt.Sprintf("HTTP/1.1 200 %s\r\n\r\n", status))) + if err != nil { + log.Println("Failed to write 200 OK:", err) + return + } + } + + // Clear deadline before SSH handshake begins. + _ = raw.SetReadDeadline(time.Time{}) + + // Hand off to SSH server, ensuring any buffered bytes are preserved. + handleConn(&bufferedConn{Conn: raw, r: br}, sshConfig) +} + +// serveHTTP80 accepts connections on ln and dispatches each to +// handleHTTP80Conn. It logs accept errors but otherwise runs +// indefinitely. This helper is used to support multiple listen +// addresses without duplicating accept loop code. +func serveHTTP80(ln net.Listener) { + for { + tcpConn, err := ln.Accept() + if err != nil { + if isListenerClosed(err) { + return + } + log.Printf("accept error on %s: %v", ln.Addr().String(), err) + continue + } + go handleHTTP80Conn(tcpConn, getSSHConfig()) + } +} + +// serveRawSSH accepts connections on ln and passes each directly into +// handleConn. Unlike serveHTTP80 it does not perform HTTP cleanup +// and starts the SSH handshake immediately. This is used for +// local-only listeners that other daemons (e.g. DNSTT) dial into. +func serveRawSSH(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + if isListenerClosed(err) { + return + } + log.Printf("local accept error on %s: %v", ln.Addr().String(), err) + continue + } + go handleConn(c, getSSHConfig()) + } +} + +// serveTLSSSH accepts TLS-wrapped connections on ln. It forces a +// TLS handshake on each accepted connection so handshake failures are +// surfaced early, then passes the decrypted stream to handleConn. +// There is no backend dial or data copying; the SSH handler runs +// directly on the TLS stream. +func serveTLSSSH(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + if isListenerClosed(err) { + return + } + log.Printf("tls accept error on %s: %v", ln.Addr().String(), err) + continue + } + go func(client net.Conn) { + _ = client.SetDeadline(time.Now().Add(tlsHandshakeTimeout)) + if tc, ok := client.(*tls.Conn); ok { + if err := tc.Handshake(); err != nil { + log.Printf("tls handshake error from %s: %v", client.RemoteAddr().String(), err) + _ = client.Close() + return + } + } + _ = client.SetDeadline(time.Time{}) + handleConn(client, getSSHConfig()) + }(c) + } +} + +var httpMethods = [][]byte{ + // HTTP/1.1 + PATCH + []byte("GET "), []byte("HEAD "), []byte("POST "), []byte("PUT "), + []byte("DELETE "), []byte("CONNECT "), []byte("OPTIONS "), []byte("TRACE "), + []byte("PATCH "), + + // WebDAV (RFC 4918) + []byte("PROPFIND "), []byte("PROPPATCH "), []byte("MKCOL "), + []byte("COPY "), []byte("MOVE "), []byte("LOCK "), []byte("UNLOCK "), + + // WebDAV SEARCH (RFC 5323) + []byte("SEARCH "), + + // WebDAV ACL (RFC 3744) + []byte("ACL "), + + // WebDAV Versioning/DeltaV (RFC 3253) + []byte("VERSION-CONTROL "), []byte("REPORT "), []byte("CHECKOUT "), + []byte("CHECKIN "), []byte("UNCHECKOUT "), []byte("MKWORKSPACE "), + []byte("UPDATE "), []byte("LABEL "), []byte("MERGE "), + []byte("BASELINE-CONTROL "), []byte("MKACTIVITY "), + + // WebDAV Ordering (RFC 3648) + []byte("ORDERPATCH "), + + // WebDAV Bindings (RFC 5842) + []byte("BIND "), []byte("UNBIND "), []byte("REBIND "), + + // CalDAV (RFC 4791) + []byte("MKCALENDAR "), + + []byte("x-skip: "), + []byte("npv-x: "), +} + +func eatLeadingEOL(br *bufio.Reader) (int, error) { + n := 0 + for { + b, err := br.Peek(1) + if err != nil { + return n, err + } + if b[0] != '\r' && b[0] != '\n' { + return n, nil + } + if _, err := br.ReadByte(); err != nil { + return n, err + } + n++ + } +} + +func asciiPreview(b []byte, max int) string { + if len(b) > max { + b = b[:max] + } + out := make([]byte, len(b)) + for i, c := range b { + if c < 32 || c == 127 { + out[i] = '.' + } else { + out[i] = c + } + } + return string(out) +} + +func peekPreview(br *bufio.Reader, max int) string { + n := br.Buffered() + if n < 1 { + n = max + } + if n > max { + n = max + } + p, _ := br.Peek(n) + return fmt.Sprintf("%q (hex:% X)", asciiPreview(p, n), p) +} + +func httpStartTokens() [][]byte { + toks := [][]byte{ + []byte("HTTP/"), []byte("HTTP/1."), []byte("HTTP/2"), []byte("HTTP/3"), + []byte("PRI "), []byte("ICY "), + } + toks = append(toks, httpMethods...) + toks = append(toks, []byte("CONNECT")) // bare + return toks +} + +func findHTTPStartIndex(b []byte) int { + if len(b) == 0 { + return -1 + } + upper := bytes.ToUpper(b) + min := -1 + for _, t := range httpStartTokens() { + tu := bytes.ToUpper(t) + if idx := bytes.Index(upper, tu); idx >= 0 { + if min == -1 || idx < min { + min = idx + } + } + } + return min +} + +func hexPreview(b []byte, max int) string { + if len(b) > max { + b = b[:max] + } + return fmt.Sprintf("% X", b) +} + +func eatLiteralBackslashCRLF(br *bufio.Reader) (int, error) { + count := 0 + for { + p, err := br.Peek(4) + if err != nil { + // If fewer than 4 bytes buffered, nothing more to eat. + return count, nil + } + if len(p) >= 4 && p[0] == '\\' && p[1] == 'r' && p[2] == '\\' && p[3] == 'n' { + if _, err := br.Discard(4); err != nil { + return count, err + } + count += 4 + continue + } + return count, nil + } +} + +func eatLiteralBackslashLF(br *bufio.Reader) (int, error) { + p, err := br.Peek(2) + if err != nil || len(p) < 2 { + return 0, nil + } + if p[0] == '\\' && p[1] == 'n' { + if _, err := br.Discard(2); err != nil { + return 0, err + } + return 2, nil + } + return 0, nil +} + +func headerHasValue(headers []byte, key, want string) bool { + lh := bytes.ToLower(headers) + keyb := []byte(strings.ToLower(key)) + wantb := []byte(strings.ToLower(want)) + + // find ":" (allow spaces before colon too) + i := bytes.Index(lh, keyb) + for i >= 0 { + rest := lh[i+len(keyb):] + rest = bytes.TrimLeft(rest, " \t") + if len(rest) > 0 && rest[0] == ':' { + // slice to end-of-line + line := rest[1:] + if j := bytes.IndexByte(line, '\n'); j >= 0 { + line = line[:j] + } + line = bytes.TrimSpace(bytes.TrimSuffix(line, []byte{'\r'})) + return bytes.Equal(line, wantb) + } + // keep searching (avoid false positives on substrings) + next := lh[i+len(keyb):] + j := bytes.Index(next, keyb) + if j < 0 { + return false + } + i += len(keyb) + j + } + return false +} + +func discardAndCaptureHTTP(br *bufio.Reader) (string, []byte, error) { + const hardCap = 64 << 10 // 64 KiB + var buf bytes.Buffer + total := 0 + + first, err := br.ReadSlice('\n') + if err != nil { + return "", nil, err + } + buf.Write(first) + total += len(first) + if total > hardCap { + return "", nil, fmt.Errorf("http header too large") + } + + for { + line, err := br.ReadSlice('\n') + if err != nil { + return "", nil, err + } + total += len(line) + if total > hardCap { + return "", nil, fmt.Errorf("http header too large") + } + buf.Write(line) + if (len(line) == 1 && line[0] == '\n') || + (len(line) == 2 && line[0] == '\r' && line[1] == '\n') { + break + } + } + fl := strings.TrimRight(string(first), "\r\n") + return fl, buf.Bytes(), nil +} + +func discardOneHTTP(br *bufio.Reader) (string, error) { + const hardCap = 64 << 10 // 64 KiB + total := 0 + + // capture first line for printing + first, err := br.ReadSlice('\n') + if err != nil { + return "", err + } + total += len(first) + if total > hardCap { + return "", fmt.Errorf("http header too large") + } + + for { + line, err := br.ReadSlice('\n') + if err != nil { + return "", err + } + total += len(line) + if total > hardCap { + return "", fmt.Errorf("http header too large") + } + + // End of headers: "\n" or "\r\n" + if (len(line) == 1 && line[0] == '\n') || + (len(line) == 2 && line[0] == '\r' && line[1] == '\n') { + break + } + } + // Trim trailing CR/LF for nicer logs + fl := strings.TrimRight(string(first), "\r\n") + return fl, nil +} + +func maybeHTTPStartPrefix(b []byte) bool { + u := bytes.ToUpper(b) + + // Helper: partial prefix match + startsWithPartial := func(s string) bool { + su := []byte(s) + if len(u) <= len(su) { + return bytes.Equal(u, su[:len(u)]) + } + return bytes.HasPrefix(u, su) + } + + // Status-lines / prefaces (allow partial) + if startsWithPartial("HTTP/") || startsWithPartial("HTTP/1.") || + startsWithPartial("HTTP/2") || startsWithPartial("HTTP/3") || + startsWithPartial("PRI ") || startsWithPartial("ICY ") { + return true + } + + // Request-lines (methods) with partial match + for _, m := range httpMethods { + um := bytes.ToUpper(m) + if len(u) <= len(um) && bytes.Equal(u, um[:len(u)]) { // partial ok + return true + } + if len(u) >= len(um) && bytes.HasPrefix(u, um) { + return true + } + } + return false +} + +func main() { + configPath := flag.String("config", "config.json", "path to JSON config file") + quietFlag := flag.Bool("quiet", false, "override config and disable logs") + userCountFlag := flag.Bool("usercount", false, "show per-user connection counters (single line)") + flag.Parse() + + cfg, userMap, err := loadConfig(*configPath) + if err != nil { + log.Fatalf("failed to load config: %v", err) + } + userMgr.ReplaceAll(userMap) + + // Store config path and live config for hot-reload via admin API. + globalCfgPath = *configPath + setGlobalCfg(cfg) + + userCountEnabled = cfg.UserCount || *userCountFlag + + var store *Store + pgDSN := os.Getenv("PG_DSN") + if pgDSN != "" { + store, err = NewStore(pgDSN) + if err != nil { + log.Fatalf("failed to connect to postgres: %v", err) + } + + ctx := context.Background() + dbUsers, err := store.LoadUsers(ctx) + if err != nil { + log.Fatalf("failed to load users from database: %v", err) + } + if len(dbUsers) > 0 { + userMgr.ReplaceAll(dbUsers) + } + } + + // Bootstrap admin users and reseller expiry checker. + if store != nil { + ctx := context.Background() + if pw, err := store.BootstrapSuperAdmin(ctx); err != nil { + log.Printf("failed to bootstrap superadmin: %v", err) + } else if pw != "" { + log.Printf("=== FIRST RUN: superadmin created — username: admin password: %s ===", pw) + } + if err := loadAdminUsersIntoCache(ctx, store); err != nil { + log.Printf("failed to load admin users: %v", err) + } + startResellerExpiryChecker(store) + } + + // Optional: initialize interface totals persistence (best-effort). + if store != nil { + statsStore = store + ctx := context.Background() + if err := store.EnsureXrayClientsSchema(ctx); err != nil { + log.Printf("xray clients table: %v", err) + } else { + startXrayClientExpiryChecker(store) + } + if err := store.EnsureIfaceTotalsTable(ctx); err == nil { + rows, err2 := store.LoadIfaceTotals(ctx) + if err2 == nil { + ifaceTotalsMgr = NewIfaceTotalsManager() + ifaceTotalsMgr.Load(rows) + + // Bootstrap with current kernel counters once at startup + if netMap, err3 := readNetDev(); err3 == nil { + for name, ctrs := range netMap { + ifaceTotalsMgr.ApplyKernel(name, ctrs.RxBytes, ctrs.TxBytes) + } + } + } else { + log.Printf("iface totals persistence disabled: %v", err2) + } + } else { + log.Printf("iface totals persistence disabled: %v", err) + } + } + + // start background collector for CPU + interface stats + startStatsCollector() + + adminAddr := os.Getenv("ADMIN_HTTP_ADDR") + if adminAddr == "" { + adminAddr = "0.0.0.0:9090" + } + + // Determine admin directory: from config or default ./admin + adminDir := cfg.AdminDir + if adminDir == "" { + adminDir = "./admin" + } + startAdminAPI(store, adminAddr, adminDir) + + // Start the integrated Xray-core subprocess if configured. + initXrayManager(cfg.Xray) + + // Global banner text (from config or file) — stored in a global so the + // admin API can update it on the fly without a restart. + { + bt := cfg.Banner + if bt == "" && cfg.BannerFile != "" { + if data, err := os.ReadFile(cfg.BannerFile); err == nil { + bt = string(data) + } else { + log.Printf("failed to read banner file %s: %v", cfg.BannerFile, err) + } + } + setBannerText(bt) + } + + sshConfig := &ssh.ServerConfig{ + PasswordCallback: passwordCallback, + PublicKeyCallback: publicKeyCallback, + NoClientAuth: false, + } + + // Optional compatibility mode for legacy SSH clients (weak algorithms). + // Enable only if you must support older embedded/mobile SSH stacks. + // Usage: SSH_LEGACY=1 ./server + if os.Getenv("SSH_LEGACY") == "1" { + sshConfig.Config = ssh.Config{ + Ciphers: []string{ + // Modern first + "chacha20-poly1305@openssh.com", + "aes128-ctr", "aes192-ctr", "aes256-ctr", + // Legacy (weak) + "aes128-cbc", "aes192-cbc", "aes256-cbc", + "3des-cbc", + }, + KeyExchanges: []string{ + // Modern first + "curve25519-sha256", "curve25519-sha256@libssh.org", + "diffie-hellman-group-exchange-sha256", + // Legacy (weak) + "diffie-hellman-group14-sha1", + "diffie-hellman-group1-sha1", + }, + MACs: []string{ + // Modern first + "hmac-sha2-256", "hmac-sha2-512", + // Legacy (weak) + "hmac-sha1", + }, + } + log.Println("[SSH] legacy crypto enabled (weak algorithms allowed)") + } + + // Fancy per-user banner: + // - Optional global banner first + // - Then block: + // Username: user + // ----------------- + // Expiration: 10/12/2025 + // ----------------- + // Max Upload: 10 Mbps + // ----------------- + // Max Download: 10 Mbps + sshConfig.BannerCallback = func(meta ssh.ConnMetadata) string { + var sb strings.Builder + + // Global / custom banner (reads live value so admin-panel edits apply immediately) + if bt := getBannerText(); bt != "" { + sb.WriteString(bt) + if !strings.HasSuffix(bt, "\n") { + sb.WriteString("\n") + } + sb.WriteString("\n") // extra blank line before user block + } + + u, ok := userMgr.Get(meta.User()) + if !ok { + // If user not found (should be rare), just return global banner + return sb.String() + } + + u.mu.Lock() + exp := u.ExpiresAt + up := u.Cfg.LimitMbpsUp + down := u.Cfg.LimitMbpsDown + name := u.Cfg.Username + u.mu.Unlock() + + // BR-style date: dd/MM/yyyy + expStr := "Sem limite" + if exp != nil { + expStr = exp.Local().Format("02/01/2006") + } + sb.WriteString("\n
-----------------
\n") + sb.WriteString("Account Information") + sb.WriteString("\n
-----------------
\n") + sb.WriteString("Username: ") + sb.WriteString(name) + sb.WriteString("\n
-----------------
\n") + + sb.WriteString("Expiration: ") + sb.WriteString(expStr) + sb.WriteString("\n
-----------------
\n") + + sb.WriteString("Max Upload: ") + sb.WriteString(strconv.Itoa(up)) + sb.WriteString(" Mbps") + sb.WriteString("\n
-----------------
\n") + sb.WriteString("Max Download: ") + sb.WriteString(strconv.Itoa(down)) + sb.WriteString(" Mbps") + sb.WriteString("\n
-----------------
\n") + return sb.String() + } + + hostKeyBytes, err := os.ReadFile(cfg.HostKeyFile) + if err != nil { + log.Fatalf("failed to load host key %s: %v", cfg.HostKeyFile, err) + } + hostKey, err := ssh.ParsePrivateKey(hostKeyBytes) + if err != nil { + log.Fatalf("failed to parse host key: %v", err) + } + + // Enforce RSA host key (ssh-rsa) for maximum compatibility with legacy SSH clients. + // If you need ed25519 or other key types, remove this check. + if hostKey.PublicKey().Type() != ssh.KeyAlgoRSA { + log.Fatalf("host key %s is %s; this build requires RSA (ssh-rsa). Generate an RSA key with: ssh-keygen -t rsa -b 2048 -f %s -N \"\"", cfg.HostKeyFile, hostKey.PublicKey().Type(), cfg.HostKeyFile) + } + sshConfig.AddHostKey(hostKey) + + // Store SSH config globally so hot-reloaded serve loops always use the latest. + setSSHConfig(sshConfig) + + // Logging setup. + quietLogs := cfg.Quiet || *quietFlag || userCountEnabled + if quietLogs { + log.SetOutput(io.Discard) + } + + // Initialise default per-connection bandwidth limits. + setDefaultLimits(cfg.DefaultLimitMbpsUp, cfg.DefaultLimitMbpsDown) + + // Start the integrated DNSTT and UDPGW if configured. + startDNSTT(cfg.DNSTT, sshConfig) + startUDPGW(cfg.UDPGW) + + // Initialise listener pools (used for initial startup and hot-reload alike). + publicPool = newListenerPool(serveHTTP80) + localPool = newListenerPool(serveRawSSH) + tlsPool = newTLSListenerPool() + + // Start public SSH listeners (listen + extra_listen). + publicAddrs := append([]string{cfg.Listen}, cfg.ExtraListen...) + for _, e := range publicPool.Sync(publicAddrs) { + log.Fatalf("failed to start listener: %v", e) + } + + // Start local raw SSH listener if configured. + if cfg.LocalSSHListen != "" { + for _, e := range localPool.Sync([]string{cfg.LocalSSHListen}) { + log.Fatalf("failed to start local SSH listener: %v", e) + } + } + + // Start TLS forwarder listeners if configured. + for _, e := range tlsPool.Sync(cfg.TLSForwarders) { + log.Fatalf("failed to start TLS listener: %v", e) + } + + // Print user counts once at startup. + updateUserDisplay() + + // Block forever. Individual accept loops run in goroutines. + select {} +} diff --git a/server_config_api.go b/server_config_api.go new file mode 100644 index 0000000..ad32a02 --- /dev/null +++ b/server_config_api.go @@ -0,0 +1,117 @@ +package main + +import ( + "encoding/json" + "io" + "net/http" + "os" + "sync" +) + +// Note: applyFullConfigReload is defined in hotreload.go + +// ---------- Global config holder ---------- + +var ( + globalCfgMu sync.RWMutex + globalCfg *Config + globalCfgPath string // set in main() from the -config flag + + bannerMu sync.RWMutex + currentBannerText string +) + +func setGlobalCfg(c *Config) { + globalCfgMu.Lock() + globalCfg = c + globalCfgMu.Unlock() +} + +func getGlobalCfg() *Config { + globalCfgMu.RLock() + defer globalCfgMu.RUnlock() + return globalCfg +} + +func setBannerText(s string) { + bannerMu.Lock() + currentBannerText = s + bannerMu.Unlock() +} + +func getBannerText() string { + bannerMu.RLock() + defer bannerMu.RUnlock() + return currentBannerText +} + +// ---------- HTTP handler ---------- + +// handleServerConfig dispatches GET (read) and POST (write) for /api/server/config. +func handleServerConfig(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + serverConfigGet(w, r) + case http.MethodPost: + serverConfigPost(w, r) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func serverConfigGet(w http.ResponseWriter, _ *http.Request) { + if globalCfgPath == "" { + http.Error(w, "config path not set", http.StatusInternalServerError) + return + } + data, err := os.ReadFile(globalCfgPath) + if err != nil { + http.Error(w, "failed to read config: "+err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(data) +} + +func serverConfigPost(w http.ResponseWriter, r *http.Request) { + if globalCfgPath == "" { + http.Error(w, "config path not set", http.StatusInternalServerError) + return + } + body, err := io.ReadAll(io.LimitReader(r.Body, 512*1024)) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + var newCfg Config + if err := json.Unmarshal(body, &newCfg); err != nil { + http.Error(w, "invalid JSON: "+err.Error(), http.StatusBadRequest) + return + } + if newCfg.Listen == "" { + http.Error(w, "listen address required", http.StatusBadRequest) + return + } + + // Preserve file-based users array (not editable through the UI). + globalCfgMu.RLock() + if globalCfg != nil { + newCfg.Users = globalCfg.Users + } + globalCfgMu.RUnlock() + + out, err := json.MarshalIndent(newCfg, "", " ") + if err != nil { + http.Error(w, "marshal error", http.StatusInternalServerError) + return + } + if err := os.WriteFile(globalCfgPath, out, 0o644); err != nil { + http.Error(w, "failed to write config: "+err.Error(), http.StatusInternalServerError) + return + } + + // Apply all changes live — no restart needed. + applyFullConfigReload(&newCfg) + + w.WriteHeader(http.StatusOK) +} diff --git a/tls_api.go b/tls_api.go new file mode 100644 index 0000000..a6f67aa --- /dev/null +++ b/tls_api.go @@ -0,0 +1,168 @@ +package main + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "fmt" + "math/big" + "net/http" + "os" + "os/exec" + "path/filepath" + "time" +) + +const tlsCertsDir = "/opt/sshpanel/certs" + +// handleTLSGenerateSelfSigned generates a self-signed TLS certificate for the +// given domain, writes it to /opt/sshpanel/certs//, and returns the paths. +func handleTLSGenerateSelfSigned(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var req struct { + Domain string `json:"domain"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Domain == "" { + http.Error(w, "domain required", http.StatusBadRequest) + return + } + + certDir := filepath.Join(tlsCertsDir, req.Domain) + if err := os.MkdirAll(certDir, 0o700); err != nil { + http.Error(w, "mkdir: "+err.Error(), http.StatusInternalServerError) + return + } + certFile := filepath.Join(certDir, "cert.pem") + keyFile := filepath.Join(certDir, "key.pem") + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + http.Error(w, "keygen: "+err.Error(), http.StatusInternalServerError) + return + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: req.Domain}, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{req.Domain}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + http.Error(w, "certgen: "+err.Error(), http.StatusInternalServerError) + return + } + cf, err := os.OpenFile(certFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + http.Error(w, "write cert: "+err.Error(), http.StatusInternalServerError) + return + } + _ = pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: der}) + cf.Close() + + privDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + http.Error(w, "marshal key: "+err.Error(), http.StatusInternalServerError) + return + } + kf, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + http.Error(w, "write key: "+err.Error(), http.StatusInternalServerError) + return + } + _ = pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER}) + kf.Close() + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "cert_file": certFile, + "key_file": keyFile, + }) +} + +// handleTLSLetsEncrypt runs certbot to obtain a certificate via Let's Encrypt. +// Requires certbot installed on the server and port 80 available. +func handleTLSLetsEncrypt(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var req struct { + Domain string `json:"domain"` + Email string `json:"email"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Domain == "" || req.Email == "" { + http.Error(w, "domain and email required", http.StatusBadRequest) + return + } + + cmd := exec.Command("certbot", "certonly", "--standalone", "--non-interactive", + "--agree-tos", "-m", req.Email, "-d", req.Domain) + out, err := cmd.CombinedOutput() + if err != nil { + http.Error(w, fmt.Sprintf("certbot failed: %v\n%s", err, string(out)), http.StatusInternalServerError) + return + } + + certFile := "/etc/letsencrypt/live/" + req.Domain + "/fullchain.pem" + keyFile := "/etc/letsencrypt/live/" + req.Domain + "/privkey.pem" + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "cert_file": certFile, + "key_file": keyFile, + "output": string(out), + }) +} + +// handleTLSUploadPEM accepts PEM text for cert and key, saves them to disk under +// /opt/sshpanel/certs//, and returns the file paths. +func handleTLSUploadPEM(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var req struct { + Name string `json:"name"` + Cert string `json:"cert"` + Key string `json:"key"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Name == "" || req.Cert == "" || req.Key == "" { + http.Error(w, "name, cert, and key required", http.StatusBadRequest) + return + } + name := filepath.Base(req.Name) + if name == "." || name == "/" || name == "" { + http.Error(w, "invalid name", http.StatusBadRequest) + return + } + certDir := filepath.Join(tlsCertsDir, name) + if err := os.MkdirAll(certDir, 0o700); err != nil { + http.Error(w, "mkdir: "+err.Error(), http.StatusInternalServerError) + return + } + certFile := filepath.Join(certDir, "cert.pem") + keyFile := filepath.Join(certDir, "key.pem") + if err := os.WriteFile(certFile, []byte(req.Cert), 0o600); err != nil { + http.Error(w, "write cert: "+err.Error(), http.StatusInternalServerError) + return + } + if err := os.WriteFile(keyFile, []byte(req.Key), 0o600); err != nil { + http.Error(w, "write key: "+err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "cert_file": certFile, + "key_file": keyFile, + }) +} diff --git a/udpgw_integration.go b/udpgw_integration.go new file mode 100644 index 0000000..b3aa035 --- /dev/null +++ b/udpgw_integration.go @@ -0,0 +1,490 @@ +package main + +// This file embeds the UDP gateway (udpgw) service into the main +// application. The original udpgw program (public domain) accepts a +// TCP connection, then forwards framed UDP datagrams to arbitrary +// destinations and demultiplexes replies back to the originating +// connection. Here we expose the same functionality via a +// configuration key in config.json. When enabled, the server binds +// to the provided Listen address (or the default 0.0.0.0:7400 if +// unspecified) and handles each client in its own goroutine. The +// gateway runs entirely in‑process and behaves like the standalone +// badvpn-udpgw daemon. + +import ( + "bufio" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +var ( + udpgwMu sync.Mutex + udpgwLn net.Listener +) + +// stopUDPGW closes the active UDPGW listener, causing the accept loop to exit. +// It is a no-op if UDPGW is not running. +func stopUDPGW() { + udpgwMu.Lock() + defer udpgwMu.Unlock() + if udpgwLn != nil { + _ = udpgwLn.Close() + udpgwLn = nil + } +} + +// startUDPGW starts the integrated UDP gateway if cfg is non‑nil and +// cfg.Listen is non‑empty. It applies default values to any zero +// configuration fields and converts duration strings to time.Duration. +// The server runs in a goroutine; any fatal errors are logged and +// prevent the gateway from starting, but do not terminate the main +// process. +func startUDPGW(cfg *UDPGWConfig) { + if cfg == nil { + return + } + // Default the listen address to the standalone default (0.0.0.0:7400) if + // unspecified. This matches the behaviour of the original + // badvpn-udpgw program, which listens on all interfaces by default. + listenAddr := cfg.Listen + if listenAddr == "" { + listenAddr = "0.0.0.0:7400" + } + // Apply defaults for numeric fields if zero. + c := &internalUDPGWConfig{} + c.listen = listenAddr + if cfg.MaxFrame > 0 { + c.maxFrame = cfg.MaxFrame + } else { + c.maxFrame = 64 * 1024 + } + c.debug = cfg.Debug + if cfg.HexdumpN > 0 { + c.hexdumpN = cfg.HexdumpN + } else { + c.hexdumpN = 64 + } + if cfg.WriteChan > 0 { + c.writeChan = cfg.WriteChan + } else { + c.writeChan = 4096 + } + c.udpBindIP = cfg.UDPBindIP + if cfg.UDPRBuf > 0 { + c.udpRBuf = cfg.UDPRBuf + } else { + c.udpRBuf = 8 * 1024 * 1024 + } + if cfg.UDPWBuf > 0 { + c.udpWBuf = cfg.UDPWBuf + } else { + c.udpWBuf = 8 * 1024 * 1024 + } + // Parse durations with fallback defaults. + if cfg.MapTTL != "" { + if d, err := time.ParseDuration(cfg.MapTTL); err == nil { + c.mapTTL = d + } else { + log.Printf("udpgw: invalid map_ttl %q: %v; using default 90s", cfg.MapTTL, err) + c.mapTTL = 90 * time.Second + } + } else { + c.mapTTL = 90 * time.Second + } + if cfg.ReapEvery != "" { + if d, err := time.ParseDuration(cfg.ReapEvery); err == nil { + c.reapEvery = d + } else { + log.Printf("udpgw: invalid reap_every %q: %v; using default 10s", cfg.ReapEvery, err) + c.reapEvery = 10 * time.Second + } + } else { + c.reapEvery = 10 * time.Second + } + // Idle timeout for TCP clients. + if cfg.IdleTimeout != "" { + if d, err := time.ParseDuration(cfg.IdleTimeout); err == nil { + c.idleTimeout = d + } else { + log.Printf("udpgw: invalid idle_timeout %q: %v; using default 2m", cfg.IdleTimeout, err) + c.idleTimeout = 2 * time.Minute + } + } else { + c.idleTimeout = 2 * time.Minute + } + // Per-client logical connID cap. + if cfg.MaxClientConns > 0 { + c.maxClientConns = cfg.MaxClientConns + } else { + c.maxClientConns = 10 + } + // Per-client destination mapping cap. + if cfg.MaxMapEntries > 0 { + c.maxMapEntries = cfg.MaxMapEntries + } else { + c.maxMapEntries = 32768 + } + // Start listening. + ln, err := net.Listen("tcp", c.listen) + if err != nil { + log.Printf("udpgw: listen failed on %s: %v", c.listen, err) + return + } + + // Register as the active listener so stopUDPGW can close it. + udpgwMu.Lock() + if udpgwLn != nil { + _ = udpgwLn.Close() + } + udpgwLn = ln + udpgwMu.Unlock() + + if c.debug { + log.Printf("udpgw: listening on %s", c.listen) + } + go func() { + for { + conn, err := ln.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("udpgw: accept error: %v", err) + continue + } + go handleUDPGWClient(conn, c) + } + }() +} + +// internalUDPGWConfig mirrors the exported UDPGWConfig but with +// time.Duration fields for TTL and reaper intervals. It does not +// embed JSON tags because it is not exposed to the user. +type internalUDPGWConfig struct { + listen string + maxFrame int + debug bool + hexdumpN int + writeChan int + udpBindIP string + udpRBuf int + udpWBuf int + mapTTL time.Duration + reapEvery time.Duration + idleTimeout time.Duration + maxClientConns int + maxMapEntries int +} + +// udpDestKey identifies a destination IPv4:port for the UDP gateway. A +// mapping of this key to a connID+x byte is kept per client so +// replies can be routed correctly. Each client maintains its own map +// of udpDestKey->udpMapVal entries. +type udpDestKey struct { + ip [4]byte + port uint16 +} + +// udpMapVal stores the mapping from udpDestKey to connID, the x byte and the +// expiration time. +type udpMapVal struct { + connID uint16 + x byte + exp time.Time +} + +// handleUDPGWClient manages a single TCP client connection to the UDP +// gateway. It creates a per‑client UDP socket, reads frames from the +// TCP connection, sends UDP datagrams to the requested destination, +// maintains a mapping of udpDestKey->udpMapVal for routing replies, and +// writes reply frames back over the TCP connection. When the client +// disconnects or an error occurs, all goroutines are terminated and the +// UDP socket is closed. +func handleUDPGWClient(conn net.Conn, c *internalUDPGWConfig) { + defer conn.Close() + remote := conn.RemoteAddr().String() + if c.debug { + log.Printf("udpgw: client connected: %s", remote) + } + // Lower latency for interactive applications by disabling Nagle. + if tcp, ok := conn.(*net.TCPConn); ok { + _ = tcp.SetNoDelay(true) + } + br := bufio.NewReaderSize(conn, 256*1024) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Bind a UDP socket for this client. Use cfg.udpBindIP if provided. + var laddr *net.UDPAddr + if c.udpBindIP != "" { + ip := net.ParseIP(c.udpBindIP) + if ip != nil { + laddr = &net.UDPAddr{IP: ip, Port: 0} + } else { + log.Printf("udpgw[%s]: invalid udp_bind IP %q", remote, c.udpBindIP) + return + } + } + udpConn, err := net.ListenUDP("udp", laddr) + if err != nil { + log.Printf("udpgw[%s]: UDP listen failed: %v", remote, err) + return + } + defer udpConn.Close() + _ = udpConn.SetReadBuffer(c.udpRBuf) + _ = udpConn.SetWriteBuffer(c.udpWBuf) + // Channel to queue outgoing frames back to the client. + writeCh := make(chan []byte, c.writeChan) + done := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-ctx.Done(): + return + case b := <-writeCh: + if len(b) == 0 { + continue + } + // If the client stops reading, don't block forever. + _ = conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + _, err := conn.Write(b) + if err != nil { + cancel() + return + } + } + } + }() + // Destination -> connID+x mapping and active connID tracking per TCP client. + var mu sync.Mutex + destToConn := make(map[udpDestKey]udpMapVal) + connIDLastSeen := make(map[uint16]time.Time) + // Start a reaper goroutine to purge expired mappings. + // Uses ctx instead of a separate stopReaper channel so that cancellation + // is always uniform: any exit path that calls cancel() (error, disconnect, + // idle timeout, panic-deferred cancel) also stops the reaper. + go func() { + t := time.NewTicker(c.reapEvery) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + now := time.Now() + mu.Lock() + for k, v := range destToConn { + if now.After(v.exp) { + delete(destToConn, k) + } + } + for id, lastSeen := range connIDLastSeen { + if now.Sub(lastSeen) > c.mapTTL { + delete(connIDLastSeen, id) + } + } + mu.Unlock() + } + } + }() + // Goroutine to read UDP replies and write framed responses back. + go func() { + buf := make([]byte, 65535) + for { + n, from, err := udpConn.ReadFromUDP(buf) + if err != nil { + return + } + if n <= 0 { + continue + } + ip4 := from.IP.To4() + if ip4 == nil { + // IPv6 replies are dropped because framing only supports IPv4. + continue + } + k := udpDestKey{port: uint16(from.Port)} + copy(k.ip[:], ip4) + mu.Lock() + v, ok := destToConn[k] + mu.Unlock() + if !ok || time.Now().After(v.exp) { + if c.debug { + log.Printf("udpgw[%s]: dropping %dB from %s (no mapping)", remote, n, from.String()) + } + continue + } + // Avoid unbounded memory growth: if the client is slow and the + // write queue is full, drop replies rather than allocating more. + if len(writeCh) == cap(writeCh) { + if c.debug { + log.Printf("udpgw[%s]: drop UDP reply %dB from %s (write queue full)", remote, n, from.String()) + } + continue + } + frame := udpgwBuildFrame(v.connID, v.x, k.ip, k.port, buf[:n]) + select { + case <-ctx.Done(): + return + case writeCh <- frame: + default: + // Race with another sender filling the channel. + } + if c.debug { + log.Printf("udpgw[%s]: UDP<- %dB from %s -> connID=%d x=0x%02x", remote, n, from.String(), v.connID, v.x) + } + } + }() + // Main loop: read frames from TCP, update mapping and send UDP. + for { + // Close idle TCP clients to avoid leaking goroutines. + if c.idleTimeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(c.idleTimeout)) + } + payload, err := udpgwReadPayload(br, c.maxFrame) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + if c.debug { + log.Printf("udpgw[%s]: idle timeout (%s); closing", remote, c.idleTimeout) + } + } + if err == io.EOF { + if c.debug { + log.Printf("udpgw: client disconnected: %s", remote) + } + } else { + log.Printf("udpgw[%s]: read error: %v", remote, err) + } + cancel() + _ = conn.Close() + <-done + return + } + // payload: connID(2) + X(1) + dstIPv4(4) + dstPort(2) + data + if len(payload) < 2+1+4+2 { + if c.debug { + log.Printf("udpgw[%s]: too short payload %dB", remote, len(payload)) + } + continue + } + connID := binary.BigEndian.Uint16(payload[0:2]) + x := payload[2] + var dstIP [4]byte + copy(dstIP[:], payload[3:7]) + dstPort := binary.BigEndian.Uint16(payload[7:9]) + data := payload[9:] + if c.debug { + log.Printf("udpgw[%s]: RX connID=%d dst=%d.%d.%d.%d:%d x=0x%02x len=%d", remote, connID, dstIP[0], dstIP[1], dstIP[2], dstIP[3], dstPort, x, len(data)) + } + // Enforce a per-client logical session cap using udpgw connID. + k := udpDestKey{ip: dstIP, port: dstPort} + now := time.Now() + mu.Lock() + // Step 1: evict TTL-expired connIDs first. + for id, lastSeen := range connIDLastSeen { + if now.Sub(lastSeen) > c.mapTTL { + delete(connIDLastSeen, id) + } + } + // Step 2: if the connID is new and we are still at or above the cap + // after TTL eviction, evict the single oldest entry to make room. + // This prevents a client rotating connIDs faster than mapTTL from + // bypassing maxClientConns and growing the map indefinitely. + if _, alreadyKnown := connIDLastSeen[connID]; !alreadyKnown && c.maxClientConns > 0 && len(connIDLastSeen) >= c.maxClientConns { + // Find and remove the oldest connID seen. + var oldestID uint16 + var oldestTime time.Time + first := true + for id, lastSeen := range connIDLastSeen { + if first || lastSeen.Before(oldestTime) { + oldestID = id + oldestTime = lastSeen + first = false + } + } + delete(connIDLastSeen, oldestID) + if c.debug { + log.Printf("udpgw[%s]: connID cap reached (%d); evicted oldest connID=%d to admit connID=%d", remote, c.maxClientConns, oldestID, connID) + } + } + connIDLastSeen[connID] = now + // Update the mapping (bounded to protect memory). + if c.maxMapEntries > 0 && len(destToConn) >= c.maxMapEntries { + // First, purge expired entries. + for dk, dv := range destToConn { + if now.After(dv.exp) { + delete(destToConn, dk) + } + } + // If still too large, evict arbitrary entries until we're under the cap. + if len(destToConn) >= c.maxMapEntries { + evict := (len(destToConn) - c.maxMapEntries) + 1 + for dk := range destToConn { + delete(destToConn, dk) + evict-- + if evict <= 0 { + break + } + } + if c.debug { + log.Printf("udpgw[%s]: mapping cap reached; evicted entries, size=%d", remote, len(destToConn)) + } + } + } + destToConn[k] = udpMapVal{connID: connID, x: x, exp: now.Add(c.mapTTL)} + mu.Unlock() + // Send the UDP datagram. + raddr := &net.UDPAddr{IP: net.IP(dstIP[:]), Port: int(dstPort)} + _, err = udpConn.WriteToUDP(data, raddr) + if err != nil { + if c.debug { + log.Printf("udpgw[%s]: UDP write failed to %s: %v", remote, raddr.String(), err) + } + continue + } + if c.debug { + log.Printf("udpgw[%s]: UDP-> %dB to %s", remote, len(data), raddr.String()) + } + } +} + +// udpgwReadPayload reads a length‑prefixed payload from r. The length +// prefix is a little‑endian uint16. Payloads larger than max cause an +// error. +func udpgwReadPayload(r *bufio.Reader, max int) ([]byte, error) { + var lenBuf [2]byte + if _, err := io.ReadFull(r, lenBuf[:]); err != nil { + return nil, err + } + n := int(binary.LittleEndian.Uint16(lenBuf[:])) + if n <= 0 || n > max { + return nil, fmt.Errorf("udpgw: invalid frame length %d", n) + } + b := make([]byte, n) + if _, err := io.ReadFull(r, b); err != nil { + return nil, err + } + return b, nil +} + +// udpgwBuildFrame constructs a reply frame for the client. The frame +// consists of a little‑endian length prefix, then connID (big endian), +// x byte, source IPv4, source port, and the data. +func udpgwBuildFrame(connID uint16, x byte, ip [4]byte, port uint16, data []byte) []byte { + payloadLen := 2 + 1 + 4 + 2 + len(data) + out := make([]byte, 2+payloadLen) + binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen)) + binary.BigEndian.PutUint16(out[2:4], connID) + out[4] = x + copy(out[5:9], ip[:]) + binary.BigEndian.PutUint16(out[9:11], port) + copy(out[11:], data) + return out +} diff --git a/update.sh b/update.sh new file mode 100644 index 0000000..5459844 --- /dev/null +++ b/update.sh @@ -0,0 +1,199 @@ +#!/bin/bash +# Update script for SSH Panel — updates the binary and admin panel in place. +# Preserves: .env, config.json, xray_config.json, SSH keys, database, certs. +# Usage: sudo bash update.sh +set -euo pipefail + +RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m' +info() { echo -e "${GREEN}[+]${NC} $*"; } +warn() { echo -e "${YELLOW}[!]${NC} $*"; } +error() { echo -e "${RED}[x]${NC} $*"; exit 1; } + +# ── config ──────────────────────────────────────────────────────────────────── +INSTALL_DIR="/opt/sshpanel" +SERVICE_NAME="sshpanel" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +GO_VERSION="${GO_VERSION:-$(awk '$1 == "go" {print $2; exit}' "$SCRIPT_DIR/go.mod" 2>/dev/null || echo "1.22.5")}" +# ───────────────────────────────────────────────────────────────────────────── + +[[ $EUID -ne 0 ]] && error "Run as root: sudo bash $0" + +echo -e "\n${GREEN}══════════════════════════════════════════${NC}" +echo -e "${GREEN} SSH Panel · Updater ${NC}" +echo -e "${GREEN}══════════════════════════════════════════${NC}\n" + +# ── 1. Pre-flight checks ────────────────────────────────────────────────────── +info "[1/5] Pre-flight checks…" + +[[ -d "$INSTALL_DIR" ]] || error "Install dir $INSTALL_DIR not found — run install.sh first." +[[ -f "$INSTALL_DIR/.env" ]] || error "$INSTALL_DIR/.env not found — run install.sh first." +[[ -f "$SCRIPT_DIR/go.mod" ]] || error "go.mod not found — run this script from the source directory." + +info " Install dir : $INSTALL_DIR" +info " Source dir : $SCRIPT_DIR" +info " Go version : $GO_VERSION" + +# ── 2. Go toolchain ─────────────────────────────────────────────────────────── +info "[2/5] Checking Go toolchain…" + +NEED_GO=true +if command -v go &>/dev/null; then + CURRENT_GO=$(go version 2>/dev/null | awk '{print $3}' | sed 's/go//') + if [[ "$(printf '%s\n' "$GO_VERSION" "$CURRENT_GO" | sort -V | head -1)" == "$GO_VERSION" ]]; then + info " Go $CURRENT_GO already installed — skipping" + NEED_GO=false + fi +fi + +if $NEED_GO; then + MACHINE=$(uname -m) + case "$MACHINE" in + x86_64) GOARCH="amd64" ;; + aarch64) GOARCH="arm64" ;; + armv7l) GOARCH="armv6l" ;; + *) GOARCH="amd64" ;; + esac + GO_URL="https://go.dev/dl/go${GO_VERSION}.linux-${GOARCH}.tar.gz" + info " Downloading Go ${GO_VERSION} (${GOARCH})…" + wget -q --show-progress -O /tmp/go.tar.gz "$GO_URL" + rm -rf /usr/local/go + tar -C /usr/local -xzf /tmp/go.tar.gz + rm -f /tmp/go.tar.gz + echo 'export PATH=$PATH:/usr/local/go/bin' > /etc/profile.d/go.sh + chmod +x /etc/profile.d/go.sh +fi + +export PATH=$PATH:/usr/local/go/bin +go version + +# ── 3. Build new binary ─────────────────────────────────────────────────────── +info "[3/5] Building new sshpanel binary…" + +cd "$SCRIPT_DIR" +export GOPATH=/tmp/gopath_sshpanel +export GOCACHE=/tmp/gocache_sshpanel +go mod download +go build -ldflags="-s -w" -o /tmp/sshpanel_new . +info " Build complete." + +# ── 4. Apply update ─────────────────────────────────────────────────────────── +info "[4/5] Applying update…" + +# Stop the service +if systemctl is-active --quiet "$SERVICE_NAME" 2>/dev/null; then + info " Stopping $SERVICE_NAME…" + systemctl stop "$SERVICE_NAME" + RESTART_NEEDED=true +else + RESTART_NEEDED=false +fi + +# Backup old binary +if [[ -f "$INSTALL_DIR/sshpanel" ]]; then + cp "$INSTALL_DIR/sshpanel" "$INSTALL_DIR/sshpanel.bak" + info " Old binary backed up to sshpanel.bak" +fi + +# Replace binary +mv /tmp/sshpanel_new "$INSTALL_DIR/sshpanel" +chmod +x "$INSTALL_DIR/sshpanel" +info " Binary updated." + +# Update admin panel files +mkdir -p "$INSTALL_DIR/admin" +cp -r "$SCRIPT_DIR/admin/"* "$INSTALL_DIR/admin/" +info " Admin panel updated." + +# Ensure banner file exists (new in this version) +if [[ ! -f "$INSTALL_DIR/banner.txt" ]]; then + touch "$INSTALL_DIR/banner.txt" + info " Created banner.txt" +fi + +# Ensure certs directory exists (new in this version) +mkdir -p "$INSTALL_DIR/certs" + +# Patch config.json to add missing fields introduced in this version +# without overwriting user-configured values. +CFG="$INSTALL_DIR/config.json" +if [[ -f "$CFG" ]]; then + # Add banner_file if not present + if ! python3 -c "import json,sys; d=json.load(open('$CFG')); sys.exit(0 if 'banner_file' in d else 1)" 2>/dev/null; then + python3 - "$CFG" << 'PYEOF' +import json, sys +path = sys.argv[1] +with open(path) as f: + d = json.load(f) +if 'banner_file' not in d: + d['banner_file'] = '/opt/sshpanel/banner.txt' +with open(path, 'w') as f: + json.dump(d, f, indent=2) +PYEOF + info " Added banner_file to config.json" + fi + + # Fix routing: remove geoip:private rules that require geoip.dat from xray_config.json + XCFG="$INSTALL_DIR/xray_config.json" + if [[ -f "$XCFG" ]]; then + if grep -q '"geoip:private"' "$XCFG" 2>/dev/null; then + python3 - "$XCFG" << 'PYEOF' +import json, sys +path = sys.argv[1] +with open(path) as f: + d = json.load(f) +routing = d.get('routing', {}) +rules = routing.get('rules', []) +# Remove rules that reference geoip:private +new_rules = [r for r in rules if 'geoip:private' not in r.get('ip', [])] +if new_rules != rules: + if new_rules: + d['routing']['rules'] = new_rules + else: + d.pop('routing', None) + with open(path, 'w') as f: + json.dump(d, f, indent=2) +PYEOF + info " Removed geoip:private routing rule from xray_config.json" + fi + fi +fi + +# ── 5. Restart service ──────────────────────────────────────────────────────── +info "[5/5] Restarting service…" + +if $RESTART_NEEDED; then + systemctl start "$SERVICE_NAME" + sleep 2 + if systemctl is-active --quiet "$SERVICE_NAME"; then + info " $SERVICE_NAME is running." + else + warn " $SERVICE_NAME failed to start — check logs:" + warn " journalctl -u $SERVICE_NAME -n 30 --no-pager" + warn " You can restore the old binary:" + warn " mv $INSTALL_DIR/sshpanel.bak $INSTALL_DIR/sshpanel && systemctl start $SERVICE_NAME" + exit 1 + fi +else + warn " Service was not running; start it with: systemctl start $SERVICE_NAME" +fi + +echo "" +echo -e "${GREEN}══════════════════════════════════════════${NC}" +echo -e "${GREEN} Update complete! ${NC}" +echo -e "${GREEN}══════════════════════════════════════════${NC}" +echo "" +echo -e " Logs: ${YELLOW}journalctl -u ${SERVICE_NAME} -f${NC}" +echo -e " ${YELLOW}tail -f ${INSTALL_DIR}/logs/panel.log${NC}" +echo "" +echo -e " Backup: ${YELLOW}${INSTALL_DIR}/sshpanel.bak${NC}" +echo "" +echo -e "${YELLOW}What was updated:${NC}" +echo -e " • sshpanel binary" +echo -e " • Admin panel (admin/index.html)" +echo -e "${YELLOW}What was preserved:${NC}" +echo -e " • .env (DB credentials, tokens)" +echo -e " • config.json (your server settings)" +echo -e " • xray_config.json (your Xray settings)" +echo -e " • SSH host keys" +echo -e " • All user data in PostgreSQL" +echo "" diff --git a/xray_clients.go b/xray_clients.go new file mode 100644 index 0000000..4813038 --- /dev/null +++ b/xray_clients.go @@ -0,0 +1,164 @@ +package main + +import ( + "context" + "database/sql" + "log" + "time" +) + +// XrayClientMeta holds metadata stored in PostgreSQL for an Xray client. +// Xray's own config only stores uuid/email/level; expiry and display name live here. +type XrayClientMeta struct { + UUID string + Name string + Email string + InboundTag string + ExpiresAt *time.Time + MaxConns int + CreatedAt time.Time +} + +func (s *Store) EnsureXrayClientsSchema(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS xray_clients ( + uuid TEXT PRIMARY KEY, + name TEXT NOT NULL DEFAULT '', + email TEXT NOT NULL DEFAULT '', + inbound_tag TEXT NOT NULL DEFAULT '', + expires_at TIMESTAMPTZ, + max_conns INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + )`) + return err +} + +func (s *Store) UpsertXrayClientMeta(ctx context.Context, m XrayClientMeta) error { + var expiresAt interface{} + if m.ExpiresAt != nil { + expiresAt = *m.ExpiresAt + } + _, err := s.db.ExecContext(ctx, ` + INSERT INTO xray_clients (uuid, name, email, inbound_tag, expires_at, max_conns) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (uuid) DO UPDATE SET + name = EXCLUDED.name, + email = EXCLUDED.email, + inbound_tag = EXCLUDED.inbound_tag, + expires_at = EXCLUDED.expires_at, + max_conns = EXCLUDED.max_conns`, + m.UUID, m.Name, m.Email, m.InboundTag, expiresAt, m.MaxConns) + return err +} + +func (s *Store) GetXrayClientMeta(ctx context.Context, uuid string) (*XrayClientMeta, error) { + m := &XrayClientMeta{} + var expiresAt sql.NullTime + err := s.db.QueryRowContext(ctx, ` + SELECT uuid, name, email, inbound_tag, expires_at, max_conns, created_at + FROM xray_clients WHERE uuid = $1`, uuid). + Scan(&m.UUID, &m.Name, &m.Email, &m.InboundTag, &expiresAt, &m.MaxConns, &m.CreatedAt) + if err != nil { + return nil, err + } + if expiresAt.Valid { + m.ExpiresAt = &expiresAt.Time + } + return m, nil +} + +func (s *Store) DeleteXrayClientMeta(ctx context.Context, uuid string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM xray_clients WHERE uuid = $1`, uuid) + return err +} + +func (s *Store) ListAllXrayClients(ctx context.Context) ([]*XrayClientMeta, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT uuid, name, email, inbound_tag, expires_at, max_conns, created_at + FROM xray_clients ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + var out []*XrayClientMeta + for rows.Next() { + m := &XrayClientMeta{} + var expiresAt sql.NullTime + if err := rows.Scan(&m.UUID, &m.Name, &m.Email, &m.InboundTag, &expiresAt, &m.MaxConns, &m.CreatedAt); err != nil { + return nil, err + } + if expiresAt.Valid { + m.ExpiresAt = &expiresAt.Time + } + out = append(out, m) + } + return out, rows.Err() +} + +func (s *Store) ListExpiredXrayClients(ctx context.Context) ([]*XrayClientMeta, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT uuid, name, email, inbound_tag, expires_at, max_conns, created_at + FROM xray_clients WHERE expires_at IS NOT NULL AND expires_at <= NOW()`) + if err != nil { + return nil, err + } + defer rows.Close() + var out []*XrayClientMeta + for rows.Next() { + m := &XrayClientMeta{} + var expiresAt sql.NullTime + if err := rows.Scan(&m.UUID, &m.Name, &m.Email, &m.InboundTag, &expiresAt, &m.MaxConns, &m.CreatedAt); err != nil { + return nil, err + } + if expiresAt.Valid { + m.ExpiresAt = &expiresAt.Time + } + out = append(out, m) + } + return out, rows.Err() +} + +// startXrayClientExpiryChecker runs a background goroutine that removes expired +// Xray clients from both the config file and the database every 5 minutes. +func startXrayClientExpiryChecker(store *Store) { + if store == nil { + return + } + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + ctx := context.Background() + expired, err := store.ListExpiredXrayClients(ctx) + if err != nil { + log.Printf("xray expiry checker: list error: %v", err) + continue + } + if len(expired) == 0 { + continue + } + needRestart := false + for _, m := range expired { + tag := m.InboundTag + if tag == "" { + _ = store.DeleteXrayClientMeta(ctx, m.UUID) + continue + } + if err := xrayMgr.RemoveXrayClient(tag, m.UUID); err != nil { + log.Printf("xray expiry: remove %s from %s: %v", m.UUID, tag, err) + } else { + needRestart = true + } + if err := store.DeleteXrayClientMeta(ctx, m.UUID); err != nil { + log.Printf("xray expiry: delete meta %s: %v", m.UUID, err) + } + log.Printf("xray expiry: removed expired client %q (%s) from inbound %s", m.Name, m.UUID, tag) + } + if needRestart { + if err := xrayMgr.Restart(); err != nil { + log.Printf("xray expiry: restart error: %v", err) + } + } + } + }() +} diff --git a/xray_integration.go b/xray_integration.go new file mode 100644 index 0000000..afd0922 --- /dev/null +++ b/xray_integration.go @@ -0,0 +1,692 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "strings" + "sync" + "syscall" + "time" +) + +// XrayConfig holds Xray process management settings embedded in the main Config. +type XrayConfig struct { + Enabled bool `json:"enabled"` + BinPath string `json:"bin_path"` // e.g. /opt/sshpanel/xray + ConfigFile string `json:"config_file"` // e.g. /opt/sshpanel/xray_config.json +} + +// xrayLogRing is a fixed-capacity circular buffer for captured log lines. +type xrayLogRing struct { + mu sync.Mutex + lines []string + pos int +} + +const xrayLogCap = 200 + +func (r *xrayLogRing) add(line string) { + r.mu.Lock() + defer r.mu.Unlock() + if len(r.lines) < xrayLogCap { + r.lines = append(r.lines, line) + } else { + r.lines[r.pos] = line + r.pos = (r.pos + 1) % xrayLogCap + } +} + +func (r *xrayLogRing) snapshot() []string { + r.mu.Lock() + defer r.mu.Unlock() + if len(r.lines) == 0 { + return nil + } + out := make([]string, len(r.lines)) + if len(r.lines) < xrayLogCap { + copy(out, r.lines) + } else { + n := copy(out, r.lines[r.pos:]) + copy(out[n:], r.lines[:r.pos]) + } + return out +} + +var xrayLogBuf = &xrayLogRing{} + +// xrayWriter captures writes from the xray subprocess into the log ring buffer +// and forwards them to stderr so they appear in the process log. +type xrayWriter struct{} + +func (w xrayWriter) Write(p []byte) (int, error) { + text := strings.TrimRight(string(p), "\n") + for _, line := range strings.Split(text, "\n") { + if line != "" { + xrayLogBuf.add(line) + } + } + return os.Stderr.Write(p) +} + +// XrayManager manages the lifecycle of the external xray subprocess. +type XrayManager struct { + mu sync.Mutex + cmd *exec.Cmd + doneCh chan struct{} + cfg *XrayConfig + startTime time.Time + lastErr string +} + +var xrayMgr = &XrayManager{} + +// initXrayManager stores the config and auto-starts Xray if Enabled is true. +func initXrayManager(cfg *XrayConfig) { + if cfg == nil { + return + } + xrayMgr.mu.Lock() + xrayMgr.cfg = cfg + xrayMgr.mu.Unlock() + + if cfg.Enabled { + if err := xrayMgr.Start(); err != nil { + log.Printf("xray: auto-start failed: %v", err) + } + } +} + +// isRunning returns true if the subprocess is currently alive. +// Must be called with m.mu held. +func (m *XrayManager) isRunning() bool { + if m.doneCh == nil { + return false + } + select { + case <-m.doneCh: + return false + default: + return true + } +} + +// Start launches the xray subprocess. Returns an error if already running or misconfigured. +func (m *XrayManager) Start() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.isRunning() { + return fmt.Errorf("xray already running (pid %d)", m.cmd.Process.Pid) + } + if m.cfg == nil { + return fmt.Errorf("xray not configured") + } + if _, err := os.Stat(m.cfg.BinPath); err != nil { + return fmt.Errorf("xray binary not found at %s", m.cfg.BinPath) + } + + args := []string{"run"} + if m.cfg.ConfigFile != "" { + args = append(args, "-c", m.cfg.ConfigFile) + } + + cmd := exec.Command(m.cfg.BinPath, args...) + cmd.Stdout = xrayWriter{} + cmd.Stderr = xrayWriter{} + + if err := cmd.Start(); err != nil { + m.lastErr = err.Error() + return fmt.Errorf("xray start: %w", err) + } + + doneCh := make(chan struct{}) + m.cmd = cmd + m.doneCh = doneCh + m.startTime = time.Now() + m.lastErr = "" + + go func() { + err := cmd.Wait() + close(doneCh) + m.mu.Lock() + if err != nil { + m.lastErr = err.Error() + } + m.mu.Unlock() + log.Printf("xray: process exited: %v", err) + }() + + log.Printf("xray: started (pid %d)", cmd.Process.Pid) + return nil +} + +// Stop sends SIGTERM and waits up to 5 s before forcing SIGKILL. +func (m *XrayManager) Stop() error { + m.mu.Lock() + if !m.isRunning() { + m.mu.Unlock() + return nil + } + doneCh := m.doneCh + cmd := m.cmd + m.mu.Unlock() + + _ = cmd.Process.Signal(syscall.SIGTERM) + + select { + case <-doneCh: + case <-time.After(5 * time.Second): + _ = cmd.Process.Kill() + select { + case <-doneCh: + case <-time.After(2 * time.Second): + } + } + log.Printf("xray: stopped") + return nil +} + +// Restart stops then starts the xray subprocess. +func (m *XrayManager) Restart() error { + if err := m.Stop(); err != nil { + return err + } + return m.Start() +} + +// XrayStatusDTO is returned by /api/xray/status. +type XrayStatusDTO struct { + Enabled bool `json:"enabled"` + Running bool `json:"running"` + PID int `json:"pid,omitempty"` + Uptime string `json:"uptime,omitempty"` + Error string `json:"error,omitempty"` +} + +// Status returns a snapshot of the current xray process state. +func (m *XrayManager) Status() XrayStatusDTO { + m.mu.Lock() + defer m.mu.Unlock() + + s := XrayStatusDTO{} + if m.cfg != nil { + s.Enabled = m.cfg.Enabled + } + if m.isRunning() && m.cmd != nil && m.cmd.Process != nil { + s.Running = true + s.PID = m.cmd.Process.Pid + s.Uptime = time.Since(m.startTime).Round(time.Second).String() + } + if m.lastErr != "" { + s.Error = m.lastErr + } + return s +} + +// GetConfig reads the current xray JSON config file. +func (m *XrayManager) GetConfig() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cfg == nil || m.cfg.ConfigFile == "" { + return nil, fmt.Errorf("xray config file not configured") + } + return os.ReadFile(m.cfg.ConfigFile) +} + +// SetConfig validates and atomically writes a new xray JSON config file. +func (m *XrayManager) SetConfig(data []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.cfg == nil || m.cfg.ConfigFile == "" { + return fmt.Errorf("xray config file not configured") + } + if !json.Valid(data) { + return fmt.Errorf("invalid JSON") + } + return os.WriteFile(m.cfg.ConfigFile, data, 0o600) +} + +// ---- Admin HTTP handlers ---- + +func handleXrayStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(xrayMgr.Status()) +} + +func handleXrayStart(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if err := xrayMgr.Start(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func handleXrayStop(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if err := xrayMgr.Stop(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func handleXrayRestart(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if err := xrayMgr.Restart(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func handleXrayConfig(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + data, err := xrayMgr.GetConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(data) + + case http.MethodPost: + body, err := io.ReadAll(io.LimitReader(r.Body, 512*1024)) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + if err := xrayMgr.SetConfig(body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func handleXrayLogs(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + lines := xrayLogBuf.snapshot() + if lines == nil { + lines = []string{} + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{"lines": lines}) +} + +// ---- Inbound / client management ---- + +// XrayClientInfo is a single client entry inside an Xray inbound. +type XrayClientInfo struct { + UUID string `json:"id"` + Email string `json:"email"` + Level int `json:"level,omitempty"` + // Metadata from PostgreSQL (enriched by handleXrayInbounds) + Name string `json:"name,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + ExpirationDays int `json:"expiration_days"` + MaxConns int `json:"max_conns"` + Expired bool `json:"expired,omitempty"` +} + +// XrayInboundInfo is returned by /api/xray/inbounds. +type XrayInboundInfo struct { + Tag string `json:"tag"` + Protocol string `json:"protocol"` + Port json.RawMessage `json:"port,omitempty"` + Listen string `json:"listen,omitempty"` + Clients []XrayClientInfo `json:"clients"` +} + +// protocols that carry a "clients" array in their settings +var xrayClientProtos = map[string]bool{ + "vless": true, "vmess": true, "trojan": true, +} + +// ListInbounds parses the config and returns only inbounds that support client lists. +func (m *XrayManager) ListInbounds() ([]XrayInboundInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cfg == nil || m.cfg.ConfigFile == "" { + return nil, fmt.Errorf("xray config file not configured") + } + data, err := os.ReadFile(m.cfg.ConfigFile) + if err != nil { + return nil, err + } + var cfg struct { + Inbounds []json.RawMessage `json:"inbounds"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse xray config: %w", err) + } + var result []XrayInboundInfo + for _, raw := range cfg.Inbounds { + var ib struct { + Tag string `json:"tag"` + Protocol string `json:"protocol"` + Port json.RawMessage `json:"port"` + Listen string `json:"listen"` + Settings struct { + Clients []XrayClientInfo `json:"clients"` + } `json:"settings"` + } + if err := json.Unmarshal(raw, &ib); err != nil { + continue + } + if !xrayClientProtos[strings.ToLower(ib.Protocol)] { + continue + } + clients := ib.Settings.Clients + if clients == nil { + clients = []XrayClientInfo{} + } + result = append(result, XrayInboundInfo{ + Tag: ib.Tag, + Protocol: strings.ToLower(ib.Protocol), + Port: ib.Port, + Listen: ib.Listen, + Clients: clients, + }) + } + if result == nil { + result = []XrayInboundInfo{} + } + return result, nil +} + +// modifyRawConfig reads the config as a generic map, calls fn to mutate it, then writes it back. +// Caller must hold m.mu. +func (m *XrayManager) modifyRawConfig(fn func(cfg map[string]interface{}) error) error { + if m.cfg == nil || m.cfg.ConfigFile == "" { + return fmt.Errorf("xray config file not configured") + } + data, err := os.ReadFile(m.cfg.ConfigFile) + if err != nil { + return err + } + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("parse xray config: %w", err) + } + if err := fn(raw); err != nil { + return err + } + out, err := json.MarshalIndent(raw, "", " ") + if err != nil { + return err + } + return os.WriteFile(m.cfg.ConfigFile, out, 0o600) +} + +// AddXrayClient adds a client to the named inbound and saves the config. +func (m *XrayManager) AddXrayClient(inboundTag, uuid, email string) error { + m.mu.Lock() + defer m.mu.Unlock() + return m.modifyRawConfig(func(raw map[string]interface{}) error { + inbounds, _ := raw["inbounds"].([]interface{}) + for _, ib := range inbounds { + ibMap, ok := ib.(map[string]interface{}) + if !ok { + continue + } + if tag, _ := ibMap["tag"].(string); tag != inboundTag { + continue + } + settings, _ := ibMap["settings"].(map[string]interface{}) + if settings == nil { + settings = make(map[string]interface{}) + ibMap["settings"] = settings + } + clients, _ := settings["clients"].([]interface{}) + for _, c := range clients { + if cm, ok := c.(map[string]interface{}); ok { + if id, _ := cm["id"].(string); id == uuid { + return fmt.Errorf("UUID %s already exists in inbound %s", uuid, inboundTag) + } + } + } + settings["clients"] = append(clients, map[string]interface{}{ + "id": uuid, "email": email, "level": 0, + }) + return nil + } + return fmt.Errorf("inbound %q not found", inboundTag) + }) +} + +// RemoveXrayClient removes a client by UUID from the named inbound and saves the config. +func (m *XrayManager) RemoveXrayClient(inboundTag, uuid string) error { + m.mu.Lock() + defer m.mu.Unlock() + return m.modifyRawConfig(func(raw map[string]interface{}) error { + inbounds, _ := raw["inbounds"].([]interface{}) + for _, ib := range inbounds { + ibMap, ok := ib.(map[string]interface{}) + if !ok { + continue + } + if tag, _ := ibMap["tag"].(string); tag != inboundTag { + continue + } + settings, _ := ibMap["settings"].(map[string]interface{}) + if settings == nil { + return fmt.Errorf("inbound %s has no settings", inboundTag) + } + clients, _ := settings["clients"].([]interface{}) + var kept []interface{} + removed := false + for _, c := range clients { + if cm, ok := c.(map[string]interface{}); ok { + if id, _ := cm["id"].(string); id == uuid { + removed = true + continue + } + } + kept = append(kept, c) + } + if !removed { + return fmt.Errorf("UUID %s not found in inbound %s", uuid, inboundTag) + } + settings["clients"] = kept + return nil + } + return fmt.Errorf("inbound %q not found", inboundTag) + }) +} + +// ---- HTTP handlers for inbound/client management ---- + +func handleXrayInbounds(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + inbounds, err := xrayMgr.ListInbounds() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // Enrich clients with metadata from PostgreSQL when available. + if statsStore != nil { + metas, err := statsStore.ListAllXrayClients(r.Context()) + if err == nil { + metaMap := make(map[string]*XrayClientMeta, len(metas)) + for _, m := range metas { + metaMap[m.UUID] = m + } + now := time.Now() + for i := range inbounds { + for j := range inbounds[i].Clients { + c := &inbounds[i].Clients[j] + m, ok := metaMap[c.UUID] + if !ok { + c.ExpirationDays = -1 + continue + } + c.Name = m.Name + c.ExpiresAt = m.ExpiresAt + c.MaxConns = m.MaxConns + if m.ExpiresAt == nil { + c.ExpirationDays = -1 + } else if m.ExpiresAt.Before(now) { + c.Expired = true + c.ExpirationDays = 0 + } else { + c.ExpirationDays = int(m.ExpiresAt.Sub(now).Hours() / 24) + } + } + } + } + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(inbounds) +} + +func handleXrayClientAdd(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var req struct { + InboundTag string `json:"inbound_tag"` + UUID string `json:"uuid"` + Email string `json:"email"` + Name string `json:"name"` + ExpiresAt string `json:"expires_at"` // RFC3339 or YYYY-MM-DD or empty + MaxConnections int `json:"max_connections"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if req.InboundTag == "" || req.UUID == "" { + http.Error(w, "inbound_tag and uuid required", http.StatusBadRequest) + return + } + if err := xrayMgr.AddXrayClient(req.InboundTag, req.UUID, req.Email); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if statsStore != nil { + meta := XrayClientMeta{ + UUID: req.UUID, + Name: req.Name, + Email: req.Email, + InboundTag: req.InboundTag, + MaxConns: req.MaxConnections, + } + if req.ExpiresAt != "" { + var t time.Time + var err error + for _, layout := range []string{time.RFC3339, "2006-01-02T15:04", "2006-01-02"} { + t, err = time.Parse(layout, req.ExpiresAt) + if err == nil { + break + } + } + if err == nil { + meta.ExpiresAt = &t + } + } + if err := statsStore.UpsertXrayClientMeta(r.Context(), meta); err != nil { + log.Printf("xray: save meta for %s: %v", req.UUID, err) + } + } + _ = xrayMgr.Restart() + w.WriteHeader(http.StatusCreated) +} + +// handleXrayClientUpdate updates the metadata (name, email, expiry, max_conns) +// of an existing Xray client in PostgreSQL without touching the config file. +func handleXrayClientUpdate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + var req struct { + UUID string `json:"uuid"` + Name string `json:"name"` + Email string `json:"email"` + ExpiresAt string `json:"expires_at"` + MaxConnections int `json:"max_connections"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if req.UUID == "" { + http.Error(w, "uuid required", http.StatusBadRequest) + return + } + if statsStore == nil { + http.Error(w, "storage not available", http.StatusInternalServerError) + return + } + meta := XrayClientMeta{ + UUID: req.UUID, + Name: req.Name, + Email: req.Email, + MaxConns: req.MaxConnections, + } + if req.ExpiresAt != "" { + for _, layout := range []string{time.RFC3339, "2006-01-02T15:04", "2006-01-02"} { + if t, err := time.Parse(layout, req.ExpiresAt); err == nil { + meta.ExpiresAt = &t + break + } + } + } + if err := statsStore.UpsertXrayClientMeta(r.Context(), meta); err != nil { + http.Error(w, "update failed: "+err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func handleXrayClientRemove(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + inboundTag := r.URL.Query().Get("inbound_tag") + uuid := r.URL.Query().Get("uuid") + if inboundTag == "" || uuid == "" { + http.Error(w, "inbound_tag and uuid required", http.StatusBadRequest) + return + } + if err := xrayMgr.RemoveXrayClient(inboundTag, uuid); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if statsStore != nil { + _ = statsStore.DeleteXrayClientMeta(r.Context(), uuid) + } + _ = xrayMgr.Restart() + w.WriteHeader(http.StatusNoContent) +}