Source code for propertyestimator.protocols.utils

"""
A set of utilities for setting up property estimation workflows.
"""
import copy
from collections import namedtuple

from propertyestimator import unit
from propertyestimator.protocols import analysis, forcefield, gradients, groups, reweighting, coordinates, simulation, \
    storage
from propertyestimator.thermodynamics import Ensemble
from propertyestimator.utils.statistics import ObservableType
from propertyestimator.workflow import WorkflowOptions
from propertyestimator.workflow.plugins import available_protocols
from propertyestimator.workflow.schemas import ProtocolReplicator, WorkflowSimulationDataToStore
from propertyestimator.workflow.utils import ProtocolPath, ReplicatorValue

BaseReweightingProtocols = namedtuple('BaseReweightingProtocols', 'unpack_stored_data '
                                                                  'analysis_protocol '
                                                                  'decorrelate_statistics '
                                                                  'decorrelate_trajectory '
                                                                  'concatenate_trajectories '
                                                                  'concatenate_statistics '
                                                                  'build_reference_system '
                                                                  'reduced_reference_potential '
                                                                  'build_target_system '
                                                                  'reduced_target_potential '
                                                                  'mbar_protocol ')


BaseSimulationProtocols = namedtuple('BaseSimulationProtocols', 'build_coordinates '
                                                                'assign_parameters '
                                                                'energy_minimisation '
                                                                'equilibration_simulation '
                                                                'production_simulation '
                                                                'analysis_protocol '
                                                                'converge_uncertainty '
                                                                'extract_uncorrelated_trajectory '
                                                                'extract_uncorrelated_statistics ')


[docs]def generate_base_reweighting_protocols(analysis_protocol, mbar_protocol, workflow_options, replicator_id='data_repl', id_suffix=''): """Constructs a set of protocols which, when combined in a workflow schema, may be executed to reweight a set of existing data to estimate a particular property. The reweighted observable of interest will be calculated by following the passed in `analysis_protocol`. Parameters ---------- analysis_protocol: AveragePropertyProtocol The protocol which will take input from the stored data, and generate a set of observables to reweight. mbar_protocol: BaseReweightingProtocol A template mbar reweighting protocol, which has it's reference observables already set. This method will automatically set the reduced potentials on this object. workflow_options: WorkflowOptions The options being used to generate a workflow. replicator_id: str The id to use for the data replicator. id_suffix: str A string suffix to append to each of the protocol ids. Returns ------- BaseReweightingProtocols: A named tuple of the protocol which should form the bulk of a property estimation workflow. ProtocolReplicator: A replicator which will clone the workflow for each piece of stored data. """ assert isinstance(analysis_protocol, analysis.AveragePropertyProtocol) assert f'$({replicator_id})' in analysis_protocol.id assert f'$({replicator_id})' not in mbar_protocol.id replicator_suffix = '_$({}){}'.format(replicator_id, id_suffix) # Unpack all the of the stored data. unpack_stored_data = storage.UnpackStoredSimulationData('unpack_data{}'.format(replicator_suffix)) unpack_stored_data.simulation_data_path = ReplicatorValue(replicator_id) # The autocorrelation time of each of the stored files will be calculated for this property # using the passed in analysis protocol. if isinstance(analysis_protocol, analysis.ExtractAverageStatistic): analysis_protocol.statistics_path = ProtocolPath('statistics_file_path', unpack_stored_data.id) elif isinstance(analysis_protocol, analysis.AverageTrajectoryProperty): analysis_protocol.input_coordinate_file = ProtocolPath('coordinate_file_path', unpack_stored_data.id) analysis_protocol.trajectory_path = ProtocolPath('trajectory_file_path', unpack_stored_data.id) # Decorrelate the frames of the stored trajectory and statistics arrays. decorrelate_statistics = analysis.ExtractUncorrelatedStatisticsData('decorrelate_stats{}'.format(replicator_suffix)) decorrelate_statistics.statistical_inefficiency = ProtocolPath('statistical_inefficiency', analysis_protocol.id) decorrelate_statistics.equilibration_index = ProtocolPath('equilibration_index', analysis_protocol.id) decorrelate_statistics.input_statistics_path = ProtocolPath('statistics_file_path', unpack_stored_data.id) decorrelate_trajectory = analysis.ExtractUncorrelatedTrajectoryData('decorrelate_traj{}'.format(replicator_suffix)) decorrelate_trajectory.statistical_inefficiency = ProtocolPath('statistical_inefficiency', analysis_protocol.id) decorrelate_trajectory.equilibration_index = ProtocolPath('equilibration_index', analysis_protocol.id) decorrelate_trajectory.input_coordinate_file = ProtocolPath('coordinate_file_path', unpack_stored_data.id) decorrelate_trajectory.input_trajectory_path = ProtocolPath('trajectory_file_path', unpack_stored_data.id) # Stitch together all of the trajectories join_trajectories = reweighting.ConcatenateTrajectories('concat_traj' + id_suffix) join_trajectories.input_coordinate_paths = ProtocolPath('coordinate_file_path', unpack_stored_data.id) join_trajectories.input_trajectory_paths = ProtocolPath('output_trajectory_path', decorrelate_trajectory.id) join_statistics = reweighting.ConcatenateStatistics('concat_stats' + id_suffix) join_statistics.input_statistics_paths = ProtocolPath('output_statistics_path', decorrelate_statistics.id) # Calculate the reduced potentials for each of the reference states. build_reference_system = forcefield.BuildSmirnoffSystem('build_system{}'.format(replicator_suffix)) build_reference_system.force_field_path = ProtocolPath('force_field_path', unpack_stored_data.id) build_reference_system.substance = ProtocolPath('substance', unpack_stored_data.id) build_reference_system.coordinate_file_path = ProtocolPath('coordinate_file_path', unpack_stored_data.id) reduced_reference_potential = reweighting.CalculateReducedPotentialOpenMM('reduced_potential{}'.format( replicator_suffix)) reduced_reference_potential.system_path = ProtocolPath('system_path', build_reference_system.id) reduced_reference_potential.thermodynamic_state = ProtocolPath('thermodynamic_state', unpack_stored_data.id) reduced_reference_potential.coordinate_file_path = ProtocolPath('coordinate_file_path', unpack_stored_data.id) reduced_reference_potential.trajectory_file_path = ProtocolPath('output_trajectory_path', join_trajectories.id) reduced_reference_potential.kinetic_energies_path = ProtocolPath('output_statistics_path', join_statistics.id) # Calculate the reduced potential of the target state. build_target_system = forcefield.BuildSmirnoffSystem('build_system_target' + id_suffix) build_target_system.force_field_path = ProtocolPath('force_field_path', 'global') build_target_system.substance = ProtocolPath('substance', 'global') build_target_system.coordinate_file_path = ProtocolPath('output_coordinate_path', join_trajectories.id) reduced_target_potential = reweighting.CalculateReducedPotentialOpenMM('reduced_potential_target' + id_suffix) reduced_target_potential.thermodynamic_state = ProtocolPath('thermodynamic_state', 'global') reduced_target_potential.system_path = ProtocolPath('system_path', build_target_system.id) reduced_target_potential.coordinate_file_path = ProtocolPath('output_coordinate_path', join_trajectories.id) reduced_target_potential.trajectory_file_path = ProtocolPath('output_trajectory_path', join_trajectories.id) reduced_target_potential.kinetic_energies_path = ProtocolPath('output_statistics_path', join_statistics.id) # Finally, apply MBAR to get the reweighted value. mbar_protocol.reference_reduced_potentials = ProtocolPath('statistics_file_path', reduced_reference_potential.id) mbar_protocol.target_reduced_potentials = [ProtocolPath('statistics_file_path', reduced_target_potential.id)] if (isinstance(mbar_protocol, reweighting.ReweightStatistics) and mbar_protocol.statistics_type != ObservableType.PotentialEnergy and mbar_protocol.statistics_type != ObservableType.TotalEnergy and mbar_protocol.statistics_type != ObservableType.Enthalpy and mbar_protocol.statistics_type != ObservableType.ReducedPotential): mbar_protocol.statistics_paths = ProtocolPath('output_statistics_path', decorrelate_statistics.id) elif isinstance(mbar_protocol, reweighting.ReweightStatistics): mbar_protocol.statistics_paths = [ProtocolPath('statistics_file_path', reduced_target_potential.id)] mbar_protocol.frame_counts = ProtocolPath('number_of_uncorrelated_samples', decorrelate_statistics.id) # TODO: Implement a cleaner way to handle this. if workflow_options.convergence_mode == WorkflowOptions.ConvergenceMode.NoChecks: mbar_protocol.required_effective_samples = -1 base_protocols = BaseReweightingProtocols(unpack_stored_data, analysis_protocol, decorrelate_statistics, decorrelate_trajectory, join_trajectories, join_statistics, build_reference_system, reduced_reference_potential, build_target_system, reduced_target_potential, mbar_protocol) # Create the replicator object. component_replicator = ProtocolReplicator(replicator_id=replicator_id) component_replicator.template_values = ProtocolPath('full_system_data', 'global') return base_protocols, component_replicator
[docs]def generate_base_simulation_protocols(analysis_protocol, workflow_options, id_suffix='', conditional_group=None): """Constructs a set of protocols which, when combined in a workflow schema, may be executed to run a single simulation to estimate a particular property. The observable of interest to extract from the simulation is determined by the passed in `analysis_protocol`. The protocols returned will: 1) Build a set of liquid coordinates for the property substance using packmol. 2) Assign a set of smirnoff force field parameters to the system. 3) Perform an energy minimisation on the system. 4) Run a short NPT equilibration simulation for 100000 steps using a timestep of 2fs. 5) Within a conditional group (up to a maximum of 100 times): 5a) Run a longer NPT production simulation for 1000000 steps using a timestep of 2fs 5b) Extract the average value of an observable and it's uncertainty. 5c) If a convergence mode is set by the options, check if the target uncertainty has been met. If not, repeat steps 5a), 5b) and 5c). 6) Extract uncorrelated configurations from a generated production simulation. 7) Extract uncorrelated statistics from a generated production simulation. Parameters ---------- analysis_protocol: AveragePropertyProtocol The protocol which will extract the observable of interest from the generated simulation data. workflow_options: WorkflowOptions The options being used to generate a workflow. id_suffix: str A string suffix to append to each of the protocol ids. conditional_group: ProtocolGroup, optional A custom group to wrap the main simulation / extraction protocols within. It is up to the caller of this method to manually add the convergence conditions to this group. If `None`, a default group with uncertainty convergence conditions is automatically constructed. Returns ------- BaseSimulationProtocols A named tuple of the generated protocols. ProtocolPath A reference to the final value of the estimated observable and its uncertainty (an `EstimatedQuantity`). WorkflowSimulationDataToStore An object which describes the default data from a simulation to store, such as the uncorrelated statistics and configurations. """ assert isinstance(analysis_protocol, analysis.AveragePropertyProtocol) build_coordinates = coordinates.BuildCoordinatesPackmol(f'build_coordinates{id_suffix}') build_coordinates.substance = ProtocolPath('substance', 'global') build_coordinates.max_molecules = 1000 assign_parameters = forcefield.BuildSmirnoffSystem(f'assign_parameters{id_suffix}') assign_parameters.force_field_path = ProtocolPath('force_field_path', 'global') assign_parameters.coordinate_file_path = ProtocolPath('coordinate_file_path', build_coordinates.id) assign_parameters.substance = ProtocolPath('substance', 'global') # Equilibration energy_minimisation = simulation.RunEnergyMinimisation(f'energy_minimisation{id_suffix}') energy_minimisation.input_coordinate_file = ProtocolPath('coordinate_file_path', build_coordinates.id) energy_minimisation.system_path = ProtocolPath('system_path', assign_parameters.id) equilibration_simulation = simulation.RunOpenMMSimulation(f'equilibration_simulation{id_suffix}') equilibration_simulation.ensemble = Ensemble.NPT equilibration_simulation.steps = 100000 equilibration_simulation.output_frequency = 5000 equilibration_simulation.timestep = 2.0 * unit.femtosecond equilibration_simulation.thermodynamic_state = ProtocolPath('thermodynamic_state', 'global') equilibration_simulation.input_coordinate_file = ProtocolPath('output_coordinate_file', energy_minimisation.id) equilibration_simulation.system_path = ProtocolPath('system_path', assign_parameters.id) # Production production_simulation = simulation.RunOpenMMSimulation(f'production_simulation{id_suffix}') production_simulation.ensemble = Ensemble.NPT production_simulation.steps = 1000000 production_simulation.output_frequency = 3000 production_simulation.timestep = 2.0 * unit.femtosecond production_simulation.thermodynamic_state = ProtocolPath('thermodynamic_state', 'global') production_simulation.input_coordinate_file = ProtocolPath('output_coordinate_file', equilibration_simulation.id) production_simulation.system_path = ProtocolPath('system_path', assign_parameters.id) # Set up a conditional group to ensure convergence of uncertainty if conditional_group is None: conditional_group = groups.ConditionalGroup(f'conditional_group{id_suffix}') conditional_group.max_iterations = 100 if workflow_options.convergence_mode != WorkflowOptions.ConvergenceMode.NoChecks: condition = groups.ConditionalGroup.Condition() condition.left_hand_value = ProtocolPath('value.uncertainty', conditional_group.id, analysis_protocol.id) condition.right_hand_value = ProtocolPath('target_uncertainty', 'global') condition.condition_type = groups.ConditionalGroup.ConditionType.LessThan conditional_group.add_condition(condition) conditional_group.add_protocols(production_simulation, analysis_protocol) # Point the analyse protocol to the correct data source if isinstance(analysis_protocol, analysis.AverageTrajectoryProperty): analysis_protocol.input_coordinate_file = ProtocolPath('coordinate_file_path', build_coordinates.id) analysis_protocol.trajectory_path = ProtocolPath('trajectory_file_path', production_simulation.id) elif isinstance(analysis_protocol, analysis.ExtractAverageStatistic): analysis_protocol.statistics_path = ProtocolPath('statistics_file_path', production_simulation.id) else: raise ValueError('The analysis protocol must inherit from either the ' 'AverageTrajectoryProperty or ExtractAverageStatistic ' 'protocols.') # Finally, extract uncorrelated data statistical_inefficiency = ProtocolPath('statistical_inefficiency', conditional_group.id, analysis_protocol.id) equilibration_index = ProtocolPath('equilibration_index', conditional_group.id, analysis_protocol.id) coordinate_file = ProtocolPath('output_coordinate_file', conditional_group.id, production_simulation.id) trajectory_path = ProtocolPath('trajectory_file_path', conditional_group.id, production_simulation.id) statistics_path = ProtocolPath('statistics_file_path', conditional_group.id, production_simulation.id) extract_uncorrelated_trajectory = analysis.ExtractUncorrelatedTrajectoryData(f'extract_traj{id_suffix}') extract_uncorrelated_trajectory.statistical_inefficiency = statistical_inefficiency extract_uncorrelated_trajectory.equilibration_index = equilibration_index extract_uncorrelated_trajectory.input_coordinate_file = coordinate_file extract_uncorrelated_trajectory.input_trajectory_path = trajectory_path extract_uncorrelated_statistics = analysis.ExtractUncorrelatedStatisticsData(f'extract_stats{id_suffix}') extract_uncorrelated_statistics.statistical_inefficiency = statistical_inefficiency extract_uncorrelated_statistics.equilibration_index = equilibration_index extract_uncorrelated_statistics.input_statistics_path = statistics_path # Build the object which defines which pieces of simulation data to store. output_to_store = WorkflowSimulationDataToStore() output_to_store.total_number_of_molecules = ProtocolPath('final_number_of_molecules', build_coordinates.id) output_to_store.statistical_inefficiency = statistical_inefficiency output_to_store.statistics_file_path = ProtocolPath('output_statistics_path', extract_uncorrelated_statistics.id) output_to_store.trajectory_file_path = ProtocolPath('output_trajectory_path', extract_uncorrelated_trajectory.id) output_to_store.coordinate_file_path = coordinate_file # Define where the final values come from. final_value_source = ProtocolPath('value', conditional_group.id, analysis_protocol.id) base_protocols = BaseSimulationProtocols(build_coordinates, assign_parameters, energy_minimisation, equilibration_simulation, production_simulation, analysis_protocol, conditional_group, extract_uncorrelated_trajectory, extract_uncorrelated_statistics) return base_protocols, final_value_source, output_to_store
[docs]def generate_gradient_protocol_group(template_reweighting_protocol, reference_force_field_paths, target_force_field_path, coordinate_file_path, trajectory_file_path, statistics_file_path='', replicator_id='repl', perturbation_scale=1.0e-4, substance_source=None, id_suffix='', enable_pbc=True, use_subset_of_force_field=True, effective_sample_indices=None): """Constructs a set of protocols which, when combined in a workflow schema, may be executed to reweight a set of existing data to estimate a particular property. The reweighted observable of interest will be calculated by following the passed in `analysis_protocol`. Parameters ---------- template_reweighting_protocol: BaseMBARProtocol A template protocol which will be used to reweight the observable of interest to small perturbations to the parameter of interest. These will then be used to calculate the finite difference gradient. This utility takes care of setting the target and reference reduced potentials. In the case that the template is of type `ReweightStatistics` and the observable is an energy, the statistics path will automatically be pointed to the energies evaluated using the perturbed parameter as opposed to the energy measured during the reference simulation. reference_force_field_paths: ProtocolPath or list of ProtocolPath The paths to the force field parameters which were used to generate the trajectories from which the observables of interest were calculated. target_force_field_path: ProtocolPath The path to the force field parameters which the observables are being estimated at (this is mainly only useful when estimating the gradients of reweighted observables). coordinate_file_path: ProtocolPath A path to the initial coordinates of the simulation trajectory which was used to estimate the observable of interest. trajectory_file_path: ProtocolPath A path to the simulation trajectory which was used to estimate the observable of interest. statistics_file_path: ProtocolPath, optional A path to the statistics where were generated from the trajectory passed to the `trajectory_file_path` parameter. This is optional in cases where multiple reference force fields are passed to this method. replicator_id: str A unique id which will be used for the protocol replicator which will replicate this group for every parameter of interest. perturbation_scale: float The default amount to perturb parameters by. substance_source: PlaceholderInput, optional An optional protocol path to the substance whose gradient is being estimated. If None, the global property substance is used. id_suffix: str An optional string to append to the end of each of the protocol ids. enable_pbc: bool If true, periodic boundary conditions are employed when recalculating the reduced potentials. use_subset_of_force_field: bool If True, any reduced potentials will only be calculated from a subset of the force field which depends on the parameter of interest. effective_sample_indices: ProtocolPath, optional A placeholder variable which can be used to make the gradient protocols dependant on an MBAR protcol to ensure gradients aren't calcuated when the MBAR protocol failed due to insufficient samples. Returns ------- ProtocolGroup The protocol group which will estimate the gradient of an observable with respect to one parameter. ProtocolReplicator The replicator which will copy the gradient group for every parameter of interest. ProtocolPath A protocol path which points to the final gradient value. """ assert isinstance(template_reweighting_protocol, reweighting.BaseMBARProtocol) id_suffix = f'_$({replicator_id}){id_suffix}' # Set values of the optional parameters. substance_source = ProtocolPath('substance', 'global') if substance_source is None else substance_source effective_sample_indices = effective_sample_indices if effective_sample_indices is not None else [] # Define the protocol which will evaluate the reduced potentials of the # reference, forward and reverse states using only a subset of the full # force field. reduced_potentials = gradients.GradientReducedPotentials(f'gradient_reduced_potentials{id_suffix}') reduced_potentials.substance = substance_source reduced_potentials.thermodynamic_state = ProtocolPath('thermodynamic_state', 'global') reduced_potentials.reference_force_field_paths = reference_force_field_paths reduced_potentials.reference_statistics_path = statistics_file_path reduced_potentials.force_field_path = target_force_field_path reduced_potentials.trajectory_file_path = trajectory_file_path reduced_potentials.coordinate_file_path = coordinate_file_path reduced_potentials.parameter_key = ReplicatorValue(replicator_id) reduced_potentials.perturbation_scale = perturbation_scale reduced_potentials.use_subset_of_force_field = use_subset_of_force_field reduced_potentials.enable_pbc = enable_pbc reduced_potentials.effective_sample_indices = effective_sample_indices # Set up the protocols which will actually reweight the value of the # observable to the forward and reverse states. template_reweighting_protocol.bootstrap_iterations = 1 template_reweighting_protocol.required_effective_samples = 0 template_reweighting_protocol.reference_reduced_potentials = ProtocolPath('reference_potential_paths', reduced_potentials.id) # We need to make sure we use the observable evaluated at the target state # if the observable depends on the parameter being reweighted. use_target_state_energies = (isinstance(template_reweighting_protocol, reweighting.ReweightStatistics) and (template_reweighting_protocol.statistics_type == ObservableType.PotentialEnergy or template_reweighting_protocol.statistics_type == ObservableType.ReducedPotential or template_reweighting_protocol.statistics_type == ObservableType.TotalEnergy or template_reweighting_protocol.statistics_type == ObservableType.Enthalpy)) template_reweighting_schema = template_reweighting_protocol.schema # Create the reweighting protocols from the template schema. reverse_mbar_schema = copy.deepcopy(template_reweighting_schema) reverse_mbar_schema.id = f'reverse_reweight{id_suffix}' reverse_mbar = available_protocols[reverse_mbar_schema.type](reverse_mbar_schema.id) reverse_mbar.schema = reverse_mbar_schema reverse_mbar.target_reduced_potentials = [ProtocolPath('reverse_potentials_path', reduced_potentials.id)] forward_mbar_schema = copy.deepcopy(template_reweighting_schema) forward_mbar_schema.id = f'forward_reweight{id_suffix}' forward_mbar = available_protocols[forward_mbar_schema.type](forward_mbar_schema.id) forward_mbar.schema = forward_mbar_schema forward_mbar.target_reduced_potentials = [ProtocolPath('forward_potentials_path', reduced_potentials.id)] if use_target_state_energies: reverse_mbar.statistics_paths = [ProtocolPath('reverse_potentials_path', reduced_potentials.id)] forward_mbar.statistics_paths = [ProtocolPath('forward_potentials_path', reduced_potentials.id)] # Set up the protocol which will actually evaluate the parameter gradient # using the central difference method. central_difference = gradients.CentralDifferenceGradient(f'central_difference{id_suffix}') central_difference.parameter_key = ReplicatorValue(replicator_id) central_difference.reverse_observable_value = ProtocolPath('value', reverse_mbar.id) central_difference.forward_observable_value = ProtocolPath('value', forward_mbar.id) central_difference.reverse_parameter_value = ProtocolPath('reverse_parameter_value', reduced_potentials.id) central_difference.forward_parameter_value = ProtocolPath('forward_parameter_value', reduced_potentials.id) # Assemble all of the protocols into a convenient group wrapper. gradient_group = groups.ProtocolGroup(f'gradient_group{id_suffix}') gradient_group.add_protocols(reduced_potentials, reverse_mbar, forward_mbar, central_difference) # Create the replicator which will copy the group for each parameter gradient # which will be calculated. parameter_replicator = ProtocolReplicator(replicator_id=replicator_id) parameter_replicator.template_values = ProtocolPath('parameter_gradient_keys', 'global') return gradient_group, parameter_replicator, ProtocolPath('gradient', gradient_group.id, central_difference.id)