/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.quantization.quantizer;

import java.io.IOException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.faiss.QFrameBitEncoder;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.BitPacker;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.knn.quantization.quantizer.QuantizerHelper;
import org.opensearch.knn.quantization.quantizer.RandomGaussianRotation;
import org.opensearch.knn.quantization.sampler.Sampler;
import org.opensearch.knn.quantization.sampler.SamplerType;
import org.opensearch.knn.quantization.sampler.SamplingFactory;

public class OneBitScalarQuantizer
implements Quantizer<float[], byte[]> {
    private final int samplingSize;
    private final boolean shouldUseRandomRotation;
    private static final boolean IS_TRAINING_REQUIRED = true;
    private final Sampler sampler;
    private static final int DEFAULT_SAMPLE_SIZE = 25000;

    public OneBitScalarQuantizer() {
        this(25000, QFrameBitEncoder.DEFAULT_ENABLE_RANDOM_ROTATION, SamplingFactory.getSampler(SamplerType.RESERVOIR));
    }

    public OneBitScalarQuantizer(int samplingSize, boolean shouldUseRandomRotation, Sampler sampler) {
        this.samplingSize = samplingSize;
        this.shouldUseRandomRotation = shouldUseRandomRotation;
        this.sampler = sampler;
    }

    public OneBitScalarQuantizer(boolean shouldUseRandomRotation) {
        this(25000, shouldUseRandomRotation, SamplingFactory.getSampler(SamplerType.RESERVOIR));
    }

    public OneBitScalarQuantizer(int samplingSize, Sampler sampler) {
        this.samplingSize = samplingSize;
        this.shouldUseRandomRotation = QFrameBitEncoder.DEFAULT_ENABLE_RANDOM_ROTATION;
        this.sampler = sampler;
    }

    @Override
    public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
        int[] sampledDocIds = this.sampler.sample(trainingRequest.getTotalNumberOfVectors(), this.samplingSize);
        return QuantizerHelper.calculateQuantizationState(trainingRequest, sampledDocIds, ScalarQuantizationParams.builder().sqType(ScalarQuantizationType.ONE_BIT).enableRandomRotation(this.shouldUseRandomRotation).build());
    }

    @Override
    public void quantize(float[] vector, QuantizationState state, QuantizationOutput<byte[]> output) {
        if (vector == null) {
            throw new IllegalArgumentException("Vector to quantize must not be null.");
        }
        this.validateState(state);
        int vectorLength = vector.length;
        OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState)state;
        float[] thresholds = binaryState.getMeanThresholds();
        if (thresholds == null || thresholds.length != vectorLength) {
            throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
        }
        float[][] rotationMatrix = binaryState.getRotationMatrix();
        if (rotationMatrix != null) {
            vector = RandomGaussianRotation.applyRotation(vector, rotationMatrix);
        }
        output.prepareQuantizedVector(vectorLength);
        BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
    }

    @Override
    public void transformWithADC(float[] vector, QuantizationState state, SpaceType spaceType) {
        this.validateState(state);
        OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState)state;
        float[][] rotationMatrix = binaryState.getRotationMatrix();
        float[] rotatedVector = (float[])vector.clone();
        if (rotationMatrix != null) {
            rotatedVector = RandomGaussianRotation.applyRotation(vector, rotationMatrix);
        }
        if (this.shouldDoADCCorrection(spaceType)) {
            this.transformVectorWithADCCorrection(rotatedVector, binaryState);
        } else {
            this.transformVectorWithADCNoCorrection(rotatedVector, binaryState);
        }
        System.arraycopy(rotatedVector, 0, vector, 0, vector.length);
    }

    private boolean shouldDoADCCorrection(SpaceType spaceType) {
        return SpaceType.L2.equals((Object)spaceType);
    }

    private void transformVectorWithADCNoCorrection(float[] vector, OneBitScalarQuantizationState binaryState) {
        for (int i = 0; i < vector.length; ++i) {
            float aboveThreshold = binaryState.getAboveThresholdMeans()[i];
            float belowThreshold = binaryState.getBelowThresholdMeans()[i];
            vector[i] = (vector[i] - belowThreshold) / (aboveThreshold - belowThreshold);
        }
    }

    private void transformVectorWithADCCorrection(float[] vector, OneBitScalarQuantizationState binaryState) {
        for (int i = 0; i < vector.length; ++i) {
            float aboveThreshold = binaryState.getAboveThresholdMeans()[i];
            float belowThreshold = binaryState.getBelowThresholdMeans()[i];
            double correction = Math.pow(aboveThreshold - belowThreshold, 2.0);
            vector[i] = (vector[i] - belowThreshold) / (aboveThreshold - belowThreshold);
            vector[i] = (float)correction * (vector[i] - 0.5f) + 0.5f;
        }
    }

    private void validateState(QuantizationState state) {
        if (!(state instanceof OneBitScalarQuantizationState)) {
            throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState.");
        }
    }
}

