/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.util.EnumMap;
import java.util.Locale;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.monitor.jvm.JvmService;
import org.opensearch.timeseries.breaker.CircuitBreakerService;
import org.opensearch.timeseries.common.exception.LimitExceededException;

public class MemoryTracker {
    private static final Logger LOG = LogManager.getLogger(MemoryTracker.class);
    protected long totalMemoryBytes = 0L;
    protected final Map<Origin, Long> totalMemoryBytesByOrigin = new EnumMap<Origin, Long>(Origin.class);
    protected long reservedMemoryBytes = 0L;
    protected final Map<Origin, Long> reservedMemoryBytesByOrigin = new EnumMap<Origin, Long>(Origin.class);
    protected long heapSize;
    protected long heapLimitBytes;
    protected int thresholdModelBytes;
    protected CircuitBreakerService timeSeriesCircuitBreakerService;

    public MemoryTracker(JvmService jvmService, double modelMaxSizePercentage, ClusterService clusterService, CircuitBreakerService timeSeriesCircuitBreakerService) {
        this.heapSize = jvmService.info().getMem().getHeapMax().getBytes();
        this.heapLimitBytes = (long)((double)this.heapSize * modelMaxSizePercentage);
        if (clusterService != null) {
            clusterService.getClusterSettings().addSettingsUpdateConsumer(AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE, it -> {
                this.heapLimitBytes = (long)((double)this.heapSize * it);
            });
        }
        this.thresholdModelBytes = 180000;
        this.timeSeriesCircuitBreakerService = timeSeriesCircuitBreakerService;
    }

    public synchronized boolean canAllocateReserved(long requiredBytes) {
        return false == this.timeSeriesCircuitBreakerService.isOpen() && this.reservedMemoryBytes + requiredBytes <= this.heapLimitBytes;
    }

    public synchronized boolean canAllocate(long bytes) {
        return false == this.timeSeriesCircuitBreakerService.isOpen() && this.totalMemoryBytes + bytes <= this.heapLimitBytes;
    }

    public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) {
        this.totalMemoryBytes += memoryToConsume;
        this.adjustOriginMemoryConsumption(memoryToConsume, origin, this.totalMemoryBytesByOrigin);
        if (reserved) {
            this.reservedMemoryBytes += memoryToConsume;
            this.adjustOriginMemoryConsumption(memoryToConsume, origin, this.reservedMemoryBytesByOrigin);
        }
    }

    private void adjustOriginMemoryConsumption(long memoryToConsume, Origin origin, Map<Origin, Long> mapToUpdate) {
        Long originTotalMemoryBytes = mapToUpdate.getOrDefault((Object)origin, 0L);
        mapToUpdate.put(origin, originTotalMemoryBytes + memoryToConsume);
    }

    public synchronized void releaseMemory(long memoryToShed, boolean reserved, Origin origin) {
        this.totalMemoryBytes -= memoryToShed;
        this.adjustOriginMemoryRelease(memoryToShed, origin, this.totalMemoryBytesByOrigin);
        if (reserved) {
            this.reservedMemoryBytes -= memoryToShed;
            this.adjustOriginMemoryRelease(memoryToShed, origin, this.reservedMemoryBytesByOrigin);
        }
    }

    private void adjustOriginMemoryRelease(long memoryToConsume, Origin origin, Map<Origin, Long> mapToUpdate) {
        Long originTotalMemoryBytes = mapToUpdate.get((Object)origin);
        if (originTotalMemoryBytes != null) {
            mapToUpdate.put(origin, originTotalMemoryBytes - memoryToConsume);
        }
    }

    public long estimateTRCFModelSize(int dimension, int numberOfTrees, double boundingBoxCacheFraction, int shingleSize, int sampleSize) {
        double baseDimension = dimension / shingleSize;
        double pointStoreSizeConstant = 1.0;
        if (shingleSize == 1) {
            pointStoreSizeConstant = 1.0;
        } else if (shingleSize == 2) {
            pointStoreSizeConstant = 0.53;
        } else if (shingleSize <= 4) {
            pointStoreSizeConstant = 0.27;
        } else if (shingleSize <= 8) {
            pointStoreSizeConstant = 0.18;
        } else if (shingleSize <= 16) {
            pointStoreSizeConstant = 0.13;
        } else if (shingleSize <= 32) {
            pointStoreSizeConstant = 0.07;
        } else if (shingleSize <= 64) {
            pointStoreSizeConstant = 0.05;
        } else if (shingleSize <= 128) {
            pointStoreSizeConstant = 0.05;
        } else {
            throw new IllegalArgumentException("out of range shingle size " + shingleSize);
        }
        int capacity = sampleSize * numberOfTrees;
        int pointStoreCapacity = Math.max(capacity + 1, 2 * sampleSize);
        int pointStoreTypeConstant = shingleSize * pointStoreCapacity >= 65535 ? 4 : 2;
        int boundingBoxExistsConstant = boundingBoxCacheFraction > 0.0 ? 1 : 0;
        int nodeStoreSize = 0;
        int numberOfInternalNodes = sampleSize - 1;
        nodeStoreSize = numberOfInternalNodes < 256 && dimension <= 256 ? 10 * sampleSize + 208 : (numberOfInternalNodes < 65535 && dimension <= 65535 ? 16 * sampleSize + 202 : 20 * sampleSize + 198);
        return (long)(152.0 * baseDimension + (double)(4 * dimension) * pointStoreSizeConstant * (double)capacity + (double)(64 * dimension) + (double)(pointStoreTypeConstant * capacity) + (double)(4 * shingleSize) + (double)capacity + (double)numberOfTrees * ((double)(32 * boundingBoxExistsConstant) + 8.0 * boundingBoxCacheFraction * (double)dimension * (double)sampleSize + 8.0 * boundingBoxCacheFraction * (double)sampleSize + (double)nodeStoreSize + (double)(8 * sampleSize) + 352.0) + 3944.0);
    }

    public long estimateCasterModelSize(int dimension, int numberOfTrees, double boundingBoxCacheFraction, int shingleSize, int sampleSize, int horizon) {
        long trcfModelSize = this.estimateTRCFModelSize(dimension, numberOfTrees, boundingBoxCacheFraction, shingleSize, sampleSize);
        double baseDimension = dimension / shingleSize;
        double errorHandlerSize = 176.0 * baseDimension * (double)horizon + 28.0 * baseDimension + (double)(12 * horizon) * (baseDimension * (double)horizon + 6.0) + 2556.0;
        return (long)((double)trcfModelSize + errorHandlerSize);
    }

    public synchronized long memoryToShed() {
        return this.totalMemoryBytes - this.heapLimitBytes;
    }

    public long getHeapLimit() {
        return this.heapLimitBytes;
    }

    public long getTotalMemoryBytes() {
        return this.totalMemoryBytes;
    }

    public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long reservedBytes) {
        long recordedTotalBytes = this.totalMemoryBytesByOrigin.getOrDefault((Object)origin, 0L);
        long recordedReservedBytes = this.reservedMemoryBytesByOrigin.getOrDefault((Object)origin, 0L);
        if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) {
            return false;
        }
        LOG.info(String.format(Locale.ROOT, "Memory states do not match.  Recorded: total bytes %d, reserved bytes %d. Actual: total bytes %d, reserved bytes: %d", recordedTotalBytes, recordedReservedBytes, totalBytes, reservedBytes));
        long reservedDiff = reservedBytes - recordedReservedBytes;
        this.reservedMemoryBytesByOrigin.put(origin, reservedBytes);
        this.reservedMemoryBytes += reservedDiff;
        long totalDiff = totalBytes - recordedTotalBytes;
        this.totalMemoryBytesByOrigin.put(origin, totalBytes);
        this.totalMemoryBytes += totalDiff;
        return true;
    }

    public int getThresholdModelBytes() {
        return this.thresholdModelBytes;
    }

    public synchronized boolean isHostingAllowed(String configId, ThresholdedRandomCutForest trcf) {
        long requiredBytes = this.estimateTRCFModelSize(trcf);
        if (this.canAllocateReserved(requiredBytes)) {
            return true;
        }
        throw new LimitExceededException(configId, String.format(Locale.ROOT, "Exceeded memory limit. New size is %d bytes and max limit is %d bytes", this.reservedMemoryBytes + requiredBytes, this.heapLimitBytes));
    }

    public long estimateTRCFModelSize(ThresholdedRandomCutForest trcf) {
        RandomCutForest forest = trcf.getForest();
        return this.estimateTRCFModelSize(forest.getDimensions(), forest.getNumberOfTrees(), forest.getBoundingBoxCacheFraction(), forest.getShingleSize(), forest.getSampleSize());
    }

    public long estimateCasterModelSize(RCFCaster caster) {
        RandomCutForest forest = caster.getForest();
        return this.estimateCasterModelSize(forest.getDimensions(), forest.getNumberOfTrees(), forest.getBoundingBoxCacheFraction(), forest.getShingleSize(), forest.getSampleSize(), caster.getForecastHorizon());
    }

    public static enum Origin {
        REAL_TIME_DETECTOR,
        HISTORICAL_SINGLE_ENTITY_DETECTOR,
        REAL_TIME_FORECASTER;

    }
}

