Source code for openff.evaluator.layers.reweighting

"""This module implements a `CalculationLayer` which attempts to
'reweight' cached simulation data to evalulate the values of properties
at states which have not previously been simulated directly, but where
simulations at similar states have been run.
"""

import copy
import os
import time

import dateutil.parser
import numpy
from openff.units import unit

from openff.evaluator.attributes import Attribute, PlaceholderValue
from openff.evaluator.datasets import PropertyPhase
from openff.evaluator.layers import calculation_layer
from openff.evaluator.layers.workflow import (
    WorkflowCalculationLayer,
    WorkflowCalculationSchema,
)
from openff.evaluator.storage.query import SimulationDataQuery


[docs]def default_storage_query(): """Return the default query to use when retrieving cached simulation data from the storage backend. Currently this query will search for data for the full substance of interest in the liquid phase. Returns ------- dict of str and SimulationDataQuery A single query with a key of `"full_system_data"`. """ query = SimulationDataQuery() query.substance = PlaceholderValue() query.property_phase = PropertyPhase.Liquid return {"full_system_data": query}
[docs]class ReweightingSchema(WorkflowCalculationSchema): """A schema which encodes the options and the workflow schema that the `SimulationLayer` should use when estimating a given class of physical properties using the built-in workflow framework. """ storage_queries = Attribute( docstring="The queries to perform when retrieving data for each " "of the components in the system from the storage backend. The " "keys of this dictionary will correspond to the metadata keys made " "available to the workflow system.", type_hint=dict, default_value=default_storage_query(), ) maximum_data_points = Attribute( docstring="The maximum number of data points to include " "as part of the multi-state reweighting calculations. If " "zero, no cap will be applied.", type_hint=int, default_value=4, ) temperature_cutoff = Attribute( docstring="The maximum difference between the target temperature " "and the temperature at which cached data was collected to. Data " "collected for temperatures outside of this cutoff will be ignored.", type_hint=unit.Quantity, default_value=5.0 * unit.kelvin, )
[docs] def validate(self, attribute_type=None): super(ReweightingSchema, self).validate(attribute_type) assert len(self.storage_queries) > 0 assert self.maximum_data_points > 0 assert all( isinstance(x, SimulationDataQuery) for x in self.storage_queries.values() )
[docs]@calculation_layer() class ReweightingLayer(WorkflowCalculationLayer): """A `CalculationLayer` which attempts to 'reweight' cached simulation data to evaluate the values of properties at states which have not previously been simulated directly, but where simulations at similar states have been run previously. """
[docs] @classmethod def required_schema_type(cls): return ReweightingSchema
@staticmethod def _rank_cached_data(data_list, target_temperature, temperature_cutoff): """Sorts the data retrieved from a storage backend based upon the likelihood that the data will contribute significantly to the reweighting. Currently we naively just prefer newer data over older data. Parameters ---------- data_list: list of tuple of str, StoredSimulationData and str A list of query results which take the form (storage_key, data_object, data_directory_path). target_temperature: openff.evaluator.unit.Quantity The temperature that the data will be reweighted to. temperature_cutoff: openff.evaluator.unit.Quantity The maximum difference in the target and data temperatures. Returns ------- list of tuple of str, StoredSimulationData and str The ranked data. """ sorted_list = [] times_created = [] # First remove any data measured outside of the allowed temperature range. for data_tuple in data_list: _, data_object, data_directory = data_tuple if ( numpy.abs( data_object.thermodynamic_state.temperature - target_temperature ) > temperature_cutoff ): continue sorted_list.append(data_tuple) # Roughly determine the time at which this data was created. time_created = dateutil.parser.parse( time.ctime(os.path.getctime(data_directory)) ) times_created.append(time_created) return [ data for _, data in reversed( sorted(zip(times_created, sorted_list), key=lambda pair: pair[0]) ) ] @staticmethod def _get_workflow_metadata( working_directory, physical_property, force_field_path, parameter_gradient_keys, storage_backend, calculation_schema, ): """ Parameters ---------- calculation_schema: ReweightingSchema """ global_metadata = WorkflowCalculationLayer._get_workflow_metadata( working_directory, physical_property, force_field_path, parameter_gradient_keys, storage_backend, calculation_schema, ) template_queries = calculation_schema.storage_queries # Apply the storage queries required_force_field_keys = set() for key in template_queries: query = copy.deepcopy(template_queries[key]) # Fill in any place holder values. if isinstance(query.substance, PlaceholderValue): query.substance = physical_property.substance # Apply the query. query_results = storage_backend.query(query) if len(query_results) == 0: # We haven't found and cached data which is compatible # with this property. return None # Save a local copy of the data object file. stored_data_tuples = [] for query_list in query_results.values(): query_list = ReweightingLayer._rank_cached_data( query_list, physical_property.thermodynamic_state.temperature, calculation_schema.temperature_cutoff, ) if calculation_schema.maximum_data_points > 0: query_list = query_list[0 : calculation_schema.maximum_data_points] if len(query_list) == 0: # Make sure we still have data after the cutoff check. return None query_data_tuples = [] for storage_key, data_object, data_directory in query_list: object_path = os.path.join(working_directory, f"{storage_key}") force_field_path = os.path.join( working_directory, f"{data_object.force_field_id}" ) # Save a local copy of the data object file. if not os.path.isfile(object_path): data_object.json(object_path) required_force_field_keys.add(data_object.force_field_id) query_data_tuples.append( (object_path, data_directory, force_field_path) ) stored_data_tuples.append(query_data_tuples) # Add the results to the metadata. if len(stored_data_tuples) == 1: stored_data_tuples = stored_data_tuples[0] global_metadata[key] = stored_data_tuples # Make a local copy of the required force fields for force_field_id in required_force_field_keys: force_field_path = os.path.join(working_directory, force_field_id) if not os.path.isfile(force_field_path): existing_force_field = storage_backend.retrieve_force_field( force_field_id ) existing_force_field.json(force_field_path) return global_metadata