import Lottie from 'lottie-react'
import { useRef, useEffect, useState, useCallback } from 'react'
import { BiSolidBrush, BiUndo } from 'react-icons/bi'

// These icons disabled due to a memory problem with importing multiple icons sets from react-icons.
// The raw HTML is used instead for these two icons. See https://kaiberteam.atlassian.net/browse/ENG-1731.
// import { BsEraser } from 'react-icons/bs'
// import { SiTarget } from 'react-icons/si'
import Annotation from './Annotation'
import Slider from './Slider'
import { LABELS } from '../constants'
import LoadingWhite from '../images/lottie/loaderWhite.json'
import http from '../services/HttpService'

const DRAWING_COLOR = 'rgba(255, 105, 180, 1)'
const CANVAS = 'canvas'

const getCursorStyle = (penSize: number) => {
  const cursorSize = penSize
  const cursorCanvas = document.createElement(CANVAS)
  cursorCanvas.width = cursorSize * 2
  cursorCanvas.height = cursorSize * 2

  const cursorCtx = cursorCanvas.getContext('2d')
  if (cursorCtx) {
    cursorCtx.beginPath()
    cursorCtx.arc(cursorSize, cursorSize, cursorSize / 2, 0, 2 * Math.PI)
    cursorCtx.strokeStyle = '#8DDE9C' // Hardcoded color value
    cursorCtx.lineWidth = 2
    cursorCtx.stroke()
    cursorCtx.closePath()
  }

  const cursorDataURL = cursorCanvas.toDataURL('image/png')
  const cursorOffset = cursorSize

  return `url(${cursorDataURL}) ${cursorOffset} ${cursorOffset}, auto`
}

interface ImageMaskProps {
  imageUrl: string
  image: Blob
  onFinalMask: (maskFile: File) => void
  handleCloseImageMaskModal: () => void
  containerClassName?: string
  defaultMask?: File // Add this line
}

const ImageMask: React.FC<ImageMaskProps> = ({
  imageUrl,
  image,
  onFinalMask,
  defaultMask,
  handleCloseImageMaskModal,
  containerClassName = '',
}) => {
  // State variables
  const [segmentationMasks, setSegmentationMasks] = useState<string[]>([])
  const [highlightedMask, setHighlightedMask] = useState<number | null>(null)
  const [maskData, setMaskData] = useState<Uint8Array[]>([])
  const [isProcessing, setIsProcessing] = useState(false)
  const [isDrawingMode, setIsDrawingMode] = useState(false)
  const [isDrawing, setIsDrawing] = useState(false)
  const [penSize, setPenSize] = useState(25)
  const [lastPoint, setLastPoint] = useState<{ x: number; y: number } | null>(
    null,
  )
  const [canvasSize, setCanvasSize] = useState({ width: 0, height: 0 })
  const [isEraserMode, setIsEraserMode] = useState(false)
  const [canvasHistory, setCanvasHistory] = useState<string[]>([])

  // Refs for canvas elements
  const canvasRef = useRef<HTMLCanvasElement>(null)
  const highlightedMaskCanvasRef = useRef<HTMLCanvasElement>(null)
  const baseCanvasRef = useRef<HTMLCanvasElement>(null)
  const drawingCanvasRef = useRef<HTMLCanvasElement>(null)

  const resetCanvas = () => {
    const canvas = canvasRef.current
    if (canvas) {
      const ctx = canvas.getContext('2d')
      if (ctx) {
        // Clear the main canvas
        ctx.clearRect(0, 0, canvas.width, canvas.height)

        // Reset the canvas history
        setCanvasHistory([])
        unifyAlpha()
      }
    }
  }

  // Binarize the alpha on the canvas and set the desired transparency
  const unifyAlpha = () => {
    const canvas = canvasRef.current
    const ctx = canvas?.getContext('2d')
    // Create a temporary canvas to hold the current content of canvasRef
    const tempCanvas = document.createElement(CANVAS)
    tempCanvas.width = canvas.width
    tempCanvas.height = canvas.height
    const tempCtx = tempCanvas.getContext('2d')

    // Draw the current content of canvasRef onto the temporary canvas
    tempCtx.drawImage(canvas, 0, 0)

    // Binarize the alpha on the temporary canvas
    const imageData = tempCtx.getImageData(
      0,
      0,
      tempCanvas.width,
      tempCanvas.height,
    )
    const data = imageData.data
    for (let i = 3; i < data.length; i += 4) {
      data[i] = data[i] > 0 ? 255 : 0
    }
    tempCtx.putImageData(imageData, 0, 0)

    // Clear the main canvas
    ctx.clearRect(0, 0, canvas.width, canvas.height)

    // Set the absolute transparency value
    ctx.globalAlpha = 0.7

    // Draw the temporary canvas back onto the main canvas
    ctx.drawImage(tempCanvas, 0, 0)

    // Reset the global alpha value
    ctx.globalAlpha = 1

    compositeFinalMask()
      .then((finalMaskFile) => {
        onFinalMask(finalMaskFile)
      })
      .catch((error) => {
        console.error('Error compositing final mask:', error)
      })
  }

  // Save the current canvas state and selected masks to the history
  const saveCanvasState = () => {
    const canvas = canvasRef.current
    if (canvas) {
      const dataURL = canvas.toDataURL()
      setCanvasHistory([...canvasHistory, dataURL])
    }
  }

  // Undo the last action on the canvas
  const undo = () => {
    if (canvasHistory.length > 0) {
      const previousState = canvasHistory[canvasHistory.length - 1]
      const canvas = canvasRef.current
      if (canvas) {
        const ctx = canvas.getContext('2d')
        if (ctx) {
          const image = new Image()
          image.onload = () => {
            ctx.globalCompositeOperation = 'source-over'
            ctx.clearRect(0, 0, canvas.width, canvas.height)
            ctx.drawImage(image, 0, 0, canvas.width, canvas.height)
            setCanvasHistory(canvasHistory.slice(0, -1))
          }
          image.src = previousState
          unifyAlpha()
        }
      }
    }
  }

  // Composite the final mask
  const compositeFinalMask = useCallback(() => {
    return new Promise<File | null>((resolve) => {
      const canvas = canvasRef.current
      if (canvas) {
        const ctx = canvas.getContext('2d')
        if (ctx) {
          // Get the image data from the canvas
          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height)
          const data = imageData.data

          // Check if all pixels are transparent (alpha = 0)
          let isEmpty = true
          for (let i = 3; i < data.length; i += 4) {
            if (data[i] !== 0) {
              isEmpty = false
              break
            }
          }

          if (isEmpty) {
            // If the mask is empty, resolve with null
            resolve(null)
          } else {
            // If the mask is not empty, create a blob and resolve with the file
            canvas.toBlob((blob) => {
              const finalMaskFile = new File([blob], 'mask.png', {
                type: 'image/png',
              })
              resolve(finalMaskFile)
            }, 'image/png')
          }
        }
      }
    })
  }, [])

  useEffect(() => {
    const precomputeMaskData = async () => {
      const maskDataArray = await Promise.all(
        segmentationMasks.map(async (maskBase64) => {
          const maskImage = new Image()
          maskImage.src = `data:image/png;base64,${maskBase64}`
          await new Promise((resolve) => (maskImage.onload = resolve))

          const maskCanvas = document.createElement(CANVAS)
          maskCanvas.width = canvasSize.width
          maskCanvas.height = canvasSize.height

          const maskCtx = maskCanvas.getContext('2d')
          maskCtx?.drawImage(
            maskImage,
            0,
            0,
            canvasSize.width,
            canvasSize.height,
          )

          const imageData = maskCtx?.getImageData(
            0,
            0,
            maskCanvas.width,
            maskCanvas.height,
          )

          return new Uint8Array(imageData?.data.buffer ?? [])
        }),
      )

      setMaskData(maskDataArray)
    }

    precomputeMaskData()
  }, [segmentationMasks, canvasSize])

  // Fetch segmentation masks when the image changes
  useEffect(() => {
    const fetchSegmentationMasks = async () => {
      try {
        setIsProcessing(true)
        const data = new FormData()
        data.append('image', image)
        const response = await http.post('/api/get_segmentation_masks', data, {
          headers: {
            'Content-Type': 'multipart/form-data',
          },
        })
        const maskUrls = response.data?.masks

        if (maskUrls) {
          // Modify the masks to convert black channels to transparency
          const modifiedMasks = await Promise.all(
            maskUrls.map(async (maskBase64: string) => {
              const maskImage = new Image()
              maskImage.src = maskBase64
              await new Promise((resolve) => (maskImage.onload = resolve))
              const maskCanvas = document.createElement(CANVAS)
              maskCanvas.width = maskImage.width
              maskCanvas.height = maskImage.height
              const maskCtx = maskCanvas.getContext('2d')
              maskCtx?.drawImage(maskImage, 0, 0)
              const imageData = maskCtx?.getImageData(
                0,
                0,
                maskCanvas.width,
                maskCanvas.height,
              )
              const data = imageData?.data
              if (data) {
                for (let i = 0; i < data.length; i += 4) {
                  const red = data[i]
                  const green = data[i + 1]
                  const blue = data[i + 2]
                  if (red === 0 && green === 0 && blue === 0) {
                    data[i + 3] = 0 // Set alpha to 0 for black pixels
                  }
                }
                maskCtx?.putImageData(imageData, 0, 0)
              }
              return maskCanvas.toDataURL('image/png').split(',')[1]
            }),
          )

          setSegmentationMasks(modifiedMasks)
        }
      } catch (error) {
        console.error('Error fetching segmentation masks:', error)
      } finally {
        setIsProcessing(false)
      }
    }

    fetchSegmentationMasks()
  }, [image])

  // Set up canvas dimensions when the image URL changes
  useEffect(() => {
    if (baseCanvasRef.current && imageUrl) {
      const canvas = baseCanvasRef.current
      const ctx = canvas.getContext('2d')
      if (ctx) {
        const image = new Image()
        image.onload = () => {
          // Calculate the aspect ratio of the base image
          const aspectRatio = image.width / image.height

          // Calculate the maximum dimension based on the screen width, up to 512
          const screenWidth = window.innerWidth
          const maxDimension = Math.min(screenWidth, 512)

          // Calculate the new dimensions based on the maximum dimension
          let newWidth = maxDimension
          let newHeight = maxDimension
          if (aspectRatio > 1) {
            newHeight = newWidth / aspectRatio
          } else {
            newWidth = newHeight * aspectRatio
          }

          setCanvasSize({ width: newWidth, height: newHeight })

          // Set the canvas dimensions to match the resized base image
          canvas.width = newWidth
          canvas.height = newHeight

          // Draw the resized base image on the canvas
          ctx.drawImage(image, 0, 0, newWidth, newHeight)

          // Set the dimensions of other canvas elements
          if (canvasRef.current) {
            const maskCanvas = canvasRef.current
            maskCanvas.width = newWidth
            maskCanvas.height = newHeight
          }

          if (drawingCanvasRef.current) {
            const drawingCanvas = drawingCanvasRef.current
            drawingCanvas.width = newWidth
            drawingCanvas.height = newHeight
          }

          if (highlightedMaskCanvasRef.current) {
            const highlightedMaskCanvas = highlightedMaskCanvasRef.current
            highlightedMaskCanvas.width = newWidth
            highlightedMaskCanvas.height = newHeight
          }
        }
        image.src = imageUrl

        const maskCanvas = canvasRef.current

        if (defaultMask) {
          // For some reason, the canvas won't render blob URLs (format blob:http://localhost:3000/...),
          // so we need to convert the default mask to a data URL using reader before rendering it on the canvas
          const reader = new FileReader()
          reader.onload = () => {
            const defaultMaskDataUrl = reader.result as string
            const defaultMaskImage = new Image()
            defaultMaskImage.onload = () => {
              if (maskCanvas) {
                const maskCtx = maskCanvas.getContext('2d')
                if (maskCtx) {
                  maskCtx.clearRect(0, 0, maskCanvas.width, maskCanvas.height)
                  maskCtx.globalCompositeOperation = 'source-over'
                  maskCtx.drawImage(
                    defaultMaskImage,
                    0,
                    0,
                    maskCanvas.width,
                    maskCanvas.height,
                  )
                  setCanvasHistory([maskCanvas.toDataURL()])
                }
              }
            }
            defaultMaskImage.src = defaultMaskDataUrl
          }
          reader.readAsDataURL(defaultMask)
        }
      }
    }
    // defaultMask omitted from dependencies to avoid infinite loop
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [imageUrl])

  // Draw the highlighted mask on the canvas
  const drawHighlightedMask = useCallback(() => {
    if (isDrawingMode) {
      return
    }

    const canvas = highlightedMaskCanvasRef.current
    if (canvas) {
      const ctx = canvas.getContext('2d')
      if (ctx) {
        const newWidth = canvas.width
        const newHeight = canvas.height

        // Clear the canvas
        ctx.clearRect(0, 0, newWidth, newHeight)

        if (highlightedMask !== null) {
          // Draw the highlighted mask on the canvas
          const highlightedMaskBase64 = segmentationMasks[highlightedMask]
          const highlightedMaskImage = new Image()
          highlightedMaskImage.onload = () => {
            ctx.save()
            // Draw the highlighted mask image on the canvas
            ctx.drawImage(highlightedMaskImage, 0, 0, newWidth, newHeight)
            // Apply the highlight color
            ctx.globalCompositeOperation = 'source-in'
            ctx.fillStyle = 'rgba(255, 255, 255, 0.7)' // Opaque white
            ctx.fillRect(0, 0, newWidth, newHeight)
            ctx.restore()
          }
          highlightedMaskImage.src = `data:image/png;base64,${highlightedMaskBase64}`
        }
      }
    }
    // isDrawingMode omitted from dependencies to avoid infinite loop
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [segmentationMasks, highlightedMask])

  // Update the highlighted mask whenever it changes
  useEffect(() => {
    drawHighlightedMask()
  }, [drawHighlightedMask])

  // Draw the selected masks on the canvas
  const drawMasksOnCanvas = useCallback(
    (
      ctx: CanvasRenderingContext2D,
      maskIndex: number,
      x: number,
      y: number,
    ) => {
      const canvas = ctx.canvas
      const newWidth = canvas.width
      const newHeight = canvas.height

      const maskBase64 = segmentationMasks[maskIndex]
      const maskImage = new Image()

      maskImage.onload = () => {
        ctx.save()

        // Create a temporary canvas to store the mask image
        const tempCanvas = document.createElement(CANVAS)
        tempCanvas.width = newWidth
        tempCanvas.height = newHeight
        const tempCtx = tempCanvas.getContext('2d')

        tempCtx.imageSmoothingEnabled = false // Disable image smoothing

        // Draw the mask image on the temporary canvas
        tempCtx.drawImage(maskImage, 0, 0, newWidth, newHeight)

        // Check if there is a drawn pixel on the canvasRef at the clicked position
        const pixelData = ctx.getImageData(x, y, 1, 1).data
        const isSelected = pixelData[3] > 0 // Check if the alpha value is greater than 0

        if (isSelected) {
          // Remove the specific mask layer from the main canvas using the mask from the temporary canvas
          ctx.globalCompositeOperation = 'destination-out'
          ctx.drawImage(tempCanvas, 0, 0)
        } else {
          // Apply the transparent hot pink color to the white parts of the mask on the temporary canvas
          tempCtx.globalCompositeOperation = 'source-in'
          tempCtx.fillStyle = 'rgba(255, 105, 180, 0.7)' // Transparent hot pink
          tempCtx.fillRect(0, 0, newWidth, newHeight)

          // Draw the colored mask from the temporary canvas onto the main canvas
          ctx.globalCompositeOperation = 'source-over'
          ctx.drawImage(tempCanvas, 0, 0)
        }

        ctx.restore()

        unifyAlpha()
      }

      maskImage.src = `data:image/png;base64,${maskBase64}`
    },
    // unifyAlpha omitted from dependencies to avoid infinite loop
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [segmentationMasks],
  )

  const handleInputStart = (
    event:
      | React.MouseEvent<HTMLCanvasElement>
      | React.TouchEvent<HTMLCanvasElement>,
  ) => {
    saveCanvasState()
    if (isDrawingMode) {
      setIsDrawing(true)
      const { x, y } = getInputCoordinates(event)
      startDrawing(x, y)
    } else {
      const { x, y } = getInputCoordinates(event)
      handleCanvasClick(x, y)
    }
  }

  const getInputCoordinates = (
    event:
      | React.MouseEvent<HTMLCanvasElement>
      | React.TouchEvent<HTMLCanvasElement>,
  ): { x: number; y: number } => {
    const canvas = canvasRef.current
    if (canvas) {
      const rect = canvas.getBoundingClientRect()
      const clientX =
        'touches' in event ? event.touches[0].clientX : event.clientX
      const clientY =
        'touches' in event ? event.touches[0].clientY : event.clientY
      const x = Math.floor(clientX - rect.left)
      const y = Math.floor(clientY - rect.top)
      return { x, y }
    }
    return { x: 0, y: 0 }
  }

  const handleCanvasClick = (x: number, y: number) => {
    const canvas = canvasRef.current
    if (canvas) {
      const index = y * canvas.width + x
      const maskIndex = maskData.findIndex((data) => data[index * 4] > 0)

      if (maskIndex !== -1) {
        const ctx = canvas.getContext('2d')
        if (ctx) {
          drawMasksOnCanvas(ctx, maskIndex, x, y)
        }
      }
    }
  }

  const handleInputEnd = (
    event:
      | React.MouseEvent<HTMLCanvasElement>
      | React.TouchEvent<HTMLCanvasElement>,
  ) => {
    event.preventDefault()
    if (isDrawingMode && isDrawing) {
      setIsDrawing(false)
      finishDrawing()
    }
  }

  const handleInputMove = (
    event:
      | React.MouseEvent<HTMLCanvasElement>
      | React.TouchEvent<HTMLCanvasElement>,
  ) => {
    if (isDrawingMode && isDrawing) {
      const { x, y } = getInputCoordinates(event)
      continueDrawing(x, y)
    } else if (event.type === 'mousemove') {
      handleCanvasMouseMove(event as React.MouseEvent<HTMLCanvasElement>)
    }
  }

  // Handle mouse out event on the canvas
  const handleMouseOut = () => {
    setHighlightedMask(null)
  }

  const startDrawing = (x: number, y: number) => {
    const canvas = isEraserMode ? canvasRef.current : drawingCanvasRef.current
    if (canvas) {
      const ctx = canvas.getContext('2d')
      if (ctx) {
        if (isEraserMode) {
          ctx.strokeStyle = 'rgba(0, 0, 0, 1)'
          ctx.fillStyle = 'rgba(0, 0, 0, 1)'
          ctx.globalCompositeOperation = 'destination-out'
        } else {
          ctx.strokeStyle = DRAWING_COLOR
          ctx.fillStyle = DRAWING_COLOR
          ctx.globalCompositeOperation = 'source-over'
        }
        ctx.lineWidth = penSize
        ctx.lineCap = 'round'

        // Draw a dot at the clicked position
        ctx.beginPath()
        ctx.arc(x, y, penSize / 2, 0, 2 * Math.PI)
        ctx.fill()
        ctx.closePath()

        // Store the initial point
        setLastPoint({ x, y })
      }
    }
  }

  const continueDrawing = (x: number, y: number) => {
    const canvas = isEraserMode ? canvasRef.current : drawingCanvasRef.current
    if (canvas) {
      const ctx = canvas.getContext('2d')
      if (ctx && lastPoint) {
        const { x: lastX, y: lastY } = lastPoint

        if (isEraserMode) {
          ctx.strokeStyle = 'rgba(0, 0, 0, 1)'
          ctx.globalCompositeOperation = 'destination-out'
        } else {
          ctx.strokeStyle = DRAWING_COLOR
          ctx.globalCompositeOperation = 'source-over'
        }

        // Calculate the distance between the current point and the last point
        const distance = Math.sqrt((x - lastX) ** 2 + (y - lastY) ** 2)

        // Set the number of interpolation steps based on the distance
        const steps = Math.ceil(distance / 5)

        // Interpolate points between the current point and the last point
        for (let i = 1; i <= steps; i++) {
          const t = i / steps
          const interpolatedX = lastX + (x - lastX) * t
          const interpolatedY = lastY + (y - lastY) * t

          ctx.beginPath()
          ctx.moveTo(lastX, lastY)
          ctx.lineTo(interpolatedX, interpolatedY)
          ctx.stroke()
          ctx.closePath()
        }

        setLastPoint({ x, y })
      }
    }
  }

  const finishDrawing = () => {
    const canvas = canvasRef.current
    const drawingCanvas = drawingCanvasRef.current
    if (canvas && drawingCanvas && lastPoint) {
      const ctx = canvas.getContext('2d')
      const drawingCtx = drawingCanvas.getContext('2d')
      if (ctx && drawingCtx) {
        const { x, y } = lastPoint

        if (isEraserMode) {
          ctx.fillStyle = 'rgba(0, 0, 0, 1)'
          ctx.globalCompositeOperation = 'destination-out'
          // Draw a final dot at the last point on the main canvas
          ctx.beginPath()
          ctx.arc(x, y, penSize / 2, 0, 2 * Math.PI)
          ctx.fill()
          ctx.closePath()
          ctx.globalCompositeOperation = 'source-over'
        } else {
          drawingCtx.fillStyle = DRAWING_COLOR
          drawingCtx.globalCompositeOperation = 'source-over'
          // Draw a final dot at the last point on the drawing canvas
          drawingCtx.beginPath()
          drawingCtx.arc(x, y, penSize / 2, 0, 2 * Math.PI)
          drawingCtx.fill()
          drawingCtx.closePath()

          // Draw the drawing canvas onto the main canvas
          ctx.drawImage(drawingCanvas, 0, 0)

          // Clear the drawing canvas
          drawingCtx.clearRect(0, 0, drawingCanvas.width, drawingCanvas.height)
        }

        setLastPoint(null)

        unifyAlpha()
      }
    }
  }

  // Handle mouse move event on the canvas when not in drawing mode
  const handleCanvasMouseMove = (
    event: React.MouseEvent<HTMLCanvasElement>,
  ) => {
    const canvas = canvasRef.current
    if (canvas) {
      const rect = canvas.getBoundingClientRect()
      const x = Math.floor(event.clientX - rect.left)
      const y = Math.floor(event.clientY - rect.top)
      const index = y * canvas.width + x
      const maskIndex = maskData.findIndex((data) => data[index * 4] > 0)
      setHighlightedMask(maskIndex !== -1 ? maskIndex : null)
    }
  }

  return (
    <div className={`flex flex-col items-center ${containerClassName}`}>
      <div
        className='relative mb-[95px] min-w-[320px]'
        style={{
          width: canvasSize.width,
          height: canvasSize.height,
        }}
      >
        {/* Canvas Container */}
        <div className='flex justify-center w-full h-full rounded-lg'>
          {/* Base Image Canvas */}
          <canvas ref={baseCanvasRef} className='absolute z-0 rounded-lg' />
          {/* Highlighted Mask Canvas */}
          <canvas
            ref={highlightedMaskCanvasRef}
            className='absolute z-10 rounded-lg'
          />
          {/* Main Canvas for Drawing and Selecting Masks */}
          <canvas
            ref={canvasRef}
            onMouseDown={handleInputStart}
            onMouseMove={handleInputMove}
            onMouseUp={handleInputEnd}
            onMouseOut={handleMouseOut}
            onTouchStart={handleInputStart}
            onTouchMove={handleInputMove}
            onTouchEnd={handleInputEnd}
            className={`absolute z-20 rounded-lg`}
            style={{
              cursor: isDrawingMode ? getCursorStyle(penSize) : 'default',
            }}
          />
          <canvas
            ref={drawingCanvasRef}
            className='absolute z-30 rounded-lg pointer-events-none'
            width={canvasSize.width}
            height={canvasSize.height}
          />
          {/* Loading Overlay */}
          {isProcessing && (
            <div
              className='absolute rounded-lg'
              style={{
                top: 0,
                left: 0,
                width: '100%',
                height: '100%',
                display: 'flex',
                flexDirection: 'column',
                justifyContent: 'center',
                alignItems: 'center',
                backgroundColor: 'rgba(0, 0, 0, 0.5)',
                zIndex: 50, // Increase the z-index value
              }}
            >
              <Lottie
                animationData={LoadingWhite}
                style={{ width: 100, height: 100 }}
              />
              <span className='text-white mt-2'>Preparing your image...</span>
            </div>
          )}
        </div>
        {/* Toolbar */}
        <div className='w-full'>
          {/* First Row */}
          <div className='flex justify-between mb-2.5 py-4 px-6 sm:px-0'>
            {/* Left Column */}
            <div className='flex items-center justify-center'>
              {/* Drawing Mode Button */}
              <Annotation
                text={LABELS.MOTION_MASK}
                hideAnnotation={!isDrawingMode || isEraserMode}
              >
                <div
                  className={`border-2 rounded-lg p-1.5 mr-2.5 ${
                    isDrawingMode && !isEraserMode
                      ? 'bg-kaiberGreen text-black border-kaiberGreen'
                      : 'bg-transparent text-k2-gray-200 border-k2-gray-200'
                  }`}
                >
                  <button
                    onClick={() => {
                      setIsDrawingMode(true)
                      setIsEraserMode(false)
                    }}
                    className={`bg-none border-none cursor-pointer outline-none flex items-center justify-center w-7.5 h-7.5`}
                  >
                    <BiSolidBrush size={20} />
                  </button>
                </div>
              </Annotation>
              {/* Selection Mode Button */}
              <Annotation text='Select Segments' hideAnnotation={isDrawingMode}>
                <div
                  className={`border-2 rounded-lg p-1.5 mr-2.5 ${
                    !isDrawingMode && !isEraserMode
                      ? 'bg-kaiberGreen text-black border-kaiberGreen'
                      : 'bg-transparent text-k2-gray-200 border-k2-gray-200'
                  }`}
                >
                  <button
                    onClick={() => {
                      setIsDrawingMode(false)
                      setIsEraserMode(false)
                    }}
                    className='bg-none border-none cursor-pointer outline-none flex items-center justify-center w-7.5 h-7.5'
                  >
                    {/* <SiTarget size={20} /> */}
                    <svg
                      stroke='currentColor'
                      fill='currentColor'
                      role='img'
                      viewBox='0 0 24 24'
                      height='20'
                      width='20'
                      xmlns='http://www.w3.org/2000/svg'
                    >
                      <title></title>
                      <path d='M12.0005 0C18.627 0 24 5.373 24 12.0005 24 18.627 18.627 24 11.9995 24 5.373 24 0 18.627 0 11.9995 0 5.373 5.373 0 12.0005 0zm0 19.826a7.8265 7.8265 0 10-.001-15.652C7.7133 4.2246 4.2653 7.7136 4.2653 12c0 4.2864 3.448 7.7754 7.7342 7.826h.001zm0-3.9853a3.8402 3.8402 0 110-7.6803c2.1204.0006 3.839 1.7197 3.839 3.8401s-1.7186 3.8396-3.839 3.8402z'></path>
                    </svg>
                  </button>
                </div>
              </Annotation>
              {/* Eraser Mode Button */}
              <Annotation
                text='Eraser'
                hideAnnotation={!isDrawingMode || !isEraserMode}
              >
                <div
                  className={`border-2 rounded-lg p-1.5 mr-2.5 ${
                    isEraserMode
                      ? 'bg-kaiberGreen text-black border-kaiberGreen'
                      : 'bg-transparent text-k2-gray-200 border-k2-gray-200'
                  }`}
                >
                  <button
                    onClick={() => {
                      setIsDrawingMode(true)
                      setIsEraserMode(!isEraserMode)
                    }}
                    className='bg-none border-none cursor-pointer outline-none flex items-center justify-center w-7.5 h-7.5'
                  >
                    {/* <BsEraser size={20} /> */}
                    <svg
                      stroke='currentColor'
                      fill='currentColor'
                      strokeWidth='0'
                      viewBox='0 0 16 16'
                      height='20'
                      width='20'
                      xmlns='http://www.w3.org/2000/svg'
                    >
                      <path d='M8.086 2.207a2 2 0 0 1 2.828 0l3.879 3.879a2 2 0 0 1 0 2.828l-5.5 5.5A2 2 0 0 1 7.879 15H5.12a2 2 0 0 1-1.414-.586l-2.5-2.5a2 2 0 0 1 0-2.828l6.879-6.879zm2.121.707a1 1 0 0 0-1.414 0L4.16 7.547l5.293 5.293 4.633-4.633a1 1 0 0 0 0-1.414l-3.879-3.879zM8.746 13.547 3.453 8.254 1.914 9.793a1 1 0 0 0 0 1.414l2.5 2.5a1 1 0 0 0 .707.293H7.88a1 1 0 0 0 .707-.293l.16-.16z'></path>
                    </svg>
                  </button>
                </div>
              </Annotation>
            </div>
            {/* Right Column */}
            <div className='p-2.5 flex items-center justify-center'>
              {/* Undo Button */}
              <div
                className={`border-2 border-k2-gray-200 rounded-lg p-1.5 mr-2.5`}
              >
                <button
                  onClick={undo}
                  className='bg-none text-k2-gray-200 border-none cursor-pointer outline-none flex items-center justify-center w-7.5 h-7.5'
                >
                  <BiUndo size={20} />
                </button>
              </div>

              {/* Reset Button */}
              <button
                onClick={resetCanvas}
                className='bg-none text-kaiberGreen border-2 border-kaiberGreen py-1.5 px-2.5 rounded-lg cursor-pointer'
              >
                Reset
              </button>
              {/* Save Button */}
              <button
                onClick={handleCloseImageMaskModal}
                className='bg-kaiberGreen text-black border-2 border-kaiberGreen py-1.5 px-2.5 rounded-lg cursor-pointer ml-2.5'
              >
                Save
              </button>
            </div>
          </div>
          {/* Second Row */}
          <div className='justify-between mb-4 pt-2 pb-4 px-6 sm:px-0'>
            {/* Pen Size Slider */}
            <Slider
              title=''
              description=''
              units='px'
              handleChange={(value: number) => setPenSize(value)}
              min={1}
              max={50}
              defaultOption={25}
              showSubscriptionTextCTA={false}
              subscriptionTextCTA={null}
            />
          </div>
        </div>
      </div>
    </div>
  )
}
export default ImageMask
