From a61191165bde7c84bfd2851ae667ebf35dfdfe44 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 15 Apr 2026 09:02:16 +0100 Subject: [PATCH 1/2] Improve energy components debugging option. --- README.md | 19 ++++++ src/somd2/config/_config.py | 5 +- src/somd2/runner/_base.py | 69 +++++++++++++++++----- src/somd2/runner/_repex.py | 26 ++++++-- src/somd2/runner/_runner.py | 114 +++++++++++++++++++++++++++++++++--- 5 files changed, 204 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 8fc73d6..d61601e 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,25 @@ geometry. To override this for all groups: somd2 perturbable_system.bss --terminal-flip-frequency "1 ps" --terminal-flip-angle "180 degrees" ``` +## Debugging with energy components + +To help diagnose simulation instabilities, `SOMD2` can record the potential +energy contribution from each OpenMM force group at every `energy-frequency` +interval. This is enabled with the `--save-energy-components` flag: + +``` +somd2 perturbable_system.bss --save-energy-components +``` + +One Parquet file per λ window is written to the output directory, named +`energy_components_.parquet`. Times are in nanoseconds and energies in +kcal/mol; both are stored as schema metadata in the file. + +> [!NOTE] +> Energy components are written more frequently than checkpoint files and are +> not guarded by the file lock, so they may lead the checkpoint files by up +> to one `checkpoint-frequency` interval when copying output mid-simulation. + ## Copying output files during a simulation When `SOMD2` writes checkpoint files it acquires an exclusive diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 6c68e54..1ba21b5 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -506,7 +506,10 @@ def __init__( Whether to save a crash report if the simulation crashes. save_energy_components: bool - Whether to save the energy contribution for each force when checkpointing. + Whether to save per-force-group energy contributions to a Parquet file + in the output directory. Energies are recorded at every 'energy_frequency' + interval, or 'gcmc_frequency' when running with GCMC. Intended for + debugging purposes. save_xml: bool Whether to write an XML file for the OpenMM system to the output diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index e717a29..664138d 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -498,6 +498,10 @@ def __init__(self, system, config): # Check the output directories and create names of output files. self._filenames = self._prepare_output() + # Per-window cache of the last saved energy-components time (ns), + # used to skip duplicate rows on restart. + self._last_ec_time = {} + # Store the current system as a reference. self._reference_system = self._system.clone() @@ -1194,7 +1198,7 @@ def increment_filename(base_filename, suffix): filenames["trajectory"] = str(output_directory / f"traj_{lam}.dcd") filenames["trajectory_chunk"] = str(output_directory / f"traj_{lam}_") filenames["energy_components"] = str( - output_directory / f"energy_components_{lam}.csv" + output_directory / f"energy_components_{lam}.parquet" ) filenames["gcmc_ghosts"] = str(output_directory / f"gcmc_ghosts_{lam}.txt") filenames["sampler_stats"] = str(output_directory / f"sampler_stats_{lam}.pkl") @@ -2020,12 +2024,23 @@ def _backup_checkpoint(self, index): except Exception as e: return index, e + try: + # Backup the existing energy components file, if it exists. + path = _Path(self._filenames[index]["energy_components"]) + if path.exists() and path.stat().st_size > 0: + _copyfile( + self._filenames[index]["energy_components"], + str(self._filenames[index]["energy_components"]) + ".bak", + ) + except Exception as e: + return index, e + return index, None def _save_energy_components(self, index, context, time_ns): """ Internal function to save the energy components for each force group to a - CSV file. + Parquet file. Parameters ---------- @@ -2040,11 +2055,28 @@ def _save_energy_components(self, index, context, time_ns): The current simulation time in nanoseconds. """ - import csv as _csv + import json as _json import openmm + import pandas as _pd + import pyarrow as _pa + import pyarrow.parquet as _pq_local filepath = self._filenames[index]["energy_components"] - file_exists = _Path(filepath).exists() + + # Lazy-initialise the last saved time for restart deduplication. + # On the first call for this window, read the existing file (if any) + # to find the maximum time already written. + if index not in self._last_ec_time: + path = _Path(filepath) + if path.exists() and path.stat().st_size > 0: + existing = _pq_local.read_table(filepath).to_pandas() + self._last_ec_time[index] = float(existing["time"].max()) + else: + self._last_ec_time[index] = -1.0 + + # Skip rows that have already been written (restart deduplication). + if time_ns <= self._last_ec_time[index]: + return # Use the named force groups already assigned by sire_to_openmm_system, # sorted alphabetically for a consistent column order across runs. @@ -2055,18 +2087,25 @@ def _save_energy_components(self, index, context, time_ns): openmm.unit.kilocalories_per_mole ) - columns = ["time"] + list(energies.keys()) - row = {"time": round(time_ns, 6)} | { - name: round(nrg, 4) for name, nrg in energies.items() - } + row = {"time": round(time_ns, 6)} | energies + df = _pd.DataFrame([row]) + + path = _Path(filepath) + if path.exists() and path.stat().st_size > 0: + _parquet_append(filepath, df) + else: + # First write: embed units as schema metadata under the "somd2" key, + # consistent with how the energy trajectory parquet files are written. + table = _pa.Table.from_pandas(df) + meta = _json.dumps( + {"time_units": "ns", "energy_units": "kcal/mol"} + ).encode() + table = table.replace_schema_metadata( + {b"somd2": meta, **table.schema.metadata} + ) + _pq_local.write_table(table, filepath) - with open(filepath, "a", newline="") as f: - writer = _csv.DictWriter(f, fieldnames=columns) - if not file_exists: - # Write a comment line with units before the header. - f.write("# time: ns, energy: kcal/mol\n") - writer.writeheader() - writer.writerow(row) + self._last_ec_time[index] = time_ns def _restore_backup_files(self): """ diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 09fc46a..3950a7c 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -1108,6 +1108,16 @@ def run(self): # Whether a frame is saved at the end of the cycle. write_gcmc_ghosts = (i + 1) % cycles_per_frame == 0 + # Current simulation time in ns for energy components saving. + time_ns = ( + ( + self._start_block * checkpoint_frequency + + (i + 1) * self._config.energy_frequency + ).to("ns") + if self._config.save_energy_components + else None + ) + # Run a dynamics block for each replica, making sure only each GPU is only # oversubscribed by a factor of self._config.oversubscription_factor. for j in range(num_batches): @@ -1121,6 +1131,7 @@ def run(self): repeat(is_gcmc), repeat(write_gcmc_ghosts), repeat(is_terminal_flip), + repeat(time_ns), ): if not result: _logger.error( @@ -1294,6 +1305,7 @@ def _run_block( is_gcmc=False, write_gcmc_ghosts=False, is_terminal_flip=False, + time_ns=None, ): """ Run a dynamics block for a given replica. @@ -1321,6 +1333,10 @@ def _run_block( Whether a terminal flip MC move should be performed before the dynamics block. + time_ns: float or None + The current simulation time in nanoseconds, used when saving energy + components. If None, energy components are not saved. + Returns ------- @@ -1417,6 +1433,10 @@ def _run_block( # Save the OpenMM state. self._dynamics_cache.save_openmm_state(index) + # Save the energy contribution for each force. + if self._config.save_energy_components and time_ns is not None: + self._save_energy_components(index, dynamics.context(), time_ns) + # Get the energy at each lambda value. energies = dynamics._current_energy_array() @@ -1781,12 +1801,6 @@ def _checkpoint(self, index, lambdas, block, num_blocks, is_final_block=False): # Commit the current system. system = dynamics.commit() - # Save the energy contribution for each force. - if self._config.save_energy_components: - self._save_energy_components( - index, dynamics.context(), system.time().to("ns") - ) - # If performing GCMC, then we need to flag the ghost waters. if gcmc_sampler is not None: system = gcmc_sampler._flag_ghost_waters(system) diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index ccd1073..542444c 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -781,6 +781,9 @@ def generate_lam_vals(lambda_base, increment=0.001): save_frames = self._config.frame_frequency > 0 next_frame = self._config.frame_frequency flip_counter = 0 + # Track elapsed simulation time separately for energy components, + # since dynamics blocks increment by gcmc_frequency not energy_frequency. + ec_elapsed = _sr.u("0ps") # Loop until we reach the runtime. while runtime < checkpoint_frequency: @@ -863,6 +866,56 @@ def generate_lam_vals(lambda_base, increment=0.001): # Update the runtime and flip counter. runtime += self._config.energy_frequency + ec_elapsed += self._config.gcmc_frequency + flip_counter += 1 + + # Save the energy contribution for each force. + if self._config.save_energy_components: + self._save_energy_components( + index, + dynamics.context(), + (block * checkpoint_frequency + ec_elapsed).to( + "ns" + ), + ) + + elif self._config.save_energy_components: + # Sub-block loop to save energy components at energy_frequency + # intervals, with optional terminal flip moves. + runtime = _sr.u("0ps") + flip_counter = 0 + while runtime < checkpoint_frequency: + if ( + terminal_flip_sampler is not None + and flip_counter % flip_every == 0 + ): + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() + dynamics.run( + self._config.energy_frequency, + energy_frequency=self._config.energy_frequency, + frame_frequency=self._config.frame_frequency, + lambda_windows=lambda_array, + rest2_scale_factors=rest2_scale_factors, + save_velocities=self._config.save_velocities, + auto_fix_minimise=self._config.auto_fix_minimise, + num_energy_neighbours=num_energy_neighbours, + null_energy=self._config.null_energy, + save_crash_report=self._config.save_crash_report, + ) + runtime += self._config.energy_frequency + self._save_energy_components( + index, + dynamics.context(), + (block * checkpoint_frequency + runtime).to("ns"), + ) flip_counter += 1 elif terminal_flip_sampler is not None: @@ -924,12 +977,6 @@ def generate_lam_vals(lambda_base, increment=0.001): # Commit the current system. system = dynamics.commit() - # Save the energy contribution for each force. - if self._config.save_energy_components: - self._save_energy_components( - index, dynamics.context(), system.time().to("ns") - ) - # If performing GCMC, then we need to flag the ghost waters. if gcmc_sampler is not None: system = gcmc_sampler._flag_ghost_waters(system) @@ -1103,7 +1150,7 @@ def generate_lam_vals(lambda_base, increment=0.001): ) except Exception as e: raise RuntimeError( - f"Final dynamics block for {lam_sym} = {lambda_value:.5f} failed: {e}" + f"Final dynamics block for {_lam_sym} = {lambda_value:.5f} failed: {e}" ) else: try: @@ -1113,6 +1160,9 @@ def generate_lam_vals(lambda_base, increment=0.001): save_frames = self._config.frame_frequency > 0 next_frame = self._config.frame_frequency flip_counter = 0 + # Track elapsed simulation time separately for energy components, + # since dynamics blocks increment by gcmc_frequency not energy_frequency. + ec_elapsed = _sr.u("0ps") # Loop until we reach the runtime. while runtime < time: @@ -1182,6 +1232,56 @@ def generate_lam_vals(lambda_base, increment=0.001): # Update the runtime and flip counter. runtime += self._config.energy_frequency + ec_elapsed += self._config.gcmc_frequency + flip_counter += 1 + + # Save the energy contribution for each force. + if self._config.save_energy_components: + self._save_energy_components( + index, + dynamics.context(), + (self._config.runtime - time + ec_elapsed).to("ns"), + ) + + elif self._config.save_energy_components: + # Sub-block loop to save energy components at energy_frequency + # intervals, with optional terminal flip moves. + runtime = _sr.u("0ps") + flip_counter = 0 + # Elapsed time before this run (0 for fresh, restart time for restart). + time_base = self._config.runtime - time + while runtime < time: + if ( + terminal_flip_sampler is not None + and flip_counter % flip_every == 0 + ): + _logger.info( + f"Performing terminal flip move at " + f"{_lam_sym} = {lambda_value:.5f}" + ) + if ( + terminal_flip_sampler.move(dynamics.context()) + and self._config.randomise_velocities + ): + dynamics.randomise_velocities() + dynamics.run( + self._config.energy_frequency, + energy_frequency=self._config.energy_frequency, + frame_frequency=self._config.frame_frequency, + lambda_windows=lambda_array, + rest2_scale_factors=rest2_scale_factors, + save_velocities=self._config.save_velocities, + auto_fix_minimise=self._config.auto_fix_minimise, + num_energy_neighbours=num_energy_neighbours, + null_energy=self._config.null_energy, + save_crash_report=self._config.save_crash_report, + ) + runtime += self._config.energy_frequency + self._save_energy_components( + index, + dynamics.context(), + (time_base + runtime).to("ns"), + ) flip_counter += 1 elif terminal_flip_sampler is not None: From 8e238660987b9c879a2b3b01a6a3ebbaba2afc4c Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 15 Apr 2026 09:42:18 +0100 Subject: [PATCH 2/2] Consolidate conditional branches to remove code duplication. --- README.md | 12 +- src/somd2/config/_config.py | 5 +- src/somd2/runner/_runner.py | 411 +++++++++++++++--------------------- 3 files changed, 180 insertions(+), 248 deletions(-) diff --git a/README.md b/README.md index d61601e..6ec3b90 100644 --- a/README.md +++ b/README.md @@ -245,8 +245,8 @@ somd2 perturbable_system.bss --terminal-flip-frequency "1 ps" --terminal-flip-an ## Debugging with energy components To help diagnose simulation instabilities, `SOMD2` can record the potential -energy contribution from each OpenMM force group at every `energy-frequency` -interval. This is enabled with the `--save-energy-components` flag: +energy contribution from each OpenMM force group. This is enabled with the +`--save-energy-components` flag: ``` somd2 perturbable_system.bss --save-energy-components @@ -256,6 +256,14 @@ One Parquet file per λ window is written to the output directory, named `energy_components_.parquet`. Times are in nanoseconds and energies in kcal/mol; both are stored as schema metadata in the file. +The recording interval depends on the runner and active samplers: + +- **Replica exchange**: always `energy-frequency` +- **Standard runner, no MC**: `energy-frequency` +- **Standard runner, with MC**: the shortest active MC frequency, i.e. + `gcmc-frequency`, `terminal-flip-frequency`, or the smaller of the two + when both are active + > [!NOTE] > Energy components are written more frequently than checkpoint files and are > not guarded by the file lock, so they may lead the checkpoint files by up diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 1ba21b5..2a9a461 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -508,8 +508,9 @@ def __init__( save_energy_components: bool Whether to save per-force-group energy contributions to a Parquet file in the output directory. Energies are recorded at every 'energy_frequency' - interval, or 'gcmc_frequency' when running with GCMC. Intended for - debugging purposes. + interval. When not running replica exchange, the interval is instead the + shortest active MC frequency when running with GCMC or terminal flip moves. + Intended for debugging purposes. save_xml: bool Whether to write an XML file for the OpenMM system to the output diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 542444c..7479c59 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -473,22 +473,12 @@ def generate_lam_vals(lambda_base, increment=0.001): self._terminal_groups, float(self._config.temperature.value()), ) - flip_every = max( - 1, - round( - ( - self._config.terminal_flip_frequency - / self._config.energy_frequency - ).value() - ), - ) _logger.info( f"Terminal flip sampler ready at {_lam_sym} = {lambda_value:.5f} " - f"(every {flip_every} energy block(s))" + f"(every {self._config.terminal_flip_frequency})" ) else: terminal_flip_sampler = None - flip_every = None # Minimisation. if self._config.minimise: @@ -772,40 +762,87 @@ def generate_lam_vals(lambda_base, increment=0.001): # Run the dynamics. try: - # GCMC specific handling. Note that the frame and checkpoint - # frequencies are multiples of the energy frequency so we can - # run in energy frequency blocks with no remainder. - if self._config.gcmc: - # Initialise the run time and time at which the next frame is saved. + # Run in sub-blocks when any MC sampler is active or energy + # components are being saved; otherwise run the full block. + needs_subblock = ( + gcmc_sampler is not None + or terminal_flip_sampler is not None + or self._config.save_energy_components + ) + if needs_subblock: runtime = _sr.u("0ps") - save_frames = self._config.frame_frequency > 0 - next_frame = self._config.frame_frequency - flip_counter = 0 - # Track elapsed simulation time separately for energy components, - # since dynamics blocks increment by gcmc_frequency not energy_frequency. ec_elapsed = _sr.u("0ps") - - # Loop until we reach the runtime. - while runtime < checkpoint_frequency: - # Perform a GCMC move before dynamics so the ghost - # state is consistent with the energies computed - # during dynamics. - _logger.info( - f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}" + flip_counter = 0 + save_frames = ( + gcmc_sampler is not None + and self._config.frame_frequency > 0 + ) + next_frame = ( + self._config.frame_frequency if save_frames else None + ) + # Sub-block size: shortest active MC frequency, or + # energy_frequency when only saving energy components. + if ( + gcmc_sampler is not None + and terminal_flip_sampler is not None + ): + block_size = min( + self._config.gcmc_frequency, + self._config.terminal_flip_frequency, ) - gcmc_sampler.push() - try: - gcmc_sampler.move(dynamics.context()) - finally: - gcmc_sampler.pop() + elif gcmc_sampler is not None: + block_size = self._config.gcmc_frequency + elif terminal_flip_sampler is not None: + block_size = self._config.terminal_flip_frequency + else: + block_size = self._config.energy_frequency + # How often to attempt each MC move (in sub-block units). + gcmc_every = ( + max( + 1, + round( + (self._config.gcmc_frequency / block_size).value() + ), + ) + if gcmc_sampler is not None + else None + ) + mc_flip_every = ( + max( + 1, + round( + ( + self._config.terminal_flip_frequency + / block_size + ).value() + ), + ) + if terminal_flip_sampler is not None + else None + ) - # GCMC always changes positions. - needs_pre_run_snapshot = self._config.auto_fix_minimise + while runtime < checkpoint_frequency: + needs_pre_run_snapshot = False - # Perform a terminal flip move at the specified frequency. + # GCMC move. + if ( + gcmc_sampler is not None + and flip_counter % gcmc_every == 0 + ): + _logger.info( + f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}" + ) + gcmc_sampler.push() + try: + gcmc_sampler.move(dynamics.context()) + finally: + gcmc_sampler.pop() + needs_pre_run_snapshot = self._config.auto_fix_minimise + + # Terminal flip move. if ( terminal_flip_sampler is not None - and flip_counter % flip_every == 0 + and flip_counter % mc_flip_every == 0 ): _logger.info( f"Performing terminal flip move at " @@ -820,30 +857,24 @@ def generate_lam_vals(lambda_base, increment=0.001): if self._config.randomise_velocities: dynamics.randomise_velocities() - # Snapshot the context state once for crash recovery - # if any MC move changed positions. + # Snapshot the context state for crash recovery if + # any MC move changed positions. if needs_pre_run_snapshot: dynamics._d._pre_run_state = ( dynamics.context().getState( getPositions=True, getVelocities=True ) ) - needs_pre_run_snapshot = False - # Write ghost residues immediately after the GCMC - # move if a frame will be saved in the upcoming - # dynamics block. - if ( - save_frames - and runtime + self._config.energy_frequency - >= next_frame - ): + # Write ghost residues immediately before the dynamics + # block if a frame will be saved within it. + if save_frames and runtime + block_size >= next_frame: gcmc_sampler.write_ghost_residues() next_frame += self._config.frame_frequency - # Run the dynamics in blocks of the GCMC frequency. + # Run the dynamics block. dynamics.run( - self._config.gcmc_frequency, + block_size, energy_frequency=self._config.energy_frequency, frame_frequency=self._config.frame_frequency, lambda_windows=lambda_array, @@ -853,7 +884,6 @@ def generate_lam_vals(lambda_base, increment=0.001): num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, - # GCMC specific options. excess_chemical_potential=( self._mu_ex if gcmc_sampler is not None else None ), @@ -864,12 +894,11 @@ def generate_lam_vals(lambda_base, increment=0.001): ), ) - # Update the runtime and flip counter. - runtime += self._config.energy_frequency - ec_elapsed += self._config.gcmc_frequency + runtime += block_size + ec_elapsed += block_size flip_counter += 1 - # Save the energy contribution for each force. + # Save energy components. if self._config.save_energy_components: self._save_energy_components( index, @@ -879,81 +908,6 @@ def generate_lam_vals(lambda_base, increment=0.001): ), ) - elif self._config.save_energy_components: - # Sub-block loop to save energy components at energy_frequency - # intervals, with optional terminal flip moves. - runtime = _sr.u("0ps") - flip_counter = 0 - while runtime < checkpoint_frequency: - if ( - terminal_flip_sampler is not None - and flip_counter % flip_every == 0 - ): - _logger.info( - f"Performing terminal flip move at " - f"{_lam_sym} = {lambda_value:.5f}" - ) - if ( - terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() - dynamics.run( - self._config.energy_frequency, - energy_frequency=self._config.energy_frequency, - frame_frequency=self._config.frame_frequency, - lambda_windows=lambda_array, - rest2_scale_factors=rest2_scale_factors, - save_velocities=self._config.save_velocities, - auto_fix_minimise=self._config.auto_fix_minimise, - num_energy_neighbours=num_energy_neighbours, - null_energy=self._config.null_energy, - save_crash_report=self._config.save_crash_report, - ) - runtime += self._config.energy_frequency - self._save_energy_components( - index, - dynamics.context(), - (block * checkpoint_frequency + runtime).to("ns"), - ) - flip_counter += 1 - - elif terminal_flip_sampler is not None: - # Terminal flip without GCMC: perform flip moves at the - # specified frequency then run the full dynamics block. - n_flips = max( - 1, - round( - ( - checkpoint_frequency - / self._config.terminal_flip_frequency - ).value() - ), - ) - for _ in range(n_flips): - _logger.info( - f"Performing terminal flip move at " - f"{_lam_sym} = {lambda_value:.5f}" - ) - if ( - terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() - - dynamics.run( - checkpoint_frequency, - energy_frequency=self._config.energy_frequency, - frame_frequency=self._config.frame_frequency, - lambda_windows=lambda_array, - rest2_scale_factors=rest2_scale_factors, - save_velocities=self._config.save_velocities, - auto_fix_minimise=self._config.auto_fix_minimise, - num_energy_neighbours=num_energy_neighbours, - null_energy=self._config.null_energy, - save_crash_report=self._config.save_crash_report, - ) - else: dynamics.run( checkpoint_frequency, @@ -1154,37 +1108,75 @@ def generate_lam_vals(lambda_base, increment=0.001): ) else: try: - if gcmc_sampler is not None: - # Initialise the run time and time at which the next frame is saved. + # Run in sub-blocks when any MC sampler is active or energy + # components are being saved; otherwise run a single block. + needs_subblock = ( + gcmc_sampler is not None + or terminal_flip_sampler is not None + or self._config.save_energy_components + ) + if needs_subblock: runtime = _sr.u("0ps") - save_frames = self._config.frame_frequency > 0 - next_frame = self._config.frame_frequency - flip_counter = 0 - # Track elapsed simulation time separately for energy components, - # since dynamics blocks increment by gcmc_frequency not energy_frequency. ec_elapsed = _sr.u("0ps") + flip_counter = 0 + save_frames = ( + gcmc_sampler is not None and self._config.frame_frequency > 0 + ) + next_frame = self._config.frame_frequency if save_frames else None + # Sub-block size: shortest active MC frequency, or + # energy_frequency when only saving energy components. + if gcmc_sampler is not None and terminal_flip_sampler is not None: + block_size = min( + self._config.gcmc_frequency, + self._config.terminal_flip_frequency, + ) + elif gcmc_sampler is not None: + block_size = self._config.gcmc_frequency + elif terminal_flip_sampler is not None: + block_size = self._config.terminal_flip_frequency + else: + block_size = self._config.energy_frequency + # How often to attempt each MC move (in sub-block units). + gcmc_every = ( + max( + 1, round((self._config.gcmc_frequency / block_size).value()) + ) + if gcmc_sampler is not None + else None + ) + mc_flip_every = ( + max( + 1, + round( + ( + self._config.terminal_flip_frequency / block_size + ).value() + ), + ) + if terminal_flip_sampler is not None + else None + ) + time_base = self._config.runtime - time - # Loop until we reach the runtime. while runtime < time: - # Perform a GCMC move before dynamics so the ghost - # state is consistent with the energies computed - # during dynamics. - _logger.info( - f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}" - ) - gcmc_sampler.push() - try: - gcmc_sampler.move(dynamics.context()) - finally: - gcmc_sampler.pop() + needs_pre_run_snapshot = False - # GCMC always changes positions. - needs_pre_run_snapshot = True + # GCMC move. + if gcmc_sampler is not None and flip_counter % gcmc_every == 0: + _logger.info( + f"Performing GCMC move at {_lam_sym} = {lambda_value:.5f}" + ) + gcmc_sampler.push() + try: + gcmc_sampler.move(dynamics.context()) + finally: + gcmc_sampler.pop() + needs_pre_run_snapshot = self._config.auto_fix_minimise - # Perform a terminal flip move at the specified frequency. + # Terminal flip move. if ( terminal_flip_sampler is not None - and flip_counter % flip_every == 0 + and flip_counter % mc_flip_every == 0 ): _logger.info( f"Performing terminal flip move at " @@ -1194,31 +1186,27 @@ def generate_lam_vals(lambda_base, increment=0.001): dynamics.context() ) if flip_accepted: - needs_pre_run_snapshot = True + if self._config.auto_fix_minimise: + needs_pre_run_snapshot = True if self._config.randomise_velocities: dynamics.randomise_velocities() - # Snapshot the context state once for crash recovery - # if any MC move changed positions. + # Snapshot the context state for crash recovery if + # any MC move changed positions. if needs_pre_run_snapshot: dynamics._d._pre_run_state = dynamics.context().getState( getPositions=True, getVelocities=True ) - needs_pre_run_snapshot = False - # Write ghost residues immediately after the GCMC - # move if a frame will be saved in the upcoming - # dynamics block. - if ( - save_frames - and runtime + self._config.energy_frequency >= next_frame - ): + # Write ghost residues immediately before the dynamics + # block if a frame will be saved within it. + if save_frames and runtime + block_size >= next_frame: gcmc_sampler.write_ghost_residues() next_frame += self._config.frame_frequency - # Run the dynamics in blocks of the GCMC frequency. + # Run the dynamics block. dynamics.run( - self._config.gcmc_frequency, + block_size, energy_frequency=self._config.energy_frequency, frame_frequency=self._config.frame_frequency, lambda_windows=lambda_array, @@ -1228,93 +1216,28 @@ def generate_lam_vals(lambda_base, increment=0.001): num_energy_neighbours=num_energy_neighbours, null_energy=self._config.null_energy, save_crash_report=self._config.save_crash_report, + excess_chemical_potential=( + self._mu_ex if gcmc_sampler is not None else None + ), + num_waters=( + _np.sum(gcmc_sampler.water_state()) + if gcmc_sampler is not None + else None + ), ) - # Update the runtime and flip counter. - runtime += self._config.energy_frequency - ec_elapsed += self._config.gcmc_frequency + runtime += block_size + ec_elapsed += block_size flip_counter += 1 - # Save the energy contribution for each force. + # Save energy components. if self._config.save_energy_components: self._save_energy_components( index, dynamics.context(), - (self._config.runtime - time + ec_elapsed).to("ns"), + (time_base + ec_elapsed).to("ns"), ) - elif self._config.save_energy_components: - # Sub-block loop to save energy components at energy_frequency - # intervals, with optional terminal flip moves. - runtime = _sr.u("0ps") - flip_counter = 0 - # Elapsed time before this run (0 for fresh, restart time for restart). - time_base = self._config.runtime - time - while runtime < time: - if ( - terminal_flip_sampler is not None - and flip_counter % flip_every == 0 - ): - _logger.info( - f"Performing terminal flip move at " - f"{_lam_sym} = {lambda_value:.5f}" - ) - if ( - terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() - dynamics.run( - self._config.energy_frequency, - energy_frequency=self._config.energy_frequency, - frame_frequency=self._config.frame_frequency, - lambda_windows=lambda_array, - rest2_scale_factors=rest2_scale_factors, - save_velocities=self._config.save_velocities, - auto_fix_minimise=self._config.auto_fix_minimise, - num_energy_neighbours=num_energy_neighbours, - null_energy=self._config.null_energy, - save_crash_report=self._config.save_crash_report, - ) - runtime += self._config.energy_frequency - self._save_energy_components( - index, - dynamics.context(), - (time_base + runtime).to("ns"), - ) - flip_counter += 1 - - elif terminal_flip_sampler is not None: - # Terminal flip without GCMC: perform flip moves at the - # start then run the full dynamics block. - n_flips = max( - 1, - round((time / self._config.terminal_flip_frequency).value()), - ) - for _ in range(n_flips): - _logger.info( - f"Performing terminal flip move at " - f"{_lam_sym} = {lambda_value:.5f}" - ) - if ( - terminal_flip_sampler.move(dynamics.context()) - and self._config.randomise_velocities - ): - dynamics.randomise_velocities() - - dynamics.run( - time, - energy_frequency=self._config.energy_frequency, - frame_frequency=self._config.frame_frequency, - lambda_windows=lambda_array, - rest2_scale_factors=rest2_scale_factors, - save_velocities=self._config.save_velocities, - auto_fix_minimise=self._config.auto_fix_minimise, - num_energy_neighbours=num_energy_neighbours, - null_energy=self._config.null_energy, - save_crash_report=self._config.save_crash_report, - ) - else: dynamics.run( time,