/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.ml;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.AnomalyDescriptor;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ad.indices.ADIndex;
import org.opensearch.ad.indices.ADIndexManagement;
import org.opensearch.ad.ml.ADCheckpointDao;
import org.opensearch.ad.ml.ADColdStart;
import org.opensearch.ad.ml.ThresholdingModel;
import org.opensearch.ad.ml.ThresholdingResult;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyResult;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.common.exception.ResourceNotFoundException;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
import org.opensearch.timeseries.feature.FeatureManager;
import org.opensearch.timeseries.feature.Features;
import org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.util.DateUtils;
import org.opensearch.timeseries.util.ModelUtil;

public class ADModelManager
extends ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADColdStart> {
    protected static final String ENTITY_SAMPLE = "sp";
    protected static final String ENTITY_RCF = "rcf";
    protected static final String ENTITY_THRESHOLD = "th";
    private static final Logger logger = LogManager.getLogger(ADModelManager.class);
    private MemoryAwareConcurrentHashmap<ThresholdedRandomCutForest> forests;
    private Map<String, ModelState<ThresholdingModel>> thresholds;
    private final double thresholdMinPvalue;
    private final int minPreviewSize;
    private final Duration modelTtl;
    private Duration checkpointInterval;
    private final double initialAcceptFraction;

    public ADModelManager(ADCheckpointDao checkpointDao, Clock clock, int rcfNumTrees, int rcfNumSamplesInTree, int rcfNumMinSamples, double thresholdMinPvalue, int minPreviewSize, Duration modelTtl, Setting<TimeValue> checkpointIntervalSetting, ADColdStart entityColdStarter, FeatureManager featureManager, MemoryTracker memoryTracker, Settings settings, ClusterService clusterService) {
        super(rcfNumTrees, rcfNumSamplesInTree, rcfNumMinSamples, entityColdStarter, memoryTracker, clock, featureManager, checkpointDao);
        this.thresholdMinPvalue = thresholdMinPvalue;
        this.minPreviewSize = minPreviewSize;
        this.modelTtl = modelTtl;
        this.checkpointInterval = DateUtils.toDuration((TimeValue)checkpointIntervalSetting.get(settings));
        if (clusterService != null) {
            clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointIntervalSetting, it -> {
                this.checkpointInterval = DateUtils.toDuration(it);
            });
        }
        this.forests = new MemoryAwareConcurrentHashmap(memoryTracker);
        this.thresholds = new ConcurrentHashMap<String, ModelState<ThresholdingModel>>();
        this.initialAcceptFraction = (double)rcfNumMinSamples * 1.0 / (double)rcfNumSamplesInTree;
    }

    @Deprecated
    public void getTRcfResult(String detectorId, String modelId, double[] point, ActionListener<ThresholdingResult> listener) {
        if (this.forests.containsKey(modelId)) {
            this.getTRcfResult((ModelState)this.forests.get(modelId), point, listener);
        } else {
            ((ADCheckpointDao)this.checkpointDao).getTRCFModel(modelId, (ActionListener<Optional<ThresholdedRandomCutForest>>)ActionListener.wrap(restoredModel -> this.processRestoredTRcf((Optional<ThresholdedRandomCutForest>)restoredModel, modelId, detectorId, point, listener), arg_0 -> listener.onFailure(arg_0)));
        }
    }

    @Deprecated
    private void getTRcfResult(ModelState<ThresholdedRandomCutForest> modelState, double[] point, ActionListener<ThresholdingResult> listener) {
        Optional<ThresholdedRandomCutForest> trcfOptional = modelState.getModel();
        if (trcfOptional.isEmpty()) {
            listener.onFailure((Exception)new TimeSeriesException("empty model"));
            return;
        }
        try {
            AnomalyDescriptor result = trcfOptional.get().process(point, 0L);
            double[] attribution = ModelUtil.normalizeAttribution(trcfOptional.get().getForest(), result.getRelevantAttribution());
            listener.onResponse((Object)new ThresholdingResult(result.getAnomalyGrade(), result.getDataConfidence(), result.getRCFScore(), result.getTotalUpdates(), result.getRelativeIndex(), attribution, result.getPastValues(), result.getExpectedValuesList(), result.getLikelihoodOfValues(), result.getThreshold(), result.getNumberOfTrees(), point, null));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    Optional<ModelState<ThresholdedRandomCutForest>> restoreModelState(Optional<ThresholdedRandomCutForest> rcfModel, String modelId, String detectorId) {
        if (!rcfModel.isPresent()) {
            return Optional.empty();
        }
        return rcfModel.filter(rcf -> this.memoryTracker.isHostingAllowed(detectorId, (ThresholdedRandomCutForest)rcf)).map(rcf -> new ModelState<ThresholdedRandomCutForest>((ThresholdedRandomCutForest)rcf, modelId, detectorId, ModelManager.ModelType.TRCF.getName(), this.clock));
    }

    private void processRestoredTRcf(Optional<ThresholdedRandomCutForest> rcfModel, String modelId, String detectorId, double[] point, ActionListener<ThresholdingResult> listener) {
        Optional<ModelState<ThresholdedRandomCutForest>> model = this.restoreModelState(rcfModel, modelId, detectorId);
        if (!model.isPresent()) {
            throw new ResourceNotFoundException(detectorId, "No checkpoints found for model id " + modelId);
        }
        this.forests.put(modelId, model.get());
        this.getTRcfResult(model.get(), point, listener);
    }

    private void processRestoredCheckpoint(Optional<ThresholdedRandomCutForest> checkpointModel, String modelId, String detectorId, ActionListener<Long> listener) {
        logger.info("Restoring checkpoint for {}", (Object)modelId);
        Optional<ModelState<ThresholdedRandomCutForest>> model = this.restoreModelState(checkpointModel, modelId, detectorId);
        model.ifPresentOrElse(modelState -> {
            this.forests.put(modelId, (ModelState<ThresholdedRandomCutForest>)modelState);
            modelState.getModel().ifPresent(trcf -> {
                if (trcf.getForest() != null) {
                    listener.onResponse((Object)trcf.getForest().getTotalUpdates());
                } else {
                    listener.onFailure((Exception)new ResourceNotFoundException(detectorId, "No checkpoints found for model id " + modelId));
                }
            });
        }, () -> listener.onFailure((Exception)new ResourceNotFoundException(detectorId, "No checkpoints found for model id " + modelId)));
    }

    public void getThresholdingResult(String detectorId, String modelId, double score, ActionListener<ThresholdingResult> listener) {
        if (this.thresholds.containsKey(modelId)) {
            this.getThresholdingResult(this.thresholds.get(modelId), score, listener);
        } else {
            ((ADCheckpointDao)this.checkpointDao).getThresholdModel(modelId, (ActionListener<Optional<ThresholdingModel>>)ActionListener.wrap(model -> this.processThresholdCheckpoint((Optional<ThresholdingModel>)model, modelId, detectorId, score, listener), arg_0 -> listener.onFailure(arg_0)));
        }
    }

    private void getThresholdingResult(ModelState<ThresholdingModel> modelState, double score, ActionListener<ThresholdingResult> listener) {
        Optional<ThresholdingModel> thresholdOptional = modelState.getModel();
        if (thresholdOptional.isPresent()) {
            ThresholdingModel threshold = thresholdOptional.get();
            double grade = threshold.grade(score);
            double confidence = threshold.confidence();
            if (score > 0.0) {
                threshold.update(score);
            }
            listener.onResponse((Object)new ThresholdingResult(grade, confidence, score));
        } else {
            listener.onFailure((Exception)new ResourceNotFoundException(modelState.getConfigId(), "No checkpoints found for model id " + modelState.getModelId()));
        }
    }

    private void processThresholdCheckpoint(Optional<ThresholdingModel> thresholdModel, String modelId, String detectorId, double score, ActionListener<ThresholdingResult> listener) {
        Optional<ModelState> model = thresholdModel.map(threshold -> new ModelState<ThresholdingModel>((ThresholdingModel)threshold, modelId, detectorId, ModelManager.ModelType.THRESHOLD.getName(), this.clock));
        if (!model.isPresent()) {
            throw new ResourceNotFoundException(detectorId, "No checkpoints found for model id " + modelId);
        }
        this.thresholds.put(modelId, model.get());
        this.getThresholdingResult(model.get(), score, listener);
    }

    public Set<String> getAllModelIds() {
        return Stream.of(this.forests.keySet(), this.thresholds.keySet()).flatMap(set -> set.stream()).collect(Collectors.toSet());
    }

    public List<ModelState<?>> getAllModels() {
        return Stream.concat(this.forests.values().stream(), this.thresholds.values().stream()).collect(Collectors.toList());
    }

    public void stopModel(String detectorId, String modelId, ActionListener<Void> listener) {
        logger.info(String.format(Locale.ROOT, "Stopping detector %s model %s", detectorId, modelId));
        this.stopModel(this.forests, modelId, (ActionListener<Void>)ActionListener.wrap(r -> this.stopModel(this.thresholds, modelId, listener), arg_0 -> listener.onFailure(arg_0)));
    }

    private <T> void stopModel(Map<String, ModelState<T>> models, String modelId, ActionListener<Void> listener) {
        Instant now = this.clock.instant();
        Optional<ModelState> modelState = Optional.ofNullable(models.remove(modelId)).filter(model -> model.getLastCheckpointTime().plus(this.checkpointInterval).isBefore(now));
        if (modelState.isPresent() && modelState.get().getModel().isPresent()) {
            Object model2 = modelState.get().getModel().get();
            if (model2 instanceof ThresholdedRandomCutForest) {
                ((ADCheckpointDao)this.checkpointDao).putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest)model2, (ActionListener<Void>)ActionListener.wrap(r -> listener.onResponse(null), arg_0 -> listener.onFailure(arg_0)));
            } else if (model2 instanceof ThresholdingModel) {
                ((ADCheckpointDao)this.checkpointDao).putThresholdCheckpoint(modelId, (ThresholdingModel)model2, (ActionListener<Void>)ActionListener.wrap(r -> listener.onResponse(null), arg_0 -> listener.onFailure(arg_0)));
            } else {
                listener.onFailure((Exception)new IllegalArgumentException("Unexpected model type"));
            }
        } else {
            listener.onResponse(null);
        }
    }

    public void clear(String detectorId, ActionListener<Void> listener) {
        this.clearModels(detectorId, this.forests, (ActionListener<Void>)ActionListener.wrap(r -> this.clearModels(detectorId, this.thresholds, listener), arg_0 -> listener.onFailure(arg_0)));
    }

    public void maintenance(ActionListener<Void> listener) {
        this.maintenanceForIterator(this.forests, this.forests.entrySet().iterator(), (ActionListener<Void>)ActionListener.wrap(r -> this.maintenanceForIterator(this.thresholds, this.thresholds.entrySet().iterator(), listener), arg_0 -> listener.onFailure(arg_0)));
    }

    private <T> void maintenanceForIterator(Map<String, ModelState<T>> models, Iterator<Map.Entry<String, ModelState<T>>> iter, ActionListener<Void> listener) {
        if (iter.hasNext()) {
            Map.Entry<String, ModelState<T>> modelEntry = iter.next();
            String modelId = modelEntry.getKey();
            ModelState modelState = modelEntry.getValue();
            Instant now = this.clock.instant();
            if (modelState.expired(this.modelTtl)) {
                models.remove(modelId);
            }
            if (modelState.getLastCheckpointTime().plus(this.checkpointInterval).isBefore(now)) {
                ActionListener checkpointListener = ActionListener.wrap(r -> {
                    modelState.setLastCheckpointTime(now);
                    this.maintenanceForIterator(models, iter, listener);
                }, e -> {
                    logger.warn("Failed to finish maintenance for model id " + modelId, (Throwable)e);
                    this.maintenanceForIterator(models, iter, listener);
                });
                Optional<T> modelOptional = modelState.getModel();
                if (modelOptional.isPresent()) {
                    T model = modelOptional.get();
                    if (model instanceof ThresholdedRandomCutForest) {
                        ((ADCheckpointDao)this.checkpointDao).putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest)model, (ActionListener<Void>)checkpointListener);
                    } else if (model instanceof ThresholdingModel) {
                        ((ADCheckpointDao)this.checkpointDao).putThresholdCheckpoint(modelId, (ThresholdingModel)model, (ActionListener<Void>)checkpointListener);
                    } else {
                        checkpointListener.onFailure((Exception)new IllegalArgumentException("Unexpected model type"));
                    }
                } else {
                    this.maintenanceForIterator(models, iter, listener);
                }
            } else {
                this.maintenanceForIterator(models, iter, listener);
            }
        } else {
            listener.onResponse(null);
        }
    }

    public List<ThresholdingResult> getPreviewResults(Features features, AnomalyDetector detector) {
        double[][] dataPoints = features.getUnprocessedFeatures();
        if (dataPoints.length < this.minPreviewSize) {
            throw new IllegalArgumentException("Insufficient data for preview results. Minimum required: " + this.minPreviewSize);
        }
        List<Map.Entry<Long, Long>> timeRanges = features.getTimeRanges();
        if (timeRanges.size() != dataPoints.length) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "time range size %d does not match data points size %d", timeRanges.size(), dataPoints.length));
        }
        int shingleSize = detector.getShingleSize();
        double rcfTimeDecay = detector.getTimeDecay();
        int baseDimension = dataPoints[0].length;
        ThresholdedRandomCutForest.Builder trcfBuilder = ThresholdedRandomCutForest.builder().randomSeed(0L).dimensions(baseDimension * shingleSize).sampleSize(this.rcfNumSamplesInTree).numberOfTrees(this.rcfNumTrees).timeDecay(rcfTimeDecay).outputAfter(this.rcfNumMinSamples).initialAcceptFraction(this.initialAcceptFraction).parallelExecutionEnabled(false).compact(true).precision(Precision.FLOAT_32).boundingBoxCacheFraction(1.0).shingleSize(shingleSize).anomalyRate(1.0 - this.thresholdMinPvalue).transformMethod(TransformMethod.NORMALIZE).alertOnce(true).autoAdjust(true).internalShinglingEnabled(true);
        if (shingleSize > 1) {
            trcfBuilder.forestMode(ForestMode.STREAMING_IMPUTE);
            trcfBuilder = ModelColdStart.applyImputationMethod(detector, trcfBuilder);
        } else {
            trcfBuilder.forestMode(ForestMode.STANDARD);
        }
        ADColdStart.applyRule(trcfBuilder, detector);
        ThresholdedRandomCutForest trcf = trcfBuilder.build();
        return IntStream.range(0, dataPoints.length).mapToObj(i -> {
            double[] point = dataPoints[i];
            long timestampSecs = (Long)((Map.Entry)timeRanges.get(i)).getValue() / 1000L;
            AnomalyDescriptor descriptor = trcf.process(point, timestampSecs);
            if (descriptor != null) {
                return this.toResult(trcf.getForest(), descriptor, point, false, (Config)detector);
            }
            return null;
        }).collect(Collectors.toList());
    }

    public void getTotalUpdates(String modelId, String detectorId, ActionListener<Long> listener) {
        ModelState model = (ModelState)this.forests.get(modelId);
        if (model != null) {
            if (model.getModel().isPresent() && ((ThresholdedRandomCutForest)model.getModel().get()).getForest() != null) {
                listener.onResponse((Object)((ThresholdedRandomCutForest)model.getModel().get()).getForest().getTotalUpdates());
            } else {
                listener.onResponse((Object)0L);
            }
        } else {
            ((ADCheckpointDao)this.checkpointDao).getTRCFModel(modelId, (ActionListener<Optional<ThresholdedRandomCutForest>>)ActionListener.wrap(checkpoint -> this.processRestoredCheckpoint((Optional<ThresholdedRandomCutForest>)checkpoint, modelId, detectorId, listener), arg_0 -> listener.onFailure(arg_0)));
        }
    }

    @Override
    protected ThresholdingResult createEmptyResult() {
        return new ThresholdingResult(0.0, 0.0, 0.0);
    }

    @Override
    protected ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor, double[] point, boolean isImputed, Config config) {
        return ModelUtil.toResult(rcf, anomalyDescriptor, point, isImputed, config);
    }
}

