import { useMutation } from "@tanstack/react-query";
import { useCommonStore } from "common/store";
import { AnalyticsAPIReq, AnalyticsResponse } from "common/types/types";
import { csvParseRows } from "d3";
import debounce from "lodash/debounce";
import { useCriteriaBuilder } from "modules/core/Core";
import { Scope } from "modules/scope-metadata/types";
import { useSearchStore } from "modules/search/store";
import { useEffect, useMemo, useState } from "react";
import { SankeyData } from "../types";
import groupBy from "lodash/groupBy";

interface SankeyTransformerProps {
	topN: number;
	groupBy: Array<string>;
	shouldAppendDimensionInName?: boolean;
	useLog: boolean;
}
interface SankeyHookProps extends SankeyTransformerProps {
	statistics: Array<string>;
	scope: Scope;
}

function useAggregateAPI(scope: String) {
	return useMutation<AnalyticsResponse, Error, AnalyticsAPIReq>([
		`${scope}Aggregate`,
		"aggregate",
	]);
}

const conflictingDimensions: { [key: string]: boolean } = {
	assetrisk: true,
	businessvalue: true,
	attacksurface: true,
	blastradius: true,
};

export function useSankeyData(props: SankeyHookProps) {
	const api = useAggregateAPI(props.scope);
	const mutator = useMemo(
		() => debounce(api.mutateAsync, 500),
		[api.mutateAsync]
	);

	const reset = api.reset;

	const searchQuery = useSearchStore(state => state.search);
	const facetState = useCommonStore(state => state.facets);
	const metadata = useCommonStore(state => state.metadata);

	// TODO: facet state without selection
	const criteria = useCriteriaBuilder(searchQuery, facetState, metadata);

	//   TODO: memoize
	const groupS = props.groupBy.join(",");
	const statsS = props.statistics.join(",");

	const [sankeyData, setSankeyData] = useState<SankeyData | undefined>(
		undefined
	);

	useEffect(() => {
		reset();
		if (!criteria) {
			return;
		}

		const fetch = async () => {
			let groupBy = groupS.split(",");
			mutator(
				{
					criteria,
					groupBy,
					statistics: statsS.split(","),
					scope: props.scope,
				},
				{
					onSuccess(data) {
						if (!data) {
							return;
						}
						let conflictingDimensionsCounts = 0;
						groupBy?.forEach(g => {
							conflictingDimensions[g] && conflictingDimensionsCounts++;
						});
						let transformed = transformToSankeyData(data, {
							topN: props.topN,
							groupBy,
							shouldAppendDimensionInName: conflictingDimensionsCounts > 1,
							useLog: props.useLog,
						});
						setSankeyData(transformed);
					},
				}
			);
		};

		fetch();
	}, [
		criteria,
		props.scope,
		mutator,
		reset,
		groupS,
		statsS,
		props.topN,
		props.useLog,
	]);

	return sankeyData;
}

interface GroupStatistic {
	key: string;
	firstDimensionName: string | undefined;
	secondDimensionName: string | undefined;
	count: number;
	log: number;
}

function getStatistics(
	items: any,
	firstDimensionName: string,
	counts: Map<string, { count: number; log: number }>,
	shouldAppendDimensionInName: boolean
) {
	const firstDimensionKeys = Object.keys(items);
	let agg = 0;
	let aggLog = 0;
	firstDimensionKeys?.forEach(firstDimensionValue => {
		let secondDimensionName = Object.keys(items[firstDimensionValue])[0];
		const secondDimensionKeys = Object.keys(
			items[firstDimensionValue][secondDimensionName]
		);
		secondDimensionKeys?.forEach(secondDimensionValue => {
			let data =
				items[firstDimensionValue][secondDimensionName][secondDimensionValue];
			let count = 0;
			let log = 0;

			if (data["statistics"]) {
				count = data["statistics"][`assetidcount`];
				log = logFn(count);
			} else {
				let countHolder = getStatistics(
					{ [secondDimensionValue]: data },
					secondDimensionName,
					counts,
					shouldAppendDimensionInName
				);
				count = countHolder.agg;
				log = countHolder.aggLog;
			}

			let fValue = firstDimensionValue;
			let sValue = secondDimensionValue;

			if (!firstDimensionValue) {
				firstDimensionValue = "<nil>";
			}

			if (firstDimensionValue === "<nil>") {
				fValue = `Untagged ${firstDimensionName}s`;
			}

			if (!secondDimensionValue) {
				secondDimensionValue = "<nil>";
			}
			if (secondDimensionValue === "<nil>") {
				sValue = `Untagged ${secondDimensionName}s`;
			}

			if (shouldAppendDimensionInName) {
				fValue = fValue + "/" + firstDimensionName;
			}
			if (shouldAppendDimensionInName) {
				sValue = sValue + "/" + secondDimensionName;
			}
			let key = `${firstDimensionName}::${fValue},${secondDimensionName}::${sValue}`;
			let oldCountHolder = counts.get(key) || { count: 0, log: 0 };

			counts.set(key, {
				count: oldCountHolder.count + count,
				log: oldCountHolder.log + log,
			});
			agg += count;
			aggLog += log;
		});
	});
	return { agg, aggLog };
}

function transformToSankeyData(
	data: any,
	props: SankeyTransformerProps
): SankeyData | undefined {
	// encounteredValues.clear();
	if (!data || !data.items) {
		return undefined;
	}
	const { topN, groupBy: groupOrder, useLog } = props;

	data = { ...data };
	data.items = {
		[groupOrder[0]]: data.items,
	};

	let counts = new Map<string, { count: number; log: number }>();

	getStatistics(
		data,
		"items",
		counts,
		props.shouldAppendDimensionInName || false
	);

	let results: Array<string> = [];

	let aggregatedStats: {
		[key: string]: { stats: Array<GroupStatistic>; total: number };
	} = {};
	counts?.forEach((count, key) => {
		let keySplit = key.split(",");
		let firstDimensionName = getNameAndValue(keySplit[0])[0];
		let secondDimensionName = getNameAndValue(keySplit[1])[0];

		let groupParent = keySplit[0];
		let groupValues = aggregatedStats[groupParent]?.stats || [];

		groupValues.push({
			key,
			firstDimensionName,
			secondDimensionName,
			count: count.count,
			log: count.log,
		});
		aggregatedStats[groupParent] = {
			stats: groupValues,
			total: groupValues.reduce((prev, s) => prev + s.count, 0),
		};
	});

	delete aggregatedStats["items::items"];
	let excludedDimensions = new Map();
	let includedDimensions = new Map();

	let groupParents = Object.keys(aggregatedStats);

	let groupedStats = groupBy(groupParents, a => a.split(":")[0]);
	let dimensions = Object.keys(groupedStats);

	let sortedKeys = dimensions.sort(
		(a, b) =>
			groupOrder.indexOf(getNameAndValue(b)[0]) -
			groupOrder.indexOf(getNameAndValue(a)[0])
	);

	let groupResults: { [key: string]: GroupStatistic } = {};
	const pushToResults = (key: string, stat: GroupStatistic) => {
		let old = groupResults[key];
		if (!old) {
			groupResults[key] = stat;
			return;
		}
		old.count += stat.count;
		old.log += stat.log;
	};

	let lastDimension = sortedKeys[0];
	let lastDimensionStats = groupedStats[lastDimension];
	lastDimensionStats?.sort((a, b) => {
		const aS = aggregatedStats[a];
		const bS = aggregatedStats[b];
		return bS.total - aS.total;
	});

	let totalLastDimensionsPicked = 0;
	lastDimensionStats?.forEach(v => {
		let dimensionCounts = aggregatedStats[v].stats;
		for (let i = 0; i < dimensionCounts.length; i++) {
			let keySplit = dimensionCounts[i].key.split(",");

			let childKey = keySplit[1];
			if (
				totalLastDimensionsPicked > topN &&
				!includedDimensions.get(childKey)
			) {
				excludedDimensions.set(childKey, true);
			} else {
				totalLastDimensionsPicked++;
				includedDimensions.set(childKey, true);
			}
		}
	});

	sortedKeys?.forEach(dimension => {
		let dimensionStats = groupedStats[dimension];

		dimensionStats.sort((a, b) => {
			const aS = aggregatedStats[a];
			const bS = aggregatedStats[b];
			return bS.total - aS.total;
		});

		dimensionStats?.forEach((parent, index) => {
			let dimensionCounts = aggregatedStats[parent].stats;

			let firstDimensionValue = getNameAndValue(parent)[1];

			for (let i = 0; i < dimensionCounts.length; i++) {
				let keySplit = dimensionCounts[i].key.split(",");

				let childKey = keySplit[1];
				let isChildExcluded =
					excludedDimensions.get(childKey) && !includedDimensions.get(childKey);

				let childName = isChildExcluded
					? `Other ${dimensionCounts[0].secondDimensionName}s`
					: getNameAndValue(childKey)[1];

				if (index < topN) {
					dimensionCounts[i].key = `${firstDimensionValue},${childName}`;

					includedDimensions.set(keySplit[0], true);
				} else {
					let otherFirstDimension = `Other ${dimensionCounts[0].firstDimensionName}s`;
					dimensionCounts[i].key = `${otherFirstDimension},${childName}`;

					if (!includedDimensions.get(keySplit[0])) {
						excludedDimensions.set(keySplit[0], true);
					}
				}

				pushToResults(dimensionCounts[i].key, dimensionCounts[i]);
			}
		});
	});

	function isGroupStatistic(obj: any): obj is GroupStatistic {
		return (
			(obj && obj.firstDimensionName !== undefined) ||
			(obj && obj.secondDimensionName !== undefined)
		);
	}

	Object.values(groupResults)?.forEach(groupCount => {
		if (
			!isGroupStatistic(groupCount) ||
			groupCount.key.startsWith("items") ||
			groupCount.key.startsWith(groupOrder[0])
		) {
			return;
		}
		let count = groupCount.count;
		let firstDimensionName = groupCount.firstDimensionName || "";
		let secondDimensionName = groupCount.secondDimensionName || "";
		results.push(
			`${groupCount.key},${count},${groupCount.log},${firstDimensionName},${secondDimensionName}`
		);
	});

	const links = csvParseRows(
		results.join("\n"),
		([source, target, value, log, sourceDimension, targetDimension]) => {
			let logN = !+log || isNaN(+log) ? 0 : +log;

			let valueN = !+value || isNaN(+value) ? 0 : +value;
			if (!useLog) {
				logN = valueN;
			}
			return source && target
				? {
						source: `${sourceDimension}:${source}`,
						target: `${targetDimension}:${target}`,
						value: valueN,
						log: logN,
						sourceDimension,
						targetDimension,
					}
				: null;
		}
	);
	const nodeByName = new Map();
	for (const link of links) {
		if (!nodeByName.has(link.source))
			nodeByName.set(link.source, {
				name: link.source,
				id: link.source,
				dimension: link.sourceDimension,
			});
		if (!nodeByName.has(link.target))
			nodeByName.set(link.target, {
				name: link.target,
				id: link.target,
				dimension: link.targetDimension,
			});
	}

	return { nodes: Array.from(nodeByName.values()), links };
}

function getNameAndValue(key: string) {
	return key.split("::");
}

function logFn(num: number) {
	if (num === 1) {
		return 0.1;
	}
	return Math.log(num);
}
