Package org.apache.activemq.apollo.mqtt

Source Code of org.apache.activemq.apollo.mqtt.MqttProtocolHandler$Break

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.apollo.mqtt;

import org.apache.activemq.apollo.broker.*;
import org.apache.activemq.apollo.broker.protocol.AbstractProtocolHandler;
import org.apache.activemq.apollo.broker.protocol.ProtocolFilter3;
import org.apache.activemq.apollo.broker.protocol.ProtocolFilter3$;
import org.apache.activemq.apollo.broker.security.SecurityContext;
import org.apache.activemq.apollo.dto.AcceptingConnectorDTO;
import org.apache.activemq.apollo.dto.ProtocolDTO;
import org.apache.activemq.apollo.mqtt.dto.MqttConnectionStatusDTO;
import org.apache.activemq.apollo.mqtt.dto.MqttDTO;
import org.apache.activemq.apollo.util.*;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.hawtdispatch.DispatchQueue;
import org.fusesource.hawtdispatch.Task;
import org.fusesource.hawtdispatch.transport.HeartBeatMonitor;
import org.fusesource.mqtt.codec.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

import static org.apache.activemq.apollo.util.Scala2Java.*;
import static org.fusesource.hawtdispatch.Dispatch.NOOP;

/**
* @author <a href="http://hiramchirino.com">Hiram Chirino</a>
*/
public class MqttProtocolHandler extends AbstractProtocolHandler {

    public static final Scala2Java.Logger log = new Scala2Java.Logger(Log$.MODULE$.apply(MqttProtocolHandler.class));

    public static <T> T received(T value) {
        log.trace("received: %s", value);
        return value;
    }

    static Fn0<String> WAITING_ON_CLIENT_REQUEST = new Fn0<String>() {
        @Override
        public String apply() {
            return "client request";
        }
    };

    public String protocol() {
        return "mqtt";
    }

    public Broker broker() {
        return connection().connector().broker();
    }

    public DispatchQueue queue() {
        return connection().dispatch_queue();
    }

    Scala2Java.Logger connection_log = log;
    MqttDTO config = null;

    public DestinationParser destination_parser() {
        DestinationParser destination_parser = MqttProtocol.destination_parser;
        if (config.queue_prefix != null ||
                config.path_separator != null ||
                config.any_child_wildcard != null ||
                config.any_descendant_wildcard != null ||
                config.regex_wildcard_start != null ||
                config.regex_wildcard_end != null ||
                config.part_pattern != null
                ) {
            destination_parser = new DestinationParser().copy(destination_parser);
            if (config.queue_prefix != null) {
                destination_parser.queue_prefix_$eq(config.queue_prefix);
            }
            if (config.path_separator != null) {
                destination_parser.path_separator_$eq(config.path_separator);
            }
            if (config.any_child_wildcard != null) {
                destination_parser.any_child_wildcard_$eq(config.any_child_wildcard);
            }
            if (config.any_descendant_wildcard != null) {
                destination_parser.any_descendant_wildcard_$eq(config.any_descendant_wildcard);
            }
            if (config.regex_wildcard_start != null) {
                destination_parser.regex_wildcard_start_$eq(config.regex_wildcard_start);
            }
            if (config.regex_wildcard_end != null) {
                destination_parser.regex_wildcard_end_$eq(config.regex_wildcard_end);
            }
            if (config.part_pattern != null) {
                destination_parser.part_pattern_$eq(Pattern.compile(config.part_pattern));
            }
        }
        return destination_parser;
    }

    final ArrayList<ProtocolFilter3> protocol_filters = new ArrayList<ProtocolFilter3>();

    /////////////////////////////////////////////////////////////////////
    //
    // Bits related setting up a client connection
    //
    /////////////////////////////////////////////////////////////////////
    public String session_id() {
        return security_context.session_id();
    }

    final SecurityContext security_context = new SecurityContext();
    SinkMux<Request> sink_manager = null;
    Sink<Request> connection_sink = null;
    MQTTProtocolCodec codec = null;


    static private MqttDTO find_config(AcceptingConnectorDTO connector_config) {
        for (ProtocolDTO protocol : connector_config.protocols) {
            if (protocol instanceof MqttDTO) {
                return (MqttDTO) protocol;
            }
        }
        return new MqttDTO();
    }

    public void on_transport_connected() {

        codec = (MQTTProtocolCodec) connection().transport().getProtocolCodec();
        AcceptingConnectorDTO connector_config = (AcceptingConnectorDTO) connection().connector().config();
        config = find_config(connector_config);

        codec.setMaxMessageLength(get(config.max_message_length, codec.getMaxMessageLength()));

        protocol_filters.clear();
        protocol_filters.addAll(ProtocolFilter3$.MODULE$.create_filters(config.protocol_filters, this));

        security_context.local_address_$eq(connection().transport().getLocalAddress());
        security_context.remote_address_$eq(connection().transport().getRemoteAddress());
        security_context.connector_id_$eq(connection().connector().id());
        security_context.certificates_$eq(connection().certificates());

        connection_log = new Scala2Java.Logger(connection().connector().broker().connection_log());

        Sink<Request> filtering_sink = new AbstractSinkMapper<Request, Object>() {
            public Sink<Object> downstream() {
                return (Sink<Object>) connection().transport_sink();
            }

            public MQTTFrame passing(Request request) {
                log.trace("sent: %s", request.message);
                request.delivered = true;
                if (request.id == 0 && request.ack != null) {
                    request.ack.apply(Consumed$.MODULE$);
                }
                return request.frame;

            }
        };

        if (!protocol_filters.isEmpty()) {
            final Sink<Request> downstream  = filtering_sink;
            filtering_sink = new AbstractSinkFilter<Request, Request>() {

                @Override
                public Sink<Request> downstream() {
                    return downstream;
                }

                @Override
                public Request filter(Request value) {
                    Request cur = value;
                    for (ProtocolFilter3 filter : protocol_filters) {
                        cur = filter.filter_outbound(cur);
                        if(cur == null ) {
                            break;
                        }
                    }
                    return cur;
                }
            };
        }
        sink_manager = new SinkMux<Request>(filtering_sink);
        connection_sink = new OverflowSink(sink_manager.open());
        resume_read();
    }


    /////////////////////////////////////////////////////////////////////
    //
    // Bits related tearing down a client connection
    //
    /////////////////////////////////////////////////////////////////////
    boolean closed = false;
    public static UnitFn1<Object> dead_handler = new UnitFn1<Object>() {
        @Override
        public void call(Object v1) {
        }
    };

    public void on_transport_disconnected() {
        if (!closed) {
            closed = true;
            dead = true;
            command_handler = dead_handler;

            security_context.logout(toScala(new UnitFn1<Throwable>() {
                @Override
                public void call(Throwable e) {
                    if (e != null) {
                        connection_log.info(e, "MQTT connection '%s' log out error: %s", security_context.remote_address(), e.toString());
                    }
                }
            }));

            heart_beat_monitor.stop();
            if (!connection().stopped()) {
                connection().stop(NOOP);
            }
            log.trace("mqtt protocol resources released");
        }
    }

    public void on_transport_failure(IOException error) {
        if (!dead) {
            command_handler.apply("failure");
            dead = true;
            command_handler = dead_handler;
            if (!connection().stopped()) {
                connection_log.info(error, "Shutting connection '%s'  down due to: %s", security_context.remote_address(), error);
                super.on_transport_failure(error);
            }
        }
    }

    /////////////////////////////////////////////////////////////////////
    //
    // Bits related managing connection flow control
    //
    /////////////////////////////////////////////////////////////////////

    Fn0<String> status = WAITING_ON_CLIENT_REQUEST;

    public void _suspend_read(final String reason) {
        suspend_read(new Fn0<String>() {
            @Override
            public String apply() {
                return reason;
            }
        });
    }

    public void suspend_read(Fn0<String> reason) {
        status = reason;
        connection().transport().suspendRead();
        heart_beat_monitor.suspendRead();
    }

    public void resume_read() {
        status = WAITING_ON_CLIENT_REQUEST;
        connection().transport().resumeRead();
        heart_beat_monitor.resumeRead();
    }

    /////////////////////////////////////////////////////////////////////
    //
    // Bits related to raising connection failure signals
    //
    /////////////////////////////////////////////////////////////////////

    boolean dead = false;

    public long die_delay() {
        return get(config.die_delay, 1000 * 5L);
    }

    class Break extends RuntimeException {
    }

    public void async_die(String msg) {
        async_die(msg, null);
    }

    public void async_die(String msg, Throwable e) {
        try {
            die(msg, e);
        } catch (Break x) {
        }
    }

    public void async_die(MessageSupport.Message response, String msg) {
        try {
            die(response, msg, null);
        } catch (Break x) {
        }
    }

    public <T> T die(String msg) {
        return die(null, msg, null);
    }

    public <T> T die(String msg, Throwable e) {
        return die(null, msg, e);
    }

    public <T> T die(MessageSupport.Message response, String msg) {
        return die(response, msg, null);
    }

    public <T> T die(MessageSupport.Message response, String msg, Throwable e) {
        if (e != null) {
            connection_log.info(e, "MQTT connection '%s' error: %s", security_context.remote_address(), msg, e);
        } else {
            connection_log.info("MQTT connection '%s' error: %s", security_context.remote_address(), msg);
        }
        return die(response);
    }

    public <T> T die(MessageSupport.Message response) {
        if (!dead) {
            command_handler.apply("failure");
            dead = true;
            command_handler = dead_handler;
            status = new Fn0<String>() {
                @Override
                public String apply() {
                    return "shuting down";
                }
            };
            if (response != null) {
                connection().transport().resumeRead();
                connection_sink.offer(new Request((short) 0, response, null));
                // TODO: if there are too many open connections we should just close the connection
                // without waiting for the error to get sent to the client.
                queue().executeAfter(die_delay(), TimeUnit.MILLISECONDS, new Task() {
                    @Override
                    public void run() {
                        connection().stop(NOOP);
                    }
                });
            } else {
                connection().stop(NOOP);
            }
        }
        throw new Break();
    }

    /////////////////////////////////////////////////////////////////////
    //
    // Bits for dispatching client requests.
    //
    /////////////////////////////////////////////////////////////////////
    UnitFn1<Object> command_handler = connect_handler();

    public void on_transport_command(Object command) {
        try {
            if (!protocol_filters.isEmpty()) {
                for (ProtocolFilter3 filter : protocol_filters) {
                    command = filter.filter_inbound(command);
                    if (command==null) {
                        return; // dropping the frame.
                    }
                }
            }
            command_handler.apply(command);
        } catch (Break e) {
        } catch (Exception e) {
            // To avoid double logging to the same log category..
            String msg = "Internal Server Error: " + e;
            if (connection_log != MqttProtocolHandler.log) {
                // but we also want the error on the apollo.log file.
                MqttProtocolHandler.log.warn(e, msg);
            }
            async_die(msg, e);
        }
    }


    /////////////////////////////////////////////////////////////////////
    //
    // Bits related establishing the client connection
    //
    /////////////////////////////////////////////////////////////////////

    CONNECT connect_message = null;
    HeartBeatMonitor heart_beat_monitor = new HeartBeatMonitor();
    VirtualHost host = null;

    public UnitFn1<Object> connect_handler() {
        return new UnitFn1<Object>() {
            @Override
            public void call(Object o) {
                if (o instanceof MQTTFrame) {
                    MQTTFrame command = (MQTTFrame) o;
                    try {
                        if (command.messageType() == CONNECT.TYPE) {
                            connect_message = received(new CONNECT().decode(command));
                            on_mqtt_connect();
                        } else {
                            die("Expecting an MQTT CONNECT message, but got: " + command.getClass());
                        }
                    } catch (java.net.ProtocolException e) {
                        die("Internal Server Error: bad mqtt command: " + command);
                    }
                } else if ("failure".equals(o)) {
                } else {
                    die("Internal Server Error: unexpected mqtt command: " + o.getClass());
                }
            }
        };
    }

    public void on_mqtt_connect() {

        final CONNACK connack = new CONNACK();

        switch(connect_message.version()) {
            case 3:case 4: break;
            default:
                connack.code(CONNACK.Code.CONNECTION_REFUSED_UNACCEPTED_PROTOCOL_VERSION);
                die(connack, "Unsupported protocol version: " + connect_message.version());
        }

        if( (connect_message.clientId() == null || connect_message.clientId().length==0) && !connect_message.cleanSession() ) {
            die(connack, "A clean session must be requested when no client id is provided.");
        }

        UTF8Buffer client_id = connect_message.clientId();
        security_context.user_$eq(Scala2Java.toString(connect_message.userName()));
        security_context.password_$eq(Scala2Java.toString(connect_message.password()));
        security_context.session_id_$eq(client_id.toString());

        final short keep_alive = connect_message.keepAlive();
        if (keep_alive > 0) {
            heart_beat_monitor.setReadInterval(((long) (keep_alive * 1.5)) * 1000);
            heart_beat_monitor.setOnDead(new Task() {
                @Override
                public void run() {
                    async_die("Missed keep alive set to " + keep_alive + " seconds");
                }
            });
        }
        heart_beat_monitor.suspendRead();
        heart_beat_monitor.setTransport(connection().transport());
        heart_beat_monitor.start();

        _suspend_read("virtual host lookup");
        broker().dispatch_queue().execute(new Task() {
            @Override
            public void run() {
                host = connection().connector().broker().get_default_virtual_host();
                queue().execute(new Task() {
                    @Override
                    public void run() {
                        resume_read();
                        if (host == null) {
                            connack.code(CONNACK.Code.CONNECTION_REFUSED_SERVER_UNAVAILABLE);
                            async_die(connack, "Default virtual host not found.");
                        } else if (!host.service_state().is_started()) {
                            connack.code(CONNACK.Code.CONNECTION_REFUSED_SERVER_UNAVAILABLE);
                            async_die(connack, "Default virtual host stopped.");
                        } else {
                            connection_log = new Scala2Java.Logger(host.connection_log());
                            if (host.authenticator() != null && host.authorizer() != null) {
                                _suspend_read("authenticating and authorizing connect");
                                host.authenticator().authenticate(security_context, toScala(new UnitFn1<String>() {
                                    public void call(final String auth_err) {
                                        queue().execute(new Task() {
                                            @Override
                                            public void run() {
                                                if (auth_err != null) {
                                                    connack.code(CONNACK.Code.CONNECTION_REFUSED_BAD_USERNAME_OR_PASSWORD);
                                                    async_die(connack, auth_err + ". Credentials=" + security_context.credential_dump());
                                                } else if (!host.authorizer().can(security_context, "connect", connection().connector())) {
                                                    connack.code(CONNACK.Code.CONNECTION_REFUSED_NOT_AUTHORIZED);
                                                    async_die(connack, String.format("Not authorized to connect to connector '%s'. Principals=", connection().connector().id(), security_context.principal_dump()));
                                                } else if (!host.authorizer().can(security_context, "connect", host)) {
                                                    connack.code(CONNACK.Code.CONNECTION_REFUSED_NOT_AUTHORIZED);
                                                    async_die(connack, String.format("Not authorized to connect to virtual host '%s'. Principals=", host.id(), security_context.principal_dump()));
                                                } else {
                                                    resume_read();
                                                    on_host_connected(host);
                                                }
                                            }
                                        });
                                    }
                                }));
                            } else {
                                on_host_connected(host);
                            }
                        }
                    }
                });
            }
        });
    }

    public void on_host_connected(VirtualHost host) {
        MqttSessionManager.attach(host, connect_message.clientId(), this);
    }

    /////////////////////////////////////////////////////////////////////
    //
    // Other msic bits.
    //
    /////////////////////////////////////////////////////////////////////
    LongCounter messages_sent = new LongCounter(0);
    LongCounter messages_received = new LongCounter(0);
    int subscription_count = 0;

    public MqttConnectionStatusDTO create_connection_status(boolean debug) {
        MqttConnectionStatusDTO rc = new MqttConnectionStatusDTO();
        rc.protocol_version = "3.1";
        rc.messages_sent = messages_sent.get();
        rc.messages_received = messages_received.get();
        rc.subscription_count = subscription_count;
        rc.waiting_on = status.apply();
        return rc;
    }


}
TOP

Related Classes of org.apache.activemq.apollo.mqtt.MqttProtocolHandler$Break

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.