/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.engine.faiss;

import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.engine.Encoder;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.mapper.CompressionLevel;

public abstract class AbstractFaissPQEncoder
implements Encoder {
    @Override
    public CompressionLevel calculateCompressionLevel(MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext) {
        if (!methodComponentContext.getParameters().containsKey("m") || !methodComponentContext.getParameters().containsKey("code_size")) {
            return CompressionLevel.NOT_CONFIGURED;
        }
        Object value = methodComponentContext.getParameters().get("m");
        ValidationException validationException = this.getMethodComponent().getParameters().get("m").validate(value, knnMethodConfigContext);
        if (validationException != null) {
            throw validationException;
        }
        Integer m = (Integer)value;
        value = methodComponentContext.getParameters().get("code_size");
        validationException = this.getMethodComponent().getParameters().get("code_size").validate(value, knnMethodConfigContext);
        if (validationException != null) {
            throw validationException;
        }
        Integer codeSize = (Integer)value;
        int dimension = knnMethodConfigContext.getDimension();
        float actualCompression = (float)dimension * 32.0f / (float)(m * codeSize);
        if (actualCompression < 2.0f) {
            return CompressionLevel.x1;
        }
        if (actualCompression < 4.0f) {
            return CompressionLevel.x2;
        }
        if (actualCompression < 8.0f) {
            return CompressionLevel.x4;
        }
        if (actualCompression < 16.0f) {
            return CompressionLevel.x8;
        }
        if (actualCompression < 32.0f) {
            return CompressionLevel.x16;
        }
        if (actualCompression < 64.0f) {
            return CompressionLevel.x32;
        }
        return CompressionLevel.MAX_COMPRESSION_LEVEL;
    }

    @Override
    public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
        KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
        KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
        Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();
        TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
        if (knnMethodContext != null && knnMethodConfigContext != null) {
            if (knnMethodContext.getMethodComponentContext().getParameters().containsKey("m") && knnMethodConfigContext.getDimension() % (Integer)knnMethodContext.getMethodComponentContext().getParameters().get("m") != 0) {
                builder.valid(false);
                builder.errorMessage("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
                return builder.build();
            }
            builder.valid(true);
        }
        if (knnMethodContext != null && trainingVectors != null) {
            long minTrainingVectorCount = 1000L;
            MethodComponentContext encoderContext = (MethodComponentContext)knnMethodContext.getMethodComponentContext().getParameters().get("encoder");
            if (encoderContext.getParameters().containsKey("code_size")) {
                int code_size = (Integer)encoderContext.getParameters().get("code_size");
                minTrainingVectorCount = (long)Math.pow(2.0, code_size);
            }
            if (trainingVectors < minTrainingVectorCount) {
                builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
                builder.errorMessage(String.format("Number of training points should be greater than %d", minTrainingVectorCount));
                return builder.build();
            }
            builder.valid(true);
        }
        return builder.build();
    }
}

