Package ch.ethz.inf.vs.scandium

Source Code of ch.ethz.inf.vs.scandium.DTLSConnector$RetransmitTask

/*******************************************************************************
* Copyright (c) 2014, Institute for Pervasive Computing, ETH Zurich.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
* 3. Neither the name of the Institute nor the names of its contributors
*    may be used to endorse or promote products derived from this software
*    without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*
* This file is part of the Scandium (Sc) Security for Californium.
******************************************************************************/
package ch.ethz.inf.vs.scandium;

import java.io.File;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Map.Entry;
import java.util.logging.Level;
import java.util.logging.Logger;

import ch.ethz.inf.vs.elements.ConnectorBase;
import ch.ethz.inf.vs.elements.RawData;
import ch.ethz.inf.vs.scandium.dtls.AlertMessage;
import ch.ethz.inf.vs.scandium.dtls.AlertMessage.AlertDescription;
import ch.ethz.inf.vs.scandium.dtls.AlertMessage.AlertLevel;
import ch.ethz.inf.vs.scandium.dtls.ApplicationMessage;
import ch.ethz.inf.vs.scandium.dtls.ByteArrayUtils;
import ch.ethz.inf.vs.scandium.dtls.ClientHandshaker;
import ch.ethz.inf.vs.scandium.dtls.ClientHello;
import ch.ethz.inf.vs.scandium.dtls.ContentType;
import ch.ethz.inf.vs.scandium.dtls.DTLSFlight;
import ch.ethz.inf.vs.scandium.dtls.DTLSMessage;
import ch.ethz.inf.vs.scandium.dtls.DTLSSession;
import ch.ethz.inf.vs.scandium.dtls.FragmentedHandshakeMessage;
import ch.ethz.inf.vs.scandium.dtls.HandshakeException;
import ch.ethz.inf.vs.scandium.dtls.HandshakeMessage;
import ch.ethz.inf.vs.scandium.dtls.Handshaker;
import ch.ethz.inf.vs.scandium.dtls.Record;
import ch.ethz.inf.vs.scandium.dtls.ResumingClientHandshaker;
import ch.ethz.inf.vs.scandium.dtls.ResumingServerHandshaker;
import ch.ethz.inf.vs.scandium.dtls.ServerHandshaker;
import ch.ethz.inf.vs.scandium.dtls.ServerHello;
import ch.ethz.inf.vs.scandium.util.ScProperties;

public class DTLSConnector extends ConnectorBase {
 
  /*
   * Note: DTLSConnector can also implement the interface Connector instead of
   * extending ConnectorBase
   */
 
  private final static Logger LOGGER = Logger.getLogger(DTLSConnector.class.getCanonicalName());

  public static final String KEY_STORE_LOCATION = ScProperties.std.getProperty("KEY_STORE_LOCATION".replace("/", File.pathSeparator));
  public static final String TRUST_STORE_LOCATION = ScProperties.std.getProperty("TRUST_STORE_LOCATION".replace("/", File.pathSeparator));
 
  private int maxFragmentLength = ScProperties.std.getInt("MAX_FRAGMENT_LENGTH");
 
  /** The overhead for the record header (13 bytes) and the handshake header (12 bytes) is 25 bytes */
  private int maxPayloadSize = maxFragmentLength + 25;

  /** The initial timer value for retransmission; rfc6347, section: 4.2.4.1 */
  private int retransmission_timeout = ScProperties.std.getInt("RETRANSMISSION_TIMEOUT");
 
  /** Maximal number of retransmissions before the attempt to transmit a message is canceled */
  private int max_retransmit = ScProperties.std.getInt("MAX_RETRANSMIT");
 
  private final InetSocketAddress address;
 
  private DatagramSocket socket;
 
  /** The timer daemon to schedule retransmissions. */
  private Timer timer = new Timer(true); // run as daemon
 
  /** Storing sessions according to peer-addresses */
  private ConcurrentHashMap<String, DTLSSession> dtlsSessions = new ConcurrentHashMap<String, DTLSSession>();

  /** Storing handshakers according to peer-addresses. */
  private ConcurrentHashMap<String, Handshaker> handshakers = new ConcurrentHashMap<String, Handshaker>();

  /** Storing flights according to peer-addresses. */
  private ConcurrentHashMap<String, DTLSFlight> flights = new ConcurrentHashMap<String, DTLSFlight>();
 
  public DTLSConnector(InetSocketAddress address) {
    super(address);
    this.address = address;
  }
 
  /**
   * Close the DTLS session with all peers.
   */
  public void close() {
    for (DTLSSession session : dtlsSessions.values()) {
      this.close(session.getPeer());
    }
  }
 
  /**
   * Close the DTLS session with the given peer.
   *
   * @param peerAddress the remote endpoint of the session to close
   */
  public void close(InetSocketAddress peerAddress) {
    DTLSSession session = dtlsSessions.get(addressToKey(peerAddress));
   
    if (session != null) {
      DTLSMessage closeNotify = new AlertMessage(AlertLevel.WARNING, AlertDescription.CLOSE_NOTIFY);
     
      DTLSFlight flight = new DTLSFlight();
      flight.addMessage(new Record(ContentType.ALERT, session.getWriteEpoch(), session.getSequenceNumber(), closeNotify, session));
      flight.setRetransmissionNeeded(false);
     
      cancelPreviousFlight(peerAddress);
 
      flight.setPeerAddress(peerAddress);
      flight.setSession(session);
 
      LOGGER.fine("Sending CLOSE_NOTIFY to " + peerAddress.toString());
     
      sendFlight(flight);
    } else {
      LOGGER.warning("Session to close not found: " + peerAddress.toString());
    }
  }
 
  @Override
  public synchronized void start() throws IOException {
    socket = new DatagramSocket(address.getPort(), address.getAddress());
    super.start();
   
    LOGGER.info("DLTS connector listening on "+address);
  }
 
  @Override
  public synchronized void stop() {
    this.close();
    this.socket.close();
    super.stop();
  }
 
  // TODO: We should not return null
  @Override
  protected RawData receiveNext() throws Exception {
    byte[] buffer = new byte[maxPayloadSize];
    DatagramPacket packet = new DatagramPacket(buffer, buffer.length);
    socket.receive(packet);
   
    if (packet.getLength() == 0)
      return null;

    InetSocketAddress peerAddress = new InetSocketAddress(packet.getAddress(), packet.getPort());
    LOGGER.finest(" => find handshaker for key "+peerAddress.toString());
    DTLSSession session = dtlsSessions.get(addressToKey(peerAddress));
    Handshaker handshaker = handshakers.get(addressToKey(peerAddress));
    byte[] data = Arrays.copyOfRange(packet.getData(), packet.getOffset(), packet.getLength());

    try {
      List<Record> records = Record.fromByteArray(data);

      for (Record record : records) {
        record.setSession(session);

        RawData raw = null;

        ContentType contentType = record.getType();
        LOGGER.finest(" => contentType: "+contentType);
        DTLSFlight flight = null;
        switch (contentType) {
        case APPLICATION_DATA:
          if (session == null) {
            // There is no session available, so no application data
            // should be received, discard it
            LOGGER.info("Discarded unexpected application data message from " + peerAddress.toString());
            return null;
          }
          // at this point, the current handshaker is not needed
          // anymore, remove it
          handshakers.remove(addressToKey(peerAddress));

          ApplicationMessage applicationData = (ApplicationMessage) record.getFragment();
          raw = new RawData(applicationData.getData());
          break;

        case ALERT:
          AlertMessage alert = (AlertMessage) record.getFragment();
          switch (alert.getDescription()) {
          case CLOSE_NOTIFY:
            session.setActive(false);
           
            LOGGER.fine("Received CLOSE_NOTIFY from " + peerAddress.toString());
           
            // server must reply with CLOSE_NOTIFY
            if (!session.isClient()) {
              DTLSMessage closeNotify = new AlertMessage(AlertLevel.WARNING, AlertDescription.CLOSE_NOTIFY);
              flight = new DTLSFlight();
              flight.addMessage(new Record(ContentType.ALERT, session.getWriteEpoch(), session.getSequenceNumber(), closeNotify, session));
              flight.setRetransmissionNeeded(false);
            }
           
            if (dtlsSessions.remove(addressToKey(peerAddress))!=null) {
              LOGGER.info("Closed session with peer: " + peerAddress.toString());
            } else {
              LOGGER.warning("Session to close not found: " + peerAddress.toString());
            }
           
            break;
         
          // remote implementation might use any alert (e.g., against padding oracle attack)
          default:
            LOGGER.warning(alert.getDescription() + " with " + peerAddress.toString());

            // cleaning up
            cancelPreviousFlight(peerAddress);
            dtlsSessions.remove(addressToKey(peerAddress));
            handshakers.remove(addressToKey(peerAddress));
            break;
           
            //TODO somehow tell CoAP endpoint to cancel
          }
          break;
        case CHANGE_CIPHER_SPEC:
        case HANDSHAKE:
          LOGGER.finest(" => handshaker: "+handshaker);
          if (handshaker == null) {
            /*
             * A handshake message received, but no handshaker
             * available: this must mean that we either received
             * a HelloRequest (from server) or a ClientHello
             * (from client) => initialize appropriate
             * handshaker type
             */

            HandshakeMessage handshake = (HandshakeMessage) record.getFragment();

            switch (handshake.getMessageType()) {
            case HELLO_REQUEST:
              /*
               * Client side: server desires a re-handshake
               */
              if (session == null) {
                // create new session
                session = new DTLSSession(peerAddress, true);
                // store session according to peer address
                dtlsSessions.put(addressToKey(peerAddress), session);

                LOGGER.info("Created new session as client with peer: " + peerAddress.toString());
              }
              handshaker = new ClientHandshaker(peerAddress, null, session);
              handshakers.put(addressToKey(peerAddress), handshaker);
              LOGGER.finest("Stored re-handshaker: " + handshaker.toString() + " for " + peerAddress.toString());
              break;

            case CLIENT_HELLO:
              /*
               * Server side: server received a client hello:
               * check first if client wants to resume a
               * session (message must contain session
               * identifier) and then check if particular
               * session still available, otherwise conduct
               * full handshake with fresh session.
               */

              if (!(handshake instanceof FragmentedHandshakeMessage)) {
                // check if session identifier set
                ClientHello clientHello = (ClientHello) handshake;
                session = getSessionByIdentifier(clientHello.getSessionId().getSessionId());
              }
             
              if (session == null) {
                // create new session
                session = new DTLSSession(peerAddress, false);
                // store session according to peer address
                dtlsSessions.put(addressToKey(peerAddress), session);

                LOGGER.info("Created new session as server with peer: " + peerAddress.toString());
                handshaker = new ServerHandshaker(peerAddress, session);
              } else {
                handshaker = new ResumingServerHandshaker(peerAddress, session);
              }
              handshakers.put(addressToKey(peerAddress), handshaker);
              LOGGER.finest("Stored handshaker: " + handshaker.toString() + " for " + peerAddress.toString());
              break;

            default:
              LOGGER.severe("Received unexpected first handshake message (type="+handshake.getMessageType()+") from " + peerAddress.toString() + ":\n" + handshake.toString());
              break;
            }
          }
          flight = handshaker.processMessage(record);
          break;

        default:
          LOGGER.severe("Received unknown DTLS record from " + peerAddress.toString() + ":\n" + ByteArrayUtils.toHexString(data));
          break;
        }

        if (flight != null) {
          cancelPreviousFlight(peerAddress);

          flight.setPeerAddress(peerAddress);
          flight.setSession(session);

          if (flight.isRetransmissionNeeded()) {
            flights.put(addressToKey(peerAddress), flight);
            scheduleRetransmission(flight);
          }

          sendFlight(flight);
        }

        if (raw != null) {

          raw.setAddress(packet.getAddress());
          raw.setPort(packet.getPort());

          return raw;
        }
      }

    } catch (Exception e) {
      /*
       * If it is a known handshake failure, send the specific Alert,
       * otherwise the general Handshake_Failure Alert.
       */
      DTLSFlight flight = new DTLSFlight();
      flight.setRetransmissionNeeded(false);
      flight.setPeerAddress(peerAddress);
      flight.setSession(session);
     
      AlertMessage alert;
      if (e instanceof HandshakeException) {
        alert = ((HandshakeException) e).getAlert();
        LOGGER.severe("Handshake Exception (" + peerAddress.toString() + "): " + e.getMessage());
      } else {
        alert = new AlertMessage(AlertLevel.FATAL, AlertDescription.HANDSHAKE_FAILURE);
        LOGGER.log(Level.SEVERE, "Unknown Exception (" + peerAddress + ").", e);
      }

      LOGGER.log(Level.SEVERE, "Datagram which lead to exception (" + peerAddress + "): " + ByteArrayUtils.toHexString(data), e);
     
      if (session == null) {
        // if the first received message failed, no session has been set
        session = new DTLSSession(peerAddress, false);
      }
      cancelPreviousFlight(peerAddress);
     
      flight.addMessage(new Record(ContentType.ALERT, session.getWriteEpoch(), session.getSequenceNumber(), alert, session));
      sendFlight(flight);
    } // receive()
    return null;
  }

  @Override
  protected void sendNext(RawData message) throws Exception {
   
    InetSocketAddress peerAddress = message.getInetSocketAddress();
    LOGGER.fine("Sending message to " + peerAddress);

    DTLSSession session = dtlsSessions.get(addressToKey(peerAddress));
   
    /*
     * When the DTLS layer receives a message from an upper layer, there is
     * either a already a DTLS session available with the peer or a new
     * handshake must be executed. If a session is available and active, the
     * message will be encrypted and send to the peer, otherwise a short
     * handshake will be initiated.
     */
    Record encryptedMessage = null;
    Handshaker handshaker = null;

    if (session == null) {
      // no session with endpoint available, create new empty session,
      // start fresh handshake
      session = new DTLSSession(peerAddress, true);
      dtlsSessions.put(addressToKey(peerAddress), session);
      handshaker = new ClientHandshaker(peerAddress, message, session);

    } else {

      if (session.isActive()) {
        // session to peer is active, send encrypted message
        DTLSMessage fragment = new ApplicationMessage(message.getBytes());
        encryptedMessage = new Record(ContentType.APPLICATION_DATA, session.getWriteEpoch(), session.getSequenceNumber(), fragment, session);
       
      } else {
        // try resuming session
        handshaker = new ResumingClientHandshaker(peerAddress, message, session);
      }
    }
   
    DTLSFlight flight = new DTLSFlight();
    // the CoAP message can not be encrypted since no session with peer
    // available, start DTLS handshake protocol
    if (handshaker != null) {
      // get starting handshake message
      handshakers.put(addressToKey(peerAddress), handshaker);
      LOGGER.finest("Stored handshaker on send: " + handshaker.toString() + " for " + peerAddress.toString());

      flight = handshaker.getStartHandshakeMessage();
      flights.put(addressToKey(peerAddress), flight);
      scheduleRetransmission(flight);
    }
   
    // the CoAP message has been encrypted and can be sent to the peer
    if (encryptedMessage != null) {
      flight.addMessage(encryptedMessage);
    }
   
    flight.setPeerAddress(peerAddress);
    flight.setSession(session);
    sendFlight(flight);
  }
 
  /**
   * Searches through all stored sessions and returns that session which
   * matches the session identifier or <code>null</code> if no such session
   * available. This method is used when the server receives a
   * {@link ClientHello} containing a session identifier indicating that the
   * client wants to resume a previous session. If a matching session is
   * found, the server will resume the session with a abbreviated handshake,
   * otherwise a full handshake (with new session identifier in
   * {@link ServerHello}) is conducted.
   *
   * @param sessionID
   *            the client's session identifier.
   * @return the session which matches the session identifier or
   *         <code>null</code> if no such session exists.
   */
  private DTLSSession getSessionByIdentifier(byte[] sessionID) {
    if (sessionID == null) {
      return null;
    }
   
    for (Entry<String, DTLSSession> entry : dtlsSessions.entrySet()) {
      // FIXME session identifiers may not be set, when the handshake failed after the initial message
      // these sessions must be deleted when this happens
      try {
        byte[] id = entry.getValue().getSessionIdentifier().getSessionId();
        if (Arrays.equals(sessionID, id)) {
          return entry.getValue();
        }
      } catch (Exception e) {
        continue;
      }
    }
   
    for (DTLSSession session:dtlsSessions.values()) {
      try {
        byte[] id = session.getSessionIdentifier().getSessionId();
        if (Arrays.equals(sessionID, id)) {
          return session;
        }
      } catch (Exception e) {
        continue;
      }
    }
   
    return null;
  }
 
  private void sendFlight(DTLSFlight flight) {
    byte[] payload = new byte[] {};
   
    // put as many records into one datagram as allowed by the block size
    List<DatagramPacket> datagrams = new ArrayList<DatagramPacket>();

    for (Record record : flight.getMessages()) {
      if (flight.getTries() > 0) {
        // adjust the record sequence number
        int epoch = record.getEpoch();
        record.setSequenceNumber(flight.getSession().getSequenceNumber(epoch));
      }
     
      byte[] recordBytes = record.toByteArray();
      if (payload.length + recordBytes.length > maxPayloadSize) {
        // can't add the next record, send current payload as datagram
        DatagramPacket datagram = new DatagramPacket(payload, payload.length, flight.getPeerAddress().getAddress(), flight.getPeerAddress().getPort());
        datagrams.add(datagram);
        payload = new byte[] {};
      }

      // retrieve payload
      payload = ByteArrayUtils.concatenate(payload, recordBytes);
    }
    DatagramPacket datagram = new DatagramPacket(payload, payload.length, flight.getPeerAddress().getAddress(), flight.getPeerAddress().getPort());
    datagrams.add(datagram);

    // send it over the UDP socket
    try {
      for (DatagramPacket datagramPacket : datagrams) {
        socket.send(datagramPacket);
      }
     
    } catch (IOException e) {
      LOGGER.log(Level.SEVERE, "Could not send the datagram", e);
    }
  }
 
  private void handleTimeout(DTLSFlight flight) {

    // set DTLS retransmission maximum
    final int max = max_retransmit;

    // check if limit of retransmissions reached
    if (flight.getTries() < max) {

      flight.incrementTries();

      sendFlight(flight);

      // schedule next retransmission
      scheduleRetransmission(flight);

    } else {
      LOGGER.fine("Maximum retransmissions reached.");
    }
  }

  private void scheduleRetransmission(DTLSFlight flight) {

    // cancel existing schedule (if any)
    if (flight.getRetransmitTask() != null) {
      flight.getRetransmitTask().cancel();
    }
   
    if (flight.isRetransmissionNeeded()) {
      // create new retransmission task
      flight.setRetransmitTask(new RetransmitTask(flight));
 
      // calculate timeout using exponential back-off
      if (flight.getTimeout() == 0) {
        // use initial timeout
        flight.setTimeout(retransmission_timeout);
      } else {
        // double timeout
        flight.incrementTimeout();
      }
 
      // schedule retransmission task
      timer.schedule(flight.getRetransmitTask(), flight.getTimeout());
    }
  }
 
  /**
   * Cancels the retransmission timer of the previous flight (if available).
   *
   * @param peerAddress
   *            the peer's address.
   */
  private void cancelPreviousFlight(InetSocketAddress peerAddress) {
    DTLSFlight previousFlight = flights.get(addressToKey(peerAddress));
    if (previousFlight != null) {
      previousFlight.getRetransmitTask().cancel();
      previousFlight.setRetransmitTask(null);
      flights.remove(addressToKey(peerAddress));
    }
  }

  @Override
  public String getName() {
    return "DTLS";
  }

  public InetSocketAddress getAddress() {
    if (socket == null) return getLocalAddr();
    else return new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort());
  }
 
  private class RetransmitTask extends TimerTask {

    private DTLSFlight flight;

    RetransmitTask(DTLSFlight flight) {
      this.flight = flight;
    }

    @Override
    public void run() {
      handleTimeout(flight);
    }
  }
 
  private String addressToKey(InetSocketAddress address) {
    return address.toString().split("/")[1];
  }

}
TOP

Related Classes of ch.ethz.inf.vs.scandium.DTLSConnector$RetransmitTask

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.