Package org.jboss.aerogear.simplepush.server.datastore

Source Code of org.jboss.aerogear.simplepush.server.datastore.InMemoryDataStore$MutableChannel

/**
* JBoss, Home of Professional Open Source
* Copyright Red Hat, Inc., and individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.aerogear.simplepush.server.datastore;

import static org.jboss.aerogear.simplepush.util.ArgumentUtil.checkNotNull;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;

import org.jboss.aerogear.simplepush.protocol.Ack;
import org.jboss.aerogear.simplepush.protocol.impl.AckImpl;
import org.jboss.aerogear.simplepush.server.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A {@link DataStore} implementation that stores all information in memory.
*/
public class InMemoryDataStore implements DataStore {

    private final ConcurrentMap<String, MutableChannel> channels = new ConcurrentHashMap<String, MutableChannel>();
    private final ConcurrentMap<String, MutableChannel> endpoints = new ConcurrentHashMap<String, MutableChannel>();
    private final ConcurrentMap<String, Set<Ack>> unacked = new ConcurrentHashMap<String, Set<Ack>>();
    private final Logger logger = LoggerFactory.getLogger(InMemoryDataStore.class);

    private byte[] salt;

    @Override
    public void savePrivateKeySalt(final byte[] salt) {
        if (this.salt != null) {
            this.salt = salt;
        }
    }

    @Override
    public byte[] getPrivateKeySalt() {
        if (salt == null) {
            return new byte[]{};
        }
        return salt;
    }

    @Override
    public boolean saveChannel(final Channel ch) {
        checkNotNull(ch, "ch");
        final MutableChannel mutableChannel = new MutableChannel(ch);
        final Channel previous = channels.putIfAbsent(ch.getChannelId(), mutableChannel);
        endpoints.put(ch.getEndpointToken(), mutableChannel);
        return previous == null;
    }

    private boolean removeChannel(final String channelId) {
        checkNotNull(channelId, "channelId");
        final Channel channel = channels.remove(channelId);
        if (channel != null) {
            endpoints.remove(endpoints.get(channel.getEndpointToken()));
        }
        return channel != null;
    }

    @Override
    public Channel getChannel(final String channelId) throws ChannelNotFoundException {
        checkNotNull(channelId, "channelId");
        final Channel channel = channels.get(channelId);
        if (channel == null) {
            throw new ChannelNotFoundException("No Channel for [" + channelId + "] was found", channelId);
        }
        return channel;
    }

    @Override
    public void removeChannels(final String uaid) {
        checkNotNull(uaid, "uaid");
        for (Channel channel : channels.values()) {
            if (channel.getUAID().equals(uaid)) {
                removeChannel(channel.getChannelId());
                logger.info("Removing [" + channel.getChannelId() + "] for UserAgent [" + uaid + "]");
            }
        }
        unacked.remove(uaid);
    }

    @Override
    public void removeChannels(final Set<String> channelIds) {
        checkNotNull(channelIds, "channelIds");
        for (String channelId : channelIds) {
            removeChannel(channelId);
            logger.debug("Removing [" + channelId + "]");
        }
    }

    @Override
    public Set<String> getChannelIds(final String uaid) {
        checkNotNull(uaid, "uaid");
        final Set<String> channelIds = new HashSet<String>();
        for (Channel channel : channels.values()) {
            if (channel.getUAID().equals(uaid)) {
                channelIds.add(channel.getChannelId());
            }
        }
        return channelIds;
    }

    @Override
    public String updateVersion(final String endpointToken, final long version) throws VersionException, ChannelNotFoundException {
        final MutableChannel channel = endpoints.get(endpointToken);
        if (channel == null) {
            throw new ChannelNotFoundException("Could not find channel for endpointToken", endpointToken);
        }
        channel.updateVersion(version);
        return channel.getChannelId();
    }

    @Override
    public String saveUnacknowledged(final String channelId, final long version) throws ChannelNotFoundException {
        checkNotNull(channelId, "channelId");
        checkNotNull(version, "version");
        final Channel channel = channels.get(channelId);
        if (channel == null) {
            throw new ChannelNotFoundException("Could not find channel", channelId);
        }
        final String uaid = channel.getUAID();
        final Set<Ack> newAcks = Collections.newSetFromMap(new ConcurrentHashMap<Ack, Boolean>());
        newAcks.add(new AckImpl(channelId, version));
        while (true) {
            final Set<Ack> currentAcks = unacked.get(uaid);
            if (currentAcks == null) {
                final Set<Ack> previous = unacked.putIfAbsent(uaid, newAcks);
                if (previous != null) {
                    newAcks.addAll(previous);
                    if (unacked.replace(uaid, previous, newAcks)) {
                        break;
                    }
                }
            } else {
                newAcks.addAll(currentAcks);
                if (unacked.replace(uaid, currentAcks, newAcks)) {
                    break;
                }
            }
        }
        return uaid;
    }

    @Override
    public Set<Ack> getUnacknowledged(final String uaid) {
        checkNotNull(uaid, "uaid");
        final Set<Ack> acks = unacked.get(uaid);
        if (acks == null) {
            return Collections.emptySet();
        }
        return Collections.unmodifiableSet(acks);
    }

    @Override
    public Set<Ack> removeAcknowledged(final String uaid, final Set<Ack> acked) {
        checkNotNull(uaid, "uaid");
        checkNotNull(acked, "acked");
        while (true) {
            final Set<Ack> currentAcks = unacked.get(uaid);
            if (currentAcks == null || currentAcks.isEmpty()) {
                return Collections.emptySet();
            }
            final Set<Ack> newAcks = Collections.newSetFromMap(new ConcurrentHashMap<Ack, Boolean>());
            boolean added = newAcks.addAll(currentAcks);
            if (!added){
                return newAcks;
            }

            boolean removed = newAcks.removeAll(acked);
            if (removed) {
                if (unacked.replace(uaid, currentAcks, newAcks)) {
                    return newAcks;
                }
            } else {
                return newAcks;
            }
        }
    }

    /**
     * A Channel implementation which has a mutable version and indended for
     * usage with the InMemoryDataStore.
     * This class uses a concurrent data structure to store and update the version.
     */
    private static class MutableChannel implements Channel {

        private final Channel delegate;
        private final AtomicLong version;

        public MutableChannel(final Channel delegate) {
            this.delegate = delegate;
            version = new AtomicLong(delegate.getVersion());
        }

        @Override
        public String getUAID() {
            return delegate.getUAID();
        }

        @Override
        public String getChannelId() {
            return delegate.getChannelId();
        }

        @Override
        public long getVersion() {
            return version.get();
        }

        public void updateVersion(final long newVersion) {
            for (;;) {
                final long currentVersion = version.get();
                if (newVersion <= currentVersion) {
                    throw new VersionException("New version [" + newVersion + "] must be greater than current version [" + currentVersion + "]");
                }
                if (version.compareAndSet(currentVersion, newVersion)) {
                    break;
                }
            }
        }

        @Override
        public String getEndpointToken() {
            return delegate.getEndpointToken();
        }

    }

}
TOP

Related Classes of org.jboss.aerogear.simplepush.server.datastore.InMemoryDataStore$MutableChannel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.