r"""
Compute and write dissociative recombination of hydrocarbons.
=============================================================

This example creates a set of reactions for dissociative electron recombination in a Cantera-like format.

There are two kind of reactions considered here :


1. e- + C2Hy+ => neutral + neutral, whose reaction rates comes from [Janev2004]_.

The reaction rate constants are computed by Janev, in equation 70, by:

.. math::

    k = \frac{F_2^{DR}(y)}{T^\frac{1}{2}} \cdot f_{\text{corr}}(T) \cdot 10^{-8} \text{cm}^3 \text{s}^{-1}

where:

- :math:`F_2^{DR}(y)` is a structural function given by equation (71),
- :math:`T` is the temperature in eV,
- :math:`f_{\text{corr}}(T)` is a correcting factor given by equation (69).

Since this has not the form of a typical Arrhenius rate, a Cantera.ExtensibleRate is used to
correctly evaluate the reactioj rate, using `type: janev-dissociative-recombination-C2Hy`.
It is defined in `./rizer/kin/extensible_rate.py`.


2. A set of three reactions which are not in [Janev2002]_ nor in [Janev2004]_

The reaction of interest here are the dissociative electron recombination of:

- `e- + C+ => C`,
- `e- + H+ => H`, and
- `e- + H2+ => H + H`.

These reactions are missing from [Janev2002] and [Janev2004].
There are also not present on LXCAT.
Therefore, we use the UMIST database and the KIDA database to get the reaction rate constants.


.. tags:: kinetics, cross section, reaction rate constant, hydrocarbon, CxHy
"""  # noqa: D205

# %%
# Import the required libraries.
# ------------------------------

from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import rizer.misc.units as u
from rizer import __version__
from rizer.io.generate_janev_cross_section import JanevCrossSection
from rizer.kin.extensible_rate import to_janev_dissociative_recombination_C2Hy_format
from rizer.kin.fit_arrhenius import arrhenius_rate
from rizer.misc.ct_utils import to_cantera_format
from rizer.misc.plt_utils import get_species_in_latex, set_mpl_style
from rizer.misc.utils import get_path_to_data, get_root

set_mpl_style()

# %%
# Parameters for the script.
# --------------------------

# ----- Plot options ----- #
# Set to True to plot the cross sections, and the reaction rates.
plot_reactions_rates = True
# Plot only the following equation. If None, plot all.
plot_only_equations: None | list[str] = ["e- + C2H2+ => C2H + H"]

# ----- Write options ----- #
write_mechanism = True
output_file = get_path_to_data(
    "mechanisms",
    "Goutier2025",
    "builder",
    "dissociative_recombination_forward_reactions.yaml",
    force_return=True,
)


# %%
# Plot and prepare mechanism.
# ---------------------------

text = r"""description: |-
  The reaction of interest here are the dissociative electron recombination `e- + C2Hy+ => neutral + neutral`.

  The reaction rate constants are computed by Janev, in equation 70, by:

  .. math::

      k = \frac{F_2^{DR}(y)}{T^\frac{1}{2}} \cdot f_{\text{corr}}(T) \cdot 10^{-8} \text{cm}^3 \text{s}^{-1}

  where:

  - :math:`F_2^{DR}(y)` is a structural function given by equation (71),
  - :math:`T` is the temperature in eV,
  - :math:`f_{\text{corr}}(T)` is a correcting factor given by equation (69).

  Since this has not the form of a typical Arrhenius rate, a Cantera.ExtensibleRate is used to
  correctly evaluate the reactioj rate, using `type: janev-dissociative-recombination-C2Hy`.
  It is defined in `./rizer/kin/extensible_rate.py`.

  ---

  Reaction rates for the following equation are also added, see the end of the file:

  - `e- + C+ => C`,
  - `e- + H+ => H`, and
  - `e- + H2+ => H + H`.

"""
text += f"  Rizer version: {__version__}\n"
relative_path = Path(__file__).relative_to(get_root().parent)
text += f"  Script: {relative_path}\n"
text += "  >>>>> Do not edit this file manually. <<<<<\n\n"
text += "units: {length: cm, time: s, quantity: mol, activation-energy: K}\n\n"
text += "reactions:\n"

# Load all the dissociative electron recombination reaction rates for C2Hy+.
janev_cross_section = JanevCrossSection()
ks: dict[str, float] = janev_cross_section.load_reaction_rates("C2Hy")

for reaction, A_cm3_per_mol_per_s in ks.items():
    reaction_formatted = reaction.replace(" -> ", " => ").replace("e", "e-")

    text += "\n\n# -----------------"
    text += f" Reaction: {reaction_formatted} "
    text += "-----------------\n\n"

    note = "Janev 2004, eq. 70 without the correction factor of eq. 69"
    text += to_janev_dissociative_recombination_C2Hy_format(
        equation=reaction_formatted,
        A=A_cm3_per_mol_per_s,
        unit="cm3/mol",
        source="Janev2004",
    )

    if plot_reactions_rates:
        if (
            plot_only_equations is not None
            and reaction_formatted not in plot_only_equations
        ):
            continue
        fig, ax = plt.subplots()

        # Compute the reaction rate
        T_K = np.linspace(1000, 50_000, 1000)
        A_m3_per_s = A_cm3_per_mol_per_s * 1e-6 / u.N_a
        k = arrhenius_rate(A_m3_per_s, -0.5, 0.0, T_K)

        T_eV = T_K * u.K_to_eV
        f_corr = 1 / (1 + 0.27 * T_eV**0.55)
        k_with_fcorr = k * f_corr

        # Plot and annotate the reaction rate.
        ax.plot(T_K, k)
        ax.text(
            x=10_000,
            y=arrhenius_rate(A_m3_per_s, -0.5, 0.0, 10_000),
            s="$k(T_e)$",
            color=ax.lines[-1].get_color(),
            horizontalalignment="center",
            verticalalignment="center",
            bbox=dict(
                facecolor="white",
                alpha=0.8,
                edgecolor=ax.lines[-1].get_color(),
                boxstyle="round",
            ),
        )
        ax.plot(T_K, k_with_fcorr, "--")
        ax.text(
            x=20_000,
            y=arrhenius_rate(A_m3_per_s, -0.5, 0.0, 20_000)
            / (1 + 0.27 * (20_000 * u.K_to_eV) ** 0.55),
            s=r"$k_\text{corr}(T_e)$",
            color=ax.lines[-1].get_color(),
            horizontalalignment="center",
            verticalalignment="center",
            bbox=dict(
                facecolor="white",
                alpha=0.8,
                edgecolor=ax.lines[-1].get_color(),
                boxstyle="round",
            ),
        )

        # Plot options.
        ax.set_xlabel("Electron temperature [K]")
        ax.set_ylabel("Reaction rate [m³/s]")

        # Nice title for the reaction.
        reactant, product = reaction_formatted.split(" => ")
        title = " + ".join(
            [get_species_in_latex(species) for species in reactant.split(" + ")]
        )
        title += r"\Rightarrow"
        title += " + ".join(
            [get_species_in_latex(species) for species in product.split(" + ")]
        )
        title = "$" + title.replace("$", "") + "$"
        ax.set_title(f"Reaction rate of {title}")
        ax.set_xlim(left=np.min(T_K), right=np.max(T_K))
        ax.set_xscale("log")
        ax.set_yscale("log")
        plt.show()


text += "\n\n######################################################################"
text += "\n######################################################################"
text += "\n######################################################################\n\n"


# %%
# Add missing dissociative recombination reactions for C+, H+ and H2+.
# --------------------------------------------------------------------

text += """
#  The reaction of interest here are the dissociative electron recombination of:
#
#  - `e- + C+ => C`,
#  - `e- + H+ => H`, and
#  - `e- + H2+ => H + H`.
#
#  These reactions are missing from [Janev2002] and [Janev2004].
#  There are also not present on LXCAT.
#  Therefore, we use the UMIST database and the KIDA database to get the reaction rate constants.
"""


@dataclass
class MissingReactionRate:
    equation: str
    alpha: float
    beta: float
    gamma: float
    temperature_range: tuple[float, float]
    url: str
    source: str
    see_also: str


missing_reaction_rates: list[MissingReactionRate] = [
    MissingReactionRate(
        equation="e- + C+ => C",
        alpha=2.36e-12,
        beta=-0.29,
        gamma=-17.6,
        temperature_range=(10.0, 41000.0),
        url="https://umistdatabase.uk/react/8741",
        source="UMIST Database",
        see_also="https://kida.astrochem-tools.org/reaction/2788/C+_+_e-.html?filter=Both",
    ),
    MissingReactionRate(
        equation="e- + H+ => H",
        alpha=3.50e-12,
        beta=-0.75,
        gamma=0.0,
        temperature_range=(10.0, 20000.0),
        url="https://umistdatabase.uk/react/8745",
        source="UMIST Database",
        see_also="https://kida.astrochem-tools.org/reaction/2791/H+_+_e-.html?filter=Both",
    ),
    MissingReactionRate(
        equation="e- + H2+ => H + H",
        alpha=1.59e-8,
        beta=-1.18,
        gamma=7.12,  # 7.12 K is very close to 0 K
        temperature_range=(10.0, 1000.0),  # NOTE: very low range?
        url="https://kida.astrochem-tools.org/reaction/1318/H2+_+_e-.html?filter=Both",
        source="KIDA Database",
        see_also="https://umistdatabase.uk/react/1517",
    ),
]


fig, ax = plt.subplots()
for missing_reaction in missing_reaction_rates:
    # From https://umistdatabase.uk/files/UDfA2024.pdf:
    # k = alpha * (T/300)**beta * exp(-gamma/T) [cm3/s]

    A_cm3_per_mol_per_s = (
        missing_reaction.alpha * (1 / 300) ** missing_reaction.beta * u.N_a
    )  # Convert from cm³/s to cm³/moles/s
    n = missing_reaction.beta
    Ea = missing_reaction.gamma  # [K]

    text += "\n\n# -----------------"
    text += f" Reaction: {missing_reaction.equation} + photon"
    text += "-----------------\n\n"
    text += to_cantera_format(
        A_cm3_per_mol_per_s,
        n,
        Ea,
        plasma=True,
        equation=missing_reaction.equation,
        note=missing_reaction.source,
        umist_url=missing_reaction.url,
        temperature_range=missing_reaction.temperature_range,
        source=missing_reaction.source,
        see_also=missing_reaction.see_also,
    )

    # Compute the reaction rate
    T_K = np.linspace(300, 50_000, 1000)
    A_m3_per_s = A_cm3_per_mol_per_s * 1e-6 / u.N_a
    k = arrhenius_rate(A_m3_per_s, n, Ea, T_K)

    # Plot and annotate the reaction rate.
    ax.plot(T_K, k, alpha=0.5, ls="--")

    mask = (T_K > missing_reaction.temperature_range[0]) & (
        T_K < missing_reaction.temperature_range[1]
    )
    color = ax.lines[-1].get_color()
    ax.plot(T_K[mask], k[mask], color=color)

    # Nice label for the reaction.
    reactant, product = missing_reaction.equation.split(" => ")
    label = " + ".join(
        [get_species_in_latex(species) for species in reactant.split(" + ")]
    )
    label += r"\Rightarrow"
    label += " + ".join(
        [get_species_in_latex(species) for species in product.split(" + ")]
    )
    label += r" + h\nu"
    label = "$" + label.replace("$", "") + "$"
    ax.text(
        x=10_000,
        y=arrhenius_rate(A_m3_per_s, n, Ea, 10_000),
        s=label,
        color=color,
        horizontalalignment="center",
        verticalalignment="center",
        bbox=dict(
            facecolor="white",
            alpha=0.8,
            edgecolor=color,
            boxstyle="round",
        ),
    )

# Plot options.
ax.set_xlabel("Electron temperature [K]")
ax.set_ylabel("Reaction rate [m³/s]")
ax.set_title("Dissociative recombination reaction rates")
ax.set_xlim(left=np.min(T_K), right=np.max(T_K))
ax.set_xscale("log")
ax.set_yscale("log")
plt.show()

# %%
# Write the mechanism to the YAML file.
# -------------------------------------

if write_mechanism:
    # Write the electronic reactions to the file.
    print(f"Writing to {output_file}...")
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(text)
    print(f"File written to {output_file}!")

# %%
