import React, { useMemo, useCallback, useRef, useState } from 'react';
import { scaleLinear } from '@visx/scale';
import { Group } from '@visx/group';
import { AxisLeft, AxisBottom } from '@visx/axis';
import { LinePath, Line } from '@visx/shape';
import { Grid } from '@visx/grid';
import { extent } from 'd3-array';
import { Text } from '@visx/text';
import { useTooltip } from '@visx/tooltip';
import { localPoint } from '@visx/event';

import DownloadChartButton from '../DownloadChartButton';
import { ModeledData } from '../../domain/ModeledData';

interface OATChartProps {
  data: ModeledData[];
  modes: string[];
  width: number;
  height: number;
  colors: string[];
  visibleModes: Set<number>;
  showExcluded: boolean;
}

// Helper types
type PointType = 'circle' | 'square' | 'cross';

// Helper functions
const isRecent = (date: string, daysThreshold: number = 32): boolean => {
  return (new Date().getTime() - new Date(date).getTime()) / (1000 * 60 * 60 * 24) < daysThreshold;
};

const isWithinYear = (date: string): boolean => {
  return (new Date().getTime() - new Date(date).getTime()) / (1000 * 60 * 60 * 24) < 365;
};

const OATChart: React.FC<OATChartProps> = ({ 
  data,   
  modes, 
  width, 
  height, 
  colors,
  visibleModes,
  showExcluded
}) => {
  const margin = { top: 40, right: 20, bottom: 80, left: 60 };

  const yMax = height - margin.top - margin.bottom;
  const xMax = width - margin.left - margin.right;
  const innerHeight = height - margin.top - margin.bottom;
  
  const [hoveredPoint, setHoveredPoint] = useState<ModeledData | null>(null);
  const [xDomain, setXDomain] = useState<[number, number] | null>(null);
  const chartRef = useRef<SVGSVGElement>(null);
  const [zoomRect, setZoomRect] = useState<{ startX: number; endX: number } | null>(null);
  const [isDragging, setIsDragging] = useState(false);

  const {
    showTooltip,
    hideTooltip,
    tooltipData,
    tooltipTop = 0,
    tooltipLeft = 0,
  } = useTooltip<ModeledData>();

  const filteredRawData = useMemo(() =>
    data.filter(d => visibleModes.has(d.mode) && (d.cause === 'Clean' || showExcluded)),
    [data, visibleModes, showExcluded]
  );

  const domains = useMemo(() => {
    const xExtent = extent(filteredRawData, d => d.oat) as [number, number];
    const yExtent = extent(filteredRawData, d => d.rate) as [number, number];
    return { xExtent, yExtent };
  }, [filteredRawData]);

  const xScale = useMemo(
    () => scaleLinear<number>({
      domain: xDomain || domains.xExtent,
      range: [0, xMax],
      nice: true,
    }),
    [xMax, domains.xExtent, xDomain]
  );

  const yScale = useMemo(
    () => scaleLinear<number>({
      domain: domains.yExtent,
      range: [yMax, 0],
      nice: true,
    }),
    [yMax, domains.yExtent]
  );

  const processedData = useMemo(() => {
    const sorted = [...data].sort((a, b) => a.oat - b.oat);

    const oatMin = sorted[0]?.oat ?? 0;
    const oatMax = sorted[sorted.length - 1]?.oat ?? 0;
    const oatRange = oatMax - oatMin;

    return {
      allData: sorted,
      oatMin,
      oatMax,
      oatRange
    };
  }, [data]);

  const visibleData = useMemo(() => {
    if (!processedData.allData.length) return [];
    return processedData.allData.filter(d => visibleModes.has(d.mode));
  }, [processedData.allData, visibleModes]);

  const lookupIndices = useMemo(() => {
    if (!processedData.allData.length) return new Map();

    const indices = new Map();
    const bucketSize = processedData.oatRange / 100; 

    processedData.allData.forEach((point, index) => {
      const bucket = Math.floor((point.oat - processedData.oatMin) / bucketSize);
      if (!indices.has(bucket)) {
        indices.set(bucket, index);
      }
    });

    return indices;
  }, [processedData]);

  const findClosestPoint = useCallback((mouseX: number, marginLeft: number) => {
    if (!visibleData.length || !processedData.allData.length) return null;

    const xOAT = xScale.invert(mouseX - marginLeft);

    const bucketSize = processedData.oatRange / 100;
    const bucket = Math.floor((xOAT - processedData.oatMin) / bucketSize);

    let startIndex = lookupIndices.get(bucket) ?? 0;
    startIndex = Math.max(0, startIndex - 5); 

    let closestPoint = null;
    let minDistance = Infinity;

    for (let i = startIndex; i < processedData.allData.length; i++) {
      const point = processedData.allData[i];
      if (!visibleModes.has(point.mode)) continue;

      const distance = Math.abs(point.oat - xOAT);

      if (point.oat > xOAT && distance > minDistance) break;

      if (distance < minDistance) {
        minDistance = distance;
        closestPoint = point;
      }
    }

    return closestPoint;
  }, [visibleData, processedData, xScale, visibleModes]);

  const handleMouseMove = useCallback(
    (event: React.MouseEvent<SVGElement>) => {
      if (!visibleData.length) return;

      const { x, y } = localPoint(event) || { x: 0, y: 0 };
      const closestPoint = findClosestPoint(x, margin.left);

      if (closestPoint) {
        setHoveredPoint(closestPoint);
        showTooltip({
          tooltipData: closestPoint,
          tooltipLeft: x,
          tooltipTop: y,
        });
      }
    },
    [showTooltip, findClosestPoint, margin.left, visibleData.length]
  );

  const handleMouseLeave = useCallback(() => {
    hideTooltip();
    setHoveredPoint(null);
  }, [hideTooltip]);

  const handleMouseDown = useCallback((event: React.MouseEvent<SVGElement>) => {
    const point = localPoint(event);
    if (point) {
      setZoomRect({ startX: point.x - margin.left, endX: point.x - margin.left });
      setIsDragging(true);
    }
  }, [margin.left]);

  const handleDragMove = useCallback((event: React.MouseEvent<SVGElement>) => {
    if (!isDragging) return;
    const point = localPoint(event);
    if (point) {
      setZoomRect(prev => prev ? { ...prev, endX: point.x - margin.left } : null);
    }
  }, [isDragging, margin.left]);

  const handleDragEnd = useCallback(() => {
    setIsDragging(false);
    if (zoomRect && zoomRect.startX - zoomRect.endX<-100) {
      const { startX, endX } = zoomRect;

      const baseScale = scaleLinear<number>({
        domain: domains.xExtent,
        range: [0, xMax],
        nice: true
      });

      let x1 = baseScale.invert(Math.min(startX, endX));
      let x2 = baseScale.invert(Math.max(startX, endX));

      x1 = Math.max(x1, domains.xExtent[0]);
      x2 = Math.min(x2, domains.xExtent[1]);

      setXDomain([x1, x2]);
    } else {
      console.log(zoomRect);
    }
    setZoomRect(null);
  }, [zoomRect, domains.xExtent, xMax]);

  const resetZoom = useCallback(() => {
    setXDomain(null);
  }, []);

  const svgEventHandlers = useMemo(() => ({
    onMouseMove: handleMouseMove,
    onMouseLeave: handleMouseLeave,
    onMouseDown: handleMouseDown,
    onMouseMoveCapture: handleDragMove,
    onMouseUp: handleDragEnd
  }), [handleMouseMove, handleMouseLeave, handleMouseDown, handleDragMove, handleDragEnd]);

  // Grouped data based on various filters
  const { excludedPoints, cleanPoints } = useMemo(() => ({
    excludedPoints: data.filter(d => d.cause !== 'Clean'),
    cleanPoints: data.filter(d => d.cause === 'Clean')
  }), [data]);

  const groupedData = useMemo(() => {
    const groups: { [key: number]: ModeledData[] } = {};
    visibleData.forEach(row => {
      if (!groups[row.mode]) {
        groups[row.mode] = [];
      }
      groups[row.mode].push(row);
    });
    return Object.entries(groups)
      .map(([iso_dow, rows]) => ({
        iso_dow: parseInt(iso_dow),
        rows
      }));
  }, [visibleData]);

  // Reusable components
  const SparkLine: React.FC<{ data: number[], width: number, height: number }> = ({ data, width, height }) => {
    if (!data || data.length === 0) return null;

    const padding = 2;
    const innerWidth = width - 2 * padding;
    const innerHeight = height - 2 * padding;

    const xScale = scaleLinear({
      domain: [0, data.length - 1],
      range: [0, innerWidth]
    });

    const yScale = scaleLinear({
      domain: extent(data) as [number, number],
      range: [innerHeight, 0]
    });

    const points = data.map((d, i) => ({
      x: xScale(i) + padding,
      y: yScale(d) + padding
    }));

    return (
      <svg width={width} height={height} style={{ display: 'block' }}>
        <LinePath
          data={points}
          x={d => d.x}
          y={d => d.y}
          stroke="#666"
          strokeWidth={1}
        />
      </svg>
    );
  };

  // Component to render different point types (circle, square, cross)
  const DataPoint: React.FC<{
    point: ModeledData,
    index: number,
    type: PointType,
    color: string
  }> = ({ point, index, type, color }) => {
    const isOlderThanYear = !isWithinYear(point.ts);
    const size = isOlderThanYear ? 1 : 3;
    const isRecentPoint = isRecent(point.ts);
    const fill = isRecentPoint ? color : 'none';
    
    const key = `point-${type}-${index}-${point.oat}-${point.rate}`;
    const x = xScale(point.oat);
    const y = yScale(point.rate);
    
    switch(type) {
      case 'cross':
        return (
          <use
            key={key}
            href="#andreas-cross-oat"
            x={x-4}
            y={y-4}
            stroke={color}
          />
        );
      case 'square':
        return (
          <rect
            key={key}
            x={x - size/2}
            y={y - size/2}
            width={size}
            height={size}
            opacity={1}
            fill={fill}
            stroke={color}
          />
        );
      case 'circle':
      default:
        return (
          <circle
            key={key}
            cx={x}
            cy={y}
            r={size}
            opacity={0.5}
            fill={fill}
            stroke={color}
          />
        );
    }
  };

  const TooltipContent: React.FC<{
    point: ModeledData;
    top: number;
    left: number;
    modes: string[];
  }> = React.memo(({ point, top, left, modes }) => (
    <div
      style={{
        position: 'absolute',
        top,
        left,
        transform: 'translate(-50%, -100%)',
        backgroundColor: 'white',
        padding: '8px',
        borderRadius: '4px',
        boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
        border: '1px solid #ccc',
        fontSize: '12px',
        zIndex: 1000,
      }}
    >
      <table style={{borderSpacing: '4px'}}>
        <tbody>
          <tr><th style={{textAlign: 'left'}}>Time</th><td>{point.ts}</td></tr>
          <tr><th style={{textAlign: 'left'}}>OAT</th><td style={{textAlign: 'right'}}>{point.oat?.toFixed(2)}</td><td style={{textAlign: 'left'}}>°C</td></tr>
          <tr><th style={{textAlign: 'left'}}>Rate</th><td style={{textAlign: 'right'}}>{point.rate?.toFixed(2)}</td><td style={{textAlign: 'left'}}>kW</td></tr>
          <tr><th style={{textAlign: 'left'}}>Expected</th><td style={{textAlign: 'right'}}>{point.pred?.toFixed(2)}</td><td style={{textAlign: 'left'}}>kW</td></tr>
          <tr><th style={{textAlign: 'left'}}>Day</th><td>{modes[point.mode - 1]}</td></tr>
          {point.profile && (
            <tr>
              <th style={{textAlign: 'left', verticalAlign: 'top'}}>Profile</th>
              <td colSpan={2}>
                <SparkLine data={point.profile} width={100} height={30} />
              </td>
            </tr>
          )}
        </tbody>
      </table>
    </div>
  ));

  return (
    <div style={{ position: 'relative' }}>
      <DownloadChartButton chartRef={chartRef} filename="oat-chart" />
      <button onClick={resetZoom} style={{ position: 'absolute', top: 10, right: 10 }}>
        Reset Zoom
      </button>
      <svg ref={chartRef} width={width} height={height}
           {...svgEventHandlers}
      >
        <defs>
          <symbol id="andreas-cross-oat" >
            <line
              x1="0" y1="0"
              x2="8" y2="8"
              strokeWidth={1.5}
            />
            <line
              x1="0" y1="8"
              x2="8" y2="0"
              strokeWidth={1.5}
            />
          </symbol>
        </defs>

        <Group left={margin.left} top={margin.top}>
          <Grid
            xScale={xScale}
            yScale={yScale}
            width={xMax}
            height={yMax}
            stroke="#e0e0e0"
            strokeOpacity={0.5}
          />

          {/* Render excluded points */}
          {showExcluded && excludedPoints
            .filter(d => visibleModes.has(d.mode))
            .map((d, i) => (
              <DataPoint 
                key={`excluded-${i}`}
                point={d}
                index={i}
                type="cross"
                color={colors[(d.mode - 1) % colors.length]}
              />
            ))}

          {/* Render clean points - using different shapes based on day of week */}
          {cleanPoints
            .filter(d => visibleModes.has(d.mode))
            .map((d, i) => {
              const isSunday = new Date(d.ts).getDay() === 0;
              return (
                <DataPoint 
                  key={`clean-${i}`}
                  point={d}
                  index={i}
                  type={isSunday ? 'square' : 'circle'}
                  color={colors[(d.mode - 1) % colors.length]}
                />
              );
            })}

          {/* Render prediction lines */}
          {groupedData.filter(({ rows }) => rows.length > 10).map(({ iso_dow, rows }, groupIndex) => (
            <LinePath<ModeledData>
              key={`line-${groupIndex}`}
              data={rows}
              x={(d: ModeledData) => xScale(d.oat)}
              y={(d: ModeledData) => yScale(d.pred)}
              stroke={colors[(iso_dow - 1) % colors.length]}
              strokeWidth={0.7}
            />
          ))}

          <AxisLeft
            scale={yScale}
            left={xScale(0)}
            label="Power (kW)"
            numTicks={6}
            tickFormat={(value) => `${Math.round(+value)}`}
          />
          <AxisBottom
            scale={xScale}
            top={innerHeight}
            label="OAT (°C)"
          />

          <Text
            angle={-90}
            width={yMax}
            y={-margin.left + 15}
            x={-yMax / 2}
            textAnchor="middle"
          >
            Power (kW)
          </Text>

          {tooltipData && tooltipLeft !== undefined && (
            <Line
              from={{ x: tooltipLeft - margin.left, y: 0 }}
              to={{ x: tooltipLeft - margin.left, y: yMax }}
              stroke="#999"
              strokeWidth={1}
              pointerEvents="none"
            />
          )}

          {zoomRect && (
            <rect
              x={Math.min(zoomRect.startX, zoomRect.endX)}
              y={margin.top}
              width={Math.abs(zoomRect.endX - zoomRect.startX)}
              height={yMax}
              fill="rgba(0, 128, 255, 0.2)"
              stroke="blue"
              strokeWidth={1}
              strokeDasharray="4 4"
            />
          )}
        </Group>
      </svg>
      {hoveredPoint && (
        <TooltipContent
          point={hoveredPoint}
          top={tooltipTop}
          left={tooltipLeft}
          modes={modes}
        />
      )}
    </div>
  );
};

export default OATChart;
