Coverage for /builds/debichem-team/python-ase/ase/io/castep/castep_input_file.py: 78.55%

289 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import difflib 

2import re 

3import warnings 

4from typing import List, Set 

5 

6import numpy as np 

7 

8from ase import Atoms 

9 

10# A convenient table to avoid the previously used "eval" 

11_tf_table = { 

12 '': True, # Just the keyword is equivalent to True 

13 'True': True, 

14 'False': False} 

15 

16 

17def _parse_tss_block(value, scaled=False): 

18 # Parse the assigned value for a Transition State Search structure block 

19 is_atoms = isinstance(value, Atoms) 

20 try: 

21 is_strlist = all(map(lambda x: isinstance(x, str), value)) 

22 except TypeError: 

23 is_strlist = False 

24 

25 if not is_atoms: 

26 if not is_strlist: 

27 # Invalid! 

28 raise TypeError('castep.cell.positions_abs/frac_intermediate/' 

29 'product expects Atoms object or list of strings') 

30 

31 # First line must be Angstroms, or nothing 

32 has_units = len(value[0].strip().split()) == 1 

33 if (not scaled) and has_units and value[0].strip() != 'ang': 

34 raise RuntimeError('Only ang units currently supported in castep.' 

35 'cell.positions_abs_intermediate/product') 

36 return '\n'.join(map(str.strip, value)) 

37 else: 

38 text_block = '' if scaled else 'ang\n' 

39 positions = (value.get_scaled_positions() if scaled else 

40 value.get_positions()) 

41 symbols = value.get_chemical_symbols() 

42 for s, p in zip(symbols, positions): 

43 text_block += ' {} {:.3f} {:.3f} {:.3f}\n'.format(s, *p) 

44 

45 return text_block 

46 

47 

48class CastepOption: 

49 """"A CASTEP option. It handles basic conversions from string to its value 

50 type.""" 

51 

52 default_convert_types = { 

53 'boolean (logical)': 'bool', 

54 'defined': 'bool', 

55 'string': 'str', 

56 'integer': 'int', 

57 'real': 'float', 

58 'integer vector': 'int_vector', 

59 'real vector': 'float_vector', 

60 'physical': 'float_physical', 

61 'block': 'block' 

62 } 

63 

64 def __init__(self, keyword, level, option_type, value=None, 

65 docstring='No information available'): 

66 self.keyword = keyword 

67 self.level = level 

68 self.type = option_type 

69 self._value = value 

70 self.__doc__ = docstring 

71 

72 @property 

73 def value(self): 

74 

75 if self._value is not None: 

76 if self.type.lower() in ('integer vector', 'real vector', 

77 'physical'): 

78 return ' '.join(map(str, self._value)) 

79 elif self.type.lower() in ('boolean (logical)', 'defined'): 

80 return str(self._value).upper() 

81 else: 

82 return str(self._value) 

83 

84 @property 

85 def raw_value(self): 

86 # The value, not converted to a string 

87 return self._value 

88 

89 @value.setter # type: ignore[attr-defined, no-redef] 

90 def value(self, val): 

91 

92 if val is None: 

93 self.clear() 

94 return 

95 

96 ctype = self.default_convert_types.get(self.type.lower(), 'str') 

97 typeparse = f'_parse_{ctype}' 

98 try: 

99 self._value = getattr(self, typeparse)(val) 

100 except ValueError: 

101 raise ConversionError(ctype, self.keyword, val) 

102 

103 def clear(self): 

104 """Reset the value of the option to None again""" 

105 self._value = None 

106 

107 @staticmethod 

108 def _parse_bool(value): 

109 try: 

110 value = _tf_table[str(value).strip().title()] 

111 except (KeyError, ValueError): 

112 raise ValueError() 

113 return value 

114 

115 @staticmethod 

116 def _parse_str(value): 

117 value = str(value) 

118 return value 

119 

120 @staticmethod 

121 def _parse_int(value): 

122 value = int(value) 

123 return value 

124 

125 @staticmethod 

126 def _parse_float(value): 

127 value = float(value) 

128 return value 

129 

130 @staticmethod 

131 def _parse_int_vector(value): 

132 # Accepts either a string or an actual list/numpy array of ints 

133 if isinstance(value, str): 

134 if ',' in value: 

135 value = value.replace(',', ' ') 

136 value = list(map(int, value.split())) 

137 

138 value = np.array(value) 

139 

140 if value.shape != (3,) or value.dtype != int: 

141 raise ValueError() 

142 

143 return list(value) 

144 

145 @staticmethod 

146 def _parse_float_vector(value): 

147 # Accepts either a string or an actual list/numpy array of floats 

148 if isinstance(value, str): 

149 if ',' in value: 

150 value = value.replace(',', ' ') 

151 value = list(map(float, value.split())) 

152 

153 value = np.array(value) * 1.0 

154 

155 if value.shape != (3,) or value.dtype != float: 

156 raise ValueError() 

157 

158 return list(value) 

159 

160 @staticmethod 

161 def _parse_float_physical(value): 

162 # If this is a string containing units, saves them 

163 if isinstance(value, str): 

164 value = value.split() 

165 

166 try: 

167 l = len(value) 

168 except TypeError: 

169 l = 1 

170 value = [value] 

171 

172 if l == 1: 

173 try: 

174 value = (float(value[0]), '') 

175 except (TypeError, ValueError): 

176 raise ValueError() 

177 elif l == 2: 

178 try: 

179 value = (float(value[0]), value[1]) 

180 except (TypeError, ValueError, IndexError): 

181 raise ValueError() 

182 else: 

183 raise ValueError() 

184 

185 return value 

186 

187 @staticmethod 

188 def _parse_block(value): 

189 

190 if isinstance(value, str): 

191 return value 

192 elif hasattr(value, '__getitem__'): 

193 return '\n'.join(value) # Arrays of lines 

194 else: 

195 raise ValueError() 

196 

197 def __repr__(self): 

198 if self._value: 

199 expr = ('Option: {keyword}({type}, {level}):\n{_value}\n' 

200 ).format(**self.__dict__) 

201 else: 

202 expr = ('Option: {keyword}[unset]({type}, {level})' 

203 ).format(**self.__dict__) 

204 return expr 

205 

206 def __eq__(self, other): 

207 if not isinstance(other, CastepOption): 

208 return False 

209 else: 

210 return self.__dict__ == other.__dict__ 

211 

212 

213class CastepOptionDict: 

214 """A dictionary-like object to hold a set of options for .cell or .param 

215 files loaded from a dictionary, for the sake of validation. 

216 

217 Replaces the old CastepCellDict and CastepParamDict that were defined in 

218 the castep_keywords.py file. 

219 """ 

220 

221 def __init__(self, options=None): 

222 object.__init__(self) 

223 self._options = {} # ComparableDict is not needed any more as 

224 # CastepOptions can be compared directly now 

225 for kw in options: 

226 opt = CastepOption(**options[kw]) 

227 self._options[opt.keyword] = opt 

228 self.__dict__[opt.keyword] = opt 

229 

230 

231class CastepInputFile: 

232 

233 """Master class for CastepParam and CastepCell to inherit from""" 

234 

235 _keyword_conflicts: List[Set[str]] = [] 

236 

237 def __init__(self, options_dict=None, keyword_tolerance=1): 

238 object.__init__(self) 

239 

240 if options_dict is None: 

241 options_dict = CastepOptionDict({}) 

242 

243 self._options = options_dict._options 

244 self.__dict__.update(self._options) 

245 # keyword_tolerance means how strict the checks on new attributes are 

246 # 0 = no new attributes allowed 

247 # 1 = new attributes allowed, warning given 

248 # 2 = new attributes allowed, silent 

249 self._perm = np.clip(keyword_tolerance, 0, 2) 

250 

251 # Compile a dictionary for quick check of conflict sets 

252 self._conflict_dict = { 

253 kw: set(cset).difference({kw}) 

254 for cset in self._keyword_conflicts for kw in cset} 

255 

256 def __repr__(self): 

257 expr = '' 

258 is_default = True 

259 for key, option in sorted(self._options.items()): 

260 if option.value is not None: 

261 is_default = False 

262 expr += ('%20s : %s\n' % (key, option.value)) 

263 if is_default: 

264 expr = 'Default\n' 

265 

266 expr += f'Keyword tolerance: {self._perm}' 

267 return expr 

268 

269 def __setattr__(self, attr, value): 

270 

271 # Hidden attributes are treated normally 

272 if attr.startswith('_'): 

273 self.__dict__[attr] = value 

274 return 

275 

276 if attr not in self._options.keys(): 

277 

278 if self._perm > 0: 

279 # Do we consider it a string or a block? 

280 is_str = isinstance(value, str) 

281 is_block = False 

282 if ((hasattr(value, '__getitem__') and not is_str) 

283 or (is_str and len(value.split('\n')) > 1)): 

284 is_block = True 

285 

286 if self._perm == 0: 

287 similars = difflib.get_close_matches(attr, 

288 self._options.keys()) 

289 if similars: 

290 raise RuntimeError( 

291 f'Option "{attr}" not known! You mean "{similars[0]}"?') 

292 else: 

293 raise RuntimeError(f'Option "{attr}" is not known!') 

294 elif self._perm == 1: 

295 warnings.warn(('Option "%s" is not known and will ' 

296 'be added as a %s') % (attr, 

297 ('block' if is_block else 

298 'string'))) 

299 attr = attr.lower() 

300 opt = CastepOption(keyword=attr, level='Unknown', 

301 option_type='block' if is_block else 'string') 

302 self._options[attr] = opt 

303 self.__dict__[attr] = opt 

304 else: 

305 attr = attr.lower() 

306 opt = self._options[attr] 

307 

308 if not opt.type.lower() == 'block' and isinstance(value, str): 

309 value = value.replace(':', ' ') 

310 

311 # If it is, use the appropriate parser, unless a custom one is defined 

312 attrparse = f'_parse_{attr.lower()}' 

313 

314 # Check for any conflicts if the value is not None 

315 if value is not None: 

316 cset = self._conflict_dict.get(attr.lower(), {}) 

317 for c in cset: 

318 if (c in self._options and self._options[c].value): 

319 warnings.warn( 

320 'option "{attr}" conflicts with "{conflict}" in ' 

321 'calculator. Setting "{conflict}" to ' 

322 'None.'.format(attr=attr, conflict=c)) 

323 self._options[c].value = None 

324 

325 if hasattr(self, attrparse): 

326 self._options[attr].value = self.__getattribute__(attrparse)(value) 

327 else: 

328 self._options[attr].value = value 

329 

330 def __getattr__(self, name): 

331 if name[0] == '_' or self._perm == 0: 

332 raise AttributeError() 

333 

334 if self._perm == 1: 

335 warnings.warn(f'Option {(name)} is not known, returning None') 

336 

337 return CastepOption(keyword='none', level='Unknown', 

338 option_type='string', value=None) 

339 

340 def get_attr_dict(self, raw=False, types=False): 

341 """Settings that go into .param file in a traditional dict""" 

342 

343 attrdict = {k: o.raw_value if raw else o.value 

344 for k, o in self._options.items() if o.value is not None} 

345 

346 if types: 

347 for key, val in attrdict.items(): 

348 attrdict[key] = (val, self._options[key].type) 

349 

350 return attrdict 

351 

352 

353class CastepParam(CastepInputFile): 

354 """CastepParam abstracts the settings that go into the .param file""" 

355 

356 _keyword_conflicts = [{'cut_off_energy', 'basis_precision'}, ] 

357 

358 def __init__(self, castep_keywords, keyword_tolerance=1): 

359 self._castep_version = castep_keywords.castep_version 

360 CastepInputFile.__init__(self, castep_keywords.CastepParamDict(), 

361 keyword_tolerance) 

362 

363 @property 

364 def castep_version(self): 

365 return self._castep_version 

366 

367 # .param specific parsers 

368 def _parse_reuse(self, value): 

369 if value is None: 

370 return None # Reset the value 

371 try: 

372 if self._options['continuation'].value: 

373 warnings.warn('Cannot set reuse if continuation is set, and ' 

374 'vice versa. Set the other to None, if you want ' 

375 'this setting.') 

376 return None 

377 except KeyError: 

378 pass 

379 return 'default' if (value is True) else str(value) 

380 

381 def _parse_continuation(self, value): 

382 if value is None: 

383 return None # Reset the value 

384 try: 

385 if self._options['reuse'].value: 

386 warnings.warn('Cannot set reuse if continuation is set, and ' 

387 'vice versa. Set the other to None, if you want ' 

388 'this setting.') 

389 return None 

390 except KeyError: 

391 pass 

392 return 'default' if (value is True) else str(value) 

393 

394 

395class CastepCell(CastepInputFile): 

396 

397 """CastepCell abstracts all setting that go into the .cell file""" 

398 

399 _keyword_conflicts = [ 

400 {'kpoint_mp_grid', 'kpoint_mp_spacing', 'kpoint_list', 

401 'kpoints_mp_grid', 'kpoints_mp_spacing', 'kpoints_list'}, 

402 {'bs_kpoint_mp_grid', 

403 'bs_kpoint_mp_spacing', 

404 'bs_kpoint_list', 

405 'bs_kpoint_path', 

406 'bs_kpoints_mp_grid', 

407 'bs_kpoints_mp_spacing', 

408 'bs_kpoints_list', 

409 'bs_kpoints_path'}, 

410 {'spectral_kpoint_mp_grid', 

411 'spectral_kpoint_mp_spacing', 

412 'spectral_kpoint_list', 

413 'spectral_kpoint_path', 

414 'spectral_kpoints_mp_grid', 

415 'spectral_kpoints_mp_spacing', 

416 'spectral_kpoints_list', 

417 'spectral_kpoints_path'}, 

418 {'phonon_kpoint_mp_grid', 

419 'phonon_kpoint_mp_spacing', 

420 'phonon_kpoint_list', 

421 'phonon_kpoint_path', 

422 'phonon_kpoints_mp_grid', 

423 'phonon_kpoints_mp_spacing', 

424 'phonon_kpoints_list', 

425 'phonon_kpoints_path'}, 

426 {'fine_phonon_kpoint_mp_grid', 

427 'fine_phonon_kpoint_mp_spacing', 

428 'fine_phonon_kpoint_list', 

429 'fine_phonon_kpoint_path'}, 

430 {'magres_kpoint_mp_grid', 

431 'magres_kpoint_mp_spacing', 

432 'magres_kpoint_list', 

433 'magres_kpoint_path'}, 

434 {'elnes_kpoint_mp_grid', 

435 'elnes_kpoint_mp_spacing', 

436 'elnes_kpoint_list', 

437 'elnes_kpoint_path'}, 

438 {'optics_kpoint_mp_grid', 

439 'optics_kpoint_mp_spacing', 

440 'optics_kpoint_list', 

441 'optics_kpoint_path'}, 

442 {'supercell_kpoint_mp_grid', 

443 'supercell_kpoint_mp_spacing', 

444 'supercell_kpoint_list', 

445 'supercell_kpoint_path'}, ] 

446 

447 def __init__(self, castep_keywords, keyword_tolerance=1): 

448 self._castep_version = castep_keywords.castep_version 

449 CastepInputFile.__init__(self, castep_keywords.CastepCellDict(), 

450 keyword_tolerance) 

451 

452 @property 

453 def castep_version(self): 

454 return self._castep_version 

455 

456 # .cell specific parsers 

457 def _parse_species_pot(self, value): 

458 

459 # Single tuple 

460 if isinstance(value, tuple) and len(value) == 2: 

461 value = [value] 

462 # List of tuples 

463 if hasattr(value, '__getitem__'): 

464 pspots = [tuple(map(str.strip, x)) for x in value] 

465 if not all(map(lambda x: len(x) == 2, value)): 

466 warnings.warn( 

467 'Please specify pseudopotentials in python as ' 

468 'a tuple or a list of tuples formatted like: ' 

469 '(species, file), e.g. ("O", "path-to/O_OTFG.usp") ' 

470 'Anything else will be ignored') 

471 return None 

472 

473 text_block = self._options['species_pot'].value 

474 

475 text_block = text_block if text_block else '' 

476 # Remove any duplicates 

477 for pp in pspots: 

478 text_block = re.sub(fr'\n?\s*{pp[0]}\s+.*', '', text_block) 

479 if pp[1]: 

480 text_block += '\n%s %s' % pp 

481 

482 return text_block 

483 

484 def _parse_symmetry_ops(self, value): 

485 if not isinstance(value, tuple) \ 

486 or not len(value) == 2 \ 

487 or not value[0].shape[1:] == (3, 3) \ 

488 or not value[1].shape[1:] == (3,) \ 

489 or not value[0].shape[0] == value[1].shape[0]: 

490 warnings.warn('Invalid symmetry_ops block, skipping') 

491 return 

492 # Now on to print... 

493 text_block = '' 

494 for op_i, (op_rot, op_tranls) in enumerate(zip(*value)): 

495 text_block += '\n'.join([' '.join([str(x) for x in row]) 

496 for row in op_rot]) 

497 text_block += '\n' 

498 text_block += ' '.join([str(x) for x in op_tranls]) 

499 text_block += '\n\n' 

500 

501 return text_block 

502 

503 def _parse_positions_abs_intermediate(self, value): 

504 return _parse_tss_block(value) 

505 

506 def _parse_positions_abs_product(self, value): 

507 return _parse_tss_block(value) 

508 

509 def _parse_positions_frac_intermediate(self, value): 

510 return _parse_tss_block(value, True) 

511 

512 def _parse_positions_frac_product(self, value): 

513 return _parse_tss_block(value, True) 

514 

515 

516class ConversionError(Exception): 

517 

518 """Print customized error for options that are not converted correctly 

519 and point out that they are maybe not implemented, yet""" 

520 

521 def __init__(self, key_type, attr, value): 

522 Exception.__init__(self) 

523 self.key_type = key_type 

524 self.value = value 

525 self.attr = attr 

526 

527 def __str__(self): 

528 contact_email = 'simon.rittmeyer@tum.de' 

529 return f'Could not convert {self.attr} = {self.value} '\ 

530 + 'to {self.key_type}\n' \ 

531 + 'This means you either tried to set a value of the wrong\n'\ 

532 + 'type or this keyword needs some special care. Please feel\n'\ 

533 + 'to add it to the corresponding __setattr__ method and send\n'\ 

534 + f'the patch to {(contact_email)}, so we can all benefit.'