diff --git a/main.go b/main.go index a7a9800..d32d4b6 100644 --- a/main.go +++ b/main.go @@ -48,10 +48,11 @@ type Store struct { } type App struct { - cfg Config - sessions map[string]time.Time - sessMu sync.Mutex - store *Store + cfg Config + sessions map[string]time.Time + sessMu sync.Mutex + store *Store + startedAt time.Time } var usernameRE = regexp.MustCompile(`^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,31}$`) @@ -85,7 +86,7 @@ func main() { log.Fatalf("store: %v", err) } - app := &App{cfg: cfg, sessions: map[string]time.Time{}, store: st} + app := &App{cfg: cfg, sessions: map[string]time.Time{}, store: st, startedAt: time.Now()} go app.expiryLoop() mux := http.NewServeMux() @@ -165,7 +166,9 @@ func loadStore(path string) (*Store, error) { func initSQLiteStore(path string) error { if _, err := exec.LookPath("sqlite3"); err != nil { - return fmt.Errorf("sqlite3 is required for the bridge account store: %w", err) + if _, pyErr := exec.LookPath("python3"); pyErr != nil { + return fmt.Errorf("sqlite account store requires sqlite3 or python3: sqlite3=%v python3=%v", err, pyErr) + } } if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return err @@ -182,24 +185,70 @@ CREATE TABLE IF NOT EXISTS accounts ( created_at_unix INTEGER NOT NULL DEFAULT 0 ); ` - return sqliteExec(path, schema) + if err := sqliteExec(path, schema); err != nil { + return err + } + return nil } func sqliteExec(path, sql string) error { - cmd := exec.Command("sqlite3", path) + if _, err := exec.LookPath("sqlite3"); err == nil { + cmd := exec.Command("sqlite3", path) + cmd.Stdin = strings.NewReader(sql) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("sqlite3: %v: %s", err, strings.TrimSpace(string(out))) + } + return nil + } + py := `import sqlite3, sys +path = sys.argv[1] +sql = sys.stdin.read() +con = sqlite3.connect(path) +try: + con.executescript(sql) + con.commit() +finally: + con.close() +` + cmd := exec.Command("python3", "-c", py, path) cmd.Stdin = strings.NewReader(sql) out, err := cmd.CombinedOutput() if err != nil { - return fmt.Errorf("sqlite3: %v: %s", err, strings.TrimSpace(string(out))) + return fmt.Errorf("python sqlite exec: %v: %s", err, strings.TrimSpace(string(out))) } return nil } func sqliteQuery(path, sql string) ([]string, error) { - cmd := exec.Command("sqlite3", "-separator", "\t", "-nullvalue", "", path, sql) + if _, err := exec.LookPath("sqlite3"); err == nil { + cmd := exec.Command("sqlite3", "-separator", "\t", "-nullvalue", "", path, sql) + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("sqlite3 query: %v: %s", err, strings.TrimSpace(string(out))) + } + text := strings.TrimRight(string(out), "\n") + if text == "" { + return nil, nil + } + return strings.Split(text, "\n"), nil + } + py := `import sqlite3, sys +path = sys.argv[1] +sql = sys.stdin.read() +con = sqlite3.connect(path) +try: + cur = con.execute(sql) + for row in cur.fetchall(): + print("\t".join("" if v is None else str(v) for v in row)) +finally: + con.close() +` + cmd := exec.Command("python3", "-c", py, path) + cmd.Stdin = strings.NewReader(sql) out, err := cmd.CombinedOutput() if err != nil { - return nil, fmt.Errorf("sqlite3 query: %v: %s", err, strings.TrimSpace(string(out))) + return nil, fmt.Errorf("python sqlite query: %v: %s", err, strings.TrimSpace(string(out))) } text := strings.TrimRight(string(out), "\n") if text == "" { @@ -318,7 +367,9 @@ func (a *App) handleLogin(w http.ResponseWriter, r *http.Request) { errText(w, 400, "invalid json") return } - if req.Username != a.cfg.Username || req.Password != a.cfg.Password { + req.Username = strings.TrimSpace(req.Username) + req.Password = strings.TrimSpace(req.Password) + if req.Username != a.cfg.Username || (req.Password != a.cfg.Password && req.Password != a.cfg.Token) { errText(w, 401, "invalid credentials") return } @@ -331,11 +382,27 @@ func (a *App) handleLogin(w http.ResponseWriter, r *http.Request) { func (a *App) auth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Senha") == a.cfg.Token || strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") == a.cfg.Token { - next.ServeHTTP(w, r) - return + bearer := strings.TrimSpace(strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")) + staticTokens := []string{ + r.Header.Get("Senha"), + r.Header.Get("X-API-Token"), + r.Header.Get("X-Bridge-Token"), + r.Header.Get("X-Auth-Token"), + bearer, + } + for _, tok := range staticTokens { + if strings.TrimSpace(tok) != "" && strings.TrimSpace(tok) == a.cfg.Token { + next.ServeHTTP(w, r) + return + } + } + + // A login session may be sent either as X-Session-Token or as a Bearer + // token. Accept both so existing panels and direct curl tests work. + tok := strings.TrimSpace(r.Header.Get("X-Session-Token")) + if tok == "" { + tok = bearer } - tok := r.Header.Get("X-Session-Token") a.sessMu.Lock() exp, ok := a.sessions[tok] if ok && time.Now().After(exp) { @@ -374,8 +441,6 @@ func (a *App) handleCreateUser(w http.ResponseWriter, r *http.Request) { Password *string `json:"password"` MaxConnections int `json:"max_connections"` ExpiresAt string `json:"expires_at"` - LimitUpMbps int `json:"limit_mbps_up"` - LimitDownMbps int `json:"limit_mbps_down"` } if err := json.NewDecoder(r.Body).Decode(&p); err != nil { errText(w, 400, "invalid json") @@ -551,7 +616,7 @@ func (a *App) deleteSSH(username, uuid string) error { _ = removeXrayClientAll(uuid) } - if err := forceRemoveSystemUser(username); err != nil { + if err := removeSystemUser(username, false); err != nil { return err } _ = removeCompatUserFiles(username) @@ -562,15 +627,17 @@ func (a *App) deleteSSH(username, uuid string) error { return err } -func forceRemoveSystemUser(username string) error { +func removeSystemUser(username string, randomizePassword bool) error { if _, err := user.Lookup(username); err != nil { return nil } - // Linux user expiry has day granularity only. For minute/hour tests the - // bridge disables the password exactly at the stored SQLite expiry time, - // kills active SSH sessions, then deletes the OS user. - disableSystemPassword(username) + // Never randomize a password for normal delete/recreate/update paths. The + // random password step is only for a live expiry event, and even that is + // skipped for accounts that were already expired before the bridge booted. + if randomizePassword { + disableSystemPassword(username) + } disconnectSSHUser(username) var lastErr error @@ -665,6 +732,29 @@ func removeLinePrefix(path, prefix string) { _ = ioutil.WriteFile(path, []byte(strings.Join(out, "\n")+"\n"), 0644) } +func (a *App) expireSSH(ac Account) error { + if ac.Username == "" { + return nil + } + if ac.UUID != "" { + _ = removeXrayClientAll(ac.UUID) + } + + // If the account was already expired before this process started, this is a + // boot/restart cleanup. Do not change its password during boot; just remove + // and disconnect it. Live expiries after boot may randomize/lock first. + randomizePassword := ac.ExpiresAt != nil && ac.ExpiresAt.After(a.startedAt) + if err := removeSystemUser(ac.Username, randomizePassword); err != nil { + return err + } + _ = removeCompatUserFiles(ac.Username) + a.store.mu.Lock() + delete(a.store.Accounts, ac.Username) + err := a.store.saveLocked() + a.store.mu.Unlock() + return err +} + func (a *App) expiryLoop() { for { time.Sleep(10 * time.Second) @@ -679,7 +769,9 @@ func (a *App) expiryLoop() { a.store.mu.Unlock() for _, ac := range expired { log.Printf("expiring %s", ac.Username) - _ = a.deleteSSH(ac.Username, ac.UUID) + if err := a.expireSSH(ac); err != nil { + log.Printf("expire %s failed: %v", ac.Username, err) + } } } } @@ -765,6 +857,14 @@ func activeSSHByPgrep(username string) int { return n } +func positiveAtoi(s string) int { + v, err := strconv.Atoi(strings.TrimSpace(s)) + if err != nil || v < 0 { + return 0 + } + return v +} + func parseTimeMaybe(s string) (*time.Time, error) { s = strings.TrimSpace(s) if s == "" {