import {
	BarChartContent,
	ConditionalEvidence,
	CountLineChartContent,
	DataIndexDimensionType,
	ItemOptionDimension,
	Matrix,
	VisualizationDimensionType,
} from '@evasys/globals/evainsights/models/report-item';
import { isAbstention, isScalaItem } from '@evasys/globals/evainsights/helper/item';
import {
	combinations,
	findIndices,
	namedZipShortest,
	onlyIndexOf,
	zipShortest,
} from '@evasys/globals/shared/helper/array';
import {
	isMatrixNdimensional,
	matrixAt,
	matrixFullBy,
	matrixMultiply,
	matrixReshape,
	matrixShape,
	matrixSum,
	matrixTranspose,
} from '../matrix';
import { at, isEqual, range } from 'lodash';
import { OrdinalStatisticsFromCounts, Statistics } from '../math';
import { findPeriodMapping } from '@evasys/globals/evainsights/typeguards/axis-chart-dimension-mapping';
import { assertNotNullish, isNotNullish } from '@evasys/globals/evainsights/typeguards/common';

export enum DimensionSegmentationType {
	// Separate the responses based on the data dimension.
	// Makes the segmentation output separate statistics for each value of this dimension.
	SEPARATE = 'SEPARATE',
	// Add the responses from each value along this dimension
	MERGE = 'MERGE',
	// Retain the (ITEM_OPTION) dimension for the purpose of processing the scaled responses.
	// The different values along an ORDINAL dimension get placed in the same segment but will not be added.
	ORDINAL = 'ORDINAL',
}

export type Segment = CountSegment | OrdinalSegment;

// The segment's position with regard to the `content.data.responseCounts`.
// A `number` value at a specific index indicates that the segment refers to the one specific cross-section of all
// responses along that particular index in the corresponding dimension.
// A `null` value indicates that the segment combines the responses across all values along that dimension.
// A segment with a position of only `null`s thus summarizes all responses included in the chart, giving, e.g. the
// total number of responses to the question.
// In contrast, a segment with a position of only `number`s represents only those responses within a single number
// in the `responseCounts` matrix.
export type SegmentPosition = Array<number | null>;

/**
 * A segment combines the responses across possibly multiple values in the chart data responseCounts
 */
interface SegmentBase {
	position: SegmentPosition;
	abstentionCount: number;
	evidence: number;
}

export interface CountSegment extends SegmentBase {
	// The total number of responses within the segment
	responseCount: number;
}

export interface OrdinalSegment extends SegmentBase {
	// The number of responses to each option of a scaled question.
	// The values are in the normalized order, i.e., unaffected by the form's or question's mirroring settings.
	optionResponseCounts: number[];
}

export const getResponseCountAxisChartStatistics = (content: BarChartContent | CountLineChartContent) => {
	const segmentationTypes = determineDimensionSegmentationTypes(content);
	const dimensionAbstentionIndices = getDimensionAbstentionIndices(content);
	const segments = computeSegments(
		content.data.responseCounts,
		content.data.conditionalEvidences,
		segmentationTypes,
		dimensionAbstentionIndices
	);
	return segments.map((segment) => ({
		position: segment.position,
		statistics: getSegmentStatistics(segment),
	}));
};

/**
 * Determine how to handle the chart dimensions.
 * Visible for testing only.
 */
export const determineDimensionSegmentationTypes = (
	content: BarChartContent | CountLineChartContent
): DimensionSegmentationType[] => {
	const scalaItemOptionIds = new Set(
		content.data.items
			? content.data.items
					.filter(isScalaItem)
					.flatMap((item) => item.itemOptions.map((itemOption) => itemOption.id))
			: []
	);

	const isLongitudinalAnalysis = findPeriodMapping(content.config.dimensionMappings) === undefined;

	return content.config.dimensionMappings.map((mapping): DimensionSegmentationType => {
		switch (mapping.data.type) {
			case DataIndexDimensionType.ITEM_OPTION: {
				const isScalaOptionMapping = mapping.data.domain.every((point) =>
					point.itemOptionIds.every((itemOptionId) => scalaItemOptionIds.has(itemOptionId))
				);

				return isScalaOptionMapping && isLongitudinalAnalysis
					? DimensionSegmentationType.ORDINAL
					: DimensionSegmentationType.MERGE;
			}
			case DataIndexDimensionType.ITEM: {
				const isMappedToGroup = mapping.visualizations.some(
					(vis) => vis.type === VisualizationDimensionType.GROUP
				);
				return isMappedToGroup ? DimensionSegmentationType.SEPARATE : DimensionSegmentationType.MERGE;
			}
			default: {
				return DimensionSegmentationType.MERGE;
			}
		}
	});
};

// Visible for testing only.
export const computeSegments = (
	responseCounts: Matrix,
	conditionalEvidences: ConditionalEvidence[],
	segmentations: DimensionSegmentationType[],
	dimensionAbstentionIndices: Set<number>[]
): Segment[] => {
	const nonAbstentionCounts = getSegmentResponseCounts(
		getResponseCountsWithoutAbstentions(dimensionAbstentionIndices, responseCounts),
		segmentations
	);

	const abstentionCounts = getSegmentResponseCounts(
		matrixMultiply(responseCounts, getAbstentionMask(dimensionAbstentionIndices, matrixShape(responseCounts))),
		// no need for per-option response counts as with ordinal segmentations, we only want one number
		segmentations.map((seg) => (seg === DimensionSegmentationType.ORDINAL ? DimensionSegmentationType.MERGE : seg))
	);
	if (!isMatrixNdimensional(abstentionCounts, 1)) {
		throw Error('Abstention counts not one-dimensional');
	}

	const positions = combinations(
		zipShortest(segmentations, matrixShape(responseCounts)).map(([segmentation, size]) =>
			segmentation === DimensionSegmentationType.SEPARATE ? range(size) : [null]
		)
	);

	const evidences = getEvidences(conditionalEvidences, segmentations);

	return namedZipShortest({
		nonAbstentionCount: nonAbstentionCounts,
		abstentionCount: abstentionCounts,
		position: positions,
	}).map(({ position, abstentionCount, nonAbstentionCount }): Segment => {
		const evidence = matrixAt(evidences, position.filter(isNotNullish));
		const base = { position, abstentionCount, evidence };
		return typeof nonAbstentionCount === 'number'
			? { ...base, responseCount: nonAbstentionCount }
			: { ...base, optionResponseCounts: nonAbstentionCount };
	});
};

/**
 * Generate chart segments according to the segmentation types
 */
const getSegmentResponseCounts = (
	responseCounts: Matrix,
	segmentations: DimensionSegmentationType[]
): number[] | number[][] => {
	// notes for future implementation once we allow for more flexible charting:
	// - show scala statistics per group if there is only one scaled ITEM_OPTION dimension mapped inside the group
	// - show scala statistics per cell if there is only one scaled ITEM_OPTION dimension mapped inside the cell
	// - else show only count statistics inside the cell

	// Example: Chart with three dimensions of size 2×3×4 with two different configurations leading to segmentations
	//   a) [MERGE, ORDINAL, KEEP]
	//   b) [KEEP, MERGE, KEEP]
	// First sum up all counts along the MERGE dimensions.
	//   a) Sum out dimension 0, retain [1,2] -> 3×4
	//   b) Sum out dimension 1, retain [0,2] -> 2×4
	const retainedDimensionIndices = findIndices(
		segmentations,
		(segmentation) => segmentation !== DimensionSegmentationType.MERGE
	);
	let aggregatedCounts = matrixSum(responseCounts, { exclude: retainedDimensionIndices });

	// Then transpose the aggregatedCounts to move the ordinal dimension to the end if present
	//   a) Transpose axis = [1, 0] -> 4×3
	//   b) No ordinal axis, keep shape of 2×4
	const ordinalIndex = onlyIndexOf(at(segmentations, retainedDimensionIndices), DimensionSegmentationType.ORDINAL);
	if (ordinalIndex !== -1) {
		aggregatedCounts = matrixTranspose(aggregatedCounts, [
			// all axes apart from the ordinal axis
			...range(ordinalIndex),
			...range(ordinalIndex + 1, retainedDimensionIndices.length),
			// followed by the ordinal axis
			ordinalIndex,
		]);
	}

	// Then reshape to flatten all KEEP dimensions
	//   a) Has ORDINAL of length 3 -> reshape to (-1, 3) -> keeps its shape of 4×3
	//   b) Has no ORDINAL -> reshape to (-1) -> single-dimensional matrix of shape 8
	const dimensions = namedZipShortest({ segmentation: segmentations, size: matrixShape(responseCounts) });
	const ordinalSize = dimensions.find((dim) => dim.segmentation === DimensionSegmentationType.ORDINAL)?.size;
	aggregatedCounts = matrixReshape(aggregatedCounts, ordinalSize === undefined ? [-1] : [-1, ordinalSize]);
	if (!isMatrixNdimensional(aggregatedCounts, 1) && !isMatrixNdimensional(aggregatedCounts, 2)) {
		throw Error('The aggregated counts are not one- or two-dimensional');
	}

	return aggregatedCounts;
};

/**
 * Build a mask that contains 1 at each position that is an abstention and 0 otherwise
 */
const getAbstentionMask = (dimensionAbstentionIndices: Set<number>[], responseCountsShape: number[]): Matrix => {
	return matrixFullBy(responseCountsShape, (position) =>
		position.some((point, dimensionIndex) => dimensionAbstentionIndices[dimensionIndex].has(point)) ? 1 : 0
	);
};

/**
 * Returns the response counts matrix without all the positions that correspond to abstentions
 */
const getResponseCountsWithoutAbstentions = (
	dimensionAbstentionIndices: Set<number>[],
	responseCounts: Matrix
): Matrix => {
	// opposite of the dimensionAbstentionIndices. For each dimension which positions remain
	const dimensionInclusionIndices = matrixShape(responseCounts).map((dimSize, dimIndex) =>
		range(dimSize).filter((posIndex) => !dimensionAbstentionIndices[dimIndex].has(posIndex))
	);

	const outputShape = dimensionInclusionIndices.map((indices) => indices.length);

	return matrixFullBy(outputShape, (outputPosition) => {
		const inputPosition = outputPosition.map((posIndex, dimIndex) => dimensionInclusionIndices[dimIndex][posIndex]);
		return matrixAt(responseCounts, inputPosition);
	});
};

/**
 * Determines which indices along each dimension in the `content.data.responseCounts` matrix are abstention indices.
 * Visible for testing only.
 */
export const getDimensionAbstentionIndices = (content: BarChartContent | CountLineChartContent) => {
	const abstentionItemOptionIds = new Set(
		content.data.items?.flatMap((item) => item.itemOptions.filter(isAbstention).map((itemOption) => itemOption.id))
	);
	// for each dimension which indices correspond to abstention item option ids
	return content.config.dimensionMappings.map(({ data }): Set<number> => {
		if (data.type !== DataIndexDimensionType.ITEM_OPTION) {
			return new Set();
		} else {
			const points: ItemOptionDimension['domain'] = data.domain.filter((point) => !point.exclude);
			return new Set(
				findIndices(points, (point) =>
					point.itemOptionIds.every((itemOptionId) => abstentionItemOptionIds.has(itemOptionId))
				)
			);
		}
	});
};

const getEvidences = (conditionalEvidences: ConditionalEvidence[], segmentation: DimensionSegmentationType[]) => {
	const dimensionIndices = findIndices(segmentation, (seg) => seg === DimensionSegmentationType.SEPARATE);
	const conditionalEvidence = conditionalEvidences.find((conditionalEvidence) =>
		isEqual(conditionalEvidence.evidenceDimensionIndices, dimensionIndices)
	);
	assertNotNullish(conditionalEvidence);
	return conditionalEvidence.evidence;
};

export interface ResponseSegmentStatistics<Responses extends Statistics = Statistics> {
	evidence: number;
	abstentions: number;
	// statistics about the selected response values
	values: Responses;
}

export const getSegmentStatistics = (segment: Segment): ResponseSegmentStatistics => ({
	evidence: segment.evidence,
	abstentions: segment.abstentionCount,
	values: getSegmentNonAbstentionStatistics(segment),
});

const getSegmentNonAbstentionStatistics = (segment: Segment): Statistics => {
	if ('responseCount' in segment) {
		return { sampleSize: segment.responseCount };
	} else if (segment.optionResponseCounts.some((count) => count > 0)) {
		return new OrdinalStatisticsFromCounts(segment.optionResponseCounts);
	} else {
		return { sampleSize: 0 };
	}
};
