import "./Canvas.css"
import { useRef, useState, useEffect, forwardRef, useImperativeHandle} from "react"
import { drawCanvasBounds, drawScrollState, getTouchDirection, getStrokeRadius, getCompositeBounds, getBrushCursor, drawSketches, getCanvasBounds, trimCanvas, getSetCursor, shouldAddToQueue, loadNewImage, trackTransforms, getCanvasPos, clearCanvas } from "../CanvasUtils"
import { TOOLMODE, TOOLMAP_ARRAY, DEFAULT_PEN_PRESSURE } from "../Constants"
import CanvasScroller from "./CanvasScroller"
import { useElementSize } from '@custom-react-hooks/use-element-size'
import CanvasInstructions from "./CanvasInstructions"
import { drawCanvasExport, saveImage } from "../ExportUtils"

const scale = window.devicePixelRatio
const scaleFactor = 1.1

function CanvasLayer ({name, layerRef, width, height, scale}) {
    return (
        <canvas
            id={name}
            className="canvas-layer"
            style={{width: `${width}px`, height: `${height}px`}}
            ref={layerRef}
            width={width * scale}
            height={height * scale}
        />
    )
}

const Canvas = forwardRef(({disablePointerEvents, scrollCanvasRef, color, brushsize, toolType, setShowActionBar, isGenerate, errorMsg, setErrorMsg}, ref) => {
    
    const canvasBase = useRef()
    const canvasContext = useRef()
    const scrollbarVertical = useRef() 
    const scrollbarHorizontal = useRef()
    const instructRef = useRef()
    let didPan = useRef(false)
    let didZoom = useRef(false)

    const imagesRef = useRef()
    const segmentsRef = useRef()
    const scribblesRef = useRef()
    const inpaintsRef = useRef()

    const [canvasWrapperRef, canvasSize] = useElementSize()
    const [shouldTrimBounds, setShouldTrimBounds] = useState(false)
    const [canvasBounds, setCanvasBounds] = useState({x: 0, y: 0, w: 512, h: 512})

    const [scribbles, setScribbles] = useState([])
    const [segments, setSegments] = useState([])
    const [inpaints, setInpaints] = useState([])
    const [images, setImages] = useState([])

    const [undoQueue, setUndoQueue] = useState([])
    const [redoQueue, setRedoQueue] = useState([])
    const [updatedQueue, setUpdatedQueue] = useState(false)

    const valuesMapping = {
        scribbles: {
            values: scribbles,
            setValues: setScribbles,
            canvas: scribblesRef
        },
        segments: {
            values: segments,
            setValues: setSegments,
            canvas: segmentsRef
        },
        inpaints: {
            values: inpaints,
            setValues: setInpaints,
            canvas: inpaintsRef
        },
        images: {
            values: images,
            setValues: setImages,
            canvas: imagesRef
        }
    }

    const canvasLayers = [{
        canvas: imagesRef,
        name: "images",
        values: images,
        setValues: setImages
    }, {
        canvas: inpaintsRef,
        name: "inpaints",
        values: inpaints,
        setValues: setInpaints
    }, {
        canvas: segmentsRef,
        name: "segments",
        values: segments,
        setValues: setSegments
    }, {
        canvas: scribblesRef,
        name: "scribbles",
        values: scribbles,
        setValues: setScribbles
    }]

    let isPointerDown = false
    let currStroke = null
    let strokes = []
    let bounds = canvasBounds
    let cursorPos = { 
        x: canvasSize.width / 2,
        y: canvasSize.height / 2
    }

    let isDragging = false
    let dragStart = { x: 0, y: 0 }
    let initialPinchDistance = null
    let prevTouch1 = null
    let prevTouch2 = null
    
    let moveCanvas = null
    let pressure = DEFAULT_PEN_PRESSURE
    let isGesture = false
    let dontZoom = false

    useImperativeHandle(ref, () => ({
        zoom(amount, level){ setZoom(amount, level) },
        undo() { undoRedo(undoQueue, setUndoQueue, setRedoQueue) },
        redo() { undoRedo(redoQueue, setRedoQueue, setUndoQueue) },
        getImage() { return (images[0])},
        downloadImage() { saveImage(images)},
        exportCanvas() {
            if ([...scribbles, ...segments, ...inpaints].length === 0) { return null }
            const { w, h, scaleFactor } = canvasBounds
            return {
                width: Math.ceil(w * scaleFactor),
                height: Math.ceil(h * scaleFactor),
                scribbles: drawCanvasExport(scribbles, bounds, scaleFactor),
                segments: drawCanvasExport(segments, bounds, scaleFactor),
                images: drawCanvasExport(images, bounds, scaleFactor),
                inpaints: drawCanvasExport(inpaints, bounds, scaleFactor, images[0] || null),
            }
        },
        async addImage(imgData) {
            const newImage = await loadNewImage(imgData, canvasBounds)
            if (newImage) {
                setImages([newImage])
                setScribbles([])
                setSegments([])
                setInpaints([])
            }
        }
    }))

    useEffect(() => {
        canvasContext.current = canvasBase.current.getContext("2d")
        trackTransforms(canvasContext.current)
    }, [])

    useEffect(() => {
        if (updatedQueue) {
            setUpdatedQueue(false)
            draw()
        } else {
            const newState = { scribbles, segments, images, inpaints }
            if (shouldAddToQueue(undoQueue, newState)) {
                setUndoQueue(prev => [...prev, newState])
                setRedoQueue([])
            }
        }
        draw()
        setErrorMsg(null)
    }, [scribbles, segments, images, inpaints])

    useEffect(() => {
        if (updatedQueue) {
            const lastState = undoQueue.length > 0 ? undoQueue[undoQueue.length - 1] : {
                scribbles: [],
                segments: [],
                images: [],
                inpaints: []
            };

            ["scribbles", "segments", "images", "inpaints"].map(mode => {
                valuesMapping[mode].setValues(lastState[mode])
            })
        }
        setShowActionBar(undoQueue.length > 0 || redoQueue.length > 0) 
    }, [undoQueue])

    useEffect(() => {
        draw()
    }, [canvasSize])

    useEffect(() => {
        if (shouldTrimBounds) {
            trimCanvasBounds()
            setShouldTrimBounds(false)
        } else {
            draw()
        }
    }, [canvasBounds])

    useEffect(() => {
        setBrushCursor(brushsize)
    }, [brushsize])

    const setBrushCursor = (brushsize) => {
        const origin = canvasContext.current.invertedPoint(0, 0)
        const dist = canvasContext.current.invertedPoint(brushsize, brushsize)
        const size = Math.abs(origin.x - dist.x)
        getSetCursor(canvasBase.current, getBrushCursor(size))
    }
    const trimCanvasBounds = () => {
        let tempBounds = null;

        ["images", "segments", "scribbles"].map(mode => {
            const {values, setValues} = valuesMapping[mode]
            
            if (values.length > 0) {
                
                const valuesBounds =  trimCanvas(mode, values, canvasBounds)
                
                if (valuesBounds === null) 
                    setValues([])
                else 
                    tempBounds = tempBounds && valuesBounds ? getCompositeBounds(tempBounds, valuesBounds): valuesBounds
            }
        })

        if (tempBounds) {
            setCanvasBounds(tempBounds)
        }
    }

    const draw = () => {

        ["images", "inpaints", "segments", "scribbles"].forEach(mode => {
            const { values, canvas } = valuesMapping[mode]
            const context = clearCanvas(canvas.current)
            const sketches = TOOLMODE[toolType] === mode || toolType === "erase" ? [...values, ...strokes] : values
            drawSketches(context, sketches)
        })
        
        //drawCanvasBounds(canvasBase.current, bounds)
        drawScrollState(scrollbarVertical.current, scrollbarHorizontal.current, canvasBase.current, bounds)
    }

    const onPointerDown = (e, constrainX = false, constrainY = false, isScrolling = false) => {
        
        e.stopPropagation()
        disablePointerEvents(true)
        
        cursorPos = getCanvasPos(e, scale)

        if (moveCanvas || isScrolling) {
            
            isDragging = true
            dragStart = canvasContext.current.transformedPoint(cursorPos.x, cursorPos.y)

            if (!constrainX && !constrainY) getSetCursor(canvasBase.current, "grabbing")

        } else {

            isPointerDown = true            
            pressure = e.pressure || pressure

            const {x, y} = canvasContext.current.transformedPoint(cursorPos.x, cursorPos.y)

            currStroke = {
                mode: TOOLMODE[toolType],
                type: toolType,
                thickness: getStrokeRadius(brushsize, pressure, toolType),
                color: color,
                points: [{x, y}]
            }

            strokes.push(currStroke)
            bounds = getCanvasBounds(x, y, bounds, currStroke.thickness)
            draw()
        }
    }

    const onPointerUp = (e,  isScrolling = false) => {
        e.stopPropagation()
        disablePointerEvents(false)

        if ((moveCanvas || isScrolling) && isDragging) {
            dragStart = null
            isDragging = false
            isGesture = false
            initialPinchDistance = null
        } else {

            isPointerDown = false
            
            if (currStroke) {
                TOOLMAP_ARRAY[currStroke.type].map(mode => {
                    valuesMapping[mode].setValues(prev => [...prev, ...strokes])
                })
                setCanvasBounds(bounds)
                setShouldTrimBounds(true)
            }
        }
    }

    const onPointerMove = (e,  constrainX = false, constrainY = false, isScrolling = false) => {
        
        e.stopPropagation()
        canvasBase.current.focus()

        if (isGesture) return

        cursorPos = getCanvasPos(e, scale)

        if ((moveCanvas || isScrolling) && isDragging) {
            if (!constrainX && !constrainY) 
                getSetCursor(canvasBase.current, "grabbing")

			dragCanvas(cursorPos, constrainX, constrainY)
        } 
        else if (isPointerDown && currStroke) {

            const new_pressure = e.pressure || DEFAULT_PEN_PRESSURE
            
            if (new_pressure === 0) onPointerUp(e)   
            else {

                const {x, y} = canvasContext.current.transformedPoint(cursorPos.x, cursorPos.y)
            
                currStroke.points.push({ x, y })
                bounds = getCanvasBounds(x, y, bounds, currStroke.thickness)
                draw()
                            
                if (Math.abs(new_pressure - pressure) > 0.1) onPointerDown(e)
            }     
        }
    }

    const dragCanvas = (dest, constrainX = false, constrainY = false) => {
        if (dragStart){
            
            const { x, y } = canvasContext.current.transformedPoint(dest.x, dest.y)
         
            const translateX = constrainX ? 0 : x - dragStart.x
            const translateY = constrainY ? 0 : y - dragStart.y;
            
            [canvasBase, scribblesRef, inpaintsRef, segmentsRef, imagesRef].map(canvas => {
                const context = canvas.current.getContext("2d")
                context.translate(translateX, translateY);
            })
            draw()

            if (didPan.current === false) {
                didPan.current = true
                instructRef.current.didDo("move canvas")
            }
        }
    }

    const handleTouch = (e) => {
        e.stopPropagation()

         if (e.touches.length === 2) {
            isGesture = true
            currStroke = null
            strokes = []
            
            if (e.type === "touchstart") {
                moveCanvas = true
                onPointerDown(e)
                moveCanvas = false
            } 
            else if (e.type === "touchmove") { handleGesture(e) }

        } else {
            isGesture = false
            initialPinchDistance = null
            dontZoom = false
        }
    }


    const handleGesture = (e) => {
        e.stopPropagation()
        e.preventDefault()
        
        let touch1 = { x: e.touches[0].clientX, y: e.touches[0].clientY }
        let touch2 = { x: e.touches[1].clientX, y: e.touches[1].clientY }
        
        // This is distance squared, but no need for an expensive sqrt as it's only used in ratio
        let currentDistance = (touch1.x - touch2.x)**2 + (touch1.y - touch2.y)**2
        
        let touchDir1 = getTouchDirection(touch1, prevTouch1 || touch1)
        let touchDir2 = getTouchDirection(touch2, prevTouch2 || touch2)
        
        if (initialPinchDistance === null){

            initialPinchDistance = currentDistance
        
        } else if (touchDir1.x === touchDir2.x && touchDir1.y === touchDir2.y) {  

            dragCanvas(getCanvasPos(e, scale))
            dontZoom = true

        } else {
            
            let delta = (currentDistance / initialPinchDistance)
            
            if (dontZoom && Math.abs(delta) < 1) return

            dontZoom = false
            delta *= currentDistance > initialPinchDistance ? 0.25 : -0.5
            setZoom(delta)
        }

        prevTouch1 = touch1
        prevTouch2 = touch2
        
    }

    const onKeyDown = e => {
        if (e.code === "Space") {
            moveCanvas = true

            if (getSetCursor(canvasBase.current) !== "grabbing")
                getSetCursor(canvasBase.current, "grab")
        }
    }

    const onKeyUp = e => {
        if (e.code === "Space") {
            moveCanvas = false
            setBrushCursor(brushsize)
        }
    }

    const onMouseOut = e => onPointerUp(e)

    const undoRedo = (queue, setQueue, setInverseQueue) => {
        if (queue.length > 0) {
            let temp = [...queue]
            const lastItem = temp.pop()
            setInverseQueue(prev => [...prev, lastItem])
            setQueue([...temp])
            setUpdatedQueue(true)
        }
    }

    
    const setZoom = (delta, level) => {

        [canvasBase, scribblesRef, inpaintsRef, segmentsRef, imagesRef].map(canvas => {
            const context = canvas.current.getContext("2d")
            const pt = canvasContext.current.transformedPoint(cursorPos.x, cursorPos.y)
            context.translate(pt.x,pt.y)
            const factor = delta ? Math.pow(scaleFactor, delta) : 1
            if (level) {
                context.setTransform(1, 0, 0, 1, 0, 0)
            } else {
                context.scale(factor,factor)
                context.translate(-pt.x,-pt.y)
            }

            setBrushCursor(brushsize)
        })
       
        draw()

        if (didZoom.current === false) {
            didZoom.current = true
            instructRef.current.didDo("zoom canvas")
        }
    }

    const onWheel = e => { 
        setZoom(e.deltaY/40) 
    }

    return (
        <div className={`canvas-wrapper ${isGenerate && "generating"}`}
            ref={canvasWrapperRef} 
        >
            <canvas
                ref={canvasBase}
                className="canvas-base"
                style={{width:`${canvasSize.width}px`, height:`${canvasSize.height}px`}}
                width={canvasSize.width * scale}
                height={canvasSize.height * scale}
                onPointerDown={onPointerDown}
                onPointerUp={onPointerUp}
                onPointerMove={onPointerMove}
                onPointerOut={onMouseOut}
                onKeyDown={onKeyDown}
                onKeyUp={onKeyUp}
                onWheel={onWheel}
                onContextMenu={(e)=> e.preventDefault()}
                tabIndex={"0"}   
                onTouchStart={handleTouch}
                onTouchEnd={handleTouch}
                onTouchMove={handleTouch}
            ></canvas>

            {canvasLayers && canvasLayers.map((layer, i) => (
                <CanvasLayer
                    key={i}
                    name={layer.name}
                    layerRef={layer.canvas}
                    width={canvasSize.width}
                    height={canvasSize.height}
                    scale={scale}
                />
            ))}

            <CanvasScroller 
                scrollCanvasRef={scrollCanvasRef}
                onCanvasDown={onPointerDown}
                onCanvasUp={onPointerUp}
                onCanvasMove={onPointerMove}
                scrollbarHorizontal={scrollbarHorizontal}
                scrollbarVertical={scrollbarVertical}
                canvasSize={canvasSize}
            />

            <CanvasInstructions
                ref={instructRef}
                errorMsg={errorMsg}
            />
            
        </div>
    )
})

export default Canvas