Coverage for /builds/debichem-team/python-ase/ase/optimize/fire.py: 92.94%

85 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1from typing import IO, Any, Callable, Dict, List, Optional, Union 

2 

3import numpy as np 

4 

5from ase import Atoms 

6from ase.optimize.optimize import Optimizer 

7from ase.utils import deprecated 

8 

9 

10def _forbid_maxmove(args: List, kwargs: Dict[str, Any]) -> bool: 

11 """Set maxstep with maxmove if not set.""" 

12 maxstep_index = 6 

13 maxmove_index = 7 

14 

15 def _pop_arg(name: str) -> Any: 

16 to_pop = None 

17 if len(args) > maxmove_index: 

18 to_pop = args[maxmove_index] 

19 args[maxmove_index] = None 

20 

21 elif name in kwargs: 

22 to_pop = kwargs[name] 

23 del kwargs[name] 

24 return to_pop 

25 

26 if len(args) > maxstep_index and args[maxstep_index] is None: 

27 value = args[maxstep_index] = _pop_arg("maxmove") 

28 elif kwargs.get("maxstep", None) is None: 

29 value = kwargs["maxstep"] = _pop_arg("maxmove") 

30 else: 

31 return False 

32 

33 return value is not None 

34 

35 

36class FIRE(Optimizer): 

37 @deprecated( 

38 "Use of `maxmove` is deprecated. Use `maxstep` instead.", 

39 category=FutureWarning, 

40 callback=_forbid_maxmove, 

41 ) 

42 def __init__( 

43 self, 

44 atoms: Atoms, 

45 restart: Optional[str] = None, 

46 logfile: Union[IO, str] = '-', 

47 trajectory: Optional[str] = None, 

48 dt: float = 0.1, 

49 maxstep: Optional[float] = None, 

50 maxmove: Optional[float] = None, 

51 dtmax: float = 1.0, 

52 Nmin: int = 5, 

53 finc: float = 1.1, 

54 fdec: float = 0.5, 

55 astart: float = 0.1, 

56 fa: float = 0.99, 

57 a: float = 0.1, 

58 downhill_check: bool = False, 

59 position_reset_callback: Optional[Callable] = None, 

60 **kwargs, 

61 ): 

62 """ 

63 

64 Parameters 

65 ---------- 

66 atoms: :class:`~ase.Atoms` 

67 The Atoms object to relax. 

68 

69 restart: str 

70 JSON file used to store hessian matrix. If set, file with 

71 such a name will be searched and hessian matrix stored will 

72 be used, if the file exists. 

73 

74 logfile: file object or str 

75 If *logfile* is a string, a file with that name will be opened. 

76 Use '-' for stdout. 

77 

78 trajectory: str 

79 Trajectory file used to store optimisation path. 

80 

81 dt: float 

82 Initial time step. Defualt value is 0.1 

83 

84 maxstep: float 

85 Used to set the maximum distance an atom can move per 

86 iteration (default value is 0.2). 

87 

88 dtmax: float 

89 Maximum time step. Default value is 1.0 

90 

91 Nmin: int 

92 Number of steps to wait after the last time the dot product of 

93 the velocity and force is negative (P in The FIRE article) before 

94 increasing the time step. Default value is 5. 

95 

96 finc: float 

97 Factor to increase the time step. Default value is 1.1 

98 

99 fdec: float 

100 Factor to decrease the time step. Default value is 0.5 

101 

102 astart: float 

103 Initial value of the parameter a. a is the Coefficient for 

104 mixing the velocity and the force. Called alpha in the FIRE article. 

105 Default value 0.1. 

106 

107 fa: float 

108 Factor to decrease the parameter alpha. Default value is 0.99 

109 

110 a: float 

111 Coefficient for mixing the velocity and the force. Called 

112 alpha in the FIRE article. Default value 0.1. 

113 

114 downhill_check: bool 

115 Downhill check directly compares potential energies of subsequent 

116 steps of the FIRE algorithm rather than relying on the current 

117 product v*f that is positive if the FIRE dynamics moves downhill. 

118 This can detect numerical issues where at large time steps the step 

119 is uphill in energy even though locally v*f is positive, i.e. the 

120 algorithm jumps over a valley because of a too large time step. 

121 

122 position_reset_callback: function(atoms, r, e, e_last) 

123 Function that takes current *atoms* object, an array of position 

124 *r* that the optimizer will revert to, current energy *e* and 

125 energy of last step *e_last*. This is only called if e > e_last. 

126 

127 kwargs : dict, optional 

128 Extra arguments passed to 

129 :class:`~ase.optimize.optimize.Optimizer`. 

130 

131 .. deprecated:: 3.19.3 

132 Use of ``maxmove`` is deprecated; please use ``maxstep``. 

133 

134 """ 

135 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs) 

136 

137 self.dt = dt 

138 

139 self.Nsteps = 0 

140 

141 if maxstep is not None: 

142 self.maxstep = maxstep 

143 else: 

144 self.maxstep = self.defaults["maxstep"] 

145 

146 self.dtmax = dtmax 

147 self.Nmin = Nmin 

148 self.finc = finc 

149 self.fdec = fdec 

150 self.astart = astart 

151 self.fa = fa 

152 self.a = a 

153 self.downhill_check = downhill_check 

154 self.position_reset_callback = position_reset_callback 

155 

156 def initialize(self): 

157 self.v = None 

158 

159 def read(self): 

160 self.v, self.dt = self.load() 

161 

162 def step(self, f=None): 

163 optimizable = self.optimizable 

164 

165 if f is None: 

166 f = optimizable.get_forces() 

167 

168 if self.v is None: 

169 self.v = np.zeros((len(optimizable), 3)) 

170 if self.downhill_check: 

171 self.e_last = optimizable.get_potential_energy() 

172 self.r_last = optimizable.get_positions().copy() 

173 self.v_last = self.v.copy() 

174 else: 

175 is_uphill = False 

176 if self.downhill_check: 

177 e = optimizable.get_potential_energy() 

178 # Check if the energy actually decreased 

179 if e > self.e_last: 

180 # If not, reset to old positions... 

181 if self.position_reset_callback is not None: 

182 self.position_reset_callback( 

183 optimizable, self.r_last, e, 

184 self.e_last) 

185 optimizable.set_positions(self.r_last) 

186 is_uphill = True 

187 self.e_last = optimizable.get_potential_energy() 

188 self.r_last = optimizable.get_positions().copy() 

189 self.v_last = self.v.copy() 

190 

191 vf = np.vdot(f, self.v) 

192 if vf > 0.0 and not is_uphill: 

193 self.v = (1.0 - self.a) * self.v + self.a * f / np.sqrt( 

194 np.vdot(f, f)) * np.sqrt(np.vdot(self.v, self.v)) 

195 if self.Nsteps > self.Nmin: 

196 self.dt = min(self.dt * self.finc, self.dtmax) 

197 self.a *= self.fa 

198 self.Nsteps += 1 

199 else: 

200 self.v[:] *= 0.0 

201 self.a = self.astart 

202 self.dt *= self.fdec 

203 self.Nsteps = 0 

204 

205 self.v += self.dt * f 

206 dr = self.dt * self.v 

207 normdr = np.sqrt(np.vdot(dr, dr)) 

208 if normdr > self.maxstep: 

209 dr = self.maxstep * dr / normdr 

210 r = optimizable.get_positions() 

211 optimizable.set_positions(r + dr) 

212 self.dump((self.v, self.dt))