import { CollectionAction } from '@kaiber/shared-types'
import { getOutgoers, useReactFlow } from '@xyflow/react'
import { useCallback, useEffect, useRef } from 'react'

import {
  processMedia,
  generateFlipbookV1,
  generateMotionV3,
  generateTransformV3,
  generateInpaintImage,
  generateUpscaleV1Image,
  generateFluxImage,
} from '../../api'
import { FlowNodeData } from '../../components/k2/Nodes/FlowNode/FlowNode'
import { FLOWS_CONFIG } from '../../config/flowsConfig'
import { MEDIA_TYPES } from '../../constants'
import { useAnalytics, useMutateCollection, useNodeUtility } from '../../hooks'
import { mediaStore, myLibraryStore } from '../../stores'
import { getDimensionsFromAspectRatio } from '../../utils/imageUtils'
import { makeTempMediaId } from '../../utils/mediaUtils'
import {
  Media,
  Status,
  MediaForm,
  ModelType,
  AllNodeTypes,
  NodeType,
  Flow,
  ElementType,
  AnalyticsEvent,
  NodeOrigin,
} from '@/types'

export const useCreateMediaNode = () => {
  const { updateNodeData, addNodes, getInternalNode, getNodes, getEdges } =
    useReactFlow()
  const { trackNodeEvent, trackEvent } = useAnalytics()
  const { getMaxNodeZIndex } = useNodeUtility()

  const updateNodeDataRef = useRef(updateNodeData)
  useEffect(() => {
    updateNodeDataRef.current = updateNodeData
  }, [updateNodeData])

  const extractWidthAndHeightFromMediaForm = (formValues: MediaForm) => {
    const { width, height } = formValues

    if (!width || !height) {
      const { aspectRatio } = formValues
      const imageDimensions = getDimensionsFromAspectRatio(aspectRatio)
      return imageDimensions
    }

    return { width, height }
  }

  const { replaceAndPrependMedia, updateCollectionLocally } =
    useMutateCollection()

  const createMediaNode = useCallback(
    async (
      parentId: string,
      formData: MediaForm,
      modelType: ModelType,
      customElements?: ElementType[],
    ) => {
      if (!(modelType in FLOWS_CONFIG)) {
        throw new Error(`Unsupported model type: ${modelType}`)
      }

      const createdAt = new Date()
      const parent = getInternalNode(parentId)
      const parentXPos = parent.position.x
      const parentYPos = parent.position.y
      const parentHeight = parent.measured?.height || 0
      const parentWidth = parent.measured?.width || 0

      const randomOffsetX = Math.floor(Math.random() * 150) + 50
      const randomOffsetY = Math.floor(Math.random() * (parentHeight / 2)) + 20

      const tempNodeX = parentXPos + parentWidth + randomOffsetX
      const tempNodeY = parentYPos + randomOffsetY

      const createdAtTimestamp = createdAt.getMilliseconds()
      const tempId = `${NodeType.MediaNode}-${createdAtTimestamp}`

      let elements: ElementType[] = customElements
      let name

      if (parent.type === NodeType.FlowNode) {
        const data = parent.data as FlowNodeData
        elements = customElements || data.elements
        name = data.name
      }

      // Validate elements against FLOWS_CONFIG
      const config = FLOWS_CONFIG[modelType]
      const { requiredElements, optionalElements } = config
      const missingRequiredElements = requiredElements.filter(
        (el: ElementType) => !elements.includes(el),
      )
      if (missingRequiredElements.length > 0) {
        console.warn(
          `Warning: Missing required elements for ${modelType}: ${missingRequiredElements.join(', ')}`,
        )
      }

      const unexpectedElements = elements.filter(
        (el: ElementType) =>
          !requiredElements.includes(el) && !optionalElements.includes(el),
      )
      if (unexpectedElements.length > 0) {
        console.warn(
          `Warning: Unexpected elements for ${modelType}: ${unexpectedElements.join(', ')}`,
        )
      }

      // @todo - https://kaiberteam.atlassian.net/browse/ENG-2364
      // the flows will eventually always their own width/height values so we don't need to do this check
      const { width, height } = extractWidthAndHeightFromMediaForm(formData)
      formData.width = width
      formData.height = height

      const flow: Flow = {
        modelType,
        elements,
        formValues: formData,
        name,
      }

      const tempMedia: Media = {
        mediaId: makeTempMediaId(),
        status: Status.Pending,
        type: MEDIA_TYPES[modelType],
        createdAt,
        flow,
        ...formData,
      }
      mediaStore.setMedia(tempMedia)
      updateNodeDataRef.current(parentId, {
        ...(parent.data ?? {}),
        status: Status.Pending,
        lastTempId: tempId,
      })

      // CollectionIds of any CollectionNode connected to the current Flow node
      const connectedCollectionIds = new Set(
        getOutgoers(parent, getNodes(), getEdges())
          .filter((node) => node.type === NodeType.CollectionNode)
          .map((node) => node.data.collectionId as string),
      )
      const hasConnectedCollection = connectedCollectionIds.size > 0
      if (hasConnectedCollection) {
        // Prepend the temp media to all collection nodes connected to the current Flow node
        connectedCollectionIds.forEach((collectionId: string) => {
          updateCollectionLocally(
            collectionId,
            'ADD_MEDIA' as CollectionAction,
            {
              mediaIds: [tempMedia.mediaId],
            },
          )
        })
      } else {
        // No collection nodes connected to the current Flow node. Add a node to canvas
        const tempNode: AllNodeTypes = {
          id: tempId,
          zIndex: getMaxNodeZIndex(),
          position: { x: tempNodeX, y: tempNodeY },
          data: {
            media: tempMedia,
          },
          type: NodeType.MediaNode,
        }
        addNodes(tempNode)
        trackNodeEvent(tempNode, AnalyticsEvent.NodeAdded, {
          nodeOrigin: NodeOrigin.GeneratedMedia,
        })
      }

      //MediaNodeProcessingStarted

      let response
      switch (modelType) {
        case ModelType.FlipbookV1:
          response = await generateFlipbookV1(flow)
          break
        case ModelType.FluxImage:
          response = await generateFluxImage(flow)
          break
        case ModelType.ImageLab:
          response = await processMedia(flow)
          break
        case ModelType.InpaintV1:
          response = await generateInpaintImage(flow)
          break
        case ModelType.MotionV3:
          response = await generateMotionV3(flow)
          break
        case ModelType.TransformV3:
          response = await generateTransformV3(flow)
          break
        case ModelType.UpscaleV1:
          response = await generateUpscaleV1Image(flow)
          break
      }

      if (response?.data) {
        const updatedMedia = response.data.medias?.[0] || {}
        let media = { ...tempMedia, ...updatedMedia }

        trackEvent(AnalyticsEvent.MediaProcessed, {
          mediaType: media.type,
          mediaId: media.mediaId,
          modelType,
          nodeId: parentId,
        })

        try {
          mediaStore.signalTempMediaMatured(tempMedia, media)
          myLibraryStore.prependMediaId(media.mediaId)

          // Wait for all collection updates to complete
          await Promise.all(
            Array.from(connectedCollectionIds).map((collectionId) =>
              replaceAndPrependMedia({
                collectionId: collectionId,
                tempMediaId: tempMedia.mediaId,
                newMediaId: media.mediaId,
              }),
            ),
          )

          const parent = getInternalNode(parentId)

          if (parent.data?.lastTempId === tempId) {
            updateNodeDataRef.current(parentId, {
              ...(parent.data ?? {}),
              status: Status.Done,
            })
          }
        } catch (error) {
          console.error('Error updating media in collection:', error)
        }
      }
    },
    [
      getInternalNode,
      getNodes,
      getEdges,
      getMaxNodeZIndex,
      addNodes,
      trackNodeEvent,
      trackEvent,
      replaceAndPrependMedia,
      updateCollectionLocally,
    ],
  )
  return { createMediaNode }
}
