import { Node, useReactFlow, XYPosition } from '@xyflow/react'
import { useState, useCallback, MouseEvent, useEffect } from 'react'
import { NodeType } from '@/types'
import { useDebouncedCallback } from 'use-debounce'

import { useCanvasContext } from '../../context'
import { useDetachNodes } from '../../hooks'
import { getDistance, sortNodes } from '../../utils/nodeUtils'

export const useNodeIntersection = (
  setNodes: React.Dispatch<React.SetStateAction<any[]>>,
) => {
  const [isIntersecting, setIsIntersecting] = useState(false)

  const { getIntersectingNodes, updateNode } = useReactFlow()
  const { invokeOnNodeIntersectionCallbacks, invokeOnNodeDragStopCallbacks } =
    useCanvasContext()
  const detachNodes = useDetachNodes()

  const handleIntersection = (mediaNodes: Node[]) => {
    mediaNodes.forEach((mediaNode) => {
      updateNode(mediaNode.id, (n) => ({
        ...n,
        data: { ...n.data, isIntersecting: true },
      }))
    })
  }

  const debouncedHandleIntersection = useDebouncedCallback(
    handleIntersection,
    500,
  )

  useEffect(() => {
    return () => {
      debouncedHandleIntersection.cancel()
    }
  }, [debouncedHandleIntersection])

  const checkIntersections = useCallback(
    (node: Node) => {
      const mediaIntersections = getIntersectingNodes(node).filter(
        (n) => n.type === NodeType.MediaNode,
      )

      if (mediaIntersections.length) {
        if (!isIntersecting) {
          setIsIntersecting(true)

          debouncedHandleIntersection(mediaIntersections)
        }
      } else {
        setIsIntersecting(false)

        debouncedHandleIntersection.cancel()
        setNodes((nds: Node[]) => {
          return nds.map((n) => {
            return {
              ...n,
              data: { ...n.data, isIntersecting: false },
            }
          })
        })
      }
    },
    [
      getIntersectingNodes,
      debouncedHandleIntersection,
      setNodes,
      isIntersecting,
    ],
  )

  const handleInvokeNodeIntersection = useCallback(
    (_: MouseEvent, node: Node) => {
      const intersections = getIntersectingNodes(node)

      let updatedNode = { ...node }

      if (intersections.length) {
        const intersectingNode = intersections.reduce(
          (closestNode, currentNode) => {
            const currentDistance = getDistance(node, currentNode)
            const closestDistance = getDistance(node, closestNode)
            return currentDistance < closestDistance ? currentNode : closestNode
          },
        )

        if (
          node.type === NodeType.MediaNode &&
          (intersectingNode.data.isIntersecting ||
            intersectingNode.type === NodeType.FlowNode ||
            intersectingNode.type === NodeType.CollectionNode)
        ) {
          invokeOnNodeIntersectionCallbacks(node, intersectingNode)
        }

        if (intersectingNode.type === NodeType.AssembleSlotNode) {
          updatedNode = {
            ...node,
            position: {
              x: intersectingNode.position?.x ?? 0,
              y: intersectingNode.position?.y ?? 0,
            },
            parentId: intersectingNode.parentId,
            data: {
              ...node.data,
              canvasId: intersectingNode.id,
            },
          }

          setNodes((nds) =>
            nds
              .map((n) => (n.id === updatedNode.id ? updatedNode : n))
              .sort(sortNodes),
          )
        }

        if (
          node.type === NodeType.MediaNode &&
          intersectingNode.type === NodeType.MediaNode &&
          intersectingNode.data.isIntersecting
        ) {
          updateNode(node.id, (n) => ({
            ...n,
            position: n.data.dragStartPosition as XYPosition,
          }))

          setNodes((nds: Node[]) => {
            return nds.map((n) => {
              return {
                ...n,
                data: { ...n.data, isIntersecting: false },
              }
            })
          })
        }
      } else {
        detachNodes([node.id])
        updatedNode = {
          ...node,
          parentId: null,
          data: { ...node.data, canvasId: null },
        }

        setNodes((nds) =>
          nds
            .map((n) => (n.id === updatedNode.id ? updatedNode : n))
            .sort(sortNodes),
        )
      }

      updateNode(node.id, (n) => ({
        ...n,
        dragging: false,
        style: { ...n.style, opacity: 1 },
      }))

      invokeOnNodeDragStopCallbacks(_, updatedNode)
    },
    [
      invokeOnNodeIntersectionCallbacks,
      setNodes,
      getIntersectingNodes,
      updateNode,
      detachNodes,
      invokeOnNodeDragStopCallbacks,
    ],
  )

  return {
    checkIntersections,
    handleIntersection: debouncedHandleIntersection,
    handleInvokeNodeIntersection,
  }
}
