diff --git a/internal/tunnel/client.go b/internal/tunnel/client.go index 1c59c67..c15d539 100644 --- a/internal/tunnel/client.go +++ b/internal/tunnel/client.go @@ -33,7 +33,6 @@ func Client(parsedURL *url.URL) error { } defer linkConn.Close() if err := linkConn.Handshake(); err != nil { - linkConn.Close() return err } log.Info("Tunnel connection established to: [%v]", linkAddr) diff --git a/internal/tunnel/server.go b/internal/tunnel/server.go index 905e948..f1c35a2 100644 --- a/internal/tunnel/server.go +++ b/internal/tunnel/server.go @@ -53,6 +53,7 @@ func Server(parsedURL *url.URL, whiteList *sync.Map, tlsConfig *tls.Config) erro log.Info("Tunnel connection established from: [%v]", linkConn.RemoteAddr().String()) var sharedMU sync.Mutex errChan := make(chan error, 2) + done := make(chan struct{}) go func() { for { time.Sleep(internal.MaxReportInterval * time.Second) @@ -60,18 +61,20 @@ func Server(parsedURL *url.URL, whiteList *sync.Map, tlsConfig *tls.Config) erro _, err = linkTLS.Write([]byte("[REPORT]\n")) sharedMU.Unlock() if err != nil { - log.Error("TLS connection health check failed: %v", err) + log.Error("Tunnel connection health check failed: %v", err) linkTLS.Close() + linkListen.Close() + close(done) errChan <- err - return + break } } }() go func() { - errChan <- ServeTCP(parsedURL, whiteList, linkAddr, targetTCPAddr, linkListen, linkTLS, &sharedMU) + errChan <- ServeTCP(parsedURL, whiteList, linkAddr, targetTCPAddr, linkListen, linkTLS, &sharedMU, done) }() go func() { - errChan <- ServeUDP(parsedURL, whiteList, linkAddr, targetUDPAddr, linkListen, linkTLS, &sharedMU) + errChan <- ServeUDP(parsedURL, whiteList, linkAddr, targetUDPAddr, linkListen, linkTLS, &sharedMU, done) }() return <-errChan } diff --git a/internal/tunnel/tcp.go b/internal/tunnel/tcp.go index 257f2ae..56b4f35 100644 --- a/internal/tunnel/tcp.go +++ b/internal/tunnel/tcp.go @@ -13,75 +13,83 @@ import ( "github.com/yosebyte/passport/pkg/log" ) -func ServeTCP(parsedURL *url.URL, whiteList *sync.Map, linkAddr, targetAddr *net.TCPAddr, linkListen net.Listener, linkTLS *tls.Conn, mu *sync.Mutex) error { - targetListen, err := net.ListenTCP("tcp", targetAddr) - if err != nil { - log.Error("Unable to listen target address: [%v]", targetAddr) - return err - } - defer targetListen.Close() - sem := make(chan struct{}, internal.MaxSemaphoreLimit) +func ServeTCP(parsedURL *url.URL, whiteList *sync.Map, linkAddr, targetAddr *net.TCPAddr, linkListen net.Listener, linkTLS *tls.Conn, mu *sync.Mutex, done <-chan struct{}) error { for { - targetConn, err := targetListen.AcceptTCP() - if err != nil { - log.Error("Unable to accept connections form target address: [%v] %v", targetAddr, err) - time.Sleep(1 * time.Second) - continue - } - clientAddr := targetConn.RemoteAddr().String() - log.Info("Target connection established from: [%v]", clientAddr) - if parsedURL.Fragment != "" { - clientIP, _, err := net.SplitHostPort(clientAddr) - if err != nil { - log.Error("Unable to extract client IP address: [%v] %v", clientAddr, err) - targetConn.Close() - time.Sleep(1 * time.Second) - continue - } - if _, exists := whiteList.Load(clientIP); !exists { - log.Warn("Unauthorized IP address blocked: [%v]", clientIP) - targetConn.Close() - continue - } - } - sem <- struct{}{} - go func(targetConn *net.TCPConn) { - defer func() { <-sem }() - mu.Lock() - _, err = linkTLS.Write([]byte("[PASSPORT]\n")) - mu.Unlock() + select { + case <-done: + log.Warn("TCP server received shutdown signal") + return nil + default: + targetListen, err := net.ListenTCP("tcp", targetAddr) if err != nil { - log.Error("Unable to send signal: %v", err) - targetConn.Close() - return + log.Error("Unable to listen target address: [%v]", targetAddr) + return err } - remoteConn, err := linkListen.Accept() - if err != nil { - log.Error("Unable to accept connections form link address: [%v] %v", linkAddr, err) - return - } - remoteTLS, ok := remoteConn.(*tls.Conn) - if !ok { - log.Error("Non-TLS connection received") - targetConn.Close() - remoteConn.Close() - return - } - if err := remoteTLS.Handshake(); err != nil { - log.Error("TLS handshake failed: %v", err) - targetConn.Close() - remoteTLS.Close() - return - } - log.Info("Starting data exchange: [%v] <-> [%v]", clientAddr, targetAddr) - if err := conn.DataExchange(remoteTLS, targetConn); err != nil { - if err == io.EOF { - log.Info("Connection closed successfully: %v", err) - } else { - log.Warn("Connection closed unexpectedly: %v", err) + defer targetListen.Close() + sem := make(chan struct{}, internal.MaxSemaphoreLimit) + for { + targetConn, err := targetListen.AcceptTCP() + if err != nil { + log.Error("Unable to accept connections form target address: [%v] %v", targetAddr, err) + time.Sleep(1 * time.Second) + continue + } + clientAddr := targetConn.RemoteAddr().String() + log.Info("Target connection established from: [%v]", clientAddr) + if parsedURL.Fragment != "" { + clientIP, _, err := net.SplitHostPort(clientAddr) + if err != nil { + log.Error("Unable to extract client IP address: [%v] %v", clientAddr, err) + targetConn.Close() + time.Sleep(1 * time.Second) + continue + } + if _, exists := whiteList.Load(clientIP); !exists { + log.Warn("Unauthorized IP address blocked: [%v]", clientIP) + targetConn.Close() + continue + } } + sem <- struct{}{} + go func(targetConn *net.TCPConn) { + defer func() { <-sem }() + mu.Lock() + _, err = linkTLS.Write([]byte("[PASSPORT]\n")) + mu.Unlock() + if err != nil { + log.Error("Unable to send signal: %v", err) + targetConn.Close() + return + } + remoteConn, err := linkListen.Accept() + if err != nil { + log.Error("Unable to accept connections form link address: [%v] %v", linkAddr, err) + return + } + remoteTLS, ok := remoteConn.(*tls.Conn) + if !ok { + log.Error("Non-TLS connection received") + targetConn.Close() + remoteConn.Close() + return + } + if err := remoteTLS.Handshake(); err != nil { + log.Error("TLS handshake failed: %v", err) + targetConn.Close() + remoteTLS.Close() + return + } + log.Info("Starting data exchange: [%v] <-> [%v]", clientAddr, targetAddr) + if err := conn.DataExchange(remoteTLS, targetConn); err != nil { + if err == io.EOF { + log.Info("Connection closed successfully: %v", err) + } else { + log.Warn("Connection closed unexpectedly: %v", err) + } + } + }(targetConn) } - }(targetConn) + } } } diff --git a/internal/tunnel/udp.go b/internal/tunnel/udp.go index 656901b..9b6d046 100644 --- a/internal/tunnel/udp.go +++ b/internal/tunnel/udp.go @@ -11,80 +11,88 @@ import ( "github.com/yosebyte/passport/pkg/log" ) -func ServeUDP(parsedURL *url.URL, whiteList *sync.Map, linkAddr *net.TCPAddr, targetAddr *net.UDPAddr, linkListen net.Listener, linkTLS *tls.Conn, mu *sync.Mutex) error { - targetConn, err := net.ListenUDP("udp", targetAddr) - if err != nil { - log.Error("Unable to listen target address: [%v]", targetAddr) - return err - } - defer targetConn.Close() - sem := make(chan struct{}, internal.MaxSemaphoreLimit) +func ServeUDP(parsedURL *url.URL, whiteList *sync.Map, linkAddr *net.TCPAddr, targetAddr *net.UDPAddr, linkListen net.Listener, linkTLS *tls.Conn, mu *sync.Mutex, done <-chan struct{}) error { for { - buffer := make([]byte, internal.MaxDataBuffer) - n, clientAddr, err := targetConn.ReadFromUDP(buffer) - if err != nil { - log.Error("Unable to read from client address: [%v] %v", clientAddr, err) - time.Sleep(1 * time.Second) - continue - } - if parsedURL.Fragment != "" { - clientIP := clientAddr.IP.String() - if _, exists := whiteList.Load(clientIP); !exists { - log.Warn("Unauthorized IP address blocked: [%v]", clientIP) - continue - } - } - mu.Lock() - _, err = linkTLS.Write([]byte("[PASSPORT]\n")) - mu.Unlock() - if err != nil { - log.Error("Unable to send signal: %v", err) - time.Sleep(1 * time.Second) - continue - } - remoteConn, err := linkListen.Accept() - if err != nil { - log.Error("Unable to accept connections from link address: [%v] %v", linkAddr, err) - time.Sleep(1 * time.Second) - continue - } - remoteTLS, ok := remoteConn.(*tls.Conn) - if !ok { - log.Error("Non-TLS connection received") - remoteConn.Close() - time.Sleep(1 * time.Second) - continue - } - if err := remoteTLS.Handshake(); err != nil { - log.Error("TLS handshake failed: %v", err) - remoteTLS.Close() - time.Sleep(1 * time.Second) - continue - } - sem <- struct{}{} - go func(buffer []byte, n int, remoteTLS *tls.Conn, clientAddr *net.UDPAddr) { - defer func() { - <-sem - remoteTLS.Close() - }() - log.Info("Starting data transfer: [%v] <-> [%v]", clientAddr, targetAddr) - _, err = remoteTLS.Write(buffer[:n]) + select { + case <-done: + log.Warn("UDP server received shutdown signal") + return nil + default: + targetConn, err := net.ListenUDP("udp", targetAddr) if err != nil { - log.Error("Unable to write to link address: [%v] %v", linkAddr, err) - return + log.Error("Unable to listen target address: [%v]", targetAddr) + return err } - n, err = remoteTLS.Read(buffer) - if err != nil { - log.Error("Unable to read from link address: [%v] %v", linkAddr, err) - return + defer targetConn.Close() + sem := make(chan struct{}, internal.MaxSemaphoreLimit) + for { + buffer := make([]byte, internal.MaxDataBuffer) + n, clientAddr, err := targetConn.ReadFromUDP(buffer) + if err != nil { + log.Error("Unable to read from client address: [%v] %v", clientAddr, err) + time.Sleep(1 * time.Second) + continue + } + if parsedURL.Fragment != "" { + clientIP := clientAddr.IP.String() + if _, exists := whiteList.Load(clientIP); !exists { + log.Warn("Unauthorized IP address blocked: [%v]", clientIP) + continue + } + } + mu.Lock() + _, err = linkTLS.Write([]byte("[PASSPORT]\n")) + mu.Unlock() + if err != nil { + log.Error("Unable to send signal: %v", err) + time.Sleep(1 * time.Second) + continue + } + remoteConn, err := linkListen.Accept() + if err != nil { + log.Error("Unable to accept connections from link address: [%v] %v", linkAddr, err) + time.Sleep(1 * time.Second) + continue + } + remoteTLS, ok := remoteConn.(*tls.Conn) + if !ok { + log.Error("Non-TLS connection received") + remoteConn.Close() + time.Sleep(1 * time.Second) + continue + } + if err := remoteTLS.Handshake(); err != nil { + log.Error("TLS handshake failed: %v", err) + remoteTLS.Close() + time.Sleep(1 * time.Second) + continue + } + sem <- struct{}{} + go func(buffer []byte, n int, remoteTLS *tls.Conn, clientAddr *net.UDPAddr) { + defer func() { + <-sem + remoteTLS.Close() + }() + log.Info("Starting data transfer: [%v] <-> [%v]", clientAddr, targetAddr) + _, err = remoteTLS.Write(buffer[:n]) + if err != nil { + log.Error("Unable to write to link address: [%v] %v", linkAddr, err) + return + } + n, err = remoteTLS.Read(buffer) + if err != nil { + log.Error("Unable to read from link address: [%v] %v", linkAddr, err) + return + } + _, err = targetConn.WriteToUDP(buffer[:n], clientAddr) + if err != nil { + log.Error("Unable to write to client address: [%v] %v", clientAddr, err) + return + } + log.Info("Transfer completed successfully") + }(buffer, n, remoteTLS, clientAddr) } - _, err = targetConn.WriteToUDP(buffer[:n], clientAddr) - if err != nil { - log.Error("Unable to write to client address: [%v] %v", clientAddr, err) - return - } - log.Info("Transfer completed successfully") - }(buffer, n, remoteTLS, clientAddr) + } } }