package main // This file embeds the UDP gateway (udpgw) service into the main // application. The original udpgw program (public domain) accepts a // TCP connection, then forwards framed UDP datagrams to arbitrary // destinations and demultiplexes replies back to the originating // connection. Here we expose the same functionality via a // configuration key in config.json. When enabled, the server binds // to the provided Listen address (or the default 0.0.0.0:7400 if // unspecified) and handles each client in its own goroutine. The // gateway runs entirely in‑process and behaves like the standalone // badvpn-udpgw daemon. import ( "bufio" "context" "encoding/binary" "errors" "fmt" "io" "log" "net" "sync" "time" ) var ( udpgwMu sync.Mutex udpgwLn net.Listener ) // stopUDPGW closes the active UDPGW listener, causing the accept loop to exit. // It is a no-op if UDPGW is not running. func stopUDPGW() { udpgwMu.Lock() defer udpgwMu.Unlock() if udpgwLn != nil { _ = udpgwLn.Close() udpgwLn = nil } } func udpgwRunning() bool { udpgwMu.Lock() defer udpgwMu.Unlock() return udpgwLn != nil } // startUDPGW starts the integrated UDP gateway if cfg is non‑nil and // cfg.Listen is non‑empty. It applies default values to any zero // configuration fields and converts duration strings to time.Duration. // The server runs in a goroutine; any fatal errors are logged and // prevent the gateway from starting, but do not terminate the main // process. func startUDPGW(cfg *UDPGWConfig) error { if cfg == nil { return nil } // Default the listen address to the standalone default (0.0.0.0:7400) if // unspecified. This matches the behaviour of the original // badvpn-udpgw program, which listens on all interfaces by default. listenAddr := cfg.Listen if listenAddr == "" { listenAddr = defaultUDPGWListen } cfg.Listen = listenAddr // Apply defaults for numeric fields if zero. c := &internalUDPGWConfig{} c.listen = listenAddr if cfg.MaxFrame > 0 { c.maxFrame = cfg.MaxFrame } else { c.maxFrame = 64 * 1024 } c.debug = cfg.Debug if cfg.HexdumpN > 0 { c.hexdumpN = cfg.HexdumpN } else { c.hexdumpN = 64 } if cfg.WriteChan > 0 { c.writeChan = cfg.WriteChan } else { c.writeChan = 4096 } c.udpBindIP = cfg.UDPBindIP if cfg.UDPRBuf > 0 { c.udpRBuf = cfg.UDPRBuf } else { c.udpRBuf = 8 * 1024 * 1024 } if cfg.UDPWBuf > 0 { c.udpWBuf = cfg.UDPWBuf } else { c.udpWBuf = 8 * 1024 * 1024 } // Parse durations with fallback defaults. if cfg.MapTTL != "" { if d, err := time.ParseDuration(cfg.MapTTL); err == nil { c.mapTTL = d } else { log.Printf("udpgw: invalid map_ttl %q: %v; using default 90s", cfg.MapTTL, err) c.mapTTL = 90 * time.Second } } else { c.mapTTL = 90 * time.Second } if cfg.ReapEvery != "" { if d, err := time.ParseDuration(cfg.ReapEvery); err == nil { c.reapEvery = d } else { log.Printf("udpgw: invalid reap_every %q: %v; using default 10s", cfg.ReapEvery, err) c.reapEvery = 10 * time.Second } } else { c.reapEvery = 10 * time.Second } // Idle timeout for TCP clients. if cfg.IdleTimeout != "" { if d, err := time.ParseDuration(cfg.IdleTimeout); err == nil { c.idleTimeout = d } else { log.Printf("udpgw: invalid idle_timeout %q: %v; using default 2m", cfg.IdleTimeout, err) c.idleTimeout = 2 * time.Minute } } else { c.idleTimeout = 2 * time.Minute } // Per-client logical connID cap. if cfg.MaxClientConns > 0 { c.maxClientConns = cfg.MaxClientConns } else { c.maxClientConns = 10 } // Per-client destination mapping cap. if cfg.MaxMapEntries > 0 { c.maxMapEntries = cfg.MaxMapEntries } else { c.maxMapEntries = 32768 } // Start listening. ln, err := net.Listen("tcp", c.listen) if err != nil { log.Printf("udpgw: listen failed on %s: %v", c.listen, err) return fmt.Errorf("udpgw: listen failed on %s: %w", c.listen, err) } // Register as the active listener so stopUDPGW can close it. udpgwMu.Lock() if udpgwLn != nil { _ = udpgwLn.Close() } udpgwLn = ln udpgwMu.Unlock() if c.debug { log.Printf("udpgw: listening on %s", c.listen) } go func() { for { conn, err := ln.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return } log.Printf("udpgw: accept error: %v", err) continue } go handleUDPGWClient(conn, c) } }() return nil } // internalUDPGWConfig mirrors the exported UDPGWConfig but with // time.Duration fields for TTL and reaper intervals. It does not // embed JSON tags because it is not exposed to the user. type internalUDPGWConfig struct { listen string maxFrame int debug bool hexdumpN int writeChan int udpBindIP string udpRBuf int udpWBuf int mapTTL time.Duration reapEvery time.Duration idleTimeout time.Duration maxClientConns int maxMapEntries int } // udpDestKey identifies a destination IPv4:port for the UDP gateway. A // mapping of this key to a connID+x byte is kept per client so // replies can be routed correctly. Each client maintains its own map // of udpDestKey->udpMapVal entries. type udpDestKey struct { ip [4]byte port uint16 } // udpMapVal stores the mapping from udpDestKey to connID, the x byte and the // expiration time. type udpMapVal struct { connID uint16 x byte exp time.Time } // handleUDPGWClient manages a single TCP client connection to the UDP // gateway. It creates a per‑client UDP socket, reads frames from the // TCP connection, sends UDP datagrams to the requested destination, // maintains a mapping of udpDestKey->udpMapVal for routing replies, and // writes reply frames back over the TCP connection. When the client // disconnects or an error occurs, all goroutines are terminated and the // UDP socket is closed. func handleUDPGWClient(conn net.Conn, c *internalUDPGWConfig) { defer conn.Close() remote := conn.RemoteAddr().String() if c.debug { log.Printf("udpgw: client connected: %s", remote) } // Lower latency for interactive applications by disabling Nagle. if tcp, ok := conn.(*net.TCPConn); ok { _ = tcp.SetNoDelay(true) } br := bufio.NewReaderSize(conn, 256*1024) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Bind a UDP socket for this client. Use cfg.udpBindIP if provided. var laddr *net.UDPAddr if c.udpBindIP != "" { ip := net.ParseIP(c.udpBindIP) if ip != nil { laddr = &net.UDPAddr{IP: ip, Port: 0} } else { log.Printf("udpgw[%s]: invalid udp_bind IP %q", remote, c.udpBindIP) return } } udpConn, err := net.ListenUDP("udp", laddr) if err != nil { log.Printf("udpgw[%s]: UDP listen failed: %v", remote, err) return } defer udpConn.Close() _ = udpConn.SetReadBuffer(c.udpRBuf) _ = udpConn.SetWriteBuffer(c.udpWBuf) // Channel to queue outgoing frames back to the client. writeCh := make(chan []byte, c.writeChan) done := make(chan struct{}) go func() { defer close(done) for { select { case <-ctx.Done(): return case b := <-writeCh: if len(b) == 0 { continue } // If the client stops reading, don't block forever. _ = conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) _, err := conn.Write(b) if err != nil { cancel() return } } } }() // Destination -> connID+x mapping and active connID tracking per TCP client. var mu sync.Mutex destToConn := make(map[udpDestKey]udpMapVal) connIDLastSeen := make(map[uint16]time.Time) // Start a reaper goroutine to purge expired mappings. // Uses ctx instead of a separate stopReaper channel so that cancellation // is always uniform: any exit path that calls cancel() (error, disconnect, // idle timeout, panic-deferred cancel) also stops the reaper. go func() { t := time.NewTicker(c.reapEvery) defer t.Stop() for { select { case <-ctx.Done(): return case <-t.C: now := time.Now() mu.Lock() for k, v := range destToConn { if now.After(v.exp) { delete(destToConn, k) } } for id, lastSeen := range connIDLastSeen { if now.Sub(lastSeen) > c.mapTTL { delete(connIDLastSeen, id) } } mu.Unlock() } } }() // Goroutine to read UDP replies and write framed responses back. go func() { buf := make([]byte, 65535) for { n, from, err := udpConn.ReadFromUDP(buf) if err != nil { return } if n <= 0 { continue } ip4 := from.IP.To4() if ip4 == nil { // IPv6 replies are dropped because framing only supports IPv4. continue } k := udpDestKey{port: uint16(from.Port)} copy(k.ip[:], ip4) mu.Lock() v, ok := destToConn[k] mu.Unlock() if !ok || time.Now().After(v.exp) { if c.debug { log.Printf("udpgw[%s]: dropping %dB from %s (no mapping)", remote, n, from.String()) } continue } // Avoid unbounded memory growth: if the client is slow and the // write queue is full, drop replies rather than allocating more. if len(writeCh) == cap(writeCh) { if c.debug { log.Printf("udpgw[%s]: drop UDP reply %dB from %s (write queue full)", remote, n, from.String()) } continue } frame := udpgwBuildFrame(v.connID, v.x, k.ip, k.port, buf[:n]) select { case <-ctx.Done(): return case writeCh <- frame: default: // Race with another sender filling the channel. } if c.debug { log.Printf("udpgw[%s]: UDP<- %dB from %s -> connID=%d x=0x%02x", remote, n, from.String(), v.connID, v.x) } } }() // Main loop: read frames from TCP, update mapping and send UDP. for { // Close idle TCP clients to avoid leaking goroutines. if c.idleTimeout > 0 { _ = conn.SetReadDeadline(time.Now().Add(c.idleTimeout)) } payload, err := udpgwReadPayload(br, c.maxFrame) if err != nil { if ne, ok := err.(net.Error); ok && ne.Timeout() { if c.debug { log.Printf("udpgw[%s]: idle timeout (%s); closing", remote, c.idleTimeout) } } if err == io.EOF { if c.debug { log.Printf("udpgw: client disconnected: %s", remote) } } else { log.Printf("udpgw[%s]: read error: %v", remote, err) } cancel() _ = conn.Close() <-done return } // payload: connID(2) + X(1) + dstIPv4(4) + dstPort(2) + data if len(payload) < 2+1+4+2 { if c.debug { log.Printf("udpgw[%s]: too short payload %dB", remote, len(payload)) } continue } connID := binary.BigEndian.Uint16(payload[0:2]) x := payload[2] var dstIP [4]byte copy(dstIP[:], payload[3:7]) dstPort := binary.BigEndian.Uint16(payload[7:9]) data := payload[9:] if c.debug { log.Printf("udpgw[%s]: RX connID=%d dst=%d.%d.%d.%d:%d x=0x%02x len=%d", remote, connID, dstIP[0], dstIP[1], dstIP[2], dstIP[3], dstPort, x, len(data)) } // Enforce a per-client logical session cap using udpgw connID. k := udpDestKey{ip: dstIP, port: dstPort} now := time.Now() mu.Lock() // Step 1: evict TTL-expired connIDs first. for id, lastSeen := range connIDLastSeen { if now.Sub(lastSeen) > c.mapTTL { delete(connIDLastSeen, id) } } // Step 2: if the connID is new and we are still at or above the cap // after TTL eviction, evict the single oldest entry to make room. // This prevents a client rotating connIDs faster than mapTTL from // bypassing maxClientConns and growing the map indefinitely. if _, alreadyKnown := connIDLastSeen[connID]; !alreadyKnown && c.maxClientConns > 0 && len(connIDLastSeen) >= c.maxClientConns { // Find and remove the oldest connID seen. var oldestID uint16 var oldestTime time.Time first := true for id, lastSeen := range connIDLastSeen { if first || lastSeen.Before(oldestTime) { oldestID = id oldestTime = lastSeen first = false } } delete(connIDLastSeen, oldestID) if c.debug { log.Printf("udpgw[%s]: connID cap reached (%d); evicted oldest connID=%d to admit connID=%d", remote, c.maxClientConns, oldestID, connID) } } connIDLastSeen[connID] = now // Update the mapping (bounded to protect memory). if c.maxMapEntries > 0 && len(destToConn) >= c.maxMapEntries { // First, purge expired entries. for dk, dv := range destToConn { if now.After(dv.exp) { delete(destToConn, dk) } } // If still too large, evict arbitrary entries until we're under the cap. if len(destToConn) >= c.maxMapEntries { evict := (len(destToConn) - c.maxMapEntries) + 1 for dk := range destToConn { delete(destToConn, dk) evict-- if evict <= 0 { break } } if c.debug { log.Printf("udpgw[%s]: mapping cap reached; evicted entries, size=%d", remote, len(destToConn)) } } } destToConn[k] = udpMapVal{connID: connID, x: x, exp: now.Add(c.mapTTL)} mu.Unlock() // Send the UDP datagram. raddr := &net.UDPAddr{IP: net.IP(dstIP[:]), Port: int(dstPort)} _, err = udpConn.WriteToUDP(data, raddr) if err != nil { if c.debug { log.Printf("udpgw[%s]: UDP write failed to %s: %v", remote, raddr.String(), err) } continue } if c.debug { log.Printf("udpgw[%s]: UDP-> %dB to %s", remote, len(data), raddr.String()) } } } // udpgwReadPayload reads a length‑prefixed payload from r. The length // prefix is a little‑endian uint16. Payloads larger than max cause an // error. func udpgwReadPayload(r *bufio.Reader, max int) ([]byte, error) { var lenBuf [2]byte if _, err := io.ReadFull(r, lenBuf[:]); err != nil { return nil, err } n := int(binary.LittleEndian.Uint16(lenBuf[:])) if n <= 0 || n > max { return nil, fmt.Errorf("udpgw: invalid frame length %d", n) } b := make([]byte, n) if _, err := io.ReadFull(r, b); err != nil { return nil, err } return b, nil } // udpgwBuildFrame constructs a reply frame for the client. The frame // consists of a little‑endian length prefix, then connID (big endian), // x byte, source IPv4, source port, and the data. func udpgwBuildFrame(connID uint16, x byte, ip [4]byte, port uint16, data []byte) []byte { payloadLen := 2 + 1 + 4 + 2 + len(data) out := make([]byte, 2+payloadLen) binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen)) binary.BigEndian.PutUint16(out[2:4], connID) out[4] = x copy(out[5:9], ip[:]) binary.BigEndian.PutUint16(out[9:11], port) copy(out[11:], data) return out }