import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score, 
    confusion_matrix, 
    classification_report
)

class KNN:
    def __init__(self, k: int, train: str, test: str, best_k: bool = False):
        """Initialize the KNN model and set the number of neighbors and file paths."""
        self.k = k
        self.train_path = train
        self.test_path = test
        self.model = KNeighborsClassifier(n_neighbors=self.k)
        self.X_train = None
        self.y_train = None
        self.X_test = None
        self.y_test = None
        self.best_k = best_k
        self.feature_cols = None
        self.label_col = None


    def data_loader(self, feature_cols: list[str], label_col: str):
        """Load training and test datasets from CSV files."""
        train = pd.read_csv(self.train_path)
        test = pd.read_csv(self.test_path)
        self.X_train = train[feature_cols]
        self.y_train = train[label_col]
        self.X_test = test[feature_cols]
        self.y_test = test[label_col]
        self.feature_cols = feature_cols
        self.label_col = label_col


    def print_dataset(self):
        """Print the training and test datasets."""
        print("\nTraining Dataset:")
        print("----------" * 5)
        print('X_train shape:', self.X_train.shape)
        print(self.X_train.tail())
        print('y_train shape:', self.y_train.shape)
        print(self.y_train.tail())
        print("\nTest Dataset:")
        print("----------" * 5)
        print('X_test shape:', self.X_test.shape)
        print(self.X_test.tail())
        print('y_test shape:', self.y_test.shape)
        print(self.y_test.tail())


    def train(self):
        """Train the KNN model using the training data."""
        self.model.fit(self.X_train, self.y_train)


    def evaluate(self):
        """Evaluate the trained model using the test data, print metrics, and optionally search for best k."""
        y_pred = self.model.predict(self.X_test)

        acc = accuracy_score(self.y_test, y_pred)
        precision = precision_score(self.y_test, y_pred, average='weighted', zero_division=0)
        recall = recall_score(self.y_test, y_pred, average='weighted', zero_division=0)
        f1 = f1_score(self.y_test, y_pred, average='weighted', zero_division=0)

        print(f"\nEvaluation Results (k={self.k})")
        print("-" * 40)
        print(f"Accuracy : {acc:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall   : {recall:.4f}")
        print(f"F1-score : {f1:.4f}")

        # Per-class metrics
        report_dict = classification_report(self.y_test, y_pred, output_dict=True, zero_division=0)
        unique_labels = sorted(self.y_test.unique())

        print("\nPer-Class Metrics (C, SG only):")
        for label in ['C', 'SG']:
            if label in unique_labels:
                scores = report_dict[label]
                print(f"{label:>5} | Precision: {scores['precision']:.4f}  Recall: {scores['recall']:.4f}  F1-score: {scores['f1-score']:.4f}")

        # Confusion Matrix
        cm = confusion_matrix(self.y_test, y_pred)
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
        plt.title(f"Confusion Matrix (k={self.k})")
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.tight_layout()
        file_name = f"confusion_matrix (k={self.k}).png"
        plt.savefig(file_name)
        print(f"\nConfusion matrix image saved as: {file_name}")
        plt.close()

        # If best_k flag is on, perform k optimization
        if self.best_k:
            print("\nSearching for best k...")
            self.find_best_k(k_range=10)

        return acc


    def find_best_k(self, k_range: int = 10):
            """Search for the best k value based on test accuracy and visualize the result."""
            best_k = self.k
            best_score = 0
            k_list = []
            acc_list = []

            for k in range(1, k_range + 1):
                model = KNeighborsClassifier(n_neighbors=k)
                model.fit(self.X_train, self.y_train)
                score = model.score(self.X_test, self.y_test)
                print(f"k={k}, Test Accuracy: {score:.4f}")
                k_list.append(k)
                acc_list.append(score)
                if score > best_score:
                    best_k = k
                    best_score = score

            # Visualization
            plt.figure(figsize=(8, 5))
            plt.plot(k_list, acc_list, marker='o', linestyle='-', color='blue')
            plt.title(f'Accuracy by Number of Neighbors (k: 1 ~ {k_range})')
            plt.xlabel('k (Number of Neighbors)')
            plt.ylabel('Test Accuracy')
            plt.xticks(k_list)
            plt.grid(True)
            plt.tight_layout()
            print(f"Saving the accuracy vs k plot -> 'knn_accuracy (k={self.k}).png'")
            plt.savefig(f'Find best K (k=1~{k_range}).png')

            print(f"\nBest k: {best_k} with Test Accuracy: {best_score:.4f}")
            return best_k


def arg_parser():
    """Parse command line arguments for k-NN parameters."""
    parser = argparse.ArgumentParser(description="Run k-NN classification using scikit-learn")
    parser.add_argument("--k", '-k', type=int, required=True,
                        help="Number of neighbors (k), default: 3")
    parser.add_argument("--train", "-tr", type=str, default="./basketball_train.csv",
                        help="Path to training CSV file, default: ./basketball_train.csv")
    parser.add_argument("--test", "-te", type=str, default="./basketball_test.csv",
                        help="Path to test CSV file, default: ./basketball_test.csv")
    parser.add_argument("--best_k", "-b", action="store_true",
                        help="Find the best k value based on test accuracy, default: False")
    return parser.parse_args()


if __name__ == "__main__":
    args = arg_parser()
    knn = KNN(**vars(args))
    knn.data_loader(feature_cols=["3P", "TRB", "BLK"], label_col="Pos")
    knn.print_dataset()
    knn.train()
    knn.evaluate()