r"""
Plot electron-impact cross sections for atomic hydrogen.
========================================================

This example loads hydrogen cross section data from the Morgan and Janev databases,
and plots them in three groups:

* Excitation from the ground state (1s): Morgan 2s/2p and Janev H → H(n=2|3|4)
* Excitation between excited states: Janev H(n=2) → H(n=3|4) and H(n=3) → H(n=4)
* Ionization to H\ :sup:`+`: Morgan ground state and Janev from H, H(n=2|3|4)

Data files:

* Morgan [Morgan1992]_ — ``data/kin/cross_section/H/Morgan.txt``
* Janev [JanevHydrogen]_ — ``data/kin/cross_section/H(1s|n=2|n=3|n=4)/janev_cross_sections_*.txt``

.. tags:: kinetics, cross section, hydrogen
"""  # noqa: D205

# %%
# Import the required libraries.
# ------------------------------

from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
from adjustText import adjust_text
from matplotlib.axes import Axes

import rizer.misc.units as u
from rizer.io.lxcat import Collision, LXCat
from rizer.misc.plt_utils import get_text, set_mpl_style
from rizer.misc.utils import get_path_to_data
from rizer.plasma.equations import (
    druyvesteyn_distribution_function_in_energy,
    maxwellian_distribution_function_in_energy,
)

set_mpl_style(nb_columns=1)

# %%
# Compute Druyvesteyn and Maxwellian distribution functions in energy.
# --------------------------------------------------------------------

T = 4 * u.eV_to_K  # K
energies = np.linspace(0, 100, 10000) * u.eV_to_J  # J
f_M = maxwellian_distribution_function_in_energy(T, energies)  # J^-1
f_D = druyvesteyn_distribution_function_in_energy(T, energies)  # J^-1

# %%
# Helpers to load and plot cross sections.
# ----------------------------------------


@dataclass
class CrossSectionData:
    subdir: str
    filename: str
    species: str
    reaction: str
    label: str
    color: str


def load_collision(
    subdir: str,
    filename: str,
    species: str,
    reaction: str,
) -> Collision:
    """Load one collision from an LXCat file under ``kin/cross_section``."""
    lx = LXCat(verbose=False)
    lx.read(file=get_path_to_data("kin", "cross_section", subdir, filename))
    return lx.species[species].collisions[reaction]


def plot_curves(
    ax: Axes,
    cross_section_data: list[CrossSectionData],
) -> None:
    """Plot cross sections described by (subdir, file, species, reaction, label, color)."""
    texts = []
    for cs_data in cross_section_data:
        collision = load_collision(
            cs_data.subdir, cs_data.filename, cs_data.species, cs_data.reaction
        )
        ax.plot(
            collision.energy_eV,
            collision.cross_section_cm2,
            color=cs_data.color,
        )
        texts.append(
            get_text(
                float(collision.energy_eV[np.argmax(collision.cross_section_cm2)]),
                float(np.max(collision.cross_section_cm2)),
                cs_data.label,
                ax=ax,
                color=cs_data.color,
            )
        )
    adjust_text(texts, avoid_self=False)


def style_cross_section_axes(
    ax: Axes, title: str, xlim: tuple[float, float], ylim: tuple[float, float]
) -> None:
    """Apply common log-log styling for cross section plots."""
    ax.set_xlabel("Energy [eV]")
    ax.set_ylabel("Cross section [cm²]")
    ax.set_title(title)
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlim(left=xlim[0], right=xlim[1])
    ax.set_ylim(bottom=ylim[0], top=ylim[1])


def plot_distribution_function(
    ax: Axes,
    T: float,
    x_text_in_eV: float,
    plot_druyvesteyn: bool = False,
) -> None:
    """Plot the Maxwellian distribution function in energy."""
    energies = np.linspace(0, 100, 10000) * u.eV_to_J  # J
    f_M = maxwellian_distribution_function_in_energy(T, energies)  # J^-1
    ax.plot(energies * u.J_to_eV, f_M / u.J_to_eV, color="black", linestyle="--")
    ax.set_ylabel(r"$f\left(\varepsilon\right)$ [$\mathrm{eV^{-1}}$]", color="black")
    ax.grid(False)

    y_text = f_M[np.argmin(np.abs(energies * u.J_to_eV - x_text_in_eV))] / u.J_to_eV
    get_text(
        x_text_in_eV,
        y_text,
        rf"$T_\mathrm{{e}}={int(T * u.K_to_eV):,} \ \mathrm{{eV}}$ ($f_\text{{M}}$)",
        ax=ax,
        color="black",
    )
    if plot_druyvesteyn:
        f_D = druyvesteyn_distribution_function_in_energy(T, energies)  # J^-1
        ax.plot(energies * u.J_to_eV, f_D / u.J_to_eV, color="black", linestyle=":")
        y_text = f_D[np.argmin(np.abs(energies * u.J_to_eV - x_text_in_eV))] / u.J_to_eV
        get_text(
            x_text_in_eV,
            y_text,
            rf"$T_\mathrm{{e}}={int(T * u.K_to_eV):,} \ \mathrm{{eV}}$ ($f_\text{{D}}$)",
            ax=ax,
            color="black",
        )


# %%
# Figure 1 — Excitation from H.
# -----------------------------

FIG1_CURVES = [
    CrossSectionData(
        subdir="H",
        filename="Morgan.txt",
        species="H",
        reaction="H -> H(2s)(10.2eV)",
        label="Morgan, 2s",
        color="r",
    ),
    CrossSectionData(
        subdir="H",
        filename="Morgan.txt",
        species="H",
        reaction="H -> H(2p)(10.2eV)",
        label="Morgan, 2p",
        color="b",
    ),
    CrossSectionData(
        subdir="H",
        filename="janev_cross_sections_H.txt",
        species="H",
        reaction="H -> H(n=2)",
        label="Janev, n=2",
        color="g",
    ),
    CrossSectionData(
        subdir="H",
        filename="janev_cross_sections_H.txt",
        species="H",
        reaction="H -> H(n=3)",
        label="Janev, n=3",
        color="k",
    ),
    CrossSectionData(
        subdir="H",
        filename="janev_cross_sections_H.txt",
        species="H",
        reaction="H -> H(n=4)",
        label="Janev, n=4",
        color="c",
    ),
]

fig1, ax1 = plt.subplots()
style_cross_section_axes(
    ax1,
    r"Electronic excitation of $\mathrm{H}$: "
    r"$\mathrm{e^- + H \rightarrow e^- + H(n)}$",
    xlim=(1, 1e3),
    ylim=(1e-19, 1e-16),
)
plot_curves(ax1, FIG1_CURVES)
ax1_twin = ax1.twinx()
plot_distribution_function(ax=ax1_twin, T=T, x_text_in_eV=2)
ax1.set_zorder(ax1_twin.get_zorder() + 1)  # put ax in front of ax2
ax1.patch.set_visible(False)  # hide the 'canvas'
plt.show()

# %%
# Figure 2 — Excitation between excited states.
# ---------------------------------------------

FIG2_CURVES = [
    CrossSectionData(
        subdir="H(n=2)",
        filename="janev_cross_sections_H(n=2).txt",
        species="H(n=2)",
        reaction="H(n=2) -> H(n=3)",
        label="Janev, n=2 → n=3",
        color="r",
    ),
    CrossSectionData(
        subdir="H(n=2)",
        filename="janev_cross_sections_H(n=2).txt",
        species="H(n=2)",
        reaction="H(n=2) -> H(n=4)",
        label="Janev, n=2 → n=4",
        color="b",
    ),
    CrossSectionData(
        subdir="H(n=3)",
        filename="janev_cross_sections_H(n=3).txt",
        species="H(n=3)",
        reaction="H(n=3) -> H(n=4)",
        label="Janev, n=3 → n=4",
        color="g",
    ),
]

fig2, ax2 = plt.subplots()
style_cross_section_axes(
    ax2,
    r"Electronic excitation of $\mathrm{H}$: "
    r"$\mathrm{e^- + H(n) \rightarrow e^- + H(m)}$",
    xlim=(1e-1, 1e3),
    ylim=(1e-18, 1e-13),
)
plot_curves(ax2, FIG2_CURVES)
ax2_twin = ax2.twinx()
plot_distribution_function(ax=ax2_twin, T=T, x_text_in_eV=0.4)
ax2.set_zorder(ax2_twin.get_zorder() + 1)  # put ax in front of ax2
ax2.patch.set_visible(False)  # hide the 'canvas'
plt.show()

# %%
# Figure 3 — Ionization to H\ :sup:`+`.
# --------------------------------------

FIG3_CURVES = [
    CrossSectionData(
        subdir="H",
        filename="Morgan.txt",
        species="H",
        reaction="H -> H^+",
        label="Morgan, H(1s)",
        color="r",
    ),
    CrossSectionData(
        subdir="H",
        filename="janev_cross_sections_H.txt",
        species="H",
        reaction="H -> H^+",
        label="Janev, H(1s)",
        color="b",
    ),
    CrossSectionData(
        subdir="H(n=2)",
        filename="janev_cross_sections_H(n=2).txt",
        species="H(n=2)",
        reaction="H(n=2) -> H^+",
        label="Janev, H(n=2)",
        color="g",
    ),
    CrossSectionData(
        subdir="H(n=3)",
        filename="janev_cross_sections_H(n=3).txt",
        species="H(n=3)",
        reaction="H(n=3) -> H^+",
        label="Janev, H(n=3)",
        color="k",
    ),
    CrossSectionData(
        subdir="H(n=4)",
        filename="janev_cross_sections_H(n=4).txt",
        species="H(n=4)",
        reaction="H(n=4) -> H^+",
        label="Janev, H(n=4)",
        color="c",
    ),
]

fig3, ax3 = plt.subplots()
style_cross_section_axes(
    ax3,
    r"Ionization of $\mathrm{H}$: $\mathrm{e^- + H(n) \rightarrow e^- + H^+ + e^-}$",
    xlim=(0.1, 1e3),
    ylim=(1e-18, 1e-14),
)
plot_curves(ax3, FIG3_CURVES)
ax3_twin = ax3.twinx()
plot_distribution_function(ax=ax3_twin, T=T, x_text_in_eV=0.3)
ax3.set_zorder(ax3_twin.get_zorder() + 1)  # put ax in front of ax2
ax3.patch.set_visible(False)  # hide the 'canvas'
plt.show()

# %%
