r"""
Cross sections and reaction rate constants for the electron impact of hydrocarbon.
==================================================================================

This example creates a set of electronic reactions for electron impact on hydrocarbon species.


This example shows how to load cross section for the electron impact of hydrocarbon
from different databases, plot the cross sections, and compute the reaction rate constant.

The cross section is loaded from the following databases
(all available on the LXCat website, except for the Janev database):

* Hayashi [Hayashi2010]_
* IST Lisbonne [ISTLisbonne1995]_
* Morgan [Morgan1992]_
* Song Bouwman [SongBouwman2021]_
* Janev [Janev2002]_
* TODO: Add references for the other databases.


Note
----
The reaction rate constant is computed from the cross section, assuming that electrons
follow a Maxwellian distribution, using the following formula:

.. math::

    k(T) = v_{th, e} \int_0^{\infty} x e^{-x} \tilde{Q}_{12}\left(k_b T_e x\right) d x

where:

- :math:`k(T)` is the reaction rate constant,
- :math:`v_{th, e}=\sqrt{\frac{8 k_b T_e}{\pi m_e}}` is the thermal electron velocity,
- :math:`\tilde{Q}_{12}` is the (electron dependent) cross-section,
- :math:`T_e` is the electron temperature.


Then, the reaction rate constant is fitted to a modified Arrhenius expression, through a
least-squares fit. The modified Arrhenius expression is given by:

.. math::

    k(T) = A T^b \exp\left(-\frac{E_a}{T}\right)

where:

- :math:`A` is the pre-exponential factor,
- :math:`b` is the temperature exponent,
- :math:`E_a` is the activation energy (here, in Kelvin).


.. tags:: kinetics, cross section, reaction rate constant, hydrocarbon, CxHy
"""  # noqa: D205

# %%
# Import the required libraries.
# ------------------------------

from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import yaml
from adjustText import adjust_text

from rizer import __version__
from rizer.kin.fit_arrhenius import (
    ArrheniusRate,
    arrhenius_rate,
    arrhenius_rate_fit_from_cross_section,
    group_and_sort_arrhenius_rate_by_equation,
)
from rizer.kin.load_cross_sections import (
    CrossSectionData,
    filter_cross_sections_to_load,
    group_cross_sections_by_equation,
    load_cross_sections_data,
    load_cross_sections_summary_yaml,
)
from rizer.misc.plt_utils import get_species_in_latex, set_mpl_style
from rizer.misc.utils import get_path_to_data, get_root

set_mpl_style()

# %%
# Parameters for the script.
# --------------------------

# ----- Plot options ----- #
# Set to True to plot the cross sections, and the reaction rates.
plot_cross_sections = True
plot_reactions_rates = True
# Plot only the following equation. If None, plot all.
plot_only_equations: None | list[str] = None


# ----- Reaction rate options ----- #
# Set to True to only keep the Janev fit, even if it is not the best fit according to the cost function.
only_janev = True
# Set to True to only keep carbon excited states.
only_excited_state = False

# Number of point to use when computing the forward reaction rate.
nb_points_integrand = 100_000

# Set the temperature range for the reaction rate constant fit.
temperature_min = 1000  # [K]
temperature_max = 50_000  # [K]
temperature_step = 50  # [K]


# ----- Write options ----- #
write_mechanism = True

species_to_keep: list[str] | None = None
species_to_remove: list[str] | None = None
equations_to_keep: list[str] | None = None
equations_to_remove: list[str] | None = None
name = "electron_heavy_forward_reactions.yaml"
if only_janev and only_excited_state:
    raise ValueError(
        "Both `only_janev` and `only_excited_state` can not be set to True at the same time."
    )
elif only_janev:
    name = "electron_heavy_forward_reactions_Janev.yaml"
    species_to_remove = ["C(1D)", "C(1S)", "C(5So)"]
    equations_to_remove = [
        "e- + C => e- + C(1D)",
        "e- + C => e- + C(1S)",
        "e- + C => e- + C(5So)",
    ]
    plot_only_equations = ["e- + CH4 => e- + e- + CH4+"]
elif only_excited_state:
    name = "electronic_excited_states/electron_heavy_forward_reactions_carbon_excited_states.yaml"
    species_to_keep = ["C", "C(1D)", "C(1S)", "C(5So)"]
    equations_to_remove = ["e- + C => e- + e- + C+"]
    plot_only_equations = ["e- + C => e- + C(1D)"]

output_file = get_path_to_data(
    "mechanisms",
    "Goutier2025",
    "builder",
    name,
    force_return=True,
)

# %%
# Load cross sections.
# --------------------

filtered_data = filter_cross_sections_to_load(
    load_cross_sections_summary_yaml(),
    species_to_keep=species_to_keep,
    species_to_remove=species_to_remove,
    equations_to_keep=equations_to_keep,
    equations_to_remove=equations_to_remove,
)

print("Loading cross sections...")
cross_sections: list[CrossSectionData] = load_cross_sections_data(filtered_data)
print(f"Finished loading cross sections. Found {len(cross_sections)} cross sections.")


# %%
# Load plot options.
# ------------------


@dataclass
class DatabasePlotOptions:
    color: str
    label: str


plot_options_dict: dict[str, DatabasePlotOptions] = {}
path_to_cross_sections_yaml = get_path_to_data(
    "kin", "cross_section", "plot_options_by_database.yaml"
)
with open(path_to_cross_sections_yaml, "r", encoding="utf-8") as yaml_file:
    data = yaml.safe_load(yaml_file)
    assert isinstance(data, dict)
    for database, color_plot in data.items():
        assert isinstance(database, str)
        assert isinstance(color_plot["color"], str)
        assert isinstance(color_plot["label"], str)

        plot_options_dict[database] = DatabasePlotOptions(
            color=color_plot["color"], label=color_plot["label"]
        )

# %%
# Plot cross sections.
# --------------------

if plot_cross_sections:
    group_cross_sections = group_cross_sections_by_equation(cross_sections)

    for equation, cross_section_list in group_cross_sections.items():
        if plot_only_equations is not None and equation not in plot_only_equations:
            continue

        fig, ax = plt.subplots()

        texts = []

        for cross_section in cross_section_list:
            # Extract the cross section and energy as numpy arrays.
            energy_eV: np.ndarray = cross_section.df.energy_eV  # [eV]
            cross_section_cm2: np.ndarray = cross_section.df.cross_section_cm2  # [cm²]

            # Plot the cross section.
            color = plot_options_dict[cross_section.database].color
            ax.plot(energy_eV, cross_section_cm2, color=color)

            # Plot label at the maximum cross section.
            label = plot_options_dict[cross_section.database].label
            texts.append(
                ax.text(
                    x=energy_eV[np.argmax(cross_section_cm2)],
                    y=np.max(cross_section_cm2),
                    s=label,
                    color=color,
                    va="center",
                    ha="center",
                    bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
                )
            )

        # Plot options.
        ax.set_xlabel("Energy [eV]")
        ax.set_ylabel("Cross section [cm²]")

        # Nice title for the reaction.
        reactant, product = equation.split(" => ")
        title = " + ".join(
            [get_species_in_latex(species) for species in reactant.split(" + ")]
        )
        title += r"\Rightarrow"
        title += " + ".join(
            [get_species_in_latex(species) for species in product.split(" + ")]
        )
        title = "$" + title.replace("$", "") + "$"
        ax.set_title(f"Cross section of {title}")
        ax.set_xlim(left=1)
        ax.set_xscale("log")
        ax.set_yscale("log")
        adjust_text(texts, arrowprops=dict(arrowstyle="->", color="k"))
        plt.show()


# %%
# Compute reaction rate.
# ----------------------


# Define the electron temperatures ([K]) at which to compute the reaction rate constant.
electron_temperatures = np.arange(
    temperature_min, temperature_max, temperature_step, dtype=float
)

arrhenius_rates: list[ArrheniusRate] = []
for i, cross_section in enumerate(cross_sections):
    print(f"Computing Arrhenius rate {i + 1}/{len(cross_sections)}")
    arrhenius_rates.append(
        arrhenius_rate_fit_from_cross_section(
            cross_section,
            electron_temperatures,
            nb_points_integrand=nb_points_integrand,
        )
    )
print("All Arrhenius rates have been computed.")

# %%
# Plot reaction rates.
# --------------------

group_arrhenius_rates = group_and_sort_arrhenius_rate_by_equation(
    arrhenius_rates, janev_cross_section_first=True
)
if plot_reactions_rates:
    for equation, arrhenius_rate_list in group_arrhenius_rates.items():
        if plot_only_equations is not None and equation not in plot_only_equations:
            continue

        fig, ax = plt.subplots()

        texts = []

        for rate in arrhenius_rate_list:
            # Compute the reaction rate
            k_fit = arrhenius_rate(
                rate.A_m3_per_s, rate.b, rate.Ea_K, electron_temperatures
            )

            # Plot the cross section.
            color = plot_options_dict[rate.source].color
            ax.plot(electron_temperatures, k_fit, color=color)
            ax.scatter(
                rate.electron_temperature[::10],
                rate.k_raw[::10],
                color=color,
                marker="+",
                s=200,
            )

            # Plot label at 3000 K.
            label = plot_options_dict[rate.source].label
            texts.append(
                ax.text(
                    x=3000,
                    y=arrhenius_rate(rate.A_m3_per_s, rate.b, rate.Ea_K, 3000),
                    s=label,
                    color=color,
                    va="center",
                    ha="center",
                    bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
                )
            )

        # Plot options.
        ax.set_xlabel("Electron temperature [K]")
        ax.set_ylabel("Reaction rate [m³/s]")

        # Nice title for the reaction.
        reactant, product = equation.split(" => ")
        title = " + ".join(
            [get_species_in_latex(species) for species in reactant.split(" + ")]
        )
        title += r"\Rightarrow"
        title += " + ".join(
            [get_species_in_latex(species) for species in product.split(" + ")]
        )
        title = "$" + title.replace("$", "") + "$"
        ax.set_title(f"Reaction rate of {title}")
        ax.set_xlim(
            left=np.min(electron_temperatures), right=np.max(electron_temperatures)
        )
        ax.set_xscale("log")
        ax.set_yscale("log")
        adjust_text(texts, arrowprops=dict(arrowstyle="->", color="k"))
        plt.show()


# %%
# Write the reactions to a given path.
# ------------------------------------

text = r"""description: |-
  Reaction rate constants fitted to Arrhenius expressions.

  The reaction rate constant is first computed from the cross section,
  using the equation:

    k(T) = v_{th, e} \int_0^{\infty} x e^{-x} \tilde{Q}_{12}\left(k_b T_e x\right) d x

  where:

  - :math:`k(T)` is the reaction rate constant,
  - :math:`v_{th, e}=\sqrt{\frac{8 k_b T_e}{\pi m_e}}` is the thermal electron velocity,
  - :math:`\tilde{Q}_{12}` is the (electron dependent) cross-section,
  - :math:`T_e` is the electron temperature.

  The reaction rate constant is then fitted to an Arrhenius expression:

    k(T) = A T^b exp(-Ea / T)

  The units for A is cm³/moles/s/K^(-b), and the units for Ea is K.

  The fit is done using a least square optimization.
  A cost function is computed to evaluate the quality of the fit.
  The cost function is the sum of the square of the residuals, divided by 2.
  The residuals are the difference between the fitted rate constant and the training data.
  The fit with the lowest cost is considered the best fit, and is uncommented.
"""
if only_janev:
    text += "  However, since `only_janev` is true, Janev cross section are used\n"
    text += "  even if it is not the best fit according to the cost function.\n\n"
else:
    text += "\n"
text += f"  Rizer version: {__version__}\n"
relative_path = Path(__file__).relative_to(get_root().parent)
text += f"  Script: {relative_path}\n"
text += f"  Number of points for commputing k: {nb_points_integrand}\n"
text += f"  Fit temperature range: {temperature_min} K to {temperature_max} K, {temperature_step} K step.\n"
text += "  >>>>> Do not edit this file manually. <<<<<\n\n"
text += "units: {length: cm, time: s, quantity: mol, activation-energy: K}\n\n"
text += "reactions:\n"


for equation, arrhenius_rate_list in group_arrhenius_rates.items():
    text += "\n\n# -----------------"
    text += f" Reaction: {equation} "
    text += "-----------------\n\n"

    text += arrhenius_rate_list[0].as_cantera_string(
        unit_A="cm3/mol/s", unit_Ea="K", commented=False
    )
    for rate in arrhenius_rate_list[1:]:
        text += rate.as_cantera_string(unit_A="cm3/mol/s", unit_Ea="K", commented=True)

text += "\n\n######################################################################"
text += "\n######################################################################"
text += "\n######################################################################\n"


# %%
# Combine all the text and write it to the YAML file.
# ---------------------------------------------------

print(text)

if write_mechanism:
    # Write the electronic reactions to the file.
    print(f"Writing to {output_file}...")
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(text)
    print(f"File written to {output_file}!")

# %%
