# File: hw10_solution.py
"""
HW10 Solution Script
- 입력: SNAP 등에서 내려받은 엣지 리스트 파일(예: email-Eu-core.txt(.gz), facebook_combined.txt(.gz) 등)
- 그래프: 무방향, 다중 엣지 가중치 합산
- 분석:
  1) Degree / PageRank 상위 10개 기업(노드) 비교 테이블 생성 (CSV)
  2) Louvain 커뮤니티 탐지 및 시각화 (PNG)
  3) 핵심 노드 5곳(기본: Degree 기준)의 이웃 subgraph PDF 저장
  4) 선택: pyvis 인터랙티브 HTML 생성 (--pyvis 옵션)
- 사용 예:
  python hw10_solution.py --edge-file ./email-Eu-core.txt.gz --delimiter space --weighted False --topk 10 --ego 5 --outdir ./outputs --pyvis
"""

import argparse
import gzip
import os
import sys
from collections import defaultdict
from typing import Iterable, Tuple, Dict

import matplotlib
matplotlib.use("Agg")  # headless
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

# Louvain (python-louvain)
try:
    import community as community_louvain  # package name: community-louvain
except Exception:
    print("[Error] 'community-louvain' 패키지가 필요합니다. 설치: pip install python-louvain", file=sys.stderr)
    raise


def read_edges(
    edge_file: str,
    delimiter: str = "space",
    weighted: bool = False
) -> Iterable[Tuple[str, str, float]]:
    """
    엣지 파일을 읽어 (u, v, w) 튜플 스트림으로 반환.
    - delimiter: 'space' | 'tab' | 'csv' (쉼표)
    - weighted=True면 세 번째 컬럼을 weight로 사용. 없으면 1.0
    """
    open_fn = gzip.open if edge_file.endswith(".gz") else open
    sep = {"space": None, "tab": "\t", "csv": ","}[delimiter]  # None=whitespace split

    with open_fn(edge_file, "rt", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if not line.strip():
                continue
            if line.startswith("#") or line.startswith("%"):
                # 주석 라인 무시
                continue
            parts = line.strip().split(sep)
            if len(parts) < 2:
                continue
            u, v = parts[0], parts[1]
            if weighted:
                if len(parts) >= 3:
                    try:
                        w = float(parts[2])
                    except ValueError:
                        w = 1.0
                else:
                    w = 1.0
            else:
                w = 1.0
            if u == v:
                # self-loop 제거(선택)
                continue
            yield (u, v, w)


def build_undirected_weighted_graph(edges: Iterable[Tuple[str, str, float]]) -> nx.Graph:
    """
    무방향 그래프 생성 + 동일 엣지에 대한 가중치 합산
    """
    agg = defaultdict(float)
    for u, v, w in edges:
        a, b = (u, v) if u <= v else (v, u)
        agg[(a, b)] += w

    G = nx.Graph()
    for (u, v), w in agg.items():
        G.add_edge(u, v, weight=w)
    return G


def compute_centralities(G: nx.Graph) -> Dict[str, Dict[str, float]]:
    """
    중심성 계산:
      - Degree centrality (무게 미반영: 정의상 degree_centrality는 단순 연결수 비율)
      - PageRank (가중치 반영: edge weight 사용)
    """
    deg = nx.degree_centrality(G)
    # PageRank에서 weight 반영 (edge attribute 'weight')
    pr = nx.pagerank(G, alpha=0.85, weight="weight")
    return {"degree": deg, "pagerank": pr}


def topk_report(cent: Dict[str, Dict[str, float]], k: int = 10) -> pd.DataFrame:
    """
    degree, pagerank 상위 k 노드를 DataFrame으로 결합
    """
    d_series = pd.Series(cent["degree"], name="Degree")
    p_series = pd.Series(cent["pagerank"], name="PageRank")

    top_d = d_series.sort_values(ascending=False).head(k)
    top_p = p_series.sort_values(ascending=False).head(k)

    df = pd.DataFrame({"Degree": top_d}).merge(
        pd.DataFrame({"PageRank": top_p}),
        left_index=True, right_index=True, how="outer"
    ).fillna(0.0)

    # 보기 좋게 정렬: Degree 우선
    df = df.sort_values(by=["Degree", "PageRank"], ascending=False)
    return df


def draw_louvain_communities(G: nx.Graph, path_png: str, seed: int = 42) -> None:
    """
    Louvain 커뮤니티 탐지 후 색상 구분 시각화 (PNG)
    """
    # 커뮤니티 탐지: returns dict {node: community_id}
    part = community_louvain.best_partition(G, weight="weight", random_state=seed)

    # 레이아웃(고정 난수)
    pos = nx.spring_layout(G, seed=seed, k=None)

    # 커뮤니티별로 색상 다르게
    communities = pd.Series(part).astype("category")
    colors = communities.cat.codes  # 0,1,2,...

    plt.figure(figsize=(10, 8))
    nx.draw_networkx_edges(G, pos, alpha=0.15, width=0.5)
    nodes = nx.draw_networkx_nodes(
        G, pos,
        node_color=colors,
        cmap="tab20",
        node_size=20,
        linewidths=0
    )
    # 간단한 범례: 커뮤니티 별 노드 수 (상위 몇 개만)
    counts = communities.value_counts().sort_values(ascending=False)
    legend_txt = "Communities (size):\n" + "\n".join([f"{cid}: {cnt}" for cid, cnt in counts.head(10).items()])
    plt.gca().text(1.02, 0.5, legend_txt, transform=plt.gca().transAxes, va="center")

    plt.title("Louvain Communities (colored)")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(path_png, dpi=220)
    plt.close()


def save_ego_pdfs(G: nx.Graph, hubs: Iterable[str], outdir: str, seed: int = 42) -> None:
    """
    핵심 노드(hubs) 각각의 1-ego subgraph를 PDF로 저장
    """
    os.makedirs(outdir, exist_ok=True)
    for node in hubs:
        if node not in G:
            continue
        ego = nx.ego_graph(G, node, radius=1, undirected=True)
        pos = nx.spring_layout(ego, seed=seed)
        plt.figure(figsize=(7, 6))
        nx.draw_networkx_edges(ego, pos, alpha=0.4)
        # 중심 노드 크게
        sizes = [200 if n != node else 600 for n in ego.nodes()]
        colors = ["#9ecae1" if n != node else "#de2d26" for n in ego.nodes()]
        nx.draw_networkx_nodes(ego, pos, node_size=sizes, node_color=colors, edgecolors="black", linewidths=0.5)
        nx.draw_networkx_labels(ego, pos, font_size=8)
        plt.title(f"Ego subgraph (center={node})")
        plt.axis("off")
        plt.tight_layout()
        out_pdf = os.path.join(outdir, f"ego_{str(node)}.pdf")
        plt.savefig(out_pdf)
        plt.close()


def maybe_make_pyvis(G: nx.Graph, path_html: str) -> None:
    """
    선택: pyvis 인터랙티브 HTML 생성 (betweenness 색, PageRank 크기)
    """
    try:
        from pyvis.network import Network
        import math
        bet = nx.betweenness_centrality(G, weight="weight", normalized=True)
        pr = nx.pagerank(G, alpha=0.85, weight="weight")

        net = Network(height="700px", width="100%", bgcolor="#ffffff", font_color="#222222", cdn_resources="in_line")
        # 레이아웃 좌표를 pyvis에 넘기려면 수동으로 add_node
        pos = nx.spring_layout(G, seed=42)
        pr_vals = list(pr.values())
        lo, hi = min(pr_vals), max(pr_vals)
        def scale(x, a=8, b=36):
            if hi == lo:
                return (a + b) / 2
            z = (x - lo) / (hi - lo)
            return a + z * (b - a)

        for n in G.nodes():
            size = scale(pr[n])
            title = f"<b>{n}</b><br>PageRank={pr[n]:.4f}<br>Betweenness={bet[n]:.4f}"
            net.add_node(n, label=str(n), title=title, size=size)

        for u, v, data in G.edges(data=True):
            w = data.get("weight", 1.0)
            net.add_edge(u, v, value=w)

        # 저장 (윈도우 인코딩 이슈 방지)
        html = net.generate_html(notebook=False)
        with open(path_html, "w", encoding="utf-8") as f:
            f.write(html)
    except Exception as e:
        print(f"[Warn] pyvis HTML 생성 실패: {e}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--edge-file", type=str, required=True, help="엣지 리스트 파일 경로 (.txt / .csv / .gz)")
    parser.add_argument("--delimiter", type=str, default="space", choices=["space", "tab", "csv"], help="엣지 구분자")
    parser.add_argument("--weighted", type=lambda x: str(x).lower() in {"1","true","yes","y"}, default=False,
                        help="세 번째 컬럼을 가중치로 사용할지 여부")
    parser.add_argument("--topk", type=int, default=10, help="상위 K (리포트)")
    parser.add_argument("--ego", type=int, default=5, help="Ego subgraph 저장 대상 개수 (Degree 기준)")
    parser.add_argument("--outdir", type=str, default="./outputs", help="결과물 저장 폴더")
    parser.add_argument("--pyvis", action="store_true", help="pyvis 인터랙티브 HTML 생성 여부")
    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    # 버전 로깅
    import platform
    import networkx
    print("== Environment ==")
    print(f"Python: {platform.python_version()}")
    print(f"OS: {platform.platform()}")
    print(f"networkx: {networkx.__version__}")
    try:
        import community
        print(f"community-louvain: {community.__version__}")
    except Exception:
        print("community-louvain: N/A")
    try:
        import matplotlib
        print(f"matplotlib: {matplotlib.__version__}")
    except Exception:
        pass

    # 1) 로드 & 그래프 구성
    edges = read_edges(args.edge_file, delimiter=args.delimiter, weighted=args.weighted)
    G = build_undirected_weighted_graph(edges)
    print(f"Graph built: |V|={G.number_of_nodes()} |E|={G.number_of_edges()} (undirected, weighted)")

    # 2) 중심성
    cent = compute_centralities(G)
    report_df = topk_report(cent, k=args.topk)
    out_csv = os.path.join(args.outdir, "top10_degree_pagerank.csv")
    report_df.to_csv(out_csv, encoding="utf-8")
    print(f"Saved top-k report: {out_csv}")

    # 3) Louvain 커뮤니티 시각화
    out_comm_png = os.path.join(args.outdir, "communities_louvain.png")
    draw_louvain_communities(G, out_comm_png)
    print(f"Saved community plot: {out_comm_png}")

    # 4) 핵심 노드 5곳(기본: Degree 상위) ego PDF
    deg_series = pd.Series(cent["degree"]).sort_values(ascending=False)
    hubs = list(deg_series.head(args.ego).index)
    save_ego_pdfs(G, hubs, outdir=os.path.join(args.outdir, "ego_pdfs"))
    print(f"Saved ego PDFs for hubs: {hubs}")

    # 5) 선택: pyvis
    if args.pyvis:
        out_html = os.path.join(args.outdir, "interactive_network.html")
        maybe_make_pyvis(G, out_html)
        print(f"Saved pyvis HTML: {out_html}")

    print("Done.")


if __name__ == "__main__":
    main()
