/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.chunking;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.inference.chunking.Chunker;
import org.elasticsearch.xpack.inference.chunking.ChunkerBuilder;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;

public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
    private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
    private static final int MAX_CHUNKS = 512;
    private final List<BatchRequest> batchRequests;
    private final AtomicInteger resultCount = new AtomicInteger();
    private final List<List<Integer>> resultOffsetStarts;
    private final List<List<Integer>> resultOffsetEnds;
    private final List<AtomicReferenceArray<E>> resultEmbeddings;
    private final AtomicArray<Exception> resultsErrors;
    private ActionListener<List<ChunkedInference>> finalListener;

    public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch) {
        this(inputs, maxNumberOfInputsPerBatch, null);
    }

    public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) {
        this(inputs, maxNumberOfInputsPerBatch, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap));
    }

    public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch, ChunkingSettings defaultChunkingSettings) {
        this.resultEmbeddings = new ArrayList<AtomicReferenceArray<E>>(inputs.size());
        this.resultOffsetStarts = new ArrayList<List<Integer>>(inputs.size());
        this.resultOffsetEnds = new ArrayList<List<Integer>>(inputs.size());
        this.resultsErrors = new AtomicArray(inputs.size());
        if (defaultChunkingSettings == null) {
            defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS;
        }
        Map<ChunkingStrategy, Chunker> chunkers = inputs.stream().map(ChunkInferenceInput::chunkingSettings).filter(Objects::nonNull).map(ChunkingSettings::getChunkingStrategy).distinct().collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy));
        Chunker defaultChunker = ChunkerBuilder.fromChunkingStrategy(defaultChunkingSettings.getChunkingStrategy());
        ArrayList<Request> allRequests = new ArrayList<Request>();
        for (int inputIndex = 0; inputIndex < inputs.size(); ++inputIndex) {
            ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings();
            if (chunkingSettings == null) {
                chunkingSettings = defaultChunkingSettings;
            }
            Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
            String inputString = inputs.get(inputIndex).input();
            List<Chunker.ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
            int resultCount = Math.min(chunks.size(), 512);
            this.resultEmbeddings.add(new AtomicReferenceArray(resultCount));
            this.resultOffsetStarts.add(new ArrayList(resultCount));
            this.resultOffsetEnds.add(new ArrayList(resultCount));
            for (int chunkIndex = 0; chunkIndex < chunks.size(); ++chunkIndex) {
                int targetChunkIndex = chunks.size() <= 512 ? chunkIndex : chunkIndex * 512 / chunks.size();
                int lastStart = this.resultOffsetStarts.size() - 1;
                int lastEnd = this.resultOffsetEnds.size() - 1;
                if (this.resultOffsetStarts.get(lastStart).size() <= targetChunkIndex) {
                    this.resultOffsetStarts.get(lastStart).add(chunks.get(chunkIndex).start());
                    this.resultOffsetEnds.get(lastEnd).add(chunks.get(chunkIndex).end());
                } else {
                    this.resultOffsetEnds.get(lastEnd).set(targetChunkIndex, chunks.get(chunkIndex).end());
                }
                allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
            }
        }
        AtomicInteger counter = new AtomicInteger();
        this.batchRequests = allRequests.stream().collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)).values().stream().map(BatchRequest::new).toList();
    }

    public List<BatchRequestAndListener> batchRequestsWithListeners(ActionListener<List<ChunkedInference>> finalListener) {
        this.finalListener = finalListener;
        return this.batchRequests.stream().map(req -> new BatchRequestAndListener((BatchRequest)req, new DebatchingListener((BatchRequest)req))).toList();
    }

    private void sendFinalResponse() {
        ArrayList<Object> response = new ArrayList<Object>(this.resultEmbeddings.size());
        for (int i = 0; i < this.resultEmbeddings.size(); ++i) {
            if (this.resultsErrors.get(i) != null) {
                response.add(new ChunkedInferenceError((Exception)this.resultsErrors.get(i)));
                this.resultsErrors.set(i, null);
                continue;
            }
            response.add(this.mergeResultsWithInputs(i));
        }
        this.finalListener.onResponse(response);
    }

    private ChunkedInference mergeResultsWithInputs(int inputIndex) {
        List<Integer> startOffsets = this.resultOffsetStarts.get(inputIndex);
        List<Integer> endOffsets = this.resultOffsetEnds.get(inputIndex);
        AtomicReferenceArray<E> embeddings = this.resultEmbeddings.get(inputIndex);
        ArrayList<EmbeddingResults.Chunk> chunks = new ArrayList<EmbeddingResults.Chunk>();
        for (int i = 0; i < embeddings.length(); ++i) {
            ChunkedInference.TextOffset offset = new ChunkedInference.TextOffset(startOffsets.get(i).intValue(), endOffsets.get(i).intValue());
            chunks.add(new EmbeddingResults.Chunk((EmbeddingResults.Embedding)embeddings.get(i), offset));
        }
        return new ChunkedInferenceEmbedding(chunks);
    }

    record Request(int inputIndex, int chunkIndex, Chunker.ChunkOffset chunk, String input) {
        public String chunkText() {
            return this.input.substring(this.chunk.start(), this.chunk.end());
        }
    }

    public record BatchRequestAndListener(BatchRequest batch, ActionListener<InferenceServiceResults> listener) {
    }

    private class DebatchingListener
    implements ActionListener<InferenceServiceResults> {
        private BatchRequest request;

        DebatchingListener(BatchRequest request) {
            this.request = request;
        }

        public void onResponse(InferenceServiceResults inferenceServiceResults) {
            if (!(inferenceServiceResults instanceof EmbeddingResults)) {
                this.onFailure((Exception)this.unexpectedResultTypeException(inferenceServiceResults.getWriteableName()));
                return;
            }
            EmbeddingResults embeddingResults = (EmbeddingResults)inferenceServiceResults;
            if (embeddingResults.embeddings().size() != this.request.requests.size()) {
                this.onFailure((Exception)this.numResultsDoesntMatchException(embeddingResults.embeddings().size(), this.request.requests.size()));
                return;
            }
            for (int i = 0; i < embeddingResults.embeddings().size(); ++i) {
                EmbeddingResults.Embedding newEmbedding = (EmbeddingResults.Embedding)embeddingResults.embeddings().get(i);
                EmbeddingRequestChunker.this.resultEmbeddings.get(this.request.requests().get(i).inputIndex()).updateAndGet(this.request.requests().get(i).chunkIndex(), oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding));
            }
            this.request = null;
            if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == EmbeddingRequestChunker.this.batchRequests.size()) {
                EmbeddingRequestChunker.this.sendFinalResponse();
            }
        }

        private ElasticsearchStatusException numResultsDoesntMatchException(int numResults, int numRequests) {
            return new ElasticsearchStatusException("Error the number of embedding responses [{}] does not equal the number of requests [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{numResults, numRequests});
        }

        private ElasticsearchStatusException unexpectedResultTypeException(String resultType) {
            return new ElasticsearchStatusException("Unexpected inference result type [{}], expected [EmbeddingResults]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{resultType});
        }

        public void onFailure(Exception e) {
            for (Request request : this.request.requests) {
                EmbeddingRequestChunker.this.resultsErrors.set(request.inputIndex(), (Object)e);
            }
            this.request = null;
            if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == EmbeddingRequestChunker.this.batchRequests.size()) {
                EmbeddingRequestChunker.this.sendFinalResponse();
            }
        }
    }

    public record BatchRequest(List<Request> requests) {
        public Supplier<List<String>> inputs() {
            return () -> this.requests.stream().map(Request::chunkText).collect(Collectors.toList());
        }
    }
}

