import numpy as np
from scipy.fft import fft2, ifft2, fftshift, ifftshift


# ===== Fast SFS Algorithm Implementation =====

def optimal_kspace_sampling(size, ros_mask, n_samples=None, verbose=True):
    # Optimized implementation of k-space sample selection using SFS algorithm.
    # This implements the recursive formulation from the paper by Gao and Reeves for both
    # underdetermined and determined cases.
    
    #Returns:
    #    selected_indices: Indices of selected k-space samples
    #    sampling_pattern: 2D array indicating selected samples

    ros_indices = np.nonzero(ros_mask.flatten())[0]
    n_ros = len(ros_indices)
    
    # Set default number of samples if not specified
    if n_samples is None:
        n_samples = n_ros
    
    if verbose:
        print(f"ROS size: {n_ros} pixels")
        print(f"Number of samples to select: {n_samples}")
    
    N = size * size
    
    # Initialize selected indices with DC component
    mid_point = size // 2
    dc_idx = mid_point * size + mid_point
    selected_indices = [dc_idx]
    
    if verbose:
        print(f"Selected sample 1/{n_samples} (DC component)")
    
    # Initialize vectors and matrices needed for recursive updates
    v_dict = {}  # Store v_i for each candidate
    u_dict = {}  # Store u_i for each candidate
    sigma_dict = {}  # Store sigma_i for each candidate
    
    # Create the Fourier basis vectors for ROS points
    # Correspond to columns of the Fourier encoding matrix
    fourier_basis = np.zeros((N, n_ros), dtype=complex)
    
    for j, ros_idx in enumerate(ros_indices):
        ros_y, ros_x = ros_idx // size, ros_idx % size
        
        # For each k-space point
        for k_y in range(size):
            for k_x in range(size):
                k_idx = k_y * size + k_x
                
                # Fourier encoding
                phase = -2j * np.pi * ((k_y * ros_y / size) + (k_x * ros_x / size))
                fourier_basis[k_idx, j] = np.exp(phase)
    
    # Initialize quantities for the first sample (DC component)
    a_dc = fourier_basis[dc_idx, :]
    sigma_dc = np.real(np.vdot(a_dc, a_dc))
    
    # For underdetermined phase (samples < ROS size)
    # Initialize dictionaries for all candidates
    for k_idx in range(N):
        if k_idx == dc_idx:
            continue
            
        a_k = fourier_basis[k_idx, :]
        
        # For first sample, v_i is simple
        v_dict[k_idx] = a_dc / sigma_dc
        
        # Calculate sigma_i
        sigma_dict[k_idx] = np.real(np.vdot(a_k, a_k) - np.abs(np.vdot(a_k, a_dc))**2 / sigma_dc)
    
    # Main selection loop
    for i in range(1, n_samples):
        # Use different selection criteria based on phase
        if i < n_ros:
            # Underdetermined phase - eq. (21) from Gao and Reeves
            best_idx = max(
                [k for k in range(N) if k not in selected_indices],
                key=lambda k: sigma_dict[k]
            )
        else:
            # Determined/overdetermined phase - eq. (11)
            # Since we don't have A_H_A_inv directly, we need to calculate 
            # a different but equivalent criterion
            # We're looking for the sample that maximizes the contribution
            best_idx = max(
                [k for k in range(N) if k not in selected_indices],
                key=lambda k: np.vdot(fourier_basis[k, :], u_dict.get(k, np.zeros(n_ros)))
            )
        
        # Add the selected sample
        selected_indices.append(best_idx)
        
        if verbose and ((i+1) % 50 == 0 or i+1 == n_samples):
            print(f"Selected sample {i+1}/{n_samples}")
        
        # Get the vector for the newly selected sample
        a_i = fourier_basis[best_idx, :]
        sigma_i = sigma_dict.get(best_idx, 0)
        
        # Transition from underdetermined to determined case
        if i + 1 == n_ros:
            # Initialize the determined phase vectors
            # We need to build the u_i vectors which are related to rows of (A^H A)^-1
            
            A_selected = np.zeros((n_ros, n_ros), dtype=complex)
            for j, k_idx in enumerate(selected_indices):
                A_selected[j, :] = fourier_basis[k_idx, :]
            
            try:
                A_inv = np.linalg.inv(A_selected)
                
                # Initialize u_dict for all candidates
                for k_idx in range(N):
                    if k_idx in selected_indices:
                        continue
                        
                    a_k = fourier_basis[k_idx, :]
                    u_dict[k_idx] = A_inv.conj().T @ a_k.conj()
                    
            except np.linalg.LinAlgError:
                # If there's an issue with the inverse, fall back to a simpler approach
                print("Matrix inversion failed, using pseudoinverse")
                
                # Continue with the same criterion as underdetermined case
                for k_idx in range(N):
                    if k_idx in selected_indices:
                        continue
        
                    # Update sigma_dict for remaining candidates
                    a_k = fourier_basis[k_idx, :]
                    v_i = v_dict.get(k_idx, np.zeros(n_ros))
                    
                    # Approximate update
                    contribution = np.abs(np.vdot(a_k, a_i))**2 / sigma_i
                    sigma_dict[k_idx] = max(0, sigma_dict.get(k_idx, 0) - contribution)
            continue
            
        # Update the sigma_dict and v_dict for all remaining candidates
        if i + 1 < n_ros:
            # Underdetermined phase update (eq. 24)
            for k_idx in range(N):
                if k_idx in selected_indices:
                    continue
                
                a_k = fourier_basis[k_idx, :]
                v_i = v_dict.get(k_idx, np.zeros(n_ros))
                
                # Update v_i using recursive formula (simplified version of eq. 24)
                v_new = v_i - (np.vdot(a_k, a_i) / sigma_i) * v_dict.get(best_idx, np.zeros(n_ros))
                v_dict[k_idx] = v_new
                
                # Update sigma_i (simplified version of eq. 31)
                contribution = np.abs(np.vdot(a_k, a_i))**2 / sigma_i
                sigma_dict[k_idx] = max(0, sigma_dict.get(k_idx, 0) - contribution)
        else:
            # Determined/overdetermined phase update (eq. 42)
            for k_idx in range(N):
                if k_idx in selected_indices:
                    continue
                    
                a_k = fourier_basis[k_idx, :]
                
                # Update u_i using recursive formula (simplified version of eq. 42)
                if k_idx in u_dict:
                    u_i = u_dict[k_idx]
                    u_best = u_dict.get(best_idx, np.zeros(n_ros))
                    
                    # Sherman-Morrison formula for updating u_i
                    factor = np.vdot(a_i, u_i) / (1 + np.vdot(a_i, u_best))
                    u_new = u_i - factor * u_best
                    u_dict[k_idx] = u_new
    
    # Create sampling pattern for visualization
    sampling_pattern = np.zeros((size, size))
    for idx in selected_indices:
        y, x = idx // size, idx % size
        sampling_pattern[y, x] = 1
    
    return selected_indices, sampling_pattern


# ===== Compressed Sensing Sampling Function =====

def cs_sampling(phantom, ros_mask, n_samples=None, tradeoff=0.3, undersamp_factor=0.25):
    # Generate a k-space sampling pattern that trades off between ROI-specific sampling and a prespecified distribution (Gaussian)
        
    # Returns:
     #   selected_indices: CS-compatible k-space sampling locations
     #   sampling_pattern: 2D sampling mask
    size = phantom.shape[0]
    n_ros = np.sum(ros_mask)
    
    if n_samples is None:
        n_samples = int(undersamp_factor*n_ros)
    
    # Determine the split between "deterministic" (ROI-optimized) and random samples
    n_deterministic = int((1 - tradeoff) * n_samples)
    n_random = int(n_samples - n_deterministic)
    
    # Get deterministic samples using the ROI-specific algorithm
    deterministic_indices, _ = optimal_kspace_sampling(
        size, ros_mask, n_samples=n_deterministic, verbose=False
    )
    
    # Create probability distribution for random sampling
    # Centered Gaussian distribution favoring lower frequencies
    pdf = np.zeros((size, size))
    center = size // 2
    
    for y in range(size):
        for x in range(size):
            # Distance from center (normalized)
            dist = np.sqrt(((y - center) / size)**2 + ((x - center) / size)**2)
            # Gaussian PDF
            pdf[y, x] = np.exp(-10 * dist**2)
    
    pdf = pdf / np.sum(pdf)
    pdf_flat = pdf.flatten()
    
    # Select random samples according to the PDF
    candidate_indices = [i for i in range(size*size) if i not in deterministic_indices]
    # Normalize probabilities for remaining candidates
    candidate_probs = [pdf_flat[i] for i in candidate_indices]
    total_prob = sum(candidate_probs)
    if total_prob > 0:
        #print(n_random)
        candidate_probs = [p / total_prob for p in candidate_probs]
        random_indices = np.random.choice(
            candidate_indices, 
            size=min(n_random, len(candidate_indices)),
            replace=False,
            p=candidate_probs
        )
    else:
        # Fallback if all probabilities are zero
        random_indices = np.random.choice(
            candidate_indices,
            size=min(n_random, len(candidate_indices)),
            replace=False
        )
    
    # Combine deterministic and random samples
    selected_indices = list(deterministic_indices) + list(random_indices)
    
    # Create sampling pattern for visualization
    sampling_pattern = np.zeros((size, size))
    for idx in selected_indices:
        y, x = idx // size, idx % size
        sampling_pattern[y, x] = 1
    
    return selected_indices, sampling_pattern


# ===== Standard Reconstruction Function =====

def reconstruct_from_samples(phantom, selected_indices, ros_mask, size):
    # Reconstruct the image from selected k-space samples using standard matrix inversion.
    # Todo: implement this w/ CG
    
    k_space = fftshift(fft2(phantom))
    k_space_flat = k_space.flatten()
    
    ros_indices = np.nonzero(ros_mask.flatten())[0]
    n_ros = len(ros_indices)
    
    # Create the encoding matrix for selected samples
    n_samples = len(selected_indices)
    A = np.zeros((n_samples, n_ros), dtype=complex)
    
    for i, k_idx in enumerate(selected_indices):
        k_y, k_x = k_idx // size, k_idx % size
        
        for j, r_idx in enumerate(ros_indices):
            r_y, r_x = r_idx // size, r_idx % size
            
            # Fourier encoding
            phase = -2j * np.pi * ((k_y * r_y / size) + (k_x * r_x / size))
            A[i, j] = np.exp(phase)
    
    k_space_selected = k_space_flat[selected_indices]
    
    # Solve the linear system
    if n_samples < n_ros:
        # Underdetermined system - use minimum norm solution
        try:
            x = A.conj().T @ np.linalg.inv(A @ A.conj().T) @ k_space_selected
        except np.linalg.LinAlgError:
            # Use pseudoinverse if direct inverse fails
            print("Explicit min. norm solution failed - using psuedoinverse")
            x = A.conj().T @ np.linalg.pinv(A @ A.conj().T) @ k_space_selected
    elif n_samples == n_ros:
        # Determined system
        try:
            x = np.linalg.solve(A, k_space_selected)
        except np.linalg.LinAlgError:
            # Use pseudoinverse if direct solve fails
            print("np.linalg.solve failed - trying psuedoinverse")
            x = np.linalg.pinv(A) @ k_space_selected
    else:
        # Overdetermined system - use least squares
        x = np.linalg.lstsq(A, k_space_selected, rcond=None)[0]
    
    image_reconstructed = np.zeros(size*size, dtype=complex)
    image_reconstructed[ros_indices] = x
    
    return image_reconstructed.reshape(size, size)

def reconstruct_from_samples_cg(phantom, selected_indices, ros_mask, size, max_iter=100, tol=1e-6, verbose=False):
    """
    Reconstruct the image from selected k-space samples using the Conjugate Gradient method.
    
    Parameters:
    -----------
    phantom : ndarray
        Original image (used for k-space data generation)
    selected_indices : list
        Indices of selected k-space samples
    ros_mask : ndarray
        Binary mask of the region of support (ROS)
    size : int
        Size of the image (assuming square image)
    max_iter : int, optional
        Maximum number of CG iterations
    tol : float, optional
        Convergence tolerance
    verbose : bool, optional
        Whether to print progress information
        
    Returns:
    --------
    tuple
        (reconstructed_image, reconstruction_history)
        - reconstructed_image: ndarray, the final reconstructed image
        - reconstruction_history: list of ndarrays, history of intermediate reconstructions
    """
    # Get k-space of the phantom
    k_space = fftshift(fft2(phantom))
    k_space_flat = k_space.flatten()
    
    # Get samples from k-space
    k_space_selected = k_space_flat[selected_indices]
    
    # Extract ROS indices
    ros_indices = np.nonzero(ros_mask.flatten())[0]
    n_ros = len(ros_indices)
    n_samples = len(selected_indices)
    
    if verbose:
        print(f"Reconstructing with CG: {n_samples} samples, {n_ros} ROS pixels")
    
    # For CG, we need functions to compute A*x and A^H*y without forming the matrix explicitly
    def forward_op(x):
        """Apply A*x: ROS to k-space samples"""
        # Create a full-size image with ROS values
        full_image = np.zeros(size*size, dtype=complex)
        full_image[ros_indices] = x
        full_image = full_image.reshape(size, size)
        
        # Apply Fourier transform
        k_space_full = fftshift(fft2(full_image))
        
        # Extract selected samples
        return k_space_full.flatten()[selected_indices]
    
    def adjoint_op(y):
        """Apply A^H*y: k-space samples to ROS"""
        # Create a full-size k-space with selected samples
        k_space_full = np.zeros(size*size, dtype=complex)
        k_space_full[selected_indices] = y
        k_space_full = k_space_full.reshape(size, size)
        
        # Apply inverse Fourier transform
        image_full = ifft2(ifftshift(k_space_full))
        
        # Extract ROS values
        return image_full.flatten()[ros_indices]
    
    # Create zero-filled reconstruction to use as initialization
    k_space_zero_filled = np.zeros(size*size, dtype=complex)
    k_space_zero_filled[selected_indices] = k_space_selected
    k_space_zero_filled = k_space_zero_filled.reshape(size, size)
    x_full_init = ifft2(ifftshift(k_space_zero_filled))
    
    # Initialize solution with zero-filled reconstruction (warm start)
    x = x_full_init.flatten()[ros_indices]
    
    # Store reconstruction history
    x_history = []
    current_full_img = np.zeros((size, size), dtype=complex)
    current_full_img_flat = current_full_img.flatten()
    current_full_img_flat[ros_indices] = x
    x_history.append(current_full_img.copy())
    
    # Compute residual: r = b - A*x
    r = k_space_selected - forward_op(x)
    
    # Initialize search direction with residual's adjoint
    p = adjoint_op(r)
    
    # Initialize residual norm for convergence check
    r_norm_sq = np.vdot(r, r).real
    initial_r_norm_sq = r_norm_sq
    
    if verbose:
        print(f"Initial residual: {np.sqrt(r_norm_sq)}")
    
    # Main CG loop
    for iter in range(max_iter):
        # Apply forward operator to search direction
        Ap = forward_op(p)
        
        # Compute step size
        pAp = np.vdot(Ap, Ap).real
        alpha = r_norm_sq / (pAp + 1e-10)  # Add small value for stability
        
        # Update solution
        x = x + alpha * p
        
        # Create full image for history
        current_full_img_flat = np.zeros(size*size, dtype=complex)
        current_full_img_flat[ros_indices] = x
        current_full_img = current_full_img_flat.reshape(size, size)
        x_history.append(current_full_img.copy())
        
        # Update residual
        r = r - alpha * Ap
        
        # Compute new residual norm
        r_norm_sq_new = np.vdot(r, r).real
        
        # Check convergence
        relative_residual = np.sqrt(r_norm_sq_new / initial_r_norm_sq)
        
        if verbose and (iter % 10 == 0 or iter + 1 == max_iter or relative_residual < tol):
            # Calculate error if phantom is available
            phantom_ros = phantom.flatten()[ros_indices]
            ros_mse = np.mean(np.abs(x - phantom_ros)**2)
            print(f"Iteration {iter+1}/{max_iter}, Relative residual: {relative_residual:.6e}, ROS MSE: {ros_mse:.6f}")
        
        if relative_residual < tol:
            if verbose:
                print(f"Converged in {iter+1} iterations")
            break
        
        # Apply adjoint operator to new residual
        z = adjoint_op(r)
        
        # Compute CG coefficient
        beta = r_norm_sq_new / r_norm_sq
        
        # Update search direction
        p = z + beta * p
        
        # Update residual norm
        r_norm_sq = r_norm_sq_new
    
    else:
        if verbose:
            print(f"Maximum iterations ({max_iter}) reached without convergence")
    
    # Create reconstructed image
    image_reconstructed = np.zeros(size*size, dtype=complex)
    image_reconstructed[ros_indices] = x
    
    return image_reconstructed.reshape(size, size)





def convert_indices_to_mask(selected_indices, size):
    sampling_pattern = np.zeros((size, size))
    for idx in selected_indices:
        y, x = idx // size, idx % size
        sampling_pattern[y, x] = 1
    return sampling_pattern
