/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.ml;

import com.google.gson.annotations.Expose;
import com.google.gson.annotations.JsonAdapter;
import com.yahoo.sketches.kll.KllFloatsSketch;
import com.yahoo.sketches.kll.KllFloatsSketchIterator;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.special.Erf;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.opensearch.ad.ml.KllFloatsSketchSerDe;
import org.opensearch.ad.ml.ThresholdingModel;

public class HybridThresholdingModel
implements ThresholdingModel {
    public static final double MIN_SCORE = 0.4;
    private static final boolean USE_DOUBLE_SIDED_ERROR = true;
    private static final double CONFIDENCE = 0.99;
    @Expose
    @JsonAdapter(value=KllFloatsSketchSerDe.class)
    private KllFloatsSketch quantileSketch;
    private double maxScore;
    private int numLogNormalQuantiles;
    private double minPvalueThreshold;
    private int downsampleNumSamples;
    private long downsampleMaxNumObservations;

    public HybridThresholdingModel(double minPvalueThreshold, double maxRankError, double maxScore, int numLogNormalQuantiles, int downsampleNumSamples, long downsampleMaxNumObservations) {
        if (minPvalueThreshold <= 0.0 || 1.0 <= minPvalueThreshold) {
            throw new IllegalArgumentException("minPvalueThreshold must be strictly between 0 and 1.");
        }
        if (maxRankError > 1.0 - minPvalueThreshold) {
            throw new IllegalArgumentException("maxRankError must be smaller than 1 - minPvalueThreshold in order to accurately estimate that threshold.");
        }
        if (maxRankError <= 0.0) {
            throw new IllegalArgumentException("maxRankError must be positive.");
        }
        if (maxScore <= 0.0) {
            throw new IllegalArgumentException("maxScore must be positive.");
        }
        if (numLogNormalQuantiles < 0) {
            throw new IllegalArgumentException("The maximum number of log-normal quantiles to compute must be non-negative.");
        }
        if (downsampleNumSamples <= 1) {
            throw new IllegalArgumentException("Number of downsamples must be greater than one.");
        }
        if ((long)downsampleNumSamples >= downsampleMaxNumObservations) {
            throw new IllegalArgumentException("The number of samples to downsample to must be less than the number of observations before downsampling is triggered.");
        }
        this.minPvalueThreshold = minPvalueThreshold;
        this.quantileSketch = new KllFloatsSketch(KllFloatsSketch.getKFromEpsilon((double)maxRankError, (boolean)true));
        this.maxScore = maxScore;
        this.numLogNormalQuantiles = numLogNormalQuantiles;
        this.downsampleNumSamples = downsampleNumSamples;
        this.downsampleMaxNumObservations = downsampleMaxNumObservations;
    }

    public HybridThresholdingModel() {
    }

    public double getMinPvalueThreshold() {
        return this.minPvalueThreshold;
    }

    public double getMaxRankError() {
        return this.quantileSketch.getNormalizedRankError(true);
    }

    public double getMaxScore() {
        return this.maxScore;
    }

    public int getNumLogNormalQuantiles() {
        return this.numLogNormalQuantiles;
    }

    public int getDownsampleNumSamples() {
        return this.downsampleNumSamples;
    }

    public long getDownsampleMaxNumObservations() {
        return this.downsampleMaxNumObservations;
    }

    @Override
    public void train(double[] anomalyScores) {
        double pvalueStep;
        SummaryStatistics stats = new SummaryStatistics();
        for (int i = 0; i < anomalyScores.length; ++i) {
            stats.addValue(Math.log(anomalyScores[i]));
        }
        double mu = stats.getMean();
        double sigma = stats.getStandardDeviation();
        double maxScorePvalue = this.computeLogNormalCdf(this.maxScore, mu, sigma);
        for (double pvalue = pvalueStep = maxScorePvalue / ((double)this.numLogNormalQuantiles + 1.0); pvalue < maxScorePvalue; pvalue += pvalueStep) {
            double currentScore = this.computeLogNormalQuantile(pvalue, mu, sigma);
            this.update(currentScore);
        }
    }

    private double computeLogNormalCdf(double anomalyScore, double mu, double sigma) {
        return (1.0 + Erf.erf((double)((Math.log(anomalyScore) - mu) / (Math.sqrt(2.0) * sigma)))) / 2.0;
    }

    private double computeLogNormalQuantile(double pvalue, double mu, double sigma) {
        return Math.exp(mu + Math.sqrt(2.0) * sigma * Erf.erfInv((double)(2.0 * pvalue - 1.0)));
    }

    @Override
    public void update(double anomalyScore) {
        this.quantileSketch.update((float)anomalyScore);
        long totalNumObservations = this.quantileSketch.getN();
        if (totalNumObservations >= this.downsampleMaxNumObservations) {
            this.downsample();
        }
    }

    @Override
    public double grade(double anomalyScore) {
        double anomalyGrade = 0.0;
        if (anomalyScore > 0.4) {
            double scale = 1.0 / (1.0 - this.minPvalueThreshold);
            double pvalue = this.quantileSketch.getRank((float)anomalyScore);
            anomalyGrade = scale * (pvalue - this.minPvalueThreshold);
            anomalyGrade = Double.isNaN(anomalyGrade) ? 0.0 : Math.max(0.0, anomalyGrade);
        }
        return anomalyGrade;
    }

    @Override
    public double confidence() {
        return 0.99;
    }

    private void downsample() {
        KllFloatsSketch downsampledQuantileSketch = new KllFloatsSketch(this.quantileSketch.getK());
        double pvalueStep = 1.0 / ((double)this.downsampleNumSamples - 1.0);
        for (double pvalue = 0.0; pvalue < 1.0; pvalue += pvalueStep) {
            float score = this.quantileSketch.getQuantile(pvalue);
            downsampledQuantileSketch.update(score);
        }
        downsampledQuantileSketch.update((float)this.maxScore);
        this.quantileSketch = downsampledQuantileSketch;
    }

    @Override
    public List<Double> extractScores() {
        KllFloatsSketchIterator iter = this.quantileSketch.iterator();
        ArrayList<Double> scores = new ArrayList<Double>();
        while (iter.next()) {
            scores.add(Double.valueOf(iter.getValue()));
        }
        return scores;
    }
}

