Coverage for /builds/debichem-team/python-ase/ase/formula.py: 91.42%

268 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import re 

2from functools import lru_cache 

3from math import gcd 

4from typing import Dict, List, Sequence, Tuple, Union 

5 

6from ase.data import atomic_numbers, chemical_symbols 

7 

8# For type hints (A, A2, A+B): 

9Tree = Union[str, Tuple['Tree', int], List['Tree']] 

10 

11 

12class Formula: 

13 def __init__(self, 

14 formula: Union[str, 'Formula'] = '', 

15 *, 

16 strict: bool = False, 

17 format: str = '', 

18 _tree: Tree = None, 

19 _count: Dict[str, int] = None): 

20 """Chemical formula object. 

21 

22 Parameters 

23 ---------- 

24 formula: str 

25 Text string representation of formula. Examples: ``'6CO2'``, 

26 ``'30Cu+2CO'``, ``'Pt(CO)6'``. 

27 strict: bool 

28 Only allow real chemical symbols. 

29 format: str 

30 Reorder according to *format*. Must be one of hill, metal, 

31 ab2, a2b, periodic or reduce. 

32 

33 Examples 

34 -------- 

35 >>> from ase.formula import Formula 

36 >>> w = Formula('H2O') 

37 >>> w.count() 

38 {'H': 2, 'O': 1} 

39 >>> 'H' in w 

40 True 

41 >>> w == 'HOH' 

42 True 

43 >>> f'{w:latex}' 

44 'H$_{2}$O' 

45 >>> w.format('latex') 

46 'H$_{2}$O' 

47 >>> divmod(6 * w + 'Cu', w) 

48 (6, Formula('Cu')) 

49 

50 Raises 

51 ------ 

52 ValueError 

53 on malformed formula 

54 """ 

55 

56 # Be sure that Formula(x) works the same whether x is string or Formula 

57 assert isinstance(formula, (str, Formula)) 

58 formula = str(formula) 

59 

60 if format: 

61 assert _tree is None and _count is None 

62 if format not in {'hill', 'metal', 'abc', 'reduce', 'ab2', 'a2b', 

63 'periodic'}: 

64 raise ValueError(f'Illegal format: {format}') 

65 formula = Formula(formula).format(format) 

66 

67 self._formula = formula 

68 

69 self._tree = _tree or parse(formula) 

70 self._count = _count or count_tree(self._tree) 

71 if strict: 

72 for symbol in self._count: 

73 if symbol not in atomic_numbers: 

74 raise ValueError('Unknown chemical symbol: ' + symbol) 

75 

76 def convert(self, fmt: str) -> 'Formula': 

77 """Reformat this formula as a new Formula. 

78 

79 Same formatting rules as Formula(format=...) keyword. 

80 """ 

81 return Formula(self._formula, format=fmt) 

82 

83 def count(self) -> Dict[str, int]: 

84 """Return dictionary mapping chemical symbol to number of atoms. 

85 

86 Example 

87 ------- 

88 >>> Formula('H2O').count() 

89 {'H': 2, 'O': 1} 

90 """ 

91 return self._count.copy() 

92 

93 def reduce(self) -> Tuple['Formula', int]: 

94 """Reduce formula. 

95 

96 Returns 

97 ------- 

98 formula: Formula 

99 Reduced formula. 

100 n: int 

101 Number of reduced formula units. 

102 

103 Example 

104 ------- 

105 >>> Formula('2H2O').reduce() 

106 (Formula('H2O'), 2) 

107 """ 

108 dct, N = self._reduce() 

109 return self.from_dict(dct), N 

110 

111 def stoichiometry(self) -> Tuple['Formula', 'Formula', int]: 

112 """Reduce to unique stoichiometry using "chemical symbols" A, B, C, ... 

113 

114 Examples 

115 -------- 

116 >>> Formula('CO2').stoichiometry() 

117 (Formula('AB2'), Formula('CO2'), 1) 

118 >>> Formula('(H2O)4').stoichiometry() 

119 (Formula('AB2'), Formula('OH2'), 4) 

120 """ 

121 count1, N = self._reduce() 

122 c = ord('A') 

123 count2 = {} 

124 count3 = {} 

125 for n, symb in sorted((n, symb) 

126 for symb, n in count1.items()): 

127 count2[chr(c)] = n 

128 count3[symb] = n 

129 c += 1 

130 return self.from_dict(count2), self.from_dict(count3), N 

131 

132 def format(self, fmt: str = '') -> str: 

133 """Format formula as string. 

134 

135 Formats: 

136 

137 * ``'hill'``: alphabetically ordered with C and H first 

138 * ``'metal'``: alphabetically ordered with metals first 

139 * ``'ab2'``: count-ordered first then alphabetically ordered 

140 * ``'abc'``: old name for ``'ab2'`` 

141 * ``'a2b'``: reverse count-ordered first then alphabetically ordered 

142 * ``'periodic'``: periodic-table ordered: period first then group 

143 * ``'reduce'``: Reduce and keep order (ABBBC -> AB3C) 

144 * ``'latex'``: LaTeX representation 

145 * ``'html'``: HTML representation 

146 * ``'rest'``: reStructuredText representation 

147 

148 Example 

149 ------- 

150 >>> Formula('H2O').format('html') 

151 'H<sub>2</sub>O' 

152 """ 

153 return format(self, fmt) 

154 

155 def __format__(self, fmt: str) -> str: 

156 """Format Formula as str. 

157 

158 Possible formats: ``'hill'``, ``'metal'``, ``'abc'``, ``'reduce'``, 

159 ``'latex'``, ``'html'``, ``'rest'``. 

160 

161 Example 

162 ------- 

163 >>> f = Formula('OH2') 

164 >>> '{f}, {f:hill}, {f:latex}'.format(f=f) 

165 'OH2, H2O, OH$_{2}$' 

166 """ 

167 

168 if fmt == 'hill': 

169 count = self.count() 

170 count2 = {symb: count.pop(symb) for symb in 'CH' if symb in count} 

171 for symb, n in sorted(count.items()): 

172 count2[symb] = n 

173 return dict2str(count2) 

174 

175 if fmt == 'metal': 

176 count = self.count() 

177 result2 = [(s, count.pop(s)) for s in non_metals if s in count] 

178 result = [(s, count[s]) for s in sorted(count)] 

179 result += sorted(result2) 

180 return dict2str(dict(result)) 

181 

182 if fmt == 'abc' or fmt == 'ab2': 

183 _, f, N = self.stoichiometry() 

184 return dict2str({symb: n * N for symb, n in f._count.items()}) 

185 

186 if fmt == 'a2b': 

187 _, f, N = self.stoichiometry() 

188 return dict2str({symb: -n * N 

189 for n, symb 

190 in sorted([(-n, symb) for symb, n 

191 in f._count.items()])}) 

192 

193 if fmt == 'periodic': 

194 count = self.count() 

195 order = periodic_table_order() 

196 items = sorted(count.items(), 

197 key=lambda item: order.get(item[0], 0)) 

198 return ''.join(symb + (str(n) if n > 1 else '') 

199 for symb, n in items) 

200 

201 if fmt == 'reduce': 

202 symbols = list(self) 

203 nsymb = len(symbols) 

204 parts = [] 

205 i1 = 0 

206 for i2, symbol in enumerate(symbols): 

207 if i2 == nsymb - 1 or symbol != symbols[i2 + 1]: 

208 parts.append(symbol) 

209 m = i2 + 1 - i1 

210 if m > 1: 

211 parts.append(str(m)) 

212 i1 = i2 + 1 

213 return ''.join(parts) 

214 

215 if fmt == 'latex': 

216 return self._tostr('$_{', '}$') 

217 

218 if fmt == 'html': 

219 return self._tostr('<sub>', '</sub>') 

220 

221 if fmt == 'rest': 

222 return self._tostr(r'\ :sub:`', r'`\ ') 

223 

224 if fmt == '': 

225 return self._formula 

226 

227 raise ValueError('Invalid format specifier') 

228 

229 @staticmethod 

230 def from_dict(dct: Dict[str, int]) -> 'Formula': 

231 """Convert dict to Formula. 

232 

233 >>> Formula.from_dict({'H': 2}) 

234 Formula('H2') 

235 """ 

236 dct2 = {} 

237 for symb, n in dct.items(): 

238 if not (isinstance(symb, str) and isinstance(n, int) and n >= 0): 

239 raise ValueError(f'Bad dictionary: {dct}') 

240 if n > 0: # filter out n=0 symbols 

241 dct2[symb] = n 

242 return Formula(dict2str(dct2), 

243 _tree=[([(symb, n) for symb, n in dct2.items()], 1)], 

244 _count=dct2) 

245 

246 @staticmethod 

247 def from_list(symbols: Sequence[str]) -> 'Formula': 

248 """Convert list of chemical symbols to Formula.""" 

249 return Formula(''.join(symbols), 

250 _tree=[(symbols[:], 1)]) # type: ignore[list-item] 

251 

252 def __len__(self) -> int: 

253 """Number of atoms.""" 

254 return sum(self._count.values()) 

255 

256 def __getitem__(self, symb: str) -> int: 

257 """Number of atoms with chemical symbol *symb*.""" 

258 return self._count.get(symb, 0) 

259 

260 def __contains__(self, f: Union[str, 'Formula']) -> bool: 

261 """Check if formula contains chemical symbols in *f*. 

262 

263 Type of *f* must be str or Formula. 

264 

265 Examples 

266 -------- 

267 >>> 'OH' in Formula('H2O') 

268 True 

269 >>> 'O2' in Formula('H2O') 

270 False 

271 """ 

272 if isinstance(f, str): 

273 f = Formula(f) 

274 for symb, n in f._count.items(): 

275 if self[symb] < n: 

276 return False 

277 return True 

278 

279 def __eq__(self, other) -> bool: 

280 """Equality check. 

281 

282 Note that order is not important. 

283 

284 Example 

285 ------- 

286 >>> Formula('CO') == Formula('OC') 

287 True 

288 """ 

289 if isinstance(other, str): 

290 other = Formula(other) 

291 elif not isinstance(other, Formula): 

292 return False 

293 return self._count == other._count 

294 

295 def __add__(self, other: Union[str, 'Formula']) -> 'Formula': 

296 """Add two formulas.""" 

297 if not isinstance(other, str): 

298 other = other._formula 

299 return Formula(self._formula + '+' + other) 

300 

301 def __radd__(self, other: str): # -> Formula 

302 return Formula(other) + self 

303 

304 def __mul__(self, N: int) -> 'Formula': 

305 """Repeat formula `N` times.""" 

306 if N == 0: 

307 return Formula('') 

308 return self.from_dict({symb: n * N 

309 for symb, n in self._count.items()}) 

310 

311 def __rmul__(self, N: int): # -> Formula 

312 return self * N 

313 

314 def __divmod__(self, 

315 other: Union['Formula', str]) -> Tuple[int, 'Formula']: 

316 """Return the tuple (self // other, self % other). 

317 

318 Invariant:: 

319 

320 div, mod = divmod(self, other) 

321 div * other + mod == self 

322 

323 Example 

324 ------- 

325 >>> divmod(Formula('H2O'), 'H') 

326 (2, Formula('O')) 

327 """ 

328 if isinstance(other, str): 

329 other = Formula(other) 

330 N = min(self[symb] // n for symb, n in other._count.items()) 

331 dct = self.count() 

332 if N: 

333 for symb, n in other._count.items(): 

334 dct[symb] -= n * N 

335 if dct[symb] == 0: 

336 del dct[symb] 

337 return N, self.from_dict(dct) 

338 

339 def __rdivmod__(self, other): 

340 return divmod(Formula(other), self) 

341 

342 def __mod__(self, other): 

343 return divmod(self, other)[1] 

344 

345 def __rmod__(self, other): 

346 return Formula(other) % self 

347 

348 def __floordiv__(self, other): 

349 return divmod(self, other)[0] 

350 

351 def __rfloordiv__(self, other): 

352 return Formula(other) // self 

353 

354 def __iter__(self): 

355 return self._tree_iter() 

356 

357 def _tree_iter(self, tree=None): 

358 if tree is None: 

359 tree = self._tree 

360 if isinstance(tree, str): 

361 yield tree 

362 elif isinstance(tree, tuple): 

363 tree, N = tree 

364 for _ in range(N): 

365 yield from self._tree_iter(tree) 

366 else: 

367 for tree in tree: 

368 yield from self._tree_iter(tree) 

369 

370 def __str__(self): 

371 return self._formula 

372 

373 def __repr__(self): 

374 return f'Formula({self._formula!r})' 

375 

376 def _reduce(self): 

377 N = 0 

378 for n in self._count.values(): 

379 if N == 0: 

380 N = n 

381 else: 

382 N = gcd(n, N) 

383 dct = {symb: n // N for symb, n in self._count.items()} 

384 return dct, N 

385 

386 def _tostr(self, sub1, sub2): 

387 parts = [] 

388 for tree, n in self._tree: 

389 s = tree2str(tree, sub1, sub2) 

390 if s[0] == '(' and s[-1] == ')': 

391 s = s[1:-1] 

392 if n > 1: 

393 s = str(n) + s 

394 parts.append(s) 

395 return '+'.join(parts) 

396 

397 

398def dict2str(dct: Dict[str, int]) -> str: 

399 """Convert symbol-to-number dict to str. 

400 

401 >>> dict2str({'A': 1, 'B': 2}) 

402 'AB2' 

403 """ 

404 return ''.join(symb + (str(n) if n > 1 else '') 

405 for symb, n in dct.items()) 

406 

407 

408def parse(f: str) -> Tree: 

409 """Convert formula string to tree structure. 

410 

411 >>> parse('2A+BC2') 

412 [('A', 2), (['B', ('C', 2)], 1)] 

413 """ 

414 if not f: 

415 return [] 

416 parts = f.split('+') 

417 result = [] 

418 for part in parts: 

419 n, f = strip_number(part) 

420 result.append((parse2(f), n)) 

421 return result # type: ignore[return-value] 

422 

423 

424def parse2(f: str) -> Tree: 

425 """Convert formula string to tree structure (no "+" symbols). 

426 

427 >>> parse('10(H2O)') 

428 [(([('H', 2), 'O'], 1), 10)] 

429 """ 

430 units = [] 

431 while f: 

432 unit: Union[str, Tuple[str, int], Tree] 

433 if f[0] == '(': 

434 level = 0 

435 for i, c in enumerate(f[1:], 1): 

436 if c == '(': 

437 level += 1 

438 elif c == ')': 

439 if level == 0: 

440 break 

441 level -= 1 

442 else: 

443 raise ValueError 

444 f2 = f[1:i] 

445 n, f = strip_number(f[i + 1:]) 

446 unit = (parse2(f2), n) 

447 else: 

448 m = re.match('([A-Z][a-z]?)([0-9]*)', f) 

449 if m is None: 

450 raise ValueError 

451 symb = m.group(1) 

452 number = m.group(2) 

453 if number: 

454 unit = (symb, int(number)) 

455 else: 

456 unit = symb 

457 f = f[m.end():] 

458 units.append(unit) 

459 if len(units) == 1: 

460 return unit 

461 return units 

462 

463 

464def strip_number(s: str) -> Tuple[int, str]: 

465 """Strip leading nuimber. 

466 

467 >>> strip_number('10AB2') 

468 (10, 'AB2') 

469 >>> strip_number('AB2') 

470 (1, 'AB2') 

471 """ 

472 m = re.match('[0-9]*', s) 

473 assert m is not None 

474 return int(m.group() or 1), s[m.end():] 

475 

476 

477def tree2str(tree: Tree, 

478 sub1: str, sub2: str) -> str: 

479 """Helper function for html, latex and rest formats.""" 

480 if isinstance(tree, str): 

481 return tree 

482 if isinstance(tree, tuple): 

483 tree, N = tree 

484 s = tree2str(tree, sub1, sub2) 

485 if N == 1: 

486 if s[0] == '(' and s[-1] == ')': 

487 return s[1:-1] 

488 return s 

489 return s + sub1 + str(N) + sub2 

490 return '(' + ''.join(tree2str(tree, sub1, sub2) for tree in tree) + ')' 

491 

492 

493def count_tree(tree: Tree) -> Dict[str, int]: 

494 if isinstance(tree, str): 

495 return {tree: 1} 

496 if isinstance(tree, tuple): 

497 tree, N = tree 

498 return {symb: n * N for symb, n in count_tree(tree).items()} 

499 dct = {} # type: Dict[str, int] 

500 for tree in tree: 

501 for symb, n in count_tree(tree).items(): 

502 m = dct.get(symb, 0) 

503 dct[symb] = m + n 

504 return dct 

505 

506 

507# non metals, half-metals/metalloid, halogen, noble gas: 

508non_metals = ['H', 'He', 'B', 'C', 'N', 'O', 'F', 'Ne', 

509 'Si', 'P', 'S', 'Cl', 'Ar', 

510 'Ge', 'As', 'Se', 'Br', 'Kr', 

511 'Sb', 'Te', 'I', 'Xe', 

512 'Po', 'At', 'Rn'] 

513 

514 

515@lru_cache 

516def periodic_table_order() -> Dict[str, int]: 

517 """Create dict for sorting after period first then row.""" 

518 return {symbol: n for n, symbol in enumerate(chemical_symbols[87:] + 

519 chemical_symbols[55:87] + 

520 chemical_symbols[37:55] + 

521 chemical_symbols[19:37] + 

522 chemical_symbols[11:19] + 

523 chemical_symbols[3:11] + 

524 chemical_symbols[1:3])} 

525 

526 

527# Backwards compatibility: 

528def formula_hill(numbers, empirical=False): 

529 """Convert list of atomic numbers to a chemical formula as a string. 

530 

531 Elements are alphabetically ordered with C and H first. 

532 

533 If argument `empirical`, element counts will be divided by greatest common 

534 divisor to yield an empirical formula""" 

535 symbols = [chemical_symbols[Z] for Z in numbers] 

536 f = Formula('', _tree=[(symbols, 1)]) 

537 if empirical: 

538 f, _ = f.reduce() 

539 return f.format('hill') 

540 

541 

542# Backwards compatibility: 

543def formula_metal(numbers, empirical=False): 

544 """Convert list of atomic numbers to a chemical formula as a string. 

545 

546 Elements are alphabetically ordered with metals first. 

547 

548 If argument `empirical`, element counts will be divided by greatest common 

549 divisor to yield an empirical formula""" 

550 symbols = [chemical_symbols[Z] for Z in numbers] 

551 f = Formula('', _tree=[(symbols, 1)]) 

552 if empirical: 

553 f, _ = f.reduce() 

554 return f.format('metal')