#!/usr/bin/env python3
"""Build the AnimeTimm organization card assets from checked-in model data."""

import csv
import json
from collections import OrderedDict
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

ROOT = Path(__file__).resolve().parents[1]
ASSETS = ROOT / "assets"
DATA = ROOT / "data"
ASSETS.mkdir(exist_ok=True)

FAMILY_ORDER = [
    "ConvNeXt",
    "EVA",
    "CAFormer",
    "SwinV2",
    "ViT",
    "MobileNetV4",
    "MobileNetV3",
    "MobileViT",
    "ResNet",
]


def load_rows():
    with (DATA / "dbv4_full_models.csv").open(newline="", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))
    for row in rows:
        row["rank_by_macro_best"] = int(row["rank_by_macro_best"])
        for key in ["params_m", "macro_best", "macro_040", "micro_040"]:
            row[key] = float(row[key])
    return sorted(rows, key=lambda item: item["macro_best"], reverse=True)


def write_data_views(rows):
    with (DATA / "dbv4_full_models.json").open("w", encoding="utf-8") as f:
        json.dump(rows, f, indent=2)
    best = OrderedDict()
    for row in rows:
        best.setdefault(row["family"], row)
    family_rows = [best[family] for family in FAMILY_ORDER if family in best]
    with (DATA / "featured_models.json").open("w", encoding="utf-8") as f:
        json.dump({"top5_macro_f1": rows[:5], "best_by_family": family_rows}, f, indent=2)
    return rows[:5], family_rows


def draw_banner():
    fig, ax = plt.subplots(figsize=(12, 3.6), dpi=160)
    fig.patch.set_facecolor("#12131c")
    ax.set_facecolor("#12131c")
    ax.set_xlim(0, 1200)
    ax.set_ylim(0, 360)
    ax.axis("off")
    colors = ["#ff7ab6", "#8f7cff", "#55d6be", "#ffd166", "#7bdff2"]
    for i, color in enumerate(colors):
        ax.add_patch(Rectangle((i * 240, 0), 240, 360, color=color, alpha=0.10 + 0.02 * i, lw=0))
    for x in range(0, 1201, 60):
        ax.plot([x, x], [0, 360], color="white", alpha=0.045, lw=0.7)
    for y in range(0, 361, 60):
        ax.plot([0, 1200], [y, y], color="white", alpha=0.045, lw=0.7)
    for x, y, size, color in [(930, 240, 52, "#ff7ab6"), (1010, 165, 38, "#55d6be"), (1080, 260, 28, "#ffd166"), (845, 120, 32, "#8f7cff")]:
        ax.scatter([x], [y], s=size * size / 4, marker="*", color=color, alpha=0.85, edgecolors="white", linewidths=0.8)
    labels = ["tags", "timm", "anime", "F1", "ONNX", "safetensors"]
    positions = [(720, 78), (812, 285), (935, 65), (1045, 104), (705, 210), (1085, 205)]
    for (x, y), label, color in zip(positions, labels, colors + ["#ffffff"]):
        ax.text(x, y, label, color="#f7f7fb", fontsize=13, fontweight="bold", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.35,rounding_size=0.18", fc=color, ec="white", alpha=0.35, lw=0.8))
    ax.text(70, 218, "AnimeTimm", color="white", fontsize=54, fontweight="bold", ha="left", va="center")
    ax.text(75, 157, "timm-based vision models for anime-style image tagging", color="#e6e7f3", fontsize=20, ha="left", va="center")
    ax.text(76, 108, "A DeepGHS research-and-hobbyist project", color="#ffcfdf", fontsize=15, ha="left", va="center")
    xs = [720, 800, 890, 980, 1080]
    ys = [135, 175, 205, 225, 252]
    ax.plot(xs, ys, color="#ffffff", lw=2.2, alpha=0.78)
    ax.scatter(xs, ys, s=75, color="#55d6be", edgecolor="white", linewidth=1.2, zorder=3)
    fig.savefig(ASSETS / "animetimm-banner.png", bbox_inches="tight", pad_inches=0)
    plt.close(fig)


def draw_snapshot(top5, family_rows):
    plot_rows = top5 + [row for row in family_rows if row not in top5]
    plot_rows = sorted(plot_rows, key=lambda row: row["macro_best"])
    fig, ax1 = plt.subplots(figsize=(11.5, 7.2), dpi=160)
    fig.patch.set_facecolor("#ffffff")
    ax1.set_facecolor("#fbfbff")
    y = list(range(len(plot_rows)))
    labels = [row["model"].replace(".dbv4-full", "") for row in plot_rows]
    bar_colors = ["#8f7cff" if row in top5 else "#55b9d6" for row in plot_rows]
    ax1.barh(y, [row["macro_best"] for row in plot_rows], color=bar_colors, alpha=0.88)
    ax1.set_yticks(y)
    ax1.set_yticklabels(labels, fontsize=9)
    ax1.set_xlim(0.28, 0.63)
    ax1.set_xlabel("Macro@Best F1", fontsize=11)
    ax1.grid(axis="x", alpha=0.22)
    for idx, row in enumerate(plot_rows):
        ax1.text(row["macro_best"] + 0.004, idx, f"{row['macro_best']:.3f}", va="center", fontsize=8, color="#222222")
    ax2 = ax1.twiny()
    ax2.scatter([row["params_m"] for row in plot_rows], y, color="#ff7ab6", s=46, edgecolor="white", linewidth=0.8, zorder=5)
    ax2.set_xscale("log")
    ax2.set_xlabel("Parameters (M, log scale)", fontsize=11)
    ax2.set_xlim(12, 900)
    for idx, row in enumerate(plot_rows):
        ax2.text(row["params_m"] * 1.04, idx + 0.14, f"{row['params_m']:.1f}M", va="center", fontsize=7, color="#8a1f55")
    ax1.set_title("AnimeTimm dbv4-full Model Snapshot", fontsize=16, fontweight="bold", pad=14)
    fig.text(0.125, 0.028, "Purple bars mark the top-5 Macro@Best F1 models. Blue bars are best-per-family representatives not already in top-5.", fontsize=8.5, color="#555555")
    fig.tight_layout(rect=[0, 0.045, 1, 1])
    fig.savefig(ASSETS / "dbv4-full-model-snapshot.png", bbox_inches="tight")
    plt.close(fig)


def main():
    rows = load_rows()
    top5, family_rows = write_data_views(rows)
    draw_banner()
    draw_snapshot(top5, family_rows)


if __name__ == "__main__":
    main()
