Launch
This commit is contained in:
353
internal/engine/udpgw_client.go
Normal file
353
internal/engine/udpgw_client.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
udpgwMaxFrame = 64 * 1024
|
||||
|
||||
udpgwProtocolBadVPN = "badvpn"
|
||||
udpgwProtocolLegacy = "legacy"
|
||||
|
||||
udpgwFlagKeepAlive = 1 << 0
|
||||
udpgwFlagRebind = 1 << 1
|
||||
udpgwFlagDNS = 1 << 2
|
||||
udpgwFlagIPv6 = 1 << 3
|
||||
)
|
||||
|
||||
type udpgwSession struct {
|
||||
mu sync.Mutex
|
||||
nextID uint16
|
||||
pairID map[string]uint16
|
||||
idClient map[uint16]*net.UDPAddr
|
||||
}
|
||||
|
||||
func (s *SocksServer) handleUDPGWAssociate(c net.Conn) {
|
||||
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
if err != nil {
|
||||
_ = writeReply(c, socksRepFail)
|
||||
s.Logger.Add("warn", "socks UDP associate failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer udp.Close()
|
||||
|
||||
gwAddr := net.JoinHostPort(s.UDPGW.Host, strconv.Itoa(s.UDPGW.Port))
|
||||
gwConn, err := dialSSHDirectTCPAddr(s.SSH, gwAddr, c.RemoteAddr(), s.Logger)
|
||||
if err != nil {
|
||||
_ = writeReply(c, socksRepFail)
|
||||
s.Logger.Add("warn", "udpgw dial through SSH failed %s: %v", gwAddr, err)
|
||||
return
|
||||
}
|
||||
defer gwConn.Close()
|
||||
|
||||
if err := writeReplyWithAddr(c, socksRepOK, udp.LocalAddr().(*net.UDPAddr)); err != nil {
|
||||
s.Logger.Add("debug", "socks UDP associate reply failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = c.SetDeadline(time.Time{})
|
||||
proto := normalizeUDPGWProtocol(s.UDPGW.Protocol)
|
||||
s.Logger.Add("info", "SOCKS5 UDP associate using UDPGW %s over SSH; protocol=%s local UDP=%s", gwAddr, proto, udp.LocalAddr().String())
|
||||
|
||||
done := make(chan struct{})
|
||||
closeOnce := sync.Once{}
|
||||
stop := func() {
|
||||
closeOnce.Do(func() {
|
||||
_ = c.Close()
|
||||
_ = gwConn.Close()
|
||||
_ = udp.Close()
|
||||
close(done)
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(io.Discard, c)
|
||||
stop()
|
||||
}()
|
||||
|
||||
sess := &udpgwSession{
|
||||
nextID: 1,
|
||||
pairID: make(map[string]uint16),
|
||||
idClient: make(map[uint16]*net.UDPAddr),
|
||||
}
|
||||
|
||||
go s.readUDPGWReplies(done, proto, gwConn, udp, sess)
|
||||
s.forwardUDPToUDPGW(done, proto, gwConn, udp, sess)
|
||||
stop()
|
||||
}
|
||||
|
||||
func normalizeUDPGWProtocol(v string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case udpgwProtocolLegacy:
|
||||
return udpgwProtocolLegacy
|
||||
default:
|
||||
return udpgwProtocolBadVPN
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) forwardUDPToUDPGW(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
_ = udp.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
n, clientAddr, err := udp.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
s.Logger.Add("debug", "udpgw local UDP read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
addr, payload, err := parseSocksUDP(buf[:n])
|
||||
if err != nil {
|
||||
if shouldLogUDPError(err) {
|
||||
s.Logger.Add("debug", "udpgw parse SOCKS UDP failed: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
connID, isNew, err := sess.idFor(clientAddr, addr)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "udpgw session id failed for %s: %v", addr.Addr(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
frame, err := udpgwBuildClientFrame(proto, connID, isNew, addr, payload)
|
||||
if err != nil {
|
||||
if shouldLogUDPError(err) {
|
||||
s.Logger.Add("debug", "udpgw frame build failed for %s: %v", addr.Addr(), err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = gwConn.SetWriteDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := gwConn.Write(frame); err != nil {
|
||||
s.Logger.Add("warn", "udpgw TCP write failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocksServer) readUDPGWReplies(done <-chan struct{}, proto string, gwConn net.Conn, udp *net.UDPConn, sess *udpgwSession) {
|
||||
br := bufio.NewReaderSize(gwConn, 256*1024)
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
payload, err := udpgwReadFrame(br, udpgwMaxFrame)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.Logger.Add("debug", "udpgw TCP read failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
connID, src, data, err := udpgwParseReplyPayload(proto, payload)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "udpgw bad reply frame: %v", err)
|
||||
continue
|
||||
}
|
||||
clientAddr := sess.clientFor(connID)
|
||||
if clientAddr == nil {
|
||||
continue
|
||||
}
|
||||
packet, err := buildSocksUDP(src, data)
|
||||
if err != nil {
|
||||
s.Logger.Add("debug", "udpgw SOCKS UDP reply build failed: %v", err)
|
||||
continue
|
||||
}
|
||||
_, _ = udp.WriteToUDP(packet, clientAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpgwSession) idFor(client *net.UDPAddr, dest socksRequest) (uint16, bool, error) {
|
||||
if client == nil {
|
||||
return 0, false, errors.New("missing local UDP client address")
|
||||
}
|
||||
key := client.String() + "|" + dest.Addr()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if id, ok := s.pairID[key]; ok {
|
||||
s.idClient[id] = cloneUDPAddr(client)
|
||||
return id, false, nil
|
||||
}
|
||||
id := s.nextID
|
||||
if id == 0 {
|
||||
id = 1
|
||||
}
|
||||
s.nextID = id + 1
|
||||
if s.nextID == 0 {
|
||||
s.nextID = 1
|
||||
}
|
||||
s.pairID[key] = id
|
||||
s.idClient[id] = cloneUDPAddr(client)
|
||||
return id, true, nil
|
||||
}
|
||||
|
||||
func (s *udpgwSession) clientFor(id uint16) *net.UDPAddr {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return cloneUDPAddr(s.idClient[id])
|
||||
}
|
||||
|
||||
func cloneUDPAddr(a *net.UDPAddr) *net.UDPAddr {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
out := *a
|
||||
if a.IP != nil {
|
||||
out.IP = append(net.IP(nil), a.IP...)
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func udpgwReadFrame(r *bufio.Reader, max int) ([]byte, error) {
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(r, lenBuf[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := int(binary.LittleEndian.Uint16(lenBuf[:]))
|
||||
if n <= 0 || n > max {
|
||||
return nil, fmt.Errorf("invalid udpgw frame length %d", n)
|
||||
}
|
||||
b := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func udpgwBuildClientFrame(proto string, connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
|
||||
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
|
||||
return udpgwBuildLegacyFrame(connID, 0, dest, data)
|
||||
}
|
||||
return udpgwBuildBadVPNFrame(connID, isNew, dest, data)
|
||||
}
|
||||
|
||||
func udpgwParseReplyPayload(proto string, payload []byte) (uint16, socksRequest, []byte, error) {
|
||||
if normalizeUDPGWProtocol(proto) == udpgwProtocolLegacy {
|
||||
return udpgwParseLegacyReplyPayload(payload)
|
||||
}
|
||||
return udpgwParseBadVPNReplyPayload(payload)
|
||||
}
|
||||
|
||||
// udpgwBuildBadVPNFrame implements the normal badvpn-udpgw PacketProto frame:
|
||||
// length(little-endian uint16) + flags(1) + conid(little-endian uint16) +
|
||||
// IPv4/IPv6 destination + destination port(network byte order) + UDP payload.
|
||||
// This is the same framing used by the Android badvpn UDPGW client and supports
|
||||
// IPv6 targets with UDPGW_CLIENT_FLAG_IPV6.
|
||||
func udpgwBuildBadVPNFrame(connID uint16, isNew bool, dest socksRequest, data []byte) ([]byte, error) {
|
||||
ip := net.ParseIP(dest.Host)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("badvpn UDPGW requires an IP target, got %q", dest.Host)
|
||||
}
|
||||
flags := byte(0)
|
||||
if isNew {
|
||||
flags |= udpgwFlagRebind
|
||||
}
|
||||
|
||||
var addr []byte
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
addr = append(addr, v4...)
|
||||
} else if v6 := ip.To16(); v6 != nil {
|
||||
flags |= udpgwFlagIPv6
|
||||
addr = append(addr, v6...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid UDPGW target IP %q", dest.Host)
|
||||
}
|
||||
|
||||
payloadLen := 1 + 2 + len(addr) + 2 + len(data)
|
||||
if payloadLen <= 0 || payloadLen > 65535 {
|
||||
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
|
||||
}
|
||||
out := make([]byte, 2+payloadLen)
|
||||
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
|
||||
out[2] = flags
|
||||
binary.LittleEndian.PutUint16(out[3:5], connID)
|
||||
copy(out[5:5+len(addr)], addr)
|
||||
portOff := 5 + len(addr)
|
||||
binary.BigEndian.PutUint16(out[portOff:portOff+2], uint16(dest.Port))
|
||||
copy(out[portOff+2:], data)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func udpgwParseBadVPNReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
|
||||
var src socksRequest
|
||||
if len(payload) < 1+2+4+2 {
|
||||
return 0, src, nil, fmt.Errorf("short badvpn udpgw payload %d", len(payload))
|
||||
}
|
||||
flags := payload[0]
|
||||
if flags&udpgwFlagKeepAlive != 0 {
|
||||
return 0, src, nil, errors.New("unexpected udpgw keepalive reply")
|
||||
}
|
||||
connID := binary.LittleEndian.Uint16(payload[1:3])
|
||||
off := 3
|
||||
if flags&udpgwFlagIPv6 != 0 {
|
||||
if len(payload) < off+16+2 {
|
||||
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv6 payload %d", len(payload))
|
||||
}
|
||||
src.Host = net.IP(payload[off : off+16]).String()
|
||||
off += 16
|
||||
} else {
|
||||
if len(payload) < off+4+2 {
|
||||
return 0, src, nil, fmt.Errorf("short badvpn udpgw ipv4 payload %d", len(payload))
|
||||
}
|
||||
src.Host = net.IP(payload[off : off+4]).String()
|
||||
off += 4
|
||||
}
|
||||
src.Port = int(binary.BigEndian.Uint16(payload[off : off+2]))
|
||||
off += 2
|
||||
return connID, src, payload[off:], nil
|
||||
}
|
||||
|
||||
// Legacy frame kept only for older experimental PC builds. It is IPv4-only.
|
||||
func udpgwBuildLegacyFrame(connID uint16, x byte, dest socksRequest, data []byte) ([]byte, error) {
|
||||
ip := net.ParseIP(dest.Host)
|
||||
v4 := ip.To4()
|
||||
if v4 == nil {
|
||||
return nil, fmt.Errorf("legacy UDPGW only supports IPv4 UDP targets; got %s", dest.Addr())
|
||||
}
|
||||
payloadLen := 2 + 1 + 4 + 2 + len(data)
|
||||
if payloadLen <= 0 || payloadLen > 65535 {
|
||||
return nil, fmt.Errorf("UDPGW payload too large: %d", payloadLen)
|
||||
}
|
||||
out := make([]byte, 2+payloadLen)
|
||||
binary.LittleEndian.PutUint16(out[0:2], uint16(payloadLen))
|
||||
binary.BigEndian.PutUint16(out[2:4], connID)
|
||||
out[4] = x
|
||||
copy(out[5:9], v4)
|
||||
binary.BigEndian.PutUint16(out[9:11], uint16(dest.Port))
|
||||
copy(out[11:], data)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func udpgwParseLegacyReplyPayload(payload []byte) (uint16, socksRequest, []byte, error) {
|
||||
var src socksRequest
|
||||
if len(payload) < 2+1+4+2 {
|
||||
return 0, src, nil, fmt.Errorf("short legacy udpgw payload %d", len(payload))
|
||||
}
|
||||
connID := binary.BigEndian.Uint16(payload[0:2])
|
||||
src.Host = net.IP(payload[3:7]).String()
|
||||
src.Port = int(binary.BigEndian.Uint16(payload[7:9]))
|
||||
return connID, src, payload[9:], nil
|
||||
}
|
||||
Reference in New Issue
Block a user