/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.AColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibTSMM;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public abstract class AMorphingMMColGroup
extends AColGroupValue {
    private static final long serialVersionUID = -4265713396790607199L;

    protected AMorphingMMColGroup(IColIndex colIndices, ADictionary dict, int[] cachedCounts) {
        super(colIndices, dict, cachedCounts);
    }

    @Override
    protected final void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) {
        LOG.warn((Object)"Should never call decompress on morphing group instead extract common values and combine all commons");
        double[] cv = new double[db.getDim(1)];
        AColGroup b = this.extractCommon(cv);
        b.decompressToDenseBlock(db, rl, ru, offR, offC);
        this.decompressToDenseBlockCommonVector(db, rl, ru, offR, offC, cv);
    }

    @Override
    protected final void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) {
        LOG.warn((Object)"Should never call decompress on morphing group instead extract common values and combine all commons");
        double[] cv = new double[db.getDim(1)];
        AColGroup b = this.extractCommon(cv);
        b.decompressToDenseBlock(db, rl, ru, offR, offC);
        this.decompressToDenseBlockCommonVector(db, rl, ru, offR, offC, cv);
    }

    private final void decompressToDenseBlockCommonVector(DenseBlock db, int rl, int ru, int offR, int offC, double[] common) {
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT) + offC;
            for (int j = 0; j < this._colIndexes.size(); ++j) {
                int n = off + this._colIndexes.get(j);
                c[n] = c[n] + common[j];
            }
            ++i;
            ++offT;
        }
    }

    @Override
    protected final void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, SparseBlock sb) {
        LOG.warn((Object)"Should never call decompress on morphing group instead extract common values and combine all commons");
        double[] cv = new double[this._colIndexes.get(this._colIndexes.size() - 1) + 1];
        AColGroup b = this.extractCommon(cv);
        b.decompressToSparseBlock(ret, rl, ru, offR, offC);
        this.decompressToSparseBlockCommonVector(ret, rl, ru, offR, offC, cv);
    }

    @Override
    protected final void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, double[] values) {
        LOG.warn((Object)"Should never call decompress on morphing group instead extract common values and combine all commons");
        double[] cv = new double[this._colIndexes.get(this._colIndexes.size() - 1) + 1];
        AColGroup b = this.extractCommon(cv);
        b.decompressToSparseBlock(ret, rl, ru, offR, offC);
        this.decompressToSparseBlockCommonVector(ret, rl, ru, offR, offC, cv);
    }

    private final void decompressToSparseBlockCommonVector(SparseBlock sb, int rl, int ru, int offR, int offC, double[] common) {
        int nCol = this._colIndexes.size();
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            for (int j = 0; j < nCol; ++j) {
                if (common[j] == 0.0) continue;
                sb.add(offT, this._colIndexes.get(j) + offC, common[j]);
            }
            SparseRow sr = sb.get(offT);
            if (sr != null) {
                sr.compact(1.0E-20);
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public final void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        LOG.warn((Object)"Should never call leftMultByMatrixNoPreAgg on morphing group");
        double[] cv = new double[result.getNumColumns()];
        AColGroup b = this.extractCommon(cv);
        b.leftMultByMatrixNoPreAgg(matrix, result, rl, ru, cl, cu);
        double[] rowSum = cl != 0 || cu != matrix.getNumColumns() ? CLALibLeftMultBy.rowSum(matrix, rl, ru, cl, cu) : matrix.rowSum().getDenseBlockValues();
        ColGroupUtils.outerProduct(rowSum, cv, result.getDenseBlockValues(), rl, ru);
    }

    @Override
    public final void leftMultByAColGroup(AColGroup lhs, MatrixBlock result, int nRows) {
        LOG.warn((Object)"Should never call leftMultByMatrixNoPreAgg on morphing group");
        double[] cv = new double[result.getNumColumns()];
        AColGroup b = this.extractCommon(cv);
        b.leftMultByAColGroup(lhs, result, nRows);
        double[] rowSum = new double[result.getNumRows()];
        lhs.computeColSums(rowSum, nRows);
        ColGroupUtils.outerProduct(rowSum, cv, result.getDenseBlockValues(), 0, result.getNumRows());
    }

    @Override
    public final void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called tsmm on morphing");
    }

    @Override
    protected final void tsmm(double[] result, int numColumns, int nRows) {
        LOG.warn((Object)"tsmm should not be called directly on a morphing column group");
        double[] cv = new double[numColumns];
        AColGroupCompressed b = (AColGroupCompressed)this.extractCommon(cv);
        b.tsmm(result, numColumns, nRows);
        double[] colSum = new double[numColumns];
        b.computeColSums(colSum, nRows);
        CLALibTSMM.addCorrectionLayer(cv, colSum, nRows, result);
    }

    @Override
    protected IColIndex rightMMGetColsDense(double[] b, int nCols, IColIndex allCols, long nnz) {
        return allCols;
    }

    @Override
    protected IColIndex rightMMGetColsSparse(SparseBlock b, int nCols, IColIndex allCols) {
        return allCols;
    }

    @Override
    protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, ADictionary preAgg) {
        LOG.warn((Object)"right mm should not be called directly on a morphing column group");
        double[] common = this.getCommon();
        int rc = right.getNumColumns();
        double[] commonMultiplied = new double[rc];
        int lc = this._colIndexes.size();
        if (right.isInSparseFormat()) {
            SparseBlock sb = right.getSparseBlock();
            for (int r = 0; r < lc; ++r) {
                int of = this._colIndexes.get(r);
                if (sb.isEmpty(of)) continue;
                int apos = sb.pos(of);
                int alen = sb.size(of) + apos;
                int[] aix = sb.indexes(of);
                double[] avals = sb.values(of);
                double v = common[r];
                for (int j = apos; j < alen; ++j) {
                    int n = aix[apos];
                    commonMultiplied[n] = commonMultiplied[n] + v * avals[j];
                }
            }
        } else {
            double[] rV = right.getDenseBlockValues();
            for (int r = 0; r < lc; ++r) {
                int rOff = rc * this._colIndexes.get(r);
                double v = common[r];
                for (int c = 0; c < rc; ++c) {
                    int n = c;
                    commonMultiplied[n] = commonMultiplied[n] + v * rV[rOff + c];
                }
            }
        }
        return this.allocateRightMultiplicationCommon(commonMultiplied, colIndexes, preAgg);
    }

    protected abstract AColGroup allocateRightMultiplicationCommon(double[] var1, IColIndex var2, ADictionary var3);

    public abstract AColGroup extractCommon(double[] var1);

    public abstract double[] getCommon();
}

