Package jadx.core.dex.visitors

Source Code of jadx.core.dex.visitors.SimplifyVisitor

package jadx.core.dex.visitors;

import jadx.core.Consts;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.conditions.IfCondition;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SimplifyVisitor extends AbstractVisitor {

  private static final Logger LOG = LoggerFactory.getLogger(SimplifyVisitor.class);

  @Override
  public void visit(MethodNode mth) {
    if (mth.isNoCode()) {
      return;
    }
    for (BlockNode block : mth.getBasicBlocks()) {
      List<InsnNode> list = block.getInstructions();
      for (int i = 0; i < list.size(); i++) {
        InsnNode modInsn = simplifyInsn(mth, list.get(i));
        if (modInsn != null) {
          list.set(i, modInsn);
        }
      }
    }
  }

  private static InsnNode simplifyInsn(MethodNode mth, InsnNode insn) {
    for (InsnArg arg : insn.getArguments()) {
      if (arg.isInsnWrap()) {
        InsnNode ni = simplifyInsn(mth, ((InsnWrapArg) arg).getWrapInsn());
        if (ni != null) {
          arg.wrapInstruction(ni);
        }
      }
    }
    switch (insn.getType()) {
      case ARITH:
        return simplifyArith(insn);

      case IF:
        simplifyIf((IfNode) insn);
        break;
      case TERNARY:
        simplifyTernary((TernaryInsn) insn);
        break;

      case INVOKE:
        return convertInvoke(mth, insn);

      case IPUT:
      case SPUT:
        return convertFieldArith(mth, insn);

      case CHECK_CAST:
        InsnArg castArg = insn.getArg(0);
        ArgType castType = (ArgType) ((IndexInsnNode) insn).getIndex();
        if (!ArgType.isCastNeeded(castArg.getType(), castType)) {
          InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
          insnNode.setResult(insn.getResult());
          insnNode.addArg(castArg);
          return insnNode;
        }
        break;

      default:
        break;
    }
    return null;
  }

  /**
   * Simplify 'cmp' instruction in if condition
   */
  private static void simplifyIf(IfNode insn) {
    InsnArg f = insn.getArg(0);
    if (f.isInsnWrap()) {
      InsnNode wi = ((InsnWrapArg) f).getWrapInsn();
      if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) {
        if (insn.getArg(1).isLiteral()
            && ((LiteralArg) insn.getArg(1)).getLiteral() == 0) {
          insn.changeCondition(insn.getOp(), wi.getArg(0), wi.getArg(1));
        } else {
          LOG.warn("TODO: cmp {}", insn);
        }
      }
    }
  }

  /**
   * Simplify condition in ternary operation
   */
  private static void simplifyTernary(TernaryInsn insn) {
    IfCondition condition = insn.getCondition();
    if (condition.isCompare()) {
      simplifyIf(condition.getCompare().getInsn());
    } else {
      insn.simplifyCondition();
    }
  }

  private static InsnNode convertInvoke(MethodNode mth, InsnNode insn) {
    MethodInfo callMth = ((InvokeNode) insn).getCallMth();
    if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
        && callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)
        && insn.getArg(0).isInsnWrap()) {
      try {
        List<InsnNode> chain = flattenInsnChain(insn);
        if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) {
          ConstructorInsn constr = (ConstructorInsn) chain.get(0);
          if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)
              && constr.getArgsCount() == 0) {
            int len = chain.size();
            InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1);
            for (int i = 1; i < len; i++) {
              concatInsn.addArg(chain.get(i).getArg(1));
            }
            concatInsn.setResult(insn.getResult());
            return concatInsn;
          }
        }
      } catch (Throwable e) {
        LOG.debug("Can't convert string concatenation: {} insn: {}", mth, insn, e);
      }
    }
    return null;
  }

  private static InsnNode simplifyArith(InsnNode insn) {
    ArithNode arith = (ArithNode) insn;
    if (arith.getArgsCount() != 2) {
      return null;
    }
    InsnArg litArg = null;
    InsnArg secondArg = arith.getArg(1);
    if (secondArg.isInsnWrap()) {
      InsnNode wr = ((InsnWrapArg) secondArg).getWrapInsn();
      if (wr.getType() == InsnType.CONST) {
        litArg = wr.getArg(0);
      }
    } else if (secondArg.isLiteral()) {
      litArg = secondArg;
    }
    if (litArg != null) {
      long lit = ((LiteralArg) litArg).getLiteral();
      // fix 'c + (-1)' => 'c - (1)'
      if (arith.getOp() == ArithOp.ADD && lit < 0) {
        return new ArithNode(ArithOp.SUB,
            arith.getResult(), insn.getArg(0),
            InsnArg.lit(-lit, litArg.getType()));
      }
    }
    return null;
  }

  /**
   * Convert field arith operation to arith instruction
   * (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit))
   */
  private static InsnNode convertFieldArith(MethodNode mth, InsnNode insn) {
    InsnArg arg = insn.getArg(0);
    if (!arg.isInsnWrap()) {
      return null;
    }
    InsnNode wrap = ((InsnWrapArg) arg).getWrapInsn();
    InsnType wrapType = wrap.getType();
    if ((wrapType != InsnType.ARITH && wrapType != InsnType.STR_CONCAT)
        || !wrap.getArg(0).isInsnWrap()) {
      return null;
    }
    InsnNode get = ((InsnWrapArg) wrap.getArg(0)).getWrapInsn();
    InsnType getType = get.getType();
    if (getType != InsnType.IGET && getType != InsnType.SGET) {
      return null;
    }
    FieldInfo field = (FieldInfo) ((IndexInsnNode) insn).getIndex();
    FieldInfo innerField = (FieldInfo) ((IndexInsnNode) get).getIndex();
    if (!field.equals(innerField)) {
      return null;
    }
    try {
      InsnArg reg = null;
      if (getType == InsnType.IGET) {
        reg = get.getArg(0);
      }
      FieldArg fArg = new FieldArg(field, reg);
      if (reg != null) {
        fArg.setType(get.getArg(0).getType());
      }
      if (wrapType == InsnType.ARITH) {
        ArithNode ar = (ArithNode) wrap;
        return new ArithNode(ar.getOp(), fArg, ar.getArg(1));
      } else {
        int argsCount = wrap.getArgsCount();
        InsnNode concat = new InsnNode(InsnType.STR_CONCAT, argsCount - 1);
        for (int i = 1; i < argsCount; i++) {
          concat.addArg(wrap.getArg(i));
        }
        return new ArithNode(ArithOp.ADD, fArg, InsnArg.wrapArg(concat));
      }
    } catch (Throwable e) {
      LOG.debug("Can't convert field arith insn: {}, mth: {}", insn, mth, e);
    }
    return null;
  }

  private static List<InsnNode> flattenInsnChain(InsnNode insn) {
    List<InsnNode> chain = new ArrayList<InsnNode>();
    InsnArg i = insn.getArg(0);
    while (i.isInsnWrap()) {
      InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn();
      chain.add(wrapInsn);
      if (wrapInsn.getArgsCount() == 0) {
        break;
      }

      i = wrapInsn.getArg(0);
    }
    Collections.reverse(chain);
    return chain;
  }
}
TOP

Related Classes of jadx.core.dex.visitors.SimplifyVisitor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.