Coverage for /builds/debichem-team/python-ase/ase/optimize/fire.py: 92.94%
85 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1from typing import IO, Any, Callable, Dict, List, Optional, Union
3import numpy as np
5from ase import Atoms
6from ase.optimize.optimize import Optimizer
7from ase.utils import deprecated
10def _forbid_maxmove(args: List, kwargs: Dict[str, Any]) -> bool:
11 """Set maxstep with maxmove if not set."""
12 maxstep_index = 6
13 maxmove_index = 7
15 def _pop_arg(name: str) -> Any:
16 to_pop = None
17 if len(args) > maxmove_index:
18 to_pop = args[maxmove_index]
19 args[maxmove_index] = None
21 elif name in kwargs:
22 to_pop = kwargs[name]
23 del kwargs[name]
24 return to_pop
26 if len(args) > maxstep_index and args[maxstep_index] is None:
27 value = args[maxstep_index] = _pop_arg("maxmove")
28 elif kwargs.get("maxstep", None) is None:
29 value = kwargs["maxstep"] = _pop_arg("maxmove")
30 else:
31 return False
33 return value is not None
36class FIRE(Optimizer):
37 @deprecated(
38 "Use of `maxmove` is deprecated. Use `maxstep` instead.",
39 category=FutureWarning,
40 callback=_forbid_maxmove,
41 )
42 def __init__(
43 self,
44 atoms: Atoms,
45 restart: Optional[str] = None,
46 logfile: Union[IO, str] = '-',
47 trajectory: Optional[str] = None,
48 dt: float = 0.1,
49 maxstep: Optional[float] = None,
50 maxmove: Optional[float] = None,
51 dtmax: float = 1.0,
52 Nmin: int = 5,
53 finc: float = 1.1,
54 fdec: float = 0.5,
55 astart: float = 0.1,
56 fa: float = 0.99,
57 a: float = 0.1,
58 downhill_check: bool = False,
59 position_reset_callback: Optional[Callable] = None,
60 **kwargs,
61 ):
62 """
64 Parameters
65 ----------
66 atoms: :class:`~ase.Atoms`
67 The Atoms object to relax.
69 restart: str
70 JSON file used to store hessian matrix. If set, file with
71 such a name will be searched and hessian matrix stored will
72 be used, if the file exists.
74 logfile: file object or str
75 If *logfile* is a string, a file with that name will be opened.
76 Use '-' for stdout.
78 trajectory: str
79 Trajectory file used to store optimisation path.
81 dt: float
82 Initial time step. Defualt value is 0.1
84 maxstep: float
85 Used to set the maximum distance an atom can move per
86 iteration (default value is 0.2).
88 dtmax: float
89 Maximum time step. Default value is 1.0
91 Nmin: int
92 Number of steps to wait after the last time the dot product of
93 the velocity and force is negative (P in The FIRE article) before
94 increasing the time step. Default value is 5.
96 finc: float
97 Factor to increase the time step. Default value is 1.1
99 fdec: float
100 Factor to decrease the time step. Default value is 0.5
102 astart: float
103 Initial value of the parameter a. a is the Coefficient for
104 mixing the velocity and the force. Called alpha in the FIRE article.
105 Default value 0.1.
107 fa: float
108 Factor to decrease the parameter alpha. Default value is 0.99
110 a: float
111 Coefficient for mixing the velocity and the force. Called
112 alpha in the FIRE article. Default value 0.1.
114 downhill_check: bool
115 Downhill check directly compares potential energies of subsequent
116 steps of the FIRE algorithm rather than relying on the current
117 product v*f that is positive if the FIRE dynamics moves downhill.
118 This can detect numerical issues where at large time steps the step
119 is uphill in energy even though locally v*f is positive, i.e. the
120 algorithm jumps over a valley because of a too large time step.
122 position_reset_callback: function(atoms, r, e, e_last)
123 Function that takes current *atoms* object, an array of position
124 *r* that the optimizer will revert to, current energy *e* and
125 energy of last step *e_last*. This is only called if e > e_last.
127 kwargs : dict, optional
128 Extra arguments passed to
129 :class:`~ase.optimize.optimize.Optimizer`.
131 .. deprecated:: 3.19.3
132 Use of ``maxmove`` is deprecated; please use ``maxstep``.
134 """
135 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs)
137 self.dt = dt
139 self.Nsteps = 0
141 if maxstep is not None:
142 self.maxstep = maxstep
143 else:
144 self.maxstep = self.defaults["maxstep"]
146 self.dtmax = dtmax
147 self.Nmin = Nmin
148 self.finc = finc
149 self.fdec = fdec
150 self.astart = astart
151 self.fa = fa
152 self.a = a
153 self.downhill_check = downhill_check
154 self.position_reset_callback = position_reset_callback
156 def initialize(self):
157 self.v = None
159 def read(self):
160 self.v, self.dt = self.load()
162 def step(self, f=None):
163 optimizable = self.optimizable
165 if f is None:
166 f = optimizable.get_forces()
168 if self.v is None:
169 self.v = np.zeros((len(optimizable), 3))
170 if self.downhill_check:
171 self.e_last = optimizable.get_potential_energy()
172 self.r_last = optimizable.get_positions().copy()
173 self.v_last = self.v.copy()
174 else:
175 is_uphill = False
176 if self.downhill_check:
177 e = optimizable.get_potential_energy()
178 # Check if the energy actually decreased
179 if e > self.e_last:
180 # If not, reset to old positions...
181 if self.position_reset_callback is not None:
182 self.position_reset_callback(
183 optimizable, self.r_last, e,
184 self.e_last)
185 optimizable.set_positions(self.r_last)
186 is_uphill = True
187 self.e_last = optimizable.get_potential_energy()
188 self.r_last = optimizable.get_positions().copy()
189 self.v_last = self.v.copy()
191 vf = np.vdot(f, self.v)
192 if vf > 0.0 and not is_uphill:
193 self.v = (1.0 - self.a) * self.v + self.a * f / np.sqrt(
194 np.vdot(f, f)) * np.sqrt(np.vdot(self.v, self.v))
195 if self.Nsteps > self.Nmin:
196 self.dt = min(self.dt * self.finc, self.dtmax)
197 self.a *= self.fa
198 self.Nsteps += 1
199 else:
200 self.v[:] *= 0.0
201 self.a = self.astart
202 self.dt *= self.fdec
203 self.Nsteps = 0
205 self.v += self.dt * f
206 dr = self.dt * self.v
207 normdr = np.sqrt(np.vdot(dr, dr))
208 if normdr > self.maxstep:
209 dr = self.maxstep * dr / normdr
210 r = optimizable.get_positions()
211 optimizable.set_positions(r + dr)
212 self.dump((self.v, self.dt))