Launch
This commit is contained in:
490
udpgw_integration.go
Normal file
490
udpgw_integration.go
Normal file
@@ -0,0 +1,490 @@
|
||||
package main
|
||||
|
||||
// This file embeds the UDP gateway (udpgw) service into the main
|
||||
// application. The original udpgw program (public domain) accepts a
|
||||
// TCP connection, then forwards framed UDP datagrams to arbitrary
|
||||
// destinations and demultiplexes replies back to the originating
|
||||
// connection. Here we expose the same functionality via a
|
||||
// configuration key in config.json. When enabled, the server binds
|
||||
// to the provided Listen address (or the default 0.0.0.0:7400 if
|
||||
// unspecified) and handles each client in its own goroutine. The
|
||||
// gateway runs entirely in‑process and behaves like the standalone
|
||||
// badvpn-udpgw daemon.
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
udpgwMu sync.Mutex
|
||||
udpgwLn net.Listener
|
||||
)
|
||||
|
||||
// stopUDPGW closes the active UDPGW listener, causing the accept loop to exit.
|
||||
// It is a no-op if UDPGW is not running.
|
||||
func stopUDPGW() {
|
||||
udpgwMu.Lock()
|
||||
defer udpgwMu.Unlock()
|
||||
if udpgwLn != nil {
|
||||
_ = udpgwLn.Close()
|
||||
udpgwLn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// startUDPGW starts the integrated UDP gateway if cfg is non‑nil and
|
||||
// cfg.Listen is non‑empty. It applies default values to any zero
|
||||
// configuration fields and converts duration strings to time.Duration.
|
||||
// The server runs in a goroutine; any fatal errors are logged and
|
||||
// prevent the gateway from starting, but do not terminate the main
|
||||
// process.
|
||||
func startUDPGW(cfg *UDPGWConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
// Default the listen address to the standalone default (0.0.0.0:7400) if
|
||||
// unspecified. This matches the behaviour of the original
|
||||
// badvpn-udpgw program, which listens on all interfaces by default.
|
||||
listenAddr := cfg.Listen
|
||||
if listenAddr == "" {
|
||||
listenAddr = "0.0.0.0:7400"
|
||||
}
|
||||
// Apply defaults for numeric fields if zero.
|
||||
c := &internalUDPGWConfig{}
|
||||
c.listen = listenAddr
|
||||
if cfg.MaxFrame > 0 {
|
||||
c.maxFrame = cfg.MaxFrame
|
||||
} else {
|
||||
c.maxFrame = 64 * 1024
|
||||
}
|
||||
c.debug = cfg.Debug
|
||||
if cfg.HexdumpN > 0 {
|
||||
c.hexdumpN = cfg.HexdumpN
|
||||
} else {
|
||||
c.hexdumpN = 64
|
||||
}
|
||||
if cfg.WriteChan > 0 {
|
||||
c.writeChan = cfg.WriteChan
|
||||
} else {
|
||||
c.writeChan = 4096
|
||||
}
|
||||
c.udpBindIP = cfg.UDPBindIP
|
||||
if cfg.UDPRBuf > 0 {
|
||||
c.udpRBuf = cfg.UDPRBuf
|
||||
} else {
|
||||
c.udpRBuf = 8 * 1024 * 1024
|
||||
}
|
||||
if cfg.UDPWBuf > 0 {
|
||||
c.udpWBuf = cfg.UDPWBuf
|
||||
} else {
|
||||
c.udpWBuf = 8 * 1024 * 1024
|
||||
}
|
||||
// Parse durations with fallback defaults.
|
||||
if cfg.MapTTL != "" {
|
||||
if d, err := time.ParseDuration(cfg.MapTTL); err == nil {
|
||||
c.mapTTL = d
|
||||
} else {
|
||||
log.Printf("udpgw: invalid map_ttl %q: %v; using default 90s", cfg.MapTTL, err)
|
||||
c.mapTTL = 90 * time.Second
|
||||
}
|
||||
} else {
|
||||
c.mapTTL = 90 * time.Second
|
||||
}
|
||||
if cfg.ReapEvery != "" {
|
||||
if d, err := time.ParseDuration(cfg.ReapEvery); err == nil {
|
||||
c.reapEvery = d
|
||||
} else {
|
||||
log.Printf("udpgw: invalid reap_every %q: %v; using default 10s", cfg.ReapEvery, err)
|
||||
c.reapEvery = 10 * time.Second
|
||||
}
|
||||
} else {
|
||||
c.reapEvery = 10 * time.Second
|
||||
}
|
||||
// Idle timeout for TCP clients.
|
||||
if cfg.IdleTimeout != "" {
|
||||
if d, err := time.ParseDuration(cfg.IdleTimeout); err == nil {
|
||||
c.idleTimeout = d
|
||||
} else {
|
||||
log.Printf("udpgw: invalid idle_timeout %q: %v; using default 2m", cfg.IdleTimeout, err)
|
||||
c.idleTimeout = 2 * time.Minute
|
||||
}
|
||||
} else {
|
||||
c.idleTimeout = 2 * time.Minute
|
||||
}
|
||||
// Per-client logical connID cap.
|
||||
if cfg.MaxClientConns > 0 {
|
||||
c.maxClientConns = cfg.MaxClientConns
|
||||
} else {
|
||||
c.maxClientConns = 10
|
||||
}
|
||||
// Per-client destination mapping cap.
|
||||
if cfg.MaxMapEntries > 0 {
|
||||
c.maxMapEntries = cfg.MaxMapEntries
|
||||
} else {
|
||||
c.maxMapEntries = 32768
|
||||
}
|
||||
// Start listening.
|
||||
ln, err := net.Listen("tcp", c.listen)
|
||||
if err != nil {
|
||||
log.Printf("udpgw: listen failed on %s: %v", c.listen, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Register as the active listener so stopUDPGW can close it.
|
||||
udpgwMu.Lock()
|
||||
if udpgwLn != nil {
|
||||
_ = udpgwLn.Close()
|
||||
}
|
||||
udpgwLn = ln
|
||||
udpgwMu.Unlock()
|
||||
|
||||
if c.debug {
|
||||
log.Printf("udpgw: listening on %s", c.listen)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("udpgw: accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go handleUDPGWClient(conn, c)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// internalUDPGWConfig mirrors the exported UDPGWConfig but with
|
||||
// time.Duration fields for TTL and reaper intervals. It does not
|
||||
// embed JSON tags because it is not exposed to the user.
|
||||
type internalUDPGWConfig struct {
|
||||
listen string
|
||||
maxFrame int
|
||||
debug bool
|
||||
hexdumpN int
|
||||
writeChan int
|
||||
udpBindIP string
|
||||
udpRBuf int
|
||||
udpWBuf int
|
||||
mapTTL time.Duration
|
||||
reapEvery time.Duration
|
||||
idleTimeout time.Duration
|
||||
maxClientConns int
|
||||
maxMapEntries int
|
||||
}
|
||||
|
||||
// udpDestKey identifies a destination IPv4:port for the UDP gateway. A
|
||||
// mapping of this key to a connID+x byte is kept per client so
|
||||
// replies can be routed correctly. Each client maintains its own map
|
||||
// of udpDestKey->udpMapVal entries.
|
||||
type udpDestKey struct {
|
||||
ip [4]byte
|
||||
port uint16
|
||||
}
|
||||
|
||||
// udpMapVal stores the mapping from udpDestKey to connID, the x byte and the
|
||||
// expiration time.
|
||||
type udpMapVal struct {
|
||||
connID uint16
|
||||
x byte
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
// handleUDPGWClient manages a single TCP client connection to the UDP
|
||||
// gateway. It creates a per‑client UDP socket, reads frames from the
|
||||
// TCP connection, sends UDP datagrams to the requested destination,
|
||||
// maintains a mapping of udpDestKey->udpMapVal for routing replies, and
|
||||
// writes reply frames back over the TCP connection. When the client
|
||||
// disconnects or an error occurs, all goroutines are terminated and the
|
||||
// UDP socket is closed.
|
||||
func handleUDPGWClient(conn net.Conn, c *internalUDPGWConfig) {
|
||||
defer conn.Close()
|
||||
remote := conn.RemoteAddr().String()
|
||||
if c.debug {
|
||||
log.Printf("udpgw: client connected: %s", remote)
|
||||
}
|
||||
// Lower latency for interactive applications by disabling Nagle.
|
||||
if tcp, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tcp.SetNoDelay(true)
|
||||
}
|
||||
br := bufio.NewReaderSize(conn, 256*1024)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Bind a UDP socket for this client. Use cfg.udpBindIP if provided.
|
||||
var laddr *net.UDPAddr
|
||||
if c.udpBindIP != "" {
|
||||
ip := net.ParseIP(c.udpBindIP)
|
||||
if ip != nil {
|
||||
laddr = &net.UDPAddr{IP: ip, Port: 0}
|
||||
} else {
|
||||
log.Printf("udpgw[%s]: invalid udp_bind IP %q", remote, c.udpBindIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
log.Printf("udpgw[%s]: UDP listen failed: %v", remote, err)
|
||||
return
|
||||
}
|
||||
defer udpConn.Close()
|
||||
_ = udpConn.SetReadBuffer(c.udpRBuf)
|
||||
_ = udpConn.SetWriteBuffer(c.udpWBuf)
|
||||
// Channel to queue outgoing frames back to the client.
|
||||
writeCh := make(chan []byte, c.writeChan)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case b := <-writeCh:
|
||||
if len(b) == 0 {
|
||||
continue
|
||||
}
|
||||
// If the client stops reading, don't block forever.
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
_, err := conn.Write(b)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Destination -> connID+x mapping and active connID tracking per TCP client.
|
||||
var mu sync.Mutex
|
||||
destToConn := make(map[udpDestKey]udpMapVal)
|
||||
connIDLastSeen := make(map[uint16]time.Time)
|
||||
// Start a reaper goroutine to purge expired mappings.
|
||||
// Uses ctx instead of a separate stopReaper channel so that cancellation
|
||||
// is always uniform: any exit path that calls cancel() (error, disconnect,
|
||||
// idle timeout, panic-deferred cancel) also stops the reaper.
|
||||
go func() {
|
||||
t := time.NewTicker(c.reapEvery)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
now := time.Now()
|
||||
mu.Lock()
|
||||
for k, v := range destToConn {
|
||||
if now.After(v.exp) {
|
||||
delete(destToConn, k)
|
||||
}
|
||||
}
|
||||
for id, lastSeen := range connIDLastSeen {
|
||||
if now.Sub(lastSeen) > c.mapTTL {
|
||||
delete(connIDLastSeen, id)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Goroutine to read UDP replies and write framed responses back.
|
||||
go func() {
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, from, err := udpConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
continue
|
||||
}
|
||||
ip4 := from.IP.To4()
|
||||
if ip4 == nil {
|
||||
// IPv6 replies are dropped because framing only supports IPv4.
|
||||
continue
|
||||
}
|
||||
k := udpDestKey{port: uint16(from.Port)}
|
||||
copy(k.ip[:], ip4)
|
||||
mu.Lock()
|
||||
v, ok := destToConn[k]
|
||||
mu.Unlock()
|
||||
if !ok || time.Now().After(v.exp) {
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: dropping %dB from %s (no mapping)", remote, n, from.String())
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Avoid unbounded memory growth: if the client is slow and the
|
||||
// write queue is full, drop replies rather than allocating more.
|
||||
if len(writeCh) == cap(writeCh) {
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: drop UDP reply %dB from %s (write queue full)", remote, n, from.String())
|
||||
}
|
||||
continue
|
||||
}
|
||||
frame := udpgwBuildFrame(v.connID, v.x, k.ip, k.port, buf[:n])
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case writeCh <- frame:
|
||||
default:
|
||||
// Race with another sender filling the channel.
|
||||
}
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: UDP<- %dB from %s -> connID=%d x=0x%02x", remote, n, from.String(), v.connID, v.x)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Main loop: read frames from TCP, update mapping and send UDP.
|
||||
for {
|
||||
// Close idle TCP clients to avoid leaking goroutines.
|
||||
if c.idleTimeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
payload, err := udpgwReadPayload(br, c.maxFrame)
|
||||
if err != nil {
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: idle timeout (%s); closing", remote, c.idleTimeout)
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
if c.debug {
|
||||
log.Printf("udpgw: client disconnected: %s", remote)
|
||||
}
|
||||
} else {
|
||||
log.Printf("udpgw[%s]: read error: %v", remote, err)
|
||||
}
|
||||
cancel()
|
||||
_ = conn.Close()
|
||||
<-done
|
||||
return
|
||||
}
|
||||
// payload: connID(2) + X(1) + dstIPv4(4) + dstPort(2) + data
|
||||
if len(payload) < 2+1+4+2 {
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: too short payload %dB", remote, len(payload))
|
||||
}
|
||||
continue
|
||||
}
|
||||
connID := binary.BigEndian.Uint16(payload[0:2])
|
||||
x := payload[2]
|
||||
var dstIP [4]byte
|
||||
copy(dstIP[:], payload[3:7])
|
||||
dstPort := binary.BigEndian.Uint16(payload[7:9])
|
||||
data := payload[9:]
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: RX connID=%d dst=%d.%d.%d.%d:%d x=0x%02x len=%d", remote, connID, dstIP[0], dstIP[1], dstIP[2], dstIP[3], dstPort, x, len(data))
|
||||
}
|
||||
// Enforce a per-client logical session cap using udpgw connID.
|
||||
k := udpDestKey{ip: dstIP, port: dstPort}
|
||||
now := time.Now()
|
||||
mu.Lock()
|
||||
// Step 1: evict TTL-expired connIDs first.
|
||||
for id, lastSeen := range connIDLastSeen {
|
||||
if now.Sub(lastSeen) > c.mapTTL {
|
||||
delete(connIDLastSeen, id)
|
||||
}
|
||||
}
|
||||
// Step 2: if the connID is new and we are still at or above the cap
|
||||
// after TTL eviction, evict the single oldest entry to make room.
|
||||
// This prevents a client rotating connIDs faster than mapTTL from
|
||||
// bypassing maxClientConns and growing the map indefinitely.
|
||||
if _, alreadyKnown := connIDLastSeen[connID]; !alreadyKnown && c.maxClientConns > 0 && len(connIDLastSeen) >= c.maxClientConns {
|
||||
// Find and remove the oldest connID seen.
|
||||
var oldestID uint16
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
for id, lastSeen := range connIDLastSeen {
|
||||
if first || lastSeen.Before(oldestTime) {
|
||||
oldestID = id
|
||||
oldestTime = lastSeen
|
||||
first = false
|
||||
}
|
||||
}
|
||||
delete(connIDLastSeen, oldestID)
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: connID cap reached (%d); evicted oldest connID=%d to admit connID=%d", remote, c.maxClientConns, oldestID, connID)
|
||||
}
|
||||
}
|
||||
connIDLastSeen[connID] = now
|
||||
// Update the mapping (bounded to protect memory).
|
||||
if c.maxMapEntries > 0 && len(destToConn) >= c.maxMapEntries {
|
||||
// First, purge expired entries.
|
||||
for dk, dv := range destToConn {
|
||||
if now.After(dv.exp) {
|
||||
delete(destToConn, dk)
|
||||
}
|
||||
}
|
||||
// If still too large, evict arbitrary entries until we're under the cap.
|
||||
if len(destToConn) >= c.maxMapEntries {
|
||||
evict := (len(destToConn) - c.maxMapEntries) + 1
|
||||
for dk := range destToConn {
|
||||
delete(destToConn, dk)
|
||||
evict--
|
||||
if evict <= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: mapping cap reached; evicted entries, size=%d", remote, len(destToConn))
|
||||
}
|
||||
}
|
||||
}
|
||||
destToConn[k] = udpMapVal{connID: connID, x: x, exp: now.Add(c.mapTTL)}
|
||||
mu.Unlock()
|
||||
// Send the UDP datagram.
|
||||
raddr := &net.UDPAddr{IP: net.IP(dstIP[:]), Port: int(dstPort)}
|
||||
_, err = udpConn.WriteToUDP(data, raddr)
|
||||
if err != nil {
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: UDP write failed to %s: %v", remote, raddr.String(), err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c.debug {
|
||||
log.Printf("udpgw[%s]: UDP-> %dB to %s", remote, len(data), raddr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// udpgwReadPayload reads a length‑prefixed payload from r. The length
|
||||
// prefix is a little‑endian uint16. Payloads larger than max cause an
|
||||
// error.
|
||||
func udpgwReadPayload(r *bufio.Reader, max int) ([]byte, error) {
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := int(binary.LittleEndian.Uint16(lenBuf[:]))
|
||||
if n <= 0 || n > max {
|
||||
return nil, fmt.Errorf("udpgw: invalid frame length %d", n)
|
||||
}
|
||||
b := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// udpgwBuildFrame constructs a reply frame for the client. The frame
|
||||
// consists of a little‑endian length prefix, then connID (big endian),
|
||||
// x byte, source IPv4, source port, and the data.
|
||||
func udpgwBuildFrame(connID uint16, x byte, ip [4]byte, port uint16, data []byte) []byte {
|
||||
payloadLen := 2 + 1 + 4 + 2 + len(data)
|
||||
out := make([]byte, 2+payloadLen)
|
||||
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
|
||||
binary.BigEndian.PutUint16(out[2:4], connID)
|
||||
out[4] = x
|
||||
copy(out[5:9], ip[:])
|
||||
binary.BigEndian.PutUint16(out[9:11], port)
|
||||
copy(out[11:], data)
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user