import React, { useMemo, useState, useRef, useCallback, useEffect } from 'react';
import { scaleTime, scaleLinear, scaleOrdinal } from '@visx/scale';
import { Group } from '@visx/group';
import { AxisLeft, AxisBottom } from '@visx/axis';
import { LinePath, Line } from '@visx/shape';
import { extent, bisector } from 'd3-array';

import { useTooltip } from '@visx/tooltip';
import { localPoint } from '@visx/event';
 
import { GridRows, GridColumns } from '@visx/grid';
import { RawDataItem } from '../../domain/RawDataItem';
import { ModeledData} from '../../domain/ModeledData';
import DownloadChartButton from '../DownloadChartButton';

type ProcessedModeledData = ModeledData & { date: Date };

 
interface TimeChartProps {
  data: ModeledData[];
  cleanedReadings: RawDataItem[];
  modes: string[];
  colors: string[];
  visibleModes: Set<number>;
  width: number;
  height: number;
  xDomain: [Date, Date];

}

const TimeModeChart: React.FC<TimeChartProps> = ({ 
  data, 
  cleanedReadings, 
  modes, 
  colors, 
  visibleModes, 
  width, 
  height, 
  xDomain
}) => {
  // Helper function to ensure consistent date handling
  const ensureDate = (date: Date | string): Date => {
    return date instanceof Date ? date : new Date(date);
  };

  const processDataPoints = <T extends { ts: string | Date }>(items: T[]): (T & { date: Date })[] => {
    return items.map(item => ({
      ...item,
      date: ensureDate(item.ts)
    }));
  };

  const margin = { top: 20, right: 20, bottom: 60, left: 40 };
  const innerWidth = width - margin.left - margin.right;
  const innerHeight = height - margin.top - margin.bottom;

  const [xDomainState, setXDomainState] = useState<[Date, Date]>(xDomain);
  const [hoveredPoint, setHoveredPoint] = useState<ProcessedModeledData | null>(null);
  const [tooltipPos, setTooltipPos] = useState({ x: 0, y: 0 });
  const [hiddenWeekdays, setHiddenWeekdays] = useState<number[]>([]);

  const {
    showTooltip,
    hideTooltip,
    tooltipData,
    tooltipLeft,
    tooltipTop,
  } = useTooltip<ModeledData>();

  // Process data once with proper date objects
  const processedData = useMemo(() => processDataPoints(data), [data]);
  const processedCleanedReadings = useMemo(() => processDataPoints(cleanedReadings), [cleanedReadings]);
   
  // Update xDomainState when xDomain prop changes
  useEffect(() => {
    setXDomainState(xDomain);
  }, [xDomain]);

  // Filter data based on visibleModes and showExcluded
  const filteredData = useMemo(() => 
    processedData.filter(d => visibleModes.has(d.mode)),
    [processedData, visibleModes]
  );

  const filteredReadings = useMemo(() => 
    processedCleanedReadings.filter(d => visibleModes.has(d.mode)),
    [processedCleanedReadings, visibleModes]
  );

  // Combine scale calculations
  const { xScale, yScale } = useMemo(() => {
    const x = scaleTime({
      range: [0, innerWidth],
      domain: xDomainState,
    });

    const [min, max] = extent(filteredReadings, d => d.rate) as [number, number];
    const y = scaleLinear({
      range: [innerHeight, 0],
      domain: [Math.floor(min), Math.ceil(max)],
      nice: true,
    });

    return { xScale: x, yScale: y };
  }, [innerWidth, innerHeight, xDomainState, filteredReadings]);

  // Memoize axis ticks and gridlines together
  const { yTicks, xTicks } = useMemo(() => ({
    yTicks: yScale.ticks().filter(Number.isInteger),
    xTicks: xScale.ticks(12)
  }), [yScale, xScale]);

  // Group data by mode once
  const groupedData = useMemo(() => {
    const groups: { [key: number]: ProcessedModeledData[] } = {};
    for (let i = 1; i <= modes.length; i++) {
      groups[i] = filteredData.filter(d => d.mode === i);
    }
    return groups;
  }, [filteredData, modes.length]);

  const bisectDate = useMemo(() => bisector<ProcessedModeledData, Date>(d => d.date).left, []);

  const getClosestPoint = (d0: ProcessedModeledData, d1: ProcessedModeledData, xDate: Date) => {
    if (!visibleModes.has(d0.mode)) {
      return d1;
    } else if (!visibleModes.has(d1.mode)) {
      return d0;
    } else {
      // Both points visible, choose the closest one
      return xDate.valueOf() - d0.date.valueOf() > d1.date.valueOf() - xDate.valueOf() ? d1 : d0;
    }
  };

  const handleMouseMove = useCallback(
    (event: React.MouseEvent<SVGElement>) => {
      const { x, y } = localPoint(event) || { x: 0, y: 0 };
      const xDate = xScale.invert(x - margin.left);
      const index = bisectDate(filteredData, xDate, 1);
      
      // Get the two closest points
      const d0 = filteredData[index - 1];
      const d1 = filteredData[index];
      
      if (!d0 || !d1 || (!visibleModes.has(d0.mode) && !visibleModes.has(d1.mode))) {
        return;
      }
      
      const closestPoint = getClosestPoint(d0, d1, xDate);
      if (!closestPoint) return;
      
      const tooltipX = x >= width / 2 ? x - 10 : x + 60;
      setHoveredPoint(closestPoint);
      setTooltipPos({ x: tooltipX, y });
      showTooltip({
        tooltipData: closestPoint,
        tooltipLeft: x,
        tooltipTop: y,
      });
    },
    [showTooltip, xScale, filteredData, width, margin.left, bisectDate, visibleModes]
  );

  const handleMouseLeave = useCallback(() => {
    hideTooltip();
    setHoveredPoint(null);
    setTooltipPos({ x: 0, y: 0 });
  }, [hideTooltip]);

  const resetZoom = useCallback(() => {
    setXDomainState(xDomain);
  }, [xDomain]);

  const chartRef = useRef<SVGSVGElement>(null);

  // Zoom functionality
  const [zoomRect, setZoomRect] = useState<{ startX: number; endX: number } | null>(null);
  const [isDragging, setIsDragging] = useState(false);

  const handleMouseDown = useCallback((event: React.MouseEvent<SVGElement>) => {
    const point = localPoint(event);
    if (point) {
      setZoomRect({ startX: point.x - margin.left, endX: point.x - margin.left });
      setIsDragging(true);
    }
  }, [margin.left]);

  const handleDragMove = useCallback((event: React.MouseEvent<SVGElement>) => {
    if (!isDragging) return;
    const point = localPoint(event);
    if (point) {
      setZoomRect(prev => prev ? { ...prev, endX: point.x - margin.left } : null);
    }
  }, [isDragging, margin.left]);

  const handleDragEnd = useCallback(() => {
    setIsDragging(false);
    if (zoomRect) {
      const { startX, endX } = zoomRect;
      const visibleData = [...filteredReadings]
        .filter(d => visibleModes.has(d.mode));
      
      const timeMin = Math.min(...visibleData.map(d => d.date.getTime()));
      const timeMax = Math.max(...visibleData.map(d => d.date.getTime()));
      
      let x1 = xScale.invert(Math.min(startX, endX));
      let x2 = xScale.invert(Math.max(startX, endX));
      
      // Constrain zoom to visible data range
      x1 = new Date(Math.max(x1.getTime(), timeMin));
      x2 = new Date(Math.min(x2.getTime(), timeMax));
      
      setXDomainState([x1, x2]);
    }
    setZoomRect(null);
  }, [xScale, zoomRect, filteredReadings, visibleModes]);

  

  return (
    <div style={{ position: 'relative' }}>
      <DownloadChartButton chartRef={chartRef} filename="time-chart" />
      <button onClick={resetZoom} style={{ position: 'absolute', top: -10, right: 10 }}>
        Reset Zoom
      </button>
      <svg ref={chartRef} width={width} height={height}
           onMouseMove={handleMouseMove}
           onMouseLeave={handleMouseLeave}
           onMouseDown={handleMouseDown}
           onMouseMoveCapture={handleDragMove}
           onMouseUp={handleDragEnd}
          
      >
       
        <Group left={margin.left} top={margin.top}>
          <GridRows
            scale={yScale}
            width={innerWidth}
            height={innerHeight}
            stroke="#e0e0e0"
            strokeOpacity={0.5}
            tickValues={yTicks}
          />
          <GridColumns
            scale={xScale}
            width={innerWidth}
            height={innerHeight}
            stroke="#e0e0e0"
            strokeOpacity={0.5}
            tickValues={xTicks}
          />
          {filteredReadings
            .map((d, i) => {
              const xPos = xScale(d.date);
              const yPos = yScale(d.rate);
              
              // Skip rendering if any coordinate is invalid
              if (isNaN(xPos) || isNaN(yPos) || !isFinite(xPos) || !isFinite(yPos)) {
                return null;
              }
               
              return (
                <circle
                  key={i}
                  cx={xPos}
                  cy={yPos}
                  r={3}
                  stroke={colors[(d.mode - 1) % colors.length]}
                  fill={colors[(d.mode - 1) % colors.length]}
                />
              );
            })}

          {Object.entries(groupedData).map(([modeStr, rows]) => {
            const modeIndex = Number(modeStr) - 1;
            if (modeIndex < 0 || modeIndex >= modes.length) return null;
            
            return (
              <React.Fragment key={modeStr}>
                <LinePath
                  data={rows}
                  x={d => xScale(d.date)}
                  y={d => yScale(d.pred)}
                  stroke={colors[(modeIndex) % colors.length]}
                  strokeWidth={1} 
                />
                {rows.map((d, i) => {
                  const xPos = xScale(d.date);
                  const yPos = yScale(d.rate);
                  
                  // Skip rendering if any coordinate is invalid
                  if (isNaN(xPos) || isNaN(yPos) || !isFinite(xPos) || !isFinite(yPos)) {
                    return null;
                  }
                  
                  return (
                    <circle
                      key={`${modeStr}-${i}`}
                      cx={xPos}
                      cy={yPos}
                      r={3}
                      stroke={colors[(modeIndex) % colors.length]}
                      fill={colors[(modeIndex) % colors.length]}
                    />
                  );
                })}
              </React.Fragment>
            );
          })}
          
         
          <AxisBottom 
            top={innerHeight} 
            scale={xScale} 
            tickValues={xTicks}
          />
          <AxisLeft
            scale={yScale}
            numTicks={6}
            tickFormat={(value) => `${Math.round(+value)}` }
            tickValues={yTicks}
          />
          
          {tooltipData && tooltipLeft !== undefined && (
            <Line
              from={{ x: tooltipLeft - margin.left, y: 0 }}
              to={{ x: tooltipLeft - margin.left, y: innerHeight }}
              stroke="#999"
              strokeWidth={1}
              pointerEvents="none"
            />
          )}

          {/* Zoom rectangle */}
          {zoomRect && (
            <rect
              x={Math.min(zoomRect.startX, zoomRect.endX)}
              y={0}
              width={Math.abs(zoomRect.endX - zoomRect.startX)}
              height={innerHeight}
              fill="rgba(0, 128, 255, 0.2)"
              stroke="blue"
              strokeWidth={1}
              strokeDasharray="4 4"
            />
          )}
        </Group>
        
        <text
          x={-height / 2}
          y={15}
          transform="rotate(-90)"
          textAnchor="middle"
          fontSize={12}
          fill="#000000"
        >
          Power (kW)
        </text>

       
      </svg>
      
      {tooltipData && (
        <div
          style={{
            position: 'absolute',
            top: tooltipPos.y,
            left: tooltipPos.x,
            transform: `translate(${tooltipPos.x > width / 2 ? '-100%' : '0%'}, -50%)`,
            backgroundColor: 'rgba(255, 255, 255, 0.8)',
            padding: '8px',
            borderRadius: '4px',
            boxShadow: '0 1px 10px rgba(0,0,0,0.1)',
            pointerEvents: 'none',
            display: 'grid',
            gridTemplateColumns: 'auto auto',
            gap: '4px 8px',
          }}
        >
          <div style={{ textAlign: 'left' }}><strong>Date</strong></div>
          <div style={{ textAlign: 'right' }}>
            {hoveredPoint?.date ? 
              `${new Date(hoveredPoint.date).toLocaleDateString()} (${new Date(hoveredPoint.date).toLocaleDateString('en-US', { weekday: 'long' })})`
              : ''}
          </div>
          <div style={{ textAlign: 'left' }}><strong>Mode</strong></div>
          <div style={{ textAlign: 'right' }}>{modes[(hoveredPoint?.mode || 1) - 1]}</div>
          <div style={{ textAlign: 'left' }}><strong>Rate</strong></div>
          <div style={{ textAlign: 'right' }}>{hoveredPoint?.rate?.toFixed(2)}</div>
          <div style={{ textAlign: 'left' }}><strong>Expected</strong></div>
          <div style={{ textAlign: 'right' }}>{hoveredPoint?.pred.toFixed(2)}</div>
        </div>
      )}
    </div>
  );
};

export default TimeModeChart;