/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.query.explain;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BitSet;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.common.IteratorWrapper;
import org.opensearch.neuralsearch.sparse.data.SparseVector;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizationUtil;
import org.opensearch.neuralsearch.sparse.query.SparseVectorQuery;

public class SparseExplanationBuilder {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseExplanationBuilder.class);
    @NonNull
    private final LeafReaderContext context;
    private final int docId;
    @NonNull
    private final SparseVectorQuery query;
    private final float boost;
    @NonNull
    private final FieldInfo fieldInfo;
    @NonNull
    private final SparseVectorReader reader;

    public Explanation explain() {
        SparseVector docVector;
        if (this.docId < 0) {
            return Explanation.noMatch((String)String.format(Locale.ROOT, "invalid document ID %d (must be non-negative)", this.docId), (Explanation[])new Explanation[0]);
        }
        if (this.query.getQueryVector().getSize() == 0) {
            return Explanation.noMatch((String)String.format(Locale.ROOT, "query vector is empty or null for field '%s'", this.query.getFieldName()), (Explanation[])new Explanation[0]);
        }
        try {
            docVector = this.reader.read(this.docId);
        }
        catch (IOException e) {
            return Explanation.noMatch((String)String.format(Locale.ROOT, "error reading document %d in field '%s': %s", this.docId, this.query.getFieldName(), e.getMessage()), (Explanation[])new Explanation[0]);
        }
        if (docVector == null) {
            return Explanation.noMatch((String)String.format(Locale.ROOT, "document %d not found or has no sparse vector in field '%s'", this.docId, this.query.getFieldName()), (Explanation[])new Explanation[0]);
        }
        byte[] queryDenseVector = this.query.getQueryVector().toDenseVector();
        int rawScore = docVector.dotProduct(queryDenseVector);
        ArrayList<Explanation> details = new ArrayList<Explanation>();
        details.add(this.explainQueryPruning());
        details.add(this.explainRawScore(rawScore, docVector, queryDenseVector));
        details.add(this.explainQuantizationRescaling());
        if (this.query.getFilter() != null) {
            details.add(this.explainFilter());
        }
        float rescaledBoost = this.calculateRescaledBoost();
        float finalScore = (float)rawScore * rescaledBoost;
        return Explanation.match((Number)Float.valueOf(finalScore), (String)String.format(Locale.ROOT, "sparse_ann score for doc %d in field '%s'", this.docId, this.query.getFieldName()), details);
    }

    private Explanation explainRawScore(int rawScore, SparseVector docVector, byte[] queryDenseVector) {
        ArrayList<Explanation> tokenDetails = new ArrayList<Explanation>();
        List<String> queryTokens = this.query.getQueryContext().getTokens();
        ArrayList<TokenContribution> contributions = new ArrayList<TokenContribution>();
        for (String tokenStr : queryTokens) {
            SparseVector.Item item;
            byte queryWeight;
            int tokenId;
            try {
                tokenId = Integer.parseInt(tokenStr);
            }
            catch (NumberFormatException e) {
                log.warn("Invalid token ID '{}' in query context, skipping", (Object)tokenStr);
                continue;
            }
            if (tokenId < 0 || tokenId >= queryDenseVector.length || (queryWeight = queryDenseVector[tokenId]) == 0) continue;
            byte docWeight = 0;
            IteratorWrapper<SparseVector.Item> iterator = docVector.iterator();
            while ((item = iterator.next()) != null) {
                if (item.getToken() != SparseVector.prepareTokenForShortType(tokenId)) continue;
                docWeight = item.getWeight();
                break;
            }
            if (docWeight == 0) continue;
            int contribution = ByteQuantizationUtil.multiplyUnsignedByte(queryWeight, docWeight);
            contributions.add(new TokenContribution(tokenStr, queryWeight, docWeight, contribution));
        }
        contributions.sort(Comparator.comparingInt(TokenContribution::getContribution).reversed());
        for (TokenContribution tc : contributions) {
            tokenDetails.add(Explanation.match((Number)tc.getContribution(), (String)String.format(Locale.ROOT, "token '%s' contribution: query_weight=%d * doc_weight=%d", tc.getToken(), ByteQuantizationUtil.getUnsignedByte(tc.getQueryWeight()), ByteQuantizationUtil.getUnsignedByte(tc.getDocWeight())), (Explanation[])new Explanation[0]));
        }
        return Explanation.match((Number)rawScore, (String)String.format(Locale.ROOT, "raw dot product score (quantized): %d", rawScore), tokenDetails);
    }

    private float calculateRescaledBoost() {
        float ceilingIngest = ByteQuantizationUtil.getCeilingValueIngest(this.fieldInfo);
        float ceilingSearch = ByteQuantizationUtil.getCeilingValueSearch(this.fieldInfo);
        return this.boost * ceilingIngest * ceilingSearch / 255.0f / 255.0f;
    }

    private Explanation explainQuantizationRescaling() {
        float ceilingIngest = ByteQuantizationUtil.getCeilingValueIngest(this.fieldInfo);
        float ceilingSearch = ByteQuantizationUtil.getCeilingValueSearch(this.fieldInfo);
        float rescaledBoost = this.calculateRescaledBoost();
        ArrayList<Explanation> details = new ArrayList<Explanation>();
        details.add(Explanation.match((Number)Float.valueOf(this.boost), (String)String.format(Locale.ROOT, "original boost: %.4f", Float.valueOf(this.boost)), (Explanation[])new Explanation[0]));
        details.add(Explanation.match((Number)Float.valueOf(ceilingIngest), (String)String.format(Locale.ROOT, "ceiling_ingest (quantization parameter): %.2f", Float.valueOf(ceilingIngest)), (Explanation[])new Explanation[0]));
        details.add(Explanation.match((Number)Float.valueOf(ceilingSearch), (String)String.format(Locale.ROOT, "ceiling_search (quantization parameter): %.2f", Float.valueOf(ceilingSearch)), (Explanation[])new Explanation[0]));
        details.add(Explanation.match((Number)255, (String)String.format(Locale.ROOT, "MAX_UNSIGNED_BYTE_VALUE: %d", 255), (Explanation[])new Explanation[0]));
        return Explanation.match((Number)Float.valueOf(rescaledBoost), (String)String.format(Locale.ROOT, "quantization rescaling: %.4f * %.2f * %.2f / %d / %d = %.6f", Float.valueOf(this.boost), Float.valueOf(ceilingIngest), Float.valueOf(ceilingSearch), 255, 255, Float.valueOf(rescaledBoost)), details);
    }

    private Explanation explainQueryPruning() {
        int prunedTokenCount;
        int originalTokenCount = this.query.getQueryVector().getSize();
        if (originalTokenCount == (prunedTokenCount = this.query.getQueryContext().getTokens().size())) {
            return Explanation.match((Number)prunedTokenCount, (String)String.format(Locale.ROOT, "query token pruning: kept all %d tokens (no pruning occurred)", prunedTokenCount), (Explanation[])new Explanation[0]);
        }
        return Explanation.match((Number)prunedTokenCount, (String)String.format(Locale.ROOT, "query token pruning: kept top %d of %d tokens", prunedTokenCount, originalTokenCount), (Explanation[])new Explanation[0]);
    }

    private Explanation explainFilter() {
        Map<Object, BitSet> filterResults = this.query.getFilterResults();
        if (filterResults == null) {
            return Explanation.match((Number)Float.valueOf(1.0f), (String)"filter present but no filter results available", (Explanation[])new Explanation[0]);
        }
        BitSet bitSet = filterResults.get(this.context.id());
        if (bitSet == null) {
            return Explanation.noMatch((String)"document filtered out (no documents in segment matched filter)", (Explanation[])new Explanation[0]);
        }
        Query filterQuery = this.query.getFilter();
        ArrayList<Explanation> details = new ArrayList<Explanation>();
        if (filterQuery != null) {
            details.add(Explanation.match((Number)Float.valueOf(1.0f), (String)String.format(Locale.ROOT, "filter criteria: %s", filterQuery), (Explanation[])new Explanation[0]));
        }
        if (bitSet.get(this.docId)) {
            int k;
            int passedCount = bitSet.cardinality();
            if (passedCount <= (k = this.query.getQueryContext().getK())) {
                return Explanation.match((Number)Float.valueOf(1.0f), (String)String.format(Locale.ROOT, "document passed filter with exact search mode (filter matched %d documents <= k=%d, all filtered documents scored exactly)", passedCount, k), details);
            }
            return Explanation.match((Number)Float.valueOf(1.0f), (String)String.format(Locale.ROOT, "document passed filter with approximate search mode (filter matched %d documents > k=%d, ANN search performed first then filtered)", passedCount, k), details);
        }
        return Explanation.noMatch((String)"document filtered out (did not match filter criteria, filter multiplier: 0.0)", details);
    }

    @Generated
    SparseExplanationBuilder(@NonNull LeafReaderContext context, int docId, @NonNull SparseVectorQuery query, float boost, @NonNull FieldInfo fieldInfo, @NonNull SparseVectorReader reader) {
        Objects.requireNonNull(context, "context is marked non-null but is null");
        Objects.requireNonNull(query, "query is marked non-null but is null");
        Objects.requireNonNull(fieldInfo, "fieldInfo is marked non-null but is null");
        Objects.requireNonNull(reader, "reader is marked non-null but is null");
        this.context = context;
        this.docId = docId;
        this.query = query;
        this.boost = boost;
        this.fieldInfo = fieldInfo;
        this.reader = reader;
    }

    @Generated
    public static SparseExplanationBuilderBuilder builder() {
        return new SparseExplanationBuilderBuilder();
    }

    private static final class TokenContribution {
        private final String token;
        private final byte queryWeight;
        private final byte docWeight;
        private final int contribution;

        @Generated
        public TokenContribution(String token, byte queryWeight, byte docWeight, int contribution) {
            this.token = token;
            this.queryWeight = queryWeight;
            this.docWeight = docWeight;
            this.contribution = contribution;
        }

        @Generated
        public String getToken() {
            return this.token;
        }

        @Generated
        public byte getQueryWeight() {
            return this.queryWeight;
        }

        @Generated
        public byte getDocWeight() {
            return this.docWeight;
        }

        @Generated
        public int getContribution() {
            return this.contribution;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof TokenContribution)) {
                return false;
            }
            TokenContribution other = (TokenContribution)o;
            if (this.getQueryWeight() != other.getQueryWeight()) {
                return false;
            }
            if (this.getDocWeight() != other.getDocWeight()) {
                return false;
            }
            if (this.getContribution() != other.getContribution()) {
                return false;
            }
            String this$token = this.getToken();
            String other$token = other.getToken();
            return !(this$token == null ? other$token != null : !this$token.equals(other$token));
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getQueryWeight();
            result = result * 59 + this.getDocWeight();
            result = result * 59 + this.getContribution();
            String $token = this.getToken();
            result = result * 59 + ($token == null ? 43 : $token.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "SparseExplanationBuilder.TokenContribution(token=" + this.getToken() + ", queryWeight=" + this.getQueryWeight() + ", docWeight=" + this.getDocWeight() + ", contribution=" + this.getContribution() + ")";
        }
    }

    @Generated
    public static class SparseExplanationBuilderBuilder {
        @Generated
        private LeafReaderContext context;
        @Generated
        private int docId;
        @Generated
        private SparseVectorQuery query;
        @Generated
        private float boost;
        @Generated
        private FieldInfo fieldInfo;
        @Generated
        private SparseVectorReader reader;

        @Generated
        SparseExplanationBuilderBuilder() {
        }

        @Generated
        public SparseExplanationBuilderBuilder context(@NonNull LeafReaderContext context) {
            Objects.requireNonNull(context, "context is marked non-null but is null");
            this.context = context;
            return this;
        }

        @Generated
        public SparseExplanationBuilderBuilder docId(int docId) {
            this.docId = docId;
            return this;
        }

        @Generated
        public SparseExplanationBuilderBuilder query(@NonNull SparseVectorQuery query) {
            Objects.requireNonNull(query, "query is marked non-null but is null");
            this.query = query;
            return this;
        }

        @Generated
        public SparseExplanationBuilderBuilder boost(float boost) {
            this.boost = boost;
            return this;
        }

        @Generated
        public SparseExplanationBuilderBuilder fieldInfo(@NonNull FieldInfo fieldInfo) {
            Objects.requireNonNull(fieldInfo, "fieldInfo is marked non-null but is null");
            this.fieldInfo = fieldInfo;
            return this;
        }

        @Generated
        public SparseExplanationBuilderBuilder reader(@NonNull SparseVectorReader reader) {
            Objects.requireNonNull(reader, "reader is marked non-null but is null");
            this.reader = reader;
            return this;
        }

        @Generated
        public SparseExplanationBuilder build() {
            return new SparseExplanationBuilder(this.context, this.docId, this.query, this.boost, this.fieldInfo, this.reader);
        }

        @Generated
        public String toString() {
            return "SparseExplanationBuilder.SparseExplanationBuilderBuilder(context=" + String.valueOf(this.context) + ", docId=" + this.docId + ", query=" + String.valueOf((Object)this.query) + ", boost=" + this.boost + ", fieldInfo=" + String.valueOf(this.fieldInfo) + ", reader=" + String.valueOf(this.reader) + ")";
        }
    }
}

