import { useReactFlow, Node, NodeResizer } from '@xyflow/react'
import { ResizeDragEvent, ResizeParams } from '@xyflow/system'
import { memo, useEffect, useCallback } from 'react'
import { withErrorBoundary } from 'react-error-boundary'

import { CROP_TITLE_HEIGHT } from './CropperNode'
import { ErrorNode } from './ErrorNode'
import {
  DEFAULT_FORM_VALUES,
  MEDIA_CANVAS_DEFAULT_SIZE,
  STABLE_DIFFUSION_CHECKPOINTS,
  MEDIA_CANVAS_MAX_SIZE,
  MEDIA_CANVAS_MIN_SIZE,
} from '../../../constants'
import { useCanvasContext } from '../../../context'
import { useCreateMediaNode } from '../../../hooks/k2/useCreateMediaNode'
import { mediaStore } from '../../../stores'
import {
  fetchImageDimensions,
  calculateScaledDimensions,
} from '../../../utils/imageUtils'
import { getAssetKeyFromMedia } from '../../../utils/mediaUtils'
import { MediaDisplay } from '../Media'
import {
  Media,
  NodeType,
  ModelType,
  ImageForm,
  MediaType,
  Image,
  ElementType,
} from '@/types'

export interface MediaNodeData extends Record<string, unknown> {
  media: Media
  isProcessing?: boolean
  isIntersecting?: boolean
  /**
   *  If this is true, do not subscribe this node to MediaStore which will update the node's media data
   *  e.g. CropperNode doesn't need to subscribe to MediaStore because it generates its own image .
   */
  noSubscribe?: boolean
  displaySize?: { width: number; height: number }
}

export interface MediaNodeProps {
  id: string
  data: MediaNodeData
}

const EMPTY_PROMPT = ''
const IMAGE_BLEND_REFERENCE_WEIGHT = 1

const BaseMediaNode: React.FC<MediaNodeProps> = memo(({ id, data }) => {
  const { registerOnNodeIntersection } = useCanvasContext()
  const { createMediaNode } = useCreateMediaNode()
  const { setNodes, getNode, updateNodeData } = useReactFlow()

  const scaledDimensions = calculateScaledDimensions(
    data.media.width,
    data.media.height,
    MEDIA_CANVAS_DEFAULT_SIZE,
  )

  const nodeDimensions = {
    width:
      data.displaySize?.width || // (1) User-resized width, if available
      scaledDimensions.width || // (2) Auto-scaled width based on media's aspect ratio
      MEDIA_CANVAS_DEFAULT_SIZE, // (3) Default width if neither of the above is available
    height:
      data.displaySize?.height ||
      scaledDimensions.height ||
      MEDIA_CANVAS_DEFAULT_SIZE,
  }

  // Subscribe to MediaStore so that the node's media is in sync with the store.
  useEffect(() => {
    if (data.noSubscribe) return
    if (!data.media.mediaId)
      throw new Error(
        'mediaId is required to subscribe to MediaStore. Node: ' + id,
      )
    return mediaStore.subscribeToMediaId(data.media.mediaId, (newMedia) => {
      updateNodeData(id, { media: newMedia })
    })
  }, [data.media.mediaId, data.noSubscribe, id, updateNodeData])

  useEffect(() => {
    registerOnNodeIntersection(
      id,
      async (selfNode: Node, intersectingNode: Node) => {
        // selfNode refers to the node that is receiving the overlap, aka the node is not being dragged
        // intersectingNode refers to the node that is being dragged and triggering the event
        if (intersectingNode.type === NodeType.MediaNode) {
          const selfNodeMedia = selfNode.data.media as Image
          const intersectingNodeMedia = intersectingNode.data.media as Image

          if (
            selfNodeMedia.type !== MediaType.Image ||
            intersectingNodeMedia.type !== MediaType.Image
          ) {
            return
          }

          const selfNodeSource = selfNodeMedia.source
          const intersectingNodeSource = intersectingNodeMedia.source

          if (!selfNodeSource || !intersectingNodeSource) {
            return
          }

          const selfNodeMediaAssetKey =
            await getAssetKeyFromMedia(selfNodeMedia)
          const intersectingNodeMediaAssetKey = await getAssetKeyFromMedia(
            intersectingNodeMedia,
          )

          if (!selfNodeMediaAssetKey || !intersectingNodeMediaAssetKey) {
            return
          }

          let { width, height } = selfNodeMedia

          if (!width || !height) {
            const imageDimensions = await fetchImageDimensions(
              intersectingNodeSource,
            )
            width = imageDimensions.width
            height = imageDimensions.height
          }

          const imageFormData: ImageForm = {
            aspectRatio: DEFAULT_FORM_VALUES.aspectRatio,
            width,
            height,
            negativePrompt: EMPTY_PROMPT,
            style: EMPTY_PROMPT,
            subject: EMPTY_PROMPT,
            type: MediaType.Image,
            imageReference: [
              selfNodeMediaAssetKey,
              intersectingNodeMediaAssetKey,
            ],
            imageReferenceWeights: [
              IMAGE_BLEND_REFERENCE_WEIGHT,
              IMAGE_BLEND_REFERENCE_WEIGHT,
            ],
            stableDiffusionCheckpoint: STABLE_DIFFUSION_CHECKPOINTS.Animated,
            version: DEFAULT_FORM_VALUES.version,
          }

          const modelType = ModelType.ImageLab

          await createMediaNode(selfNode.id, imageFormData, modelType, [
            ElementType.Prompt,
            ElementType.ImageReferenceUpload,
          ])
        }
      },
    )
  }, [id, registerOnNodeIntersection, createMediaNode])

  const onResize = useCallback(
    (
      event: ResizeDragEvent,
      { width: newWidth, height: newHeight }: ResizeParams,
    ) => {
      updateNodeData(id, {
        displaySize: { width: newWidth, height: newHeight },
      })
    },
    [id, updateNodeData],
  )

  const handleChangeToCropperNode = useCallback(() => {
    const node = getNode(id)
    if (node) {
      const cropperNode: Node<MediaNodeData> = {
        ...node,
        position: {
          x: node.position.x,
          y: node.position.y - CROP_TITLE_HEIGHT,
        },
        data,
        type: NodeType.CropperNode,
      }
      setNodes((nodes) => nodes.map((n) => (n.id === id ? cropperNode : n)))
    }
  }, [id, getNode, setNodes, data])

  // The browser doesn't seem to be able to upscale and interpolate videos past the actual width/height of the video.
  // Images will simply get blurrier if the max size is exceeded.
  const maxWidth =
    data.media.type === MediaType.Image
      ? MEDIA_CANVAS_MAX_SIZE
      : Math.min(MEDIA_CANVAS_MAX_SIZE, data.media.width)
  const maxHeight =
    data.media.type === MediaType.Image
      ? MEDIA_CANVAS_MAX_SIZE
      : Math.min(MEDIA_CANVAS_MAX_SIZE, data.media.height)

  return (
    <div className='relative group' style={nodeDimensions}>
      <NodeResizer
        minWidth={MEDIA_CANVAS_MIN_SIZE}
        minHeight={MEDIA_CANVAS_MIN_SIZE}
        maxWidth={maxWidth}
        maxHeight={maxHeight}
        onResize={onResize}
        keepAspectRatio
        lineStyle={{
          border: 'none',
          padding: '10px',
        }}
        handleStyle={{
          width: '20px',
          height: '20px',
          border: 'none',
          background: 'transparent',
          padding: '10px',
        }}
      />
      <MediaDisplay
        media={data.media}
        width={nodeDimensions.width}
        height={nodeDimensions.height}
        nodeId={id}
        onCropperNodeChange={handleChangeToCropperNode}
        showNodeToolbar
        isIntersecting={data.isIntersecting}
      />
      {/* <Handle type='target' position={Position.Left} /> */}
    </div>
  )
})

BaseMediaNode.displayName = 'MediaNode'

export const MediaNode = withErrorBoundary(BaseMediaNode, {
  FallbackComponent: ErrorNode,
  onError(error: any, info: any) {
    console.error('MediaNode Error caught by Error Boundary:', error, info)
  },
})
