import { useEffect, useMemo } from "react";

import { HierarchyNode, hierarchy as d3Hierarchy, tree as d3Tree } from "d3-hierarchy";
import { every, filter, forEach, keyBy, mapValues, uniqBy } from "lodash";

import { hideRootNode } from "../utils/Dag.helpers";

import { useDagFlowContext } from "../context/useDagFlowContext";

const useExpandCollapse = ({ setNodes, setEdges }: $TSFixMe) => {
  const { initialNodes, initialEdges, edgeType, nodesExpanded, setNodesExpanded } =
    useDagFlowContext();

  const isAllNodesExpanded = useMemo(
    () => every(nodesExpanded, (value) => value === true),
    [nodesExpanded]
  );

  const isAllNodesCollapsed = useMemo(
    () => every(nodesExpanded, (value) => value === false),
    [nodesExpanded]
  );

  const expandAllNodes = () => {
    setNodesExpanded(mapValues(nodesExpanded, () => true));
  };

  const collapseAllNodes = () => {
    setNodesExpanded(mapValues(nodesExpanded, () => false));
  };

  const layout = d3Tree<$TSFixMe>()?.nodeSize([25, 200]);

  const getNodesTree = () => {
    // Create a lookup map for the nodes
    const nodeMap = keyBy(initialNodes, "id");

    // Add a children array to each node
    forEach(nodeMap, (node: $TSFixMe) => {
      node.children = [];
    });

    // Build the hierarchical structure using the edges
    forEach(initialEdges, (edge) => {
      const sourceNode: $TSFixMe = nodeMap[edge?.source];
      const targetNode = nodeMap[edge?.target];
      if (sourceNode && targetNode) {
        sourceNode?.children?.push(targetNode);
      }
    });

    // Find the root nodes (nodes that are not targets in any edge)
    const targetIds = new Set(initialEdges?.map((edge: $TSFixMe) => edge?.target));
    const rootNodes = filter(nodeMap, (node) => !targetIds?.has(node?.id));

    // Output the hierarchical structure
    return rootNodes;
  };

  const getElements = (hierarchy: $TSFixMe) => {
    hierarchy?.descendants()?.forEach((d: $TSFixMe) => {
      d.children = d?.data?.expanded ? d?.data?.children : null;
    });

    const root = layout(hierarchy);

    const nodes = root?.descendants()?.map((d: $TSFixMe) => ({
      id: d?.data?.id,
      type: d?.data?.type,
      hidden: d?.data?.hidden,
      draggable: d?.data?.draggable,
      position: d?.data?.position || { x: d?.y, y: d?.x },
      data: { ...d?.data?.data }
    }));

    const edgeIds: { [key: string]: string } = {};
    forEach(initialEdges, (edge) => {
      edgeIds[`${edge?.source}_${edge?.target}`] = edge?.id;
    });

    const edges = root?.links()?.map((d: $TSFixMe, i: number) => ({
      id: edgeIds[`${d?.source?.data?.id}_${d?.target?.data?.id}`] || `${i}`,
      source: d?.source?.data?.id,
      target: d?.target?.data?.id,
      type: edgeType
    }));

    return { nodes, edges };
  };

  useEffect(() => {
    const dummyRootNode = { name: "rootNode", children: getNodesTree() };
    const hierarchy = d3Hierarchy<$TSFixMe>(dummyRootNode);

    hierarchy?.descendants()?.forEach((d: HierarchyNode<$TSFixMe>) => {
      d.data.expanded = nodesExpanded[d?.data?.id] ?? !d?.data?.data?.collapsed;
      d.data.id = d?.data?.id;
      d.data.children = d?.children;
      d.children = d?.children;
    });

    const initialElements = getElements(hierarchy);

    const nodes = uniqBy(hideRootNode(initialElements?.nodes), "id");
    const edges = uniqBy(initialElements?.edges, "id");

    setNodes(nodes);
    setEdges(edges);
  }, [initialNodes, initialEdges, nodesExpanded]);

  return { isAllNodesExpanded, isAllNodesCollapsed, expandAllNodes, collapseAllNodes };
};

export default useExpandCollapse;
