/*
* Redberry: symbolic tensor computations.
*
* Copyright (c) 2010-2012:
* Stanislav Poslavsky <stvlpos@mail.ru>
* Bolotin Dmitriy <bolotin.dmitriy@gmail.com>
*
* This file is part of Redberry.
*
* Redberry is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Redberry is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Redberry. If not, see <http://www.gnu.org/licenses/>.
*/
package cc.redberry.transformation.collect.old;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import cc.redberry.core.context.CC;
import static cc.redberry.core.indices.IndicesUtils.*;
import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorIterator;
import cc.redberry.core.tensor.TensorNumber;
import cc.redberry.core.indexmapping.IndexMappingBufferRecord;
import cc.redberry.core.indexmapping.IndexMappingDirectAllowingUnmapped;
import cc.redberry.core.indexmapping.IndexMappingUtils;
import cc.redberry.core.indexmapping.IndexMappingBuffer;
import cc.redberry.core.tensor.testing.TTest;
import cc.redberry.core.transformations.ApplyIndexMappingDirectTransformation;
import cc.redberry.transformation.Transformation;
import cc.redberry.transformation.Transformations;
import cc.redberry.transformation.collect.RenameContractedIndices;
import cc.redberry.transformation.collect.SplitPattern;
import cc.redberry.transformation.contractions.UncontractIndices;
import cc.redberry.transformation.contractions.UncontractIndicesAndRename;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.TensorUtils;
/**
*
* @author Dmitry Bolotin
* @author Stanislav Poslavsky
*/
public class CollectSimilarTerms implements Transformation {
private final SplitPattern splitPatterns;
private SplitPattern currentPattern;
final List<CollectedTerm> collectedTerms = new ArrayList<>();
public CollectSimilarTerms(SplitPattern splitPatterns) {
this.splitPatterns = splitPatterns;
}
@Override
public Tensor transform(Tensor tensor) {
if (!(tensor instanceof Sum))
return tensor;
generateCollectedTermsList((Sum) tensor);
Sum result = new Sum();
Tensor temp;
for (CollectedTerm collectedTerm : collectedTerms) {
temp = collectedTerm.result();
temp = Transformations.calculateNumbers(temp);
result.add(temp.equivalent());
}
return result.equivalent();
}
void generateCollectedTermsList(Sum tensor) {
currentPattern = splitPatterns;
OUT_FOR:
for (Tensor current : tensor) {
Split split = split(current);
for (CollectedTerm collectedTerm : collectedTerms)
if (TTest.testEqualstensorStructure(collectedTerm.collectedItem, split.collectedTerm)) {
//main rountine
UncontractIndicesAndRename uncontractIndicesTransformation =
new UncontractIndicesAndRename(TensorUtils.getAllIndicesNames(split.factoredOut),
collectedTerm.collectedTermIndicesNames.toArray());
split.collectedTerm = uncontractIndicesTransformation.renameIndicesAndBuidKroneckers(split.collectedTerm);
List<Tensor> generatedKroneckers = uncontractIndicesTransformation.getKroneckers();
split.factoredOut.add(generatedKroneckers);
//renaming indices of toCollect tensor
new RenameContractedIndices(collectedTerm.getCollectedFactorsIndicesNames().toArray()).transform(split.collectedTerm, split.factoredOut);
Tensor toCollect = split.collectedTerm;
Tensor collectedTensor = collectedTerm.collectedItem;
List<IndexMappingBuffer> buffers = IndexMappingUtils.createAllMappings(toCollect, collectedTensor, false);
IntArrayList uncontractedIndicesToCollect = TensorUtils.getContractedIndicesNames(split.collectedTerm, split.factoredOut);
IntArrayList uncontractedIndicesCollected = TensorUtils.getContractedIndicesNames(collectedTerm.collectedItem, collectedTerm.collectedFactors.getElements().get(0));
//finding best mapping
//botleneck for very huge sums
IndexMappingBuffer concurentBuffer = null;
int concurrence = 0, currentConcurrence;
boolean sign = false;
for (IndexMappingBuffer buffer : buffers) {
currentConcurrence = 0;
for (Map.Entry<Integer, IndexMappingBufferRecord> entry : buffer.getMap().entrySet()) {
int indexToCollect = entry.getKey();
int indexCollected = entry.getValue().getIndexName();
if (indexToCollect == indexCollected || (uncontractedIndicesToCollect.contains(indexToCollect) && uncontractedIndicesCollected.contains(indexCollected)))
currentConcurrence++;
}
if (concurentBuffer == null
|| currentConcurrence > concurrence
|| (sign && !buffer.getSignum()
&& (currentConcurrence >= concurrence))) {
concurentBuffer = buffer;
concurrence = currentConcurrence;
}
}
IndexMappingDirectAllowingUnmapped mappingToCollect = new IndexMappingDirectAllowingUnmapped();
IndexMappingDirectAllowingUnmapped mappingCollectedTensor = new IndexMappingDirectAllowingUnmapped();
//building index generator
IntArrayList usedIndices = new IntArrayList(TensorUtils.getAllIndicesNames(split.factoredOut));
usedIndices.addAll(collectedTerm.getCollectedFactorsIndicesNames());
usedIndices.addAll(collectedTerm.getCollectedTermIndicesNames());
IndexGenerator ig = new IndexGenerator(usedIndices.toArray());
List<Tensor> kroneckersCollected = new ArrayList<>();
for (Map.Entry<Integer, IndexMappingBufferRecord> entry : concurentBuffer.getMap().entrySet()) {
int indexCollected, indexToCollect, rawState = ((entry.getValue().getStates() & 1) ^ 1) << 31;
indexCollected = rawState | entry.getValue().getIndexName();
indexToCollect = rawState | entry.getKey();
SimpleTensor kroneckerCollected;
SimpleTensor kroneckerToCollect;
if (indexToCollect == indexCollected)
continue;
else if (uncontractedIndicesToCollect.contains(getNameWithType(indexToCollect))
&& uncontractedIndicesCollected.contains(getNameWithType(indexCollected))) {
mappingToCollect.add(inverseIndexState(indexToCollect),
inverseIndexState(indexCollected));
// continue;
}
else if (uncontractedIndicesToCollect.contains(getNameWithType(indexToCollect))) {
kroneckerCollected = CC.createKronecker(inverseIndexState(indexToCollect), indexCollected);
collectedTerm.getCollectedFactorsIndicesNames().add(getNameWithType(indexCollected));
collectedTerm.getCollectedFactorsIndicesNames().add(getNameWithType(indexToCollect));
collectedTerm.getCollectedTermIndicesNames().replaceFirst(getNameWithType(indexCollected), getNameWithType(indexToCollect));
mappingCollectedTensor.add(indexCollected, indexToCollect);
kroneckersCollected.add(kroneckerCollected);
// continue;
} else if (uncontractedIndicesCollected.contains(getNameWithType(indexCollected))) {
kroneckerToCollect = CC.createKronecker(indexToCollect, inverseIndexState(indexCollected));
split.factoredOut.add(kroneckerToCollect);
// continue;
} else {
int newIndex = ig.generate(getType(indexToCollect));
collectedTerm.getCollectedFactorsIndicesNames().add(getNameWithType(indexCollected));
collectedTerm.getCollectedFactorsIndicesNames().add(newIndex);
collectedTerm.getCollectedTermIndicesNames().replaceFirst(getNameWithType(indexCollected), newIndex);
// uncontractedIndicesCollected.addAll(newIndex);
// uncontractedIndicesToCollect.addAll(newIndex);
// uncontractedIndicesCollected.addAll(getNameWithType(indexCollected));
// uncontractedIndicesToCollect.addAll(getNameWithType(indexToCollect));
newIndex = getRawStateInt(indexToCollect) | newIndex;
kroneckerCollected = CC.createKronecker(indexCollected, inverseIndexState(newIndex));
kroneckerToCollect = CC.createKronecker(indexToCollect, inverseIndexState(newIndex));
mappingCollectedTensor.add(indexCollected, newIndex);
kroneckersCollected.add(kroneckerCollected);
split.factoredOut.add(kroneckerToCollect);
}
}
if (!kroneckersCollected.isEmpty())
for (Tensor p : collectedTerm.collectedFactors)
((Product) p).add(kroneckersCollected);
if (concurentBuffer.getSignum())
split.factoredOut.addFirst(TensorNumber.createMINUSONE());
if (!mappingToCollect.isEmpty())
ApplyIndexMappingDirectTransformation.INSTANCE.perform(split.factoredOut, mappingToCollect);
collectedTerm.collectedFactors.add(split.factoredOut);
collectedTerm.getCollectedFactorsIndicesNames().addAll(TensorUtils.getAllIndicesNames(split.factoredOut));
if (!mappingCollectedTensor.isEmpty())
ApplyIndexMappingDirectTransformation.INSTANCE.perform(collectedTensor, mappingCollectedTensor);
continue OUT_FOR;
}
UncontractIndices uncontractIndicesTransformation =
new UncontractIndices(TensorUtils.getAllIndicesNames(split.factoredOut));
split.collectedTerm = uncontractIndicesTransformation.renameIndicesAndBuidKroneckers(split.collectedTerm);
split.factoredOut.add(uncontractIndicesTransformation.getKroneckers());
CollectedTerm newTerm = new CollectedTerm(split.collectedTerm, split.factoredOut);
collectedTerms.add(newTerm);
}
}
private Split split(Tensor tensor) {
if (tensor instanceof Product) {
TensorIterator it = tensor.iterator();
Tensor current;
List<Tensor> factored = new ArrayList<>();
while (it.hasNext()) {
current = it.next();
if (currentPattern.factorOut(current)) {
factored.add(current);
it.remove();
}
}
if (factored.isEmpty())
factored.add(TensorNumber.createONE());
if (((Product) tensor).isEmpty())
return new Split(TensorNumber.createONE(), factored);
return new Split(tensor.equivalent(), factored);
}
if (tensor instanceof SimpleTensor)
if (currentPattern.factorOut(tensor)) {
List<Tensor> factoredOut = new ArrayList<>();
factoredOut.add(tensor);
return new Split(TensorNumber.createONE(), factoredOut);
} else {
List<Tensor> factoredOut = new ArrayList<>();
factoredOut.add(TensorNumber.createONE());
return new Split(tensor, factoredOut);
}
throw new UnsupportedOperationException();
}
private static class Split {
Tensor collectedTerm;
Product factoredOut;
public Split(Tensor collectingTerm, List<Tensor> factoredOut) {
this.collectedTerm = collectingTerm;
this.factoredOut = new Product(factoredOut);
}
}
static class CollectedTerm {
final Tensor collectedItem;
final Sum collectedFactors = new Sum();
private IntArrayList collectedTermIndicesNames;
private IntArrayBuffer collectedFactorsndexesNames;
public CollectedTerm(Tensor collectedTerm, Tensor factoredOut) {
this.collectedItem = collectedTerm;
collectedFactors.add(factoredOut);
collectedFactorsndexesNames = new IntArrayBuffer(TensorUtils.getAllIndicesNames(factoredOut));
collectedTermIndicesNames = new IntArrayList(TensorUtils.getAllIndicesNames(collectedTerm));
}
Tensor result() {
return new Product(collectedFactors.equivalent(), collectedItem.equivalent());
}
IntArrayList getCollectedFactorsIndicesNames() {
return collectedFactorsndexesNames;
}
IntArrayList getCollectedTermIndicesNames() {
return collectedTermIndicesNames;
}
}
private static class IntArrayBuffer extends IntArrayList {
public IntArrayBuffer(int[] data) {
super(data);
}
@Override
public void add(int num) {
if (contains(num))
return;
super.add(num);
}
@Override
public void addAll(int[] arr) {
IntArrayList temp = new IntArrayList();
for (int i : arr)
if (contains(i))
continue;
else
temp.add(i);
super.addAll(temp.toArray());
}
}
}