diff --git a/reverse-tunnel.conf.example b/reverse-tunnel.conf.example index 7b73fca..22e3bcc 100644 --- a/reverse-tunnel.conf.example +++ b/reverse-tunnel.conf.example @@ -1,6 +1,8 @@ -tunnel_port=2222 +tunnel_port_range=2224:2242 tunnel_host=0.0.0.0 http_port=2223 external_port_range=20000:30000 host_key_path=hostkey.ser -idle_token_timeout=86400 \ No newline at end of file +idle_token_timeout=86400 +ports_per_ssh_server=5 +check_ssh_servers_interval=120 \ No newline at end of file diff --git a/src/main/java/org/fogbowcloud/ssh/Main.java b/src/main/java/org/fogbowcloud/ssh/Main.java index 4374066..3afea94 100644 --- a/src/main/java/org/fogbowcloud/ssh/Main.java +++ b/src/main/java/org/fogbowcloud/ssh/Main.java @@ -11,27 +11,35 @@ public static void main(String[] args) throws IOException { FileInputStream input = new FileInputStream(args[0]); properties.load(input); - String tunnelPort = properties.getProperty("tunnel_port"); + String tunnelPortRange = properties.getProperty("tunnel_port_range"); + String[] tunnelPortRangeSplit = tunnelPortRange.split(":"); String tunnelHost = properties.getProperty("tunnel_host"); String httpPort = properties.getProperty("http_port"); String externalPortRange = properties.getProperty("external_port_range"); String[] externalRangeSplit = externalPortRange.split(":"); String externalHostKeyPath = properties.getProperty("host_key_path"); String idleTokenTimeoutStr = properties.getProperty("idle_token_timeout"); + String portsPerShhServer = properties.getProperty("ports_per_ssh_server"); Long idleTokenTimeout = null; if (idleTokenTimeoutStr != null) { idleTokenTimeout = Long.parseLong(idleTokenTimeoutStr) * 1000; } + String checkSSHServersIntervalStr = properties.getProperty("check_ssh_servers_interval"); + int checkSSHServersInterval = Integer.parseInt(checkSSHServersIntervalStr); + TunnelHttpServer tunnelHttpServer = new TunnelHttpServer( Integer.parseInt(httpPort), tunnelHost, - Integer.parseInt(tunnelPort), + Integer.parseInt(tunnelPortRangeSplit[0]), + Integer.parseInt(tunnelPortRangeSplit[1]), Integer.parseInt(externalRangeSplit[0]), Integer.parseInt(externalRangeSplit[1]), idleTokenTimeout, - externalHostKeyPath); + externalHostKeyPath, + Integer.parseInt(portsPerShhServer), checkSSHServersInterval); tunnelHttpServer.start(); + } -} +} \ No newline at end of file diff --git a/src/main/java/org/fogbowcloud/ssh/TunnelHttpServer.java b/src/main/java/org/fogbowcloud/ssh/TunnelHttpServer.java index aebad94..390c87d 100644 --- a/src/main/java/org/fogbowcloud/ssh/TunnelHttpServer.java +++ b/src/main/java/org/fogbowcloud/ssh/TunnelHttpServer.java @@ -5,8 +5,19 @@ import java.io.ObjectInputStream; import java.io.UnsupportedEncodingException; import java.security.KeyPair; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import org.apache.log4j.Logger; import org.apache.sshd.common.util.Base64; import org.json.JSONObject; @@ -14,21 +25,71 @@ import fi.iki.elonen.NanoHTTPD.Response.Status; public class TunnelHttpServer extends NanoHTTPD { - - private TunnelServer tunneling; + + //private TunnelServer tunneling; + private static final int SSH_SERVER_VERIFICATION_TIME = 300; + private static final Logger LOGGER = Logger.getLogger(TunnelHttpServer.class); + + private Map tunnelServers = new ConcurrentHashMap(); + private String hostKeyPath; private KeyPair kp; + + private int lowerPort; + private int higherPort; + private String sshTunnelHost; + private int lowerSshTunnelPort; + private int higherSshTunnelPort; + private Long idleTokenTimeout; + private int checkSSHServersInterval; + + private int portsPerShhServer; + + private ScheduledExecutorService executor = Executors.newScheduledThreadPool(1); - public TunnelHttpServer(int httpPort, String sshTunnelHost, int sshTunnelPort, - int lowerPort, int higherPort, Long idleTokenTimeout, String hostKeyPath) { + public TunnelHttpServer(int httpPort, String sshTunnelHost, int lowerSshTunnelPort, int higherSshTunnelPort, + int lowerPort, int higherPort, Long idleTokenTimeout, String hostKeyPath, int portsPerShhServer, int checkSSHServersInterval) { super(httpPort); this.hostKeyPath = hostKeyPath; + + this.lowerPort = lowerPort; + this.higherPort = higherPort; + this.sshTunnelHost = sshTunnelHost; + this.lowerSshTunnelPort = lowerSshTunnelPort; + this.higherSshTunnelPort = higherSshTunnelPort; + this.idleTokenTimeout = idleTokenTimeout; + this.portsPerShhServer = portsPerShhServer; + this.checkSSHServersInterval = checkSSHServersInterval == 0 ? SSH_SERVER_VERIFICATION_TIME : checkSSHServersInterval; + try { - this.tunneling = new TunnelServer(sshTunnelHost, sshTunnelPort, - lowerPort, higherPort, idleTokenTimeout, hostKeyPath); - this.tunneling.start(); + + this.createNewTunnelServer(); + + executor.scheduleWithFixedDelay(new Runnable() { + @Override + public void run() { + + List tunnelsToRemove = new ArrayList(); + + for(Entry entry : tunnelServers.entrySet()){ + if(entry.getValue().getActiveTokensNumber() <= 0){ + tunnelsToRemove.add(entry.getValue()); + } + } + + for(TunnelServer tunneling : tunnelsToRemove){ + try { + removeTunnelServer(tunneling); + } catch (InterruptedException e) { + LOGGER.error(e.getMessage(), e); + } + } + + } + }, this.checkSSHServersInterval, this.checkSSHServersInterval, TimeUnit.SECONDS); + } catch (IOException e) { - e.printStackTrace(); + LOGGER.error(e.getMessage(), e); } } @@ -50,10 +111,13 @@ public Response serve(IHTTPSession session) { if (method.equals(Method.GET)) { if (splitUri.length == 4 && splitUri[3].equals("all")) { - Map ports = this.tunneling.getPortByPrefix(tokenId); + Map ports = new HashMap(); + for(TunnelServer tunneling : tunnelServers.values()){ + ports.putAll(tunneling.getPortByPrefix(tokenId)); + } return new NanoHTTPD.Response(new JSONObject(ports).toString()); } else { - Integer port = this.tunneling.getPort(tokenId); + Integer port = this.getPortByTokenId(tokenId); if (port == null) { return new NanoHTTPD.Response(Status.NOT_FOUND, MIME_PLAINTEXT, "404 Port Not Found"); @@ -61,13 +125,52 @@ public Response serve(IHTTPSession session) { return new NanoHTTPD.Response(port.toString()); } } - if (method.equals(Method.POST)) { - Integer port = this.tunneling.createPort(tokenId); - if (port == null) { - return new NanoHTTPD.Response(Status.INTERNAL_ERROR, MIME_PLAINTEXT, ""); + + //TODO verify if the request can request new port. (ports quota per instance.) + Integer instancePort = null ; + Integer sshServerPort = null ; + + if(tunnelServers.values() != null && !tunnelServers.values().isEmpty()){ + for(TunnelServer tunneling : tunnelServers.values()){ + instancePort = tunneling.createPort(tokenId); + if(instancePort != null){ + sshServerPort = tunneling.getSshTunnelPort(); + break; + } + } + } + + if (instancePort == null) { + try { + TunnelServer tunneling = this.createNewTunnelServer(); + if(tunneling != null){ + instancePort = tunneling.createPort(tokenId); + sshServerPort = tunneling.getSshTunnelPort(); + } + } catch (IOException e) { + return new NanoHTTPD.Response(Status.INTERNAL_ERROR, MIME_PLAINTEXT, "Error while creating shh server to handle new port."); + } + } + + if (instancePort == null) { + return new NanoHTTPD.Response(Status.FORBIDDEN, MIME_PLAINTEXT, "Token [" + tokenId + "] didn't get any port. All ssh servers are busy."); + } + //Return format: instancePort:sshTunnelServerPort (int:int) + return new NanoHTTPD.Response(instancePort.toString()+":"+sshServerPort.toString()); + } + + if (method.equals(Method.DELETE)) { + + if (splitUri.length == 4) { + String portNumber = splitUri[3]; + if(Utils.isNumber(portNumber)){ + if(this.releaseInstancePort(tokenId, Integer.parseInt(portNumber))){ + return new NanoHTTPD.Response(Status.OK, MIME_PLAINTEXT, "OK"); + } + } } - return new NanoHTTPD.Response(port.toString()); + return new NanoHTTPD.Response(Status.METHOD_NOT_ALLOWED, MIME_PLAINTEXT, "Token can not delete this port"); } return new NanoHTTPD.Response(Status.METHOD_NOT_ALLOWED, MIME_PLAINTEXT, ""); @@ -105,5 +208,94 @@ public Response serve(IHTTPSession session) { return new NanoHTTPD.Response(Status.METHOD_NOT_ALLOWED, MIME_PLAINTEXT, ""); } + + private TunnelServer createNewTunnelServer() throws IOException{ + + //Setting available ports to this tunnel server + int initialPort = 0; + int endPort = 0; + int sshTunnelPort = 0; + + Set usedInitialPorts = new HashSet(); + for (TunnelServer tunnelServer : tunnelServers.values()) { + usedInitialPorts.add(new Integer(tunnelServer.getLowerPort())); + } -} + for(int port = lowerPort; port < higherPort; port+=portsPerShhServer){ + if(!usedInitialPorts.contains(new Integer(port))){ + initialPort = port; + break; + } + } + + if(initialPort == 0){ + return null; + } + + endPort = initialPort+(portsPerShhServer-1); + if(endPort > higherPort){ + endPort = higherPort; + } + + //Setting the port that this tunnel Server listening to manage connections requests. + for(int port = lowerSshTunnelPort ; port <= higherSshTunnelPort ; port++){ + if(!tunnelServers.containsKey(new Integer(port))){ + sshTunnelPort = port; + break; + } + } + + if(sshTunnelPort == 0){ + return null; + } + + TunnelServer tunneling = new TunnelServer(sshTunnelHost, sshTunnelPort, + initialPort, endPort, idleTokenTimeout, hostKeyPath); + + tunnelServers.put(new Integer(sshTunnelPort), tunneling); + tunneling.start(); + + return tunneling; + } + + private Integer getPortByTokenId(String tokenId){ + for(TunnelServer tunneling : tunnelServers.values()){ + if(tunneling.getPort(tokenId) != null){ + return tunneling.getPort(tokenId); + } + } + return null; + } + + //TODO: Create new method to validate if the requester have available quota to request new port. + + + + private boolean releaseInstancePort(String tokenId, Integer port){ + for(TunnelServer tunneling : tunnelServers.values()){ + + Integer actualPort = tunneling.getAllPorts().get(tokenId); + + if( actualPort != null && (actualPort.compareTo(port)== 0) ){ + tunneling.releasePort(port); + if(tunneling.getActiveTokensNumber() == 0){ + try { + this.removeTunnelServer(tunneling); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + return true; + } + } + return false; + } + + private void removeTunnelServer(TunnelServer tunneling) throws InterruptedException{ + if(tunneling != null){ + tunneling.stop(); + LOGGER.warn("Removing ssh server with port: "+tunneling.getSshTunnelPort()); + tunnelServers.remove(tunneling.getSshTunnelPort()); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/fogbowcloud/ssh/TunnelServer.java b/src/main/java/org/fogbowcloud/ssh/TunnelServer.java index d609575..028e9c0 100644 --- a/src/main/java/org/fogbowcloud/ssh/TunnelServer.java +++ b/src/main/java/org/fogbowcloud/ssh/TunnelServer.java @@ -53,12 +53,14 @@ public Token(Integer port) { private SshServer sshServer; private String sshTunnelHost; - private int sshTunnelPort; - private int lowerPort; - private int higherPort; + private final int sshTunnelPort; + private final int lowerPort; + private final int higherPort; private String hostKeyPath; private Long idleTokenTimeout; + private int nioWorkers; + public TunnelServer(String sshTunnelHost, int sshTunnelPort, int lowerPort, int higherPort, Long idleTokenTimeout, String hostKeyPath) { this.sshTunnelHost = sshTunnelHost; @@ -68,6 +70,7 @@ public TunnelServer(String sshTunnelHost, int sshTunnelPort, int lowerPort, this.idleTokenTimeout = idleTokenTimeout == null ? TOKEN_EXPIRATION_TIMEOUT : idleTokenTimeout; this.hostKeyPath = hostKeyPath; + this.nioWorkers = (higherPort - lowerPort)+2; //+2 is to have a secure margin of works for ports. If number of ports is 5, workers will be set to 6; } public synchronized Integer createPort(String token) { @@ -158,6 +161,7 @@ public String getName() { sshServer.setUserAuthFactories(userAuthenticators); sshServer.setHost(sshTunnelHost == null ? "0.0.0.0" : sshTunnelHost); sshServer.setPort(sshTunnelPort); + sshServer.setNioWorkers(nioWorkers); executor.scheduleWithFixedDelay(new Runnable() { @Override public void run() { @@ -267,4 +271,65 @@ public Map getPortByPrefix(String tokenId) { return portsByPrefix; } -} + //TODO: Create a method that return boolean for server busy (reached port limit) or not. + public boolean isServerBusy(){ + for (int port = lowerPort; port <= higherPort; port++) { + if (!isTaken(port)) { + return false; + } + } + return true; + } + + //TODO: Create a new method to remove a token and release the relative port. + public void removeToken(String tokenId){ + tokens.remove(tokenId); + + } + + public void releasePort(Integer port){ + if(port != null){ + String tokenToRemove = null; + for(Entry e : tokens.entrySet()){ + if(port.compareTo(e.getValue().port) == 0){ + tokenToRemove = e.getKey(); + break; + } + } + + if(this.getActiveSession(port.intValue()) != null){ + this.getActiveSession(port.intValue()).close(true); + } + tokens.remove(tokenToRemove); + } + } + + public void stop() throws InterruptedException{ + + List activeSessions = sshServer.getActiveSessions(); + if(activeSessions != null && !activeSessions.isEmpty()){ + for (AbstractSession session : activeSessions) { + session.close(true); + } + } + sshServer.stop(true); + + } + + public int getActiveTokensNumber(){ + return tokens.size(); + } + + public int getLowerPort() { + return lowerPort; + } + + public int getHigherPort() { + return higherPort; + } + + public int getSshTunnelPort() { + return sshTunnelPort; + } + +} \ No newline at end of file diff --git a/src/main/java/org/fogbowcloud/ssh/Utils.java b/src/main/java/org/fogbowcloud/ssh/Utils.java new file mode 100644 index 0000000..97f2989 --- /dev/null +++ b/src/main/java/org/fogbowcloud/ssh/Utils.java @@ -0,0 +1,20 @@ +package org.fogbowcloud.ssh; + +import java.math.BigDecimal; + +public class Utils { + + public static boolean isNumber(String value){ + if(value != null){ + try{ + BigDecimal bd = new BigDecimal(value); + bd = null; + return true; + }catch(NumberFormatException nfe){ + return false; + } + } + return false; + } + +}