/*
 * Decompiled with CFR 0.152.
 */
package org.apache.datasketches.count;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.charset.StandardCharsets;
import java.util.Random;
import org.apache.datasketches.common.Family;
import org.apache.datasketches.common.SketchesArgumentException;
import org.apache.datasketches.common.SketchesException;
import org.apache.datasketches.common.Util;
import org.apache.datasketches.common.positional.PositionalSegment;
import org.apache.datasketches.hash.MurmurHash3;

public class CountMinSketch {
    private final byte numHashes_;
    private final int numBuckets_;
    private final long seed_;
    private final long[] hashSeeds_;
    private final long[] sketchArray_;
    private long totalWeight_;
    private static final ThreadLocal<MemorySegment> LONG_SEGMENT = ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8]));

    CountMinSketch(byte numHashes, int numBuckets, long seed) {
        if (numHashes <= 0) {
            throw new SketchesArgumentException("Number of hash functions must be positive, got: " + (byte)numHashes);
        }
        if (numBuckets <= 0) {
            throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets);
        }
        if (numBuckets < 3) {
            throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error \u2264 1.0. With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / (double)numBuckets));
        }
        long totalSize = (long)numHashes * (long)numBuckets;
        if (totalSize > Integer.MAX_VALUE) {
            throw new SketchesArgumentException("Sketch array size would overflow: " + (byte)numHashes + " * " + numBuckets + " = " + totalSize + " > 2147483647");
        }
        if (totalSize >= 0x40000000L) {
            throw new SketchesArgumentException("Sketch would require excessive memory: " + (byte)numHashes + " * " + numBuckets + " = " + totalSize + " elements (~" + String.format("%d", totalSize * 8L / 0x40000000L) + " GB). Consider reducing numHashes or numBuckets.");
        }
        this.numHashes_ = numHashes;
        this.numBuckets_ = numBuckets;
        this.seed_ = seed;
        this.hashSeeds_ = new long[numHashes];
        this.sketchArray_ = new long[(int)totalSize];
        this.totalWeight_ = 0L;
        Random rand = new Random(seed);
        for (int i = 0; i < numHashes; ++i) {
            this.hashSeeds_[i] = rand.nextLong();
        }
    }

    private static byte[] longToBytes(long value) {
        MemorySegment segment = LONG_SEGMENT.get();
        segment.set(ValueLayout.JAVA_LONG_UNALIGNED, 0L, value);
        return segment.toArray(ValueLayout.JAVA_BYTE);
    }

    private long[] getHashes(byte[] item) {
        long[] updateLocations = new long[this.numHashes_];
        for (int i = 0; i < this.numHashes_; ++i) {
            long[] index = MurmurHash3.hash(item, this.hashSeeds_[i]);
            updateLocations[i] = (long)i * (long)this.numBuckets_ + (long)Math.floorMod(index[0], this.numBuckets_);
        }
        return updateLocations;
    }

    public boolean isEmpty() {
        return this.totalWeight_ == 0L;
    }

    public byte getNumHashes_() {
        return this.numHashes_;
    }

    public int getNumBuckets_() {
        return this.numBuckets_;
    }

    public long getSeed_() {
        return this.seed_;
    }

    public long getTotalWeight_() {
        return this.totalWeight_;
    }

    public double getRelativeError() {
        return Math.exp(1.0) / (double)this.numBuckets_;
    }

    public static byte suggestNumHashes(double confidence) {
        if (confidence < 0.0 || confidence > 1.0) {
            throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive).");
        }
        int value = (int)Math.ceil(Math.log(1.0 / (1.0 - confidence)));
        return (byte)Math.min(value, 127);
    }

    public static int suggestNumBuckets(double relativeError) {
        if (relativeError < 0.0) {
            throw new SketchesException("Relative error must be at least 0.");
        }
        return (int)Math.ceil(Math.exp(1.0) / relativeError);
    }

    public void update(long item, long weight) {
        this.update(CountMinSketch.longToBytes(item), weight);
    }

    public void update(String item, long weight) {
        if (item == null || item.isEmpty()) {
            return;
        }
        byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
        this.update(strByte, weight);
    }

    public void update(byte[] item, long weight) {
        long[] hashLocations;
        if (item.length == 0) {
            return;
        }
        this.totalWeight_ += weight > 0L ? weight : -weight;
        for (long h : hashLocations = this.getHashes(item)) {
            int n = (int)h;
            this.sketchArray_[n] = this.sketchArray_[n] + weight;
        }
    }

    public long getEstimate(long item) {
        return this.getEstimate(CountMinSketch.longToBytes(item));
    }

    public long getEstimate(String item) {
        if (item == null || item.isEmpty()) {
            return 0L;
        }
        byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
        return this.getEstimate(strByte);
    }

    public long getEstimate(byte[] item) {
        if (item.length == 0) {
            return 0L;
        }
        long[] hashLocations = this.getHashes(item);
        long res = this.sketchArray_[(int)hashLocations[0]];
        for (int i = 1; i < hashLocations.length; ++i) {
            res = Math.min(res, this.sketchArray_[(int)hashLocations[i]]);
        }
        return res;
    }

    public long getUpperBound(long item) {
        return this.getUpperBound(CountMinSketch.longToBytes(item));
    }

    public long getUpperBound(String item) {
        if (item == null || item.isEmpty()) {
            return 0L;
        }
        byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
        return this.getUpperBound(strByte);
    }

    public long getUpperBound(byte[] item) {
        if (item.length == 0) {
            return 0L;
        }
        return this.getEstimate(item) + (long)(this.getRelativeError() * (double)this.getTotalWeight_());
    }

    public long getLowerBound(long item) {
        return this.getLowerBound(CountMinSketch.longToBytes(item));
    }

    public long getLowerBound(String item) {
        if (item == null || item.isEmpty()) {
            return 0L;
        }
        byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
        return this.getLowerBound(strByte);
    }

    public long getLowerBound(byte[] item) {
        return this.getEstimate(item);
    }

    public void merge(CountMinSketch other) {
        boolean acceptableConfig;
        if (this == other) {
            throw new SketchesException("Cannot merge a sketch with itself");
        }
        boolean bl = acceptableConfig = this.getNumBuckets_() == other.getNumBuckets_() && this.getNumHashes_() == other.getNumHashes_() && this.getSeed_() == other.getSeed_();
        if (!acceptableConfig) {
            throw new SketchesException("Incompatible sketch configuration.");
        }
        for (int i = 0; i < this.sketchArray_.length; ++i) {
            int n = i;
            this.sketchArray_[n] = this.sketchArray_[n] + other.sketchArray_[i];
        }
        this.totalWeight_ += other.getTotalWeight_();
    }

    private int getSerializedSizeBytes() {
        int preambleBytes = Family.COUNTMIN.getMinPreLongs() * 8;
        if (this.isEmpty()) {
            return preambleBytes;
        }
        return preambleBytes + 8 + this.sketchArray_.length * 8;
    }

    public byte[] toByteArray() {
        int serializedSizeBytes = this.getSerializedSizeBytes();
        byte[] bytes = new byte[serializedSizeBytes];
        PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes));
        int preambleLongs = Family.COUNTMIN.getMinPreLongs();
        posSeg.setByte((byte)preambleLongs);
        boolean serialVersion = true;
        posSeg.setByte((byte)1);
        int familyId = Family.COUNTMIN.getID();
        posSeg.setByte((byte)familyId);
        int flagsByte = this.isEmpty() ? Flag.IS_EMPTY.mask() : 0;
        posSeg.setByte((byte)flagsByte);
        boolean NULL_32 = false;
        posSeg.setInt(0);
        posSeg.setInt(this.numBuckets_);
        posSeg.setByte(this.numHashes_);
        short hashSeed = Util.computeSeedHash(this.seed_);
        posSeg.setShort(hashSeed);
        boolean NULL_8 = false;
        posSeg.setByte((byte)0);
        if (this.isEmpty()) {
            return bytes;
        }
        posSeg.setLong(this.totalWeight_);
        for (long w : this.sketchArray_) {
            posSeg.setLong(w);
        }
        return bytes;
    }

    public static CountMinSketch deserialize(byte[] b, long seed) {
        long w;
        boolean empty;
        PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b));
        byte preambleLongs = posSeg.getByte();
        byte serialVersion = posSeg.getByte();
        byte familyId = posSeg.getByte();
        byte flagsByte = posSeg.getByte();
        posSeg.getInt();
        int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs();
        if (preambleLongs != expectedPreambleLongs) {
            throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs + ", actual " + preambleLongs);
        }
        boolean expectedSerialVersion = true;
        if (serialVersion != 1) {
            throw new SketchesArgumentException("Serial version mismatch: expected 1, actual " + serialVersion);
        }
        int expectedFamilyId = Family.COUNTMIN.getID();
        if (familyId != expectedFamilyId) {
            throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId + ", actual " + familyId);
        }
        int numBuckets = posSeg.getInt();
        byte numHashes = posSeg.getByte();
        short seedHash = posSeg.getShort();
        posSeg.getByte();
        if (seedHash != Util.computeSeedHash(seed)) {
            throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", " + Util.computeSeedHash(seed));
        }
        CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
        boolean bl = empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0;
        if (empty) {
            return cms;
        }
        cms.totalWeight_ = w = posSeg.getLong();
        for (int i = 0; i < cms.sketchArray_.length; ++i) {
            cms.sketchArray_[i] = posSeg.getLong();
        }
        return cms;
    }

    private static enum Flag {
        IS_EMPTY;


        int mask() {
            return 1 << this.ordinal();
        }
    }
}

