From fc535a5e76c6a729a191ef7dec5c4bd9aa0006aa Mon Sep 17 00:00:00 2001 From: Jeroen van Erp Date: Fri, 27 Mar 2015 15:19:04 +0100 Subject: [PATCH] Added support for (unauthenticated) HTTP proxies (fixes #170) --- .../hierynomus/sshj/socket/SocketFactory.java | 92 +++++++++++++++++++ .../java/net/schmizz/sshj/SocketClient.java | 24 ++--- 2 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 src/main/java/com/hierynomus/sshj/socket/SocketFactory.java diff --git a/src/main/java/com/hierynomus/sshj/socket/SocketFactory.java b/src/main/java/com/hierynomus/sshj/socket/SocketFactory.java new file mode 100644 index 00000000..d7f4e047 --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/socket/SocketFactory.java @@ -0,0 +1,92 @@ +package com.hierynomus.sshj.socket; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.Socket; +import java.net.SocketException; + +import static java.lang.String.format; + +/** + * https://code.google.com/p/java-socket-over-http-proxy-connect/source/browse/trunk/src/sg/com/en/SocketFactory.java + */ +public class SocketFactory { + + private static final int DEFAULT_CONNECT_TIMEOUT = 0; + private int connectTimeout = DEFAULT_CONNECT_TIMEOUT; + + private final javax.net.SocketFactory delegateSocketFactory = javax.net.SocketFactory.getDefault(); + + public static SocketFactory getDefault() { + return new SocketFactory(); + } + + public Socket createSocket(String address, int port) throws IOException { + return createSocket(new InetSocketAddress(address, port)); + } + + public Socket createSocket(InetSocketAddress inetSocketAddress) throws IOException { + Socket socket = delegateSocketFactory.createSocket(); + socket.connect(inetSocketAddress, connectTimeout); + return socket; + } + + public Socket createSocket(InetSocketAddress inetSocketAddress, Proxy proxy) throws IOException { + if (proxy.type() == Proxy.Type.HTTP) { + return createHttpProxySocket(inetSocketAddress, proxy); + } + Socket socket = new Socket(proxy); + socket.connect(inetSocketAddress, connectTimeout); + return socket; + } + + private Socket createHttpProxySocket(InetSocketAddress inetSocketAddress, Proxy proxy) throws IOException { + Socket socket = delegateSocketFactory.createSocket(); + socket.connect(proxy.address()); + + String connect = format("CONNECT %s:%d\n\n", inetSocketAddress.getHostName(), inetSocketAddress.getPort()); + socket.getOutputStream().write(connect.getBytes()); + checkAndFlushProxyResponse(socket); + return socket; + } + + private void checkAndFlushProxyResponse(Socket socket)throws IOException { + InputStream socketInput = socket.getInputStream(); + byte[] tmpBuffer = new byte[512]; + int len = socketInput.read(tmpBuffer, 0, tmpBuffer.length); + + if (len == 0) { + throw new SocketException("Empty response from proxy"); + } + + String proxyResponse = new String(tmpBuffer, 0, len, "UTF-8"); + + // Expecting HTTP/1.x 200 OK + if (proxyResponse.contains("200")) { + // Flush any outstanding message in buffer + if (socketInput.available() > 0) + socketInput.skip(socketInput.available()); + // Proxy Connect Successful + } else { + throw new SocketException("Fail to create Socket\nResponse was:" + proxyResponse); + } + } + + public Socket createSocket(InetSocketAddress bindpoint, InetSocketAddress endpoint) throws IOException { + Socket socket = delegateSocketFactory.createSocket(); + socket.bind(bindpoint); + socket.connect(endpoint, connectTimeout); + return socket; + } + + public int getConnectTimeout() { + return connectTimeout; + } + + public void setConnectTimeout(int connectTimeout) { + this.connectTimeout = connectTimeout; + } + +} diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index 57293b80..7002f851 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -15,7 +15,8 @@ */ package net.schmizz.sshj; -import javax.net.SocketFactory; +import com.hierynomus.sshj.socket.SocketFactory; + import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -32,10 +33,7 @@ public abstract class SocketClient { private InputStream input; private OutputStream output; - private SocketFactory socketFactory = SocketFactory.getDefault(); - - private static final int DEFAULT_CONNECT_TIMEOUT = 0; - private int connectTimeout = DEFAULT_CONNECT_TIMEOUT; + private SocketFactory socketFactory = new SocketFactory(); private int timeout = 0; @@ -47,15 +45,13 @@ public abstract class SocketClient { public void connect(InetAddress host, int port) throws IOException { - socket = socketFactory.createSocket(); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + socket = socketFactory.createSocket(new InetSocketAddress(host, port)); onConnect(); } public void connect(InetAddress host, int port, Proxy proxy) throws IOException { - socket = new Socket(proxy); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + socket = socketFactory.createSocket(new InetSocketAddress(host, port), proxy); onConnect(); } @@ -74,9 +70,9 @@ public abstract class SocketClient { public void connect(InetAddress host, int port, InetAddress localAddr, int localPort) throws IOException { - socket = socketFactory.createSocket(); - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + InetSocketAddress bindpoint = new InetSocketAddress(localAddr, localPort); + InetSocketAddress endpoint = new InetSocketAddress(host, port); + socket = socketFactory.createSocket(bindpoint, endpoint); onConnect(); } @@ -160,11 +156,11 @@ public abstract class SocketClient { } public int getConnectTimeout() { - return connectTimeout; + return socketFactory.getConnectTimeout(); } public void setConnectTimeout(int connectTimeout) { - this.connectTimeout = connectTimeout; + socketFactory.setConnectTimeout(connectTimeout); } public int getTimeout() {