/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.symbols.FunctionSymbol;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.types.v2.PythonType;
import org.sonar.python.cfg.fixpoint.ReachingDefinitionsAnalysis;
import org.sonar.python.checks.cdk.CdkPredicate;
import org.sonar.python.checks.hotspots.CommonValidationUtils;
import org.sonar.python.semantic.ClassSymbolImpl;
import org.sonar.python.semantic.SymbolUtils;
import org.sonar.python.tree.TreeUtils;
import org.sonar.python.types.v2.TypeCheckMap;

@Rule(key="S6709")
public class RandomSeedCheck
extends PythonSubscriptionCheck {
    private static final String NUMPY_SEED_ARG_NAME = "seed";
    private static final Map<String, String> SEED_METHODS_TO_CHECK = Map.of("numpy.seed", "seed", "numpy.random.seed", "seed", "numpy.random.default_rng", "seed", "numpy.random.SeedSequence", "entropy", "numpy.random.PCG64", "seed", "numpy.random.PCG64DXSM", "seed", "numpy.random.MT19937", "seed", "numpy.random.SFC64", "seed", "numpy.random.Philox", "seed");
    private static final String SKLEARN_FQN = "sklearn";
    private static final String SKLEARN_ARG_NAME = "random_state";
    private static final String MESSAGE = "Provide a seed for this random generator.";
    private static final String SKLEARN_MESSAGE = "Provide a seed for the random_state parameter.";
    private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis;
    private static final Predicate<CallExpression> SOLVER_NOT_SAG_SAGA = RandomSeedCheck.keywordAbsentOrNotIn("solver", "sag", "saga");
    private static final Predicate<CallExpression> SELECTION_NOT_RANDOM = RandomSeedCheck.keywordAbsentOrNotIn("selection", "random");
    private static final Map<String, Predicate<CallExpression>> SKLEARN_EXCEPTIONS = Map.ofEntries(Map.entry("sklearn.svm._classes.SVC", RandomSeedCheck.probabilityArgAbsent()), Map.entry("sklearn.linear_model._logistic.LogisticRegression", SOLVER_NOT_SAG_SAGA), Map.entry("sklearn.linear_model._ridge.Ridge", SOLVER_NOT_SAG_SAGA), Map.entry("sklearn.linear_model._coordinate_descent.Lasso", SELECTION_NOT_RANDOM), Map.entry("sklearn.linear_model._coordinate_descent.ElasticNet", SELECTION_NOT_RANDOM));
    private TypeCheckMap<Predicate<CallExpression>> typeCheckMap;

    private static Predicate<CallExpression> keywordAbsentOrNotIn(String keyword, String ... restrictedValues) {
        Set<String> restrictedValueSet = Set.of(restrictedValues);
        return call -> {
            RegularArgument arg = TreeUtils.argumentByKeyword(keyword, call.arguments());
            if (arg == null) {
                return true;
            }
            String expressionString = CommonValidationUtils.singleAssignedString(arg.expression());
            return restrictedValueSet.stream().noneMatch(expressionString::equals);
        };
    }

    private static Predicate<CallExpression> probabilityArgAbsent() {
        return call -> {
            RegularArgument probabilityArg = TreeUtils.argumentByKeyword("probability", call.arguments());
            return probabilityArg == null || CdkPredicate.isFalse().test(probabilityArg.expression());
        };
    }

    @Override
    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.FILE_INPUT, ctx -> {
            this.reachingDefinitionsAnalysis = new ReachingDefinitionsAnalysis(ctx.pythonFile());
            this.typeCheckMap = new TypeCheckMap();
            SKLEARN_EXCEPTIONS.forEach((fqn, predicate) -> this.typeCheckMap.put(ctx.typeChecker().typeCheckBuilder().isTypeWithFqn((String)fqn), (Predicate<CallExpression>)predicate));
        });
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, this::checkEmptySeedCall);
    }

    private void checkEmptySeedCall(SubscriptionContext ctx) {
        CallExpression call = (CallExpression)ctx.syntaxNode();
        Optional<Symbol> maybeCalleeSymbol = Optional.ofNullable(call.calleeSymbol());
        maybeCalleeSymbol.map(Symbol::fullyQualifiedName).map(SEED_METHODS_TO_CHECK::get).filter(argName -> this.isArgumentAbsentOrNone(TreeUtils.nthArgumentOrKeyword(0, argName, call.arguments()))).map(arg -> MESSAGE).or(() -> maybeCalleeSymbol.filter(symbol -> symbol.fullyQualifiedName() != null && symbol.fullyQualifiedName().startsWith(SKLEARN_FQN)).filter(RandomSeedCheck::hasRandomStateParameter).filter(symbol -> this.isArgumentAbsentOrNone(TreeUtils.argumentByKeyword(SKLEARN_ARG_NAME, call.arguments()))).filter(symbol -> !this.isSKLearnException(call)).map(symbol -> SKLEARN_MESSAGE)).ifPresent(message -> ctx.addIssue(call.callee(), (String)message));
    }

    private static boolean hasRandomStateParameter(Symbol calleeSymbol) {
        return RandomSeedCheck.isClassInstantiationWithRandomStateParameter(calleeSymbol).or(() -> RandomSeedCheck.isFunctionWithRandomStateParameter(calleeSymbol)).orElse(false);
    }

    private static Optional<Boolean> isClassInstantiationWithRandomStateParameter(Symbol calleeSymbol) {
        return Optional.of(calleeSymbol).filter(s -> s.is(Symbol.Kind.CLASS)).map(ClassSymbolImpl.class::cast).map(classSymbol -> classSymbol.declaredMembers().stream().filter(member -> "__init__".equals(member.name())).toList()).filter(members -> members.size() == 1).map(members -> (Symbol)members.get(0)).map(RandomSeedCheck::hasRandomStateParameter);
    }

    private static Optional<Boolean> isFunctionWithRandomStateParameter(Symbol calleeSymbol) {
        return Optional.of(calleeSymbol).filter(s1 -> s1.is(Symbol.Kind.FUNCTION)).map(SymbolUtils::getFunctionSymbols).filter(symbols -> symbols.size() == 1).map(symbols -> (FunctionSymbol)symbols.get(0)).map(symbol -> symbol.parameters().stream().map(FunctionSymbol.Parameter::name).anyMatch(SKLEARN_ARG_NAME::equals));
    }

    private boolean isArgumentAbsentOrNone(@Nullable RegularArgument arg) {
        return arg == null || arg.expression().is(Tree.Kind.NONE) || this.isAssignedNone(arg.expression());
    }

    private boolean isAssignedNone(Expression exp) {
        return Optional.of(exp).flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class)).map(this.reachingDefinitionsAnalysis::valuesAtLocation).filter(Predicate.not(Set::isEmpty)).filter(values -> values.stream().allMatch(value -> value.is(Tree.Kind.NONE))).isPresent();
    }

    private boolean isSKLearnException(CallExpression call) {
        PythonType calleeType = call.callee().typeV2();
        return this.typeCheckMap.getOptionalForType(calleeType).map(predicate -> predicate.test(call)).orElse(false);
    }
}

