import React, { useEffect, useRef, useState, useMemo, useCallback } from 'react';
import * as d3 from 'd3';
import { isEqual } from 'lodash';
import { motion } from 'framer-motion';

import Icon from './Icon';
import { TooltipButton } from './Buttons';

const PapersGraph = ({
  references=[], 
  recommendations = [], 
  graphParams = {}, 
  selectedPaper, 
  setSelectedPaper, 
  stackTitle=false, 
  title='Citations graph' ,
  hideAttributeSelection=false,
  wide=false
}) => {
  const svgRef = useRef(null);
  const containerRef = useRef(null);
  const [dimensions, setDimensions] = useState({ width: 0, height: 0 });
  const [hasSimulated, setHasSimulated] = useState(false);
  const [currentTransform, setCurrentTransform] = useState(d3.zoomIdentity.scale(0.3));
  const [isSimulationRunning, setIsSimulationRunning] = useState(false);
  const [nodeSizeAttribute, setNodeSizeAttribute] = useState('default');
  const [refreshKey, setRefreshKey] = useState(0);
  const [showReferences, setShowReferences] = useState(true);

  // Modify the nodePositionsRef to reset on refresh
  const nodePositionsRef = useRef({});

  const allPapers = [...references, ...recommendations];
  
  const simDuration = allPapers.length > 300 ? 5000 : 3000;

  const getNodeSizeValue = (node) => {
    switch (nodeSizeAttribute) {
      case 'score':
        return node.score;
      case 'n_citations':
        return node.n_citations;
      default:
        return node.linkCount;
    }
  };

  const memoizedData = useMemo(() => {
    if ((references.length === 0) && (recommendations.length === 0)) {
      return { nodes: [], links: [], nodeSizeScale: null };
    }

    const averageRecommendationScore = recommendations.length > 0
      ? recommendations.reduce((sum, rec) => sum + (rec.score || 0), 0) / recommendations.length
      : 0.5;

    const nodes = allPapers
      .filter(paper => showReferences || !references.includes(paper))
      .map(paper => ({ 
        id: paper.paperId, 
        group: references.includes(paper) ? 'reference' : 'recommendation',
        linkCount: 0,
        score: references.includes(paper) 
          ? (recommendations.length > 0 ? averageRecommendationScore : 0.5) 
          : paper.score,
        n_citations: paper.n_citations || 0,
        ...(nodePositionsRef.current[paper.paperId] || {})
      }));

    const nodeIds = new Set(nodes.map(node => node.id));

    let links = [];
    nodes.forEach(paper => {
      const paperData = allPapers.find(p => p.paperId === paper.id);
      if (paperData) {
        paperData.references.forEach(refId => {
          if (nodeIds.has(refId)) {
            links.push({ source: paper.id, target: refId });
            paper.linkCount++;
            nodes.find(n => n.id === refId).linkCount++;
          }
        });
      }
    });


    const values = nodes.map(getNodeSizeValue);
    const sortedValues = values.sort((a, b) => a - b);
    const p5 = d3.quantile(sortedValues, 0.05);
    const p95 = d3.quantile(sortedValues, 0.95);

    const nodeSizeScale = d3.scaleLinear()
      .domain([p5, p95])
      .range([0, 1])
      .clamp(true);

    const dataSignature = { 
      nodeIds: nodes.map(n => n.id).sort().join(','), 
      linkIds: links.map(l => `${l.source}-${l.target}`).sort().join(','),
      nodeSizeAttribute,
    };
    return { nodes, links, nodeSizeScale, dataSignature };
  }, [allPapers, nodeSizeAttribute, refreshKey, showReferences, getNodeSizeValue]);



  const { nodes, links, nodeSizeScale, dataSignature } = memoizedData;

  const handleNodeClick = useCallback((d) => {    
    const  paper = allPapers.find(p => p.paperId === d.id);

    if (paper && setSelectedPaper) {
      setSelectedPaper(paper);
    }

  }, [setSelectedPaper]);


  // Update dimensions on resize
  useEffect(() => {
    const updateDimensions = () => {
      if (containerRef.current) {
        const { width, height } = containerRef.current.getBoundingClientRect();
        setDimensions({ width, height });
      }
    };

    updateDimensions();
    window.addEventListener('resize', updateDimensions);

    return () => window.removeEventListener('resize', updateDimensions);
  }, []);

  const refreshSimulation = useCallback(() => {
    if (isSimulationRunning) {
      return;
    }
    setHasSimulated(false);
    setIsSimulationRunning(true);
    setRefreshKey(prevKey => prevKey + 1);
    // Clear stored node positions
    nodePositionsRef.current = {};
    setShowReferences(true); // Reset showReferences to true when refreshing
  }, [isSimulationRunning]);

  const getNodeRadius = useCallback((d) => {
    const { minNodeRadius = 8, maxNodeRadius = 20 } = graphParams;
    const referenceRadius = minNodeRadius + (maxNodeRadius - minNodeRadius) /  2;
    
    if (references.some(ref => ref.paperId === d.id)) {
      return referenceRadius;
    }
    
    return minNodeRadius + (maxNodeRadius - minNodeRadius) * nodeSizeScale(getNodeSizeValue(d));
  }, [graphParams, references, nodeSizeScale, getNodeSizeValue]);

  // Create and update the graph
  useEffect(() => {
    if (dimensions.width === 0 || dimensions.height === 0 || !nodeSizeScale) {
      return;
    }


    const {
      width = dimensions.width,
      height = dimensions.height,
      chargeStrength = -140,
      centerForceStrength = 1,
      linkDistance = 50,
      referenceColor = '#7a96b2',
      recommendationColor ='#de5b7e' ,
      orphanColor = '#2ca02c',
      linkThickness = 1.5,
      linkColor = '#757575',
      linkOpacity = 1,
      selectedPaperColor = 'white',
    } = graphParams;

    const orphanedNodes = nodes.filter(node => node.linkCount === 0);
    const connectedNodes = nodes.filter(node => node.linkCount > 0);

    const svg = d3.select(svgRef.current)
      .attr('width', width)
      .attr('height', height);

    svg.selectAll('*').remove();

    const g = svg.append('g')
      .attr('transform', currentTransform);

    const link = g.append('g')
      .selectAll('line')
      .data(links)
      .join('line')
      .attr('stroke', linkColor)
      .attr('stroke-width', linkThickness)
      .attr('stroke-opacity', linkOpacity);

    const node = g.append('g')
      .selectAll('circle')
      .data(nodes, d => d.id)  // Use key function to maintain node identity
      .join(
        enter => enter.append('circle')
          .attr('r', getNodeRadius)
          .attr('fill', d => {
            if (d.id === selectedPaper?.paperId) return selectedPaperColor;
            if (d.group === 'recommendation') return recommendationColor;
            if (d.group === 'reference') return referenceColor;
            return orphanColor;
          })
          .attr('stroke', d => d.id === selectedPaper?.paperId ? (d.group === 'recommendation' ? recommendationColor : referenceColor) : 'white')
          .attr('stroke-width', d => d.id === selectedPaper?.paperId ? 3 : 1)
          .attr('cx', d => d.x || Math.random() * width)
          .attr('cy', d => d.y || Math.random() * height),
        update => update
          .attr('fill', d => {
            if (d.id === selectedPaper?.paperId) return selectedPaperColor;
            if (d.group === 'recommendation') return recommendationColor;
            if (d.group === 'reference') return referenceColor;
            return orphanColor;
          })
          .attr('stroke', d => d.id === selectedPaper?.paperId ? (d.group === 'recommendation' ? recommendationColor : referenceColor) : 'white')
          .attr('stroke-width', d => d.id === selectedPaper?.paperId ? 3 : 1),
        exit => exit.remove()
      )
      .on('click', (event, d) => {
        if (!isSimulationRunning) {
          event.stopPropagation();
          handleNodeClick(d);
        }
      })
      .on('mouseover', function(event, d) {
        if (!isSimulationRunning) {
          d3.select(this).attr('stroke', d.id === selectedPaper?.paperId ? (d.group === 'recommendation' ? recommendationColor : referenceColor) : 'black');
        }
      })
      .on('mouseout', function(event, d) {
        if (!isSimulationRunning) {
          d3.select(this).attr('stroke', d.id === selectedPaper?.paperId ? (d.group === 'recommendation' ? recommendationColor : referenceColor) : 'white');
        }
      });

    // Prevent zoom on node click
    node.on("mousedown.zoom", null);
    node.on("touchstart.zoom", null);

    // Define updateGraph function after node and link are created
    function updateGraph() {
      node.attr('cx', d => d.x)
         .attr('cy', d => d.y)
         .attr('r', getNodeRadius);

      link.attr('x1', d => d.source.x)
          .attr('y1', d => d.source.y)
          .attr('x2', d => d.target.x)
          .attr('y2', d => d.target.y);
      
      // Store node positions
      nodes.forEach(d => {
        nodePositionsRef.current[d.id] = { x: d.x, y: d.y };
      });
    }

    const simulation = d3.forceSimulation(nodes)
      .force('link', d3.forceLink(links).id(d => d.id).distance(linkDistance))
      .force('charge', d3.forceManyBody().strength(d => 
        orphanedNodes.includes(d) ? chargeStrength / 4 : (chargeStrength * 1.5)
      ))
      .force('center', d3.forceCenter(width / 2, height / 2).strength(centerForceStrength))
      .force('collision', d3.forceCollide().radius(d => getNodeRadius(d) * 2).strength(0.8))
      .force('x', d3.forceX(width / 2).strength(d => orphanedNodes.includes(d) ? 0.02 : 0.01))
      .force('y', d3.forceY(height / 2).strength(d => orphanedNodes.includes(d) ? 0.02 : 0.01));

    // Add radial force for connected nodes
    simulation.force('radial', d3.forceRadial(Math.min(width, height) / 3, width / 2, height / 2)
      .strength(d => connectedNodes.includes(d) ? 0.1 : 0.05));

    // // Add clustering force for connected nodes
    // const clusterPadding = 50;
    // const clusters = d3.group(connectedNodes, d => d.group);
    
    // simulation.force('cluster', alpha => {
    //   connectedNodes.forEach(d => {
    //     const cluster = clusters.get(d.group);
    //     if (cluster) {
    //       const centroid = d3.mean(cluster, n => n.x);
    //       const k = alpha * 1;
    //       d.vx -= (d.x - centroid) * k * 0;
    //     }
    //   });
    // });

    if (!hasSimulated) {
      setIsSimulationRunning(true);
      simulation.on('tick', updateGraph);
      simulation.alpha(1).restart();

      setTimeout(() => {
        simulation.stop();
        setHasSimulated(true);
        setIsSimulationRunning(false);
      }, simDuration);
    } else {
      // If simulation has already run, just update the graph with stored positions
      updateGraph();
    }

    const zoom = d3.zoom()
      .scaleExtent([0.1, 4])
      .on('zoom', (event) => {
        g.attr('transform', event.transform);
        setCurrentTransform(event.transform);
      });

    svg.call(zoom)
      .call(zoom.transform, currentTransform)
      .on("dblclick.zoom", null)
      .on("click.zoom", null);

  
    // Show zoom buttons on mobile
    const showZoomButtonsOnMobile = () => {
      if (window.innerWidth <= 768) {
        d3.selectAll('.zoom-buttons').style('display', null);
      } else {
        d3.selectAll('.zoom-buttons').style('display', 'none');
      }
    };

    showZoomButtonsOnMobile();

    // Cleanup function
    return () => {
      simulation.stop();
      window.removeEventListener('resize', showZoomButtonsOnMobile);
    };
  }, [dimensions, nodeSizeScale, handleNodeClick, references, recommendations, graphParams, selectedPaper, hasSimulated, dataSignature, currentTransform, nodeSizeAttribute, refreshKey, getNodeRadius]);

  const toolTipButtonStyle = "text-xs text-gray-500 hover:text-gray-700 border-b  hover:border-secondary/50";
  const activeTooTipButtonStyle = "text-secondary border-secondary";

  return (
    <motion.div 
      className="group"
      initial={{ opacity: 0, y: 20 }}
      animate={{ opacity: 1, y: 0 }}
      exit={{ opacity: 0, y: 20 }}
      transition={{ duration: 0.5 }}
    >
      <div className={`flex justify-between items-center mb-3 mt-1 w-10/12 lg:full max-w-7xl mx-auto ${stackTitle ? 'flex-col' : ''}`}>
        <p className={`text-sm font-semibold ${stackTitle ? 'mb-1' : ''}`}><Icon icon="chart-network" className="mr-1" /> {title}</p>
        {!hideAttributeSelection && (
          <div className="flex space-x-2">
            <TooltipButton
              icon="circle-o"
              tooltip="Toggle node size by number of connections"
              onClick={() => setNodeSizeAttribute('default')}
              className={`${toolTipButtonStyle} ${nodeSizeAttribute === 'default' ? activeTooTipButtonStyle : 'border-transparent'}`}
            />
            <TooltipButton
              icon="fire"
              tooltip="Toggle node size by relevance score"
              onClick={() => setNodeSizeAttribute('score')}
              className={`${toolTipButtonStyle} ${nodeSizeAttribute === 'score' ? activeTooTipButtonStyle : 'border-transparent'}`}
            />
            <TooltipButton
              icon="chart-line"
              tooltip="Toggle node size by number of citations"
              onClick={() => setNodeSizeAttribute('n_citations')}
              className={`${toolTipButtonStyle} ${nodeSizeAttribute === 'n_citations' ? activeTooTipButtonStyle : 'border-transparent'}`}
            />
          </div>
        )}
        <div className={`${!hideAttributeSelection ? 'border-l border-gray-400 pl-2' : ''} flex items-center space-x-2 `}>

        {recommendations.length > 0 && (
          <TooltipButton
            icon="book"
            tooltip={showReferences ? "Hide references" : "Show references"}
            onClick={() => hasSimulated && setShowReferences(!showReferences)}
            className={`${toolTipButtonStyle} ${showReferences ? activeTooTipButtonStyle : 'border-transparent'} ${!hasSimulated ? 'opacity-50 cursor-not-allowed' : ''}`}
          />
        )}
          <TooltipButton
            icon="refresh"
            tooltip="Refresh simulation"
            onClick={refreshSimulation}
            className={`${toolTipButtonStyle} border-transparent`}
          />
        </div>
      </div>

      <motion.div  
        ref={containerRef} 
        className={`rounded-lg border bg-secondary-dark/10  mb-1 cursor-pointer w-10/12 ${wide ? 'lg:w-10/12' : 'lg:w-2/3'} aspect-square mx-auto`}
        initial={{ scale: 0.9 }}
        animate={{ scale: 1 }}
        exit={{ scale: 0.9 }}
        transition={{ duration: 0.3, delay: 0.2 }}
      >
        <svg ref={svgRef} className="w-full h-full"></svg> 
      </motion.div>
    
      {/* legend */}
      {references.length > 0 && recommendations.length > 0 && (
        <div className={`flex flex-row justify-center items-center gap-8 w-full px-2 py-1  text-xs `}>
          <div className="flex flex-row items-center">
            <span className="inline-block w-3 h-3 border border-white rounded-full bg-[#6f8193] mr-1"></span> references
          </div>
          <div className="flex flex-row items-center">
            <span className="inline-block w-3 h-3 border border-white rounded-full bg-[#de5b7e] mr-1"></span> recommendations
          </div>
        </div>
      )}

      {setSelectedPaper && (
        <p className="text-xs text-center group-hover:opacity-100 opacity-0 transition-opacity duration-300 mb-2">
          Click on a node to see a <Icon icon="file-lines" className="-mr-0.5" /> paper's details
        </p>
      )}
    </motion.div>
  );
};

export default React.memo(PapersGraph);
