Package cc.redberry.core.transformations

Source Code of cc.redberry.core.transformations.DifferentiateTransformation$SymmetricDifferentiationRule

/*
* Redberry: symbolic tensor computations.
*
* Copyright (c) 2010-2013:
*   Stanislav Poslavsky   <stvlpos@mail.ru>
*   Bolotin Dmitriy       <bolotin.dmitriy@gmail.com>
*
* This file is part of Redberry.
*
* Redberry is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Redberry is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Redberry. If not, see <http://www.gnu.org/licenses/>.
*/
package cc.redberry.core.transformations;

import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.*;
import cc.redberry.core.tensor.functions.ScalarFunction;
import cc.redberry.core.transformations.substitutions.SubstitutionTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeSimpleTensorTransformation;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.set.hash.TIntHashSet;

import static cc.redberry.core.indices.IndicesUtils.*;
import static cc.redberry.core.tensor.ApplyIndexMapping.*;
import static cc.redberry.core.tensor.Tensors.*;
import static cc.redberry.core.utils.ArraysUtils.addAll;

/**
* Differentiates specified tensor with respect to specified simple tensors.
* It temporary does not support derivatives of tensor fields.
*
* @author Dmitry Bolotin
* @author Stanislav Poslavsky
* @since 1.0
*/
public final class DifferentiateTransformation implements Transformation {

    private final SimpleTensor[] vars;
    private final Transformation[] expandAndContract;

    /**
     * Creates transformations which differentiate with respect to specified simple tensors.
     *
     * @param vars
     */
    public DifferentiateTransformation(SimpleTensor... vars) {
        this.vars = vars;
        this.expandAndContract = new Transformation[0];
    }

    public DifferentiateTransformation(Transformation[] expandAndContract, SimpleTensor... vars) {
        this.vars = vars;
        this.expandAndContract = expandAndContract;
    }

    @Override
    public Tensor transform(Tensor t) {
        return differentiate(t, expandAndContract, vars);
    }

    /**
     * Gives the multiple derivative of specified order of specified tensor with respect to specified simple tensor.
     *
     * @param tensor tensor to be differentiated
     * @param var    simple tensor
     * @param order  order of derivative
     * @return derivative
     * @throws IllegalArgumentException if both order is not one and var is not scalar.
     */
    public static Tensor differentiate(Tensor tensor, SimpleTensor var, int order) {
        if (var.getIndices().size() != 0 && order > 1)
            throw new IllegalArgumentException();
        for (; order > 0; --order)
            tensor = differentiate(tensor, new Transformation[0], var);
        return tensor;
    }

    /**
     * Gives the multiple derivative of specified tensor with respect to specified arguments.
     *
     * @param tensor tensor to be differentiated
     * @param vars   arguments
     * @return derivative
     * @throws IllegalArgumentException if there is clash of indices
     */
    public static Tensor differentiate(Tensor tensor, SimpleTensor... vars) {
        if (vars.length == 0)
            return tensor;
        if (vars.length == 1)
            return differentiate(tensor, new Transformation[0], vars[0]);
        return differentiate(tensor, new Transformation[0], vars);
    }

    /**
     * Gives the multiple derivative of specified tensor with respect to specified arguments.
     *
     * @param tensor            tensor to be differentiated
     * @param vars              arguments
     * @param expandAndContract additional transformations to be applied after each step of differentiation
     * @return derivative
     * @throws IllegalArgumentException if there is clash of indices
     */
    public static Tensor differentiate(Tensor tensor, Transformation[] expandAndContract, SimpleTensor... vars) {
        if (vars.length == 0)
            return tensor;
        if (vars.length == 1)
            return differentiate(tensor, expandAndContract, vars[0]);

        boolean needRename = false;
        for (SimpleTensor var : vars)
            if (var.getIndices().size() != 0) {
                needRename = true;
                break;
            }

        SimpleTensor[] resolvedVars = vars;
        if (needRename) {
            TIntHashSet forbidden = TensorUtils.getAllIndicesNamesT(tensor);
            for (SimpleTensor var : vars)
                forbidden.addAll(getIndicesNames(var.getIndices().getFree()));

            resolvedVars = vars.clone();
            for (int i = 0; i < vars.length; ++i)
                if (!forbidden.isEmpty() && resolvedVars[i].getIndices().size() != 0) {
                    if (resolvedVars[i].getIndices().size() != resolvedVars[i].getIndices().getFree().size())
                        resolvedVars[i] = (SimpleTensor) renameDummy(resolvedVars[i], forbidden.toArray());
                    forbidden.addAll(getIndicesNames(resolvedVars[i].getIndices()));
                }
            tensor = renameDummy(tensor, TensorUtils.getAllIndicesNamesT(resolvedVars).toArray(), forbidden);
            tensor = renameIndicesOfFieldsArguments(tensor, forbidden);
        }

        for (SimpleTensor var : resolvedVars)
            tensor = differentiate1(tensor, createRule(var), expandAndContract);

        return tensor;
    }

    private static Tensor differentiate(Tensor tensor, Transformation[] expandAndContract, SimpleTensor var) {
        if (var.getIndices().size() != 0) {
            TIntHashSet forbidden = TensorUtils.getAllIndicesNamesT(tensor);
            var = (SimpleTensor) renameDummy(var, TensorUtils.getAllIndicesNamesT(tensor).toArray());
            forbidden.addAll(IndicesUtils.getIndicesNames(var.getIndices()));
            tensor = renameDummy(tensor, TensorUtils.getAllIndicesNamesT(var).toArray(), forbidden);
            tensor = renameIndicesOfFieldsArguments(tensor, forbidden);
        }
        return differentiate1(tensor, createRule(var), expandAndContract);
    }

    private static Tensor differentiateWithRenaming(Tensor tensor, SimpleTensorDifferentiationRule rule, Transformation[] expandAndContarct) {
        SimpleTensorDifferentiationRule newRule = rule.newRuleForTensor(tensor);
        tensor = renameDummy(tensor, newRule.getForbidden());
        return differentiate1(tensor, newRule, expandAndContarct);
    }

    private static Tensor differentiate1(Tensor tensor, SimpleTensorDifferentiationRule rule, Transformation[] transformations) {
        if (tensor.getClass() == SimpleTensor.class) {
            Tensor temp = rule.differentiateSimpleTensor((SimpleTensor) tensor);
            return applyTransformations(temp, transformations);
        }
        if (tensor.getClass() == TensorField.class) {
            TensorField field = (TensorField) tensor;
            SumBuilder result = new SumBuilder(tensor.size());
            Tensor dArg;

            for (int i = tensor.size() - 1; i >= 0; --i) {
                dArg = differentiate1(field.get(i), rule, transformations);
                if (TensorUtils.isZero(dArg)) continue;

                result.put(
                        multiply(dArg,
                                fieldDerivative(field, field.getArgIndices(i).getInverted(), i))
                );

            }
            return applyTransformations(EliminateMetricsTransformation.eliminate(result.build()), transformations);
        }
        if (tensor instanceof Sum) {
            SumBuilder builder = new SumBuilder();
            Tensor temp;
            for (Tensor t : tensor) {
                temp = differentiate1(t, rule, transformations);
                temp = applyTransformations(temp, transformations);
                builder.put(temp);
            }
            return builder.build();
        }
        if (tensor instanceof ScalarFunction) {
            Tensor temp = multiply(((ScalarFunction) tensor).derivative(),
                    differentiateWithRenaming(tensor.get(0), rule, transformations));
            temp = applyTransformations(temp, transformations);
            return temp;
        }
        if (tensor instanceof Power) {
            //e^f*ln(g) -> g^f*(f'*ln(g)+f/g*g') ->f*g^(f-1)*g' + g^f*ln(g)*f'
            Tensor temp = sum(
                    multiply(tensor.get(1),
                            pow(tensor.get(0), sum(tensor.get(1), Complex.MINUS_ONE)),
                            differentiate1(tensor.get(0), rule, transformations)),
                    multiply(tensor,
                            log(tensor.get(0)),
                            differentiateWithRenaming(tensor.get(1), rule, transformations)));
            temp = applyTransformations(temp, transformations);
            return temp;
        }
        if (tensor instanceof Product) {
            SumBuilder result = new SumBuilder();
            Tensor temp;
            for (int i = tensor.size() - 1; i >= 0; --i) {
                temp = tensor.set(i, differentiate1(tensor.get(i), rule, transformations));
                if (rule.var.getIndices().size() != 0)
                    temp = EliminateMetricsTransformation.eliminate(temp);
                temp = applyTransformations(temp, transformations);
                result.put(temp);
            }
            return result.build();
        }
        if (tensor instanceof Complex)
            return Complex.ZERO;
        throw new UnsupportedOperationException();
    }

    private static Tensor applyTransformations(Tensor tensor, Transformation[] transformations) {
        for (Transformation transformation : transformations)
            tensor = transformation.transform(tensor);
        return tensor;
    }

    private static SimpleTensorDifferentiationRule createRule(SimpleTensor var) {
        if (var.getIndices().size() == 0)
            return new SymbolicDifferentiationRule(var);
        return new SymmetricDifferentiationRule(var);
    }

    private static abstract class SimpleTensorDifferentiationRule {

        protected final SimpleTensor var;

        protected SimpleTensorDifferentiationRule(SimpleTensor var) {
            this.var = var;
        }

        Tensor differentiateSimpleTensor(SimpleTensor simpleTensor) {
            if (simpleTensor.getName() != var.getName())
                return Complex.ZERO;
            return differentiateSimpleTensorWithoutCheck(simpleTensor);
        }

        abstract SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor);

        abstract Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor);

        abstract int[] getForbidden();
    }

    private static final class SymbolicDifferentiationRule extends SimpleTensorDifferentiationRule {

        private SymbolicDifferentiationRule(SimpleTensor var) {
            super(var);
        }

        @Override
        Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor) {
            return Complex.ONE;
        }

        @Override
        SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor) {
            return this;
        }

        @Override
        int[] getForbidden() {
            return new int[0];
        }
    }

    private static final class SymmetricDifferentiationRule extends SimpleTensorDifferentiationRule {

        private final Tensor derivative;
        private final int[] allFreeFrom, freeVarIndices;

        private SymmetricDifferentiationRule(SimpleTensor var, Tensor derivative, int[] allFreeFrom, int[] freeVarIndices) {
            super(var);
            this.derivative = derivative;
            this.allFreeFrom = allFreeFrom;
            this.freeVarIndices = freeVarIndices;
        }

        SymmetricDifferentiationRule(SimpleTensor var) {
            super(var);
            SimpleIndices varIndices = var.getIndices();
            int[] allFreeVarIndices = new int[varIndices.size()];
            int[] allFreeArgIndices = new int[varIndices.size()];
            byte type;
            int state, i = 0, length = allFreeArgIndices.length;
            IndexGenerator indexGenerator = new IndexGenerator(varIndices);
            for (; i < length; ++i) {
                type = getType(varIndices.get(i));
                state = getRawStateInt(varIndices.get(i));
                allFreeVarIndices[i] = setRawState(indexGenerator.generate(type), inverseIndexState(state));
                allFreeArgIndices[i] = setRawState(indexGenerator.generate(type), state);
            }
            int[] allIndices = addAll(allFreeVarIndices, allFreeArgIndices);
            SimpleIndices dIndices = IndicesFactory.createSimple(null, allIndices);
            SimpleTensor symmetric = simpleTensor("@!@#@##_AS@23@@#", dIndices);
            Tensor derivative = SymmetrizeSimpleTensorTransformation.symmetrize(
                    symmetric,
                    allFreeVarIndices,
                    varIndices.getSymmetries().getInnerSymmetries());
            derivative = applyIndexMapping(
                    derivative,
                    new Mapping(allIndices,
                            addAll(varIndices.getInverted().getAllIndices().copy(), allFreeArgIndices)),
                    new int[0]);
            ProductBuilder builder = new ProductBuilder(0, length);
            for (i = 0; i < length; ++i)
                builder.put(createMetricOrKronecker(allFreeArgIndices[i], allFreeVarIndices[i]));
            derivative = new SubstitutionTransformation(symmetric, builder.build()).transform(derivative);
            this.derivative = derivative;
            this.freeVarIndices = var.getIndices().getFree().getInverted().getAllIndices().copy();
            this.allFreeFrom = addAll(allFreeArgIndices, freeVarIndices);
        }

        @Override
        Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor) {
            int[] to = simpleTensor.getIndices().getAllIndices().copy();
            to = addAll(to, freeVarIndices);
            return applyIndexMapping(derivative, new Mapping(allFreeFrom, to), new int[0]);
        }

        @Override
        SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor) {
            return new SymmetricDifferentiationRule(this.var,
                    renameDummy(derivative, TensorUtils.getAllIndicesNamesT(tensor).toArray()), allFreeFrom, freeVarIndices);
        }

        @Override
        int[] getForbidden() {
            return TensorUtils.getAllIndicesNamesT(derivative).toArray();
        }
    }
}
TOP

Related Classes of cc.redberry.core.transformations.DifferentiateTransformation$SymmetricDifferentiationRule

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.