/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonCheck;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
import org.sonar.plugins.python.api.symbols.ClassSymbol;
import org.sonar.plugins.python.api.tree.AssignmentStatement;
import org.sonar.plugins.python.api.tree.BaseTreeVisitor;
import org.sonar.plugins.python.api.tree.ClassDef;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.tree.Tuple;
import org.sonar.python.quickfix.TextEditUtils;
import org.sonar.python.tree.FunctionDefImpl;
import org.sonar.python.tree.TreeUtils;

@Rule(key="S6974")
public class SkLearnEstimatorDontInitializeEstimatedValuesCheck
extends PythonSubscriptionCheck {
    private static final String BASE_ESTIMATOR_FULLY_QUALIFIED_NAME = "sklearn.base.BaseEstimator";
    private static final Set<String> MIXINS_FULLY_QUALIFIED_NAME = Set.of("sklearn.base.BiclusterMixin", "sklearn.base.ClassifierMixin", "sklearn.base.ClusterMixin", "sklearn.base.DensityMixin", "sklearn.base.MetaEstimatorMixin", "sklearn.base.OneToOneFeatureMixin", "sklearn.base.OutlierMixin", "sklearn.base.RegressorMixin", "sklearn.base.TransformerMixin");
    private static final String MESSAGE = "Move this estimated attribute in the `fit` method.";
    private static final String MESSAGE_SECONDARY = "The attribute is used in this estimator";
    public static final String QUICK_FIX_MESSAGE = "Remove the statement";
    public static final String QUICK_FIX_RENAME_MESSAGE = "Remove all trailing underscores from the variable name";

    @Override
    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.FUNCDEF, SkLearnEstimatorDontInitializeEstimatedValuesCheck::checkFunction);
    }

    private static boolean inheritsMixin(ClassSymbol classSymbol) {
        return MIXINS_FULLY_QUALIFIED_NAME.stream().anyMatch(classSymbol::isOrExtends);
    }

    private static void checkFunction(SubscriptionContext subscriptionContext) {
        FunctionDefImpl functionDef = (FunctionDefImpl)subscriptionContext.syntaxNode();
        if (!"__init__".equals(functionDef.name().name())) {
            return;
        }
        ClassDef classDef = (ClassDef)TreeUtils.firstAncestorOfKind(functionDef, Tree.Kind.CLASSDEF);
        if (classDef == null) {
            return;
        }
        ClassSymbol classSymbol = TreeUtils.getClassSymbolFromDef(classDef);
        if (classSymbol == null) {
            return;
        }
        boolean inheritsBaseEstimator = Optional.of(classSymbol).map(classSymbol1 -> classSymbol1.isOrExtends(BASE_ESTIMATOR_FULLY_QUALIFIED_NAME)).orElse(false);
        if (!inheritsBaseEstimator && !SkLearnEstimatorDontInitializeEstimatedValuesCheck.inheritsMixin(classSymbol)) {
            return;
        }
        VariableDeclarationEndingWithUnderscoreVisitor visitor = new VariableDeclarationEndingWithUnderscoreVisitor();
        functionDef.body().accept(visitor);
        Map<QualifiedExpression, AssignmentStatement> offendingVariables = visitor.qualifiedExpressions;
        Name secondaryLocation = classDef.name();
        offendingVariables.forEach((qualifiedExpression, assignmentStatement) -> {
            PythonCheck.PreciseIssue issue = subscriptionContext.addIssue(qualifiedExpression.name(), MESSAGE).secondary(secondaryLocation, MESSAGE_SECONDARY);
            SkLearnEstimatorDontInitializeEstimatedValuesCheck.createQuickFix(assignmentStatement).ifPresent(issue::addQuickFix);
            issue.addQuickFix(SkLearnEstimatorDontInitializeEstimatedValuesCheck.createQuickFixRename(qualifiedExpression));
        });
    }

    private static PythonQuickFix createQuickFixRename(QualifiedExpression qualifiedExpression) {
        PythonQuickFix.Builder quickFix = PythonQuickFix.newQuickFix(QUICK_FIX_RENAME_MESSAGE);
        String newName = qualifiedExpression.name().name().replaceAll("_+$", "");
        return quickFix.addTextEdit(TextEditUtils.renameAllUsages(qualifiedExpression.name(), newName)).build();
    }

    private static Optional<PythonQuickFix> createQuickFix(AssignmentStatement assignmentStatement) {
        PythonQuickFix.Builder builder = PythonQuickFix.newQuickFix(QUICK_FIX_MESSAGE);
        if (assignmentStatement.lhsExpressions().size() != 1 || assignmentStatement.lhsExpressions().stream().anyMatch(expressions -> expressions.expressions().size() != 1)) {
            return Optional.empty();
        }
        builder.addTextEdit(TextEditUtils.removeStatement(assignmentStatement));
        if (assignmentStatement.assignedValue().is(Tree.Kind.NONE)) {
            return Optional.of(builder.build());
        }
        return Optional.empty();
    }

    private static class VariableDeclarationEndingWithUnderscoreVisitor
    extends BaseTreeVisitor {
        private final Map<QualifiedExpression, AssignmentStatement> qualifiedExpressions = new HashMap<QualifiedExpression, AssignmentStatement>();

        private VariableDeclarationEndingWithUnderscoreVisitor() {
        }

        private static boolean isOffendingQualifiedExpression(QualifiedExpression qualifiedExpression) {
            return !qualifiedExpression.name().name().startsWith("__") && qualifiedExpression.name().name().endsWith("_") && qualifiedExpression.qualifier().is(Tree.Kind.NAME) && "self".equals(((Name)qualifiedExpression.qualifier()).name());
        }

        @Override
        public void visitAssignmentStatement(AssignmentStatement pyAssignmentStatementTree) {
            Stream<QualifiedExpression> offendingQualifiedExpressions = pyAssignmentStatementTree.lhsExpressions().stream().flatMap(expressionList -> expressionList.expressions().stream()).filter(expression -> expression.is(Tree.Kind.QUALIFIED_EXPR)).map(QualifiedExpression.class::cast);
            Stream<QualifiedExpression> offendingTuples = pyAssignmentStatementTree.lhsExpressions().stream().flatMap(expressionList -> expressionList.expressions().stream()).filter(expression -> expression.is(Tree.Kind.TUPLE)).map(Tuple.class::cast).flatMap(tuple -> tuple.elements().stream()).filter(expression -> expression.is(Tree.Kind.QUALIFIED_EXPR)).map(QualifiedExpression.class::cast);
            Stream.concat(offendingQualifiedExpressions, offendingTuples).filter(VariableDeclarationEndingWithUnderscoreVisitor::isOffendingQualifiedExpression).forEach(qualifiedExpression -> this.qualifiedExpressions.put((QualifiedExpression)qualifiedExpression, pyAssignmentStatementTree));
            super.visitAssignmentStatement(pyAssignmentStatementTree);
        }
    }
}

