Package com.sshtools.j2ssh.transport

Source Code of com.sshtools.j2ssh.transport.TransportProtocolCommon

/*
*  SSHTools - Java SSH2 API
*
*  Copyright (C) 2002-2003 Lee David Painter and Contributors.
*
*  Contributions made by:
*
*  Brett Smith
*  Richard Pernavas
*  Erwin Bolwidt
*
*  This program is free software; you can redistribute it and/or
*  modify it under the terms of the GNU General Public License
*  as published by the Free Software Foundation; either version 2
*  of the License, or (at your option) any later version.
*
*  This program is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU General Public License for more details.
*
*  You should have received a copy of the GNU General Public License
*  along with this program; if not, write to the Free Software
*  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
*/
package com.sshtools.j2ssh.transport;

import com.sshtools.j2ssh.SshException;
import com.sshtools.j2ssh.SshThread;
import com.sshtools.j2ssh.configuration.ConfigurationLoader;
import com.sshtools.j2ssh.configuration.SshConnectionProperties;
import com.sshtools.j2ssh.io.ByteArrayWriter;
import com.sshtools.j2ssh.net.TransportProvider;
import com.sshtools.j2ssh.transport.kex.KeyExchangeException;
import com.sshtools.j2ssh.transport.kex.SshKeyExchange;
import com.sshtools.j2ssh.transport.kex.SshKeyExchangeFactory;
import com.sshtools.j2ssh.util.Hash;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.io.IOException;
import java.io.InterruptedIOException;

import java.math.BigInteger;

import java.security.NoSuchAlgorithmException;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Vector;


/**
*
*
* @author $author$
* @version $Revision: 1.2 $
*/
public abstract class TransportProtocolCommon
implements TransportProtocol, Runnable
{
    // Flag to keep on running
    //private boolean keepRunning = true;

    /**  */
    protected static Log log = LogFactory.getLog(TransportProtocolCommon.class);
    private static int nextThreadNo = 1;

    /**  */
    public final static int EOL_CRLF = 1;

    /**  */
    public final static int EOL_LF = 2;

    /**  */
    public static final String PROTOCOL_VERSION = "2.0";

    /**  */
    public static String SOFTWARE_VERSION_COMMENTS = "http://www.sshtools.com " +
        ConfigurationLoader.getVersionString("J2SSH", "j2ssh.properties");
    private int threadNo = nextThreadNo++;

    /**  */
    protected BigInteger k = null;

    /**  */
    protected Boolean completeOnNewKeys = new Boolean(false);

    /**  */
    protected HostKeyVerification hosts;

    /**  */
    protected Map kexs = new HashMap();
    private boolean sendIgnore = false;

    //protected Map transportMessages = new HashMap();

    /**  */
    protected SshConnectionProperties properties;

    /**  */
    protected SshMessageStore messageStore = new SshMessageStore();

    /**  */
    protected SshMsgKexInit clientKexInit = null;

    /**  */
    protected SshMsgKexInit serverKexInit = null;

    /**  */
    protected String clientIdent = null;

    /**  */
    protected String serverIdent = null;

    /**  */
    protected TransportProtocolAlgorithmSync algorithmsIn;

    /**  */
    protected TransportProtocolAlgorithmSync algorithmsOut;

    /**  */
    protected TransportProtocolState state = new TransportProtocolState();
    private byte[] exchangeHash = null;

    /**  */
    protected byte[] sessionIdentifier = null;

    /**  */
    protected byte[] hostKey = null;

    /**  */
    protected byte[] signature = null;
    private Vector eventHandlers = new Vector();

    // Storage of messages whilst in key exchange
    private List messageStack = new ArrayList();

    // Message notification registry
    private Map messageNotifications = new HashMap();

    // Key exchange lock for accessing the kex init messages
    private Object kexLock = new Object();

    // Object to synchronize key changing
    private Object keyLock = new Object();

    // The connected socket
    //private Socket socket;
    // The underlying transport provider
    TransportProvider provider;

    // The thread object
    private SshThread thread;
    private long kexTimeout = 3600000L;
    // private long kexTransferLimitKB = 100L; // 100 K
    private long kexTransferLimitKB = 1073741824L/1024L;
    // private long kexTransferLimit = 1073741824L;
    private long startTime = System.currentTimeMillis();
    private long transferredKB   = 0;
    private long lastTriggeredKB = 0;

    // The input stream for recieving data

    /**  */
    protected TransportProtocolInputStream sshIn;

    // The output stream for sending data

    /**  */
    protected TransportProtocolOutputStream sshOut;
    private int remoteEOL = EOL_CRLF;

    //private Map registeredMessages = new HashMap();
    private Vector messageStores = new Vector();

    /**
     * Creates a new TransportProtocolCommon object.
     */
    public TransportProtocolCommon() {
    }

    /**
     *
     *
     * @return
     */
    public int getConnectionId() {
        return threadNo;
    }

    /**
     *
     *
     * @return
     */
    public int getRemoteEOL() {
        return remoteEOL;
    }

    /**
     *
     *
     * @return
     */
    public TransportProtocolState getState() {
        return state;
    }

    /**
     *
     *
     * @return
     */
    public SshConnectionProperties getProperties() {
        return properties;
    }

    /**
     *
     */
    protected abstract void onDisconnect();

    /**
     *
     *
     * @param description
     */
    public void disconnect(String description) {
        if (log.isDebugEnabled()) {
            log.debug("Disconnect: " + description);
        }

        try {
            state.setValue(TransportProtocolState.DISCONNECTED);
            state.setDisconnectReason(description);

            // Send the disconnect message automatically
            sendDisconnect(SshMsgDisconnect.BY_APPLICATION, description);
        } catch (Exception e) {
            log.warn("Failed to send disconnect", e);
        }
    }

    /**
     *
     *
     * @param sendIgnore
     */
    public void setSendIgnore(boolean sendIgnore) {
        this.sendIgnore = sendIgnore;
    }

    /**
     *
     *
     * @param seconds
     *
     * @throws TransportProtocolException
     */
    public void setKexTimeout(long seconds) throws TransportProtocolException {
        if (seconds < 60) {
            throw new TransportProtocolException(
                "Keys can only be re-exchanged every minute or more");
        }

        kexTimeout = seconds * 1000;
    }

    /**
     *
     *
     * @param kilobytes
     *
     * @throws TransportProtocolException
     */
    public void setKexTransferLimit(long kilobytes)
        throws TransportProtocolException {
        if (kilobytes < 10) {
            throw new TransportProtocolException(
                "Keys can only be re-exchanged after every 10k of data, or more");
        }

        //kexTransferLimit = kilobytes * 1024;
        kexTransferLimitKB = kilobytes;
    }

    /*public InetSocketAddress getRemoteAddress() {
       return (InetSocketAddress)socket.getRemoteSocketAddress();
     }*/
    public long getOutgoingByteCount() {
        return sshOut.getNumBytesTransfered();
    }

    /**
     *
     *
     * @return
     */
    public long getIncomingByteCount() {
        return sshIn.getNumBytesTransfered();
    }

    /**
     *
     *
     * @param eventHandler
     */
    public void addEventHandler(TransportProtocolEventHandler eventHandler) {
        if (eventHandler != null) {
            eventHandlers.add(eventHandler);
        }
    }

    /**
     *
     *
     * @throws MessageAlreadyRegisteredException
     */
    public abstract void registerTransportMessages()
        throws MessageAlreadyRegisteredException;

    /**
     *
     *
     * @return
     */
    public byte[] getSessionIdentifier() {
        return (byte[]) sessionIdentifier.clone();
    }

    /**
     *
     */
    public void run() {
      try {
        state.setValue(TransportProtocolState.NEGOTIATING_PROTOCOL);
        log.info("Registering transport protocol messages with inputstream");
        algorithmsOut = new TransportProtocolAlgorithmSync();
        algorithmsIn = new TransportProtocolAlgorithmSync();
       
        // Create the input/output streams
        sshIn = new TransportProtocolInputStream(this,
            provider.getInputStream(), algorithmsIn);
        sshOut = new TransportProtocolOutputStream(provider.getOutputStream(),
            this, algorithmsOut);
       
        // Register the transport layer messages that this class will handle
        messageStore.registerMessage(SshMsgDisconnect.SSH_MSG_DISCONNECT,
            SshMsgDisconnect.class);
        messageStore.registerMessage(SshMsgIgnore.SSH_MSG_IGNORE,
            SshMsgIgnore.class);
        messageStore.registerMessage(SshMsgUnimplemented.SSH_MSG_UNIMPLEMENTED,
            SshMsgUnimplemented.class);
        messageStore.registerMessage(SshMsgDebug.SSH_MSG_DEBUG,
            SshMsgDebug.class);
        messageStore.registerMessage(SshMsgKexInit.SSH_MSG_KEX_INIT,
            SshMsgKexInit.class);
        messageStore.registerMessage(SshMsgNewKeys.SSH_MSG_NEWKEYS,
            SshMsgNewKeys.class);
        registerTransportMessages();
       
        List list = SshKeyExchangeFactory.getSupportedKeyExchanges();
        Iterator it = list.iterator();
       
        while (it.hasNext()) {
          String keyExchange = (String) it.next();
          SshKeyExchange kex = SshKeyExchangeFactory.newInstance(keyExchange);
          kex.init(this);
          kexs.put(keyExchange, kex);
        }
       
        // call abstract to initialise the local ident string
        setLocalIdent();
       
        // negotiate the protocol version
        negotiateVersion();
        startBinaryPacketProtocol();
      }
      catch (Throwable e) {
        if (e instanceof IOException) {
          state.setLastError((IOException) e);
        }
       
        if (state.getValue() != TransportProtocolState.DISCONNECTED) {
          log.error("The Transport Protocol thread failed", e);
         
          //log.info(e.getMessage());
          stop();
        }
      }
      finally {
        thread = null;
      }
     
      log.debug("The Transport Protocol has been stopped");
    }

    /**
     *
     *
     * @param msg
     * @param sender
     *
     * @throws IOException
     * @throws TransportProtocolException
     */
    public synchronized void sendMessage(SshMessage msg, Object sender)
        throws IOException {
        // Send a message, if were in key exchange then add it to
        // the list unless of course it is a transport protocol or key
        // exchange message
        if (log.isDebugEnabled()) {
            log.info("Sending " + msg.getMessageName());
        }

        int currentState = state.getValue();

        if (sender instanceof SshKeyExchange ||
                sender instanceof TransportProtocolCommon ||
                (currentState == TransportProtocolState.CONNECTED)) {
            sshOut.sendMessage(msg);

            if (currentState == TransportProtocolState.CONNECTED) {
                if (sendIgnore) {
                    byte[] count = new byte[1];
                    ConfigurationLoader.getRND().nextBytes(count);

                    byte[] rand = new byte[(count[0] & 0xFF) + 1];
                    ConfigurationLoader.getRND().nextBytes(rand);

                    SshMsgIgnore ignore = new SshMsgIgnore(new String(rand));

                    if (log.isDebugEnabled()) {
                        log.debug("Sending " + ignore.getMessageName());
                    }

                    sshOut.sendMessage(ignore);
                }
            }
        } else if (currentState == TransportProtocolState.PERFORMING_KEYEXCHANGE) {
            log.debug("Adding to message queue whilst in key exchange");

            synchronized (messageStack) {
                // Add this message to the end of the list
                messageStack.add(msg);
            }
        } else {
            throw new TransportProtocolException(
                "The transport protocol is disconnected");
        }
    }

    /**
     *
     *
     * @throws IOException
     */
    protected abstract void onStartTransportProtocol()
        throws IOException;

    /**
     *
     *
     * @param provider
     * @param properties
     *
     * @throws IOException
     */
    public void startTransportProtocol(TransportProvider provider,
        SshConnectionProperties properties) throws IOException {
        // Save the connected socket for later use
        this.provider = provider;
        this.properties = properties;

        // Start the transport layer message loop
        log.info("Starting transport protocol");
        thread = new SshThread(this, "Transport protocol", true);
        thread.start();
        onStartTransportProtocol();
    }

    /**
     *
     *
     * @return
     */
    public String getUnderlyingProviderDetail() {
        return provider.getProviderDetail();
    }

    /**
     *
     *
     * @param messageId
     * @param store
     *
     * @throws MessageNotRegisteredException
     */
    public void unregisterMessage(Integer messageId, SshMessageStore store)
        throws MessageNotRegisteredException {
        if (log.isDebugEnabled()) {
            log.debug("Unregistering message Id " + messageId.toString());
        }

        if (!messageNotifications.containsKey(messageId)) {
            throw new MessageNotRegisteredException(messageId);
        }

        SshMessageStore actual = (SshMessageStore) messageNotifications.get(messageId);

        if (!store.equals(actual)) {
            throw new MessageNotRegisteredException(messageId, store);
        }

        messageNotifications.remove(messageId);
    }

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected abstract String getDecryptionAlgorithm()
        throws AlgorithmNotAgreedException;

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected abstract String getEncryptionAlgorithm()
        throws AlgorithmNotAgreedException;

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected abstract String getInputStreamCompAlgortihm()
        throws AlgorithmNotAgreedException;

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected abstract String getInputStreamMacAlgorithm()
        throws AlgorithmNotAgreedException;

    /**
     *
     */
    protected abstract void setLocalIdent();

    /**
     *
     *
     * @return
     */
    public abstract String getLocalId();

    /**
     *
     *
     * @param msg
     */
    protected abstract void setLocalKexInit(SshMsgKexInit msg);

    /**
     *
     *
     * @return
     */
    protected abstract SshMsgKexInit getLocalKexInit();

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected abstract String getOutputStreamCompAlgorithm()
        throws AlgorithmNotAgreedException;

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected abstract String getOutputStreamMacAlgorithm()
        throws AlgorithmNotAgreedException;

    /**
     *
     *
     * @param ident
     */
    protected abstract void setRemoteIdent(String ident);

    /**
     *
     *
     * @return
     */
    public abstract String getRemoteId();

    /**
     *
     *
     * @param msg
     */
    protected abstract void setRemoteKexInit(SshMsgKexInit msg);

    /**
     *
     *
     * @return
     */
    protected abstract SshMsgKexInit getRemoteKexInit();

    /**
     *
     *
     * @param kex
     *
     * @throws IOException
     * @throws KeyExchangeException
     */
    protected abstract void performKeyExchange(SshKeyExchange kex)
        throws IOException, KeyExchangeException;

    /**
     *
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected String getKexAlgorithm() throws AlgorithmNotAgreedException {
        return determineAlgorithm(clientKexInit.getSupportedKex(),
            serverKexInit.getSupportedKex());
    }

    public boolean isConnected() {
        return (state.getValue() == TransportProtocolState.CONNECTED) ||
        (state.getValue() == TransportProtocolState.PERFORMING_KEYEXCHANGE);
    }

    /**
     *
     *
     * @throws IOException
     * @throws KeyExchangeException
     */
    protected void beginKeyExchange() throws IOException, KeyExchangeException {
        log.info("Starting key exchange");

        //state.setValue(TransportProtocolState.PERFORMING_KEYEXCHANGE);
        String kexAlgorithm = "";

        // We now have both kex inits, this is where client/server
        // implemtations take over so call abstract methods
        try {
            // Determine the key exchange algorithm
            kexAlgorithm = getKexAlgorithm();

            if (log.isDebugEnabled()) {
                log.debug("Key exchange algorithm: " + kexAlgorithm);
            }

            // Get an instance of the key exchange algortihm
            SshKeyExchange kex = (SshKeyExchange) kexs.get(kexAlgorithm);

            // Do the key exchange
            performKeyExchange(kex);

            // Record the output
            exchangeHash = kex.getExchangeHash();

            if (sessionIdentifier == null) {
                sessionIdentifier = new byte[exchangeHash.length];
                System.arraycopy(exchangeHash, 0, sessionIdentifier, 0,
                    sessionIdentifier.length);
                thread.setSessionId(sessionIdentifier);
            }

            hostKey = kex.getHostKey();
            signature = kex.getSignature();
            k = kex.getSecret();

            // Send new keys
            sendNewKeys();
            kex.reset();
        } catch (AlgorithmNotAgreedException e) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "No suitable key exchange algorithm was agreed");
            throw new KeyExchangeException(
                "No suitable key exchange algorithm could be agreed.");
        }
    }

    /**
     *
     *
     * @return
     *
     * @throws IOException
     */
    protected SshMsgKexInit createLocalKexInit() throws IOException {
        return new SshMsgKexInit(properties);
    }

    /**
     *
     */
    protected void onCorruptMac() {
        log.fatal("Corrupt Mac on Input");

        // Send a disconnect message
        sendDisconnect(SshMsgDisconnect.MAC_ERROR, "Corrupt Mac on input",
            new SshException("Corrupt Mac on Imput"));
    }

    /**
     *
     *
     * @param msg
     *
     * @throws IOException
     */
    protected abstract void onMessageReceived(SshMessage msg)
        throws IOException;

    /**
     *
     *
     * @param reason
     * @param description
     */
    protected void sendDisconnect(int reason, String description) {
        SshMsgDisconnect msg = new SshMsgDisconnect(reason, description, "");

        try {
            sendMessage(msg, this);
            stop();
        } catch (Exception e) {
            log.warn("Failed to send disconnect", e);
        }
    }

    /**
     *
     *
     * @param reason
     * @param description
     * @param error
     */
    protected void sendDisconnect(int reason, String description,
        IOException error) {
        state.setLastError(error);
        sendDisconnect(reason, description);
    }

    /**
     *
     *
     * @throws IOException
     */
    protected void sendKeyExchangeInit() throws IOException {
        setLocalKexInit(createLocalKexInit());
        sendMessage(getLocalKexInit(), this);
        state.setValue(TransportProtocolState.PERFORMING_KEYEXCHANGE);
    }

    /**
     *
     *
     * @throws IOException
     */
    protected void sendNewKeys() throws IOException {
        // Send new keys
        SshMsgNewKeys msg = new SshMsgNewKeys();
        sendMessage(msg, this);

        // Lock the outgoing algorithms so nothing else is sent untill
        // weve updated them with the new keys
        algorithmsOut.lock();

        // Do we need to hold the algorithmsOut lock during
        // the input message handling below? If not, then the
        // lock could be taken just before completeKeyExchange
        // or even moved into the completeKeyExchange method.
        // We would then not need the try-finally below (which
        // is needed for exceptions from eg the readMessage call).
        boolean hasReleasedLock = false;
        try {
            int[] filter = new int[1];
            filter[0] = SshMsgNewKeys.SSH_MSG_NEWKEYS;
            msg = (SshMsgNewKeys) readMessage(filter);
           
            if (log.isDebugEnabled()) {
                log.debug("Received " + msg.getMessageName());
            }
           
            // Release done in completeKeyExchange
            hasReleasedLock = true;
            completeKeyExchange();
        }
        finally {
            if( ! hasReleasedLock ) {
                algorithmsOut.release();
            }
        }
    }

    /**
     *
     *
     * @param encryptCSKey
     * @param encryptCSIV
     * @param encryptSCKey
     * @param encryptSCIV
     * @param macCSKey
     * @param macSCKey
     *
     * @throws AlgorithmNotAgreedException
     * @throws AlgorithmOperationException
     * @throws AlgorithmNotSupportedException
     * @throws AlgorithmInitializationException
     */
    protected abstract void setupNewKeys(byte[] encryptCSKey,
        byte[] encryptCSIV, byte[] encryptSCKey, byte[] encryptSCIV,
        byte[] macCSKey, byte[] macSCKey)
        throws AlgorithmNotAgreedException, AlgorithmOperationException,
            AlgorithmNotSupportedException, AlgorithmInitializationException;

    /**
     *
     *
     * @throws IOException
     * @throws TransportProtocolException
     */
    protected void completeKeyExchange() throws IOException {
        log.info("Completing key exchange");

        boolean hasReleasedLock = false;
        try {
            // Reset the state variables
            //completeOnNewKeys = new Boolean(false);
            log.debug("Making keys from key exchange output");

            // Make the keys
            byte[] encryptionKey = makeSshKey('C');
            byte[] encryptionIV = makeSshKey('A');
            byte[] decryptionKey = makeSshKey('D');
            byte[] decryptionIV = makeSshKey('B');
            byte[] sendMac = makeSshKey('E');
            byte[] receiveMac = makeSshKey('F');
            log.debug("Creating algorithm objects");
            setupNewKeys(encryptionKey, encryptionIV, decryptionKey,
                decryptionIV, sendMac, receiveMac);

            // Reset the key exchange
            clientKexInit = null;
            serverKexInit = null;

            //algorithmsIn.release();
            algorithmsOut.release();
            hasReleasedLock = true;

            /*
             *  Update our state, we can send all packets
             *
             */
            state.setValue(TransportProtocolState.CONNECTED);

            // Send any outstanding messages
            synchronized (messageStack) {
                Iterator it = messageStack.iterator();
                log.debug("Sending queued messages");

                while (it.hasNext()) {
                    SshMessage msg = (SshMessage) it.next();
                    sendMessage(msg, this);
                }

                messageStack.clear();
            }
        } catch (AlgorithmNotAgreedException anae) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "Algorithm not agreed");
            throw new TransportProtocolException(
                "The connection was disconnected because an algorithm could not be agreed");
        } catch (AlgorithmNotSupportedException anse) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "Application error");
            throw new TransportProtocolException(
                "The connection was disconnected because an algorithm class could not be loaded");
        } catch (AlgorithmOperationException aoe) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "Algorithm operation error");
            throw new TransportProtocolException(
                "The connection was disconnected because" +
                " of an algorithm operation error");
        } catch (AlgorithmInitializationException aie) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "Algorithm initialization error");
            throw new TransportProtocolException(
                "The connection was disconnected because" +
                " of an algorithm initialization error");
        }
        finally {
            if( ! hasReleasedLock ) {
                algorithmsOut.release();
            }
        }
    }

    /**
     *
     *
     * @return
     */
    protected List getEventHandlers() {
        return eventHandlers;
    }

    /**
     *
     *
     * @param clientAlgorithms
     * @param serverAlgorithms
     *
     * @return
     *
     * @throws AlgorithmNotAgreedException
     */
    protected String determineAlgorithm(List clientAlgorithms,
        List serverAlgorithms) throws AlgorithmNotAgreedException {
        if (log.isDebugEnabled()) {
            log.debug("Determine Algorithm");
            log.debug("Client Algorithms: " + clientAlgorithms.toString());
            log.debug("Server Algorithms: " + serverAlgorithms.toString());
        }

        String algorithmClient;
        String algorithmServer;
        Iterator itClient = clientAlgorithms.iterator();

        while (itClient.hasNext()) {
            algorithmClient = (String) itClient.next();

            Iterator itServer = serverAlgorithms.iterator();

            while (itServer.hasNext()) {
                algorithmServer = (String) itServer.next();

                if (algorithmClient.equals(algorithmServer)) {
                    log.debug("Returning " + algorithmClient);

                    return algorithmClient;
                }
            }
        }

        throw new AlgorithmNotAgreedException("Could not agree algorithm");
    }

    /**
     *
     *
     * @throws IOException
     */
    protected void startBinaryPacketProtocol() throws IOException {
        // Send our Kex Init
        sendKeyExchangeInit();

        SshMessage msg;

        // Perform a transport protocol message loop
        while (state.getValue() != TransportProtocolState.DISCONNECTED) {
            // Process incoming messages returning any transport protocol
            // messages to be handled here
            msg = processMessages();

            if (log.isDebugEnabled()) {
                log.debug("Received " + msg.getMessageName());
            }

            switch (msg.getMessageId()) {
            case SshMsgKexInit.SSH_MSG_KEX_INIT: {
                onMsgKexInit((SshMsgKexInit) msg);

                break;
            }

            case SshMsgDisconnect.SSH_MSG_DISCONNECT: {
                onMsgDisconnect((SshMsgDisconnect) msg);

                break;
            }

            case SshMsgIgnore.SSH_MSG_IGNORE: {
                onMsgIgnore((SshMsgIgnore) msg);

                break;
            }

            case SshMsgUnimplemented.SSH_MSG_UNIMPLEMENTED: {
                onMsgUnimplemented((SshMsgUnimplemented) msg);

                break;
            }

            case SshMsgDebug.SSH_MSG_DEBUG: {
                onMsgDebug((SshMsgDebug) msg);

                break;
            }

            default:
                onMessageReceived(msg);
            }
        }
    }

    /**
     *
     */
    protected final void stop() {
        onDisconnect();

        Iterator it = eventHandlers.iterator();
        TransportProtocolEventHandler eventHandler;

        while (it.hasNext()) {
            eventHandler = (TransportProtocolEventHandler) it.next();
            eventHandler.onDisconnect(this);
        }

        // Close the input/output streams
        //sshIn.close();
        if (messageStore != null) {
            messageStore.close();
        }

        // 05/01/2003 moiz change begin:
        // all close all the registerd messageStores
        SshMessageStore ms;

        for (it = messageStores.iterator(); (it != null) && it.hasNext();) {
            ms = (SshMessageStore) it.next();

            try {
                ms.close();
            } catch (Exception e) {
            }
        }

        messageStores.clear();

        // 05/01/2003 moizd change end:
        messageStore = null;

        try {
            provider.close();
        } catch (IOException ioe) {
        }
  state.setValue(TransportProtocolState.DISCONNECTED);
    }

    private byte[] makeSshKey(char chr) throws IOException {
        try {
            // Create the first 20 bytes of key data
            ByteArrayWriter keydata = new ByteArrayWriter();
            byte[] data = new byte[20];
            Hash hash = new Hash("SHA");

            // Put the dh k value
            hash.putBigInteger(k);

            // Put in the exchange hash
            hash.putBytes(exchangeHash);

            // Put in the character
            hash.putByte((byte) chr);

            // Put the exchange hash in again
            hash.putBytes(sessionIdentifier);

            // Create the fist 20 bytes
            data = hash.doFinal();
            keydata.write(data);

            // Now do the next 20
            hash.reset();

            // Put the dh k value in again
            hash.putBigInteger(k);

            // And the exchange hash
            hash.putBytes(exchangeHash);

            // Finally the first 20 bytes of data we created
            hash.putBytes(data);
            data = hash.doFinal();

            // Put it all together
            keydata.write(data);

            // Return it
            return keydata.toByteArray();
        } catch (NoSuchAlgorithmException nsae) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "Application error");
            throw new TransportProtocolException("SHA algorithm not supported");
        } catch (IOException ioe) {
            sendDisconnect(SshMsgDisconnect.KEY_EXCHANGE_FAILED,
                "Application error");
            throw new TransportProtocolException("Error writing key data");
        }
    }

    private void negotiateVersion() throws IOException {
        byte[] buf;
        int len;
        String remoteVer = "";
        log.info("Negotiating protocol version");
        log.debug("Local identification: " + getLocalId());

        // Get the local ident string by calling the abstract method, this
        // way the implementations set the correct variables for computing the
        // exchange hash
        String data = getLocalId() + "\r\n";

        // Send our version string
        provider.getOutputStream().write(data.getBytes());

        // Now wait for a reply and evaluate the ident string
        //buf = new byte[255];
        StringBuffer buffer = new StringBuffer();
        char ch;
        int MAX_BUFFER_LENGTH = 255;

        // Look for a string starting with "SSH-"
        while (!remoteVer.startsWith("SSH-") &&
                (buffer.length() < MAX_BUFFER_LENGTH)) {
            // Get the next string
            while (((ch = (char) provider.getInputStream().read()) != '\n') &&
                    (buffer.length() < MAX_BUFFER_LENGTH)) {
                buffer.append(ch);
            }

            // Set trimming off any EOL characters
            remoteVer = buffer.toString();

            // Guess the remote sides EOL by looking at the end of the ident string
            if (remoteVer.endsWith("\r")) {
                remoteEOL = EOL_CRLF;
            } else {
                remoteEOL = EOL_LF;
            }

            log.debug("EOL is guessed at " +
                ((remoteEOL == EOL_CRLF) ? "CR+LF" : "LF"));

            // Remove any \r
            remoteVer = remoteVer.trim();
        }

        // Get the index of the seperators
        int l = remoteVer.indexOf("-");
        int r = remoteVer.indexOf("-", l + 1);

        // Call abstract method so the implementations can set the
        // correct member variable
        setRemoteIdent(remoteVer.trim());

        if (log.isDebugEnabled()) {
            log.debug("Remote identification: " + getRemoteId());
        }

        // Get the version
        String remoteVersion = remoteVer.substring(l + 1, r);

        // Evaluate the version, we only support 2.0
        if (!(remoteVersion.equals("2.0") || (remoteVersion.equals("1.99")))) {
            log.fatal(
                "The remote computer does not support protocol version 2.0");
            throw new TransportProtocolException(
                "The protocol version of the remote computer is not supported!");
        }

        log.info("Protocol negotiation complete");
    }

    private void onMsgDebug(SshMsgDebug msg) {
        log.debug(msg.getMessage());
    }

    private void onMsgDisconnect(SshMsgDisconnect msg)
        throws IOException {
        log.info("The remote computer disconnected: " + msg.getDescription());
        state.setValue(TransportProtocolState.DISCONNECTED);
        state.setDisconnectReason(msg.getDescription());
        stop();
    }

    private void onMsgIgnore(SshMsgIgnore msg) {
        if (log.isDebugEnabled()) {
            log.debug("SSH_MSG_IGNORE with " +
                String.valueOf(msg.getData().length()) + " bytes of data");
        }
    }

    private void onMsgKexInit(SshMsgKexInit msg) throws IOException {
        log.debug("Received remote key exchange init message");
        log.debug(msg.toString());

        synchronized (kexLock) {
            setRemoteKexInit(msg);

            // As either party can initiate a key exchange then we
            // must check to see if we have sent our own
            if (state.getValue() != TransportProtocolState.PERFORMING_KEYEXCHANGE) {
                //if (getLocalKexInit() == null) {
                sendKeyExchangeInit();
            }

            //}
            beginKeyExchange();
        }
    }

    private void onMsgNewKeys(SshMsgNewKeys msg) throws IOException {
        // Determine whether we have completed our own
        log.debug("Received New Keys");
        //algorithmsIn.lock();

        synchronized (completeOnNewKeys) {
            if (completeOnNewKeys.booleanValue()) {
                // We need to take this lock since
                // it is released in completeKeyExchange.
                algorithmsOut.lock();
                completeKeyExchange();
            } else {
                completeOnNewKeys = new Boolean(true);
            }
        }
    }

    private void onMsgUnimplemented(SshMsgUnimplemented msg) {
        if (log.isDebugEnabled()) {
            log.debug("The message with sequence no " + msg.getSequenceNo() +
                " was reported as unimplemented by the remote end.");
        }
    }

    /**
     *
     *
     * @param filter
     *
     * @return
     *
     * @throws IOException
     */
    public SshMessage readMessage(int[] filter) throws IOException {
        byte[] msgdata = null;
        SshMessage msg;

        while (state.getValue() != TransportProtocolState.DISCONNECTED) {
            boolean hasmsg = false;

            while (!hasmsg) {
                msgdata = sshIn.readMessage();
                hasmsg = true;
            }

            Integer messageId = SshMessage.getMessageId(msgdata);

            // First check the filter
            for (int i = 0; i < filter.length; i++) {
                if (filter[i] == messageId.intValue()) {
                    if (messageStore.isRegisteredMessage(messageId)) {
                        return messageStore.createMessage(msgdata);
                    } else {
                        SshMessageStore ms = getMessageStore(messageId);
                        msg = ms.createMessage(msgdata);

                        if (log.isDebugEnabled()) {
                            log.debug("Processing " + msg.getMessageName());
                        }

                        return msg;
                    }
                }
            }

            if (messageStore.isRegisteredMessage(messageId)) {
                msg = messageStore.createMessage(msgdata);

                switch (messageId.intValue()) {
                case SshMsgDisconnect.SSH_MSG_DISCONNECT: {
                    onMsgDisconnect((SshMsgDisconnect) msg);

                    break;
                }

                case SshMsgIgnore.SSH_MSG_IGNORE: {
                    onMsgIgnore((SshMsgIgnore) msg);

                    break;
                }

                case SshMsgUnimplemented.SSH_MSG_UNIMPLEMENTED: {
                    onMsgUnimplemented((SshMsgUnimplemented) msg);

                    break;
                }

                case SshMsgDebug.SSH_MSG_DEBUG: {
                    onMsgDebug((SshMsgDebug) msg);

                    break;
                }

                default: // Exception not allowed
                    throw new IOException(
                        "Unexpected transport protocol message");
                }
            } else {
                throw new IOException("Unexpected message received");
            }
        }

        throw new IOException("The transport protocol disconnected");
    }

    /**
     *
     *
     * @return
     *
     * @throws IOException
     */
    protected SshMessage processMessages() throws IOException {
        byte[] msgdata = null;
        SshMessage msg;
        SshMessageStore ms;

        while (state.getValue() != TransportProtocolState.DISCONNECTED) {
            long currentTime = System.currentTimeMillis();

            transferredKB =  sshIn.getNumBytesTransfered()/1024
                          + sshOut.getNumBytesTransfered()/1024;
          
            long kbLimit = transferredKB - lastTriggeredKB;
           
            if (((currentTime - startTime) > kexTimeout) ||
                (kbLimit > kexTransferLimitKB) )
            {
                //    ((sshIn.getNumBytesTransfered() +
                //    sshOut.getNumBytesTransfered()) > kexTransferLimit)) {
              startTime     = currentTime;
              lastTriggeredKB = transferredKB;
              if (log.isDebugEnabled()) {
                log.info("rekey");
              }
              sendKeyExchangeInit();
            }
           
            boolean hasmsg = false;

            while (!hasmsg) {
                try {
                    msgdata = sshIn.readMessage();
                    hasmsg = true;
                } catch (InterruptedIOException ex /*SocketTimeoutException ex*/) {
                    log.info("Possible timeout on transport inputstream");

                    Iterator it = eventHandlers.iterator();
                    TransportProtocolEventHandler eventHandler;
                   
                    while (it.hasNext()) {
                        eventHandler = (TransportProtocolEventHandler) it.next();
                        eventHandler.onSocketTimeout(this /*,
                        provider.isConnected()*/);
                    }
                }
            }

            Integer messageId = SshMessage.getMessageId(msgdata);

            if (!messageStore.isRegisteredMessage(messageId)) {
                try {
                    ms = getMessageStore(messageId);
                    msg = ms.createMessage(msgdata);

                    if (log.isDebugEnabled()) {
                        log.info("Received " + msg.getMessageName());
                    }

                    ms.addMessage(msg);
                } catch (MessageNotRegisteredException mnre) {
                    log.info("Unimplemented message received " +
                        String.valueOf(messageId.intValue()));
                    msg = new SshMsgUnimplemented(sshIn.getSequenceNo());
                    sendMessage(msg, this);
                }
            } else {
                return messageStore.createMessage(msgdata);
            }
        }

        throw new IOException("The transport protocol has disconnected");
    }

    /**
     *
     *
     * @param store
     *
     * @throws MessageAlreadyRegisteredException
     */
    public void addMessageStore(SshMessageStore store)
        throws MessageAlreadyRegisteredException {
        messageStores.add(store);
    }

    private SshMessageStore getMessageStore(Integer messageId)
        throws MessageNotRegisteredException {
        SshMessageStore ms;

        for (Iterator it = messageStores.iterator();
                (it != null) && it.hasNext();) {
            ms = (SshMessageStore) it.next();

            if (ms.isRegisteredMessage(messageId)) {
                return ms;
            }
        }

        throw new MessageNotRegisteredException(messageId);
    }

    /**
     *
     *
     * @param ms
     */
    public void removeMessageStore(SshMessageStore ms) {
        messageStores.remove(ms);
    }
}
TOP

Related Classes of com.sshtools.j2ssh.transport.TransportProtocolCommon

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.