# train.py

import argparse

from matplotlib import pyplot as plt
import pandas as pd
from models.my_decision_tree import MyDecisionTree
from sklearn.metrics import (
    confusion_matrix, accuracy_score,
    precision_score, recall_score, f1_score
)
import seaborn as sns


class Trainer:
    def __init__(self, train_data, test_data, model_path, max_depth):
        self.train_data = train_data
        self.test_data = test_data
        self.model_path = model_path
        self.max_depth = max_depth

    def run(self):
        clf = MyDecisionTree(max_depth=self.max_depth)
        clf.load_data(self.train_data)
        clf.train()
        clf.save_model(self.model_path)
        print(f"✅ 모델이 저장되었습니다: {self.model_path}")

        # 2. 성능 평가 (Optional)
        if self.test_data:
            self.evaluate(clf)

    def evaluate(self, clf: MyDecisionTree):
        print(f"\n성능 평가: {self.test_data}")
        df_test = pd.read_csv(self.test_data)

        # 특성과 정답 라벨 분리
        X_test = df_test[clf.feature_columns]
        y_test = df_test.iloc[:, -1]

        # 예측
        y_pred = clf.predict(X_test)

        # 메트릭 출력
        print(f"Accuracy:  {accuracy_score(y_test, y_pred):.4f}")
        print(f"Precision: {precision_score(y_test, y_pred, average='macro'):.4f}")
        print(f"Recall:    {recall_score(y_test, y_pred, average='macro'):.4f}")
        print(f"F1-Score:  {f1_score(y_test, y_pred, average='macro'):.4f}")

        # Confusion Matrix 시각화 결과 저장
        confusion_matrix_result = confusion_matrix(y_test, y_pred)
        labels = sorted(y_test.unique())
        plt.figure(figsize=(6, 5))
        sns.heatmap(confusion_matrix_result, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
        plt.title("Confusion Matrix")
        plt.xlabel("Predicted Label")
        plt.ylabel("True Label")
        save_path = "confusion_matrix.png"
        plt.savefig(save_path)
        print(f"Confusion_matrix is saved -> {save_path}")
        plt.close()

def parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="의사결정트리 모델 학습 및 저장")
    parser.add_argument("--train_data", '-t', type=str, default="dataset/train.csv",
                        help="학습용 CSV 파일 경로, 예: data/train.csv")
    parser.add_argument("--test_data", '-e', type=str, default=None,
                        help="평가할 테스트셋 경로 (선택), 예: dataset/test.csv")
    parser.add_argument("--model_path", type=str,
                        default="decision_tree_model.pkl", help="저장할 모델 파일 경로, default: decision_tree_model.pkl")
    parser.add_argument("--max_depth", type=int, default=3, help="의사결정트리 최대 깊이, default: 3")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parser()
    trainer = Trainer(**vars(args))  # Unpack args to pass as keyword arguments
    trainer.run()

