import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift, ifftshift
import time
import scipy.sparse as sp
from phantoms import *
from sampling import *
from admm import ROISpecificADMM

# Set the random seed for reproducibility
np.random.seed(42)
accel = 1
undersamp_factor = 1/accel

# title size for plots
plt.rcParams['axes.titlesize'] = 20

def run_cs_simulation(phantom_type='circle', size=32, n_samples=None, tradeoff=0.3, use_admm=True):
    # Full simulated compressed sensing reconstruction using our ADMM algorithm

    # Returns:
    #     phantom: Original phantom
    #     reconstructed_standard: Standard reconstruction
    #     reconstructed_admm_roi: ROI-focused ADMM reconstruction 
    #     sampling_pattern: K-space sampling pattern
    #     ros_mask: Region of support mask
    #     timing_data: Dictionary containing timing data for each method
    # Create the phantom based on the specified type
    
    if phantom_type == 'circle':
        ros_mask = create_circular_phantom(size, radius=size//4)
    elif phantom_type == 'off_center_circle':
        ros_mask = create_circular_phantom(size, radius=size//4, center=(size//3, size//3))
    elif phantom_type == 'rectangle':
        ros_mask = create_rectangular_phantom(size)
    else:
        raise ValueError(f"Unknown phantom type: {phantom_type}")
    
    # Add an image pattern to the phantom
    phantom = add_image_to_phantom(ros_mask)
    
    # Initialize timing data dictionary
    timing_data = {}
    
    # Generate k-space sampling pattern
    print("Generating k-space sampling pattern...")
    start_time = time.time()
    selected_indices, sampling_pattern = cs_sampling(
        phantom, ros_mask, n_samples, tradeoff, undersamp_factor=undersamp_factor
    )
    sampling_time = time.time() - start_time
    timing_data['sampling'] = sampling_time
    print(f"Sample selection took {sampling_time:.2f} seconds")
    
    # Standard reconstruction using pseudoinverse
    print("Performing standard reconstruction with least-squares...")
    start_time = time.time()
    reconstructed_standard = reconstruct_from_samples(phantom, selected_indices, ros_mask, size)
    standard_time = time.time() - start_time
    timing_data['standard'] = standard_time
    print(f"Standard reconstruction took {standard_time:.2f} seconds")
    
    # For display purposes, take the magnitude of the complex data
    reconstructed_standard_mag = np.abs(reconstructed_standard)
    
    # Calculate reconstruction error within the ROS for standard reconstruction
    error_standard = np.zeros_like(phantom)
    error_standard[ros_mask > 0] = np.abs(phantom[ros_mask > 0] - reconstructed_standard_mag[ros_mask > 0])
    mean_error_standard = np.mean(error_standard[ros_mask > 0])
    print(f"Standard reconstruction - Mean error within ROS: {mean_error_standard:.6f}")
    
    # Initialize ADMM reconstruction variables
    reconstructed_admm_roi = None
    
    if use_admm:
        print("Performing ROI-focused ADMM reconstruction...")
        roi_admm = ROISpecificADMM(
            img_shape=(size, size),
            n_iterations=20,
            rho=0.5,
            #lambda_val=0.00001, # 0.0005-0.001 seems to be best
            lambda_val = 0.0005,
            verbose=True
        )
        
        # Perform reconstruction
        start_time = time.time()
        reconstructed_admm_roi, _ = roi_admm.reconstruct(phantom, sampling_pattern, ros_mask)
        admm_roi_time = time.time() - start_time
        timing_data['admm_roi'] = admm_roi_time
        print(f"ROI-focused ADMM reconstruction took {admm_roi_time:.2f} seconds")
        
        # Calculate error for ROI-focused ADMM reconstruction
        reconstructed_admm_roi_mag = np.abs(reconstructed_admm_roi)
        error_admm_roi = np.zeros_like(phantom)
        error_admm_roi[ros_mask > 0] = np.abs(phantom[ros_mask > 0] - reconstructed_admm_roi_mag[ros_mask > 0])
        mean_error_admm_roi = np.mean(error_admm_roi[ros_mask > 0])
        print(f"ROI-focused ADMM reconstruction - Mean error within ROS: {mean_error_admm_roi:.6f}")
    
    return phantom, reconstructed_standard, reconstructed_admm_roi, sampling_pattern, ros_mask, timing_data


# ===== Visualization of results =====

def plot_cs_comparison_results(phantom, reconstructed_standard, reconstructed_admm_roi, 
                             sampling_pattern, ros_mask, timing_data, title_prefix, suptitle, suffix=""):
    
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    
    # First plot: showing phantom and sampling information
    
    axes[0].imshow(phantom, cmap='gray', vmin=0)
    axes[0].set_title(f"{title_prefix} - Masked Phantom")
    axes[0].axis('off')
    
    axes[1].imshow(ros_mask, cmap='binary')
    axes[1].set_title(f"{title_prefix} - Region of Support (ROS)")
    axes[1].axis('off')
    
    axes[2].imshow(sampling_pattern, cmap='binary')
    axes[2].set_title(f"{title_prefix} - Sampling Pattern")
    axes[2].axis('off')
    
    error_standard = np.abs(phantom - np.abs(reconstructed_standard))
    error_standard_roi = error_standard * ros_mask

    plt.suptitle(suptitle, fontsize=24)    
    plt.tight_layout()
    plt.savefig(f"results/{title_prefix.replace(' ', '_')}_cs_comparison_sampling_{accel}x_accel_{suffix}.png")
    plt.show()

    # second plot 
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))    
    # Row 1: Standard reconstruction
    axes[0, 0].imshow(np.abs(reconstructed_standard), cmap='gray')
    axes[0, 0].set_title(f"Direct LS Reconstruction ({timing_data.get('standard', 0):.2f}s)")
    axes[0, 0].axis('off')
    
    error_standard = np.abs(phantom - np.abs(reconstructed_standard))
    im = axes[0, 1].imshow(error_standard, cmap='hot')
    axes[0, 1].set_title(f"LS - Error Map")
    axes[0, 1].axis('off')
    cbar = plt.colorbar(im, ax=axes[0, 1], fraction=0.04, pad=0.04)
    for t in cbar.ax.get_yticklabels():
         t.set_fontsize(14)
        
    # Row 2: ROI-focused ADMM reconstruction
    if reconstructed_admm_roi is not None:
        axes[1, 0].imshow(np.abs(reconstructed_admm_roi), cmap='gray')
        axes[1, 0].set_title(f"ROI-specific ADMM Reconstruction ({timing_data.get('admm_roi', 0):.2f}s)")
        axes[1, 0].axis('off')
        
        error_admm_roi_full = np.abs(phantom - np.abs(reconstructed_admm_roi))
        im = axes[1, 1].imshow(error_admm_roi_full, cmap='hot')
        axes[1, 1].set_title(f"ROI-specific ADMM - Error Map")
        axes[1, 1].axis('off')
        
        cbar = plt.colorbar(im, ax=axes[1, 1], fraction=0.04, pad=0.04)
        for t in cbar.ax.get_yticklabels():
            t.set_fontsize(14)
    
    plt.tight_layout()
    plt.savefig(f"results/{title_prefix.replace(' ', '_')}_cs_comparison_{accel}x_accel_{suffix}.png")
    plt.show()
    
    # # Create a separate plot for error comparisons in ROI
    # if reconstructed_admm_roi is not None:
    #     fig, ax = plt.subplots(figsize=(10, 6))
        
    #     # Compute mean errors within ROI
    #     mean_error_standard = np.mean(error_standard_roi)
    #     mean_error_admm_roi = np.mean(error_admm_roi_full)

    #     methods = ['Standard', 'ROI-focused ADMM']
    #     errors = [mean_error_standard, mean_error_admm_roi]
        
    #     # Create bar chart
    #     ax.bar(methods, errors, width=0.2, color='gray')
    #     ax.set_title(f"{title_prefix} - Mean Error Comparison within ROS")
    #     ax.set_ylabel('Mean Error')
    #     ax.grid(True, axis='y', linestyle='--', alpha=0.7)
        
    #     # Add error values on top of bars
    #     for i, v in enumerate(errors):
    #         ax.text(i, v, f"{v:.4f}", ha='center')
        
    #     plt.tight_layout()
    #     plt.savefig(f"{title_prefix.replace(' ', '_')}_cs_error_comparison.png")
    #     plt.show()


# ===== CS Demo Function =====

def run_cs_demos():
    # Run compressed sensing demos for different phantom types
    
    # Parameters
    size = 64
    use_admm = True
    roi_focused = True
    
    # Store all timing data
    all_timing_data = {}
    
    # Run simulations for different phantom types
    phantoms = ['circle', 'off_center_circle', 'rectangle']
    
    for phantom_type in phantoms:
        print("\n" + "="*50)
        print(f"Running CS simulation for {phantom_type} phantom")
        print("="*50)
        
        # First try with only ROI-focused samples
        tradeoff = 0  
        
        # Run the simulation
        phantom, reconstructed_standard, reconstructed_admm_roi, sampling_pattern, ros_mask, timing_data = run_cs_simulation(
            phantom_type=phantom_type,
            size=size,
            n_samples=None,  
            tradeoff=tradeoff,
            use_admm=use_admm
        )
        
        # Store timing data
        all_timing_data[f"{phantom_type}_cs"] = timing_data
        
        # Plot the results
        plot_cs_comparison_results(
            phantom, 
            reconstructed_standard, 
            reconstructed_admm_roi, 
            sampling_pattern, 
            ros_mask,
            timing_data,
            f"{phantom_type.title()} ROI",
            f"Optimized sampling\n",
            "optimized"
        )
        
        # Try with all Gaussian sampling 
        tradeoff = 1 
        
        print("\nRunning with random sampling...")
        phantom, reconstructed_standard, reconstructed_admm_roi, sampling_pattern, ros_mask, timing_data_high = run_cs_simulation(
            phantom_type=phantom_type,
            size=size,
            n_samples=None,
            tradeoff=tradeoff,
            use_admm=use_admm
        )
        
        # Store timing data
        all_timing_data[f"{phantom_type}_cs_high"] = timing_data_high
        
        # Plot the results
        plot_cs_comparison_results(
            phantom, 
            reconstructed_standard, 
            reconstructed_admm_roi, 
            sampling_pattern, 
            ros_mask,
            timing_data_high,
            f"{phantom_type.title()} ROI",
            f"Gaussian sampling\n",
            "gaussian"
        )


def main():
    # Run compressed sensing simulations
    print("\n" + "="*50)
    print("Running compressed sensing simulations")
    print("="*50)
    run_cs_demos()

if __name__ == "__main__":
    run_cs_demos()
