fix test
This commit is contained in:
150
main.go
150
main.go
@@ -48,10 +48,11 @@ type Store struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type App struct {
|
type App struct {
|
||||||
cfg Config
|
cfg Config
|
||||||
sessions map[string]time.Time
|
sessions map[string]time.Time
|
||||||
sessMu sync.Mutex
|
sessMu sync.Mutex
|
||||||
store *Store
|
store *Store
|
||||||
|
startedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var usernameRE = regexp.MustCompile(`^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,31}$`)
|
var usernameRE = regexp.MustCompile(`^[a-zA-Z0-9_][a-zA-Z0-9_-]{0,31}$`)
|
||||||
@@ -85,7 +86,7 @@ func main() {
|
|||||||
log.Fatalf("store: %v", err)
|
log.Fatalf("store: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app := &App{cfg: cfg, sessions: map[string]time.Time{}, store: st}
|
app := &App{cfg: cfg, sessions: map[string]time.Time{}, store: st, startedAt: time.Now()}
|
||||||
go app.expiryLoop()
|
go app.expiryLoop()
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -165,7 +166,9 @@ func loadStore(path string) (*Store, error) {
|
|||||||
|
|
||||||
func initSQLiteStore(path string) error {
|
func initSQLiteStore(path string) error {
|
||||||
if _, err := exec.LookPath("sqlite3"); err != nil {
|
if _, err := exec.LookPath("sqlite3"); err != nil {
|
||||||
return fmt.Errorf("sqlite3 is required for the bridge account store: %w", err)
|
if _, pyErr := exec.LookPath("python3"); pyErr != nil {
|
||||||
|
return fmt.Errorf("sqlite account store requires sqlite3 or python3: sqlite3=%v python3=%v", err, pyErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -182,24 +185,70 @@ CREATE TABLE IF NOT EXISTS accounts (
|
|||||||
created_at_unix INTEGER NOT NULL DEFAULT 0
|
created_at_unix INTEGER NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
return sqliteExec(path, schema)
|
if err := sqliteExec(path, schema); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteExec(path, sql string) error {
|
func sqliteExec(path, sql string) error {
|
||||||
cmd := exec.Command("sqlite3", path)
|
if _, err := exec.LookPath("sqlite3"); err == nil {
|
||||||
|
cmd := exec.Command("sqlite3", path)
|
||||||
|
cmd.Stdin = strings.NewReader(sql)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sqlite3: %v: %s", err, strings.TrimSpace(string(out)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
py := `import sqlite3, sys
|
||||||
|
path = sys.argv[1]
|
||||||
|
sql = sys.stdin.read()
|
||||||
|
con = sqlite3.connect(path)
|
||||||
|
try:
|
||||||
|
con.executescript(sql)
|
||||||
|
con.commit()
|
||||||
|
finally:
|
||||||
|
con.close()
|
||||||
|
`
|
||||||
|
cmd := exec.Command("python3", "-c", py, path)
|
||||||
cmd.Stdin = strings.NewReader(sql)
|
cmd.Stdin = strings.NewReader(sql)
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sqlite3: %v: %s", err, strings.TrimSpace(string(out)))
|
return fmt.Errorf("python sqlite exec: %v: %s", err, strings.TrimSpace(string(out)))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteQuery(path, sql string) ([]string, error) {
|
func sqliteQuery(path, sql string) ([]string, error) {
|
||||||
cmd := exec.Command("sqlite3", "-separator", "\t", "-nullvalue", "", path, sql)
|
if _, err := exec.LookPath("sqlite3"); err == nil {
|
||||||
|
cmd := exec.Command("sqlite3", "-separator", "\t", "-nullvalue", "", path, sql)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sqlite3 query: %v: %s", err, strings.TrimSpace(string(out)))
|
||||||
|
}
|
||||||
|
text := strings.TrimRight(string(out), "\n")
|
||||||
|
if text == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return strings.Split(text, "\n"), nil
|
||||||
|
}
|
||||||
|
py := `import sqlite3, sys
|
||||||
|
path = sys.argv[1]
|
||||||
|
sql = sys.stdin.read()
|
||||||
|
con = sqlite3.connect(path)
|
||||||
|
try:
|
||||||
|
cur = con.execute(sql)
|
||||||
|
for row in cur.fetchall():
|
||||||
|
print("\t".join("" if v is None else str(v) for v in row))
|
||||||
|
finally:
|
||||||
|
con.close()
|
||||||
|
`
|
||||||
|
cmd := exec.Command("python3", "-c", py, path)
|
||||||
|
cmd.Stdin = strings.NewReader(sql)
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite3 query: %v: %s", err, strings.TrimSpace(string(out)))
|
return nil, fmt.Errorf("python sqlite query: %v: %s", err, strings.TrimSpace(string(out)))
|
||||||
}
|
}
|
||||||
text := strings.TrimRight(string(out), "\n")
|
text := strings.TrimRight(string(out), "\n")
|
||||||
if text == "" {
|
if text == "" {
|
||||||
@@ -318,7 +367,9 @@ func (a *App) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
errText(w, 400, "invalid json")
|
errText(w, 400, "invalid json")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if req.Username != a.cfg.Username || req.Password != a.cfg.Password {
|
req.Username = strings.TrimSpace(req.Username)
|
||||||
|
req.Password = strings.TrimSpace(req.Password)
|
||||||
|
if req.Username != a.cfg.Username || (req.Password != a.cfg.Password && req.Password != a.cfg.Token) {
|
||||||
errText(w, 401, "invalid credentials")
|
errText(w, 401, "invalid credentials")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -331,11 +382,27 @@ func (a *App) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func (a *App) auth(next http.Handler) http.Handler {
|
func (a *App) auth(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Senha") == a.cfg.Token || strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") == a.cfg.Token {
|
bearer := strings.TrimSpace(strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer "))
|
||||||
next.ServeHTTP(w, r)
|
staticTokens := []string{
|
||||||
return
|
r.Header.Get("Senha"),
|
||||||
|
r.Header.Get("X-API-Token"),
|
||||||
|
r.Header.Get("X-Bridge-Token"),
|
||||||
|
r.Header.Get("X-Auth-Token"),
|
||||||
|
bearer,
|
||||||
|
}
|
||||||
|
for _, tok := range staticTokens {
|
||||||
|
if strings.TrimSpace(tok) != "" && strings.TrimSpace(tok) == a.cfg.Token {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A login session may be sent either as X-Session-Token or as a Bearer
|
||||||
|
// token. Accept both so existing panels and direct curl tests work.
|
||||||
|
tok := strings.TrimSpace(r.Header.Get("X-Session-Token"))
|
||||||
|
if tok == "" {
|
||||||
|
tok = bearer
|
||||||
}
|
}
|
||||||
tok := r.Header.Get("X-Session-Token")
|
|
||||||
a.sessMu.Lock()
|
a.sessMu.Lock()
|
||||||
exp, ok := a.sessions[tok]
|
exp, ok := a.sessions[tok]
|
||||||
if ok && time.Now().After(exp) {
|
if ok && time.Now().After(exp) {
|
||||||
@@ -374,8 +441,6 @@ func (a *App) handleCreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
Password *string `json:"password"`
|
Password *string `json:"password"`
|
||||||
MaxConnections int `json:"max_connections"`
|
MaxConnections int `json:"max_connections"`
|
||||||
ExpiresAt string `json:"expires_at"`
|
ExpiresAt string `json:"expires_at"`
|
||||||
LimitUpMbps int `json:"limit_mbps_up"`
|
|
||||||
LimitDownMbps int `json:"limit_mbps_down"`
|
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
||||||
errText(w, 400, "invalid json")
|
errText(w, 400, "invalid json")
|
||||||
@@ -551,7 +616,7 @@ func (a *App) deleteSSH(username, uuid string) error {
|
|||||||
_ = removeXrayClientAll(uuid)
|
_ = removeXrayClientAll(uuid)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := forceRemoveSystemUser(username); err != nil {
|
if err := removeSystemUser(username, false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_ = removeCompatUserFiles(username)
|
_ = removeCompatUserFiles(username)
|
||||||
@@ -562,15 +627,17 @@ func (a *App) deleteSSH(username, uuid string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func forceRemoveSystemUser(username string) error {
|
func removeSystemUser(username string, randomizePassword bool) error {
|
||||||
if _, err := user.Lookup(username); err != nil {
|
if _, err := user.Lookup(username); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Linux user expiry has day granularity only. For minute/hour tests the
|
// Never randomize a password for normal delete/recreate/update paths. The
|
||||||
// bridge disables the password exactly at the stored SQLite expiry time,
|
// random password step is only for a live expiry event, and even that is
|
||||||
// kills active SSH sessions, then deletes the OS user.
|
// skipped for accounts that were already expired before the bridge booted.
|
||||||
disableSystemPassword(username)
|
if randomizePassword {
|
||||||
|
disableSystemPassword(username)
|
||||||
|
}
|
||||||
disconnectSSHUser(username)
|
disconnectSSHUser(username)
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
@@ -665,6 +732,29 @@ func removeLinePrefix(path, prefix string) {
|
|||||||
_ = ioutil.WriteFile(path, []byte(strings.Join(out, "\n")+"\n"), 0644)
|
_ = ioutil.WriteFile(path, []byte(strings.Join(out, "\n")+"\n"), 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) expireSSH(ac Account) error {
|
||||||
|
if ac.Username == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ac.UUID != "" {
|
||||||
|
_ = removeXrayClientAll(ac.UUID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the account was already expired before this process started, this is a
|
||||||
|
// boot/restart cleanup. Do not change its password during boot; just remove
|
||||||
|
// and disconnect it. Live expiries after boot may randomize/lock first.
|
||||||
|
randomizePassword := ac.ExpiresAt != nil && ac.ExpiresAt.After(a.startedAt)
|
||||||
|
if err := removeSystemUser(ac.Username, randomizePassword); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = removeCompatUserFiles(ac.Username)
|
||||||
|
a.store.mu.Lock()
|
||||||
|
delete(a.store.Accounts, ac.Username)
|
||||||
|
err := a.store.saveLocked()
|
||||||
|
a.store.mu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) expiryLoop() {
|
func (a *App) expiryLoop() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
@@ -679,7 +769,9 @@ func (a *App) expiryLoop() {
|
|||||||
a.store.mu.Unlock()
|
a.store.mu.Unlock()
|
||||||
for _, ac := range expired {
|
for _, ac := range expired {
|
||||||
log.Printf("expiring %s", ac.Username)
|
log.Printf("expiring %s", ac.Username)
|
||||||
_ = a.deleteSSH(ac.Username, ac.UUID)
|
if err := a.expireSSH(ac); err != nil {
|
||||||
|
log.Printf("expire %s failed: %v", ac.Username, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -765,6 +857,14 @@ func activeSSHByPgrep(username string) int {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func positiveAtoi(s string) int {
|
||||||
|
v, err := strconv.Atoi(strings.TrimSpace(s))
|
||||||
|
if err != nil || v < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
func parseTimeMaybe(s string) (*time.Time, error) {
|
func parseTimeMaybe(s string) (*time.Time, error) {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
if s == "" {
|
if s == "" {
|
||||||
|
|||||||
Reference in New Issue
Block a user