# inference.py
#
import argparse
import pandas as pd
from models.my_decision_tree import MyDecisionTree


def inference(model_path: str, test_data: str, input_index: int = 0):
    clf = MyDecisionTree()
    clf.load_model(model_path)
    print(f"Test columns: {clf.feature_columns}")
    df = pd.read_csv(test_data)

    inferece_data = df.iloc[[input_index]][clf.feature_columns] # 2차원 DataFrame 유지
    prediction = clf.predict(inferece_data)

    input_row = df.iloc[input_index]
    true_label = input_row.iloc[-1]  # 마지막 열이 레이블이라고 가정

    print(f'---' * 20)
    print(f"Inputdata: Row Indix{input_index}")
    print(input_row)
    print(f"\n>>> Label: {true_label}\tPrediction: {prediction[0]}")


def parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="의사결정트리 모델 추론")
    parser.add_argument("--model_path", "-m", type=str, default="decision_tree_model.pkl",
                        help="모델 파일 경로, default: decision_tree_model.pkl")
    parser.add_argument("--test_data", "-d", type=str, default="dataset/test.csv",
                        help="테스트 데이터셋, default: dataset/test.csv")
    parser.add_argument("--input_index", "-i", type=int, default=0,
                        help="입력 인덱스, default: 0")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parser()
    inference(**vars(args))  # Unpack args to pass as keyword arguments

