Coverage for /builds/debichem-team/python-ase/ase/io/castep/castep_input_file.py: 78.55%
289 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import difflib
2import re
3import warnings
4from typing import List, Set
6import numpy as np
8from ase import Atoms
10# A convenient table to avoid the previously used "eval"
11_tf_table = {
12 '': True, # Just the keyword is equivalent to True
13 'True': True,
14 'False': False}
17def _parse_tss_block(value, scaled=False):
18 # Parse the assigned value for a Transition State Search structure block
19 is_atoms = isinstance(value, Atoms)
20 try:
21 is_strlist = all(map(lambda x: isinstance(x, str), value))
22 except TypeError:
23 is_strlist = False
25 if not is_atoms:
26 if not is_strlist:
27 # Invalid!
28 raise TypeError('castep.cell.positions_abs/frac_intermediate/'
29 'product expects Atoms object or list of strings')
31 # First line must be Angstroms, or nothing
32 has_units = len(value[0].strip().split()) == 1
33 if (not scaled) and has_units and value[0].strip() != 'ang':
34 raise RuntimeError('Only ang units currently supported in castep.'
35 'cell.positions_abs_intermediate/product')
36 return '\n'.join(map(str.strip, value))
37 else:
38 text_block = '' if scaled else 'ang\n'
39 positions = (value.get_scaled_positions() if scaled else
40 value.get_positions())
41 symbols = value.get_chemical_symbols()
42 for s, p in zip(symbols, positions):
43 text_block += ' {} {:.3f} {:.3f} {:.3f}\n'.format(s, *p)
45 return text_block
48class CastepOption:
49 """"A CASTEP option. It handles basic conversions from string to its value
50 type."""
52 default_convert_types = {
53 'boolean (logical)': 'bool',
54 'defined': 'bool',
55 'string': 'str',
56 'integer': 'int',
57 'real': 'float',
58 'integer vector': 'int_vector',
59 'real vector': 'float_vector',
60 'physical': 'float_physical',
61 'block': 'block'
62 }
64 def __init__(self, keyword, level, option_type, value=None,
65 docstring='No information available'):
66 self.keyword = keyword
67 self.level = level
68 self.type = option_type
69 self._value = value
70 self.__doc__ = docstring
72 @property
73 def value(self):
75 if self._value is not None:
76 if self.type.lower() in ('integer vector', 'real vector',
77 'physical'):
78 return ' '.join(map(str, self._value))
79 elif self.type.lower() in ('boolean (logical)', 'defined'):
80 return str(self._value).upper()
81 else:
82 return str(self._value)
84 @property
85 def raw_value(self):
86 # The value, not converted to a string
87 return self._value
89 @value.setter # type: ignore[attr-defined, no-redef]
90 def value(self, val):
92 if val is None:
93 self.clear()
94 return
96 ctype = self.default_convert_types.get(self.type.lower(), 'str')
97 typeparse = f'_parse_{ctype}'
98 try:
99 self._value = getattr(self, typeparse)(val)
100 except ValueError:
101 raise ConversionError(ctype, self.keyword, val)
103 def clear(self):
104 """Reset the value of the option to None again"""
105 self._value = None
107 @staticmethod
108 def _parse_bool(value):
109 try:
110 value = _tf_table[str(value).strip().title()]
111 except (KeyError, ValueError):
112 raise ValueError()
113 return value
115 @staticmethod
116 def _parse_str(value):
117 value = str(value)
118 return value
120 @staticmethod
121 def _parse_int(value):
122 value = int(value)
123 return value
125 @staticmethod
126 def _parse_float(value):
127 value = float(value)
128 return value
130 @staticmethod
131 def _parse_int_vector(value):
132 # Accepts either a string or an actual list/numpy array of ints
133 if isinstance(value, str):
134 if ',' in value:
135 value = value.replace(',', ' ')
136 value = list(map(int, value.split()))
138 value = np.array(value)
140 if value.shape != (3,) or value.dtype != int:
141 raise ValueError()
143 return list(value)
145 @staticmethod
146 def _parse_float_vector(value):
147 # Accepts either a string or an actual list/numpy array of floats
148 if isinstance(value, str):
149 if ',' in value:
150 value = value.replace(',', ' ')
151 value = list(map(float, value.split()))
153 value = np.array(value) * 1.0
155 if value.shape != (3,) or value.dtype != float:
156 raise ValueError()
158 return list(value)
160 @staticmethod
161 def _parse_float_physical(value):
162 # If this is a string containing units, saves them
163 if isinstance(value, str):
164 value = value.split()
166 try:
167 l = len(value)
168 except TypeError:
169 l = 1
170 value = [value]
172 if l == 1:
173 try:
174 value = (float(value[0]), '')
175 except (TypeError, ValueError):
176 raise ValueError()
177 elif l == 2:
178 try:
179 value = (float(value[0]), value[1])
180 except (TypeError, ValueError, IndexError):
181 raise ValueError()
182 else:
183 raise ValueError()
185 return value
187 @staticmethod
188 def _parse_block(value):
190 if isinstance(value, str):
191 return value
192 elif hasattr(value, '__getitem__'):
193 return '\n'.join(value) # Arrays of lines
194 else:
195 raise ValueError()
197 def __repr__(self):
198 if self._value:
199 expr = ('Option: {keyword}({type}, {level}):\n{_value}\n'
200 ).format(**self.__dict__)
201 else:
202 expr = ('Option: {keyword}[unset]({type}, {level})'
203 ).format(**self.__dict__)
204 return expr
206 def __eq__(self, other):
207 if not isinstance(other, CastepOption):
208 return False
209 else:
210 return self.__dict__ == other.__dict__
213class CastepOptionDict:
214 """A dictionary-like object to hold a set of options for .cell or .param
215 files loaded from a dictionary, for the sake of validation.
217 Replaces the old CastepCellDict and CastepParamDict that were defined in
218 the castep_keywords.py file.
219 """
221 def __init__(self, options=None):
222 object.__init__(self)
223 self._options = {} # ComparableDict is not needed any more as
224 # CastepOptions can be compared directly now
225 for kw in options:
226 opt = CastepOption(**options[kw])
227 self._options[opt.keyword] = opt
228 self.__dict__[opt.keyword] = opt
231class CastepInputFile:
233 """Master class for CastepParam and CastepCell to inherit from"""
235 _keyword_conflicts: List[Set[str]] = []
237 def __init__(self, options_dict=None, keyword_tolerance=1):
238 object.__init__(self)
240 if options_dict is None:
241 options_dict = CastepOptionDict({})
243 self._options = options_dict._options
244 self.__dict__.update(self._options)
245 # keyword_tolerance means how strict the checks on new attributes are
246 # 0 = no new attributes allowed
247 # 1 = new attributes allowed, warning given
248 # 2 = new attributes allowed, silent
249 self._perm = np.clip(keyword_tolerance, 0, 2)
251 # Compile a dictionary for quick check of conflict sets
252 self._conflict_dict = {
253 kw: set(cset).difference({kw})
254 for cset in self._keyword_conflicts for kw in cset}
256 def __repr__(self):
257 expr = ''
258 is_default = True
259 for key, option in sorted(self._options.items()):
260 if option.value is not None:
261 is_default = False
262 expr += ('%20s : %s\n' % (key, option.value))
263 if is_default:
264 expr = 'Default\n'
266 expr += f'Keyword tolerance: {self._perm}'
267 return expr
269 def __setattr__(self, attr, value):
271 # Hidden attributes are treated normally
272 if attr.startswith('_'):
273 self.__dict__[attr] = value
274 return
276 if attr not in self._options.keys():
278 if self._perm > 0:
279 # Do we consider it a string or a block?
280 is_str = isinstance(value, str)
281 is_block = False
282 if ((hasattr(value, '__getitem__') and not is_str)
283 or (is_str and len(value.split('\n')) > 1)):
284 is_block = True
286 if self._perm == 0:
287 similars = difflib.get_close_matches(attr,
288 self._options.keys())
289 if similars:
290 raise RuntimeError(
291 f'Option "{attr}" not known! You mean "{similars[0]}"?')
292 else:
293 raise RuntimeError(f'Option "{attr}" is not known!')
294 elif self._perm == 1:
295 warnings.warn(('Option "%s" is not known and will '
296 'be added as a %s') % (attr,
297 ('block' if is_block else
298 'string')))
299 attr = attr.lower()
300 opt = CastepOption(keyword=attr, level='Unknown',
301 option_type='block' if is_block else 'string')
302 self._options[attr] = opt
303 self.__dict__[attr] = opt
304 else:
305 attr = attr.lower()
306 opt = self._options[attr]
308 if not opt.type.lower() == 'block' and isinstance(value, str):
309 value = value.replace(':', ' ')
311 # If it is, use the appropriate parser, unless a custom one is defined
312 attrparse = f'_parse_{attr.lower()}'
314 # Check for any conflicts if the value is not None
315 if value is not None:
316 cset = self._conflict_dict.get(attr.lower(), {})
317 for c in cset:
318 if (c in self._options and self._options[c].value):
319 warnings.warn(
320 'option "{attr}" conflicts with "{conflict}" in '
321 'calculator. Setting "{conflict}" to '
322 'None.'.format(attr=attr, conflict=c))
323 self._options[c].value = None
325 if hasattr(self, attrparse):
326 self._options[attr].value = self.__getattribute__(attrparse)(value)
327 else:
328 self._options[attr].value = value
330 def __getattr__(self, name):
331 if name[0] == '_' or self._perm == 0:
332 raise AttributeError()
334 if self._perm == 1:
335 warnings.warn(f'Option {(name)} is not known, returning None')
337 return CastepOption(keyword='none', level='Unknown',
338 option_type='string', value=None)
340 def get_attr_dict(self, raw=False, types=False):
341 """Settings that go into .param file in a traditional dict"""
343 attrdict = {k: o.raw_value if raw else o.value
344 for k, o in self._options.items() if o.value is not None}
346 if types:
347 for key, val in attrdict.items():
348 attrdict[key] = (val, self._options[key].type)
350 return attrdict
353class CastepParam(CastepInputFile):
354 """CastepParam abstracts the settings that go into the .param file"""
356 _keyword_conflicts = [{'cut_off_energy', 'basis_precision'}, ]
358 def __init__(self, castep_keywords, keyword_tolerance=1):
359 self._castep_version = castep_keywords.castep_version
360 CastepInputFile.__init__(self, castep_keywords.CastepParamDict(),
361 keyword_tolerance)
363 @property
364 def castep_version(self):
365 return self._castep_version
367 # .param specific parsers
368 def _parse_reuse(self, value):
369 if value is None:
370 return None # Reset the value
371 try:
372 if self._options['continuation'].value:
373 warnings.warn('Cannot set reuse if continuation is set, and '
374 'vice versa. Set the other to None, if you want '
375 'this setting.')
376 return None
377 except KeyError:
378 pass
379 return 'default' if (value is True) else str(value)
381 def _parse_continuation(self, value):
382 if value is None:
383 return None # Reset the value
384 try:
385 if self._options['reuse'].value:
386 warnings.warn('Cannot set reuse if continuation is set, and '
387 'vice versa. Set the other to None, if you want '
388 'this setting.')
389 return None
390 except KeyError:
391 pass
392 return 'default' if (value is True) else str(value)
395class CastepCell(CastepInputFile):
397 """CastepCell abstracts all setting that go into the .cell file"""
399 _keyword_conflicts = [
400 {'kpoint_mp_grid', 'kpoint_mp_spacing', 'kpoint_list',
401 'kpoints_mp_grid', 'kpoints_mp_spacing', 'kpoints_list'},
402 {'bs_kpoint_mp_grid',
403 'bs_kpoint_mp_spacing',
404 'bs_kpoint_list',
405 'bs_kpoint_path',
406 'bs_kpoints_mp_grid',
407 'bs_kpoints_mp_spacing',
408 'bs_kpoints_list',
409 'bs_kpoints_path'},
410 {'spectral_kpoint_mp_grid',
411 'spectral_kpoint_mp_spacing',
412 'spectral_kpoint_list',
413 'spectral_kpoint_path',
414 'spectral_kpoints_mp_grid',
415 'spectral_kpoints_mp_spacing',
416 'spectral_kpoints_list',
417 'spectral_kpoints_path'},
418 {'phonon_kpoint_mp_grid',
419 'phonon_kpoint_mp_spacing',
420 'phonon_kpoint_list',
421 'phonon_kpoint_path',
422 'phonon_kpoints_mp_grid',
423 'phonon_kpoints_mp_spacing',
424 'phonon_kpoints_list',
425 'phonon_kpoints_path'},
426 {'fine_phonon_kpoint_mp_grid',
427 'fine_phonon_kpoint_mp_spacing',
428 'fine_phonon_kpoint_list',
429 'fine_phonon_kpoint_path'},
430 {'magres_kpoint_mp_grid',
431 'magres_kpoint_mp_spacing',
432 'magres_kpoint_list',
433 'magres_kpoint_path'},
434 {'elnes_kpoint_mp_grid',
435 'elnes_kpoint_mp_spacing',
436 'elnes_kpoint_list',
437 'elnes_kpoint_path'},
438 {'optics_kpoint_mp_grid',
439 'optics_kpoint_mp_spacing',
440 'optics_kpoint_list',
441 'optics_kpoint_path'},
442 {'supercell_kpoint_mp_grid',
443 'supercell_kpoint_mp_spacing',
444 'supercell_kpoint_list',
445 'supercell_kpoint_path'}, ]
447 def __init__(self, castep_keywords, keyword_tolerance=1):
448 self._castep_version = castep_keywords.castep_version
449 CastepInputFile.__init__(self, castep_keywords.CastepCellDict(),
450 keyword_tolerance)
452 @property
453 def castep_version(self):
454 return self._castep_version
456 # .cell specific parsers
457 def _parse_species_pot(self, value):
459 # Single tuple
460 if isinstance(value, tuple) and len(value) == 2:
461 value = [value]
462 # List of tuples
463 if hasattr(value, '__getitem__'):
464 pspots = [tuple(map(str.strip, x)) for x in value]
465 if not all(map(lambda x: len(x) == 2, value)):
466 warnings.warn(
467 'Please specify pseudopotentials in python as '
468 'a tuple or a list of tuples formatted like: '
469 '(species, file), e.g. ("O", "path-to/O_OTFG.usp") '
470 'Anything else will be ignored')
471 return None
473 text_block = self._options['species_pot'].value
475 text_block = text_block if text_block else ''
476 # Remove any duplicates
477 for pp in pspots:
478 text_block = re.sub(fr'\n?\s*{pp[0]}\s+.*', '', text_block)
479 if pp[1]:
480 text_block += '\n%s %s' % pp
482 return text_block
484 def _parse_symmetry_ops(self, value):
485 if not isinstance(value, tuple) \
486 or not len(value) == 2 \
487 or not value[0].shape[1:] == (3, 3) \
488 or not value[1].shape[1:] == (3,) \
489 or not value[0].shape[0] == value[1].shape[0]:
490 warnings.warn('Invalid symmetry_ops block, skipping')
491 return
492 # Now on to print...
493 text_block = ''
494 for op_i, (op_rot, op_tranls) in enumerate(zip(*value)):
495 text_block += '\n'.join([' '.join([str(x) for x in row])
496 for row in op_rot])
497 text_block += '\n'
498 text_block += ' '.join([str(x) for x in op_tranls])
499 text_block += '\n\n'
501 return text_block
503 def _parse_positions_abs_intermediate(self, value):
504 return _parse_tss_block(value)
506 def _parse_positions_abs_product(self, value):
507 return _parse_tss_block(value)
509 def _parse_positions_frac_intermediate(self, value):
510 return _parse_tss_block(value, True)
512 def _parse_positions_frac_product(self, value):
513 return _parse_tss_block(value, True)
516class ConversionError(Exception):
518 """Print customized error for options that are not converted correctly
519 and point out that they are maybe not implemented, yet"""
521 def __init__(self, key_type, attr, value):
522 Exception.__init__(self)
523 self.key_type = key_type
524 self.value = value
525 self.attr = attr
527 def __str__(self):
528 contact_email = 'simon.rittmeyer@tum.de'
529 return f'Could not convert {self.attr} = {self.value} '\
530 + 'to {self.key_type}\n' \
531 + 'This means you either tried to set a value of the wrong\n'\
532 + 'type or this keyword needs some special care. Please feel\n'\
533 + 'to add it to the corresponding __setattr__ method and send\n'\
534 + f'the patch to {(contact_email)}, so we can all benefit.'