Package org.ardverk.dht.io

Source Code of org.ardverk.dht.io.MessageDispatcher$MessageEntity

/*
* Copyright 2009-2012 Roger Kapsi
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.ardverk.dht.io;

import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import org.ardverk.collection.FixedSizeHashSet;
import org.ardverk.concurrent.EventUtils;
import org.ardverk.concurrent.ExecutorUtils;
import org.ardverk.dht.KUID;
import org.ardverk.dht.io.transport.Endpoint;
import org.ardverk.dht.io.transport.Transport;
import org.ardverk.dht.io.transport.TransportCallback;
import org.ardverk.dht.message.Message;
import org.ardverk.dht.message.MessageFactory;
import org.ardverk.dht.message.MessageId;
import org.ardverk.dht.message.RequestMessage;
import org.ardverk.dht.message.ResponseMessage;
import org.ardverk.dht.routing.Contact;
import org.ardverk.io.IoUtils;
import org.ardverk.lang.Bindable;
import org.ardverk.lang.BindableUtils;
import org.ardverk.lang.Precoditions;
import org.ardverk.lang.TimeStamp;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.inject.ImplementedBy;


/**
* The {@link MessageDispatcher} is responsible for sending messages over
* a given {@link Transport} and keeping track of the messages.
*/
@ImplementedBy(DefaultMessageDispatcher.class)
public abstract class MessageDispatcher
    implements Bindable<Transport>, Closeable {
 
  private static final Logger LOG
    = LoggerFactory.getLogger(MessageDispatcher.class);
 
  private static final ScheduledExecutorService EXECUTOR
    = ExecutorUtils.newSingleThreadScheduledExecutor(
      "MessageDispatcherThread");
 
  private final TransportCallback callback = new TransportCallback() {
    @Override
    public void messageSent(KUID contactId, Message message) {
      MessageDispatcher.this.messageSent(contactId, message);
    }
   
    @Override
    public ResponseMessage handleRequest(RequestMessage request)
        throws IOException {
      return MessageDispatcher.this.handleRequest(request);
    }

    @Override
    public boolean handleResponse(ResponseMessage response)
        throws IOException {
      return MessageDispatcher.this.handleResponse(response);
    }

    @Override
    public void handleException(Endpoint endpoint, Message message, Throwable t) {
      MessageDispatcher.this.handleException(endpoint, message, t);
    }
  };
 
  private final List<MessageListener> listeners
    = new CopyOnWriteArrayList<MessageListener>();
 
  private final MessageEntityManager entityManager
    = new MessageEntityManager();
   
  private final ScheduledExecutorService executor;
 
  private final MessageFactory factory;
 
  private final ResponseChecker checker;
 
  private Transport transport = null;
 
  /**
   * Creates a {@link MessageDispatcher}.
   */
  public MessageDispatcher(MessageFactory factory) {
    this(EXECUTOR, factory);
  }
 
  /**
   * Creates a {@link MessageDispatcher} with a custom
   * {@link ScheduledExecutorService} that is used to
   * keep for timing out requests.
   */
  public MessageDispatcher(ScheduledExecutorService executor,
      MessageFactory factory) {
    this.executor = executor;
    this.factory = factory;
   
    // TODO: Is memorizing the 512 most recently received MessageIds
    // too much or too little?
    this.checker = new ResponseChecker(factory, 512);
  }
 
  /**
   * Returns the {@link Transport}
   */
  public synchronized Transport getTransport() {
    return transport;
  }
 
  @Override
  public synchronized boolean isBound() {
    return transport != null;
  }
 
  @Override
  public synchronized void bind(Transport transport) throws IOException {
    if (transport == null) {
      throw new NullPointerException("transport");
    }
   
    if (isBound()) {
      throw new IOException();
    }
   
    transport.bind(callback);
    this.transport = transport;
  }
 
  @Override
  public synchronized void unbind() {
    unbind(false);
  }
 
  /**
   * Unbinds the {@link Transport} and closes it optionally.
   */
  private synchronized void unbind(boolean close) {
    if (transport != null) {
      BindableUtils.unbind(transport);
     
      if (close) {
        IoUtils.close(transport);
      }
     
      transport = null;
    }
  }
 
  @Override
  public void close() {
    unbind(true);
    entityManager.close();
  }
 
  /**
   * Returns the {@link MessageFactory}
   */
  public MessageFactory getMessageFactory() {
    return factory;
  }
 
  /**
   * Sends a {@link RequestMessage} to the a {@link Contact} with the
   * given {@link KUID}.
   *
   * <p>NOTE: The destination {@link SocketAddress} is encoded in the
   * {@link RequestMessage}. The {@link KUID} is used to validate
   * {@link ResponseMessage}s.
   */
  public void send(MessageCallback callback,
      KUID contactId, RequestMessage request,
      long timeout, TimeUnit unit) throws IOException {
   
    if (callback != null) {
      RequestEntity entity = new RequestEntity(
          contactId, request);
      entityManager.add(callback, entity, timeout, unit);
    }
   
    Transport transport = null;
    synchronized (this) {
      transport = this.transport;
    }
   
    if (transport == null) {
      throw new IOException();
    }
   
    transport.send(contactId, request, timeout, unit);
    fireMessageSent(contactId, request);
  }
 
  /**
   * Callback method for outgoing {@link Message} that failed to be sent.
   * Returns {@code true} if the {@link Throwable} was handled or not.
   */
  public boolean handleException(Endpoint endpoint,
      Message message, Throwable t) {
    MessageId messageId = message.getMessageId();
    MessageEntity entity = entityManager.get(messageId);
   
    if (entity != null) {
      entity.handleException(t);
      return true;
    }
    return false;
  }
 
  /**
   * Callback method for all outgoing {@link Message}s that have been sent.
   */
  public void messageSent(KUID contactId, Message message) {
  }
 
  /**
   * Callback method for incoming {@link ResponseMessage}s.
   */
  public boolean handleResponse(ResponseMessage response) throws IOException {
    fireMessageReceived(response);
    return handleResponse0(response);
  }
 
  private boolean handleResponse0(ResponseMessage response) throws IOException {
    boolean success = false;
    if (checker.check(response)) {
      MessageEntity entity = entityManager.get(response);
      if (entity != null) {
        success = entity.handleResponse(response);
      } else {
        lateResponse(response);
      }
    }
    return success;
  }
 
  /**
   * Callback method for incoming {@link RequestMessage}s.
   */
  public final ResponseMessage handleRequest(RequestMessage request) throws IOException {
    fireMessageReceived(request);
    ResponseMessage response = handleRequest0(request);
   
    if (response != null) {
      KUID contactId = request.getContact().getId();
      fireMessageSent(contactId, response);
    }
    return response;
  }
 
  protected abstract ResponseMessage handleRequest0(RequestMessage request) throws IOException;
 
  /**
   * Callback method for late incoming {@link ResponseMessage}s.
   */
  protected abstract void lateResponse(ResponseMessage response) throws IOException;
 
  /**
   * Callback method for late incoming {@link ResponseMessage}s.
   */
  protected boolean handleResponse(MessageCallback callback,
      RequestEntity entity, ResponseMessage response,
      long time, TimeUnit unit) throws IOException {
    return callback.handleResponse(entity, response, time, unit);
  }
 
  /**
   * Callback method for timeouts.
   */
  protected void handleTimeout(MessageCallback callback,
      RequestEntity entity, long time, TimeUnit unit)
        throws IOException {
    callback.handleTimeout(entity, time, unit);
  }
 
  /**
   * Callback method for illegal {@link ResponseMessage}s.
   */
  protected void handleIllegalResponse(MessageCallback callback,
      RequestEntity entity, ResponseMessage response,
      long time, TimeUnit unit) throws IOException {
    callback.handleIllegalResponse(entity, response, time, unit);
  }
 
  /**
   * Callback for {@link Exception}s.
   */
  protected void handleException(MessageCallback callback,
      RequestEntity entity, Throwable t) {
    callback.handleException(entity, t);
  }
 
  /**
   * Adds the given {@link MessageListener}.
   */
  public void addMessageListener(MessageListener l) {
    listeners.add(Precoditions.notNull(l, "l"));
  }
 
  /**
   * Removes the given {@link MessageListener}.
   */
  public void removeMessageListener(MessageListener l) {
    listeners.remove(l);
  }
 
  /**
   * Returns all {@link MessageListener}s.
   */
  public MessageListener[] getMessageListeners() {
    return listeners.toArray(new MessageListener[0]);
  }
 
  /**
   * Fires a message sent event.
   */
  protected void fireMessageSent(final KUID contactId, final Message message) {
   
    if (!listeners.isEmpty()) {
      Runnable event = new Runnable() {
        @Override
        public void run() {
          for (MessageListener l : listeners) {
            l.handleMessageSent(contactId, message);
          }
        }
      };
      EventUtils.fireEvent(event);
    }
  }
 
  /**
   * Fires a message received event.
   */
  protected void fireMessageReceived(final Message message) {
    if (!listeners.isEmpty()) {
      Runnable event = new Runnable() {
        @Override
        public void run() {
          for (MessageListener l : listeners) {
            l.handleMessageReceived(message);
          }
        }
      };
      EventUtils.fireEvent(event);
    }
  }
 
  /**
   * The {@link MessageEntityManager} keeps track of {@link RequestEntity}s
   * and their {@link MessageCallback}s. It's also responsible for timing
   * out {@link RequestMessage} for which we haven't received any responses
   * within a defined time frame.
   */
  private class MessageEntityManager implements Closeable {
   
    private final Map<MessageId, MessageEntity> callbacks
      = Collections.synchronizedMap(new HashMap<MessageId, MessageEntity>());
   
    private boolean open = true;
   
    @Override
    public void close() {
      synchronized (callbacks) {
        if (!open) {
          return;
        }
       
        open = false;
       
        for (MessageEntity entity : callbacks.values()) {
          entity.cancel();
        }
       
        callbacks.clear();
      }
    }
   
    /**
     * Adds a {@link RequestEntity} and its {@link MessageCallback}.
     */
    public void add(MessageCallback callback, RequestEntity entity,
        long timeout, TimeUnit unit) {
     
      final MessageId messageId = entity.getMessageId();
     
      synchronized (callbacks) {
       
        if (!open) {
          throw new IllegalStateException();
        }
       
        if (callbacks.containsKey(messageId)) {
          throw new IllegalArgumentException("messageId=" + messageId);
        }
       
        Runnable task = new Runnable() {
          @Override
          public void run() {
            MessageEntity entity
              = callbacks.remove(messageId);
           
            if (entity != null) {
              try {
                entity.handleTimeout();
              } catch (IOException err) {
                LOG.error("IOException", err);
              }
            }
          }
        };
       
        ScheduledFuture<?> future
          = executor.schedule(task, timeout, unit);
       
        MessageEntity messageEntity = new MessageEntity(
            future, callback, entity);
        callbacks.put(messageId, messageEntity);
      }
    }
   
    /**
     * Returns a {@link MessageEntity} for the given {@link ResponseMessage}.
     */
    public MessageEntity get(ResponseMessage message) {
      return get(message.getMessageId());
    }
   
    /**
     * Returns a {@link MessageEntity} for the given {@link MessageId}.
     */
    public MessageEntity get(MessageId messageId) {
      return callbacks.remove(messageId);
    }
  }

  /**
   * A {@link MessageEntity} represents a {@link RequestMessage} we've sent
   * and provides the infrastructure to process the {@link ResponseMessage}.
   */
  private class MessageEntity {
   
    private final TimeStamp creationTime = TimeStamp.now();
   
    private final ScheduledFuture<?> future;

    private final MessageCallback callback;
   
    private final RequestEntity entity;
   
    private final AtomicBoolean open = new AtomicBoolean(true);
   
    private MessageEntity(ScheduledFuture<?> future,
        MessageCallback callback,
        RequestEntity entity) {
     
      this.future = Precoditions.notNull(future, "future");
      this.callback = Precoditions.notNull(callback, "callback");
      this.entity = Precoditions.notNull(entity, "entity");
    }

    /**
     * Cancels the {@link MessageEntity}.
     */
    public boolean cancel() {
      future.cancel(true);
      return open.getAndSet(false);
    }
   
    /**
     * Called if a {@link ResponseMessage} was received.
     */
    public boolean handleResponse(ResponseMessage response) throws IOException {
      boolean success = false;
      if (cancel()) {
        long time = creationTime.getAgeInMillis();
       
        if (entity.check(response)) {
          success = MessageDispatcher.this.handleResponse(callback, entity,
              response, time, TimeUnit.MILLISECONDS);
        } else {
          MessageDispatcher.this.handleIllegalResponse(callback,
              entity, response, time, TimeUnit.MILLISECONDS);
        }
      }
      return success;
    }

    /**
     * Called if a timeout occurred (i.e. we didn't receive a
     * {@link ResponseMessage} within in the predefined time).
     */
    public void handleTimeout() throws IOException {
      if (cancel()) {
       
        long time = creationTime.getAgeInMillis();
        MessageDispatcher.this.handleTimeout(callback, entity,
            time, TimeUnit.MILLISECONDS);
      }
    }
   
    /**
     * Called if an exception occured.
     */
    public void handleException(Throwable t) {
      if (cancel()) {
        MessageDispatcher.this.handleException(callback, entity, t);
      }
    }
  }
 
  /**
   * The {@link ResponseChecker} makes sure {@link ResponseMessage}s fulfill
   * certain requirements before they're considered for further processing.
   */
  private static class ResponseChecker {
   
    private static final Logger LOG
      = LoggerFactory.getLogger(ResponseChecker.class);
   
    private final MessageFactory factory;
   
    private final Set<MessageId> history;
   
    public ResponseChecker(MessageFactory factory, int historySize) {
      this.factory = Precoditions.notNull(factory, "factory");
      this.history = Collections.synchronizedSet(
          new FixedSizeHashSet<MessageId>(historySize));
    }
   
    /**
     * Returns {@code true} if the {@link ResponseMessage} is OK.
     */
    public boolean check(ResponseMessage response) {
      MessageId messageId = response.getMessageId();
      if (!history.add(messageId)) {
        if (LOG.isErrorEnabled()) {
          LOG.error("Multiple respones: " + response);
        }
        return false;
      }
     
      Contact contact = response.getContact();
      if (!factory.isFor(messageId, contact.getRemoteAddress())) {
        if (LOG.isErrorEnabled()) {
          LOG.error("Wrong MessageId signature: " + response);
        }
        return false;
      }
     
      return true;
    }
  }
}
TOP

Related Classes of org.ardverk.dht.io.MessageDispatcher$MessageEntity

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.