diff --git a/src/main/java/com/hierynomus/sshj/backport/JavaVersion.java b/src/main/java/com/hierynomus/sshj/backport/JavaVersion.java new file mode 100644 index 00000000..548488c1 --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/backport/JavaVersion.java @@ -0,0 +1,13 @@ +package com.hierynomus.sshj.backport; + +import java.math.BigDecimal; + +public class JavaVersion { + public static boolean isJava7OrEarlier() { + String property = System.getProperty("java.specification.version"); + float diff = Float.parseFloat(property) - 1.7f; + + return diff < 0.01; + } + +} diff --git a/src/main/java/com/hierynomus/sshj/backport/Jdk7HttpProxySocket.java b/src/main/java/com/hierynomus/sshj/backport/Jdk7HttpProxySocket.java new file mode 100644 index 00000000..38f7f329 --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/backport/Jdk7HttpProxySocket.java @@ -0,0 +1,62 @@ +package com.hierynomus.sshj.backport; + +import java.io.IOException; +import java.io.InputStream; +import java.net.*; +import java.nio.charset.Charset; + +public class Jdk7HttpProxySocket extends Socket { + + private Proxy httpProxy = null; + + public Jdk7HttpProxySocket(Proxy proxy) { + super(proxy.type() == Proxy.Type.HTTP ? Proxy.NO_PROXY : proxy); + if (proxy.type() == Proxy.Type.HTTP) { + this.httpProxy = proxy; + } + } + + @Override + public void connect(SocketAddress endpoint, int timeout) throws IOException { + if (httpProxy != null) { + connectHttpProxy(endpoint, timeout); + } else { + super.connect(endpoint, timeout); + } + } + + private void connectHttpProxy(SocketAddress endpoint, int timeout) throws IOException { + super.connect(httpProxy.address(), timeout); + + if (!(endpoint instanceof InetSocketAddress)) { + throw new SocketException("Expected an InetSocketAddress to connect to, got: " + endpoint); + } + InetSocketAddress isa = (InetSocketAddress) endpoint; + String httpConnect = "CONNECT " + isa.getHostName() + ":" + isa.getPort() + " HTTP/1.0\n\n"; + getOutputStream().write(httpConnect.getBytes(Charset.forName("UTF-8"))); + checkAndFlushProxyResponse(); + } + + private void checkAndFlushProxyResponse()throws IOException { + InputStream socketInput = getInputStream(); + byte[] tmpBuffer = new byte[512]; + int len = socketInput.read(tmpBuffer, 0, tmpBuffer.length); + + if (len == 0) { + throw new SocketException("Empty response from proxy"); + } + + String proxyResponse = new String(tmpBuffer, 0, len, "UTF-8"); + + // Expecting HTTP/1.x 200 OK + if (proxyResponse.contains("200")) { + // Flush any outstanding message in buffer + if (socketInput.available() > 0) { + socketInput.skip(socketInput.available()); + } + // Proxy Connect Successful + } else { + throw new SocketException("Fail to create Socket\nResponse was:" + proxyResponse); + } + } +} diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index 57293b80..b4caf5bc 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -15,6 +15,9 @@ */ package net.schmizz.sshj; +import com.hierynomus.sshj.backport.JavaVersion; +import com.hierynomus.sshj.backport.Jdk7HttpProxySocket; + import javax.net.SocketFactory; import java.io.IOException; import java.io.InputStream; @@ -45,34 +48,35 @@ public abstract class SocketClient { this.defaultPort = defaultPort; } - public void connect(InetAddress host, int port) - throws IOException { + public void connect(InetAddress host, int port) throws IOException { socket = socketFactory.createSocket(); socket.connect(new InetSocketAddress(host, port), connectTimeout); onConnect(); } - public void connect(InetAddress host, int port, Proxy proxy) - throws IOException { - socket = new Socket(proxy); + + public void connect(InetAddress host, int port, Proxy proxy) throws IOException { + if (JavaVersion.isJava7OrEarlier() && proxy.type() == Proxy.Type.HTTP) { + // Java7 and earlier have no support for HTTP Connect proxies, return our custom socket. + socket = new Jdk7HttpProxySocket(proxy); + } else { + socket = new Socket(proxy); + } socket.connect(new InetSocketAddress(host, port), connectTimeout); onConnect(); } - public void connect(String hostname, int port) - throws IOException { + public void connect(String hostname, int port) throws IOException { this.hostname = hostname; connect(InetAddress.getByName(hostname), port); } - public void connect(String hostname, int port, Proxy proxy) - throws IOException { + public void connect(String hostname, int port, Proxy proxy) throws IOException { this.hostname = hostname; connect(InetAddress.getByName(hostname), port, proxy); } - public void connect(InetAddress host, int port, - InetAddress localAddr, int localPort) + public void connect(InetAddress host, int port, InetAddress localAddr, int localPort) throws IOException { socket = socketFactory.createSocket(); socket.bind(new InetSocketAddress(localAddr, localPort)); @@ -80,35 +84,28 @@ public abstract class SocketClient { onConnect(); } - public void connect(String hostname, int port, - InetAddress localAddr, int localPort) - throws IOException { + public void connect(String hostname, int port, InetAddress localAddr, int localPort) throws IOException { this.hostname = hostname; connect(InetAddress.getByName(hostname), port, localAddr, localPort); } - public void connect(InetAddress host) - throws IOException { + public void connect(InetAddress host) throws IOException { connect(host, defaultPort); } - public void connect(String hostname) - throws IOException { + public void connect(String hostname) throws IOException { connect(hostname, defaultPort); } - public void connect(InetAddress host, Proxy proxy) - throws IOException { + public void connect(InetAddress host, Proxy proxy) throws IOException { connect(host, defaultPort, proxy); } - public void connect(String hostname, Proxy proxy) - throws IOException { + public void connect(String hostname, Proxy proxy) throws IOException { connect(hostname, defaultPort, proxy); } - public void disconnect() - throws IOException { + public void disconnect() throws IOException { if (socket != null) { socket.close(); socket = null; @@ -131,7 +128,6 @@ public abstract class SocketClient { return socket.getLocalPort(); } - public InetAddress getLocalAddress() { return socket.getLocalAddress(); } @@ -149,10 +145,11 @@ public abstract class SocketClient { } public void setSocketFactory(SocketFactory factory) { - if (factory == null) + if (factory == null) { socketFactory = SocketFactory.getDefault(); - else + } else { socketFactory = factory; + } } public SocketFactory getSocketFactory() { @@ -187,8 +184,7 @@ public abstract class SocketClient { return output; } - void onConnect() - throws IOException { + void onConnect() throws IOException { socket.setSoTimeout(timeout); input = socket.getInputStream(); output = socket.getOutputStream();