import { isMatrixNdimensional, matrixShape } from '../common';
import { sum } from 'lodash';
import { nestedPositions } from '../../array';
import { invertAxes, MatrixAxisSpecification } from '../axes';
import { associativeMatrixAggregate } from './util';
import { zeros } from '../creation';
import { Matrix } from '@evasys/globals/evainsights/models/report-item';

export function matrixSum(matrix: Matrix, axis: { exclude: number }): number[];
export function matrixSum(matrix: Matrix, axis?: MatrixAxisSpecification): Matrix;
export function matrixSum(matrix: Matrix, axis?: MatrixAxisSpecification): Matrix {
	const retainedAxes = invertAxes(matrix, axis);

	if (retainedAxes.length === 0) {
		return associativeMatrixAggregate(matrix, sum);
	}

	const inputShape = matrixShape(matrix);
	const outputShape = retainedAxes.map((ax) => inputShape[ax]);
	const result = zeros(outputShape);
	for (const [position, value] of nestedPositions(matrix)) {
		const resultPosition = retainedAxes.map((ax) => position[ax]);
		incrementCell(result, resultPosition, value);
	}

	return result;
}

/**
 * Increment one value in a matrix. Modifies the matrix.
 */
const incrementCell = (matrix: Matrix, position: number[], increment: number) => {
	if (position.length === 0) {
		throw Error('Cannot increment a cell of a zero-dimensional matrix');
	}

	let subMatrix = matrix;
	for (const i of position.slice(0, position.length - 1)) {
		if (!Array.isArray(subMatrix)) {
			throw Error('Matrix has fewer dimensions than position indicates');
		}

		subMatrix = subMatrix[i];
	}

	if (!isMatrixNdimensional(subMatrix, 1)) {
		throw Error('Matrix has fewer dimensions than position indicates');
	}

	subMatrix[position[position.length - 1]] += increment;
};
