Refactored out duplicate code.

This commit is contained in:
Jeroen van Erp
2019-05-08 13:44:04 +02:00
parent 0e784dd171
commit c2b9c0266d
6 changed files with 100 additions and 75 deletions

View File

@@ -19,11 +19,7 @@ import net.schmizz.sshj.common.*;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl; import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.connection.channel.direct.DirectConnection; import net.schmizz.sshj.connection.channel.direct.*;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.connection.channel.direct.SessionChannel;
import net.schmizz.sshj.connection.channel.direct.SessionFactory;
import net.schmizz.sshj.connection.channel.forwarded.ConnectListener; import net.schmizz.sshj.connection.channel.forwarded.ConnectListener;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder; import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.ForwardedTCPIPChannel; import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.ForwardedTCPIPChannel;
@@ -665,7 +661,7 @@ public class SSHClient
* *
* @return a {@link LocalPortForwarder} * @return a {@link LocalPortForwarder}
*/ */
public LocalPortForwarder newLocalPortForwarder(LocalPortForwarder.Parameters parameters, public LocalPortForwarder newLocalPortForwarder(Parameters parameters,
ServerSocket serverSocket) { ServerSocket serverSocket) {
LocalPortForwarder forwarder = new LocalPortForwarder(conn, parameters, serverSocket, loggerFactory); LocalPortForwarder forwarder = new LocalPortForwarder(conn, parameters, serverSocket, loggerFactory);
forwarders.add(forwarder); forwarders.add(forwarder);

View File

@@ -15,33 +15,22 @@
*/ */
package net.schmizz.sshj.connection.channel.direct; package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
/** A channel for creating a direct TCP/IP connection from the server to a remote address. */ /** A channel for creating a direct TCP/IP connection from the server to a remote address. */
public class DirectConnection extends AbstractDirectChannel { public class DirectConnection extends DirectTCPIPChannel {
private final String remoteHost; public static final String LOCALHOST = "localhost";
private final int remotePort; public static final int LOCALPORT = 65536;
public DirectConnection(Connection conn, String remoteHost, int remotePort) { public DirectConnection(Connection conn, String remoteHost, int remotePort) {
super(conn, "direct-tcpip"); super(conn, new Parameters(LOCALHOST, LOCALPORT, remoteHost, remotePort));
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
@Override protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(getRemoteHost())
.putUInt32(getRemotePort())
.putString("localhost")
.putUInt32(65536); // it looks like OpenSSH uses this value in stdio-forward
} }
public String getRemoteHost() { public String getRemoteHost() {
return remoteHost; return parameters.getRemoteHost();
} }
public int getRemotePort() { public int getRemotePort() {
return remotePort; return parameters.getRemotePort();
} }
} }

View File

@@ -0,0 +1,37 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.Connection;
public class DirectTCPIPChannel extends AbstractDirectChannel {
protected final Parameters parameters;
protected DirectTCPIPChannel(Connection conn, Parameters parameters) {
super(conn, "direct-tcpip");
this.parameters = parameters;
}
@Override
protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(parameters.getRemoteHost())
.putUInt32(parameters.getRemotePort())
.putString(parameters.getLocalHost())
.putUInt32(parameters.getLocalPort());
}
}

View File

@@ -18,7 +18,6 @@ package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.concurrent.Event; import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.LoggerFactory; import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor;
@@ -34,48 +33,14 @@ import static com.hierynomus.sshj.backport.Sockets.asCloseable;
public class LocalPortForwarder { public class LocalPortForwarder {
public static class Parameters { public static class ForwardedChannel
extends DirectTCPIPChannel {
private final String localHost;
private final int localPort;
private final String remoteHost;
private final int remotePort;
public Parameters(String localHost, int localPort, String remoteHost, int remotePort) {
this.localHost = localHost;
this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
public String getRemoteHost() {
return remoteHost;
}
public int getRemotePort() {
return remotePort;
}
public String getLocalHost() {
return localHost;
}
public int getLocalPort() {
return localPort;
}
}
public static class DirectTCPIPChannel
extends AbstractDirectChannel {
protected final Socket socket; protected final Socket socket;
protected final Parameters parameters;
public DirectTCPIPChannel(Connection conn, Socket socket, Parameters parameters) { public ForwardedChannel(Connection conn, Socket socket, Parameters parameters) {
super(conn, "direct-tcpip"); super(conn, parameters);
this.socket = socket; this.socket = socket;
this.parameters = parameters;
} }
protected void start() protected void start()
@@ -90,16 +55,6 @@ public class LocalPortForwarder {
.spawnDaemon("chan2soc"); .spawnDaemon("chan2soc");
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, socket); SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, socket);
} }
@Override
protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(parameters.getRemoteHost())
.putUInt32(parameters.getRemotePort())
.putString(parameters.getLocalHost())
.putUInt32(parameters.getLocalPort());
}
} }
private final LoggerFactory loggerFactory; private final LoggerFactory loggerFactory;
@@ -118,7 +73,7 @@ public class LocalPortForwarder {
} }
private void startChannel(Socket socket) throws IOException { private void startChannel(Socket socket) throws IOException {
DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters); ForwardedChannel chan = new ForwardedChannel(conn, socket, parameters);
try { try {
chan.open(); chan.open();
chan.start(); chan.start();

View File

@@ -0,0 +1,48 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.connection.channel.direct;
public class Parameters {
private final String localHost;
private final int localPort;
private final String remoteHost;
private final int remotePort;
public Parameters(String localHost, int localPort, String remoteHost, int remotePort) {
this.localHost = localHost;
this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
public String getRemoteHost() {
return remoteHost;
}
public int getRemotePort() {
return remotePort;
}
public String getLocalHost() {
return localHost;
}
public int getLocalPort() {
return localPort;
}
}

View File

@@ -16,7 +16,7 @@
package com.hierynomus.sshj.connection.channel.direct package com.hierynomus.sshj.connection.channel.direct
import com.hierynomus.sshj.test.SshFixture import com.hierynomus.sshj.test.SshFixture
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder import net.schmizz.sshj.connection.channel.direct.Parameters
import org.junit.Rule import org.junit.Rule
import spock.lang.Specification import spock.lang.Specification
import spock.util.concurrent.PollingConditions import spock.util.concurrent.PollingConditions
@@ -33,7 +33,7 @@ class LocalPortForwarderSpec extends Specification {
def client = tunnelFixture.setupConnectedDefaultClient() def client = tunnelFixture.setupConnectedDefaultClient()
client.authPassword("test", "test") client.authPassword("test", "test")
def socket = new ServerSocket(0) def socket = new ServerSocket(0)
def lpf = client.newLocalPortForwarder(new LocalPortForwarder.Parameters("localhost", socket.getLocalPort(), "localhost", realServer.server.port), socket) def lpf = client.newLocalPortForwarder(new Parameters("localhost", socket.getLocalPort(), "localhost", realServer.server.port), socket)
def thread = new Thread(new Runnable() { def thread = new Thread(new Runnable() {
@Override @Override
void run() { void run() {