import { Dispatch, SetStateAction, useEffect, useMemo } from "react";

// Packages
import { Node as ReactFlowNode, Edge as ReactFlowEdge } from "react-flow-renderer";
import { HierarchyNode, hierarchy as d3Hierarchy, tree as d3Tree } from "d3-hierarchy";
import { every, filter, forEach, keyBy, mapValues, uniqBy } from "lodash";

// Utils
import { hideRootNode } from "../utils/Dag.helpers";

// Contexts
import { useDagFlowContext } from "../context/useDagFlowContext";

// Types
import { EdgeType, NodeType } from "../../../Dag.types";

type Props = {
  setNodes: Dispatch<SetStateAction<ReactFlowNode[]>>;
  setEdges: Dispatch<SetStateAction<ReactFlowEdge[]>>;
};

const useExpandCollapse = (props: Props) => {
  const { setNodes, setEdges } = props || {};

  const { initialNodes, initialEdges, edgeType, nodesExpanded, setNodesExpanded } =
    useDagFlowContext();

  const isAllNodesExpanded = useMemo(
    () => every(nodesExpanded, (value) => value === true),
    [nodesExpanded]
  );

  const isAllNodesCollapsed = useMemo(
    () => every(nodesExpanded, (value) => value === false),
    [nodesExpanded]
  );

  const expandAllNodes = () => {
    setNodesExpanded(mapValues(nodesExpanded, () => true));
  };

  const collapseAllNodes = () => {
    setNodesExpanded(mapValues(nodesExpanded, () => false));
  };

  const layout = d3Tree<NodeType>().nodeSize([25, 200]);

  const getNodesTree = () => {
    // Create a lookup map for the nodes
    const nodeMap = keyBy(initialNodes, "id") as unknown as { [key: string]: NodeType };

    // Add a children array to each node
    forEach(nodeMap, (node: NodeType) => {
      // @ts-ignore
      node.children = [];
    });

    // Build the hierarchical structure using the edges
    forEach(initialEdges, (edge: EdgeType) => {
      const sourceNode = nodeMap[edge.source];
      const targetNode = nodeMap[edge.target];
      if (sourceNode && targetNode) {
        // @ts-ignore
        sourceNode.children?.push(targetNode);
      }
    });

    // Find the root nodes (nodes that are not targets in any edge)
    const targetIds = new Set(initialEdges.map((edge: EdgeType) => edge.target));
    const rootNodes = filter(nodeMap, (node) => !targetIds.has(node.id));

    return rootNodes;
  };

  const getElements = (hierarchy: HierarchyNode<NodeType>) => {
    hierarchy.descendants().forEach((d: HierarchyNode<NodeType>) => {
      // @ts-ignore
      d.children = d.data.expanded ? d.data.children : null;
    });

    const root = layout(hierarchy);

    const nodes = root.descendants().map((d: HierarchyNode<NodeType>) => ({
      id: d.data.id,
      type: d.data.type,
      hidden: d.data.hidden,
      draggable: d.data.draggable,
      position: d.data.position || { x: d.y, y: d.x },
      data: { ...d.data.data }
    }));

    const edgeIds: { [key: string]: string } = {};
    forEach(initialEdges, (edge: EdgeType) => {
      edgeIds[`${edge.source}_${edge.target}`] = edge.id;
    });

    const edges = root.links().map((d, i) => ({
      id: edgeIds[`${d.source.data.id}_${d.target.data.id}`] || `${i}`,
      source: d.source.data.id,
      target: d.target.data.id,
      type: edgeType
    }));

    return { nodes, edges };
  };

  useEffect(() => {
    const dummyRootNode = { name: "rootNode", children: getNodesTree() };
    // @ts-ignore
    const hierarchy = d3Hierarchy<NodeType>(dummyRootNode);

    hierarchy.descendants().forEach((d: HierarchyNode<NodeType>) => {
      // @ts-ignore
      d.data.expanded = nodesExpanded[d.data.id] ?? !d.data.data?.collapsed;
      d.data.id = d.data.id;
      // @ts-ignore
      d.data.children = d.children;
      d.children = d.children;
    });

    const initialElements = getElements(hierarchy);

    // @ts-ignore
    const nodes = uniqBy(hideRootNode(initialElements.nodes), "id");
    const edges = uniqBy(initialElements.edges, "id");

    // @ts-ignore
    setNodes(nodes);
    setEdges(edges);
  }, [initialNodes, initialEdges, nodesExpanded]);

  return { isAllNodesExpanded, isAllNodesCollapsed, expandAllNodes, collapseAllNodes };
};

export default useExpandCollapse;
