Source code for endstate_correction.simulation.base

import abc
import logging
import os
from typing import List, Literal, Union

import mdtraj
from mdtraj.core.trajectory import Trajectory as mdtraj_trajectory
from mdtraj.reporters import HDF5Reporter
from openmm import Integrator, LangevinIntegrator, OpenMMException
from openmm import unit
from openmm.app import (
    AmberPrmtopFile,
    CharmmPsfFile,
    NoCutoff,
    PME,
    Simulation,
    Topology,
)
from openmmml import MLPotential
from openmmtools.utils import get_fastest_platform

from ..protocol import BSSProtocol
from ..topology import AMBERTopology, CHARMMTopology


[docs]class EndstateCorrectionBase(abc.ABC): def __init__( self, top: Union[AMBERTopology, CHARMMTopology], env: Literal["complex", "waterbox", "vacuum"], ml_atoms: List[int], protocol: BSSProtocol, name: str = "endstate_correction", work_dir: str = "./", potential: str = "ani2x", implementation: Literal["nnpops", "torchani"] = "nnpops", interpolate: bool = True, ): self.logger = logging.getLogger("EndstateCorrectionBase") self.top = top self.protocol = protocol self.env = env self.name = name self.work_dir = work_dir self._mm_topology = None self._initial_coordinates = None mm_system = self._createSystem() ml_potential = MLPotential(potential) ml_system = ml_potential.createMixedSystem( self.get_mm_topology().topology, mm_system, ml_atoms, interpolate=interpolate, implementation=implementation, ) integrator = self.get_integrator() platform = get_fastest_platform(minimum_precision="mixed") self.simulation = Simulation( self.get_mm_topology().topology, ml_system, integrator, platform=platform ) def get_mm_topology(self) -> Topology: if self._mm_topology is None: self._mm_topology = self._get_mm_topology() return self._mm_topology def get_initial_coordinates(self) -> unit.Quantity: if self._initial_coordinates is None: self._initial_coordinates = self._get_initial_coordinates() return self._initial_coordinates @abc.abstractmethod def _get_mm_topology(self) -> Union[CharmmPsfFile, AmberPrmtopFile]: pass @abc.abstractmethod def _get_initial_coordinates(self) -> unit.Quantity: pass def get_integrator(self) -> Integrator: return LangevinIntegrator( self.protocol.temperature, self.protocol.collision_rate, self.protocol.timestep, ) def _createSystem(self, **kwargs): if self.env == "vacuum": mm_system = self.get_mm_topology().createSystem( nonbondedMethod=NoCutoff, **kwargs ) else: mm_system = self.get_mm_topology().createSystem( nonbondedMethod=PME, nonbondedCutoff=self.protocol.rlist, **kwargs, ) return mm_system def start(self): # define lambda states lamb = self.protocol.lam self.logger.info(f"{lamb=}") # define where to store samples trajectory_file = f"{self.work_dir}/{self.name}.h5" self._traj_file = trajectory_file self.logger.info(f"Trajectory saved to: {trajectory_file}") # set lambda self.simulation.context.setParameter("lambda_interpolate", lamb) # set coordinates self.simulation.context.setPositions(self.get_initial_coordinates()) # try to set velocities using openMM, fall back to manual velocity seeding if it fails if not self.protocol.restart: try: self.simulation.context.setVelocitiesToTemperature( self.protocol.temperature ) except OpenMMException: from endstate_correction.equ import _seed_velocities, _get_masses self.simulation.context.setVelocities( _seed_velocities(_get_masses(self.simulation.system)) ) # define HDF5Reporter self.simulation.reporters.append( HDF5Reporter( trajectory_file, self.protocol.restart_interval, ) ) # perform sampling self.simulation.step(self.protocol.n_integration_steps) self.simulation.reporters.clear() def set_trajectory(self, traj_file: str) -> None: self._traj_file = traj_file def get_trajectory(self) -> mdtraj_trajectory: traj = mdtraj.load_hdf5(self._traj_file) if self.env == "waterbox": traj.image_molecules() return traj def get_simulation(self) -> Simulation: return self.simulation def get_xyz(self) -> List[unit.Quantity]: traj = self.get_trajectory() return [traj.openmm_positions(i) for i in range(traj.n_frames)]