Safe Update
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -332,55 +333,79 @@ func getDNSTTLogLines() []string {
|
||||
// the Noise private key from cfg.PrivKeyFile, parses cfg.Domain into a dns.Name,
|
||||
// and then launches runDNSTT in a goroutine. Any errors during start are
|
||||
// logged. The SSH server configuration is used when handling streams.
|
||||
func startDNSTT(cfg *DNSTTConfig, sshConf *ssh.ServerConfig) {
|
||||
func startDNSTT(cfg *DNSTTConfig, sshConf *ssh.ServerConfig) error {
|
||||
if cfg == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
startDNSTTCapReaper()
|
||||
dnsttSSHConfig = sshConf
|
||||
// Configure whether periodic DNSTT statistics should be emitted to stderr.
|
||||
// When DisableStatsLog is true, stats will be collected but log lines are suppressed.
|
||||
if cfg != nil {
|
||||
dnsttPrintStats = !cfg.DisableStatsLog
|
||||
// Initialise the log buffer once. Use a capacity of 100 lines (~few KB).
|
||||
if dnsttLogBuf == nil {
|
||||
dnsttLogBuf = newDNSTTLogBuffer(100)
|
||||
}
|
||||
// Configure the DNSTT logger output. If DisableConsoleLog is set,
|
||||
// write only to the buffer; otherwise tee to both the buffer and stderr.
|
||||
if cfg.DisableConsoleLog {
|
||||
dnsttLog.SetOutput(dnsttLogBuf)
|
||||
} else {
|
||||
dnsttLog.SetOutput(io.MultiWriter(dnsttLogBuf, os.Stderr))
|
||||
}
|
||||
dnsttPrintStats = !cfg.DisableStatsLog
|
||||
// Initialise the log buffer once. Use a capacity of 100 lines (~few KB).
|
||||
if dnsttLogBuf == nil {
|
||||
dnsttLogBuf = newDNSTTLogBuffer(100)
|
||||
}
|
||||
// Configure the DNSTT logger output. If DisableConsoleLog is set,
|
||||
// write only to the buffer; otherwise tee to both the buffer and stderr.
|
||||
if cfg.DisableConsoleLog {
|
||||
dnsttLog.SetOutput(dnsttLogBuf)
|
||||
} else {
|
||||
dnsttLog.SetOutput(io.MultiWriter(dnsttLogBuf, os.Stderr))
|
||||
}
|
||||
|
||||
// Read the private key from file.
|
||||
f, err := os.Open(cfg.PrivKeyFile)
|
||||
if err != nil {
|
||||
dnsttLog.Printf("cannot open privkey file %s: %v", cfg.PrivKeyFile, err)
|
||||
return
|
||||
msg := fmt.Errorf("cannot open privkey file %s: %w", cfg.PrivKeyFile, err)
|
||||
dnsttLog.Print(msg.Error())
|
||||
return msg
|
||||
}
|
||||
privkey, err := noise.ReadKey(f)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
dnsttLog.Printf("cannot read privkey from file: %v", err)
|
||||
return
|
||||
msg := fmt.Errorf("cannot read privkey from file: %w", err)
|
||||
dnsttLog.Print(msg.Error())
|
||||
return msg
|
||||
}
|
||||
// Parse the domain name. dns.ParseName accepts a domain with a trailing
|
||||
// dot or without. Any error here will abort the dnstt server.
|
||||
// Parse the domain name. dns.ParseName accepts a domain with a trailing
|
||||
// dot or without. Any error here will abort the dnstt server.
|
||||
domain, err := dns.ParseName(cfg.Domain)
|
||||
if err != nil {
|
||||
dnsttLog.Printf("invalid domain %q: %v", cfg.Domain, err)
|
||||
return
|
||||
msg := fmt.Errorf("invalid domain %q: %w", cfg.Domain, err)
|
||||
dnsttLog.Print(msg.Error())
|
||||
return msg
|
||||
}
|
||||
udpListen := cfg.UDPListen
|
||||
if udpListen == "" {
|
||||
udpListen = defaultDNSTTListen
|
||||
cfg.UDPListen = udpListen
|
||||
}
|
||||
|
||||
// Bind synchronously so the admin panel can immediately know whether DNSTT
|
||||
// really started or failed because of a bad address/locked port.
|
||||
dnsConn, err := net.ListenPacket("udp", udpListen)
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("dnstt: opening UDP listener on %s: %w", udpListen, err)
|
||||
dnsttLog.Print(msg.Error())
|
||||
return msg
|
||||
}
|
||||
|
||||
// Log initialisation parameters so DNSTT startup is visible even when
|
||||
// quiet logging is enabled. This helps with debugging.
|
||||
dnsttLog.Printf("starting: domain=%q udp_listen=%q privkey=%q", cfg.Domain, cfg.UDPListen, cfg.PrivKeyFile)
|
||||
// quiet logging is enabled. This helps with debugging.
|
||||
dnsttLog.Printf("starting: domain=%q udp_listen=%q privkey=%q", cfg.Domain, udpListen, cfg.PrivKeyFile)
|
||||
go func() {
|
||||
if err := runDNSTT(privkey, domain, cfg.UDPListen); err != nil {
|
||||
if err := runDNSTTOnConn(privkey, domain, udpListen, dnsConn); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
dnsttLog.Printf("server exited with error: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func dnsttRunning() bool {
|
||||
dnsttConnMu.Lock()
|
||||
defer dnsttConnMu.Unlock()
|
||||
return dnsttConn != nil
|
||||
}
|
||||
|
||||
// handleDNSTTStream accepts a smux.Stream from a client and hands it off to
|
||||
@@ -991,6 +1016,10 @@ func runDNSTT(privkey []byte, domain dns.Name, udpListen string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("dnstt: opening UDP listener on %s: %v", udpListen, err)
|
||||
}
|
||||
return runDNSTTOnConn(privkey, domain, udpListen, dnsConn)
|
||||
}
|
||||
|
||||
func runDNSTTOnConn(privkey []byte, domain dns.Name, udpListen string, dnsConn net.PacketConn) error {
|
||||
if udp, ok := dnsConn.(*net.UDPConn); ok {
|
||||
_ = udp.SetReadBuffer(4 * 1024 * 1024)
|
||||
_ = udp.SetWriteBuffer(4 * 1024 * 1024)
|
||||
@@ -998,11 +1027,18 @@ func runDNSTT(privkey []byte, domain dns.Name, udpListen string) error {
|
||||
|
||||
// Register so stopDNSTT() can close this socket and unblock the read loop.
|
||||
dnsttConnMu.Lock()
|
||||
if dnsttConn != nil {
|
||||
if dnsttConn != nil && dnsttConn != dnsConn {
|
||||
_ = dnsttConn.Close()
|
||||
}
|
||||
dnsttConn = dnsConn
|
||||
dnsttConnMu.Unlock()
|
||||
defer func() {
|
||||
dnsttConnMu.Lock()
|
||||
if dnsttConn == dnsConn {
|
||||
dnsttConn = nil
|
||||
}
|
||||
dnsttConnMu.Unlock()
|
||||
}()
|
||||
// Log readiness of the UDP listener.
|
||||
dnsttLog.Printf("udp listener ready on %s", udpListen)
|
||||
// compute maximum encoded payload and resulting MTU
|
||||
|
||||
Reference in New Issue
Block a user