/*
 * Decompiled with CFR 0.152.
 */
package com.jetbrains.python.codeInsight.controlflow;

import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.containers.Stack;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.codeInsight.controlflow.InstructionTypeCallback;
import com.jetbrains.python.codeInsight.stdlib.PyStdlibTypeProvider;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.PyBinaryExpression;
import com.jetbrains.python.psi.PyCallExpression;
import com.jetbrains.python.psi.PyCaseClause;
import com.jetbrains.python.psi.PyConditionalExpression;
import com.jetbrains.python.psi.PyElementType;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyIfPart;
import com.jetbrains.python.psi.PyMatchStatement;
import com.jetbrains.python.psi.PyPattern;
import com.jetbrains.python.psi.PyPrefixExpression;
import com.jetbrains.python.psi.PyRecursiveElementVisitor;
import com.jetbrains.python.psi.PyReferenceExpression;
import com.jetbrains.python.psi.PyTupleExpression;
import com.jetbrains.python.psi.PyUtil;
import com.jetbrains.python.psi.impl.PyEvaluator;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.PyClassType;
import com.jetbrains.python.psi.types.PyInstantiableType;
import com.jetbrains.python.psi.types.PyLiteralType;
import com.jetbrains.python.psi.types.PyNoneType;
import com.jetbrains.python.psi.types.PyStructuralType;
import com.jetbrains.python.psi.types.PyTupleType;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.PyTypeChecker;
import com.jetbrains.python.psi.types.PyTypeUtil;
import com.jetbrains.python.psi.types.PyUnionType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public class PyTypeAssertionEvaluator
extends PyRecursiveElementVisitor {
    private final Stack<Assertion> myStack = new Stack();
    private boolean myPositive;

    public PyTypeAssertionEvaluator(boolean positive) {
        this.myPositive = positive;
    }

    List<Assertion> getDefinitions() {
        return this.myStack;
    }

    @Override
    public void visitPyCallExpression(@NotNull PyCallExpression node) {
        PyExpression typeElement2;
        PyExpression[] args;
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(0);
        }
        if (node.isCalleeText(new String[]{"isinstance", "assertIsInstance"})) {
            PyExpression pyExpression;
            PyExpression[] args2 = node.getArguments();
            if (args2.length == 2 && (pyExpression = args2[0]) instanceof PyReferenceExpression) {
                PyReferenceExpression target = (PyReferenceExpression)pyExpression;
                PyExpression typeElement2 = args2[1];
                this.pushAssertion(target, this.myPositive, context -> PyTypeAssertionEvaluator.transformTypeFromAssertion(context.getType(typeElement2), false, context, typeElement2));
            }
        } else if (node.isCalleeText(new String[]{"callable"})) {
            PyExpression typeElement2;
            PyExpression[] args3 = node.getArguments();
            if (args3.length == 1 && (typeElement2 = args3[0]) instanceof PyReferenceExpression) {
                PyReferenceExpression target = (PyReferenceExpression)typeElement2;
                this.pushAssertion(target, this.myPositive, context -> PyTypingTypeProvider.createTypingCallableType((PsiElement)node));
            }
        } else if (node.isCalleeText(new String[]{"issubclass"}) && (args = node.getArguments()).length == 2 && (typeElement2 = args[0]) instanceof PyReferenceExpression) {
            PyReferenceExpression target = (PyReferenceExpression)typeElement2;
            typeElement2 = args[1];
            this.pushAssertion(target, this.myPositive, context -> PyTypeAssertionEvaluator.transformTypeFromAssertion(context.getType(typeElement2), true, context, typeElement2));
        }
    }

    @Override
    public void visitPyReferenceExpression(@NotNull PyReferenceExpression node) {
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(1);
        }
        if (this.myPositive && (PyTypeAssertionEvaluator.isIfReferenceStatement(node) || PyTypeAssertionEvaluator.isIfReferenceConditionalStatement(node) || PyTypeAssertionEvaluator.isIfNotReferenceStatement(node))) {
            this.pushAssertion(node, !this.myPositive, context -> PyNoneType.INSTANCE);
            return;
        }
        super.visitPyReferenceExpression(node);
    }

    @Override
    public void visitPyBinaryExpression(@NotNull PyBinaryExpression node) {
        boolean isOrEqualsOperator;
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(2);
        }
        PyExpression lhs = PyPsiUtils.flattenParens(node.getLeftExpression());
        PyExpression rhs = PyPsiUtils.flattenParens(node.getRightExpression());
        if (lhs == null || rhs == null) {
            return;
        }
        PyElementType operator = node.getOperator();
        boolean bl = isOrEqualsOperator = node.isOperator("is") || PyTokenTypes.EQEQ.equals(operator);
        if (isOrEqualsOperator || node.isOperator("isnot") || PyTokenTypes.NE.equals(operator) || PyTokenTypes.NE_OLD.equals(operator)) {
            this.setPositive(isOrEqualsOperator, () -> this.processIsOrEquals(lhs, rhs));
        }
        if (PyTokenTypes.IN_KEYWORD.equals(operator) || node.isOperator("notin")) {
            this.setPositive(PyTokenTypes.IN_KEYWORD.equals(operator), () -> this.processIn(lhs, rhs));
        }
    }

    private void processIsOrEquals(@NotNull PyExpression lhs, @NotNull PyExpression rhs) {
        Boolean leftBoolean;
        if (lhs == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(3);
        }
        if (rhs == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(4);
        }
        if ((leftBoolean = PyEvaluator.evaluateNoResolve(lhs, Boolean.class)) != null) {
            this.setPositive(leftBoolean, () -> rhs.accept(this));
            return;
        }
        Boolean rightBoolean = PyEvaluator.evaluateNoResolve(rhs, Boolean.class);
        if (rightBoolean != null) {
            this.setPositive(rightBoolean, () -> lhs.accept(this));
            return;
        }
        if (PyLiteralType.isNone(lhs)) {
            if (rhs instanceof PyReferenceExpression) {
                PyReferenceExpression referenceExpr = (PyReferenceExpression)rhs;
                this.pushAssertion(referenceExpr, this.myPositive, context -> PyNoneType.INSTANCE);
            }
            return;
        }
        if (PyLiteralType.isNone(rhs)) {
            if (lhs instanceof PyReferenceExpression) {
                PyReferenceExpression referenceExpr = (PyReferenceExpression)lhs;
                this.pushAssertion(referenceExpr, this.myPositive, context -> PyNoneType.INSTANCE);
            }
            return;
        }
        if (lhs instanceof PyReferenceExpression) {
            PyReferenceExpression referenceExpr = (PyReferenceExpression)lhs;
            this.pushAssertion(referenceExpr, this.myPositive, context -> PyTypeAssertionEvaluator.getLiteralType(rhs, context));
        }
    }

    private void processIn(@NotNull PyExpression lhs, @NotNull PyExpression rhs) {
        if (lhs == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(5);
        }
        if (rhs == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(6);
        }
        if (lhs instanceof PyReferenceExpression) {
            PyReferenceExpression referenceExpr = (PyReferenceExpression)lhs;
            if (rhs instanceof PyTupleExpression) {
                PyTupleExpression tupleExpr = (PyTupleExpression)rhs;
                this.pushAssertion(referenceExpr, this.myPositive, context -> {
                    PyExpression[] elements = tupleExpr.getElements();
                    ArrayList<PyType> types = new ArrayList<PyType>(elements.length);
                    for (PyExpression element : elements) {
                        PyNoneType type2;
                        PyType pyType = type2 = PyLiteralType.isNone(element) ? PyNoneType.INSTANCE : PyTypeAssertionEvaluator.getLiteralType(element, context);
                        if (type2 == null) {
                            return null;
                        }
                        types.add(type2);
                    }
                    return PyUnionType.union(types);
                });
            }
        }
    }

    @Nullable
    private static PyType getLiteralType(@NotNull PyExpression element, @NotNull TypeEvalContext context) {
        PyType type2;
        if (element == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(7);
        }
        if (context == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(8);
        }
        if ((type2 = PyLiteralType.getLiteralType(element, context)) == null) {
            type2 = context.getType(element);
        }
        return PyTypeUtil.toStream(type2).allMatch(subtype -> subtype instanceof PyLiteralType) ? type2 : null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void setPositive(boolean positive, @NotNull Runnable runnable) {
        if (runnable == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(9);
        }
        boolean oldPositive = this.myPositive;
        if (!positive) {
            this.myPositive = !this.myPositive;
        }
        try {
            runnable.run();
        }
        finally {
            this.myPositive = oldPositive;
        }
    }

    @Override
    public void visitPyPattern(@NotNull PyPattern node) {
        PyMatchStatement matchStatement;
        PyExpression subject;
        PsiElement parent;
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(10);
        }
        if ((parent = PsiTreeUtil.skipParentsOfType((PsiElement)node, (Class[])new Class[]{PyCaseClause.class})) instanceof PyMatchStatement && (subject = PyPsiUtils.flattenParens((matchStatement = (PyMatchStatement)parent).getSubject())) instanceof PyReferenceExpression) {
            PyReferenceExpression target = (PyReferenceExpression)subject;
            this.pushAssertion(target, this.myPositive, context -> context.getType(node));
        }
    }

    @Override
    public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) {
        if (matchStatement == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(11);
        }
        assert (!this.myPositive);
        PyExpression subject = matchStatement.getSubject();
        if (subject == null) {
            return;
        }
        if (subject instanceof PyReferenceExpression) {
            PyReferenceExpression target = (PyReferenceExpression)subject;
            this.pushAssertion(target, true, context -> {
                PyType subjectType = context.getType(subject);
                for (PyCaseClause cs : matchStatement.getCaseClauses()) {
                    if (cs.getPattern() == null || cs.getGuardCondition() != null) continue;
                    subjectType = (PyType)Ref.deref(PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
                }
                return subjectType;
            });
        }
    }

    /*
     * Issues handling annotations - annotations may be inaccurate
     */
    @ApiStatus.Internal
    @Nullable
    public static Ref<PyType> createAssertionType(@Nullable PyType initial, @Nullable PyType suggested, boolean positive, @NotNull TypeEvalContext context) {
        if (context == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(12);
        }
        if (positive) {
            List initialSubtypes = ((StreamEx)PyTypeUtil.toStream(initial).filter(initialSubtype -> PyTypeAssertionEvaluator.match(suggested, initialSubtype, context))).toList();
            StreamEx suggestedSubtypes = (StreamEx)((StreamEx)PyTypeUtil.toStream(suggested).filter(suggestedSubtype -> PyTypeAssertionEvaluator.match(initial, suggestedSubtype, context))).filter(suggestedSubtype -> !ContainerUtil.exists((Iterable)initialSubtypes, initialSubtype -> PyTypeAssertionEvaluator.match(initialSubtype, suggestedSubtype, context)));
            List types = ((StreamEx)StreamEx.of((Collection)initialSubtypes).append((Stream)suggestedSubtypes)).toList();
            return Ref.create((Object)(types.isEmpty() ? suggested : PyUnionType.union(types)));
        }
        if (initial instanceof PyUnionType) {
            PyUnionType unionType = (PyUnionType)initial;
            return Ref.create((Object)PyTypeAssertionEvaluator.excludeFromUnion(unionType, suggested, context));
        }
        if (PyTypeAssertionEvaluator.match(suggested, initial, context)) {
            return null;
        }
        @Nullable Ref diff = PyTypeAssertionEvaluator.trySubtract(initial, suggested, context);
        return diff != null ? diff : Ref.create((Object)initial);
    }

    @Nullable
    private static PyType excludeFromUnion(@NotNull PyUnionType unionType, @Nullable PyType type2, @NotNull TypeEvalContext context) {
        if (unionType == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(13);
        }
        if (context == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(14);
        }
        ArrayList<PyType> members = new ArrayList<PyType>();
        for (PyType m : unionType.getMembers()) {
            Ref<@Nullable PyType> diff = PyTypeAssertionEvaluator.trySubtract(m, type2, context);
            if (diff != null) {
                members.add((PyType)diff.get());
                continue;
            }
            if (PyTypeChecker.match(type2, m, context)) continue;
            members.add(m);
        }
        return PyUnionType.union(members);
    }

    @Nullable
    private static @Nullable Ref<@Nullable PyType> trySubtract(@Nullable PyType type1, @Nullable PyType type2, @NotNull TypeEvalContext context) {
        PyClassType classType1;
        if (context == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(15);
        }
        assert (!(type1 instanceof PyUnionType));
        if (!(type1 instanceof PyLiteralType) && type1 instanceof PyClassType && PyStdlibTypeProvider.isCustomEnum((classType1 = (PyClassType)type1).getPyClass(), context)) {
            if (ContainerUtil.exists(classType1.getPyClass().getAncestorClasses(context), cls -> "enum.Flag".equals(cls.getQualifiedName()))) {
                return null;
            }
            List<PyLiteralType> enumMembers = PyStdlibTypeProvider.getEnumMembers(classType1.getPyClass(), context).toList();
            List filteredEnumMembers = ContainerUtil.filter(enumMembers, m -> !PyTypeChecker.match(type2, m, context));
            PyType type3 = enumMembers.size() == filteredEnumMembers.size() ? type1 : PyUnionType.union(filteredEnumMembers);
            return Ref.create((Object)type3);
        }
        return null;
    }

    private static boolean match(@Nullable PyType expected, @Nullable PyType actual, @NotNull TypeEvalContext context) {
        if (context == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(16);
        }
        return !(actual instanceof PyStructuralType) && !PyTypeChecker.isUnknown(actual, context) && PyTypeChecker.match(expected, actual, context);
    }

    @Nullable
    private static PyType transformTypeFromAssertion(@Nullable PyType type2, boolean transformToDefinition, @NotNull TypeEvalContext context, @Nullable PyExpression typeElement) {
        if (context == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(17);
        }
        if (type2 instanceof PyTupleType) {
            PyTupleType tupleType = (PyTupleType)type2;
            ArrayList<PyType> members = new ArrayList<PyType>();
            int count = tupleType.getElementCount();
            PyTupleExpression tupleExpression = PyUtil.as(PyPsiUtils.flattenParens(typeElement), PyTupleExpression.class);
            if (tupleExpression != null && tupleExpression.getElements().length == count) {
                PyExpression[] elements = tupleExpression.getElements();
                for (int i = 0; i < count; ++i) {
                    members.add(PyTypeAssertionEvaluator.transformTypeFromAssertion(tupleType.getElementType(i), transformToDefinition, context, elements[i]));
                }
            } else {
                for (int i = 0; i < count; ++i) {
                    members.add(PyTypeAssertionEvaluator.transformTypeFromAssertion(tupleType.getElementType(i), transformToDefinition, context, null));
                }
            }
            return PyUnionType.union(members);
        }
        if (type2 instanceof PyUnionType) {
            return ((PyUnionType)type2).map(member -> PyTypeAssertionEvaluator.transformTypeFromAssertion(member, transformToDefinition, context, null));
        }
        if (type2 instanceof PyClassType && "types.UnionType".equals(((PyClassType)type2).getClassQName()) && typeElement != null) {
            Ref<PyType> typeFromTypingProvider = PyTypingTypeProvider.getType(typeElement, context);
            if (typeFromTypingProvider != null) {
                return PyTypeAssertionEvaluator.transformTypeFromAssertion((PyType)typeFromTypingProvider.get(), transformToDefinition, context, null);
            }
        } else if (type2 instanceof PyInstantiableType) {
            PyInstantiableType instantiableType = (PyInstantiableType)type2;
            return transformToDefinition ? instantiableType.toClass() : instantiableType.toInstance();
        }
        return type2;
    }

    private void pushAssertion(final @NotNull PyReferenceExpression target, final boolean positive, final @NotNull Function<TypeEvalContext, PyType> suggestedType) {
        if (target == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(18);
        }
        if (suggestedType == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(19);
        }
        InstructionTypeCallback typeCallback = new InstructionTypeCallback(){

            @Override
            public Ref<PyType> getType(TypeEvalContext context) {
                return PyTypeAssertionEvaluator.createAssertionType(context.getType(target), (PyType)suggestedType.apply(context), positive, context);
            }
        };
        this.myStack.push((Object)new Assertion(target, typeCallback));
    }

    private static boolean isIfReferenceStatement(@NotNull PyReferenceExpression node) {
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(20);
        }
        return node.getParent() instanceof PyIfPart;
    }

    private static boolean isIfReferenceConditionalStatement(@NotNull PyReferenceExpression node) {
        PsiElement parent;
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(21);
        }
        return (parent = node.getParent()) instanceof PyConditionalExpression && node == ((PyConditionalExpression)parent).getCondition();
    }

    private static boolean isIfNotReferenceStatement(@NotNull PyReferenceExpression node) {
        PsiElement parent;
        if (node == null) {
            PyTypeAssertionEvaluator.$$$reportNull$$$0(22);
        }
        return (parent = node.getParent()) instanceof PyPrefixExpression && ((PyPrefixExpression)parent).getOperator() == PyTokenTypes.NOT_KEYWORD && parent.getParent() instanceof PyIfPart;
    }

    private static /* synthetic */ void $$$reportNull$$$0(int n) {
        Object[] objectArray;
        Object[] objectArray2;
        Object[] objectArray3 = new Object[3];
        switch (n) {
            default: {
                objectArray2 = objectArray3;
                objectArray3[0] = "node";
                break;
            }
            case 3: 
            case 5: {
                objectArray2 = objectArray3;
                objectArray3[0] = "lhs";
                break;
            }
            case 4: 
            case 6: {
                objectArray2 = objectArray3;
                objectArray3[0] = "rhs";
                break;
            }
            case 7: {
                objectArray2 = objectArray3;
                objectArray3[0] = "element";
                break;
            }
            case 8: 
            case 12: 
            case 14: 
            case 15: 
            case 16: 
            case 17: {
                objectArray2 = objectArray3;
                objectArray3[0] = "context";
                break;
            }
            case 9: {
                objectArray2 = objectArray3;
                objectArray3[0] = "runnable";
                break;
            }
            case 11: {
                objectArray2 = objectArray3;
                objectArray3[0] = "matchStatement";
                break;
            }
            case 13: {
                objectArray2 = objectArray3;
                objectArray3[0] = "unionType";
                break;
            }
            case 18: {
                objectArray2 = objectArray3;
                objectArray3[0] = "target";
                break;
            }
            case 19: {
                objectArray2 = objectArray3;
                objectArray3[0] = "suggestedType";
                break;
            }
        }
        objectArray2[1] = "com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator";
        switch (n) {
            default: {
                objectArray = objectArray2;
                objectArray2[2] = "visitPyCallExpression";
                break;
            }
            case 1: {
                objectArray = objectArray2;
                objectArray2[2] = "visitPyReferenceExpression";
                break;
            }
            case 2: {
                objectArray = objectArray2;
                objectArray2[2] = "visitPyBinaryExpression";
                break;
            }
            case 3: 
            case 4: {
                objectArray = objectArray2;
                objectArray2[2] = "processIsOrEquals";
                break;
            }
            case 5: 
            case 6: {
                objectArray = objectArray2;
                objectArray2[2] = "processIn";
                break;
            }
            case 7: 
            case 8: {
                objectArray = objectArray2;
                objectArray2[2] = "getLiteralType";
                break;
            }
            case 9: {
                objectArray = objectArray2;
                objectArray2[2] = "setPositive";
                break;
            }
            case 10: {
                objectArray = objectArray2;
                objectArray2[2] = "visitPyPattern";
                break;
            }
            case 11: {
                objectArray = objectArray2;
                objectArray2[2] = "visitPyMatchStatement";
                break;
            }
            case 12: {
                objectArray = objectArray2;
                objectArray2[2] = "createAssertionType";
                break;
            }
            case 13: 
            case 14: {
                objectArray = objectArray2;
                objectArray2[2] = "excludeFromUnion";
                break;
            }
            case 15: {
                objectArray = objectArray2;
                objectArray2[2] = "trySubtract";
                break;
            }
            case 16: {
                objectArray = objectArray2;
                objectArray2[2] = "match";
                break;
            }
            case 17: {
                objectArray = objectArray2;
                objectArray2[2] = "transformTypeFromAssertion";
                break;
            }
            case 18: 
            case 19: {
                objectArray = objectArray2;
                objectArray2[2] = "pushAssertion";
                break;
            }
            case 20: {
                objectArray = objectArray2;
                objectArray2[2] = "isIfReferenceStatement";
                break;
            }
            case 21: {
                objectArray = objectArray2;
                objectArray2[2] = "isIfReferenceConditionalStatement";
                break;
            }
            case 22: {
                objectArray = objectArray2;
                objectArray2[2] = "isIfNotReferenceStatement";
                break;
            }
        }
        throw new IllegalArgumentException(String.format("Argument for @NotNull parameter '%s' of %s.%s must not be null", objectArray));
    }

    static class Assertion {
        private final PyReferenceExpression element;
        private final InstructionTypeCallback myFunction;

        Assertion(PyReferenceExpression element, InstructionTypeCallback getType) {
            this.element = element;
            this.myFunction = getType;
        }

        public PyReferenceExpression getElement() {
            return this.element;
        }

        public InstructionTypeCallback getTypeEvalFunction() {
            return this.myFunction;
        }
    }
}

