fix test
This commit is contained in:
150
main.go
150
main.go
@@ -48,10 +48,11 @@ type Store struct {
|
||||
}
|
||||
|
||||
type App struct {
|
||||
cfg Config
|
||||
sessions map[string]time.Time
|
||||
sessMu sync.Mutex
|
||||
store *Store
|
||||
cfg Config
|
||||
sessions map[string]time.Time
|
||||
sessMu sync.Mutex
|
||||
store *Store
|
||||
startedAt time.Time
|
||||
}
|
||||
|
||||
var usernameRE = regexp.MustCompile(`^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,31}$`)
|
||||
@@ -85,7 +86,7 @@ func main() {
|
||||
log.Fatalf("store: %v", err)
|
||||
}
|
||||
|
||||
app := &App{cfg: cfg, sessions: map[string]time.Time{}, store: st}
|
||||
app := &App{cfg: cfg, sessions: map[string]time.Time{}, store: st, startedAt: time.Now()}
|
||||
go app.expiryLoop()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@@ -165,7 +166,9 @@ func loadStore(path string) (*Store, error) {
|
||||
|
||||
func initSQLiteStore(path string) error {
|
||||
if _, err := exec.LookPath("sqlite3"); err != nil {
|
||||
return fmt.Errorf("sqlite3 is required for the bridge account store: %w", err)
|
||||
if _, pyErr := exec.LookPath("python3"); pyErr != nil {
|
||||
return fmt.Errorf("sqlite account store requires sqlite3 or python3: sqlite3=%v python3=%v", err, pyErr)
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return err
|
||||
@@ -182,24 +185,70 @@ CREATE TABLE IF NOT EXISTS accounts (
|
||||
created_at_unix INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
`
|
||||
return sqliteExec(path, schema)
|
||||
if err := sqliteExec(path, schema); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteExec(path, sql string) error {
|
||||
cmd := exec.Command("sqlite3", path)
|
||||
if _, err := exec.LookPath("sqlite3"); err == nil {
|
||||
cmd := exec.Command("sqlite3", path)
|
||||
cmd.Stdin = strings.NewReader(sql)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite3: %v: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
py := `import sqlite3, sys
|
||||
path = sys.argv[1]
|
||||
sql = sys.stdin.read()
|
||||
con = sqlite3.connect(path)
|
||||
try:
|
||||
con.executescript(sql)
|
||||
con.commit()
|
||||
finally:
|
||||
con.close()
|
||||
`
|
||||
cmd := exec.Command("python3", "-c", py, path)
|
||||
cmd.Stdin = strings.NewReader(sql)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite3: %v: %s", err, strings.TrimSpace(string(out)))
|
||||
return fmt.Errorf("python sqlite exec: %v: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteQuery(path, sql string) ([]string, error) {
|
||||
cmd := exec.Command("sqlite3", "-separator", "\t", "-nullvalue", "", path, sql)
|
||||
if _, err := exec.LookPath("sqlite3"); err == nil {
|
||||
cmd := exec.Command("sqlite3", "-separator", "\t", "-nullvalue", "", path, sql)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite3 query: %v: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
text := strings.TrimRight(string(out), "\n")
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return strings.Split(text, "\n"), nil
|
||||
}
|
||||
py := `import sqlite3, sys
|
||||
path = sys.argv[1]
|
||||
sql = sys.stdin.read()
|
||||
con = sqlite3.connect(path)
|
||||
try:
|
||||
cur = con.execute(sql)
|
||||
for row in cur.fetchall():
|
||||
print("\t".join("" if v is None else str(v) for v in row))
|
||||
finally:
|
||||
con.close()
|
||||
`
|
||||
cmd := exec.Command("python3", "-c", py, path)
|
||||
cmd.Stdin = strings.NewReader(sql)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite3 query: %v: %s", err, strings.TrimSpace(string(out)))
|
||||
return nil, fmt.Errorf("python sqlite query: %v: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
text := strings.TrimRight(string(out), "\n")
|
||||
if text == "" {
|
||||
@@ -318,7 +367,9 @@ func (a *App) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
errText(w, 400, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Username != a.cfg.Username || req.Password != a.cfg.Password {
|
||||
req.Username = strings.TrimSpace(req.Username)
|
||||
req.Password = strings.TrimSpace(req.Password)
|
||||
if req.Username != a.cfg.Username || (req.Password != a.cfg.Password && req.Password != a.cfg.Token) {
|
||||
errText(w, 401, "invalid credentials")
|
||||
return
|
||||
}
|
||||
@@ -331,11 +382,27 @@ func (a *App) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (a *App) auth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Senha") == a.cfg.Token || strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") == a.cfg.Token {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
bearer := strings.TrimSpace(strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer "))
|
||||
staticTokens := []string{
|
||||
r.Header.Get("Senha"),
|
||||
r.Header.Get("X-API-Token"),
|
||||
r.Header.Get("X-Bridge-Token"),
|
||||
r.Header.Get("X-Auth-Token"),
|
||||
bearer,
|
||||
}
|
||||
for _, tok := range staticTokens {
|
||||
if strings.TrimSpace(tok) != "" && strings.TrimSpace(tok) == a.cfg.Token {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// A login session may be sent either as X-Session-Token or as a Bearer
|
||||
// token. Accept both so existing panels and direct curl tests work.
|
||||
tok := strings.TrimSpace(r.Header.Get("X-Session-Token"))
|
||||
if tok == "" {
|
||||
tok = bearer
|
||||
}
|
||||
tok := r.Header.Get("X-Session-Token")
|
||||
a.sessMu.Lock()
|
||||
exp, ok := a.sessions[tok]
|
||||
if ok && time.Now().After(exp) {
|
||||
@@ -374,8 +441,6 @@ func (a *App) handleCreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
Password *string `json:"password"`
|
||||
MaxConnections int `json:"max_connections"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
LimitUpMbps int `json:"limit_mbps_up"`
|
||||
LimitDownMbps int `json:"limit_mbps_down"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
||||
errText(w, 400, "invalid json")
|
||||
@@ -551,7 +616,7 @@ func (a *App) deleteSSH(username, uuid string) error {
|
||||
_ = removeXrayClientAll(uuid)
|
||||
}
|
||||
|
||||
if err := forceRemoveSystemUser(username); err != nil {
|
||||
if err := removeSystemUser(username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = removeCompatUserFiles(username)
|
||||
@@ -562,15 +627,17 @@ func (a *App) deleteSSH(username, uuid string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func forceRemoveSystemUser(username string) error {
|
||||
func removeSystemUser(username string, randomizePassword bool) error {
|
||||
if _, err := user.Lookup(username); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Linux user expiry has day granularity only. For minute/hour tests the
|
||||
// bridge disables the password exactly at the stored SQLite expiry time,
|
||||
// kills active SSH sessions, then deletes the OS user.
|
||||
disableSystemPassword(username)
|
||||
// Never randomize a password for normal delete/recreate/update paths. The
|
||||
// random password step is only for a live expiry event, and even that is
|
||||
// skipped for accounts that were already expired before the bridge booted.
|
||||
if randomizePassword {
|
||||
disableSystemPassword(username)
|
||||
}
|
||||
disconnectSSHUser(username)
|
||||
|
||||
var lastErr error
|
||||
@@ -665,6 +732,29 @@ func removeLinePrefix(path, prefix string) {
|
||||
_ = ioutil.WriteFile(path, []byte(strings.Join(out, "\n")+"\n"), 0644)
|
||||
}
|
||||
|
||||
func (a *App) expireSSH(ac Account) error {
|
||||
if ac.Username == "" {
|
||||
return nil
|
||||
}
|
||||
if ac.UUID != "" {
|
||||
_ = removeXrayClientAll(ac.UUID)
|
||||
}
|
||||
|
||||
// If the account was already expired before this process started, this is a
|
||||
// boot/restart cleanup. Do not change its password during boot; just remove
|
||||
// and disconnect it. Live expiries after boot may randomize/lock first.
|
||||
randomizePassword := ac.ExpiresAt != nil && ac.ExpiresAt.After(a.startedAt)
|
||||
if err := removeSystemUser(ac.Username, randomizePassword); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = removeCompatUserFiles(ac.Username)
|
||||
a.store.mu.Lock()
|
||||
delete(a.store.Accounts, ac.Username)
|
||||
err := a.store.saveLocked()
|
||||
a.store.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *App) expiryLoop() {
|
||||
for {
|
||||
time.Sleep(10 * time.Second)
|
||||
@@ -679,7 +769,9 @@ func (a *App) expiryLoop() {
|
||||
a.store.mu.Unlock()
|
||||
for _, ac := range expired {
|
||||
log.Printf("expiring %s", ac.Username)
|
||||
_ = a.deleteSSH(ac.Username, ac.UUID)
|
||||
if err := a.expireSSH(ac); err != nil {
|
||||
log.Printf("expire %s failed: %v", ac.Username, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -765,6 +857,14 @@ func activeSSHByPgrep(username string) int {
|
||||
return n
|
||||
}
|
||||
|
||||
func positiveAtoi(s string) int {
|
||||
v, err := strconv.Atoi(strings.TrimSpace(s))
|
||||
if err != nil || v < 0 {
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseTimeMaybe(s string) (*time.Time, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
|
||||
Reference in New Issue
Block a user