r"""
Generate NASA9 coefficients for the electronic excited states of an atom.
=========================================================================

The partition function of an atom is the the product of the partition functions
of its electronic and translational states:

.. math::

    Z_{\text{tot}} = Z_{\text{trans}} Z_{\text{el}}

The translational partition function of an atom is given by:

.. math::

    Z_{\text{trans}} = \left( \frac{2 \pi m k_B T}{h^2} \right)^{3/2} V

The electronic partition function of an atom is the sum of the partition functions
of its electronic states, which are given by:

.. math::

    Z_{\text{el}} = \sum_i g_i e^{-E_i/(k_B T)}


---

For an excited atomic state, the electronic partition function is simply:

.. math::

    Z_{\text{el}} = g e^{-E/(k_B T)}

Since :math:`\ln Z = \ln g - \frac{E}{k_B T}`,
we have :math:`\frac{\partial \ln Z}{\partial T} = \frac{E}{k_B T^2}`.


Therefore, the (electronic) internal energy of an excited atomic state is given by:

.. math::

    U_{\text{el}} = k_B T^2 \frac{\partial \ln Z_{\text{el}}}{\partial T} = E

The (electronic) entropy of an excited atomic state is given by:

.. math::

    S_{\text{el}} = k_B \left( \ln Z_{\text{el}} + T \frac{\partial \ln Z_{\text{el}}}{\partial T} \right)
                  = k_B \left( \ln g - \frac{E}{k_B T} + \frac{E}{k_B T} \right)
                  = k_B \ln g

The (electronic) heat capacity (at constant volume) of an excited atomic state is given by:

.. math::

    C_{V,\text{el}} = \frac{\partial U_{\text{el}}}{\partial T} = 0


---

Therefore, the total internal energy of an excited atomic state is given by:

.. math::

    U_{\text{tot}} = U_{\text{trans}} + U_{\text{el}}
                   = \frac{3}{2} k_B T + E

For the total enthalpy, we have:

.. math::

    H_{\text{tot}} = U_{\text{tot}} + P V / N
                   = \frac{3}{2} k_B T + E + k_B T
                   = \frac{5}{2} k_B T + E


The total entropy of an excited atomic state is given by:

.. math::

    S_{\text{tot}} = S_{\text{trans}} + S_{\text{el}}
                   = S_{\text{trans}} + k_B \ln g

The total heat capacity (at constant volume) of an excited atomic state is given by:

.. math::

    C_{V,\text{tot}} = C_{V,\text{trans}} + C_{V,\text{el}} = \frac{3}{2} k_B

The total heat capacity (at constant pressure) of an excited atomic state is given by:

.. math::

    C_{P,\\text{tot}} = C_{V,\text{tot}} + PV/N = \frac{5}{2} k_B

---

Then, we can fit to the NASA9 format, which is given by:

.. math::

    C_p(T)/R = \frac{a_0}{T^2} + \frac{a_1}{T} + a_2 + a_3 T + a_4 T^2 + a_5 T^3 + a_6 T^4 \\
    H(T)/RT  = -\frac{a_0}{T^2} + \frac{a_1 \ln T}{T} + a_2 + \frac{a_3 T}{2} + \frac{a_4 T^2}{3}
               + \frac{a_5 T^3}{4} + \frac{a_6 T^4}{5} + \frac{a_7}{T} \\
    s(T)/R   = -\frac{a_0}{2 T^2} - \frac{a_1}{T} + a_2 \ln T + a_3 T + \frac{a_4 T^2}{2}
               + \frac{a_5 T^3}{3} + \frac{a_6 T^4}{4} + a_8 \\

where :math:`a_0` to :math:`a_8` are the coefficients to be fitted.

We see that the heat capacity is constant, so we can set :math:`a_0` to :math:`a_6` to zero,
and :math:`a_2 = 5/2`.

Then, the enthalpy per particle is given by :math:`H(T) = \frac{5}{2} k_B T + E`.
Converting to per mole, we have :math:`H(T) = \frac{5}{2} R T + E N_A`.
Dividing by :math:`R T`, we have :math:`H(T)/R T = \frac{5}{2} + \frac{E N_A}{R T}`.
On the other hand, the NASA9 format for enthalpy is given by :math:`H(T)/R T = a_2 + \frac{a_7}{T}`.
Therefore, we can set :math:`a_7 = E N_A / R`.
To account for reference energy, we have :math:`a_7 = (E + E_{\text{ref}}) N_A / R`, so that
:math:`H(T) = \frac{5}{2} R T + (E + E_{\text{ref}}) N_A`. For an atom in its ground state,
:math:`E = 0`, thus at 0 K, we have :math:`H(0) = E_{\text{ref}} N_A`.


Finally, the entropy per particle is given by :math:`S(T) = S_{\text{trans}} + k_B \ln g`.
Converting to per mole, we have :math:`S(T) = S_{\text{trans}} + R \ln g`.
Dividing by :math:`R`, we have :math:`S(T)/R = S_{\text{trans}}/R + \ln g`.
On the other hand, the NASA9 format for entropy is given by :math:`S(T)/R = a_2 \ln T + a_8`.
Therefore, we can set :math:`a_8 = S_{\text{trans}}/R + \ln g`.
"""  # noqa: D205


# %%
# Import the required libraries.
# ------------------------------

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import yaml

import rizer.misc.units as u
from rizer.misc.ct_utils import (
    CanteraStringInput,
    CanteraStringInputThermo,
    from_ct_dict_to_ct_str,
)
from rizer.misc.utils import get_path_to_data

# %%
# Define the input data file.
# ---------------------------

INPUT_DATA_FILE = get_path_to_data(
    "mechanisms",
    "thermo",
    "electronic_excited_states",
    "atomic_electronic_states.yaml",
)


# %%
# Define the functions to compute the translational entropy and total entropy of an atomic excited state.
# -------------------------------------------------------------------------------------------------------


def translational_entropy(mass: float, T: float, P_ref: float = 1e5) -> float:
    r"""Compute the translational entropy of a species.

    Parameters
    ----------
    mass : float
        The mass of the species in kg.
    T : float
        The temperature in K.
    P_ref : float, optional
        The reference pressure in Pa. Default is 1e5 Pa.

    Returns
    -------
    float
        The translational entropy in J/K/particle.

    Note
    ----
    The translational entropy is given by:

    .. math::

        S_{\text{trans}} = k_\text{B} \left( \frac{5}{2} \ln T - \ln P_\text{ref}
                           + \ln \left( \frac{2 \pi m}{h^2} \right)^{3/2}
                           + \left( k_B \right)^{5/2}
                           + \frac{5}{2} \right)
    """
    return u.k_b * (
        5 / 2 * np.log(T)
        - np.log(P_ref)
        + np.log((2 * np.pi * mass / u.h**2) ** 1.5 * u.k_b**2.5)
        + 5 / 2
    )  # J/K/particle


def compute_total_entropy_atomic_excited_state(
    mass: float, T_ref: float, g: int, P_ref: float = 1e5
) -> float:
    r"""Compute the total entropy of an atomic excited state.

    Parameters
    ----------
    mass : float
        The mass of the species in kg.
    T_ref : float
        The reference temperature in K.
    g : int
        The degeneracy of the excited state.
    P_ref : float, optional
        The reference pressure in Pa. Default is 1e5 Pa.

    Returns
    -------
    float
        The total entropy of the excited state in J/mol/K.

    Note
    ----
    The total entropy is given by:

    .. math::

        S_{\text{tot}} = S_{\text{trans}} + S_{\text{el}}
                       = S_{\text{trans}} + k_\text{B} \ln g
    """
    S_tr = translational_entropy(mass, T_ref, P_ref)  # J/K/particle
    S_el = np.log(g) * u.k_b  # J/K/particle
    return (S_tr + S_el) * u.N_a  # J/mol/K


# %%
# Define the data classes for the electronic excited states and element configurations.
# --------------------------------------------------------------------------------------


@dataclass
class ElectronicExcitedState:
    symbol: str
    term: str
    configuration: str
    g: int
    E_J: float
    mass: float

    def name(self) -> str:
        return f"{self.symbol}({self.term})"


@dataclass
class ElementConfig:
    symbol: str
    element_name: str
    mass: float
    H_0_kJ_per_mol: float
    output_file: str
    ground_state_term: str
    ground_state_alias: str
    ground_state_alias_note: str
    states: list[ElectronicExcitedState]


# %%
# Define the functions to load the element configurations from the input data file.
# -----------------------------------------------------------------------------------


def load_element_configs(
    yaml_path: Path = INPUT_DATA_FILE,
) -> tuple[dict, list[ElementConfig]]:
    """Load atomic electronic state data from a YAML file."""
    with yaml_path.open(encoding="utf-8") as f:
        data = yaml.safe_load(f)

    reference = data["reference"]
    element_configs = []
    for element in data["elements"]:
        symbol = element["symbol"]
        mass = element["mass_Da"] * u.Da
        ground_state = element["ground_state"]
        states = [
            ElectronicExcitedState(
                symbol=symbol,
                term=state["term"],
                configuration=state["configuration"],
                g=state["g"],
                E_J=state["E_eV"] * u.eV_to_J,
                mass=mass,
            )
            for state in element["states"]
        ]
        element_configs.append(
            ElementConfig(
                symbol=symbol,
                element_name=element["name"],
                mass=mass,
                H_0_kJ_per_mol=element["H_0_kJ_per_mol"],
                output_file=element["output_file"],
                ground_state_term=ground_state["term"],
                ground_state_alias=ground_state["alias"],
                ground_state_alias_note=ground_state["alias_note"],
                states=states,
            )
        )

    return reference, element_configs


# %%
# Define the functions to compute the NASA9 coefficients for an atomic excited state.
# -----------------------------------------------------------------------------------


def compute_nasa9_coefficients(
    state: ElectronicExcitedState,
    T_ref: float,
    P_ref: float,
    H_0_kJ_per_mol: float,
    a_2: float,
) -> list[float]:
    """Compute NASA9 coefficients for an atomic excited state."""
    H_tot_ref = 5 / 2 * u.R * T_ref + state.E_J * u.N_a + H_0_kJ_per_mol * 1e3  # J/mol
    S_tot_ref = compute_total_entropy_atomic_excited_state(
        mass=state.mass,
        T_ref=T_ref,
        g=state.g,
        P_ref=P_ref,
    )  # J/mol/K

    a_7 = H_tot_ref / u.R - a_2 * T_ref
    a_8 = S_tot_ref / u.R - a_2 * np.log(T_ref)
    return [0.0, 0.0, a_2, 0.0, 0.0, 0.0, 0.0, a_7, a_8]


# %%
# Define the functions to build the Cantera YAML string input for an atomic excited state.
# -------------------------------------------------------------------------------------------


def build_cantera_string_input(
    state: ElectronicExcitedState,
    element_name: str,
    temperature_ranges: list[float],
    coefficients: list[float],
) -> CanteraStringInput:
    """Build a Cantera YAML species entry for an atomic excited state."""
    return CanteraStringInput(
        name=state.name(),
        composition="{" + state.symbol + ": 1}",
        thermo=CanteraStringInputThermo(
            model="NASA9",
            temperature_ranges=temperature_ranges,
            data=[coefficients],
        ),
        notes=(
            f"Fit for the excited state {state.name()} of {element_name} "
            f"({state.E_J * u.J_to_eV:.2f} eV), with term {state.term} "
            f"and configuration {state.configuration}."
        ),
    )


# %%
# Define the function to write the NASA9 coefficients for all excited states of an element to a YAML file.
# --------------------------------------------------------------------------------------------------------


def write_atomic_excited_states_nasa9_yaml(
    element: ElementConfig,
    T_ref: float,
    P_ref: float,
    temperature_ranges: list[float],
    a_2: float,
) -> Path:
    """Write NASA9 coefficients for all excited states of an element."""
    output_file = get_path_to_data(
        "mechanisms",
        "thermo",
        "electronic_excited_states",
        element.output_file,
        force_return=True,
    )
    txt = "description: |-\n"
    txt += f"  NASA9 polynomial fits for atomic excited states of {element.element_name}.\n\n"
    txt += "species:\n"

    for state in element.states:
        coefficients = compute_nasa9_coefficients(
            state,
            T_ref=T_ref,
            P_ref=P_ref,
            H_0_kJ_per_mol=element.H_0_kJ_per_mol,
            a_2=a_2,
        )
        cantera_string_input = build_cantera_string_input(
            state,
            element_name=element.element_name,
            temperature_ranges=temperature_ranges,
            coefficients=coefficients,
        )
        txt += from_ct_dict_to_ct_str(cantera_string_input) + "\n"

        is_ground_state = state.term == element.ground_state_term and state.E_J == 0.0
        if is_ground_state:
            cantera_string_input.name = element.ground_state_alias
            txt += from_ct_dict_to_ct_str(cantera_string_input)
            # Remove the last newline.
            txt = txt[:-2]
            txt += f". {element.ground_state_alias_note}.\n\n"

    # Remove the last newline.
    txt = txt[:-1]

    with output_file.open("w", encoding="utf-8") as f:
        f.write(txt)

    return output_file


# %%
# Generate the NASA9 coefficients for all excited states of all elements and write them to YAML files.
# ----------------------------------------------------------------------------------------------------

reference, element_configs = load_element_configs()
T_ref = float(reference["T_ref"])
P_ref = float(reference["P_ref"])
temperature_ranges = [float(T) for T in reference["temperature_ranges"]]
a_2 = 5 / 2

for element in element_configs:
    output_file = write_atomic_excited_states_nasa9_yaml(
        element,
        T_ref=T_ref,
        P_ref=P_ref,
        temperature_ranges=temperature_ranges,
        a_2=a_2,
    )
    print(
        f"NASA9 coefficients for atomic excited states of {element.element_name} "
        f"written to {output_file}"
    )

# %%
