import numpy as np
from scipy.fft import fft2, ifft2, fftshift, ifftshift
import scipy.sparse as sp


class ROISpecificADMM:
    # ADMM implementation for MRI reconstruction with L1 regularization on
    # row-major and column-major finite differences of ROI pixels.
    def __init__(self, img_shape, n_iterations=20, rho=0.1, lambda_val=0.001, cg_iterations=15, verbose=False):
        # img_shape: Image dimensions (height, width)
        # n_iterations: Number of ADMM iterations
        # rho: ADMM penalty parameter
        # lambda_val: Regularization parameter for L1 norm
        # cg_iterations: Number of CG iterations for x-update
        # verbose: Whether to print progress
        self.img_shape = img_shape
        self.n_iterations = n_iterations
        self.rho = rho
        self.lambda_val = lambda_val
        self.cg_iterations = cg_iterations
        self.verbose = verbose
    
    def _compute_flat_diff_matrices(self, roi_indices, shape):
        # Compute matrices for taking differences in flattened ROI vectors.
            
        # Returns:
        #    D: Combined difference matrix for both row-major and column-major differences
        #    D_adj: Adjoint (transpose) of D
        
        height, width = shape
        n_pixels = height * width
        n_roi = len(roi_indices)
        
        # Create mapping from full image indices to ROI indices
        roi_idx_map = -np.ones(n_pixels, dtype=int)
        for i, idx in enumerate(roi_indices):
            roi_idx_map[idx] = i
        
        # Create ROI mask
        roi_mask = np.zeros(n_pixels, dtype=bool)
        roi_mask[roi_indices] = True
        
        # For row-major (C-order) differences
        row_indices_c = []
        col_indices_c = []
        data_c = []
        diff_row_c = 0
        
        roi_flat_c = np.zeros(n_pixels, dtype=int)
        roi_flat_c[roi_indices] = np.arange(n_roi)
        
        for i in range(n_pixels-1):
            # Check if both pixels are in ROI
            if roi_mask[i] and roi_mask[i+1]:
                # Skip differences at the right edge of the image
                if (i + 1) % width == 0:
                    continue
                
                # Get ROI-specific indices
                roi_idx1 = roi_idx_map[i]
                roi_idx2 = roi_idx_map[i+1]
                
                row_indices_c.append(diff_row_c)
                col_indices_c.append(roi_idx1)
                data_c.append(-1)  # First pixel: negative
                
                row_indices_c.append(diff_row_c)
                col_indices_c.append(roi_idx2)
                data_c.append(1)   # Second pixel: positive
                
                diff_row_c += 1
        
        # For column-major (F-order) differences
        row_indices_f = []
        col_indices_f = []
        data_f = []
        diff_row_f = 0
        
        f_indices = np.arange(n_pixels).reshape(shape, order='F').flatten()
        
        for i in range(len(f_indices)-1):
            # Get original indices
            idx1 = f_indices[i]
            idx2 = f_indices[i+1]
            
            # Check if both pixels are in ROI
            if roi_mask[idx1] and roi_mask[idx2]:
                # Skip differences at the bottom edge of the image
                if (i + 1) % height == 0:
                    continue
                
                # Get ROI-specific indices
                roi_idx1 = roi_idx_map[idx1]
                roi_idx2 = roi_idx_map[idx2]
                
                # Add entries for this difference
                row_indices_f.append(diff_row_f)
                col_indices_f.append(roi_idx1)
                data_f.append(-1)  # First pixel: negative
                
                row_indices_f.append(diff_row_f)
                col_indices_f.append(roi_idx2)
                data_f.append(1)   # Second pixel: positive
                
                diff_row_f += 1
        
        # Create sparse matrices
        n_diffs_c = diff_row_c
        n_diffs_f = diff_row_f
        
        if self.verbose:
            print(f"Row-major differences: {n_diffs_c}, Column-major differences: {n_diffs_f}")
        
        D_c = sp.csr_matrix((data_c, (row_indices_c, col_indices_c)), shape=(n_diffs_c, n_roi))
        D_f = sp.csr_matrix((data_f, (row_indices_f, col_indices_f)), shape=(n_diffs_f, n_roi))
        
        # Combined matrix for both types of differences
        D = sp.vstack([D_c, D_f])
        
        # Adjoint (transpose) of D
        D_adj = D.transpose()
        
        return D, D_adj
    
    def reconstruct(self, phantom, sampling_pattern, roi_mask):
        # Perform ROI-focused ADMM reconstruction with anisotropic TV regularization.

        if roi_mask is None:
            raise ValueError("ROI mask required")
        
        size_y, size_x = self.img_shape
        
        # Get ROI indices
        roi_indices = np.nonzero(roi_mask.flatten())[0]
        n_roi = len(roi_indices)
        
        if self.verbose:
            print(f"ROI contains {n_roi} pixels")
        
        # Compute difference matrices
        D, D_adj = self._compute_flat_diff_matrices(roi_indices, (size_y, size_x))
        
        kspace_full = fftshift(fft2(phantom))
        
        kspace_sampled = kspace_full * sampling_pattern
        
        # Initialize with zero-filled reconstruction
        x_full = ifft2(ifftshift(kspace_sampled))
        
        x_roi = x_full.flatten()[roi_indices]
        
        # Store intermediate results (as full images for visualization)
        x_history = [x_full.copy()]
        
        # Define operators for the ROI-focused problem
        def forward_op(x_roi_vec):
            # Place ROI values into full image
            x_full_vec = np.zeros(size_y * size_x, dtype=complex)
            x_full_vec[roi_indices] = x_roi_vec
            x_full_img = x_full_vec.reshape(size_y, size_x)
            
            # Transform to k-space and apply mask
            return (fftshift(fft2(x_full_img)) * sampling_pattern).flatten()
        
        def adjoint_op(kspace_vec):
            # Reshape k-space to 2D
            kspace_img = kspace_vec.reshape(size_y, size_x)
            
            # Apply inverse FFT
            x_full_img = ifft2(ifftshift(kspace_img))
            
            # Extract ROI values
            return x_full_img.flatten()[roi_indices]
        
        # Compute initial differences
        Dx = D @ x_roi
        
        # ADMM variables
        z = Dx.copy()  # Auxiliary variable for differences
        u = np.zeros_like(z)  # Dual variable
        
        # ADMM iterations
        for i in range(self.n_iterations):
            # x-update (data consistency with difference penalty)
            # Solve: min_x ||Ax - y||^2 + rho||Dx - z + u||^2
            
            # Compute A^H y
            A_H_y = adjoint_op(kspace_sampled.flatten())
            
            # Compute D^H(z - u)
            D_H_zu = D_adj @ (z - u)
            
            # Right-hand side of the equation
            rhs = A_H_y + self.rho * D_H_zu
            
            # Define the linear operator (A^H A + rho D^H D)
            def linear_op(v):
                # Apply A^H A
                AHA_v = adjoint_op(forward_op(v))
                
                # Apply D^H D
                DHD_v = D_adj @ (D @ v)
                
                return AHA_v + self.rho * DHD_v
            
            # Solve using conjugate gradient
            # Initialize with current solution
            x_new = x_roi.copy()
            
            # Initialize residual: r = b - Ax
            r = rhs - linear_op(x_new)
            
            # Initial search direction
            p = r.copy()
            
            # CG iterations
            for cg_iter in range(self.cg_iterations):
                # Apply linear operator: A p
                Ap = linear_op(p)
                
                # Compute step size: alpha = r^H r / (p^H A p)
                rTr = np.real(np.sum(r.conj() * r))
                pTAp = np.real(np.sum(p.conj() * Ap))
                alpha = rTr / (pTAp + 1e-10)  # Add small value for stability
                
                # Update solution: x = x + alpha p
                x_new = x_new + alpha * p
                
                # Update residual: r_new = r - alpha A p
                r_new = r - alpha * Ap
                
                # Check convergence
                residual_norm = np.linalg.norm(r_new)
                if residual_norm < 1e-6:
                    break
                
                # Compute beta: beta = r_new^H r_new / (r^H r)
                r_newTr_new = np.real(np.sum(r_new.conj() * r_new))
                beta = r_newTr_new / (rTr + 1e-10)
                
                # Update search direction: p = r_new + beta p
                p = r_new + beta * p
                r = r_new
            
            # Update x_roi with the CG solution
            x_roi = x_new
            
            # Compute differences for the current solution
            Dx = D @ x_roi
            
            # z-update (difference L1 sparsity)
            # Solve: min_z lambda||z||_1 + (rho/2)||Dx - z + u||_2^2
            v = Dx + u
            threshold = self.lambda_val / self.rho
            
            # Soft thresholding operation for L1 norm
            z = np.sign(v) * np.maximum(np.abs(v) - threshold, 0)
            
            # u-update (dual variable)
            u = u + Dx - z
            
            # Create full image for history
            x_full_vec = np.zeros(size_y * size_x, dtype=complex)
            x_full_vec[roi_indices] = x_roi
            x_full_img = x_full_vec.reshape(size_y, size_x)
            
            # Store intermediate result
            x_history.append(x_full_img.copy())
            
            # Print progress if verbose
            if self.verbose and (i % 5 == 0 or i == self.n_iterations - 1):
                # Calculate error
                phantom_roi = phantom.flatten()[roi_indices]
                roi_mse = np.mean(np.abs(x_roi - phantom_roi)**2)
                print(f"ROI-Focused Finite-Diff ADMM Iter {i+1}/{self.n_iterations}, ROI MSE: {roi_mse:.6f}")
                print(f"  L1 norm of differences: {np.sum(np.abs(Dx)):.4f}")
        
        # Return the final reconstruction as a full image
        x_full_vec = np.zeros(size_y * size_x, dtype=complex)
        x_full_vec[roi_indices] = x_roi
        x_full_img = x_full_vec.reshape(size_y, size_x)
        
        return x_full_img, x_history