/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.symbols.ClassSymbol;
import org.sonar.plugins.python.api.symbols.FunctionSymbol;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.DictionaryLiteral;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.ExpressionList;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.StringLiteral;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.python.checks.utils.Expressions;
import org.sonar.python.tree.DictionaryLiteralImpl;
import org.sonar.python.tree.KeyValuePairImpl;
import org.sonar.python.tree.ListLiteralImpl;
import org.sonar.python.tree.TreeUtils;
import org.sonar.python.tree.TupleImpl;

@Rule(key="S6972")
public class SklearnPipelineParameterAreCorrectCheck
extends PythonSubscriptionCheck {
    public static final String MESSAGE = "Provide valid parameters to the estimator.";
    private static final Set<String> SKLEARN_SEARCH_FQNS = Set.of("sklearn.model_selection._search.GridSearchCV", "sklearn.model_selection._search_successive_halving.HalvingGridSearchCV", "sklearn.model_selection._search.RandomizedSearchCV", "sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV");

    @Override
    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, SklearnPipelineParameterAreCorrectCheck::checkCallExpression);
    }

    private static void checkCallExpression(SubscriptionContext subscriptionContext) {
        CallExpression callExpression = (CallExpression)subscriptionContext.syntaxNode();
        Optional parsedFunctionOptional = Optional.ofNullable(callExpression.calleeSymbol()).map(Symbol::fullyQualifiedName).filter(SKLEARN_SEARCH_FQNS::contains).map(callExpr -> SklearnPipelineParameterAreCorrectCheck.getStepAndParametersFromDict(callExpression)).flatMap(parsedParameters -> SklearnPipelineParameterAreCorrectCheck.getPipelineNameAndParsedParametersFromSearchFunctions(parsedParameters, callExpression)).or(() -> Optional.ofNullable(callExpression.calleeSymbol()).map(Symbol::fullyQualifiedName).filter("sklearn.pipeline.Pipeline.set_params"::equals).map(callExpr -> SklearnPipelineParameterAreCorrectCheck.getStepAndParametersFromArguments(callExpression)).flatMap(parsedParameters -> SklearnPipelineParameterAreCorrectCheck.getPipelineNameAndParsedParametersFromPipelineSetParamsFunction(parsedParameters, callExpression)));
        parsedFunctionOptional.ifPresent(pipelineNameAndParsedParameters -> {
            Map<String, Set<ParameterNameAndLocation>> parsedFunction = pipelineNameAndParsedParameters.parsedParameters;
            Name pipelineName = pipelineNameAndParsedParameters.pipelineName;
            Expressions.singleAssignedNonNameValue(pipelineName).flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class)).ifPresent(pipelineCallExpr -> SklearnPipelineParameterAreCorrectCheck.findProblems(parsedFunction, SklearnPipelineParameterAreCorrectCheck.parsePipeline(pipelineCallExpr), subscriptionContext));
        });
    }

    private static Optional<PipelineNameAndParsedParameters> getPipelineNameAndParsedParametersFromPipelineSetParamsFunction(Map<String, Set<ParameterNameAndLocation>> parsedParameters, CallExpression callExpression) {
        return SklearnPipelineParameterAreCorrectCheck.newPipelineNameAndParsedParameters(Optional.of(callExpression).map(CallExpression::callee).flatMap(TreeUtils.toOptionalInstanceOfMapper(QualifiedExpression.class)).map(QualifiedExpression::qualifier).flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class)).orElse(null), parsedParameters);
    }

    private static Optional<PipelineNameAndParsedParameters> getPipelineNameAndParsedParametersFromSearchFunctions(Map<String, Set<ParameterNameAndLocation>> parsedParameters, CallExpression callExpression) {
        return SklearnPipelineParameterAreCorrectCheck.newPipelineNameAndParsedParameters(Optional.ofNullable(TreeUtils.nthArgumentOrKeyword(0, "estimator", callExpression.arguments())).flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class)).map(RegularArgument::expression).flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class)).orElse(null), parsedParameters);
    }

    private static Optional<PipelineNameAndParsedParameters> newPipelineNameAndParsedParameters(@Nullable Name pipelineName, Map<String, Set<ParameterNameAndLocation>> parsedParameters) {
        return Optional.ofNullable(pipelineName).map(pipelineName1 -> new PipelineNameAndParsedParameters((Name)pipelineName1, parsedParameters));
    }

    private static void findProblems(Map<String, Set<ParameterNameAndLocation>> setParameters, Map<String, ClassSymbol> pipelineDefinition, SubscriptionContext subscriptionContext) {
        for (Map.Entry<String, Set<ParameterNameAndLocation>> entry : setParameters.entrySet()) {
            String step = entry.getKey();
            Set<ParameterNameAndLocation> stringAndTree = entry.getValue();
            Set<String> parameters = stringAndTree.stream().map(ParameterNameAndLocation::string).collect(Collectors.toSet());
            ClassSymbol classifier = pipelineDefinition.get(step);
            if (classifier == null) continue;
            List possibleParameters = SklearnPipelineParameterAreCorrectCheck.getInitFunctionSymbol(classifier).map(FunctionSymbol::parameters).orElse(List.of());
            parameters.forEach(parameter -> {
                if (SklearnPipelineParameterAreCorrectCheck.isNotAValidParameter(parameter, possibleParameters)) {
                    SklearnPipelineParameterAreCorrectCheck.createIssue(subscriptionContext, parameter, stringAndTree);
                }
            });
        }
    }

    private static void createIssue(SubscriptionContext subscriptionContext, String parameter, Set<ParameterNameAndLocation> parameterNameAndLocation) {
        parameterNameAndLocation.stream().filter(parameterNameAndLocation1 -> parameterNameAndLocation1.string().equals(parameter)).findFirst().ifPresent(location -> subscriptionContext.addIssue(location.tree, MESSAGE));
    }

    private static boolean isNotAValidParameter(String parameter, List<FunctionSymbol.Parameter> possibleParameters) {
        return possibleParameters.stream().noneMatch(symbol -> Objects.equals(symbol.name(), parameter));
    }

    private static Optional<FunctionSymbol> getInitFunctionSymbol(ClassSymbol classSymbol) {
        return classSymbol.declaredMembers().stream().filter(memberSymbol -> "__init__".equals(memberSymbol.name())).findFirst().map(FunctionSymbol.class::cast);
    }

    private static Stream<Expression> getExpressionsFromArgument(@Nullable RegularArgument argument) {
        return Optional.ofNullable(argument).map(RegularArgument::expression).flatMap(TreeUtils.toOptionalInstanceOfMapper(ListLiteralImpl.class)).map(ListLiteralImpl::elements).map(ExpressionList::expressions).stream().flatMap(Collection::stream);
    }

    private static Map<String, ClassSymbol> parsePipeline(CallExpression callExpression) {
        RegularArgument stepsArgument = TreeUtils.nthArgumentOrKeyword(0, "steps", callExpression.arguments());
        HashMap<String, ClassSymbol> out = new HashMap<String, ClassSymbol>();
        SklearnPipelineParameterAreCorrectCheck.getExpressionsFromArgument(stepsArgument).map(TreeUtils.toInstanceOfMapper(TupleImpl.class)).filter(Objects::nonNull).map(TupleImpl::elements).filter(SklearnPipelineParameterAreCorrectCheck::isTwoElementTuple).map(SklearnPipelineParameterAreCorrectCheck::createEmptyExpressionAndPrefix).flatMap(SklearnPipelineParameterAreCorrectCheck.expandRecursivePipelines()).forEach(expressionAndPrefix -> SklearnPipelineParameterAreCorrectCheck.getResult(expressionAndPrefix.tuple()).ifPresent(stepAndClassifier1 -> out.put(expressionAndPrefix.prefix() + stepAndClassifier1.stepName(), stepAndClassifier1.classifierName())));
        return out;
    }

    private static boolean isTwoElementTuple(List<Expression> elements) {
        return elements.size() == 2;
    }

    private static ExpressionAndPrefix createEmptyExpressionAndPrefix(List<Expression> tuple) {
        return new ExpressionAndPrefix(tuple, "", 0);
    }

    private static Function<ExpressionAndPrefix, Stream<ExpressionAndPrefix>> expandRecursivePipelines() {
        return expressionAndPrefix -> {
            List<Expression> tuple = expressionAndPrefix.tuple();
            Expression step = tuple.get(0);
            Expression classifier = tuple.get(1);
            if (!step.is(Tree.Kind.STRING_LITERAL) || !classifier.is(Tree.Kind.NAME)) {
                return Stream.of(expressionAndPrefix);
            }
            if (expressionAndPrefix.depth > 10) {
                return Stream.of(expressionAndPrefix);
            }
            return SklearnPipelineParameterAreCorrectCheck.classifierIsANestedPipeline((Name)classifier).map(callExpression -> TreeUtils.nthArgumentOrKeyword(0, "steps", callExpression.arguments())).map(SklearnPipelineParameterAreCorrectCheck::getExpressionsFromArgument).orElse(Stream.empty()).map(TreeUtils.toInstanceOfMapper(TupleImpl.class)).filter(Objects::nonNull).map(TupleImpl::elements).filter(SklearnPipelineParameterAreCorrectCheck::isTwoElementTuple).map(elements -> new ExpressionAndPrefix((List<Expression>)elements, expressionAndPrefix.prefix() + ((StringLiteral)step).trimmedQuotesValue() + "__", expressionAndPrefix.depth + 1)).flatMap(SklearnPipelineParameterAreCorrectCheck.expandRecursivePipelines());
        };
    }

    private static Optional<CallExpression> classifierIsANestedPipeline(Name classifier) {
        return Expressions.singleAssignedNonNameValue(classifier).flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class)).filter(callExpression -> Optional.of(callExpression).map(CallExpression::calleeSymbol).map(Symbol::fullyQualifiedName).filter("sklearn.pipeline.Pipeline"::equals).isPresent());
    }

    private static Optional<StepAndClassifier> getResult(List<Expression> tuple) {
        Expression step = tuple.get(0);
        Expression classifier = tuple.get(1);
        Optional<String> stepName = Optional.ofNullable(step).flatMap(TreeUtils.toOptionalInstanceOfMapper(StringLiteral.class)).map(StringLiteral::trimmedQuotesValue);
        Optional<ClassSymbol> classifierName = Optional.ofNullable(classifier).flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class)).map(CallExpression::calleeSymbol).filter(symbol -> symbol.is(Symbol.Kind.CLASS) && !"sklearn.pipeline.Pipeline".equals(symbol.fullyQualifiedName())).map(ClassSymbol.class::cast);
        return stepName.flatMap(stepName1 -> classifierName.map(classifierName1 -> new StepAndClassifier((String)stepName1, (ClassSymbol)classifierName1)));
    }

    private static Map<String, Set<ParameterNameAndLocation>> getStepAndParametersFromArguments(CallExpression callExpression) {
        return callExpression.arguments().stream().filter(argument -> argument.is(Tree.Kind.REGULAR_ARGUMENT)).map(RegularArgument.class::cast).map(RegularArgument::keywordArgument).filter(Objects::nonNull).map(SklearnPipelineParameterAreCorrectCheck::getStepAndParameterFromName).mapMulti(Optional::ifPresent).collect(SklearnPipelineParameterAreCorrectCheck.mergeStringAndTreeToMapCollector());
    }

    private static Map<String, Set<ParameterNameAndLocation>> getStepAndParametersFromDict(CallExpression callExpression) {
        return Optional.ofNullable(TreeUtils.nthArgumentOrKeyword(1, "param_grid", callExpression.arguments())).flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class)).map(RegularArgument::expression).stream().flatMap(SklearnPipelineParameterAreCorrectCheck::extractKeyValuePairFromDictLiteral).map(KeyValuePairImpl::key).map(TreeUtils.toInstanceOfMapper(StringLiteral.class)).filter(Objects::nonNull).map(stringLiteral -> SklearnPipelineParameterAreCorrectCheck.getStepAndParameterFromString(stringLiteral.trimmedQuotesValue(), stringLiteral)).mapMulti(Optional::ifPresent).collect(SklearnPipelineParameterAreCorrectCheck.mergeStringAndTreeToMapCollector());
    }

    private static Stream<KeyValuePairImpl> extractKeyValuePairFromDictLiteral(Expression expression) {
        return Optional.of(expression).flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class)).flatMap(Expressions::singleAssignedNonNameValue).flatMap(TreeUtils.toOptionalInstanceOfMapper(DictionaryLiteralImpl.class)).map(DictionaryLiteral::elements).stream().flatMap(Collection::stream).map(TreeUtils.toInstanceOfMapper(KeyValuePairImpl.class)).filter(Objects::nonNull);
    }

    private static Collector<StepAndParameter, ?, Map<String, Set<ParameterNameAndLocation>>> mergeStringAndTreeToMapCollector() {
        return Collectors.toMap(StepAndParameter::step, stepAndParameter -> Set.of(new ParameterNameAndLocation(stepAndParameter.parameter, stepAndParameter.location)), (set1, set2) -> {
            HashSet set = new HashSet(set1);
            set.addAll(set2);
            return set;
        });
    }

    private static Optional<StepAndParameter> getStepAndParameterFromName(Name name) {
        return SklearnPipelineParameterAreCorrectCheck.splitStepString(name.name()).map(split -> {
            String splitsNotLast = Arrays.stream(split).limit((long)((String[])split).length - 1L).collect(Collectors.joining("__"));
            return new StepAndParameter(splitsNotLast, split[((String[])split).length - 1], name);
        });
    }

    private static Optional<StepAndParameter> getStepAndParameterFromString(String string, Tree location) {
        return SklearnPipelineParameterAreCorrectCheck.splitStepString(string).map(split -> new StepAndParameter(split[0], split[1], location));
    }

    private static Optional<String[]> splitStepString(String string) {
        String[] split = string.split("__");
        if (split.length < 2 || string.endsWith("__")) {
            return Optional.empty();
        }
        return Optional.of(split);
    }

    private record ExpressionAndPrefix(List<Expression> tuple, String prefix, int depth) {
    }

    private record StepAndParameter(String step, String parameter, Tree location) {
    }

    private record ParameterNameAndLocation(String string, Tree tree) {
    }

    private record StepAndClassifier(String stepName, ClassSymbol classifierName) {
    }

    private record PipelineNameAndParsedParameters(Name pipelineName, Map<String, Set<ParameterNameAndLocation>> parsedParameters) {
    }
}

