"""Plotly 시각화 모듈."""

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.graph_objs import Figure


def fig_category_sales(filtered_frame: pd.DataFrame) -> Figure:
    """
    카테고리별 매출 합계를 막대 차트로 시각화한다.

    Parameters
    ----------
    filtered_frame : pandas.DataFrame
        필터가 적용된 데이터프레임.

    Returns
    -------
    plotly.graph_objs.Figure
        카테고리별 Sales 합계 막대 차트.
    """
    summary = (
        filtered_frame.groupby("Category", as_index=False)["Sales"]
        .sum()
        .sort_values("Sales", ascending=False)
    )
    if summary.empty:
        return go.Figure().update_layout(title="카테고리별 매출 (데이터 없음)")

    figure = px.bar(summary, x="Category", y="Sales", title="카테고리별 매출 (Sales)")
    figure.update_layout(xaxis={"categoryorder": "total descending"}, margin=dict(l=20, r=20, t=40, b=20))
    return figure


def fig_region_share(filtered_frame: pd.DataFrame) -> Figure:
    """
    지역별 매출 비중을 도넛 차트로 시각화한다.

    Parameters
    ----------
    filtered_frame : pandas.DataFrame
        필터가 적용된 데이터프레임.

    Returns
    -------
    plotly.graph_objs.Figure
        채널별 Sales 비중 도넛 차트.
    """
    summary = (
        filtered_frame.groupby("Region", as_index=False)["Sales"]
        .sum()
        .sort_values("Sales", ascending=False)
    )
    if summary.empty:
        return go.Figure().update_layout(title="지역별 매출 비중 (데이터 없음)")

    figure = px.pie(summary, names="Region", values="Sales", title="지역별 매출 비중", hole=0.45)
    figure.update_traces(textinfo="percent+label")
    figure.update_layout(margin=dict(l=20, r=20, t=40, b=20))
    return figure


def fig_state_topn(filtered_frame: pd.DataFrame, top_n: int = 10) -> Figure:
    """
    주(State)별 상위 매출 Top-N을 막대 차트로 시각화한다.

    Parameters
    ----------
    filtered_frame : pandas.DataFrame
        필터가 적용된 데이터프레임.
    top_n : int, optional
        상위 도시 개수, 기본값 10.

    Returns
    -------
    plotly.graph_objs.Figure
        Top-N 도시별 Sales 막대 차트.
    """
    summary = (
        filtered_frame.groupby("State", as_index=False)["Sales"]
        .sum()
        .sort_values("Sales", ascending=False)
        .head(top_n)
    )
    if summary.empty:
        return go.Figure().update_layout(title=f"상위 {top_n} 주(State)별 매출 (데이터 없음)")

    figure = px.bar(summary, x="State", y="Sales", title=f"상위 {top_n} 주(State)별 매출")
    figure.update_layout(xaxis={"categoryorder": "total descending"}, margin=dict(l=20, r=20, t=40, b=20))
    return figure


def fig_monthly_sales(filtered_frame: pd.DataFrame) -> Figure:
    """월별 매출 추이를 선 그래프로 표현한다."""

    if "OrderYearMonth" not in filtered_frame.columns:
        return go.Figure().update_layout(title="월별 매출 추이 (데이터 없음)")

    summary = (
        filtered_frame.groupby("OrderYearMonth", as_index=False)["Sales"].sum()
    )

    if summary.empty:
        return go.Figure().update_layout(title="월별 매출 추이 (데이터 없음)")

    summary["OrderMonth"] = pd.to_datetime(summary["OrderYearMonth"], format="%Y-%m", errors="coerce")
    summary = summary.sort_values("OrderMonth")

    figure = px.line(
        summary,
        x="OrderMonth",
        y="Sales",
        markers=True,
        title="월별 매출 추이",
    )
    figure.update_xaxes(title="Order Month")
    figure.update_yaxes(title="Sales")
    figure.update_layout(margin=dict(l=20, r=20, t=40, b=20))
    return figure


def fig_product_treemap(filtered_frame: pd.DataFrame, max_items: int = 25) -> Figure:
    """
    카테고리/서브카테고리 매출을 트리맵으로 시각화한다. (상위 max_items로 제한)

    Parameters
    ----------
    filtered_frame : pandas.DataFrame
        필터가 적용된 데이터프레임.
    max_items : int, optional
        트리맵에 표시할 (Sales 기준) 상위 아이템 수, 기본값 25.

    Returns
    -------
    plotly.graph_objs.Figure
        제품 트리맵.
    """
    group_cols = [col for col in ["Category", "Sub-Category"] if col in filtered_frame.columns]
    if not group_cols:
        group_cols = ["Category"]
    summary = (
        filtered_frame.groupby(group_cols, as_index=False)["Sales"]
        .sum()
        .sort_values("Sales", ascending=False)
        .head(max_items)
    )

    # --- 방어적 변환: narwhals-like 객체를 pandas DataFrame으로 변환하고 인덱스 정리 ---
    try:
        if hasattr(summary, "to_pandas"):
            summary = summary.to_pandas()
        else:
            summary = pd.DataFrame(summary)
    except Exception:
        summary = pd.DataFrame(summary)

    # Sales 컬럼을 숫자로 강제하고 NaN은 0으로 채움
    if "Sales" in summary.columns:
        summary["Sales"] = pd.to_numeric(summary["Sales"], errors="coerce").fillna(0)
    else:
        summary["Sales"] = 0

    summary = summary.reset_index(drop=True)
    # -------------------------------------------------------------------------------

    if summary.empty:
        return go.Figure().update_layout(title="카테고리별 매출 트리맵 (데이터 없음)")

    # px.treemap 호출을 안전하게 감싸서 실패시 빈(플레이스홀더) Figure 반환
    try:
        path_columns = [col for col in ["Category", "Sub-Category"] if col in summary.columns]
        if not path_columns:
            path_columns = ["Category"]
        figure = px.treemap(
            summary,
            path=path_columns,
            values="Sales",
            title=f"카테고리별 매출 트리맵 (상위 {max_items}개)"
        )
        figure.update_layout(margin=dict(l=20, r=20, t=40, b=20))
    except Exception:
        import plotly.graph_objects as go
        figure = go.Figure().update_layout(title="데이터 오류 (데이터 없음)")

    return figure
