Source code for propertyestimator.workflow.protocols

"""
A collection of specialized workflow building blocks, which when chained together,
form a larger property estimation workflow.
"""

import copy

from propertyestimator.utils import graph, utils
from propertyestimator.utils.utils import get_nested_attribute, set_nested_attribute
from propertyestimator.workflow.decorators import (
    InequalityMergeBehaviour,
    MergeBehaviour,
    protocol_input,
    protocol_output,
)
from propertyestimator.workflow.schemas import ProtocolSchema
from propertyestimator.workflow.utils import PlaceholderInput, ProtocolPath


[docs]class BaseProtocol: """The base class for a protocol which would form one step of a larger property calculation workflow. A protocol may for example: * create the coordinates of a mixed simulation box * set up a bound ligand-protein system * build the simulation topology * perform an energy minimisation An individual protocol may require a set of inputs, which may either be set as constants >>> from propertyestimator.protocols.simulation import RunOpenMMSimulation >>> >>> npt_equilibration = RunOpenMMSimulation('npt_equilibration') >>> npt_equilibration.ensemble = RunOpenMMSimulation.Ensemble.NPT or from the output of another protocol, pointed to by a ProtocolPath >>> npt_production = RunOpenMMSimulation('npt_production') >>> # Use the coordinate file output by the npt_equilibration protocol >>> # as the input to the npt_production protocol >>> npt_production.input_coordinate_file = ProtocolPath('output_coordinate_file', >>> npt_equilibration.id) In this way protocols may be chained together, thus defining a larger property calculation workflow from simple, reusable building blocks. .. warning:: This class is still heavily under development and is subject to rapid changes. """ @property def id(self): """str: The unique id of this protocol.""" return self._id @property def schema(self): """ProtocolSchema: A serializable schema for this object.""" return self._get_schema() @schema.setter def schema(self, schema_value): self._set_schema(schema_value) @property def dependencies(self): """list of ProtocolPath: A list of pointers to the protocols which this protocol takes input from. """ return_dependencies = [] for input_path in self.required_inputs: value_references = self.get_value_references(input_path) if len(value_references) == 0: continue for value_reference in value_references.values(): if value_reference in return_dependencies: continue if ( value_reference.start_protocol is None or value_reference.start_protocol == self.id ): continue return_dependencies.append(value_reference) return return_dependencies allow_merging = protocol_input( docstring="Defines whether this protocols is allowed " "to merge with other protocols.", type_hint=bool, default_value=True, )
[docs] def __init__(self, protocol_id): # A unique identifier for this node. self._id = protocol_id # Defines whether a protocol is allowed to try and merge with other identical ones. self._allow_merging = True self.provided_outputs = [] self.required_inputs = [] self._initialize()
[docs] def execute(self, directory, available_resources): """ Execute the protocol. Protocols may be chained together by passing the output of previous protocols as input to the current one. Parameters ---------- directory: str The directory to store output data in. available_resources: ComputeResources The resources available to execute on. Returns ---------- Dict[str, Any] The output of the execution. """ return self._get_output_dictionary()
def _initialize(self): """Initialize the protocol.""" # Find the required inputs and outputs. self.provided_outputs = [] self.required_inputs = [] output_attributes = utils.find_types_with_decorator(type(self), protocol_output) input_attributes = utils.find_types_with_decorator(type(self), protocol_input) for output_attribute in output_attributes: self.provided_outputs.append(ProtocolPath(output_attribute)) for input_attribute in input_attributes: self.required_inputs.append(ProtocolPath(input_attribute)) # The directory in which to execute the protocol. self.directory = None def _get_schema(self): """Returns this protocols properties (i.e id and parameters) as a ProtocolSchema Returns ------- ProtocolSchema The schema representation. """ schema = ProtocolSchema() schema.id = self.id schema.type = type(self).__name__ for input_path in self.required_inputs: if not ( input_path.start_protocol is None or ( input_path.start_protocol == self.id and input_path.start_protocol == input_path.last_protocol ) ): continue # Always make sure to only pass a copy of the input. Changing the schema # should NOT change the protocol. schema.inputs[input_path.full_path] = copy.deepcopy( self.get_value(input_path) ) return schema def _set_schema(self, schema_value): """Sets this protocols properties (i.e id and parameters) from a ProtocolSchema Parameters ---------- schema_value: ProtocolSchema The schema which will describe this protocol. """ self._id = schema_value.id if type(self).__name__ != schema_value.type: # Make sure this object is the correct type. raise ValueError( "Cannot convert a {} protocol to a {}.".format( str(type(self)), schema_value.type ) ) for input_full_path in schema_value.inputs: value = copy.deepcopy(schema_value.inputs[input_full_path]) input_path = ProtocolPath.from_string(input_full_path) self.set_value(input_path, value) def _get_output_dictionary(self): """Builds a dictionary of the output property names and their values. Returns ------- Dict[str, Any] A dictionary whose keys are the output property names, and the values their associated values. """ return_dictionary = {} for output_path in self.provided_outputs: return_dictionary[output_path.full_path] = self.get_value(output_path) return return_dictionary
[docs] def set_uuid(self, value): """Store the uuid of the calculation this protocol belongs to Parameters ---------- value : str The uuid of the parent calculation. """ if self.id.find(value) >= 0: return self._id = graph.append_uuid(self.id, value) for input_path in self.required_inputs: input_path.append_uuid(value) value_references = self.get_value_references(input_path) for value_reference in value_references.values(): value_reference.append_uuid(value) for output_path in self.provided_outputs: output_path.append_uuid(value)
[docs] def replace_protocol(self, old_id, new_id): """Finds each input which came from a given protocol and redirects it to instead take input from a new one. Notes ----- This method is mainly intended to be used only when merging multiple protocols into one. Parameters ---------- old_id : str The id of the old input protocol. new_id : str The id of the new input protocol. """ for input_path in self.required_inputs: input_path.replace_protocol(old_id, new_id) if input_path.start_protocol is not None or ( input_path.start_protocol != input_path.last_protocol and input_path.start_protocol != self.id ): continue value_references = self.get_value_references(input_path) for value_reference in value_references.values(): value_reference.replace_protocol(old_id, new_id) for output_path in self.provided_outputs: output_path.replace_protocol(old_id, new_id) if self._id == old_id: self._id = new_id
def _find_inputs_to_merge(self): """Returns a list of those inputs which should be considered when attempting to merge two different protocols of the same type. Returns ------- set of ProtocolPath References to those inputs which should be considered. """ inputs_to_consider = set() for input_path in self.required_inputs: # Do not consider paths that point to child (e.g grouped) protocols. # These should be handled by the container classes themselves. if ( input_path.start_protocol is not None and input_path.start_protocol != self.id ): continue if not ( input_path.start_protocol is None or ( input_path.start_protocol == input_path.last_protocol and input_path.start_protocol == self.id ) ): continue # If no merge behaviour flag is present (for example in the case of # ConditionalGroup conditions), simply assume this is handled explicitly # elsewhere. if not hasattr( getattr(type(self), input_path.property_name), "merge_behavior" ): continue inputs_to_consider.add(input_path) return inputs_to_consider
[docs] def can_merge(self, other, path_replacements=None): """Determines whether this protocol can be merged with another. Parameters ---------- other : :obj:`BaseProtocol` The protocol to compare against. path_replacements: list of tuple of str, optional Replacements to make in any value reference protocol paths before comparing for equality. Returns ---------- bool True if the two protocols are safe to merge. """ if not self.allow_merging or not isinstance(self, type(other)): return False if path_replacements is None: path_replacements = [] inputs_to_consider = self._find_inputs_to_merge() for input_path in inputs_to_consider: # Do a quick sanity check that the other protocol # does in fact also require this input. if input_path not in other.required_inputs: return False merge_behavior = getattr( type(self), input_path.property_name ).merge_behavior self_value = self.get_value(input_path) other_value = other.get_value(input_path) if ( isinstance(self_value, PlaceholderInput) and not isinstance(other_value, PlaceholderInput) ) or ( isinstance(other_value, PlaceholderInput) and not isinstance(self_value, PlaceholderInput) ): # We cannot safely merge inputs when only one of the values # is currently known. return False if isinstance(self_value, ProtocolPath) and isinstance( other_value, ProtocolPath ): other_value_post_merge = ProtocolPath.from_string(other_value.full_path) for original_id, new_id in path_replacements: other_value_post_merge.replace_protocol(original_id, new_id) # We cannot safely choose which value to take when the # values are not know ahead of time unless the two values # come from the exact same source. if self_value.protocol_path != other_value_post_merge.protocol_path: return False elif isinstance(self_value, PlaceholderInput) and isinstance( other_value, PlaceholderInput ): return False elif ( merge_behavior == MergeBehaviour.ExactlyEqual and self_value != other_value ): return False return True
[docs] def merge(self, other): """Merges another BaseProtocol with this one. The id of this protocol will remain unchanged. It is assumed that can_merge has already returned that these protocols are compatible to be merged together. Parameters ---------- other: BaseProtocol The protocol to merge into this one. Returns ------- Dict[str, str] A map between any original protocol ids and their new merged values. """ if not self.can_merge(other): raise ValueError("These protocols can not be safely merged.") inputs_to_consider = self._find_inputs_to_merge() for input_path in inputs_to_consider: merge_behavior = getattr( type(self), input_path.property_name ).merge_behavior if merge_behavior == MergeBehaviour.ExactlyEqual: continue if isinstance(self.get_value(input_path), ProtocolPath) or isinstance( other.get_value(input_path), ProtocolPath ): continue if merge_behavior == InequalityMergeBehaviour.SmallestValue: value = min(self.get_value(input_path), other.get_value(input_path)) elif merge_behavior == InequalityMergeBehaviour.LargestValue: value = max(self.get_value(input_path), other.get_value(input_path)) else: raise NotImplementedError() self.set_value(input_path, value) return {}
[docs] def get_value_references(self, input_path): """Returns a dictionary of references to the protocols which one of this protocols inputs (specified by `input_path`) takes its value from. Notes ----- Currently this method only functions correctly for an input value which is either currently a :obj:`ProtocolPath`, or a `list` / `dict` which contains at least one :obj:`ProtocolPath`. Parameters ---------- input_path: :obj:`propertyestimator.workflow.utils.ProtocolPath` The input value to check. Returns ------- dict of ProtocolPath and ProtocolPath A dictionary of the protocol paths that the input targeted by `input_path` depends upon. """ input_value = self.get_value(input_path) if isinstance(input_value, ProtocolPath): return {input_path: input_value} if ( not isinstance(input_value, list) and not isinstance(input_value, tuple) and not isinstance(input_value, dict) ): return {} property_name, protocols_ids = ProtocolPath.to_components(input_path.full_path) return_paths = {} if isinstance(input_value, list) or isinstance(input_value, tuple): for index, list_value in enumerate(input_value): if not isinstance(list_value, ProtocolPath): continue path_index = ProtocolPath( property_name + "[{}]".format(index), *protocols_ids ) return_paths[path_index] = list_value else: for dict_key in input_value: if not isinstance(input_value[dict_key], ProtocolPath): continue path_index = ProtocolPath( property_name + "[{}]".format(dict_key), *protocols_ids ) return_paths[path_index] = input_value[dict_key] return return_paths
[docs] def get_class_attribute(self, reference_path): """Returns one of this protocols, or any of its children's, attributes directly (rather than its value). Parameters ---------- reference_path: ProtocolPath The path pointing to the attribute to return. Returns ---------- object: The class attribute. """ if ( reference_path.start_protocol is not None and reference_path.start_protocol != self.id ): raise ValueError( "The reference path {} does not point to this protocol".format( reference_path ) ) if ( reference_path.property_name.count(ProtocolPath.property_separator) >= 1 or reference_path.property_name.find("[") > 0 ): raise ValueError( "The expected attribute cannot be found for " "nested property names: {}".format(reference_path.property_name) ) return getattr(type(self), reference_path.property_name)
[docs] def get_value(self, reference_path): """Returns the value of one of this protocols inputs / outputs. Parameters ---------- reference_path: ProtocolPath The path pointing to the value to return. Returns ---------- Any: The value of the input / output """ if ( reference_path.start_protocol is not None and reference_path.start_protocol != self.id ): raise ValueError("The reference path does not target this protocol.") if reference_path.property_name is None or reference_path.property_name == "": raise ValueError("The reference path does specify a property to return.") return get_nested_attribute(self, reference_path.property_name)
[docs] def set_value(self, reference_path, value): """Sets the value of one of this protocols inputs. Parameters ---------- reference_path: ProtocolPath The path pointing to the value to return. value: Any The value to set. """ if ( reference_path.start_protocol is not None and reference_path.start_protocol != self.id ): raise ValueError("The reference path does not target this protocol.") if reference_path.property_name is None or reference_path.property_name == "": raise ValueError("The reference path does specify a property to set.") if reference_path in self.provided_outputs: raise ValueError("Output values cannot be set by this method.") set_nested_attribute(self, reference_path.property_name, value)
[docs] def apply_replicator( self, replicator, template_values, template_index=-1, template_value=None, update_input_references=False, ): """Applies a `ProtocolReplicator` to this protocol. This method should clone any protocols whose id contains the id of the replicator (in the format `$(replicator.id)`). Parameters ---------- replicator: ProtocolReplicator The replicator to apply. template_values: list of Any A list of the values which will be inserted into the newly replicated protocols. This parameter is mutually exclusive with `template_index` and `template_value` template_index: int, optional A specific value which should be used for any protocols flagged as to be replicated by the replicator. This option is mainly used when replicating children of an already replicated protocol. This parameter is mutually exclusive with `template_values` and must be set along with a `template_value`. template_value: Any, optional A specific index which should be used for any protocols flagged as to be replicated by the replicator. This option is mainly used when replicating children of an already replicated protocol. This parameter is mutually exclusive with `template_values` and must be set along with a `template_index`. update_input_references: bool If true, any protocols which take their input from a protocol which was flagged for replication will be updated to take input from the actually replicated protocol. This should only be set to true if this protocol is not nested within a workflow or a protocol group. This option cannot be used when a specific `template_index` or `template_value` is providied. Returns ------- dict of ProtocolPath and list of tuple of ProtocolPath and int A dictionary of references to all of the protocols which have been replicated, with keys of original protocol ids. Each value is comprised of a list of the replicated protocol ids, and their index into the `template_values` array. """ return {}