import { select, extent, max, scaleLinear, scaleBand, easeCubicInOut, axisLeft, axisBottom, scaleQuantize } from 'd3';
import { round } from 'src/formatters';

const columnCount = 11;
const xFormat = round(2);

function handleEvent(ref) {
  return function (_, d) {
    ref.current(this, d);
  };
}

function create({ id, margin }) {
  const prefixedId = `#${id}`;

  const root = select(prefixedId)
    .append('g')
    .attr('class', 'root')
    .attr('transform', `translate(${margin.left},${margin.top})`);

  root.append('g').attr('class', 'x-axis').style('shape-rendering', 'crispEdges');

  root.append('g').attr('class', 'y-axis').style('shape-rendering', 'crispEdges');

  // add X axis label
  root
    .append('text') // text label for the x axis
    .attr('class', 'x-axis-label')
    .text('RETURN PERFORMANCE OF PEER GROUP');

  // add Y axis label
  root
    .append('text')
    .attr('transform', 'rotate(-90)')
    .attr('dy', '1em')
    .attr('class', 'y-axis-label')
    .style('text-anchor', 'middle')
    .text('# OF FUNDS');

  return root;
}

function draw({ id, data, selectedPeers, metric, size, margin, mouseEventRefs, showOutliers }) {
  if (size.width < 1 || size.height < 1) return;

  const { itemOver, itemOut, itemClick, itemTouchStart, itemTouchEnd } = mouseEventRefs;

  const visibleData = data
    .filter(d => Number.isFinite(d[metric.key]))
    .filter(d => (showOutliers ? true : !d[`${metric.key}Outlier`]));
  const xExtents = extent(visibleData, datum => datum[metric.key]);
  const xBand = (xExtents[1] - xExtents[0]) / columnCount;
  const series = Array(visibleData.length > 1 ? columnCount : 1)
    .fill(null)
    .map((_, index) => {
      const value = xExtents[0] + xBand * index;
      const from = xExtents[0] + xBand * index;
      const to = index === columnCount - 1 ? xExtents[0] + 999 : xExtents[0] + xBand * (index + 1);
      const items = visibleData.filter(d => d[metric.key] >= from && d[metric.key] < to);
      return { value, count: items.length };
    });

  let root = select(`#${id} .root`);

  if (!root.node()) {
    root = create({ id, margin });
  }

  // set the dimensions and margins of the graph
  const svgWidth = size.width - margin.left - margin.right;
  const svgHeight = size.height - margin.top - margin.bottom;

  const x = scaleBand()
    .domain(series.map(d => d.value))
    .range([0, svgWidth])
    .padding(0.1);

  const xRange = new Array(columnCount).fill(svgWidth).map((d, index) => {
    return (d / columnCount) * index;
  });

  const xQuantize = scaleQuantize()
    .domain(extent(series, d => d.value))
    .range([...xRange, svgWidth]);

  const maxY = max(series, d => d.count);
  const y = scaleLinear()
    // TODO use extent function
    .domain([0, maxY])
    .range([svgHeight, 0]);

  // add the X axis
  const xAxis = axisBottom(xQuantize)
    .tickSize(4)
    .tickFormat(d => {
      return xFormat(d);
    });

  const yAxis = axisLeft(y).ticks(maxY).tickSize(-svgWidth);
  // .tickFormat((d, i) => (d % 1 === 0 ? d : ''));

  const t = root.transition().duration(500).ease(easeCubicInOut);

  // position the x-axis
  root.select('.x-axis').attr('transform', `translate(0,${svgHeight})`).transition(t).call(xAxis);

  // position the x-axis
  root.select('.y-axis').transition(t).call(yAxis);

  // position the x-axis label
  root
    .select('.x-axis-label')
    //.transition(t)
    .attr('x', svgWidth / 2)
    .attr('y', svgHeight + margin.bottom - 20);

  // position the y-axis label
  root
    .select('.y-axis-label')
    //.transition(t)
    .attr('y', 0 - margin.left)
    .attr('x', 0 - svgHeight / 2);

  const rects = root.selectAll('.bar').data(series, d => d.value);

  // remove old bars
  rects.exit().transition(t).attr('height', 0).attr('y', svgHeight).remove();

  // update existing bars
  rects
    .on('mouseover', handleEvent(itemOver))
    .on('mouseout', handleEvent(itemOut))
    .on('click', handleEvent(itemClick))
    .on('touchstart', handleEvent(itemTouchStart))
    .on('touchend', handleEvent(itemTouchEnd))
    .transition(t)
    .attr('x', d => x(d.value))
    .attr('y', d => y(d.count))
    .attr('width', x.bandwidth())
    .attr('height', d => svgHeight - y(d.count));

  // add new bars
  rects
    .enter()
    .append('rect')
    .attr('class', 'bar')
    .attr('height', 0)
    .attr('y', svgHeight)
    .attr('x', d => x(d.value))
    .attr('width', x.bandwidth())
    .on('mouseover', handleEvent(itemOver))
    .on('mouseout', handleEvent(itemOut))
    .on('click', handleEvent(itemClick))
    .on('touchstart', handleEvent(itemTouchStart))
    .on('touchend', handleEvent(itemTouchEnd))
    .transition()
    .attr('height', d => svgHeight - y(d.count))
    .attr('y', d => y(d.count));

  // add the points
  // map the selectedValues into bands
  const selectedValues = Object.keys(selectedPeers)
    .map(p => {
      const fund = visibleData.find(d => d.fundId === p);
      if (fund) {
        const value = fund[metric.key];
        const band = series.find(
          (item, index) => value >= item.value && value < (index === series.length - 1 ? 99 : series[index + 1].value)
        );
        return {
          value: band.value,
          color: selectedPeers[p],
          fundId: fund.fundId,
          index: 0,
        };
      }
      return null;
    })
    .filter(sv => !!sv);

  // give the points that have the same value an incrementing index so we can stack them on top of each other
  for (let i = selectedValues.length - 1; i >= 0; i--) {
    const sv = selectedValues[i];
    // are there any other selectedValues with the same value?
    const duplicateValues = selectedValues.filter(v => v.value === sv.value && v.fundId !== sv.fundId);

    if (duplicateValues.length > 0) {
      const maxIndex = Math.max(...duplicateValues.map(d => d.index));
      sv.index = maxIndex + 1;
    }
  }

  const points = root.selectAll('.point').data(selectedValues, d => d.fundId);

  //remove old points
  points.exit().transition(t).attr('opacity', 0).attr('cy', svgHeight).remove();

  // update existing points
  points
    .transition(t)
    .attr('opacity', 1)
    .attr('cy', d => (d.index === 0 ? svgHeight - 15 : svgHeight - d.index * 15))
    .attr('cx', d => x(d.value) + x.bandwidth() / 2);

  // add new points
  points
    .enter()
    .append('circle')
    .attr('class', 'point')
    .attr('cx', d => x(d.value) + x.bandwidth() / 2)
    .attr('cy', svgHeight)
    .attr('r', 6.5)
    .attr('opacity', 0)
    .attr('fill', d => d.color)
    .transition(t)
    .attr('opacity', 1)
    .attr('cy', d => (d.index === 0 ? svgHeight - 15 : svgHeight - d.index * 15));
}

export default draw;
