"""Molecule class for optimisation."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import networkx as nx
import numpy as np
if TYPE_CHECKING:
from collections import abc
from .atom import Atom
from .bond import Bond
[docs]
@dataclass
class Molecule:
"""Molecule to optimize.
Parameters:
atoms:
Atoms that define the molecule.
bonds:
Bonds between atoms that define the molecule.
position_matrix:
A ``(n, 3)`` matrix holding the position of every atom in
the :class:`.Molecule`.
"""
atoms: tuple[Atom, ...]
bonds: tuple[Bond, ...]
position_matrix: np.ndarray
def __post_init__(self) -> None:
"""Post initialization of molecule."""
self.atoms = tuple(self.atoms)
self.bonds = tuple(self.bonds)
self.position_matrix = np.array(
self.position_matrix.T,
dtype=np.float64,
)
[docs]
def get_position_matrix(self) -> np.ndarray:
"""Return a matrix holding the atomic positions.
Returns:
The array has the shape ``(n, 3)``. Each row holds the
x, y and z coordinates of an atom.
"""
return np.array(self.position_matrix.T)
[docs]
def with_displacement(self, displacement: np.ndarray) -> Molecule:
"""Return a displaced clone Molecule.
Parameters:
displacement:
The displacement vector to be applied.
"""
new_position_matrix = self.position_matrix.T + displacement
return Molecule(
atoms=tuple(self.atoms),
bonds=tuple(self.bonds),
position_matrix=np.array(new_position_matrix),
)
[docs]
def with_position_matrix(self, position_matrix: np.ndarray) -> Molecule:
"""Return clone Molecule with new position matrix.
Parameters:
position_matrix:
A position matrix of the clone. The shape of the matrix
is ``(n, 3)``.
"""
return Molecule(
atoms=tuple(self.atoms),
bonds=tuple(self.bonds),
position_matrix=np.array(position_matrix),
)
[docs]
def write_xyz_content(self) -> list[str]:
"""Write basic `.xyz` file content of Molecule."""
coords = self.get_position_matrix()
content = ["0"]
for i, atom in enumerate(self.get_atoms(), 1):
x, y, z = (i for i in coords[atom.get_id()])
content.append(f"{atom.get_element_string()} {x:f} {y:f} {z:f}\n")
# Set first line to the atom_count.
content[0] = f"{i}\n\n"
return content
[docs]
def write_xyz_file(self, path: str) -> None:
"""Write basic `.xyz` file of Molecule to `path`.
Connectivity is not maintained in this file type!
"""
content = self.write_xyz_content()
with open(path, "w") as f:
f.write("".join(content))
def _write_pdb_content(self) -> list[str]:
"""Write basic `.pdb` file content of Molecule."""
content = []
atom_counts: dict[str, int] = {}
hetatm = "HETATM"
alt_loc = ""
res_name = "UNL"
chain_id = ""
res_seq = "1"
i_code = ""
occupancy = "1.00"
temp_factor = "0.00"
coords = self.get_position_matrix()
# This set will be used by bonds.
atoms = set()
for i, atom in enumerate(self.get_atoms(), 1):
x, y, z = (i for i in coords[atom.get_id()])
atom_id = atom.get_id()
atoms.add(atom_id)
serial = atom_id + 1
element = atom.get_element_string()
charge = 0
atom_counts[element] = atom_counts.get(element, 0) + 1
name = f"{element}{atom_counts[element]}"
content.append(
f"{hetatm:<6}{serial:>5} {name:<4}"
f"{alt_loc:<1}{res_name:<3} {chain_id:<1}"
f"{res_seq:>4}{i_code:<1} "
f" {x:>7.3f} {y:>7.3f} {z:>7.3f}"
f"{occupancy:>6}{temp_factor:>6} "
f"{element:>2}{charge:>2}\n"
)
conect = "CONECT"
for bond in self.get_bonds():
a1 = bond.get_atom1_id()
a2 = bond.get_atom2_id()
if a1 in atoms and a2 in atoms:
content.append(
f"{conect:<6}{a1+1:>5}{a2+1:>5} \n"
)
content.append("END\n")
return content
[docs]
def write_pdb_file(self, path: str) -> None:
"""Write basic `.pdb` file of Molecule to `path`."""
content = self._write_pdb_content()
with open(path, "w") as f:
f.write("".join(content))
[docs]
def get_atoms(self) -> abc.Iterable[Atom]:
"""Yield the atoms in the molecule, ordered as input."""
yield from self.atoms
[docs]
def get_bonds(self) -> abc.Iterable[Bond]:
"""Yield the bonds in the molecule, ordered as input."""
yield from self.bonds
[docs]
def get_num_atoms(self) -> int:
"""Return the number of atoms in the molecule."""
return len(self.atoms)
[docs]
def get_centroid(self, atom_ids: tuple | set | None = None) -> float:
"""Return the centroid.
Parameters:
atom_ids: :class:`iterable` of :class:`int`, optional
The ids of atoms which are used to calculate the
centroid. Can be a single :class:`int`, if a single
atom is to be used, or ``None`` if all atoms are to be
used.
Returns:
The centroid of atoms specified by `atom_ids`.
Raises:
If `atom_ids` has a length of ``0``.
"""
if atom_ids is None:
atom_ids = range(len(self.atoms)) # type: ignore[assignment]
elif not isinstance(atom_ids, (list, tuple)):
atom_ids = tuple(atom_ids)
if len(atom_ids) == 0: # type: ignore[arg-type]
msg = "atom_ids was of length 0."
raise ValueError(msg)
return np.divide(
self.position_matrix[:, atom_ids].sum(axis=1), # type: ignore[index]
len(atom_ids), # type: ignore[arg-type]
)
[docs]
def get_subunits(self, bond_pair_ids: tuple) -> dict:
"""Get connected graphs based on Molecule separated by bonds.
Parameters:
bond_pair_ids:
:class:`iterable` of :class:`tuple` of :class:`ints`
Iterable of pairs of atom ids with bond between them to
optimize.
Returns:
subunits:
The subunits of `mol` split by bonds defined by
`bond_pair_ids`. Key is subunit identifier, Value is
:class:`iterable` of atom ids in subunit.
"""
# Produce a graph from the molecule that does not include edges
# where the bonds to be optimized are.
mol_graph = nx.Graph()
for atom in self.get_atoms():
mol_graph.add_node(atom.get_id())
# Add edges.
for bond in self.get_bonds():
pair_ids = (bond.get_atom1_id(), bond.get_atom2_id())
if pair_ids not in bond_pair_ids:
mol_graph.add_edge(*pair_ids)
# Get atom ids in disconnected subgraphs.
return dict(enumerate(nx.connected_components(mol_graph)))
def __str__(self) -> str:
"""String representation."""
return repr(self)
def __repr__(self) -> str:
"""String representation."""
return (
f"<{self.__class__.__name__}({len(self.atoms)} atoms) "
f"at {id(self)}>"
)