Refactored out duplicate code.

This commit is contained in:
Jeroen van Erp
2019-05-08 13:44:04 +02:00
parent 0e784dd171
commit c2b9c0266d
6 changed files with 100 additions and 75 deletions

View File

@@ -19,11 +19,7 @@ import net.schmizz.sshj.common.*;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.connection.channel.direct.DirectConnection;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.connection.channel.direct.SessionChannel;
import net.schmizz.sshj.connection.channel.direct.SessionFactory;
import net.schmizz.sshj.connection.channel.direct.*;
import net.schmizz.sshj.connection.channel.forwarded.ConnectListener;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.ForwardedTCPIPChannel;
@@ -665,7 +661,7 @@ public class SSHClient
*
* @return a {@link LocalPortForwarder}
*/
public LocalPortForwarder newLocalPortForwarder(LocalPortForwarder.Parameters parameters,
public LocalPortForwarder newLocalPortForwarder(Parameters parameters,
ServerSocket serverSocket) {
LocalPortForwarder forwarder = new LocalPortForwarder(conn, parameters, serverSocket, loggerFactory);
forwarders.add(forwarder);

View File

@@ -15,33 +15,22 @@
*/
package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.Connection;
/** A channel for creating a direct TCP/IP connection from the server to a remote address. */
public class DirectConnection extends AbstractDirectChannel {
private final String remoteHost;
private final int remotePort;
public class DirectConnection extends DirectTCPIPChannel {
public static final String LOCALHOST = "localhost";
public static final int LOCALPORT = 65536;
public DirectConnection(Connection conn, String remoteHost, int remotePort) {
super(conn, "direct-tcpip");
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
@Override protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(getRemoteHost())
.putUInt32(getRemotePort())
.putString("localhost")
.putUInt32(65536); // it looks like OpenSSH uses this value in stdio-forward
super(conn, new Parameters(LOCALHOST, LOCALPORT, remoteHost, remotePort));
}
public String getRemoteHost() {
return remoteHost;
return parameters.getRemoteHost();
}
public int getRemotePort() {
return remotePort;
return parameters.getRemotePort();
}
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.Connection;
public class DirectTCPIPChannel extends AbstractDirectChannel {
protected final Parameters parameters;
protected DirectTCPIPChannel(Connection conn, Parameters parameters) {
super(conn, "direct-tcpip");
this.parameters = parameters;
}
@Override
protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(parameters.getRemoteHost())
.putUInt32(parameters.getRemotePort())
.putString(parameters.getLocalHost())
.putUInt32(parameters.getLocalPort());
}
}

View File

@@ -18,7 +18,6 @@ package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor;
@@ -34,48 +33,14 @@ import static com.hierynomus.sshj.backport.Sockets.asCloseable;
public class LocalPortForwarder {
public static class Parameters {
private final String localHost;
private final int localPort;
private final String remoteHost;
private final int remotePort;
public Parameters(String localHost, int localPort, String remoteHost, int remotePort) {
this.localHost = localHost;
this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
public String getRemoteHost() {
return remoteHost;
}
public int getRemotePort() {
return remotePort;
}
public String getLocalHost() {
return localHost;
}
public int getLocalPort() {
return localPort;
}
}
public static class DirectTCPIPChannel
extends AbstractDirectChannel {
public static class ForwardedChannel
extends DirectTCPIPChannel {
protected final Socket socket;
protected final Parameters parameters;
public DirectTCPIPChannel(Connection conn, Socket socket, Parameters parameters) {
super(conn, "direct-tcpip");
public ForwardedChannel(Connection conn, Socket socket, Parameters parameters) {
super(conn, parameters);
this.socket = socket;
this.parameters = parameters;
}
protected void start()
@@ -90,16 +55,6 @@ public class LocalPortForwarder {
.spawnDaemon("chan2soc");
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, socket);
}
@Override
protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(parameters.getRemoteHost())
.putUInt32(parameters.getRemotePort())
.putString(parameters.getLocalHost())
.putUInt32(parameters.getLocalPort());
}
}
private final LoggerFactory loggerFactory;
@@ -118,7 +73,7 @@ public class LocalPortForwarder {
}
private void startChannel(Socket socket) throws IOException {
DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters);
ForwardedChannel chan = new ForwardedChannel(conn, socket, parameters);
try {
chan.open();
chan.start();

View File

@@ -0,0 +1,48 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.connection.channel.direct;
public class Parameters {
private final String localHost;
private final int localPort;
private final String remoteHost;
private final int remotePort;
public Parameters(String localHost, int localPort, String remoteHost, int remotePort) {
this.localHost = localHost;
this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
public String getRemoteHost() {
return remoteHost;
}
public int getRemotePort() {
return remotePort;
}
public String getLocalHost() {
return localHost;
}
public int getLocalPort() {
return localPort;
}
}

View File

@@ -16,7 +16,7 @@
package com.hierynomus.sshj.connection.channel.direct
import com.hierynomus.sshj.test.SshFixture
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder
import net.schmizz.sshj.connection.channel.direct.Parameters
import org.junit.Rule
import spock.lang.Specification
import spock.util.concurrent.PollingConditions
@@ -33,7 +33,7 @@ class LocalPortForwarderSpec extends Specification {
def client = tunnelFixture.setupConnectedDefaultClient()
client.authPassword("test", "test")
def socket = new ServerSocket(0)
def lpf = client.newLocalPortForwarder(new LocalPortForwarder.Parameters("localhost", socket.getLocalPort(), "localhost", realServer.server.port), socket)
def lpf = client.newLocalPortForwarder(new Parameters("localhost", socket.getLocalPort(), "localhost", realServer.server.port), socket)
def thread = new Thread(new Runnable() {
@Override
void run() {