Files
DragonCoreSSH-NewWEB/udpgw_integration.go
2026-05-02 23:20:13 -03:00

499 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
// This file embeds the UDP gateway (udpgw) service into the main
// application. The original udpgw program (public domain) accepts a
// TCP connection, then forwards framed UDP datagrams to arbitrary
// destinations and demultiplexes replies back to the originating
// connection. Here we expose the same functionality via a
// configuration key in config.json. When enabled, the server binds
// to the provided Listen address (or the default 0.0.0.0:7400 if
// unspecified) and handles each client in its own goroutine. The
// gateway runs entirely inprocess and behaves like the standalone
// badvpn-udpgw daemon.
import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"
)
var (
udpgwMu sync.Mutex
udpgwLn net.Listener
)
// stopUDPGW closes the active UDPGW listener, causing the accept loop to exit.
// It is a no-op if UDPGW is not running.
func stopUDPGW() {
udpgwMu.Lock()
defer udpgwMu.Unlock()
if udpgwLn != nil {
_ = udpgwLn.Close()
udpgwLn = nil
}
}
func udpgwRunning() bool {
udpgwMu.Lock()
defer udpgwMu.Unlock()
return udpgwLn != nil
}
// startUDPGW starts the integrated UDP gateway if cfg is nonnil and
// cfg.Listen is nonempty. It applies default values to any zero
// configuration fields and converts duration strings to time.Duration.
// The server runs in a goroutine; any fatal errors are logged and
// prevent the gateway from starting, but do not terminate the main
// process.
func startUDPGW(cfg *UDPGWConfig) error {
if cfg == nil {
return nil
}
// Default the listen address to the standalone default (0.0.0.0:7400) if
// unspecified. This matches the behaviour of the original
// badvpn-udpgw program, which listens on all interfaces by default.
listenAddr := cfg.Listen
if listenAddr == "" {
listenAddr = defaultUDPGWListen
}
cfg.Listen = listenAddr
// Apply defaults for numeric fields if zero.
c := &internalUDPGWConfig{}
c.listen = listenAddr
if cfg.MaxFrame > 0 {
c.maxFrame = cfg.MaxFrame
} else {
c.maxFrame = 64 * 1024
}
c.debug = cfg.Debug
if cfg.HexdumpN > 0 {
c.hexdumpN = cfg.HexdumpN
} else {
c.hexdumpN = 64
}
if cfg.WriteChan > 0 {
c.writeChan = cfg.WriteChan
} else {
c.writeChan = 4096
}
c.udpBindIP = cfg.UDPBindIP
if cfg.UDPRBuf > 0 {
c.udpRBuf = cfg.UDPRBuf
} else {
c.udpRBuf = 8 * 1024 * 1024
}
if cfg.UDPWBuf > 0 {
c.udpWBuf = cfg.UDPWBuf
} else {
c.udpWBuf = 8 * 1024 * 1024
}
// Parse durations with fallback defaults.
if cfg.MapTTL != "" {
if d, err := time.ParseDuration(cfg.MapTTL); err == nil {
c.mapTTL = d
} else {
log.Printf("udpgw: invalid map_ttl %q: %v; using default 90s", cfg.MapTTL, err)
c.mapTTL = 90 * time.Second
}
} else {
c.mapTTL = 90 * time.Second
}
if cfg.ReapEvery != "" {
if d, err := time.ParseDuration(cfg.ReapEvery); err == nil {
c.reapEvery = d
} else {
log.Printf("udpgw: invalid reap_every %q: %v; using default 10s", cfg.ReapEvery, err)
c.reapEvery = 10 * time.Second
}
} else {
c.reapEvery = 10 * time.Second
}
// Idle timeout for TCP clients.
if cfg.IdleTimeout != "" {
if d, err := time.ParseDuration(cfg.IdleTimeout); err == nil {
c.idleTimeout = d
} else {
log.Printf("udpgw: invalid idle_timeout %q: %v; using default 2m", cfg.IdleTimeout, err)
c.idleTimeout = 2 * time.Minute
}
} else {
c.idleTimeout = 2 * time.Minute
}
// Per-client logical connID cap.
if cfg.MaxClientConns > 0 {
c.maxClientConns = cfg.MaxClientConns
} else {
c.maxClientConns = 10
}
// Per-client destination mapping cap.
if cfg.MaxMapEntries > 0 {
c.maxMapEntries = cfg.MaxMapEntries
} else {
c.maxMapEntries = 32768
}
// Start listening.
ln, err := net.Listen("tcp", c.listen)
if err != nil {
log.Printf("udpgw: listen failed on %s: %v", c.listen, err)
return fmt.Errorf("udpgw: listen failed on %s: %w", c.listen, err)
}
// Register as the active listener so stopUDPGW can close it.
udpgwMu.Lock()
if udpgwLn != nil {
_ = udpgwLn.Close()
}
udpgwLn = ln
udpgwMu.Unlock()
if c.debug {
log.Printf("udpgw: listening on %s", c.listen)
}
go func() {
for {
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("udpgw: accept error: %v", err)
continue
}
go handleUDPGWClient(conn, c)
}
}()
return nil
}
// internalUDPGWConfig mirrors the exported UDPGWConfig but with
// time.Duration fields for TTL and reaper intervals. It does not
// embed JSON tags because it is not exposed to the user.
type internalUDPGWConfig struct {
listen string
maxFrame int
debug bool
hexdumpN int
writeChan int
udpBindIP string
udpRBuf int
udpWBuf int
mapTTL time.Duration
reapEvery time.Duration
idleTimeout time.Duration
maxClientConns int
maxMapEntries int
}
// udpDestKey identifies a destination IPv4:port for the UDP gateway. A
// mapping of this key to a connID+x byte is kept per client so
// replies can be routed correctly. Each client maintains its own map
// of udpDestKey->udpMapVal entries.
type udpDestKey struct {
ip [4]byte
port uint16
}
// udpMapVal stores the mapping from udpDestKey to connID, the x byte and the
// expiration time.
type udpMapVal struct {
connID uint16
x byte
exp time.Time
}
// handleUDPGWClient manages a single TCP client connection to the UDP
// gateway. It creates a perclient UDP socket, reads frames from the
// TCP connection, sends UDP datagrams to the requested destination,
// maintains a mapping of udpDestKey->udpMapVal for routing replies, and
// writes reply frames back over the TCP connection. When the client
// disconnects or an error occurs, all goroutines are terminated and the
// UDP socket is closed.
func handleUDPGWClient(conn net.Conn, c *internalUDPGWConfig) {
defer conn.Close()
remote := conn.RemoteAddr().String()
if c.debug {
log.Printf("udpgw: client connected: %s", remote)
}
// Lower latency for interactive applications by disabling Nagle.
if tcp, ok := conn.(*net.TCPConn); ok {
_ = tcp.SetNoDelay(true)
}
br := bufio.NewReaderSize(conn, 256*1024)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Bind a UDP socket for this client. Use cfg.udpBindIP if provided.
var laddr *net.UDPAddr
if c.udpBindIP != "" {
ip := net.ParseIP(c.udpBindIP)
if ip != nil {
laddr = &net.UDPAddr{IP: ip, Port: 0}
} else {
log.Printf("udpgw[%s]: invalid udp_bind IP %q", remote, c.udpBindIP)
return
}
}
udpConn, err := net.ListenUDP("udp", laddr)
if err != nil {
log.Printf("udpgw[%s]: UDP listen failed: %v", remote, err)
return
}
defer udpConn.Close()
_ = udpConn.SetReadBuffer(c.udpRBuf)
_ = udpConn.SetWriteBuffer(c.udpWBuf)
// Channel to queue outgoing frames back to the client.
writeCh := make(chan []byte, c.writeChan)
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case b := <-writeCh:
if len(b) == 0 {
continue
}
// If the client stops reading, don't block forever.
_ = conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
_, err := conn.Write(b)
if err != nil {
cancel()
return
}
}
}
}()
// Destination -> connID+x mapping and active connID tracking per TCP client.
var mu sync.Mutex
destToConn := make(map[udpDestKey]udpMapVal)
connIDLastSeen := make(map[uint16]time.Time)
// Start a reaper goroutine to purge expired mappings.
// Uses ctx instead of a separate stopReaper channel so that cancellation
// is always uniform: any exit path that calls cancel() (error, disconnect,
// idle timeout, panic-deferred cancel) also stops the reaper.
go func() {
t := time.NewTicker(c.reapEvery)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
now := time.Now()
mu.Lock()
for k, v := range destToConn {
if now.After(v.exp) {
delete(destToConn, k)
}
}
for id, lastSeen := range connIDLastSeen {
if now.Sub(lastSeen) > c.mapTTL {
delete(connIDLastSeen, id)
}
}
mu.Unlock()
}
}
}()
// Goroutine to read UDP replies and write framed responses back.
go func() {
buf := make([]byte, 65535)
for {
n, from, err := udpConn.ReadFromUDP(buf)
if err != nil {
return
}
if n <= 0 {
continue
}
ip4 := from.IP.To4()
if ip4 == nil {
// IPv6 replies are dropped because framing only supports IPv4.
continue
}
k := udpDestKey{port: uint16(from.Port)}
copy(k.ip[:], ip4)
mu.Lock()
v, ok := destToConn[k]
mu.Unlock()
if !ok || time.Now().After(v.exp) {
if c.debug {
log.Printf("udpgw[%s]: dropping %dB from %s (no mapping)", remote, n, from.String())
}
continue
}
// Avoid unbounded memory growth: if the client is slow and the
// write queue is full, drop replies rather than allocating more.
if len(writeCh) == cap(writeCh) {
if c.debug {
log.Printf("udpgw[%s]: drop UDP reply %dB from %s (write queue full)", remote, n, from.String())
}
continue
}
frame := udpgwBuildFrame(v.connID, v.x, k.ip, k.port, buf[:n])
select {
case <-ctx.Done():
return
case writeCh <- frame:
default:
// Race with another sender filling the channel.
}
if c.debug {
log.Printf("udpgw[%s]: UDP<- %dB from %s -> connID=%d x=0x%02x", remote, n, from.String(), v.connID, v.x)
}
}
}()
// Main loop: read frames from TCP, update mapping and send UDP.
for {
// Close idle TCP clients to avoid leaking goroutines.
if c.idleTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
payload, err := udpgwReadPayload(br, c.maxFrame)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
if c.debug {
log.Printf("udpgw[%s]: idle timeout (%s); closing", remote, c.idleTimeout)
}
}
if err == io.EOF {
if c.debug {
log.Printf("udpgw: client disconnected: %s", remote)
}
} else {
log.Printf("udpgw[%s]: read error: %v", remote, err)
}
cancel()
_ = conn.Close()
<-done
return
}
// payload: connID(2) + X(1) + dstIPv4(4) + dstPort(2) + data
if len(payload) < 2+1+4+2 {
if c.debug {
log.Printf("udpgw[%s]: too short payload %dB", remote, len(payload))
}
continue
}
connID := binary.BigEndian.Uint16(payload[0:2])
x := payload[2]
var dstIP [4]byte
copy(dstIP[:], payload[3:7])
dstPort := binary.BigEndian.Uint16(payload[7:9])
data := payload[9:]
if c.debug {
log.Printf("udpgw[%s]: RX connID=%d dst=%d.%d.%d.%d:%d x=0x%02x len=%d", remote, connID, dstIP[0], dstIP[1], dstIP[2], dstIP[3], dstPort, x, len(data))
}
// Enforce a per-client logical session cap using udpgw connID.
k := udpDestKey{ip: dstIP, port: dstPort}
now := time.Now()
mu.Lock()
// Step 1: evict TTL-expired connIDs first.
for id, lastSeen := range connIDLastSeen {
if now.Sub(lastSeen) > c.mapTTL {
delete(connIDLastSeen, id)
}
}
// Step 2: if the connID is new and we are still at or above the cap
// after TTL eviction, evict the single oldest entry to make room.
// This prevents a client rotating connIDs faster than mapTTL from
// bypassing maxClientConns and growing the map indefinitely.
if _, alreadyKnown := connIDLastSeen[connID]; !alreadyKnown && c.maxClientConns > 0 && len(connIDLastSeen) >= c.maxClientConns {
// Find and remove the oldest connID seen.
var oldestID uint16
var oldestTime time.Time
first := true
for id, lastSeen := range connIDLastSeen {
if first || lastSeen.Before(oldestTime) {
oldestID = id
oldestTime = lastSeen
first = false
}
}
delete(connIDLastSeen, oldestID)
if c.debug {
log.Printf("udpgw[%s]: connID cap reached (%d); evicted oldest connID=%d to admit connID=%d", remote, c.maxClientConns, oldestID, connID)
}
}
connIDLastSeen[connID] = now
// Update the mapping (bounded to protect memory).
if c.maxMapEntries > 0 && len(destToConn) >= c.maxMapEntries {
// First, purge expired entries.
for dk, dv := range destToConn {
if now.After(dv.exp) {
delete(destToConn, dk)
}
}
// If still too large, evict arbitrary entries until we're under the cap.
if len(destToConn) >= c.maxMapEntries {
evict := (len(destToConn) - c.maxMapEntries) + 1
for dk := range destToConn {
delete(destToConn, dk)
evict--
if evict <= 0 {
break
}
}
if c.debug {
log.Printf("udpgw[%s]: mapping cap reached; evicted entries, size=%d", remote, len(destToConn))
}
}
}
destToConn[k] = udpMapVal{connID: connID, x: x, exp: now.Add(c.mapTTL)}
mu.Unlock()
// Send the UDP datagram.
raddr := &net.UDPAddr{IP: net.IP(dstIP[:]), Port: int(dstPort)}
_, err = udpConn.WriteToUDP(data, raddr)
if err != nil {
if c.debug {
log.Printf("udpgw[%s]: UDP write failed to %s: %v", remote, raddr.String(), err)
}
continue
}
if c.debug {
log.Printf("udpgw[%s]: UDP-> %dB to %s", remote, len(data), raddr.String())
}
}
}
// udpgwReadPayload reads a lengthprefixed payload from r. The length
// prefix is a littleendian uint16. Payloads larger than max cause an
// error.
func udpgwReadPayload(r *bufio.Reader, max int) ([]byte, error) {
var lenBuf [2]byte
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
return nil, err
}
n := int(binary.LittleEndian.Uint16(lenBuf[:]))
if n <= 0 || n > max {
return nil, fmt.Errorf("udpgw: invalid frame length %d", n)
}
b := make([]byte, n)
if _, err := io.ReadFull(r, b); err != nil {
return nil, err
}
return b, nil
}
// udpgwBuildFrame constructs a reply frame for the client. The frame
// consists of a littleendian length prefix, then connID (big endian),
// x byte, source IPv4, source port, and the data.
func udpgwBuildFrame(connID uint16, x byte, ip [4]byte, port uint16, data []byte) []byte {
payloadLen := 2 + 1 + 4 + 2 + len(data)
out := make([]byte, 2+payloadLen)
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
binary.BigEndian.PutUint16(out[2:4], connID)
out[4] = x
copy(out[5:9], ip[:])
binary.BigEndian.PutUint16(out[9:11], port)
copy(out[11:], data)
return out
}