import { useMemo } from 'react'

import { scaleLinear, scaleUtc } from '@visx/scale'

import { ensureValidDomain, getDomain, getLabel, getMirrorDomain, getNumberLabel, getValue, isValidDomain } from '../lib'

import { find, flatten, get, sum } from 'lodash-es'

import { getMergedLabels } from '../lib'

import { CombinedChartProps, QuantitativeCustomAxisChartTypeProps, SizeRequirements, SVGDatumType } from '../typings'

import { AxisScale, AxisScaleOutput } from '@visx/axis'

import { GenericAxisChartProps } from '../GenericAxisChart'

import { Accessor } from '@visx/shape/lib/types'


type QuantitativeAxesHookProps<T extends SVGDatumType = SVGDatumType> = (
  QuantitativeCustomAxisChartTypeProps &
  SizeRequirements &
  Pick<
    CombinedChartProps<T>, 'data' | 'datasets'
  >
)


function makeAccessor<T extends SVGDatumType = SVGDatumType>(
  scale: AxisScale<AxisScaleOutput>,
  accessor: Accessor<T, number>,
  scaleType: 'linear' | 'time'
): Accessor<T, number> {
  if( scaleType === 'linear' ){
    return (d): number => (
      scale(accessor(d) || 0) as number
    )
  }
  return (d): number => (
    scale(new Date(accessor(d) as number)) as number
  )
}


export function useQuantitativeAxes<T extends SVGDatumType = SVGDatumType>({
  data,
  datasets,
  width,
  height,
  stacked,
  xScaleType = 'linear',
  yScaleType = 'linear',
  domain,
  defaultXDomain,
  mirrorYDomain,
  xOffset,
  axisLine,
  axisText,
  verticalMargin = 0,
  roundXDomain = true,
  roundYDomain = true,
  arbitraryTimeline = false,
}: QuantitativeAxesHookProps<T>): (
  Pick<
    GenericAxisChartProps<T>,
    'xScale' | 'yScale'
  > & {
    xAccessor: Accessor<T, number>
    yAccessor: Accessor<T, number>
    xOffset: number
    yOffset: number
    labels: number[]
  }
) {

  const canonicalXOffset = (
    xOffset !== undefined ?
      xOffset :
      (axisLine && axisText) ?
        36 :
        0
  )

  const yMax = height - verticalMargin

  return useMemo(() => {
    let flattenedData: T[] = []

    let labels: number[] = []

    if( datasets ){

      labels = getMergedLabels(datasets) as number[]

      if( !stacked ){

        flattenedData = flatten(datasets.map( d => d.data ))

      }else{

        flattenedData = labels.map( label => ({
          label,
          value: sum(
            datasets.map( d => 
              get(
                find(d.data, dt => dt.label === label),
                'value'
              ) || 0
            )
          ),
        })) as T[]
      }

    }else if( data ){

      flattenedData = data
      labels = flattenedData.map(getNumberLabel)

    }

    const xScaleGenerator = xScaleType === 'time' ? scaleUtc : scaleLinear

    const xDataDomain = getDomain(
      flattenedData,
      d => Number(getLabel(d)),
      [Infinity, -Infinity],
    )

    const xDomain = (
      defaultXDomain && !isValidDomain(xDataDomain) ?
        defaultXDomain :
        ensureValidDomain(xDataDomain)
    )

    const xScale = xScaleGenerator({
      range: [canonicalXOffset, width],
      domain: xDomain,
      nice: arbitraryTimeline ? false : roundXDomain,
    }) as AxisScale<AxisScaleOutput>

    const yScaleGenerator = yScaleType === 'time' ? scaleUtc : scaleLinear

    const yScale = yScaleGenerator({
      range: [yMax, 0],
      domain: ensureValidDomain(domain || (mirrorYDomain ? getMirrorDomain(flattenedData) : getDomain(flattenedData))),
      nice: roundYDomain,
    }) as AxisScale<AxisScaleOutput>

    const xAccessor = makeAccessor(
      xScale,
      getLabel as Accessor<T, number>,
      xScaleType,
    )

    const yAccessor = makeAccessor(
      yScale,
      getValue as Accessor<T, number>,
      yScaleType,
    )

    return {
      xScale,
      xAccessor,
      yScale,
      yAccessor,
      xOffset: canonicalXOffset,
      yOffset: 0,
      labels,
    }

  },
  [
    xScaleType, yScaleType, data, datasets, canonicalXOffset, width, yMax,
    domain, mirrorYDomain, defaultXDomain, roundXDomain, roundYDomain, stacked,
    arbitraryTimeline,
  ])

}