import { useCallback, useMemo } from 'react';
import { ensurePluginOrder, actions } from 'react-table';
import type { Hooks, TableInstance, ActionType, TableState, Row } from 'react-table';

import type { BaseRecord } from '../types';

export const useFilteredRowSelect = <T extends BaseRecord>(hooks: Hooks<T>) => {
	hooks.stateReducers.push(reducer);
	hooks.useInstance.push(useInstance);
};

const PLUGIN_NAME = 'useFilteredRowSelect';

useFilteredRowSelect.pluginName = PLUGIN_NAME;
actions.selectFilteredRows = 'selectFilteredRows';

function getSelectedFlatRows(selectedRows: Row<any>[]) {
	return selectedRows.reduce<string[]>((acc, { id: outerId, isGrouped, isExpanded, leafRows }) => {
		if (!isGrouped) {
			acc.push(outerId);
		} else if (!isExpanded) {
			acc.push(...leafRows.map(({ id }) => id));
		}
		return acc;
	}, []);
}

function reducer(state: TableState, action: ActionType) {
	if (action.type === actions.selectFilteredRows) {
		return {
			...state,
			selectedRowIds: action.value.reduce((acc, id) => {
				acc[id] = true;

				return acc;
			}, {})
		};
	}

	return state;
}

function useInstance<T extends BaseRecord>(instance: TableInstance<T>) {
	const { plugins, state, dispatch, selectedFlatRows } = instance;

	ensurePluginOrder(plugins, ['useRowSelect'], PLUGIN_NAME);

	const selectFilteredRows = useCallback(value => dispatch({ type: actions.selectFilteredRows, value }), [dispatch]);

	// since state in instance contains the prev state of the table in `state.selectedRowIds` we have selected rows before applying filters
	const prevSelectedRowIds = Object.keys(state.selectedRowIds);
	// meanwhile in `selectedFlatRows` from table instance we have rows selected rows after applying filters
	const newSelectedRowIds = useMemo(
		() => getSelectedFlatRows(selectedFlatRows),
		// For a better memo length is used, as case when length is the same but data is not
		// is a super edge case of programmatically changed selectedRows state
		// But keep in mind this can lead to misbehaviour
		[selectedFlatRows.length]
	);

	// This also ensures that removed rows leave no remnants in selectedRowIds state
	if (prevSelectedRowIds.length && prevSelectedRowIds.length !== newSelectedRowIds.length) {
		selectFilteredRows(newSelectedRowIds);
	}
}
