/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import com.google.common.base.Predicates;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Locale;
import java.util.function.Predicate;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.KNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.indices.ModelDao;

public class ExactSearcher {
    @Generated
    private static final Logger log = LogManager.getLogger(ExactSearcher.class);
    private final ModelDao modelDao;

    public TopDocs searchLeaf(LeafReaderContext leafReaderContext, ExactSearcherContext context) throws IOException {
        KNNIterator iterator = this.getKNNIterator(leafReaderContext, context);
        if (iterator == null) {
            return TopDocsCollector.EMPTY_TOPDOCS;
        }
        if (context.getRadius() != null) {
            return this.doRadialSearch(leafReaderContext, context, iterator);
        }
        if (context.getMatchedDocsIterator() != null && context.numberOfMatchedDocs <= (long)context.getK()) {
            return this.scoreAllDocs(iterator);
        }
        return this.searchTopCandidates(iterator, context.getK(), (Predicate<Float>)Predicates.alwaysTrue());
    }

    private TopDocs doRadialSearch(LeafReaderContext leafReaderContext, ExactSearcherContext context, KNNIterator iterator) throws IOException {
        assert (context.isMemoryOptimizedSearchEnabled != null);
        SegmentReader reader = Lucene.segmentReader((LeafReader)leafReaderContext.reader());
        FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo((LeafReader)reader, context.getField());
        if (fieldInfo == null) {
            return TopDocsCollector.EMPTY_TOPDOCS;
        }
        KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
        if (KNNEngine.FAISS != engine) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine));
        }
        SpaceType spaceType = FieldInfoExtractor.getSpaceType(this.modelDao, fieldInfo);
        float minScore = context.isMemoryOptimizedSearchEnabled != false ? context.getRadius().floatValue() : spaceType.scoreTranslation(context.getRadius().floatValue());
        return this.filterDocsByMinScore(context, iterator, minScore);
    }

    private TopDocs scoreAllDocs(KNNIterator iterator) throws IOException {
        int docId;
        ArrayList<ScoreDoc> scoreDocList = new ArrayList<ScoreDoc>();
        while ((docId = iterator.nextDoc()) != Integer.MAX_VALUE) {
            scoreDocList.add(new ScoreDoc(docId, iterator.score()));
        }
        scoreDocList.sort(Comparator.comparing(scoreDoc -> Float.valueOf(scoreDoc.score), Comparator.reverseOrder()));
        return new TopDocs(new TotalHits((long)scoreDocList.size(), TotalHits.Relation.EQUAL_TO), (ScoreDoc[])scoreDocList.toArray(ScoreDoc[]::new));
    }

    private TopDocs searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate<Float> filterScore) throws IOException {
        int docId;
        if (filterScore == null) {
            throw new NullPointerException("filterScore is marked non-null but is null");
        }
        HitQueue queue = new HitQueue(limit, true);
        ScoreDoc topDoc = (ScoreDoc)queue.top();
        HashMap docToScore = new HashMap();
        while ((docId = iterator.nextDoc()) != Integer.MAX_VALUE) {
            float currentScore = iterator.score();
            if (!filterScore.test(Float.valueOf(currentScore)) || !(currentScore > topDoc.score)) continue;
            topDoc.score = currentScore;
            topDoc.doc = docId;
            topDoc = (ScoreDoc)queue.updateTop();
        }
        while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
            queue.pop();
        }
        ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
        for (int i = topScoreDocs.length - 1; i >= 0; --i) {
            topScoreDocs[i] = (ScoreDoc)queue.pop();
        }
        TotalHits totalHits = new TotalHits((long)topScoreDocs.length, TotalHits.Relation.EQUAL_TO);
        return new TopDocs(totalHits, topScoreDocs);
    }

    private TopDocs filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) throws IOException {
        int maxResultWindow = context.getMaxResultWindow();
        Predicate<Float> scoreGreaterThanOrEqualToMinScore = score -> score.floatValue() >= minScore;
        return this.searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore);
    }

    private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
        boolean isNestedRequired;
        DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator();
        SegmentReader reader = Lucene.segmentReader((LeafReader)leafReaderContext.reader());
        FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo((LeafReader)reader, exactSearcherContext.getField());
        if (fieldInfo == null) {
            log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", (Object)exactSearcherContext.getField(), (Object)reader.getSegmentName());
            return null;
        }
        VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
        SpaceType spaceType = FieldInfoExtractor.getSpaceType(this.modelDao, fieldInfo);
        boolean bl = isNestedRequired = exactSearcherContext.getParentsFilter() != null;
        if (VectorDataType.BINARY == vectorDataType) {
            KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, (LeafReader)reader);
            if (isNestedRequired) {
                return new NestedBinaryVectorIdsKNNIterator(matchedDocs, exactSearcherContext.getByteQueryVector(), (KNNBinaryVectorValues)vectorValues, spaceType, exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext));
            }
            return new BinaryVectorIdsKNNIterator(matchedDocs, exactSearcherContext.getByteQueryVector(), (KNNBinaryVectorValues)vectorValues, spaceType);
        }
        if (VectorDataType.BYTE == vectorDataType) {
            KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, (LeafReader)reader);
            if (isNestedRequired) {
                return new NestedByteVectorIdsKNNIterator(matchedDocs, exactSearcherContext.getFloatQueryVector(), (KNNByteVectorValues)vectorValues, spaceType, exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext));
            }
            return new ByteVectorIdsKNNIterator(matchedDocs, exactSearcherContext.getFloatQueryVector(), (KNNByteVectorValues)vectorValues, spaceType);
        }
        byte[] quantizedQueryVector = null;
        SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = null;
        if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
            segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build((LeafReader)reader, fieldInfo, exactSearcherContext.getField());
            if (SegmentLevelQuantizationUtil.isAdcEnabled(segmentLevelQuantizationInfo)) {
                SegmentLevelQuantizationUtil.transformVectorWithADC(exactSearcherContext.getFloatQueryVector(), segmentLevelQuantizationInfo, spaceType);
            } else {
                quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(exactSearcherContext.getFloatQueryVector(), segmentLevelQuantizationInfo);
            }
        }
        KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, (LeafReader)reader);
        if (isNestedRequired) {
            return new NestedVectorIdsKNNIterator(matchedDocs, exactSearcherContext.getFloatQueryVector(), (KNNFloatVectorValues)vectorValues, spaceType, exactSearcherContext.getParentsFilter().getBitSet(leafReaderContext), quantizedQueryVector, segmentLevelQuantizationInfo);
        }
        return new VectorIdsKNNIterator(matchedDocs, exactSearcherContext.getFloatQueryVector(), (KNNFloatVectorValues)vectorValues, spaceType, quantizedQueryVector, segmentLevelQuantizationInfo);
    }

    @Generated
    public ExactSearcher(ModelDao modelDao) {
        this.modelDao = modelDao;
    }

    public static final class ExactSearcherContext {
        private final boolean useQuantizedVectorsForSearch;
        private final int k;
        private final Float radius;
        private final DocIdSetIterator matchedDocsIterator;
        private final long numberOfMatchedDocs;
        private final BitSetProducer parentsFilter;
        private final float[] floatQueryVector;
        private final byte[] byteQueryVector;
        private final String field;
        private final Integer maxResultWindow;
        private final VectorSimilarityFunction similarityFunction;
        private final Boolean isMemoryOptimizedSearchEnabled;

        @Generated
        ExactSearcherContext(boolean useQuantizedVectorsForSearch, int k, Float radius, DocIdSetIterator matchedDocsIterator, long numberOfMatchedDocs, BitSetProducer parentsFilter, float[] floatQueryVector, byte[] byteQueryVector, String field, Integer maxResultWindow, VectorSimilarityFunction similarityFunction, Boolean isMemoryOptimizedSearchEnabled) {
            this.useQuantizedVectorsForSearch = useQuantizedVectorsForSearch;
            this.k = k;
            this.radius = radius;
            this.matchedDocsIterator = matchedDocsIterator;
            this.numberOfMatchedDocs = numberOfMatchedDocs;
            this.parentsFilter = parentsFilter;
            this.floatQueryVector = floatQueryVector;
            this.byteQueryVector = byteQueryVector;
            this.field = field;
            this.maxResultWindow = maxResultWindow;
            this.similarityFunction = similarityFunction;
            this.isMemoryOptimizedSearchEnabled = isMemoryOptimizedSearchEnabled;
        }

        @Generated
        public static ExactSearcherContextBuilder builder() {
            return new ExactSearcherContextBuilder();
        }

        @Generated
        public boolean isUseQuantizedVectorsForSearch() {
            return this.useQuantizedVectorsForSearch;
        }

        @Generated
        public int getK() {
            return this.k;
        }

        @Generated
        public Float getRadius() {
            return this.radius;
        }

        @Generated
        public DocIdSetIterator getMatchedDocsIterator() {
            return this.matchedDocsIterator;
        }

        @Generated
        public long getNumberOfMatchedDocs() {
            return this.numberOfMatchedDocs;
        }

        @Generated
        public BitSetProducer getParentsFilter() {
            return this.parentsFilter;
        }

        @Generated
        public float[] getFloatQueryVector() {
            return this.floatQueryVector;
        }

        @Generated
        public byte[] getByteQueryVector() {
            return this.byteQueryVector;
        }

        @Generated
        public String getField() {
            return this.field;
        }

        @Generated
        public Integer getMaxResultWindow() {
            return this.maxResultWindow;
        }

        @Generated
        public VectorSimilarityFunction getSimilarityFunction() {
            return this.similarityFunction;
        }

        @Generated
        public Boolean getIsMemoryOptimizedSearchEnabled() {
            return this.isMemoryOptimizedSearchEnabled;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ExactSearcherContext)) {
                return false;
            }
            ExactSearcherContext other = (ExactSearcherContext)o;
            if (this.isUseQuantizedVectorsForSearch() != other.isUseQuantizedVectorsForSearch()) {
                return false;
            }
            if (this.getK() != other.getK()) {
                return false;
            }
            if (this.getNumberOfMatchedDocs() != other.getNumberOfMatchedDocs()) {
                return false;
            }
            Float this$radius = this.getRadius();
            Float other$radius = other.getRadius();
            if (this$radius == null ? other$radius != null : !((Object)this$radius).equals(other$radius)) {
                return false;
            }
            Integer this$maxResultWindow = this.getMaxResultWindow();
            Integer other$maxResultWindow = other.getMaxResultWindow();
            if (this$maxResultWindow == null ? other$maxResultWindow != null : !((Object)this$maxResultWindow).equals(other$maxResultWindow)) {
                return false;
            }
            Boolean this$isMemoryOptimizedSearchEnabled = this.getIsMemoryOptimizedSearchEnabled();
            Boolean other$isMemoryOptimizedSearchEnabled = other.getIsMemoryOptimizedSearchEnabled();
            if (this$isMemoryOptimizedSearchEnabled == null ? other$isMemoryOptimizedSearchEnabled != null : !((Object)this$isMemoryOptimizedSearchEnabled).equals(other$isMemoryOptimizedSearchEnabled)) {
                return false;
            }
            DocIdSetIterator this$matchedDocsIterator = this.getMatchedDocsIterator();
            DocIdSetIterator other$matchedDocsIterator = other.getMatchedDocsIterator();
            if (this$matchedDocsIterator == null ? other$matchedDocsIterator != null : !this$matchedDocsIterator.equals(other$matchedDocsIterator)) {
                return false;
            }
            BitSetProducer this$parentsFilter = this.getParentsFilter();
            BitSetProducer other$parentsFilter = other.getParentsFilter();
            if (this$parentsFilter == null ? other$parentsFilter != null : !this$parentsFilter.equals(other$parentsFilter)) {
                return false;
            }
            if (!Arrays.equals(this.getFloatQueryVector(), other.getFloatQueryVector())) {
                return false;
            }
            if (!Arrays.equals(this.getByteQueryVector(), other.getByteQueryVector())) {
                return false;
            }
            String this$field = this.getField();
            String other$field = other.getField();
            if (this$field == null ? other$field != null : !this$field.equals(other$field)) {
                return false;
            }
            VectorSimilarityFunction this$similarityFunction = this.getSimilarityFunction();
            VectorSimilarityFunction other$similarityFunction = other.getSimilarityFunction();
            return !(this$similarityFunction == null ? other$similarityFunction != null : !this$similarityFunction.equals(other$similarityFunction));
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + (this.isUseQuantizedVectorsForSearch() ? 79 : 97);
            result = result * 59 + this.getK();
            long $numberOfMatchedDocs = this.getNumberOfMatchedDocs();
            result = result * 59 + (int)($numberOfMatchedDocs >>> 32 ^ $numberOfMatchedDocs);
            Float $radius = this.getRadius();
            result = result * 59 + ($radius == null ? 43 : ((Object)$radius).hashCode());
            Integer $maxResultWindow = this.getMaxResultWindow();
            result = result * 59 + ($maxResultWindow == null ? 43 : ((Object)$maxResultWindow).hashCode());
            Boolean $isMemoryOptimizedSearchEnabled = this.getIsMemoryOptimizedSearchEnabled();
            result = result * 59 + ($isMemoryOptimizedSearchEnabled == null ? 43 : ((Object)$isMemoryOptimizedSearchEnabled).hashCode());
            DocIdSetIterator $matchedDocsIterator = this.getMatchedDocsIterator();
            result = result * 59 + ($matchedDocsIterator == null ? 43 : $matchedDocsIterator.hashCode());
            BitSetProducer $parentsFilter = this.getParentsFilter();
            result = result * 59 + ($parentsFilter == null ? 43 : $parentsFilter.hashCode());
            result = result * 59 + Arrays.hashCode(this.getFloatQueryVector());
            result = result * 59 + Arrays.hashCode(this.getByteQueryVector());
            String $field = this.getField();
            result = result * 59 + ($field == null ? 43 : $field.hashCode());
            VectorSimilarityFunction $similarityFunction = this.getSimilarityFunction();
            result = result * 59 + ($similarityFunction == null ? 43 : $similarityFunction.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "ExactSearcher.ExactSearcherContext(useQuantizedVectorsForSearch=" + this.isUseQuantizedVectorsForSearch() + ", k=" + this.getK() + ", radius=" + this.getRadius() + ", matchedDocsIterator=" + String.valueOf(this.getMatchedDocsIterator()) + ", numberOfMatchedDocs=" + this.getNumberOfMatchedDocs() + ", parentsFilter=" + String.valueOf(this.getParentsFilter()) + ", floatQueryVector=" + Arrays.toString(this.getFloatQueryVector()) + ", byteQueryVector=" + Arrays.toString(this.getByteQueryVector()) + ", field=" + this.getField() + ", maxResultWindow=" + this.getMaxResultWindow() + ", similarityFunction=" + String.valueOf(this.getSimilarityFunction()) + ", isMemoryOptimizedSearchEnabled=" + this.getIsMemoryOptimizedSearchEnabled() + ")";
        }

        @Generated
        public static class ExactSearcherContextBuilder {
            @Generated
            private boolean useQuantizedVectorsForSearch;
            @Generated
            private int k;
            @Generated
            private Float radius;
            @Generated
            private DocIdSetIterator matchedDocsIterator;
            @Generated
            private long numberOfMatchedDocs;
            @Generated
            private BitSetProducer parentsFilter;
            @Generated
            private float[] floatQueryVector;
            @Generated
            private byte[] byteQueryVector;
            @Generated
            private String field;
            @Generated
            private Integer maxResultWindow;
            @Generated
            private VectorSimilarityFunction similarityFunction;
            @Generated
            private Boolean isMemoryOptimizedSearchEnabled;

            @Generated
            ExactSearcherContextBuilder() {
            }

            @Generated
            public ExactSearcherContextBuilder useQuantizedVectorsForSearch(boolean useQuantizedVectorsForSearch) {
                this.useQuantizedVectorsForSearch = useQuantizedVectorsForSearch;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder k(int k) {
                this.k = k;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder radius(Float radius) {
                this.radius = radius;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder matchedDocsIterator(DocIdSetIterator matchedDocsIterator) {
                this.matchedDocsIterator = matchedDocsIterator;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder numberOfMatchedDocs(long numberOfMatchedDocs) {
                this.numberOfMatchedDocs = numberOfMatchedDocs;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder parentsFilter(BitSetProducer parentsFilter) {
                this.parentsFilter = parentsFilter;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder floatQueryVector(float[] floatQueryVector) {
                this.floatQueryVector = floatQueryVector;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder byteQueryVector(byte[] byteQueryVector) {
                this.byteQueryVector = byteQueryVector;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder field(String field) {
                this.field = field;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder maxResultWindow(Integer maxResultWindow) {
                this.maxResultWindow = maxResultWindow;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder similarityFunction(VectorSimilarityFunction similarityFunction) {
                this.similarityFunction = similarityFunction;
                return this;
            }

            @Generated
            public ExactSearcherContextBuilder isMemoryOptimizedSearchEnabled(Boolean isMemoryOptimizedSearchEnabled) {
                this.isMemoryOptimizedSearchEnabled = isMemoryOptimizedSearchEnabled;
                return this;
            }

            @Generated
            public ExactSearcherContext build() {
                return new ExactSearcherContext(this.useQuantizedVectorsForSearch, this.k, this.radius, this.matchedDocsIterator, this.numberOfMatchedDocs, this.parentsFilter, this.floatQueryVector, this.byteQueryVector, this.field, this.maxResultWindow, this.similarityFunction, this.isMemoryOptimizedSearchEnabled);
            }

            @Generated
            public String toString() {
                return "ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder(useQuantizedVectorsForSearch=" + this.useQuantizedVectorsForSearch + ", k=" + this.k + ", radius=" + this.radius + ", matchedDocsIterator=" + String.valueOf(this.matchedDocsIterator) + ", numberOfMatchedDocs=" + this.numberOfMatchedDocs + ", parentsFilter=" + String.valueOf(this.parentsFilter) + ", floatQueryVector=" + Arrays.toString(this.floatQueryVector) + ", byteQueryVector=" + Arrays.toString(this.byteQueryVector) + ", field=" + this.field + ", maxResultWindow=" + this.maxResultWindow + ", similarityFunction=" + String.valueOf(this.similarityFunction) + ", isMemoryOptimizedSearchEnabled=" + this.isMemoryOptimizedSearchEnabled + ")";
            }
        }
    }
}

