import * as d3 from 'd3';
import { takeDistributed } from 'src/utils/array';

function getCanvas({ svg, margin }) {
  if (!svg.node()) return {};

  const svgBox = svg.node().getBoundingClientRect();

  const reference = svg
    .selectAll('.reference')
    .data([0])
    .join('rect')
    .attr('class', 'reference')
    .style('fill', '#ffffff11')
    .style('stroke', '#cccccc33')
    .style('transform', `translate(${margin.left}, ${margin.top}`)
    .attr('width', `calc(${svgBox.width}px - ${margin.left} - ${margin.right})`)
    .attr('height', `calc(${svgBox.height}px - ${margin.top} - ${margin.bottom})`);

  const box = reference.node().getBoundingClientRect();

  const top = box.top - svgBox.top;
  const left = box.left - svgBox.left;
  const right = left + box.width;
  const bottom = top + box.height;

  reference.remove();

  return {
    width: box.width,
    height: box.height,
    top,
    left,
    right,
    bottom,
    margin: {
      left,
      top,
      right: svgBox.width - right,
      bottom: svgBox.height - bottom,
    },
  };
}

function getXAxis({ data, range }) {
  const [minDomain, maxDomain] = d3.extent(data, d => (Number.isFinite(+d.vintage) ? +d.vintage : 0));
  const spread = maxDomain - minDomain;

  const domain = new Array(spread + 1).fill(minDomain).map((d, i) => `${d + i}`);

  const width = range[1] - range[0];
  const xl = width > 600 && 8;
  const lg = width > 340 && 5;
  const md = width > 220 && 3;
  const sm = 2;
  const tickCount = Math.min(xl || lg || md || sm, domain.length);
  const tickLabels = takeDistributed(tickCount)(domain);

  const xScale = d3.scaleBand().domain(domain).range(range).padding(0.1);
  const xAxis = d3.axisBottom(xScale).tickFormat(x => (tickLabels.includes(x) ? x : ''));
  return { xAxis, xScale };
}

function getYAxis({ range, domain, width, height, yAxisFormatter }) {
  const yScale = d3.scaleLinear().domain(domain).range(range);
  const yAxis = d3
    .axisRight(yScale)
    .tickSize(width)
    .tickFormat(yAxisFormatter)
    .ticks(height > 200 ? 8 : 4);
  return { yAxis, yScale };
}

function getColorScale({ data, colorRange = ['#18bfed', '#0f3a45'] }) {
  const domain = [...new Set(data.flatMap(d => d.items.map(i => i.strategyId)))];

  const seq = domain.map((_, i) => i);
  const colorScale = d3.scaleLinear().domain(d3.extent(seq)).range(colorRange);

  const range = seq.map(colorScale);

  return d3
    .scaleOrdinal()
    .domain([...domain])
    .range(range)
    .unknown('#ccc');
}

function getBarData({ data }) {
  return data.map(d => {
    const items = d.items
      .filter(item => !item.hidden)
      .map((item, i, group) => {
        const start = i === 0 ? 0 : d3.sum(group.slice(0, i), g => +g.commitmentAmount);

        const end = +item.commitmentAmount + start;

        return {
          ...item,
          start,
          end,
        };
      });
    return {
      ...d,
      items,
    };
  });
}

function getLegendData({ data }) {
  const allItems = data.flatMap(d => d.items);

  return [...new Set(data.flatMap(d => d.items.map(i => i.strategyId)))].map(strategyId => {
    const { strategy } = allItems.find(i => strategyId === i.strategyId);

    return {
      key: strategyId,
      label: strategy,
    };
  });
}

function polyPoints(arr = [[]]) {
  return arr.map(p => p.join(',')).join(' ');
}

function drawCalendarChart({ data, yAxisFormatter }) {
  return function ({ svg, width, height, margin }) {
    if (!data) return;
    if (data.length < 1) return;
    if (!width || !height) return;

    const legendData = getLegendData({ data });
    const barData = getBarData({ data });

    const canvas = getCanvas({ svg, margin });

    const legendColumns = width > 300 ? 2 : 1;
    const legendRowHeight = 24;
    const legendHeight = (legendData.length / legendColumns) * legendRowHeight;
    const legendPostition = canvas.bottom - legendHeight + 10;

    const maxBarHeight = canvas.height - legendHeight;
    const xAxisPosition = maxBarHeight + 3;
    const maxCommitmentAmount = d3.max(data, d => d.totalCommitmentAmount);

    const yAxisWidth = 16 * 2;

    const { xScale, xAxis } = getXAxis({
      data,
      range: [canvas.left + yAxisWidth, canvas.right],
      width,
    });

    const { yScale, yAxis } = getYAxis({
      range: [canvas.top, maxBarHeight],
      domain: [maxCommitmentAmount, 0],
      width: canvas.width,
      height: maxBarHeight,
      yAxisFormatter,
    });

    const blueScale = getColorScale({
      data,
      colorRange: ['#18bfed', '#0f3a45'],
    });

    const purpleScale = getColorScale({
      data,
      colorRange: ['#b049fd', '#45008a'],
    });

    svg.style('font-size', `0.8em`);

    svg
      .selectAll('.x-axis')
      .data([data])
      .join('g')
      .attr('class', 'x-axis')
      .style('transform', `translateY(${xAxisPosition}px)`)
      .style('font-size', '1em')
      .call(xAxis);

    svg
      .selectAll('.y-axis')
      .data([data])
      .join('g')
      .attr('class', 'y-axis')
      .style('transform', `translateX(${canvas.left}px)`)
      .style('font-size', '1em')
      .call(yAxis)
      .call(g => g.selectAll('text').attr('x', 0).attr('y', -8));

    const root = svg
      .selectAll('.root')
      .data([true])
      .join('g')
      .attr('class', 'root')
      .style('transform', `translateX(${yAxisWidth})`);

    root
      .selectAll('.column-group')
      .data(barData)
      .join('g')
      .attr('class', d => `column-group column-group-${d.vintage}`)
      .selectAll('.bar')
      .data(d => d.items)
      .join('rect')
      .attr('class', d => `bar bar-${d.key}`)
      .style('fill', d => {
        return d.type === 'targetFund' ? purpleScale(d.strategyId) : blueScale(d.strategyId);
      })
      .attr('x', d => Math.max(xScale(d.vintage), 0))
      .attr('y', d => Math.max(yScale(d.end), 0))
      .attr('width', Math.max(xScale.bandwidth(), 1))
      .attr('height', d => {
        return Math.max(yScale(d.start) - yScale(d.end), 0);
      });

    root
      .selectAll('.legend')
      .data([legendData])
      .join('g')
      .attr('class', 'legend')
      .style('transform', `translate(${canvas.left}px, ${legendPostition + 10}px )`)
      .selectAll('.legend-item')
      .data(legendData)
      .join('g')
      .attr('class', d => `legend-item legend-item-${d.key}`)
      .style('transform', (_, i) => {
        const row = Math.floor(i / legendColumns) * legendRowHeight;
        const column = (i % legendColumns) * (canvas.width / 2);
        return `translate(${column}px, ${row}px )`;
      })
      .call(selection => {
        selection
          .selectAll('text')
          .data(d => [d])
          .join('text')
          .text(d => d.label)
          .attr('x', '2em')
          .attr('y', '0.75em')
          .style('fill', 'currentColor')
          .style('alignment-baseline', 'middle')
          .style('font-size', '1em');

        selection
          .selectAll('polygon')
          .data(d => [d.key, d.key])
          .join('polygon')
          .attr('fill', (d, i) => {
            return i === 0 ? blueScale(d) : purpleScale(d);
          })
          .attr('points', (_, i) => {
            const size = 18;
            return i === 0
              ? polyPoints([
                  [0, 0],
                  [size, 0],
                  [0, size],
                ])
              : polyPoints([
                  [size, 0],
                  [size, size],
                  [0, size],
                ]);
          });
      });
  };
}

export default drawCalendarChart;
