import { Matrix } from 'ml-matrix';
import { multivariateNormalPDF, argmax } from '../../mathUtils';
import { CLUSTERING_CONFIG } from '../constants';

/**
 * Gaussian Mixture Model implementation
 */
export class GMM {
  private means: Matrix;
  private covariances: Matrix[];
  private weights: number[];
  private nClusters: number;
  private nFeatures: number;

  /**
   * Create a new GMM instance
   * @param nClusters Number of clusters/components
   * @param nFeatures Number of features in the dataset
   */
  constructor(nClusters: number, nFeatures: number) {
    this.nClusters = nClusters;
    this.nFeatures = nFeatures;
    this.means = Matrix.zeros(nClusters, nFeatures);
    this.covariances = Array(nClusters).fill(Matrix.eye(nFeatures));
    this.weights = Array(nClusters).fill(1 / nClusters);
  }

  /**
   * E-step: Calculate responsibilities
   * @param X Data matrix
   * @returns Matrix of responsibilities
   */
  private expectation(X: Matrix): Matrix {
    const responsibilities = Matrix.zeros(X.rows, this.nClusters);
    
    for (let i = 0; i < X.rows; i++) {
      let total = 0;
      for (let k = 0; k < this.nClusters; k++) {
        const prob = multivariateNormalPDF(
          X.getRow(i),
          this.means.getRow(k),
          this.covariances[k]
        ) * this.weights[k];
        responsibilities.set(i, k, prob);
        total += prob;
      }
      // Normalize probabilities
      if (total > 0) {
        responsibilities.setRow(i, responsibilities.getRow(i).map(v => v / total));
      }
    }
    return responsibilities;
  }

  /**
   * M-step: Update parameters
   * @param X Data matrix
   * @param responsibilities Responsibility matrix
   */
  private maximization(X: Matrix, responsibilities: Matrix): void {
    const Nk = responsibilities.sum('column');
    
    // Update weights
    this.weights = Nk.map(n => n / X.rows);
    
    // Update means
    for (let k = 0; k < this.nClusters; k++) {
      if (Nk[k] > 0) {
        const mean = X.transpose().mmul(new Matrix([responsibilities.getColumn(k)])).div(Nk[k]);
        this.means.setRow(k, mean.getColumn(0));
      }
    }
    
    // Update covariances
    for (let k = 0; k < this.nClusters; k++) {
      if (Nk[k] > 0) {
        const diff = X.sub(new Matrix([this.means.getRow(k)]));
        const weightedDiff = diff.mulColumnVector(new Matrix([responsibilities.getColumn(k)]));
        const covariance = diff.transpose().mmul(weightedDiff).div(Nk[k]);
        
        // Add a small regularization term to prevent singularity
        for (let i = 0; i < this.nFeatures; i++) {
          covariance.set(i, i, covariance.get(i, i) + 1e-6);
        }
        
        this.covariances[k] = covariance;
      }
    }
  }

  /**
   * Calculate log likelihood of the model
   * @param X Data matrix
   * @returns Log likelihood value
   */
  private calculateLogLikelihood(X: Matrix): number {
    let logLikelihood = 0;
    
    for (let i = 0; i < X.rows; i++) {
      let sum = 0;
      for (let k = 0; k < this.nClusters; k++) {
        sum += multivariateNormalPDF(
          X.getRow(i),
          this.means.getRow(k),
          this.covariances[k]
        ) * this.weights[k];
      }
      logLikelihood += Math.log(Math.max(sum, 1e-10)); // Prevent log(0)
    }
    
    return logLikelihood;
  }

  /**
   * Fit the GMM model to data
   * @param X Data matrix
   * @param maxIterations Maximum number of iterations
   * @param tolerance Convergence tolerance
   * @returns Array of cluster assignments
   */
  fit(
    X: Matrix, 
    maxIterations = CLUSTERING_CONFIG.MAX_ITERATIONS, 
    tolerance = CLUSTERING_CONFIG.CONVERGENCE_THRESHOLD
  ): number[] {
    let prevLogLikelihood = -Infinity;
    let labels: number[] = [];
    
    for (let iter = 0; iter < maxIterations; iter++) {
      // E-step
      const responsibilities = this.expectation(X);
      
      // M-step
      this.maximization(X, responsibilities);
      
      // Calculate log likelihood
      const logLikelihood = this.calculateLogLikelihood(X);
      
      // Check convergence
      if (Math.abs(logLikelihood - prevLogLikelihood) < tolerance) {
        console.log(`GMM converged after ${iter + 1} iterations`);
        break;
      }
      
      prevLogLikelihood = logLikelihood;
      labels = argmax(responsibilities, 'row');
    }
    
    return labels;
  }

  /**
   * Predict cluster probabilities for new data
   * @param X Data matrix
   * @returns Matrix of cluster probabilities
   */
  predict(X: Matrix): Matrix {
    const predictions = Matrix.zeros(X.rows, this.nClusters);
    
    for (let i = 0; i < X.rows; i++) {
      let total = 0;
      for (let k = 0; k < this.nClusters; k++) {
        const prob = multivariateNormalPDF(
          X.getRow(i),
          this.means.getRow(k),
          this.covariances[k]
        ) * this.weights[k];
        predictions.set(i, k, prob);
        total += prob;
      }
      
      // Normalize probabilities
      if (total > 0) {
        predictions.setRow(i, predictions.getRow(i).map(v => v / total));
      }
    }
    
    return predictions;
  }
}
