Coverage for /builds/debichem-team/python-ase/ase/optimize/precon/fire.py: 85.22%
115 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import time
3import numpy as np
5from ase.filters import UnitCellFilter
6from ase.optimize.optimize import Optimizer
9class PreconFIRE(Optimizer):
11 def __init__(self, atoms, restart=None, logfile='-', trajectory=None,
12 dt=0.1, maxmove=0.2, dtmax=1.0, Nmin=5, finc=1.1, fdec=0.5,
13 astart=0.1, fa=0.99, a=0.1, theta=0.1,
14 precon=None, use_armijo=True, variable_cell=False, **kwargs):
15 """
16 Preconditioned version of the FIRE optimizer
18 In time this implementation is expected to replace
19 :class:`~ase.optimize.fire.FIRE`.
21 Parameters
22 ----------
23 atoms: :class:`~ase.Atoms`
24 The Atoms object to relax.
26 restart: string
27 JSON file used to store hessian matrix. If set, file with
28 such a name will be searched and hessian matrix stored will
29 be used, if the file exists.
31 trajectory: string
32 Trajectory file used to store optimisation path.
34 logfile: file object or str
35 If *logfile* is a string, a file with that name will be opened.
36 Use '-' for stdout.
38 variable_cell: bool
39 If True, wrap atoms in UnitCellFilter to relax cell and positions.
41 kwargs : dict, optional
42 Extra arguments passed to
43 :class:`~ase.optimize.optimize.Optimizer`.
45 """
46 if variable_cell:
47 atoms = UnitCellFilter(atoms)
48 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs)
50 self._actual_atoms = atoms
52 self.dt = dt
53 self.Nsteps = 0
54 self.maxmove = maxmove
55 self.dtmax = dtmax
56 self.Nmin = Nmin
57 self.finc = finc
58 self.fdec = fdec
59 self.astart = astart
60 self.fa = fa
61 self.a = a
62 self.theta = theta
63 self.precon = precon
64 self.use_armijo = use_armijo
66 def initialize(self):
67 self.v = None
68 self.skip_flag = False
69 self.e1 = None
71 def read(self):
72 self.v, self.dt = self.load()
74 def step(self, f=None):
75 atoms = self._actual_atoms
77 if f is None:
78 f = atoms.get_forces()
80 r = atoms.get_positions()
82 if self.precon is not None:
83 # Can this be moved out of the step method?
84 self.precon.make_precon(atoms)
85 invP_f = self.precon.solve(f.reshape(-1)).reshape(len(atoms), -1)
87 if self.v is None:
88 self.v = np.zeros((len(self._actual_atoms), 3))
89 else:
90 if self.use_armijo:
92 if self.precon is None:
93 v_test = self.v + self.dt * f
94 else:
95 v_test = self.v + self.dt * invP_f
97 r_test = r + self.dt * v_test
99 self.skip_flag = False
100 func_val = self.func(r_test)
101 self.e1 = func_val
102 if (func_val > self.func(r) -
103 self.theta * self.dt * np.vdot(v_test, f)):
104 self.v[:] *= 0.0
105 self.a = self.astart
106 self.dt *= self.fdec
107 self.Nsteps = 0
108 self.skip_flag = True
110 if not self.skip_flag:
112 v_f = np.vdot(self.v, f)
113 if v_f > 0.0:
114 if self.precon is None:
115 self.v = (1.0 - self.a) * self.v + self.a * f / \
116 np.sqrt(np.vdot(f, f)) * \
117 np.sqrt(np.vdot(self.v, self.v))
118 else:
119 self.v = (
120 (1.0 - self.a) * self.v +
121 self.a *
122 (np.sqrt(self.precon.dot(self.v.reshape(-1),
123 self.v.reshape(-1))) /
124 np.sqrt(np.dot(f.reshape(-1),
125 invP_f.reshape(-1))) * invP_f))
126 if self.Nsteps > self.Nmin:
127 self.dt = min(self.dt * self.finc, self.dtmax)
128 self.a *= self.fa
129 self.Nsteps += 1
130 else:
131 self.v[:] *= 0.0
132 self.a = self.astart
133 self.dt *= self.fdec
134 self.Nsteps = 0
136 if self.precon is None:
137 self.v += self.dt * f
138 else:
139 self.v += self.dt * invP_f
140 dr = self.dt * self.v
141 normdr = np.sqrt(np.vdot(dr, dr))
142 if normdr > self.maxmove:
143 dr = self.maxmove * dr / normdr
144 atoms.set_positions(r + dr)
145 self.dump((self.v, self.dt))
147 def func(self, x):
148 """Objective function for use of the optimizers"""
149 self._actual_atoms.set_positions(x.reshape(-1, 3))
150 potl = self._actual_atoms.get_potential_energy()
151 return potl
153 def run(self, fmax=0.05, steps=100000000, smax=None):
154 if smax is None:
155 smax = fmax
156 self.smax = smax
157 return Optimizer.run(self, fmax, steps)
159 def converged(self, forces=None):
160 """Did the optimization converge?"""
161 if forces is None:
162 forces = self._actual_atoms.get_forces()
163 if isinstance(self._actual_atoms, UnitCellFilter):
164 natoms = len(self._actual_atoms.atoms)
165 forces, stress = forces[:natoms], self._actual_atoms.stress
166 fmax_sq = (forces**2).sum(axis=1).max()
167 smax_sq = (stress**2).max()
168 return (fmax_sq < self.fmax**2 and smax_sq < self.smax**2)
169 else:
170 fmax_sq = (forces**2).sum(axis=1).max()
171 return fmax_sq < self.fmax**2
173 def log(self, forces=None):
174 if forces is None:
175 forces = self._actual_atoms.get_forces()
176 if isinstance(self._actual_atoms, UnitCellFilter):
177 natoms = len(self._actual_atoms.atoms)
178 forces, stress = forces[:natoms], self._actual_atoms.stress
179 fmax = np.sqrt((forces**2).sum(axis=1).max())
180 smax = np.sqrt((stress**2).max())
181 else:
182 fmax = np.sqrt((forces**2).sum(axis=1).max())
183 if self.e1 is not None:
184 # reuse energy at end of line search to avoid extra call
185 e = self.e1
186 else:
187 e = self._actual_atoms.get_potential_energy()
188 T = time.localtime()
189 if self.logfile is not None:
190 name = self.__class__.__name__
191 if isinstance(self._actual_atoms, UnitCellFilter):
192 self.logfile.write(
193 '%s: %3d %02d:%02d:%02d %15.6f %12.4f %12.4f\n' %
194 (name, self.nsteps, T[3], T[4], T[5], e, fmax, smax))
196 else:
197 self.logfile.write(
198 '%s: %3d %02d:%02d:%02d %15.6f %12.4f\n' %
199 (name, self.nsteps, T[3], T[4], T[5], e, fmax))
200 self.logfile.flush()