Coverage for /builds/debichem-team/python-ase/ase/calculators/kim/kimpy_wrappers.py: 75.81%
339 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1"""
2Wrappers that provide a minimal interface to kimpy methods and objects
4Daniel S. Karls
5University of Minnesota
6"""
8import functools
9from abc import ABC
11import numpy as np
13from .exceptions import (
14 KIMModelInitializationError,
15 KIMModelNotFound,
16 KIMModelParameterError,
17 KimpyError,
18)
21class LazyKimpyImport:
22 """This class avoids module level import of the optional kimpy module."""
24 def __getattr__(self, attr):
25 return getattr(self._kimpy, attr)
27 @functools.cached_property
28 def _kimpy(self):
29 import kimpy
30 return kimpy
33class Wrappers:
34 """Shortcuts written in a way that avoids module-level kimpy import."""
36 @property
37 def collections_create(self):
38 return functools.partial(check_call, kimpy.collections.create)
40 @property
41 def model_create(self):
42 return functools.partial(check_call, kimpy.model.create)
44 @property
45 def simulator_model_create(self):
46 return functools.partial(check_call, kimpy.simulator_model.create)
48 @property
49 def get_species_name(self):
50 return functools.partial(
51 check_call, kimpy.species_name.get_species_name)
53 @property
54 def get_number_of_species_names(self):
55 return functools.partial(
56 check_call, kimpy.species_name.get_number_of_species_names)
58 @property
59 def collection_item_type_portableModel(self):
60 return kimpy.collection_item_type.portableModel
63kimpy = LazyKimpyImport()
64wrappers = Wrappers()
66# Function used for casting parameter/extent indices to C-compatible ints
67c_int = np.intc
69# Function used for casting floating point parameter values to C-compatible
70# doubles
71c_double = np.double
74def c_int_args(func):
75 """
76 Decorator for instance methods that will cast all of the args passed,
77 excluding the first (which corresponds to 'self'), to C-compatible
78 integers.
79 """
81 @functools.wraps(func)
82 def myfunc(*args, **kwargs):
83 args_cast = [args[0]]
84 args_cast += map(c_int, args[1:])
85 return func(*args, **kwargs)
87 return myfunc
90def check_call(f, *args, **kwargs):
91 """Call a kimpy function using its arguments and, if a RuntimeError is
92 raised, catch it and raise a KimpyError with the exception's
93 message.
95 (Starting with kimpy 2.0.0, a RuntimeError is the only exception
96 type raised when something goes wrong.)"""
98 try:
99 return f(*args, **kwargs)
100 except RuntimeError as e:
101 raise KimpyError(
102 f'Calling kimpy function "{f.__name__}" failed:\n {e!s}')
105def check_call_wrapper(func):
106 @functools.wraps(func)
107 def myfunc(*args, **kwargs):
108 return check_call(func, *args, **kwargs)
110 return myfunc
113class ModelCollections:
114 """
115 KIM Portable Models and Simulator Models are installed/managed into
116 different "collections". In order to search through the different
117 KIM API model collections on the system, a corresponding object must
118 be instantiated. For more on model collections, see the KIM API's
119 install file:
120 https://github.com/openkim/kim-api/blob/master/INSTALL
121 """
123 def __init__(self):
124 self.collection = wrappers.collections_create()
126 def __enter__(self):
127 return self
129 def __exit__(self, exc_type, value, traceback):
130 pass
132 def get_item_type(self, model_name):
133 try:
134 model_type = check_call(self.collection.get_item_type, model_name)
135 except KimpyError:
136 msg = (
137 "Could not find model {} installed in any of the KIM API "
138 "model collections on this system. See "
139 "https://openkim.org/doc/usage/obtaining-models/ for "
140 "instructions on installing models.".format(model_name)
141 )
142 raise KIMModelNotFound(msg)
144 return model_type
146 @property
147 def initialized(self):
148 return hasattr(self, "collection")
151class PortableModel:
152 """Creates a KIM API Portable Model object and provides a minimal
153 interface to it"""
155 def __init__(self, model_name, debug):
156 self.model_name = model_name
157 self.debug = debug
159 # Create KIM API Model object
160 units_accepted, self.kim_model = wrappers.model_create(
161 kimpy.numbering.zeroBased,
162 kimpy.length_unit.A,
163 kimpy.energy_unit.eV,
164 kimpy.charge_unit.e,
165 kimpy.temperature_unit.K,
166 kimpy.time_unit.ps,
167 self.model_name,
168 )
170 if not units_accepted:
171 raise KIMModelInitializationError(
172 "Requested units not accepted in kimpy.model.create"
173 )
175 if self.debug:
176 l_unit, e_unit, c_unit, te_unit, ti_unit = check_call(
177 self.kim_model.get_units
178 )
179 print(f"Length unit is: {l_unit}")
180 print(f"Energy unit is: {e_unit}")
181 print(f"Charge unit is: {c_unit}")
182 print(f"Temperature unit is: {te_unit}")
183 print(f"Time unit is: {ti_unit}")
184 print()
186 self._create_parameters()
188 def __enter__(self):
189 return self
191 def __exit__(self, exc_type, value, traceback):
192 pass
194 @check_call_wrapper
195 def _get_number_of_parameters(self):
196 return self.kim_model.get_number_of_parameters()
198 def _create_parameters(self):
199 def _kim_model_parameter(**kwargs):
200 dtype = kwargs["dtype"]
202 if dtype == "Integer":
203 return KIMModelParameterInteger(**kwargs)
204 elif dtype == "Double":
205 return KIMModelParameterDouble(**kwargs)
206 else:
207 raise KIMModelParameterError(
208 f"Invalid model parameter type {dtype}. Supported types "
209 "'Integer' and 'Double'."
210 )
212 self._parameters = {}
213 num_params = self._get_number_of_parameters()
214 for index_param in range(num_params):
215 parameter_metadata = self._get_one_parameter_metadata(index_param)
216 name = parameter_metadata["name"]
218 self._parameters[name] = _kim_model_parameter(
219 kim_model=self.kim_model,
220 dtype=parameter_metadata["dtype"],
221 extent=parameter_metadata["extent"],
222 name=name,
223 description=parameter_metadata["description"],
224 parameter_index=index_param,
225 )
227 def get_model_supported_species_and_codes(self):
228 """Get all of the supported species for this model and their
229 corresponding integer codes that are defined in the KIM API
231 Returns
232 -------
233 species : list of str
234 Abbreviated chemical symbols of all species the mmodel
235 supports (e.g. ["Mo", "S"])
237 codes : list of int
238 Integer codes used by the model for each species (order
239 corresponds to the order of ``species``)
240 """
241 species = []
242 codes = []
243 num_kim_species = wrappers.get_number_of_species_names()
245 for i in range(num_kim_species):
246 species_name = wrappers.get_species_name(i)
248 species_is_supported, code = self.get_species_support_and_code(
249 species_name)
251 if species_is_supported:
252 species.append(str(species_name))
253 codes.append(code)
255 return species, codes
257 @check_call_wrapper
258 def clear_then_refresh(self):
259 self.kim_model.clear_then_refresh()
261 @c_int_args
262 def _get_parameter_metadata(self, index_parameter):
263 try:
264 dtype, extent, name, description = check_call(
265 self.kim_model.get_parameter_metadata, index_parameter
266 )
267 except KimpyError as e:
268 raise KIMModelParameterError(
269 "Failed to retrieve metadata for "
270 f"parameter at index {index_parameter}"
271 ) from e
273 return dtype, extent, name, description
275 def parameters_metadata(self):
276 """Metadata associated with all model parameters.
278 Returns
279 -------
280 dict
281 Metadata associated with all model parameters.
282 """
283 return {
284 param_name: param.metadata
285 for param_name, param in self._parameters.items()
286 }
288 def parameter_names(self):
289 """Names of model parameters registered in the KIM API.
291 Returns
292 -------
293 tuple
294 Names of model parameters registered in the KIM API
295 """
296 return tuple(self._parameters.keys())
298 def get_parameters(self, **kwargs):
299 """
300 Get the values of one or more model parameter arrays.
302 Given the names of one or more model parameters and a set of indices
303 for each of them, retrieve the corresponding elements of the relevant
304 model parameter arrays.
306 Parameters
307 ----------
308 **kwargs
309 Names of the model parameters and the indices whose values should
310 be retrieved.
312 Returns
313 -------
314 dict
315 The requested indices and the values of the model's parameters.
317 Note
318 ----
319 The output of this method can be used as input of
320 ``set_parameters``.
322 Example
323 -------
324 To get `epsilons` and `sigmas` in the LJ universal model for Mo-Mo
325 (index 4879), Mo-S (index 2006) and S-S (index 1980) interactions::
327 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
328 >>> calc = KIM(LJ)
329 >>> calc.get_parameters(epsilons=[4879, 2006, 1980],
330 ... sigmas=[4879, 2006, 1980])
331 {'epsilons': [[4879, 2006, 1980],
332 [4.47499, 4.421814057295943, 4.36927]],
333 'sigmas': [[4879, 2006, 1980],
334 [2.74397, 2.30743, 1.87089]]}
335 """
336 parameters = {}
337 for parameter_name, index_range in kwargs.items():
338 parameters.update(
339 self._get_one_parameter(
340 parameter_name,
341 index_range))
342 return parameters
344 def set_parameters(self, **kwargs):
345 """
346 Set the values of one or more model parameter arrays.
348 Given the names of one or more model parameters and a set of indices
349 and corresponding values for each of them, mutate the corresponding
350 elements of the relevant model parameter arrays.
352 Parameters
353 ----------
354 **kwargs
355 Names of the model parameters to mutate and the corresponding
356 indices and values to set.
358 Returns
359 -------
360 dict
361 The requested indices and the values of the model's parameters
362 that were set.
364 Example
365 -------
366 To set `epsilons` in the LJ universal model for Mo-Mo (index 4879),
367 Mo-S (index 2006) and S-S (index 1980) interactions to 5.0, 4.5, and
368 4.0, respectively::
370 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
371 >>> calc = KIM(LJ)
372 >>> calc.set_parameters(epsilons=[[4879, 2006, 1980],
373 ... [5.0, 4.5, 4.0]])
374 {'epsilons': [[4879, 2006, 1980],
375 [5.0, 4.5, 4.0]]}
376 """
377 parameters = {}
378 for parameter_name, parameter_data in kwargs.items():
379 index_range, values = parameter_data
380 self._set_one_parameter(parameter_name, index_range, values)
381 parameters[parameter_name] = parameter_data
383 return parameters
385 def _get_one_parameter(self, parameter_name, index_range):
386 """
387 Retrieve value of one or more components of a model parameter array.
389 Parameters
390 ----------
391 parameter_name : str
392 Name of model parameter registered in the KIM API.
393 index_range : int or list
394 Zero-based index (int) or indices (list of int) specifying the
395 component(s) of the corresponding model parameter array that are
396 to be retrieved.
398 Returns
399 -------
400 dict
401 The requested indices and the corresponding values of the model
402 parameter array.
403 """
404 if parameter_name not in self._parameters:
405 raise KIMModelParameterError(
406 f"Parameter '{parameter_name}' is not "
407 "supported by this model. "
408 "Please check that the parameter name is spelled correctly."
409 )
411 return self._parameters[parameter_name].get_values(index_range)
413 def _set_one_parameter(self, parameter_name, index_range, values):
414 """
415 Set the value of one or more components of a model parameter array.
417 Parameters
418 ----------
419 parameter_name : str
420 Name of model parameter registered in the KIM API.
421 index_range : int or list
422 Zero-based index (int) or indices (list of int) specifying the
423 component(s) of the corresponding model parameter array that are
424 to be mutated.
425 values : int/float or list
426 Value(s) to assign to the component(s) of the model parameter
427 array specified by ``index_range``.
428 """
429 if parameter_name not in self._parameters:
430 raise KIMModelParameterError(
431 f"Parameter '{parameter_name}' is not "
432 "supported by this model. "
433 "Please check that the parameter name is spelled correctly."
434 )
436 self._parameters[parameter_name].set_values(index_range, values)
438 def _get_one_parameter_metadata(self, index_parameter):
439 """
440 Get metadata associated with a single model parameter.
442 Parameters
443 ----------
444 index_parameter : int
445 Zero-based index used by the KIM API to refer to this model
446 parameter.
448 Returns
449 -------
450 dict
451 Metadata associated with the requested model parameter.
452 """
453 dtype, extent, name, description = self._get_parameter_metadata(
454 index_parameter)
455 parameter_metadata = {
456 "name": name,
457 "dtype": repr(dtype),
458 "extent": extent,
459 "description": description,
460 }
461 return parameter_metadata
463 @check_call_wrapper
464 def compute(self, compute_args_wrapped, release_GIL):
465 return self.kim_model.compute(
466 compute_args_wrapped.compute_args, release_GIL)
468 @check_call_wrapper
469 def get_species_support_and_code(self, species_name):
470 return self.kim_model.get_species_support_and_code(species_name)
472 @check_call_wrapper
473 def get_influence_distance(self):
474 return self.kim_model.get_influence_distance()
476 @check_call_wrapper
477 def get_neighbor_list_cutoffs_and_hints(self):
478 return self.kim_model.get_neighbor_list_cutoffs_and_hints()
480 def compute_arguments_create(self):
481 return ComputeArguments(self, self.debug)
483 @property
484 def initialized(self):
485 return hasattr(self, "kim_model")
488class KIMModelParameter(ABC):
489 def __init__(self, kim_model, dtype, extent,
490 name, description, parameter_index):
491 self._kim_model = kim_model
492 self._dtype = dtype
493 self._extent = extent
494 self._name = name
495 self._description = description
497 # Ensure that parameter_index is cast to a C-compatible integer. This
498 # is necessary because this is passed to kimpy.
499 self._parameter_index = c_int(parameter_index)
501 @property
502 def metadata(self):
503 return {
504 "dtype": self._dtype,
505 "extent": self._extent,
506 "name": self._name,
507 "description": self._description,
508 }
510 @c_int_args
511 def _get_one_value(self, index_extent):
512 get_parameter = getattr(self._kim_model, self._dtype_accessor)
513 try:
514 return check_call(
515 get_parameter, self._parameter_index, index_extent)
516 except KimpyError as exception:
517 raise KIMModelParameterError(
518 f"Failed to access component {index_extent} of model "
519 f"parameter of type '{self._dtype}' at parameter index "
520 f"{self._parameter_index}"
521 ) from exception
523 def _set_one_value(self, index_extent, value):
524 value_typecast = self._dtype_c(value)
526 try:
527 check_call(
528 self._kim_model.set_parameter,
529 self._parameter_index,
530 c_int(index_extent),
531 value_typecast,
532 )
533 except KimpyError:
534 raise KIMModelParameterError(
535 f"Failed to set component {index_extent} at parameter index "
536 f"{self._parameter_index} to {self._dtype} value "
537 f"{value_typecast}"
538 )
540 def get_values(self, index_range):
541 index_range_dim = np.ndim(index_range)
542 if index_range_dim == 0:
543 values = self._get_one_value(index_range)
544 elif index_range_dim == 1:
545 values = []
546 for idx in index_range:
547 values.append(self._get_one_value(idx))
548 else:
549 raise KIMModelParameterError(
550 "Index range must be an integer or a list of integers"
551 )
552 return {self._name: [index_range, values]}
554 def set_values(self, index_range, values):
555 index_range_dim = np.ndim(index_range)
556 values_dim = np.ndim(values)
558 # Check the shape of index_range and values
559 msg = "index_range and values must have the same shape"
560 assert index_range_dim == values_dim, msg
562 if index_range_dim == 0:
563 self._set_one_value(index_range, values)
564 elif index_range_dim == 1:
565 assert len(index_range) == len(values), msg
566 for idx, value in zip(index_range, values):
567 self._set_one_value(idx, value)
568 else:
569 raise KIMModelParameterError(
570 "Index range must be an integer or a list containing a "
571 "single integer"
572 )
575class KIMModelParameterInteger(KIMModelParameter):
576 _dtype_c = c_int
577 _dtype_accessor = "get_parameter_int"
580class KIMModelParameterDouble(KIMModelParameter):
581 _dtype_c = c_double
582 _dtype_accessor = "get_parameter_double"
585class ComputeArguments:
586 """Creates a KIM API ComputeArguments object from a KIM Portable
587 Model object and configures it for ASE. A ComputeArguments object
588 is associated with a KIM Portable Model and is used to inform the
589 KIM API of what the model can compute. It is also used to
590 register the data arrays that allow the KIM API to pass the atomic
591 coordinates to the model and retrieve the corresponding energy and
592 forces, etc."""
594 def __init__(self, kim_model_wrapped, debug):
595 self.kim_model_wrapped = kim_model_wrapped
596 self.debug = debug
598 # Create KIM API ComputeArguments object
599 self.compute_args = check_call(
600 self.kim_model_wrapped.kim_model.compute_arguments_create
601 )
603 # Check compute arguments
604 kimpy_arg_name = kimpy.compute_argument_name
605 num_arguments = kimpy_arg_name.get_number_of_compute_argument_names()
606 if self.debug:
607 print(f"Number of compute_args: {num_arguments}")
609 for i in range(num_arguments):
610 name = check_call(kimpy_arg_name.get_compute_argument_name, i)
611 dtype = check_call(
612 kimpy_arg_name.get_compute_argument_data_type, name)
614 arg_support = self.get_argument_support_status(name)
616 if self.debug:
617 print(
618 "Compute Argument name {:21} is of type {:7} "
619 "and has support "
620 "status {}".format(*[str(x)
621 for x in [name, dtype, arg_support]])
622 )
624 # See if the model demands that we ask it for anything
625 # other than energy and forces. If so, raise an
626 # exception.
627 if arg_support == kimpy.support_status.required:
628 if (
629 name != kimpy.compute_argument_name.partialEnergy
630 and name != kimpy.compute_argument_name.partialForces
631 ):
632 raise KIMModelInitializationError(
633 f"Unsupported required ComputeArgument {name}"
634 )
636 # Check compute callbacks
637 callback_name = kimpy.compute_callback_name
638 num_callbacks = callback_name.get_number_of_compute_callback_names()
639 if self.debug:
640 print()
641 print(f"Number of callbacks: {num_callbacks}")
643 for i in range(num_callbacks):
644 name = check_call(callback_name.get_compute_callback_name, i)
646 support_status = self.get_callback_support_status(name)
648 if self.debug:
649 print(
650 "Compute callback {:17} has support status {}".format(
651 str(name), support_status
652 )
653 )
655 # Cannot handle any "required" callbacks
656 if support_status == kimpy.support_status.required:
657 raise KIMModelInitializationError(
658 f"Unsupported required ComputeCallback: {name}"
659 )
661 @check_call_wrapper
662 def set_argument_pointer(self, compute_arg_name, data_object):
663 return self.compute_args.set_argument_pointer(
664 compute_arg_name, data_object)
666 @check_call_wrapper
667 def get_argument_support_status(self, name):
668 return self.compute_args.get_argument_support_status(name)
670 @check_call_wrapper
671 def get_callback_support_status(self, name):
672 return self.compute_args.get_callback_support_status(name)
674 @check_call_wrapper
675 def set_callback(self, compute_callback_name,
676 callback_function, data_object):
677 return self.compute_args.set_callback(
678 compute_callback_name, callback_function, data_object
679 )
681 @check_call_wrapper
682 def set_callback_pointer(
683 self, compute_callback_name, callback, data_object):
684 return self.compute_args.set_callback_pointer(
685 compute_callback_name, callback, data_object
686 )
688 def update(
689 self, num_particles, species_code, particle_contributing,
690 coords, energy, forces
691 ):
692 """Register model input and output in the kim_model object."""
693 compute_arg_name = kimpy.compute_argument_name
694 set_argument_pointer = self.set_argument_pointer
696 set_argument_pointer(compute_arg_name.numberOfParticles, num_particles)
697 set_argument_pointer(
698 compute_arg_name.particleSpeciesCodes,
699 species_code)
700 set_argument_pointer(
701 compute_arg_name.particleContributing, particle_contributing
702 )
703 set_argument_pointer(compute_arg_name.coordinates, coords)
704 set_argument_pointer(compute_arg_name.partialEnergy, energy)
705 set_argument_pointer(compute_arg_name.partialForces, forces)
707 if self.debug:
708 print("Debug: called update_kim")
709 print()
712class SimulatorModel:
713 """Creates a KIM API Simulator Model object and provides a minimal
714 interface to it. This is only necessary in this package in order to
715 extract any information about a given simulator model because it is
716 generally embedded in a shared object.
717 """
719 def __init__(self, model_name):
720 # Create a KIM API Simulator Model object for this model
721 self.model_name = model_name
722 self.simulator_model = wrappers.simulator_model_create(self.model_name)
724 # Need to close template map in order to access simulator
725 # model metadata
726 self.simulator_model.close_template_map()
728 def __enter__(self):
729 return self
731 def __exit__(self, exc_type, value, traceback):
732 pass
734 @property
735 def simulator_name(self):
736 simulator_name, _ = self.simulator_model.\
737 get_simulator_name_and_version()
738 return simulator_name
740 @property
741 def num_supported_species(self):
742 num_supported_species = self.simulator_model.\
743 get_number_of_supported_species()
744 if num_supported_species == 0:
745 raise KIMModelInitializationError(
746 "Unable to determine supported species of "
747 "simulator model {}.".format(self.model_name)
748 )
749 return num_supported_species
751 @property
752 def supported_species(self):
753 supported_species = []
754 for spec_code in range(self.num_supported_species):
755 species = check_call(
756 self.simulator_model.get_supported_species, spec_code)
757 supported_species.append(species)
759 return tuple(supported_species)
761 @property
762 def num_metadata_fields(self):
763 return self.simulator_model.get_number_of_simulator_fields()
765 @property
766 def metadata(self):
767 sm_metadata_fields = {}
768 for field in range(self.num_metadata_fields):
769 extent, field_name = check_call(
770 self.simulator_model.get_simulator_field_metadata, field
771 )
772 sm_metadata_fields[field_name] = []
773 for ln in range(extent):
774 field_line = check_call(
775 self.simulator_model.get_simulator_field_line, field, ln
776 )
777 sm_metadata_fields[field_name].append(field_line)
779 return sm_metadata_fields
781 @property
782 def supported_units(self):
783 try:
784 supported_units = self.metadata["units"][0]
785 except (KeyError, IndexError):
786 raise KIMModelInitializationError(
787 "Unable to determine supported units of "
788 "simulator model {}.".format(self.model_name)
789 )
791 return supported_units
793 @property
794 def atom_style(self):
795 """
796 See if a 'model-init' field exists in the SM metadata and, if
797 so, whether it contains any entries including an "atom_style"
798 command. This is specific to LAMMPS SMs and is only required
799 for using the LAMMPSrun calculator because it uses
800 lammps.inputwriter to create a data file. All other content in
801 'model-init', if it exists, is ignored.
802 """
803 atom_style = None
804 for ln in self.metadata.get("model-init", []):
805 if ln.find("atom_style") != -1:
806 atom_style = ln.split()[1]
808 return atom_style
810 @property
811 def model_defn(self):
812 return self.metadata["model-defn"]
814 @property
815 def initialized(self):
816 return hasattr(self, "simulator_model")