#!/usr/bin/env python3
"""
Image Processor Script
======================

This script processes images in a specified input folder, resizing and cropping them to fit into
standardized dimensions or "buckets" determined by clustering image sizes. The processed images
are saved to an output folder.

Usage:
    python fit_in_buckets.py input_folder output_folder number_of_buckets [options]

Parameters:
    input_folder (str): Path to the folder containing the input images.
    output_folder (str): Path to the folder where processed images will be saved.
    number_of_buckets (int): The number of size categories to cluster the images into.

Optional Arguments:
    --min-images-per-bucket (int): Minimum number of images per bucket before merging (default: 6).
    --bucket-dimension-multiple (int): Multiple to which bucket dimensions are adjusted (default: 64).
    --crop-mode (str): Crop mode, 'center' or 'random' (default: 'center').

Example:
    python fit_in_buckets.py ./images ./processed_images 5 --min-images-per-bucket 6 --bucket-dimension-multiple 64 --crop-mode center

Dependencies:
    - Python 3.x
    - Pillow (PIL)
    - NumPy
    - scikit-learn

Installation:
    pip install Pillow numpy scikit-learn

Notes:
    - Supported image formats are PNG and JPEG (extensions: .png, .jpg, .jpeg).
    - Images are resized to fill the bucket dimensions and then cropped to fit exactly.
    - Bucket dimensions are adjusted to the nearest multiple of the specified value (default is 64 pixels).
    - The script outputs information about the resizing and cropping process for each image.
    - Initial and final bucket counts are displayed for comparison.
    - Each bucket will have at least the specified minimum number of images; small clusters are merged with the nearest larger cluster.
    - **Cropping Options**:
        - **Center Cropping**: Cropping is performed equally from the center, preserving the central part of the image.
        - **Random Cropping**: Cropping position is randomly selected, which can introduce variation if central content is not critical.
    - **Minimizing Image Modification**:
        - The script aims to minimize modifications to the original images by:
            - **Uniform Scaling**: Images are resized uniformly, preserving the original aspect ratio.
            - **Minimal Scaling**: The scaling factor used is the minimum required to fill the bucket dimensions.
            - **Adjustable Bucket Dimensions**: Bucket dimensions are adjusted to the nearest multiple specified by the user, providing flexibility in standardization.
            - **Cropping Options**: Allowing the user to choose cropping from the center or random positions to suit different types of images.

Author:
    [Your Name]

Date:
    [Date]

"""

import os
import sys
import argparse
import random
from PIL import Image
import numpy as np
from sklearn.cluster import KMeans
from collections import defaultdict

# Set up command-line argument parsing
parser = argparse.ArgumentParser(description='Process images into standardized buckets.')
parser.add_argument('input_folder', help='Path to the folder containing the input images.')
parser.add_argument('output_folder', help='Path to the folder where processed images will be saved.')
parser.add_argument('number_of_buckets', type=int, help='The number of size categories to cluster the images into.')
parser.add_argument('--min-images-per-bucket', type=int, default=6, help='Minimum number of images per bucket before merging (default: 6).')
parser.add_argument('--bucket-dimension-multiple', type=int, default=64, help='Multiple to which bucket dimensions are adjusted (default: 64).')
parser.add_argument('--crop-mode', choices=['center', 'random'], default='center', help="Crop mode: 'center' or 'random' (default: 'center').")
args = parser.parse_args()

# Get command-line arguments
image_folder = args.input_folder
output_folder = args.output_folder
try:
    number_of_buckets = int(args.number_of_buckets)
    if number_of_buckets <= 0:
        raise ValueError
except ValueError:
    print("Error: number_of_buckets must be a positive integer.")
    sys.exit(1)

min_images_per_bucket = args.min_images_per_bucket
bucket_dimension_multiple = args.bucket_dimension_multiple
crop_mode = args.crop_mode

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Initialize lists to hold image data
image_names = []
widths = []
heights = []
aspect_ratios = []

# Step 1: Collect Image Data
for filename in os.listdir(image_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(image_folder, filename)
        try:
            with Image.open(image_path) as img:
                width, height = img.size
                if width == 0 or height == 0:
                    print(f"Warning: Image {filename} has zero width or height. Skipping.")
                    continue
                image_names.append(filename)
                widths.append(width)
                heights.append(height)
                aspect_ratios.append(width / height)
        except Exception as e:
            print(f"Error processing image {filename}: {e}")
            continue

if len(image_names) == 0:
    print("No valid images found in the input folder.")
    sys.exit(1)

# Convert lists to numpy arrays
widths = np.array(widths)
heights = np.array(heights)
aspect_ratios = np.array(aspect_ratios)

# Prepare data for clustering
X = np.column_stack((widths, heights, aspect_ratios))

# Normalize the data
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Step 2: Initial Clustering
kmeans = KMeans(n_clusters=number_of_buckets, random_state=42)
clusters = kmeans.fit_predict(X_scaled)

# Step 3: Handle Small Clusters
cluster_counts = np.bincount(clusters)
small_clusters = np.where(cluster_counts < min_images_per_bucket)[0]

if len(small_clusters) > 0:
    print(f"Found {len(small_clusters)} clusters with fewer than {min_images_per_bucket} images. Merging them with nearest clusters.")
    for small_cluster in small_clusters:
        # Indices of images in the small cluster
        small_indices = np.where(clusters == small_cluster)[0]
        # For each image in the small cluster
        for idx in small_indices:
            # Compute distances to all cluster centers
            distances = np.linalg.norm(X_scaled[idx] - kmeans.cluster_centers_, axis=1)
            # Exclude the small cluster itself
            distances[small_cluster] = np.inf
            # Assign to the nearest cluster
            new_cluster = np.argmin(distances)
            clusters[idx] = new_cluster
    # Recompute cluster counts after reassignment
    cluster_counts = np.bincount(clusters, minlength=number_of_buckets)

# Remove empty clusters
unique_clusters = np.unique(clusters)
cluster_mapping = {old_label: new_label for new_label, old_label in enumerate(unique_clusters)}
clusters = np.array([cluster_mapping[old_label] for old_label in clusters])
number_of_buckets = len(unique_clusters)

# Step 4: Bucket Determination
buckets = {}
for i in range(number_of_buckets):
    indices = np.where(clusters == i)[0]
    if len(indices) == 0:
        continue  # Skip empty clusters
    cluster_widths = widths[indices]
    cluster_heights = heights[indices]
    # Calculate the mean width and height
    mean_width = np.mean(cluster_widths)
    mean_height = np.mean(cluster_heights)
    # Adjust to nearest specified pixel steps
    bucket_width = int(round(mean_width / bucket_dimension_multiple) * bucket_dimension_multiple)
    bucket_height = int(round(mean_height / bucket_dimension_multiple) * bucket_dimension_multiple)
    buckets[i] = {
        'bucket_width': bucket_width,
        'bucket_height': bucket_height,
        'image_indices': indices
    }

# Step 5: Assignment and Initial Counting
bucket_counts_initial = {}
for i, bucket in buckets.items():
    dims = (bucket['bucket_width'], bucket['bucket_height'])
    bucket_counts_initial[dims] = len(bucket['image_indices'])

# Print initial bucket counts
print("Initial Bucket Dimensions and Image Counts:")
for dims, count in bucket_counts_initial.items():
    print(f"Bucket {dims}: {count} images")
print("\nProcessing images...\n")

# Step 6: Resizing and Cropping
# Create a mapping from image index to bucket index
image_to_bucket = {}
for bucket_idx, bucket in buckets.items():
    for img_idx in bucket['image_indices']:
        image_to_bucket[img_idx] = bucket_idx

# Process each image
for img_idx, filename in enumerate(image_names):
    image_path = os.path.join(image_folder, filename)
    with Image.open(image_path) as img:
        orig_width, orig_height = img.size
        bucket_idx = image_to_bucket[img_idx]
        bucket = buckets[bucket_idx]
        bucket_width = bucket['bucket_width']
        bucket_height = bucket['bucket_height']

        # Calculate scaling factors
        scale_width = bucket_width / orig_width
        scale_height = bucket_height / orig_height
        # Choose the larger scaling factor to ensure the image fills the bucket dimensions
        scale = max(scale_width, scale_height)
        new_width = int(orig_width * scale)
        new_height = int(orig_height * scale)
        # Resize the image
        resized_img = img.resize((new_width, new_height), Image.LANCZOS)

        # Calculate resizing percentage
        resize_percentage = (scale - 1) * 100  # As a percentage
        resize_percentage = round(resize_percentage, 2)

        # Calculate excess pixels
        delta_width = new_width - bucket_width
        delta_height = new_height - bucket_height

        # Initialize cropping parameters
        crop_x = 0
        crop_y = 0
        cropped_axis = ''
        pixels_cropped = 0

        # Determine cropping along one axis
        if delta_width >= delta_height:
            # Crop along width (x-axis)
            if crop_mode == 'center':
                crop_x = (new_width - bucket_width) // 2
            else:
                crop_x = random.randint(0, new_width - bucket_width)
            crop_box = (crop_x, 0, crop_x + bucket_width, new_height)
            cropped_axis = 'x'
            pixels_cropped = delta_width
        else:
            # Crop along height (y-axis)
            if crop_mode == 'center':
                crop_y = (new_height - bucket_height) // 2
            else:
                crop_y = random.randint(0, new_height - bucket_height)
            crop_box = (0, crop_y, new_width, crop_y + bucket_height)
            cropped_axis = 'y'
            pixels_cropped = delta_height

        # Round pixels cropped to integer
        pixels_cropped = int(round(pixels_cropped))

        cropped_img = resized_img.crop(crop_box)

        # Ensure the final image has exact bucket dimensions
        final_img = cropped_img.resize((bucket_width, bucket_height), Image.LANCZOS)

        # Save the processed image in the output folder
        output_path = os.path.join(output_folder, filename)
        final_img.save(output_path)

        # Print the required information
        print(f"Image {filename} is resized by {resize_percentage}% and cropped from axis {cropped_axis} by {pixels_cropped} pixels to fit in bucket ({bucket_width}, {bucket_height})")

# Step 7: Reassignment and New Counting
# Since all images now have dimensions matching the bucket dimensions, we can count them
bucket_counts_new = defaultdict(int)
for img_idx, filename in enumerate(image_names):
    bucket_idx = image_to_bucket[img_idx]
    bucket = buckets[bucket_idx]
    dims = (bucket['bucket_width'], bucket['bucket_height'])
    bucket_counts_new[dims] += 1

# Print new bucket counts
print("\nNew Bucket Dimensions and Image Counts:")
for dims, count in bucket_counts_new.items():
    print(f"Bucket {dims}: {count} images")

# Compare initial and new counts
print("\nComparison of Initial and New Bucket Counts:")
for dims in bucket_counts_initial.keys():
    initial_count = bucket_counts_initial.get(dims, 0)
    new_count = bucket_counts_new.get(dims, 0)
    print(f"Bucket {dims}: Initial Count = {initial_count}, New Count = {new_count}")
