/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.cocode;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.cocode.AColumnCoCoder;
import org.apache.sysds.runtime.compress.cocode.ColIndexes;
import org.apache.sysds.runtime.compress.cocode.Memorizer;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.estim.AComEst;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CoCodeGreedy
extends AColumnCoCoder {
    private final Memorizer mem;

    protected CoCodeGreedy(AComEst sizeEstimator, ACostEstimate costEstimator, CompressionSettings cs) {
        super(sizeEstimator, costEstimator, cs);
        this.mem = new Memorizer(sizeEstimator);
    }

    protected CoCodeGreedy(AComEst sizeEstimator, ACostEstimate costEstimator, CompressionSettings cs, Memorizer mem) {
        super(sizeEstimator, costEstimator, cs);
        this.mem = mem;
    }

    @Override
    protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
        colInfos.setInfo(this.combine(colInfos.compressionInfo, k));
        return colInfos;
    }

    protected List<CompressedSizeInfoColGroup> combine(List<CompressedSizeInfoColGroup> inputColumns, int k) {
        for (CompressedSizeInfoColGroup g : inputColumns) {
            this.mem.put(g);
        }
        return this.coCodeBruteForce(inputColumns, k);
    }

    private List<CompressedSizeInfoColGroup> coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns, int k) {
        ArrayList<ColIndexes> workSet = new ArrayList<ColIndexes>(inputColumns.size());
        ExecutorService pool = CommonThreadPool.get(k);
        for (int i = 0; i < inputColumns.size(); ++i) {
            CompressedSizeInfoColGroup g = inputColumns.get(i);
            workSet.add(new ColIndexes(g.getColumns()));
        }
        if (k > 1) {
            this.parallelFirstCombine(workSet, pool);
        }
        double secondChange = 0.0;
        CompressedSizeInfoColGroup secondTmp = null;
        ColIndexes secondSelectedJ = null;
        ColIndexes secondSelected1 = null;
        ColIndexes secondSelected2 = null;
        while (workSet.size() > 1) {
            if (secondChange != 0.0) {
                this.mem.incst4();
            }
            double changeInCost = secondChange;
            CompressedSizeInfoColGroup tmp = secondTmp;
            ColIndexes selectedJ = secondSelectedJ;
            ColIndexes selected1 = secondSelected1;
            ColIndexes selected2 = secondSelected2;
            for (int i = 0; i < workSet.size(); ++i) {
                for (int j = i + 1; j < workSet.size(); ++j) {
                    IColIndex c;
                    ColIndexes cI;
                    CompressedSizeInfoColGroup c1c2Inf;
                    double costC1C2;
                    double newCostIfJoined;
                    ColIndexes c1 = (ColIndexes)workSet.get(i);
                    ColIndexes c2 = (ColIndexes)workSet.get(j);
                    double costC1 = this._cest.getCost(this.mem.get(c1));
                    double costC2 = this._cest.getCost(this.mem.get(c2));
                    this.mem.incst1();
                    if (-Math.min(costC1, costC2) > changeInCost || !((newCostIfJoined = (costC1C2 = this._cest.getCost(c1c2Inf = this.mem.getOrCreate(cI = new ColIndexes(c = c1._indexes.combine(c2._indexes)), c1, c2))) - costC1 - costC2) < 0.0)) continue;
                    if (tmp == null) {
                        changeInCost = newCostIfJoined;
                        tmp = c1c2Inf;
                        selectedJ = cI;
                        selected1 = c1;
                        selected2 = c2;
                        continue;
                    }
                    if (!(newCostIfJoined < changeInCost) && (newCostIfJoined != changeInCost || c1c2Inf.getColumns().size() >= tmp.getColumns().size())) continue;
                    if (selected1 != secondSelected1 && selected2 != secondSelected2) {
                        secondTmp = tmp;
                        secondSelectedJ = selectedJ;
                        secondSelected1 = selected1;
                        secondSelected2 = selected2;
                        secondChange = changeInCost;
                    }
                    changeInCost = newCostIfJoined;
                    tmp = c1c2Inf;
                    selectedJ = cI;
                    selected1 = c1;
                    selected2 = c2;
                }
            }
            if (tmp == null) break;
            workSet.remove(selected1);
            workSet.remove(selected2);
            this.mem.remove(selected1, selected2);
            this.mem.put(selectedJ, tmp);
            workSet.add(selectedJ);
            if (!selectedJ.contains(secondSelected1, secondSelected2)) continue;
            secondTmp = null;
            secondSelectedJ = null;
            secondSelected1 = null;
            secondSelected2 = null;
            secondChange = 0.0;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Memorizer stats:" + this.mem.stats()));
        }
        this.mem.resetStats();
        pool.shutdown();
        ArrayList<CompressedSizeInfoColGroup> ret = new ArrayList<CompressedSizeInfoColGroup>(workSet.size());
        for (ColIndexes w : workSet) {
            ret.add(this.mem.get(w));
        }
        return ret;
    }

    protected void parallelFirstCombine(List<ColIndexes> workSet, ExecutorService pool) {
        try {
            ArrayList<CombineTask> tasks = new ArrayList<CombineTask>();
            int size = workSet.size();
            for (int i = 0; i < size; ++i) {
                for (int j = i + 1; j < size; ++j) {
                    tasks.add(new CombineTask(workSet.get(i), workSet.get(j)));
                }
            }
            for (Future t : pool.invokeAll(tasks)) {
                t.get();
            }
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed parallelize first level all join all", e);
        }
    }

    protected class CombineTask
    implements Callable<Object> {
        private final ColIndexes _c1;
        private final ColIndexes _c2;

        protected CombineTask(ColIndexes c1, ColIndexes c2) {
            this._c1 = c1;
            this._c2 = c2;
        }

        @Override
        public Object call() {
            IColIndex c = this._c1._indexes.combine(this._c2._indexes);
            ColIndexes cI = new ColIndexes(c);
            CoCodeGreedy.this.mem.getOrCreate(cI, this._c1, this._c2);
            return null;
        }
    }
}

