/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;

public class IPAPassPropagateReplaceLiterals
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return true;
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        this.rReplaceLiterals(prog.getStatementBlocks(), prog, fgraph, fcallSizes);
        for (String fkey : fgraph.getReachableFunctions()) {
            List<FunctionOp> flist = fgraph.getFunctionCalls(fkey);
            if (flist.isEmpty()) continue;
            FunctionOp first = flist.get(0);
            if (!fcallSizes.hasSafeLiterals(fkey)) continue;
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            ArrayList<DataIdentifier> finputs = fstmt.getInputParams();
            LocalVariableMap callVars = new LocalVariableMap();
            for (int j = 0; j < finputs.size(); ++j) {
                if (!fcallSizes.isSafeLiteral(fkey, j)) continue;
                LiteralOp lit = (LiteralOp)first.getInput().get(j);
                String varname = first.getInputVariableNames() != null ? first.getInputVariableNames()[j] : finputs.get(j).getName();
                callVars.put(varname, ScalarObjectFactory.createScalarObject(lit.getValueType(), lit));
            }
            for (StatementBlock sb : fstmt.getBody()) {
                this.rReplaceLiterals(sb, callVars);
            }
            this.rReplaceLiterals(fstmt.getBody(), prog, fgraph, fcallSizes);
        }
        return false;
    }

    private void rReplaceLiterals(List<StatementBlock> sbs, DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        LocalVariableMap constants = new LocalVariableMap();
        for (StatementBlock sb : sbs) {
            constants.removeAllIn(sb.variablesUpdated().getVariableNames());
            this.rReplaceLiterals(sb, constants);
            if (!HopRewriteUtils.isLastLevelStatementBlock(sb)) continue;
            for (Hop root : sb.getHops()) {
                if (!HopRewriteUtils.isData(root, Types.OpOpData.TRANSIENTWRITE) || !(root.getInput().get(0) instanceof LiteralOp)) continue;
                constants.put(root.getName(), ScalarObjectFactory.createScalarObject((LiteralOp)root.getInput().get(0)));
            }
        }
    }

    private void rReplaceLiterals(StatementBlock sb, LocalVariableMap constants) {
        for (String varname : sb.variablesUpdated().getVariableNames()) {
            if (!constants.keySet().contains(varname)) continue;
            constants.remove(varname);
        }
        if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement ws = (WhileStatement)sb.getStatement(0);
            IPAPassPropagateReplaceLiterals.replaceLiterals(wsb.getPredicateHops(), constants);
            for (StatementBlock current : ws.getBody()) {
                this.rReplaceLiterals(current, constants);
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement ifs = (IfStatement)sb.getStatement(0);
            IPAPassPropagateReplaceLiterals.replaceLiterals(isb.getPredicateHops(), constants);
            for (StatementBlock current : ifs.getIfBody()) {
                this.rReplaceLiterals(current, constants);
            }
            for (StatementBlock current : ifs.getElseBody()) {
                this.rReplaceLiterals(current, constants);
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fs = (ForStatement)sb.getStatement(0);
            IPAPassPropagateReplaceLiterals.replaceLiterals(fsb.getFromHops(), constants);
            IPAPassPropagateReplaceLiterals.replaceLiterals(fsb.getToHops(), constants);
            IPAPassPropagateReplaceLiterals.replaceLiterals(fsb.getIncrementHops(), constants);
            for (StatementBlock current : fs.getBody()) {
                this.rReplaceLiterals(current, constants);
            }
        } else {
            IPAPassPropagateReplaceLiterals.replaceLiterals(sb.getHops(), constants);
        }
    }

    private static void replaceLiterals(ArrayList<Hop> roots, LocalVariableMap constants) {
        if (roots == null) {
            return;
        }
        try {
            Hop.resetVisitStatus(roots);
            for (Hop root : roots) {
                Recompiler.rReplaceLiterals(root, constants, true);
            }
            Hop.resetVisitStatus(roots);
        }
        catch (Exception ex) {
            throw new HopsException(ex);
        }
    }

    private static void replaceLiterals(Hop root, LocalVariableMap constants) {
        if (root == null) {
            return;
        }
        try {
            root.resetVisitStatus();
            Recompiler.rReplaceLiterals(root, constants, true);
            root.resetVisitStatus();
        }
        catch (Exception ex) {
            throw new HopsException(ex);
        }
    }
}

