/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic;

import java.util.EnumSet;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.SageMakerElasticTaskSettings;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;

public class ElasticRerankPayload
implements ElasticPayload {
    private static final EnumSet<TaskType> SUPPORTED_TASKS = EnumSet.of(TaskType.RERANK);
    private static final ConstructingObjectParser<RankedDocsResults, Void> PARSER = RankedDocsResults.createParser((boolean)false);

    @Override
    public EnumSet<TaskType> supportedTasks() {
        return SUPPORTED_TASKS;
    }

    @Override
    public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
        SageMakerStoredTaskSchema sageMakerStoredTaskSchema = model.apiTaskSettings();
        if (sageMakerStoredTaskSchema instanceof SageMakerElasticTaskSettings) {
            SageMakerElasticTaskSettings elasticTaskSettings = (SageMakerElasticTaskSettings)sageMakerStoredTaskSchema;
            return SdkBytes.fromUtf8String((String)Strings.toString((builder, params) -> {
                if (request.input().size() > 1) {
                    builder.field(InferenceAction.Request.INPUT.getPreferredName(), request.input());
                } else {
                    builder.field(InferenceAction.Request.INPUT.getPreferredName(), request.input().get(0));
                }
                assert (request.query() != null) : "InferenceAction.Request will validate that rerank requests have a query field";
                builder.field(InferenceAction.Request.QUERY.getPreferredName(), request.query());
                if (request.returnDocuments() != null) {
                    builder.field(InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName(), request.returnDocuments());
                }
                if (request.topN() != null) {
                    builder.field(InferenceAction.Request.TOP_N.getPreferredName(), request.topN());
                }
                if (!elasticTaskSettings.isEmpty()) {
                    builder.field(InferenceAction.Request.TASK_SETTINGS.getPreferredName());
                    if (elasticTaskSettings.isFragment()) {
                        builder.startObject();
                    }
                    builder.value((ToXContent)elasticTaskSettings);
                    if (elasticTaskSettings.isFragment()) {
                        builder.endObject();
                    }
                }
                return builder;
            }));
        }
        throw this.createUnsupportedSchemaException(model);
    }

    public RankedDocsResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
        try (XContentParser p = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream());){
            RankedDocsResults rankedDocsResults = (RankedDocsResults)PARSER.apply(p, null);
            return rankedDocsResults;
        }
    }
}

