Coverage for /builds/debichem-team/python-ase/ase/gui/images.py: 71.17%
281 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import warnings
2from math import sqrt
4import numpy as np
6from ase import Atoms
7from ase.calculators.singlepoint import SinglePointCalculator
8from ase.constraints import FixAtoms
9from ase.data import covalent_radii
10from ase.geometry import find_mic
11from ase.gui.defaults import read_defaults
12from ase.gui.i18n import _
13from ase.io import read, string2index, write
16class Images:
17 def __init__(self, images=None):
18 self.covalent_radii = covalent_radii.copy()
19 self.config = read_defaults()
20 self.atom_scale = self.config['radii_scale']
21 if images is None:
22 images = [Atoms()]
23 self.initialize(images)
25 def __len__(self):
26 return len(self._images)
28 def __getitem__(self, index):
29 return self._images[index]
31 def __iter__(self):
32 return iter(self._images)
34 # XXXXXXX hack
35 # compatibility hacks while allowing variable number of atoms
36 def get_dynamic(self, atoms: Atoms) -> np.ndarray:
37 dynamic = np.ones(len(atoms), bool)
38 for constraint in atoms.constraints:
39 if isinstance(constraint, FixAtoms):
40 dynamic[constraint.index] = False
41 return dynamic
43 def set_dynamic(self, mask, value):
44 # Does not make much sense if different images have different
45 # atom counts. Attempts to apply mask to all images,
46 # to the extent possible.
47 for atoms in self:
48 dynamic = self.get_dynamic(atoms)
49 dynamic[mask[:len(atoms)]] = value
50 atoms.constraints = [c for c in atoms.constraints
51 if not isinstance(c, FixAtoms)]
52 atoms.constraints.append(FixAtoms(mask=~dynamic))
54 def scale_radii(self, scaling_factor):
55 self.covalent_radii *= scaling_factor
57 def get_energy(self, atoms: Atoms) -> np.float64:
58 try:
59 return atoms.get_potential_energy()
60 except RuntimeError:
61 return np.nan # type: ignore[return-value]
63 def get_forces(self, atoms: Atoms):
64 try:
65 return atoms.get_forces(apply_constraint=False)
66 except RuntimeError:
67 return None
69 def initialize(self, images, filenames=None):
70 nimages = len(images)
71 if filenames is None:
72 filenames = [None] * nimages
73 self.filenames = filenames
75 warning = False
77 self._images = []
79 # Whether length or chemical composition changes:
80 self.have_varying_species = False
81 for i, atoms in enumerate(images):
82 # copy atoms or not? Not copying allows back-editing,
83 # but copying actually forgets things like the attached
84 # calculator (might have forces/energies
85 self._images.append(atoms)
86 self.have_varying_species |= not np.array_equal(self[0].numbers,
87 atoms.numbers)
88 if hasattr(self, 'Q'):
89 assert False # XXX askhl fix quaternions
90 self.Q[i] = atoms.get_quaternions()
91 if (atoms.pbc != self[0].pbc).any():
92 warning = True
94 if warning:
95 import warnings
96 warnings.warn('Not all images have the same boundary conditions!')
98 self.maxnatoms = max(len(atoms) for atoms in self)
99 self.selected = np.zeros(self.maxnatoms, bool)
100 self.selected_ordered = []
101 self.visible = np.ones(self.maxnatoms, bool)
102 self.repeat = np.ones(3, int)
104 def get_radii(self, atoms: Atoms) -> np.ndarray:
105 radii = np.array([self.covalent_radii[z] for z in atoms.numbers])
106 radii *= self.atom_scale
107 return radii
109 def read(self, filenames, default_index=':', filetype=None):
110 if isinstance(default_index, str):
111 default_index = string2index(default_index)
113 images = []
114 names = []
115 for filename in filenames:
116 from ase.io.formats import parse_filename
118 if '@' in filename and 'postgres' not in filename or \
119 'postgres' in filename and filename.count('@') == 2:
120 actual_filename, index = parse_filename(filename, None)
121 else:
122 actual_filename, index = parse_filename(filename,
123 default_index)
125 # Read from stdin:
126 if filename == '-':
127 import sys
128 from io import BytesIO
129 buf = BytesIO(sys.stdin.buffer.read())
130 buf.seek(0)
131 filename = buf
132 filetype = 'traj'
134 imgs = read(filename, index, filetype)
135 if hasattr(imgs, 'iterimages'):
136 imgs = list(imgs.iterimages())
138 images.extend(imgs)
140 # Name each file as filename@index:
141 if isinstance(index, slice):
142 start = index.start or 0
143 step = index.step or 1
144 else:
145 start = index
146 step = 1
147 for i, img in enumerate(imgs):
148 if isinstance(start, int):
149 names.append('{}@{}'.format(
150 actual_filename, start + i * step))
151 else:
152 names.append(f'{actual_filename}@{start}')
154 self.initialize(images, names)
156 def repeat_results(self, atoms: Atoms, repeat=None, oldprod=None):
157 """Return a dictionary which updates the magmoms, energy and forces
158 to the repeated amount of atoms.
159 """
160 def getresult(name, get_quantity):
161 # ase/io/trajectory.py line 170 does this by using
162 # the get_property(prop, atoms, allow_calculation=False)
163 # so that is an alternative option.
164 try:
165 if (not atoms.calc or
166 atoms.calc.calculation_required(atoms, [name])):
167 quantity = None
168 else:
169 quantity = get_quantity()
170 except Exception as err:
171 quantity = None
172 errmsg = ('An error occurred while retrieving {} '
173 'from the calculator: {}'.format(name, err))
174 warnings.warn(errmsg)
175 return quantity
177 if repeat is None:
178 repeat = self.repeat.prod()
179 if oldprod is None:
180 oldprod = self.repeat.prod()
182 results = {}
184 original_length = len(atoms) // oldprod
185 newprod = repeat.prod()
187 # Read the old properties
188 magmoms = getresult('magmoms', atoms.get_magnetic_moments)
189 magmom = getresult('magmom', atoms.get_magnetic_moment)
190 energy = getresult('energy', atoms.get_potential_energy)
191 forces = getresult('forces', atoms.get_forces)
193 # Update old properties to the repeated image
194 if magmoms is not None:
195 magmoms = np.tile(magmoms[:original_length], newprod)
196 results['magmoms'] = magmoms
198 if magmom is not None:
199 magmom = magmom * newprod / oldprod
200 results['magmom'] = magmom
202 if forces is not None:
203 forces = np.tile(forces[:original_length].T, newprod).T
204 results['forces'] = forces
206 if energy is not None:
207 energy = energy * newprod / oldprod
208 results['energy'] = energy
210 return results
212 def repeat_unit_cell(self):
213 for atoms in self:
214 # Get quantities taking into account current repeat():'
215 results = self.repeat_results(atoms, self.repeat.prod(),
216 oldprod=self.repeat.prod())
218 atoms.cell *= self.repeat.reshape((3, 1))
219 atoms.calc = SinglePointCalculator(atoms, **results)
220 self.repeat = np.ones(3, int)
222 def repeat_images(self, repeat):
223 from ase.constraints import FixAtoms
224 repeat = np.array(repeat)
225 oldprod = self.repeat.prod()
226 images = []
227 constraints_removed = False
229 for i, atoms in enumerate(self):
230 refcell = atoms.get_cell()
231 fa = []
232 for c in atoms._constraints:
233 if isinstance(c, FixAtoms):
234 fa.append(c)
235 else:
236 constraints_removed = True
237 atoms.set_constraint(fa)
239 # Update results dictionary to repeated atoms
240 results = self.repeat_results(atoms, repeat, oldprod)
242 del atoms[len(atoms) // oldprod:] # Original atoms
244 atoms *= repeat
245 atoms.cell = refcell
247 atoms.calc = SinglePointCalculator(atoms, **results)
249 images.append(atoms)
251 if constraints_removed:
252 from ase.gui.ui import showwarning, tk
254 # We must be able to show warning before the main GUI
255 # has been created. So we create a new window,
256 # then show the warning, then destroy the window.
257 tmpwindow = tk.Tk()
258 tmpwindow.withdraw() # Host window will never be shown
259 showwarning(_('Constraints discarded'),
260 _('Constraints other than FixAtoms '
261 'have been discarded.'))
262 tmpwindow.destroy()
264 self.initialize(images, filenames=self.filenames)
265 self.repeat = repeat
267 def center(self):
268 """Center each image in the existing unit cell, keeping the
269 cell constant."""
270 for atoms in self:
271 atoms.center()
273 def graph(self, expr: str) -> np.ndarray:
274 """Routine to create the data in graphs, defined by the
275 string expr."""
276 import ase.units as units
277 code = compile(expr + ',', '<input>', 'eval')
279 nimages = len(self)
281 def d(n1, n2):
282 return sqrt(((R[n1] - R[n2])**2).sum())
284 def a(n1, n2, n3):
285 v1 = R[n1] - R[n2]
286 v2 = R[n3] - R[n2]
287 arg = np.vdot(v1, v2) / (sqrt((v1**2).sum() * (v2**2).sum()))
288 if arg > 1.0:
289 arg = 1.0
290 if arg < -1.0:
291 arg = -1.0
292 return 180.0 * np.arccos(arg) / np.pi
294 def dih(n1, n2, n3, n4):
295 # vector 0->1, 1->2, 2->3 and their normalized cross products:
296 a = R[n2] - R[n1]
297 b = R[n3] - R[n2]
298 c = R[n4] - R[n3]
299 bxa = np.cross(b, a)
300 bxa /= np.sqrt(np.vdot(bxa, bxa))
301 cxb = np.cross(c, b)
302 cxb /= np.sqrt(np.vdot(cxb, cxb))
303 angle = np.vdot(bxa, cxb)
304 # check for numerical trouble due to finite precision:
305 if angle < -1:
306 angle = -1
307 if angle > 1:
308 angle = 1
309 angle = np.arccos(angle)
310 if np.vdot(bxa, c) > 0:
311 angle = 2 * np.pi - angle
312 return angle * 180.0 / np.pi
314 # get number of mobile atoms for temperature calculation
315 E = np.array([self.get_energy(atoms) for atoms in self])
317 s = 0.0
319 # Namespace for eval:
320 ns = {'E': E,
321 'd': d, 'a': a, 'dih': dih}
323 data = []
324 for i in range(nimages):
325 ns['i'] = i
326 ns['s'] = s
327 ns['R'] = R = self[i].get_positions()
328 ns['V'] = self[i].get_velocities()
329 F = self.get_forces(self[i])
330 if F is not None:
331 ns['F'] = F
332 ns['A'] = self[i].get_cell()
333 ns['M'] = self[i].get_masses()
334 # XXX askhl verify:
335 dynamic = self.get_dynamic(self[i])
336 if F is not None:
337 ns['f'] = f = ((F * dynamic[:, None])**2).sum(1)**.5
338 ns['fmax'] = max(f)
339 ns['fave'] = f.mean()
340 ns['epot'] = epot = E[i]
341 ns['ekin'] = ekin = self[i].get_kinetic_energy()
342 ns['e'] = epot + ekin
343 ndynamic = dynamic.sum()
344 if ndynamic > 0:
345 ns['T'] = 2.0 * ekin / (3.0 * ndynamic * units.kB)
346 data = eval(code, ns)
347 if i == 0:
348 nvariables = len(data)
349 xy = np.empty((nvariables, nimages))
350 xy[:, i] = data
351 if i + 1 < nimages and not self.have_varying_species:
352 dR = find_mic(self[i + 1].positions - R, self[i].get_cell(),
353 self[i].get_pbc())[0]
354 s += sqrt((dR**2).sum())
355 return xy
357 def write(self, filename, rotations='', bbox=None,
358 **kwargs):
359 # XXX We should show the unit cell whenever there is one
360 indices = range(len(self))
361 p = filename.rfind('@')
362 if p != -1:
363 try:
364 slice = string2index(filename[p + 1:])
365 except ValueError:
366 pass
367 else:
368 indices = indices[slice]
369 filename = filename[:p]
370 if isinstance(indices, int):
371 indices = [indices]
373 images = [self.get_atoms(i) for i in indices]
374 if len(filename) > 4 and filename[-4:] in ['.eps', '.png', '.pov']:
375 write(filename, images,
376 rotation=rotations,
377 bbox=bbox, **kwargs)
378 else:
379 write(filename, images, **kwargs)
381 def get_atoms(self, frame, remove_hidden=False):
382 atoms = self[frame]
383 try:
384 E = atoms.get_potential_energy()
385 except RuntimeError:
386 E = None
387 try:
388 F = atoms.get_forces()
389 except RuntimeError:
390 F = None
392 # Remove hidden atoms if applicable
393 if remove_hidden:
394 atoms = atoms[self.visible]
395 if F is not None:
396 F = F[self.visible]
397 atoms.calc = SinglePointCalculator(atoms, energy=E, forces=F)
398 return atoms
400 def delete(self, i):
401 self._images.pop(i)
402 self.filenames.pop(i)
403 self.initialize(self._images, self.filenames)