Coverage for /builds/debichem-team/python-ase/ase/io/cif.py: 90.84%

491 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1"""Module to read and write atoms in cif file format. 

2 

3See http://www.iucr.org/resources/cif/spec/version1.1/cifsyntax for a 

4description of the file format. STAR extensions as save frames, 

5global blocks, nested loops and multi-data values are not supported. 

6The "latin-1" encoding is required by the IUCR specification. 

7""" 

8 

9import collections.abc 

10import io 

11import re 

12import shlex 

13import warnings 

14from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union 

15 

16import numpy as np 

17 

18from ase import Atoms 

19from ase.cell import Cell 

20from ase.io.cif_unicode import format_unicode, handle_subscripts 

21from ase.spacegroup import crystal 

22from ase.spacegroup.spacegroup import Spacegroup, spacegroup_from_data 

23from ase.utils import iofunction 

24 

25rhombohedral_spacegroups = {146, 148, 155, 160, 161, 166, 167} 

26 

27 

28old_spacegroup_names = {'Abm2': 'Aem2', 

29 'Aba2': 'Aea2', 

30 'Cmca': 'Cmce', 

31 'Cmma': 'Cmme', 

32 'Ccca': 'Ccc1'} 

33 

34# CIF maps names to either single values or to multiple values via loops. 

35CIFDataValue = Union[str, int, float] 

36CIFData = Union[CIFDataValue, List[CIFDataValue]] 

37 

38 

39def convert_value(value: str) -> CIFDataValue: 

40 """Convert CIF value string to corresponding python type.""" 

41 value = value.strip() 

42 if re.match('(".*")|(\'.*\')$', value): 

43 return handle_subscripts(value[1:-1]) 

44 elif re.match(r'[+-]?\d+$', value): 

45 return int(value) 

46 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?$', value): 

47 return float(value) 

48 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+\)$', 

49 value): 

50 return float(value[:value.index('(')]) # strip off uncertainties 

51 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+$', 

52 value): 

53 warnings.warn(f'Badly formed number: "{value}"') 

54 return float(value[:value.index('(')]) # strip off uncertainties 

55 else: 

56 return handle_subscripts(value) 

57 

58 

59def parse_multiline_string(lines: List[str], line: str) -> str: 

60 """Parse semicolon-enclosed multiline string and return it.""" 

61 assert line[0] == ';' 

62 strings = [line[1:].lstrip()] 

63 while True: 

64 line = lines.pop().strip() 

65 if line[:1] == ';': 

66 break 

67 strings.append(line) 

68 return '\n'.join(strings).strip() 

69 

70 

71def parse_singletag(lines: List[str], line: str) -> Tuple[str, CIFDataValue]: 

72 """Parse a CIF tag (entries starting with underscore). Returns 

73 a key-value pair.""" 

74 kv = line.split(None, 1) 

75 if len(kv) == 1: 

76 key = line 

77 line = lines.pop().strip() 

78 while not line or line[0] == '#': 

79 line = lines.pop().strip() 

80 if line[0] == ';': 

81 value = parse_multiline_string(lines, line) 

82 else: 

83 value = line 

84 else: 

85 key, value = kv 

86 return key, convert_value(value) 

87 

88 

89def parse_cif_loop_headers(lines: List[str]) -> Iterator[str]: 

90 while lines: 

91 line = lines.pop() 

92 tokens = line.split() 

93 

94 if len(tokens) == 1 and tokens[0].startswith('_'): 

95 header = tokens[0].lower() 

96 yield header 

97 else: 

98 lines.append(line) # 'undo' pop 

99 return 

100 

101 

102def parse_cif_loop_data(lines: List[str], 

103 ncolumns: int) -> List[List[CIFDataValue]]: 

104 columns: List[List[CIFDataValue]] = [[] for _ in range(ncolumns)] 

105 

106 tokens = [] 

107 while lines: 

108 line = lines.pop().strip() 

109 lowerline = line.lower() 

110 if (not line or 

111 line.startswith('_') or 

112 lowerline.startswith('data_') or 

113 lowerline.startswith('loop_')): 

114 lines.append(line) 

115 break 

116 

117 if line.startswith('#'): 

118 continue 

119 

120 line = line.split(' #')[0] 

121 

122 if line.startswith(';'): 

123 moretokens = [parse_multiline_string(lines, line)] 

124 else: 

125 if ncolumns == 1: 

126 moretokens = [line] 

127 else: 

128 moretokens = shlex.split(line, posix=False) 

129 

130 tokens += moretokens 

131 if len(tokens) < ncolumns: 

132 continue 

133 if len(tokens) == ncolumns: 

134 for i, token in enumerate(tokens): 

135 columns[i].append(convert_value(token)) 

136 else: 

137 warnings.warn(f'Wrong number {len(tokens)} of tokens, ' 

138 f'expected {ncolumns}: {tokens}') 

139 

140 # (Due to continue statements we cannot move this to start of loop) 

141 tokens = [] 

142 

143 if tokens: 

144 assert len(tokens) < ncolumns 

145 raise RuntimeError('CIF loop ended unexpectedly with incomplete row: ' 

146 f'{tokens}, expected {ncolumns} tokens') 

147 

148 return columns 

149 

150 

151def parse_loop(lines: List[str]) -> Dict[str, List[CIFDataValue]]: 

152 """Parse a CIF loop. Returns a dict with column tag names as keys 

153 and a lists of the column content as values.""" 

154 

155 headers = list(parse_cif_loop_headers(lines)) 

156 # Dict would be better. But there can be repeated headers. 

157 

158 columns = parse_cif_loop_data(lines, len(headers)) 

159 

160 columns_dict = {} 

161 for i, header in enumerate(headers): 

162 if header in columns_dict: 

163 warnings.warn(f'Duplicated loop tags: {header}') 

164 else: 

165 columns_dict[header] = columns[i] 

166 return columns_dict 

167 

168 

169def parse_items(lines: List[str], line: str) -> Dict[str, CIFData]: 

170 """Parse a CIF data items and return a dict with all tags.""" 

171 tags: Dict[str, CIFData] = {} 

172 

173 while True: 

174 if not lines: 

175 break 

176 line = lines.pop().strip() 

177 if not line: 

178 continue 

179 lowerline = line.lower() 

180 if not line or line.startswith('#'): 

181 continue 

182 elif line.startswith('_'): 

183 key, value = parse_singletag(lines, line) 

184 tags[key.lower()] = value 

185 elif lowerline.startswith('loop_'): 

186 tags.update(parse_loop(lines)) 

187 elif lowerline.startswith('data_'): 

188 if line: 

189 lines.append(line) 

190 break 

191 elif line.startswith(';'): 

192 parse_multiline_string(lines, line) 

193 else: 

194 raise ValueError(f'Unexpected CIF file entry: "{line}"') 

195 return tags 

196 

197 

198class NoStructureData(RuntimeError): 

199 pass 

200 

201 

202class CIFBlock(collections.abc.Mapping): 

203 """A block (i.e., a single system) in a crystallographic information file. 

204 

205 Use this object to query CIF tags or import information as ASE objects.""" 

206 

207 cell_tags = ['_cell_length_a', '_cell_length_b', '_cell_length_c', 

208 '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma'] 

209 

210 def __init__(self, name: str, tags: Dict[str, CIFData]): 

211 self.name = name 

212 self._tags = tags 

213 

214 def __repr__(self) -> str: 

215 tags = set(self._tags) 

216 return f'CIFBlock({self.name}, tags={tags})' 

217 

218 def __getitem__(self, key: str) -> CIFData: 

219 return self._tags[key] 

220 

221 def __iter__(self) -> Iterator[str]: 

222 return iter(self._tags) 

223 

224 def __len__(self) -> int: 

225 return len(self._tags) 

226 

227 def get(self, key, default=None): 

228 return self._tags.get(key, default) 

229 

230 def get_cellpar(self) -> Optional[List]: 

231 try: 

232 return [self[tag] for tag in self.cell_tags] 

233 except KeyError: 

234 return None 

235 

236 def get_cell(self) -> Cell: 

237 cellpar = self.get_cellpar() 

238 if cellpar is None: 

239 return Cell.new([0, 0, 0]) 

240 return Cell.new(cellpar) 

241 

242 def _raw_scaled_positions(self) -> Optional[np.ndarray]: 

243 coords = [self.get(name) for name in ['_atom_site_fract_x', 

244 '_atom_site_fract_y', 

245 '_atom_site_fract_z']] 

246 # XXX Shall we try to handle mixed coordinates? 

247 # (Some scaled vs others fractional) 

248 if None in coords: 

249 return None 

250 return np.array(coords).T 

251 

252 def _raw_positions(self) -> Optional[np.ndarray]: 

253 coords = [self.get('_atom_site_cartn_x'), 

254 self.get('_atom_site_cartn_y'), 

255 self.get('_atom_site_cartn_z')] 

256 if None in coords: 

257 return None 

258 return np.array(coords).T 

259 

260 def _get_site_coordinates(self): 

261 scaled = self._raw_scaled_positions() 

262 

263 if scaled is not None: 

264 return 'scaled', scaled 

265 

266 cartesian = self._raw_positions() 

267 

268 if cartesian is None: 

269 raise NoStructureData('No positions found in structure') 

270 

271 return 'cartesian', cartesian 

272 

273 def _get_symbols_with_deuterium(self): 

274 labels = self._get_any(['_atom_site_type_symbol', 

275 '_atom_site_label']) 

276 if labels is None: 

277 raise NoStructureData('No symbols') 

278 

279 symbols = [] 

280 for label in labels: 

281 if label == '.' or label == '?': 

282 raise NoStructureData('Symbols are undetermined') 

283 # Strip off additional labeling on chemical symbols 

284 match = re.search(r'([A-Z][a-z]?)', label) 

285 symbol = match.group(0) 

286 symbols.append(symbol) 

287 return symbols 

288 

289 def get_symbols(self) -> List[str]: 

290 symbols = self._get_symbols_with_deuterium() 

291 return [symbol if symbol != 'D' else 'H' for symbol in symbols] 

292 

293 def _where_deuterium(self): 

294 return np.array([symbol == 'D' for symbol 

295 in self._get_symbols_with_deuterium()], bool) 

296 

297 def _get_masses(self) -> Optional[np.ndarray]: 

298 mask = self._where_deuterium() 

299 if not any(mask): 

300 return None 

301 

302 symbols = self.get_symbols() 

303 masses = Atoms(symbols).get_masses() 

304 masses[mask] = 2.01355 

305 return masses 

306 

307 def _get_any(self, names): 

308 for name in names: 

309 if name in self: 

310 return self[name] 

311 return None 

312 

313 def _get_spacegroup_number(self): 

314 # Symmetry specification, see 

315 # http://www.iucr.org/resources/cif/dictionaries/cif_sym for a 

316 # complete list of official keys. In addition we also try to 

317 # support some commonly used depricated notations 

318 return self._get_any(['_space_group.it_number', 

319 '_space_group_it_number', 

320 '_symmetry_int_tables_number']) 

321 

322 def _get_spacegroup_name(self): 

323 hm_symbol = self._get_any(['_space_group_name_h-m_alt', 

324 '_symmetry_space_group_name_h-m', 

325 '_space_group.Patterson_name_h-m', 

326 '_space_group.patterson_name_h-m']) 

327 

328 hm_symbol = old_spacegroup_names.get(hm_symbol, hm_symbol) 

329 return hm_symbol 

330 

331 def _get_sitesym(self): 

332 sitesym = self._get_any(['_space_group_symop_operation_xyz', 

333 '_space_group_symop.operation_xyz', 

334 '_symmetry_equiv_pos_as_xyz']) 

335 if isinstance(sitesym, str): 

336 sitesym = [sitesym] 

337 return sitesym 

338 

339 def _get_fractional_occupancies(self): 

340 return self.get('_atom_site_occupancy') 

341 

342 def _get_setting(self) -> Optional[int]: 

343 setting_str = self.get('_symmetry_space_group_setting') 

344 if setting_str is None: 

345 return None 

346 

347 setting = int(setting_str) 

348 if setting not in [1, 2]: 

349 raise ValueError( 

350 f'Spacegroup setting must be 1 or 2, not {setting}') 

351 return setting 

352 

353 def get_spacegroup(self, subtrans_included) -> Spacegroup: 

354 # XXX The logic in this method needs serious cleaning up! 

355 no = self._get_spacegroup_number() 

356 if isinstance(no, str): 

357 # If the value was specified as "key 'value'" with ticks, 

358 # then "integer values" become strings and we'll have to 

359 # manually convert it: 

360 no = int(no) 

361 

362 hm_symbol = self._get_spacegroup_name() 

363 sitesym = self._get_sitesym() 

364 

365 if sitesym: 

366 # Special cases: sitesym can be None or an empty list. 

367 # The empty list could be replaced with just the identity 

368 # function, but it seems more correct to try to get the 

369 # spacegroup number and derive the symmetries for that. 

370 subtrans = [(0.0, 0.0, 0.0)] if subtrans_included else None 

371 

372 spacegroup = spacegroup_from_data( 

373 no=no, symbol=hm_symbol, sitesym=sitesym, 

374 subtrans=subtrans, 

375 setting=1) # should the setting be passed from somewhere? 

376 elif no is not None: 

377 spacegroup = no 

378 elif hm_symbol is not None: 

379 spacegroup = hm_symbol 

380 else: 

381 spacegroup = 1 

382 

383 setting_std = self._get_setting() 

384 

385 setting = 1 

386 setting_name = None 

387 if '_symmetry_space_group_setting' in self: 

388 assert setting_std is not None 

389 setting = setting_std 

390 elif '_space_group_crystal_system' in self: 

391 setting_name = self['_space_group_crystal_system'] 

392 elif '_symmetry_cell_setting' in self: 

393 setting_name = self['_symmetry_cell_setting'] 

394 

395 if setting_name: 

396 no = Spacegroup(spacegroup).no 

397 if no in rhombohedral_spacegroups: 

398 if setting_name == 'hexagonal': 

399 setting = 1 

400 elif setting_name in ('trigonal', 'rhombohedral'): 

401 setting = 2 

402 else: 

403 warnings.warn( 

404 f'unexpected crystal system {setting_name!r} ' 

405 f'for space group {spacegroup!r}') 

406 # FIXME - check for more crystal systems... 

407 else: 

408 warnings.warn( 

409 f'crystal system {setting_name!r} is not ' 

410 f'interpreted for space group {spacegroup!r}. ' 

411 'This may result in wrong setting!') 

412 

413 spg = Spacegroup(spacegroup, setting) 

414 if no is not None: 

415 assert int(spg) == no, (int(spg), no) 

416 return spg 

417 

418 def get_unsymmetrized_structure(self) -> Atoms: 

419 """Return Atoms without symmetrizing coordinates. 

420 

421 This returns a (normally) unphysical Atoms object 

422 corresponding only to those coordinates included 

423 in the CIF file, useful for e.g. debugging. 

424 

425 This method may change behaviour in the future.""" 

426 symbols = self.get_symbols() 

427 coordtype, coords = self._get_site_coordinates() 

428 

429 atoms = Atoms(symbols=symbols, 

430 cell=self.get_cell(), 

431 masses=self._get_masses()) 

432 

433 if coordtype == 'scaled': 

434 atoms.set_scaled_positions(coords) 

435 else: 

436 assert coordtype == 'cartesian' 

437 atoms.positions[:] = coords 

438 

439 return atoms 

440 

441 def has_structure(self): 

442 """Whether this CIF block has an atomic configuration.""" 

443 try: 

444 self.get_symbols() 

445 self._get_site_coordinates() 

446 except NoStructureData: 

447 return False 

448 else: 

449 return True 

450 

451 def get_atoms(self, store_tags=False, primitive_cell=False, 

452 subtrans_included=True, fractional_occupancies=True) -> Atoms: 

453 """Returns an Atoms object from a cif tags dictionary. See read_cif() 

454 for a description of the arguments.""" 

455 if primitive_cell and subtrans_included: 

456 raise RuntimeError( 

457 'Primitive cell cannot be determined when sublattice ' 

458 'translations are included in the symmetry operations listed ' 

459 'in the CIF file, i.e. when `subtrans_included` is True.') 

460 

461 cell = self.get_cell() 

462 assert cell.rank in [0, 3] 

463 

464 kwargs: Dict[str, Any] = {} 

465 if store_tags: 

466 kwargs['info'] = self._tags.copy() 

467 

468 if fractional_occupancies: 

469 occupancies = self._get_fractional_occupancies() 

470 else: 

471 occupancies = None 

472 

473 if occupancies is not None: 

474 # no warnings in this case 

475 kwargs['onduplicates'] = 'keep' 

476 

477 # The unsymmetrized_structure is not the asymmetric unit 

478 # because the asymmetric unit should have (in general) a smaller cell, 

479 # whereas we have the full cell. 

480 unsymmetrized_structure = self.get_unsymmetrized_structure() 

481 

482 if cell.rank == 3: 

483 spacegroup = self.get_spacegroup(subtrans_included) 

484 atoms = crystal(unsymmetrized_structure, 

485 spacegroup=spacegroup, 

486 setting=spacegroup.setting, 

487 occupancies=occupancies, 

488 primitive_cell=primitive_cell, 

489 **kwargs) 

490 else: 

491 atoms = unsymmetrized_structure 

492 if kwargs.get('info') is not None: 

493 atoms.info.update(kwargs['info']) 

494 if occupancies is not None: 

495 occ_dict = { 

496 str(i): {sym: occupancies[i]} 

497 for i, sym in enumerate(atoms.symbols) 

498 } 

499 atoms.info['occupancy'] = occ_dict 

500 

501 return atoms 

502 

503 

504def parse_block(lines: List[str], line: str) -> CIFBlock: 

505 assert line.lower().startswith('data_') 

506 blockname = line.split('_', 1)[1].rstrip() 

507 tags = parse_items(lines, line) 

508 return CIFBlock(blockname, tags) 

509 

510 

511def parse_cif(fileobj, reader='ase') -> Iterator[CIFBlock]: 

512 if reader == 'ase': 

513 return parse_cif_ase(fileobj) 

514 elif reader == 'pycodcif': 

515 return parse_cif_pycodcif(fileobj) 

516 else: 

517 raise ValueError(f'No such reader: {reader}') 

518 

519 

520def parse_cif_ase(fileobj) -> Iterator[CIFBlock]: 

521 """Parse a CIF file using ase CIF parser.""" 

522 

523 if isinstance(fileobj, str): 

524 with open(fileobj, 'rb') as fileobj: 

525 data = fileobj.read() 

526 else: 

527 data = fileobj.read() 

528 

529 if isinstance(data, bytes): 

530 data = data.decode('latin1') 

531 data = format_unicode(data) 

532 lines = [e for e in data.split('\n') if len(e) > 0] 

533 if len(lines) > 0 and lines[0].rstrip() == '#\\#CIF_2.0': 

534 warnings.warn('CIF v2.0 file format detected; `ase` CIF reader might ' 

535 'incorrectly interpret some syntax constructions, use ' 

536 '`pycodcif` reader instead') 

537 lines = [''] + lines[::-1] # all lines (reversed) 

538 

539 while lines: 

540 line = lines.pop().strip() 

541 if not line or line.startswith('#'): 

542 continue 

543 

544 yield parse_block(lines, line) 

545 

546 

547def parse_cif_pycodcif(fileobj) -> Iterator[CIFBlock]: 

548 """Parse a CIF file using pycodcif CIF parser.""" 

549 if not isinstance(fileobj, str): 

550 fileobj = fileobj.name 

551 

552 try: 

553 from pycodcif import parse 

554 except ImportError: 

555 raise ImportError( 

556 'parse_cif_pycodcif requires pycodcif ' + 

557 '(http://wiki.crystallography.net/cod-tools/pycodcif/)') 

558 

559 data, _, _ = parse(fileobj) 

560 

561 for datablock in data: 

562 tags = datablock['values'] 

563 for tag in tags.keys(): 

564 values = [convert_value(x) for x in tags[tag]] 

565 if len(values) == 1: 

566 tags[tag] = values[0] 

567 else: 

568 tags[tag] = values 

569 yield CIFBlock(datablock['name'], tags) 

570 

571 

572def iread_cif( 

573 fileobj, 

574 index=-1, 

575 store_tags: bool = False, 

576 primitive_cell: bool = False, 

577 subtrans_included: bool = True, 

578 fractional_occupancies: bool = True, 

579 reader: str = 'ase', 

580) -> Iterator[Atoms]: 

581 # Find all CIF blocks with valid crystal data 

582 # TODO: return Atoms of the block name ``index`` if it is a string. 

583 images = [] 

584 for block in parse_cif(fileobj, reader): 

585 if not block.has_structure(): 

586 continue 

587 

588 atoms = block.get_atoms( 

589 store_tags, primitive_cell, 

590 subtrans_included, 

591 fractional_occupancies=fractional_occupancies) 

592 images.append(atoms) 

593 

594 if index is None or index == ':': 

595 index = slice(None, None, None) 

596 

597 if not isinstance(index, (slice, str)): 

598 index = slice(index, (index + 1) or None) 

599 

600 for atoms in images[index]: 

601 yield atoms 

602 

603 

604def read_cif( 

605 fileobj, 

606 index=-1, 

607 *, 

608 store_tags: bool = False, 

609 primitive_cell: bool = False, 

610 subtrans_included: bool = True, 

611 fractional_occupancies: bool = True, 

612 reader: str = 'ase', 

613) -> Union[Atoms, List[Atoms]]: 

614 """Read Atoms object from CIF file. 

615 

616 Parameters 

617 ---------- 

618 store_tags : bool 

619 If true, the *info* attribute of the returned Atoms object will be 

620 populated with all tags in the corresponding cif data block. 

621 primitive_cell : bool 

622 If true, the primitive cell is built instead of the conventional cell. 

623 subtrans_included : bool 

624 If true, sublattice translations are assumed to be included among the 

625 symmetry operations listed in the CIF file (seems to be the common 

626 behaviour of CIF files). 

627 Otherwise the sublattice translations are determined from setting 1 of 

628 the extracted space group. A result of setting this flag to true, is 

629 that it will not be possible to determine the primitive cell. 

630 fractional_occupancies : bool 

631 If true, the resulting atoms object will be tagged equipped with a 

632 dictionary `occupancy`. The keys of this dictionary will be integers 

633 converted to strings. The conversion to string is done in order to 

634 avoid troubles with JSON encoding/decoding of the dictionaries with 

635 non-string keys. 

636 Also, in case of mixed occupancies, the atom's chemical symbol will be 

637 that of the most dominant species. 

638 reader : str 

639 Select CIF reader. 

640 

641 * ``ase`` : built-in CIF reader (default) 

642 * ``pycodcif`` : CIF reader based on ``pycodcif`` package 

643 

644 Notes 

645 ----- 

646 Only blocks with valid crystal data will be included. 

647 """ 

648 g = iread_cif( 

649 fileobj, 

650 index, 

651 store_tags, 

652 primitive_cell, 

653 subtrans_included, 

654 fractional_occupancies, 

655 reader, 

656 ) 

657 if isinstance(index, (slice, str)): 

658 # Return list of atoms 

659 return list(g) 

660 else: 

661 # Return single atoms object 

662 return next(g) 

663 

664 

665def format_cell(cell: Cell) -> str: 

666 assert cell.rank == 3 

667 lines = [] 

668 for name, value in zip(CIFBlock.cell_tags, cell.cellpar()): 

669 line = f'{name:20} {value}\n' 

670 lines.append(line) 

671 assert len(lines) == 6 

672 return ''.join(lines) 

673 

674 

675def format_generic_spacegroup_info() -> str: 

676 # We assume no symmetry whatsoever 

677 return '\n'.join([ 

678 '_space_group_name_H-M_alt "P 1"', 

679 '_space_group_IT_number 1', 

680 '', 

681 'loop_', 

682 ' _space_group_symop_operation_xyz', 

683 " 'x, y, z'", 

684 '', 

685 ]) 

686 

687 

688class CIFLoop: 

689 def __init__(self): 

690 self.names = [] 

691 self.formats = [] 

692 self.arrays = [] 

693 

694 def add(self, name, array, fmt): 

695 assert name.startswith('_') 

696 self.names.append(name) 

697 self.formats.append(fmt) 

698 self.arrays.append(array) 

699 if len(self.arrays[0]) != len(self.arrays[-1]): 

700 raise ValueError(f'Loop data "{name}" has {len(array)} ' 

701 'elements, expected {len(self.arrays[0])}') 

702 

703 def tostring(self): 

704 lines = [] 

705 append = lines.append 

706 append('loop_') 

707 for name in self.names: 

708 append(f' {name}') 

709 

710 template = ' ' + ' '.join(self.formats) 

711 

712 ncolumns = len(self.arrays) 

713 nrows = len(self.arrays[0]) if ncolumns > 0 else 0 

714 for row in range(nrows): 

715 arraydata = [array[row] for array in self.arrays] 

716 line = template.format(*arraydata) 

717 append(line) 

718 append('') 

719 return '\n'.join(lines) 

720 

721 

722@iofunction('wb') 

723def write_cif(fd, images, cif_format=None, 

724 wrap=True, labels=None, loop_keys=None) -> None: 

725 r"""Write *images* to CIF file. 

726 

727 wrap: bool 

728 Wrap atoms into unit cell. 

729 

730 labels: list 

731 Use this list (shaped list[i_frame][i_atom] = string) for the 

732 '_atom_site_label' section instead of automatically generating 

733 it from the element symbol. 

734 

735 loop_keys: dict 

736 Add the information from this dictionary to the `loop\_` 

737 section. Keys are printed to the `loop\_` section preceeded by 

738 ' _'. dict[key] should contain the data printed for each atom, 

739 so it needs to have the setup `dict[key][i_frame][i_atom] = 

740 string`. The strings are printed as they are, so take care of 

741 formating. Information can be re-read using the `store_tags` 

742 option of the cif reader. 

743 

744 """ 

745 

746 if cif_format is not None: 

747 warnings.warn('The cif_format argument is deprecated and may be ' 

748 'removed in the future. Use loop_keys to customize ' 

749 'data written in loop.', FutureWarning) 

750 

751 if loop_keys is None: 

752 loop_keys = {} 

753 

754 if hasattr(images, 'get_positions'): 

755 images = [images] 

756 

757 fd = io.TextIOWrapper(fd, encoding='latin-1') 

758 try: 

759 for i, atoms in enumerate(images): 

760 blockname = f'data_image{i}\n' 

761 image_loop_keys = {key: loop_keys[key][i] for key in loop_keys} 

762 

763 write_cif_image(blockname, atoms, fd, 

764 wrap=wrap, 

765 labels=None if labels is None else labels[i], 

766 loop_keys=image_loop_keys) 

767 

768 finally: 

769 # Using the TextIOWrapper somehow causes the file to close 

770 # when this function returns. 

771 # Detach in order to circumvent this highly illogical problem: 

772 fd.detach() 

773 

774 

775def autolabel(symbols: Sequence[str]) -> List[str]: 

776 no: Dict[str, int] = {} 

777 labels = [] 

778 for symbol in symbols: 

779 if symbol in no: 

780 no[symbol] += 1 

781 else: 

782 no[symbol] = 1 

783 labels.append('%s%d' % (symbol, no[symbol])) 

784 return labels 

785 

786 

787def chemical_formula_header(atoms): 

788 counts = atoms.symbols.formula.count() 

789 formula_sum = ' '.join(f'{sym}{count}' for sym, count 

790 in counts.items()) 

791 return (f'_chemical_formula_structural {atoms.symbols}\n' 

792 f'_chemical_formula_sum "{formula_sum}"\n') 

793 

794 

795class BadOccupancies(ValueError): 

796 pass 

797 

798 

799def expand_kinds(atoms, coords): 

800 # try to fetch occupancies // spacegroup_kinds - occupancy mapping 

801 symbols = list(atoms.symbols) 

802 coords = list(coords) 

803 occupancies = [1] * len(symbols) 

804 occ_info = atoms.info.get('occupancy') 

805 kinds = atoms.arrays.get('spacegroup_kinds') 

806 if occ_info is not None and kinds is not None: 

807 for i, kind in enumerate(kinds): 

808 occ_info_kind = occ_info[str(kind)] 

809 symbol = symbols[i] 

810 if symbol not in occ_info_kind: 

811 raise BadOccupancies('Occupancies present but no occupancy ' 

812 'info for "{symbol}"') 

813 occupancies[i] = occ_info_kind[symbol] 

814 # extend the positions array in case of mixed occupancy 

815 for sym, occ in occ_info[str(kind)].items(): 

816 if sym != symbols[i]: 

817 symbols.append(sym) 

818 coords.append(coords[i]) 

819 occupancies.append(occ) 

820 return symbols, coords, occupancies 

821 

822 

823def atoms_to_loop_data(atoms, wrap, labels, loop_keys): 

824 if atoms.cell.rank == 3: 

825 coord_type = 'fract' 

826 coords = atoms.get_scaled_positions(wrap).tolist() 

827 else: 

828 coord_type = 'Cartn' 

829 coords = atoms.get_positions(wrap).tolist() 

830 

831 try: 

832 symbols, coords, occupancies = expand_kinds(atoms, coords) 

833 except BadOccupancies as err: 

834 warnings.warn(str(err)) 

835 occupancies = [1] * len(atoms) 

836 symbols = list(atoms.symbols) 

837 

838 if labels is None: 

839 labels = autolabel(symbols) 

840 

841 coord_headers = [f'_atom_site_{coord_type}_{axisname}' 

842 for axisname in 'xyz'] 

843 

844 loopdata = {} 

845 loopdata['_atom_site_label'] = (labels, '{:<8s}') 

846 loopdata['_atom_site_occupancy'] = (occupancies, '{:6.4f}') 

847 

848 _coords = np.array(coords) 

849 for i, key in enumerate(coord_headers): 

850 loopdata[key] = (_coords[:, i], '{}') 

851 

852 loopdata['_atom_site_type_symbol'] = (symbols, '{:<2s}') 

853 loopdata['_atom_site_symmetry_multiplicity'] = ( 

854 [1.0] * len(symbols), '{}') 

855 

856 for key in loop_keys: 

857 # Should expand the loop_keys like we expand the occupancy stuff. 

858 # Otherwise user will never figure out how to do this. 

859 values = [loop_keys[key][i] for i in range(len(symbols))] 

860 loopdata['_' + key] = (values, '{}') 

861 

862 return loopdata, coord_headers 

863 

864 

865def write_cif_image(blockname, atoms, fd, *, wrap, 

866 labels, loop_keys): 

867 fd.write(blockname) 

868 fd.write(chemical_formula_header(atoms)) 

869 

870 rank = atoms.cell.rank 

871 if rank == 3: 

872 fd.write(format_cell(atoms.cell)) 

873 fd.write('\n') 

874 fd.write(format_generic_spacegroup_info()) 

875 fd.write('\n') 

876 elif rank != 0: 

877 raise ValueError('CIF format can only represent systems with ' 

878 f'0 or 3 lattice vectors. Got {rank}.') 

879 

880 loopdata, coord_headers = atoms_to_loop_data(atoms, wrap, labels, 

881 loop_keys) 

882 

883 headers = [ 

884 '_atom_site_type_symbol', 

885 '_atom_site_label', 

886 '_atom_site_symmetry_multiplicity', 

887 *coord_headers, 

888 '_atom_site_occupancy', 

889 ] 

890 

891 headers += ['_' + key for key in loop_keys] 

892 

893 loop = CIFLoop() 

894 for header in headers: 

895 array, fmt = loopdata[header] 

896 loop.add(header, array, fmt) 

897 

898 fd.write(loop.tostring())