Package cc.redberry.core.transformations.collect

Source Code of cc.redberry.core.transformations.collect.CollectTransformation$Split

/*
* Redberry: symbolic tensor computations.
*
* Copyright (c) 2010-2013:
*   Stanislav Poslavsky   <stvlpos@mail.ru>
*   Bolotin Dmitriy       <bolotin.dmitriy@gmail.com>
*
* This file is part of Redberry.
*
* Redberry is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Redberry is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Redberry. If not, see <http://www.gnu.org/licenses/>.
*/
package cc.redberry.core.transformations.collect;

import cc.redberry.concurrent.OutputPortUnsafe;
import cc.redberry.core.combinatorics.Combinatorics;
import cc.redberry.core.indexgenerator.IndexGeneratorImpl;
import cc.redberry.core.indexmapping.IndexMapping;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.*;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.*;
import cc.redberry.core.transformations.EliminateMetricsTransformation;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.expand.ExpandPort;
import cc.redberry.core.transformations.powerexpand.PowerExpandUnwrapTransformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;

import java.util.ArrayList;
import java.util.Arrays;

import static cc.redberry.core.indices.IndicesUtils.*;
import static cc.redberry.core.tensor.Tensors.multiply;
import static cc.redberry.core.tensor.Tensors.sum;

/**
* Collects together terms that involve the same powers of objects matching specified simple tensors or tensor fields.
*
* @author Dmitry Bolotin
* @author Stanislav Poslavsky
* @since 1.1.5
*/
public class CollectTransformation implements Transformation {
    private final TIntHashSet patternsNames;
    private final Transformation powerExpand;
    private final Transformation[] transformations;

    /**
     * Creates Collect transformation that collects together terms that involve
     * the same powers of objects matching specified simple tensors or tensor fields and applies specified
     * transformations to the expression that forms the coefficient of each term obtained.
     *
     * @param patterns        specified simple tensors or tensor fields
     * @param transformations transformations to be applied to the expression that forms the coefficient
     *                        of each term obtained
     */
    public CollectTransformation(SimpleTensor[] patterns, Transformation[] transformations) {
        patternsNames = new TIntHashSet();
        powerExpand = new PowerExpandUnwrapTransformation(patterns);
        for (SimpleTensor t : patterns)
            patternsNames.add(t.getName());
        this.transformations = transformations;
    }

    /**
     * Creates Collect transformation that collects together terms that involve
     * the same powers of objects matching specified simple tensors or tensor fields.
     *
     * @param patterns specified simple tensors or tensor fields
     */
    public CollectTransformation(SimpleTensor... patterns) {
        this(patterns, new Transformation[0]);
    }

    @Override
    public Tensor transform(Tensor t) {
        SumBuilder notMatched = new SumBuilder();
        TIntObjectHashMap<ArrayList<Split>> map = new TIntObjectHashMap<>();
        OutputPortUnsafe<Tensor> port = ExpandPort.createPort(t);
        Tensor current;
        Split toAdd;
        ArrayList<Split> nodes;
        out:
        while ((current = port.take()) != null) {
            toAdd = split(current);
            if (toAdd.factors.length == 0) {
                notMatched.put(current);
                continue;
            }

            nodes = map.get(toAdd.hashCode);
            if (nodes == null) {
                nodes = new ArrayList<>();
                nodes.add(toAdd);
                map.put(toAdd.hashCode, nodes);
                continue;
            }

            int[] match;
            for (Split base : nodes) {
                if ((match = matchFactors(base.factors, toAdd.factors)) != null) {
                    Tensor[] toAddFactors = Combinatorics.reorder(toAdd.factors, match);
                    Mapping mapping =
                            IndexMappings.createBijectiveProductPort(toAddFactors, base.factors).take();

//                    mapping =  mapping.inverseStates();
//                    for (Map.Entry<Integer, IndexMappingBufferRecord> entry : mapping.getMap().entrySet())
//                        entry.getValue().invertStates();


                    base.summands.add(ApplyIndexMapping.applyIndexMappingAutomatically(toAdd.summands.get(0), mapping, base.forbidden));
                    continue out;
                }
            }

            nodes.add(toAdd);
        }


        for (ArrayList<Split> splits : map.valueCollection())
            for (Split split : splits)
                notMatched.put(split.toTensor(transformations));


        return notMatched.build();
    }

    private boolean match(Tensor t) {
        if (t instanceof SimpleTensor)
            return patternsNames.contains(t.hashCode());
        if (TensorUtils.isPositiveIntegerPower(t))
            return patternsNames.contains(t.get(0).hashCode());
        return false;
    }

    private Split split(Tensor tensor) {
        Tensor[] factors;
        Tensor summand;

        if (tensor instanceof SimpleTensor || TensorUtils.isPositiveIntegerPowerOfSimpleTensor(tensor))
            if (match(tensor)) {
                factors = new Tensor[1];
                factors[0] = tensor;
                summand = Complex.ONE;
            } else
                return new Split(new Tensor[0], tensor);
        else if (tensor instanceof Product || TensorUtils.isPositiveIntegerPowerOfProduct(tensor)) {
            //early check
            tensor = powerExpand.transform(tensor);

            boolean containsMatch = false;
            for (Tensor t : (tensor instanceof Product ? tensor : tensor.get(0))) {
                if (match(t)) {
                    containsMatch = true;
                    break;
                }
            }
            if (!containsMatch) return new Split(new Tensor[0], tensor);

            assert tensor instanceof Product;

            ArrayList<Tensor> factorsList = new ArrayList<>();
            summand = tensor;
            for (Tensor t : tensor) {
                if (match(t)) {
                    factorsList.add(t);
                    assert summand != Complex.ONE;
                    if (summand instanceof Product)
                        summand = ((Product) summand).remove(t);
                    else summand = Complex.ONE;
                }
            }
            factors = factorsList.toArray(new Tensor[factorsList.size()]);
        } else
            return new Split(new Tensor[0], tensor);


        TIntHashSet freeIndices = new TIntHashSet(IndicesUtils.getIndicesNames(tensor.getIndices().getFree()));

        //now we need to uncontract dummies and free

        Indices factorIndices = new IndicesBuilder().append(factors).getIndices();
        TIntHashSet dummies = new TIntHashSet(IndicesUtils.getIntersections(
                factorIndices.getUpper().copy(), factorIndices.getLower().copy()));
        SimpleIndices currentFactorIndices;
        IntArrayList from = new IntArrayList(), to = new IntArrayList();
        ArrayList<Tensor> kroneckers = new ArrayList<>();
        int j, index, newIndex;
        IndexGeneratorImpl generator = new IndexGeneratorImpl(TensorUtils.getAllIndicesNamesT(tensor).toArray());
        for (int i = 0; i < factors.length; ++i) {
            from.clear();
            to.clear();
            currentFactorIndices = IndicesFactory.createSimple(null, factors[i].getIndices());

            for (j = currentFactorIndices.size() - 1; j >= 0; --j) {
                index = currentFactorIndices.get(j);
                if (freeIndices.contains(getNameWithType(index))) {
                    newIndex = setRawState(getRawStateInt(index), generator.generate(getType(index)));
                    from.add(index);
                    to.add(newIndex);
                    kroneckers.add(Tensors.createKronecker(index, inverseIndexState(newIndex)));
                } else if (IndicesUtils.getState(index) && dummies.contains(getNameWithType(index))) {
                    newIndex = setRawState(getRawStateInt(index), generator.generate(getType(index)));
                    from.add(index);
                    to.add(newIndex);
                    kroneckers.add(Tensors.createKronecker(index, inverseIndexState(newIndex)));
                }
            }

            factors[i] = applyDirectMapping(factors[i],
                    new StateSensitiveMapping(from.toArray(), to.toArray()));
        }

        //temp check
//            factorIndices = new IndicesBuilder().append(factors).getIndices();
//            assert factorIndices.size() == factorIndices.getFree().size();

        kroneckers.add(summand);
        summand = Tensors.multiply(kroneckers.toArray(new Tensor[kroneckers.size()]));
        summand = EliminateMetricsTransformation.eliminate(summand);

        return new Split(factors, summand);
    }

    private static final class Split {
        final Tensor[] factors;
        final ArrayList<Tensor> summands = new ArrayList<>();
        final int hashCode;//real hash code (with fields args)
        final int[] forbidden;

        private Split(Tensor[] factors, Tensor summand) {
            this.factors = factors;
            this.summands.add(summand);
            Arrays.sort(factors);
            this.hashCode = Arrays.hashCode(factors);
            this.forbidden = IndicesUtils.getIndicesNames(new IndicesBuilder().append(factors).getIndices());
        }


        @Override
        public int hashCode() {
            return hashCode;
        }

        Tensor toTensor(Transformation[] transformations) {
            Tensor sum = Transformation.Util.applySequentially(
                    Tensors.sum(summands.toArray(new Tensor[summands.size()])),
                    transformations);
            Tensor[] ms = new Tensor[factors.length + 1];
            ms[ms.length - 1] = sum;
            System.arraycopy(factors, 0, ms, 0, factors.length);
            return Tensors.multiply(ms);
        }

        @Override
        public String toString() {
            return multiply(factors) + " : " + sum(summands.toArray(new Tensor[summands.size()]));
        }
    }

    static int[] matchFactors(final Tensor[] a, final Tensor[] b) {
        if (a.length != b.length) return null;
        int begin = 0, j, n, length = a.length;

        int[] permutation = new int[length];
        Arrays.fill(permutation, -1);

        for (int i = 1; i <= length; ++i) {
            if (i == length || a[i].hashCode() != b[i - 1].hashCode()) {
                if (i - 1 != begin) {
                    OUT:
                    for (n = begin; n < i; ++n) {
                        for (j = begin; j < i; ++j)
                            if (permutation[j] == -1 && matchSimpleTensors(a[n], b[j])) {
                                permutation[j] = n;
                                continue OUT;
                            }
                        return null;
                    }
                } else {
                    if (!matchSimpleTensors(a[i - 1], b[i - 1])) return null;
                    else permutation[i - 1] = i - 1;
                }
                begin = i;
            }
        }
        return Combinatorics.inverse(permutation);
    }

    private static boolean matchSimpleTensors(Tensor a, Tensor b) {
        if (a.getClass() != b.getClass()) return false;
        if (a.hashCode() != b.hashCode()) return false;
        if (TensorUtils.isPositiveIntegerPowerOfSimpleTensor(a))
            return TensorUtils.isPositiveIntegerPowerOfSimpleTensor(b)
                    && a.get(1).equals(b.get(1))
                    && matchSimpleTensors(a.get(0), b.get(0));
        if (a instanceof TensorField)
            for (int i = a.size() - 1; i >= 0; --i)
                if (!IndexMappings.positiveMappingExists(a.get(i), b.get(i))) return false;
        return true;
    }

    private static Tensor applyDirectMapping(Tensor t, DirectIndexMapping mapping) {
        if (t instanceof SimpleTensor) {
            SimpleTensor st = (SimpleTensor) t;
            SimpleIndices newIndices = st.getIndices().applyIndexMapping(mapping);
            if (t instanceof TensorField)
                return Tensors.field(st.getName(), newIndices, ((TensorField) st).getArgIndices(), ((TensorField) st).getArguments());
            else
                return Tensors.simpleTensor(st.getName(), newIndices);
        } else {
            assert t.getIndices().size() == 0;
            return t;
        }
    }


    private static abstract class DirectIndexMapping implements IndexMapping {
        final int[] from, to;

        private DirectIndexMapping(int[] from, int[] to) {
            ArraysUtils.quickSort(from, to);
            this.from = from;
            this.to = to;
        }

    }

    private static final class StateSensitiveMapping extends DirectIndexMapping {
        private StateSensitiveMapping(int[] from, int[] to) {
            super(from, to);
        }

        @Override
        public int map(int from) {
            int index;
            if ((index = Arrays.binarySearch(this.from, from)) >= 0)
                return to[index];
            return from;
        }
    }
}
TOP

Related Classes of cc.redberry.core.transformations.collect.CollectTransformation$Split

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.