import React, { useRef, useState, useEffect, useMemo } from 'react'
import PropTypes from 'prop-types'
import { motion, useViewportScroll, useTransform, useSpring, transform } from 'framer-motion'
import { useWindowSize } from '@react-hook/window-size'

/**
 * Parallax element while scrolling. Uses framer-motion springs for each property
 */
const ScrollParallax = ({
  children,
  as = 'div',
  style,
  keyframes,
  ease,
  transition,
  horizontal = false,
  disabled,
  times = {},
  startOffset = 0,
  stopOffset = 0,
  debug = false,
  detectOverflow = true,
  clamp = true,
  speed,
  motionValue,
  ...props
}) => {
  const [viewportWidth, viewportHeight] = useWindowSize({ initialHeight: 0, initialWidth: 0, wait: 200 })

  const ref = useRef()
  const calcEl = useRef()

  // default scroll value
  if (!motionValue) {
    const { scrollY } = useViewportScroll()
    motionValue = scrollY
  }

  const springs = useRef({}).current
  const [bounds, setBounds] = useState([0, 1])

  // deep copy keyframes for initial shape
  const inputRange = useMemo(() => {
    return JSON.parse(JSON.stringify(keyframes))
  }, [keyframes, horizontal])

  const progress = useTransform(motionValue, bounds, [0, 1], { clamp })
  const springProgress = useSpring(progress, transition) // FIX can't allow overflow on each value when using one spring

  // debug
  // useEffect(() => progress.onChange((val) => console.log('p', val)), [])

  // generate easings from shortcut value
  const easings = useMemo(() => {
    const easings = {}
    if (!ease) return easings
    Object.keys(keyframes).forEach((prop) => {
      if (Array.isArray(ease[prop])) {
        easings[prop] = ease[prop]
      } else if (typeof ease === 'function' || typeof ease[prop] === 'function') {
        const value = ease[prop] || ease
        // apply to all frames
        easings[prop] = [...Array(keyframes[prop].length - 1)].map(() => value)
      }
    })
    return easings
  }, [ease, keyframes])

  const outputRange = useMemo(() => {
    const outputRange = JSON.parse(JSON.stringify(keyframes)) // deep clone

    // speed shorthand
    if (speed) {
      if (horizontal) {
        outputRange.x = [
          transform(speed, [0.5, 1, 2], ['-50vw', '0vw', '50vw'], { clamp: false }),
          transform(speed, [0.5, 1, 2], ['50vw', '0vw', '-50vw'], { clamp: false }),
        ]
      } else {
        outputRange.y = [
          transform(speed, [0.5, 1, 2], ['-50vh', '0vh', '50vh'], { clamp: false }),
          transform(speed, [0.5, 1, 2], ['50vh', '0vh', '-50vh'], { clamp: false }),
        ]
      }
    }

    // Generate times/inputRange for all props in keyframe
    Object.keys(outputRange).forEach((prop) => {
      // check if keyframe times provided, if not autogenerate with equal spacing
      if (!times[prop]) {
        if (Array.isArray(times) && times.length === outputRange[prop].length) {
          times[prop] = times
        } else {
          // auto gen inputRage for number of keyframes
          const firstProp = outputRange[prop]
          const l = firstProp.length
          times[prop] = firstProp.map((_, index) => (1 / (l - 1)) * index)
        }
      }

      // create input range for each time
      times[prop].forEach((_keyframe, i) => {
        if (!inputRange[prop]) inputRange[prop] = []
        inputRange[prop][i] = _keyframe
      })
    })

    return outputRange
  }, [keyframes, times, horizontal, speed, inputRange])

  const transformer = (prop) => (val) => {
    const ease = easings[prop]
    return transform(val, inputRange[prop], outputRange[prop], { clamp, ease })
  }

  // create a motionvalue spring for each prop to keyframes
  Object.keys(outputRange).forEach((prop) => {
    springs[prop] = useTransform(transition ? springProgress : progress, transformer(prop))
  })

  // evaluate css to pixels using DOM and getComputedStyle
  const toPixels = (val) => {
    const axis = horizontal ? 'left' : 'top'
    calcEl.current.style[axis] = isNaN(val) ? val : `${val}px`
    const px = parseFloat(window.getComputedStyle(calcEl.current)[axis])
    return px
  }

  // Calculate position and overflow
  useEffect(() => {
    var rect = ref.current.getBoundingClientRect()
    const scrollPos = horizontal ? motionValue.get() : window.pageYOffset

    let startOffsetPx = toPixels(startOffset, rect)
    let stopOffsetPx = toPixels(stopOffset, rect)

    if (detectOverflow) {
      if (!horizontal && outputRange?.y) {
        startOffsetPx = startOffsetPx || Math.min(0, toPixels(outputRange?.y[0], rect))
        stopOffsetPx = stopOffsetPx || Math.max(0, toPixels(outputRange?.y[outputRange.y.length - 1], rect))
      }
      if (horizontal && outputRange?.x) {
        startOffsetPx = startOffsetPx || toPixels(outputRange?.x[0], rect)
        stopOffsetPx = stopOffsetPx || toPixels(outputRange?.x[outputRange.x.length - 1], rect)
      }
    }

    // Y bounds
    const top = rect.top + startOffsetPx + scrollPos - viewportHeight
    const bottom = rect.top + scrollPos + rect.height + stopOffsetPx
    const left = rect.left + startOffsetPx + scrollPos - viewportWidth
    const right = rect.left + scrollPos + rect.width + stopOffsetPx
    setBounds(horizontal ? [left, right] : [top, bottom])
  }, [viewportWidth, viewportHeight, horizontal, motionValue, motionValue.get()]) // detect update scroll on re-render to support reload when scrolled down

  const springStyles = disabled ? {} : springs
  const MotionElement = motion[as]

  return (
    <div ref={ref} style={{ position: 'relative', outline: debug && '1px dashed orange', ...style }} {...props}>
      {debug &&
        // <div
        //   style={{
        //     position: 'absolute',
        //     top: keyframes.y[0] || 0,
        //     left: keyframes.x[0] || 0,
        //     right: '-' + keyframes.x[keyframes.x.length - 1] || 0,
        //     bottom: '-' + keyframes.y[keyframes.y.length - 1] || 0,
        //     background: 'orange',
        //     opacity: 0.2,
        //   }}
        // ></div>
        null}
      <MotionElement style={{ willChange: 'transform', ...springStyles }}>{children}</MotionElement>

      {/* El used to evaluate css sizes to pixels */}
      <div
        ref={calcEl}
        style={{
          position: 'absolute',
          top: 0,
          left: 0,
          right: 0,
          bottom: 0,
          pointerEvents: 'none',
          visibility: 'hidden',
          zIndex: -1,
        }}
      />
    </div>
  )
}

ScrollParallax.defaultProps = {
  // transition: {
  //   type: 'spring',
  //   stiffness: 50,
  //   damping: 20,
  //   restDelta: 0.001,
  //   speedDelta: 0.001,
  // },
  disabled: false,
  keyframes: {},
}

ScrollParallax.propTypes = {
  style: PropTypes.object,
  horizontal: PropTypes.bool,
  disabled: PropTypes.bool,
  times: PropTypes.any, // props and values to keyframes between
  keyframes: PropTypes.object, // props and values to keyframes between
  transition: PropTypes.oneOfType([PropTypes.bool, PropTypes.object]), // framer-motion transition settings
  as: PropTypes.string,
  ease: PropTypes.oneOfType([PropTypes.func, PropTypes.object]),
  startOffset: PropTypes.string,
  stopOffset: PropTypes.string,
  debug: PropTypes.bool,
  detectOverflow: PropTypes.bool,
  clamp: PropTypes.bool,
  motionValue: PropTypes.object,
  speed: PropTypes.number,
}

export default ScrollParallax
