# train.py

import pandas as pd
import pickle
from sklearn.tree import DecisionTreeClassifier

class MyDecisionTree:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth
        self.model = DecisionTreeClassifier(max_depth=self.max_depth)
        self.feature_columns = None

    def load_data(self, dataset_path):
        df = pd.read_csv(dataset_path)
        self.feature_columns = df.columns[1:-1].tolist()
        X = df[self.feature_columns]
        y = df[df.columns[-1]]
        self.X_train = X
        self.y_train = y

    def train(self):
        self.model.fit(self.X_train, self.y_train)

    def save_model(self, model_path):
        with open(model_path, 'wb') as f:
            pickle.dump({
                'model': self.model,
                'feature_columns': self.feature_columns
            }, f)

    def load_model(self, model_path):
        with open(model_path, 'rb') as f:
            saved = pickle.load(f)
            self.model = saved['model']
            self.feature_columns = saved['feature_columns']

    def predict(self, input_data: pd.DataFrame):
        return self.model.predict(input_data)
