Coverage for /builds/debichem-team/python-ase/ase/optimize/precon/fire.py: 85.22%

115 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import time 

2 

3import numpy as np 

4 

5from ase.filters import UnitCellFilter 

6from ase.optimize.optimize import Optimizer 

7 

8 

9class PreconFIRE(Optimizer): 

10 

11 def __init__(self, atoms, restart=None, logfile='-', trajectory=None, 

12 dt=0.1, maxmove=0.2, dtmax=1.0, Nmin=5, finc=1.1, fdec=0.5, 

13 astart=0.1, fa=0.99, a=0.1, theta=0.1, 

14 precon=None, use_armijo=True, variable_cell=False, **kwargs): 

15 """ 

16 Preconditioned version of the FIRE optimizer 

17 

18 In time this implementation is expected to replace 

19 :class:`~ase.optimize.fire.FIRE`. 

20 

21 Parameters 

22 ---------- 

23 atoms: :class:`~ase.Atoms` 

24 The Atoms object to relax. 

25 

26 restart: string 

27 JSON file used to store hessian matrix. If set, file with 

28 such a name will be searched and hessian matrix stored will 

29 be used, if the file exists. 

30 

31 trajectory: string 

32 Trajectory file used to store optimisation path. 

33 

34 logfile: file object or str 

35 If *logfile* is a string, a file with that name will be opened. 

36 Use '-' for stdout. 

37 

38 variable_cell: bool 

39 If True, wrap atoms in UnitCellFilter to relax cell and positions. 

40 

41 kwargs : dict, optional 

42 Extra arguments passed to 

43 :class:`~ase.optimize.optimize.Optimizer`. 

44 

45 """ 

46 if variable_cell: 

47 atoms = UnitCellFilter(atoms) 

48 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs) 

49 

50 self._actual_atoms = atoms 

51 

52 self.dt = dt 

53 self.Nsteps = 0 

54 self.maxmove = maxmove 

55 self.dtmax = dtmax 

56 self.Nmin = Nmin 

57 self.finc = finc 

58 self.fdec = fdec 

59 self.astart = astart 

60 self.fa = fa 

61 self.a = a 

62 self.theta = theta 

63 self.precon = precon 

64 self.use_armijo = use_armijo 

65 

66 def initialize(self): 

67 self.v = None 

68 self.skip_flag = False 

69 self.e1 = None 

70 

71 def read(self): 

72 self.v, self.dt = self.load() 

73 

74 def step(self, f=None): 

75 atoms = self._actual_atoms 

76 

77 if f is None: 

78 f = atoms.get_forces() 

79 

80 r = atoms.get_positions() 

81 

82 if self.precon is not None: 

83 # Can this be moved out of the step method? 

84 self.precon.make_precon(atoms) 

85 invP_f = self.precon.solve(f.reshape(-1)).reshape(len(atoms), -1) 

86 

87 if self.v is None: 

88 self.v = np.zeros((len(self._actual_atoms), 3)) 

89 else: 

90 if self.use_armijo: 

91 

92 if self.precon is None: 

93 v_test = self.v + self.dt * f 

94 else: 

95 v_test = self.v + self.dt * invP_f 

96 

97 r_test = r + self.dt * v_test 

98 

99 self.skip_flag = False 

100 func_val = self.func(r_test) 

101 self.e1 = func_val 

102 if (func_val > self.func(r) - 

103 self.theta * self.dt * np.vdot(v_test, f)): 

104 self.v[:] *= 0.0 

105 self.a = self.astart 

106 self.dt *= self.fdec 

107 self.Nsteps = 0 

108 self.skip_flag = True 

109 

110 if not self.skip_flag: 

111 

112 v_f = np.vdot(self.v, f) 

113 if v_f > 0.0: 

114 if self.precon is None: 

115 self.v = (1.0 - self.a) * self.v + self.a * f / \ 

116 np.sqrt(np.vdot(f, f)) * \ 

117 np.sqrt(np.vdot(self.v, self.v)) 

118 else: 

119 self.v = ( 

120 (1.0 - self.a) * self.v + 

121 self.a * 

122 (np.sqrt(self.precon.dot(self.v.reshape(-1), 

123 self.v.reshape(-1))) / 

124 np.sqrt(np.dot(f.reshape(-1), 

125 invP_f.reshape(-1))) * invP_f)) 

126 if self.Nsteps > self.Nmin: 

127 self.dt = min(self.dt * self.finc, self.dtmax) 

128 self.a *= self.fa 

129 self.Nsteps += 1 

130 else: 

131 self.v[:] *= 0.0 

132 self.a = self.astart 

133 self.dt *= self.fdec 

134 self.Nsteps = 0 

135 

136 if self.precon is None: 

137 self.v += self.dt * f 

138 else: 

139 self.v += self.dt * invP_f 

140 dr = self.dt * self.v 

141 normdr = np.sqrt(np.vdot(dr, dr)) 

142 if normdr > self.maxmove: 

143 dr = self.maxmove * dr / normdr 

144 atoms.set_positions(r + dr) 

145 self.dump((self.v, self.dt)) 

146 

147 def func(self, x): 

148 """Objective function for use of the optimizers""" 

149 self._actual_atoms.set_positions(x.reshape(-1, 3)) 

150 potl = self._actual_atoms.get_potential_energy() 

151 return potl 

152 

153 def run(self, fmax=0.05, steps=100000000, smax=None): 

154 if smax is None: 

155 smax = fmax 

156 self.smax = smax 

157 return Optimizer.run(self, fmax, steps) 

158 

159 def converged(self, forces=None): 

160 """Did the optimization converge?""" 

161 if forces is None: 

162 forces = self._actual_atoms.get_forces() 

163 if isinstance(self._actual_atoms, UnitCellFilter): 

164 natoms = len(self._actual_atoms.atoms) 

165 forces, stress = forces[:natoms], self._actual_atoms.stress 

166 fmax_sq = (forces**2).sum(axis=1).max() 

167 smax_sq = (stress**2).max() 

168 return (fmax_sq < self.fmax**2 and smax_sq < self.smax**2) 

169 else: 

170 fmax_sq = (forces**2).sum(axis=1).max() 

171 return fmax_sq < self.fmax**2 

172 

173 def log(self, forces=None): 

174 if forces is None: 

175 forces = self._actual_atoms.get_forces() 

176 if isinstance(self._actual_atoms, UnitCellFilter): 

177 natoms = len(self._actual_atoms.atoms) 

178 forces, stress = forces[:natoms], self._actual_atoms.stress 

179 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

180 smax = np.sqrt((stress**2).max()) 

181 else: 

182 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

183 if self.e1 is not None: 

184 # reuse energy at end of line search to avoid extra call 

185 e = self.e1 

186 else: 

187 e = self._actual_atoms.get_potential_energy() 

188 T = time.localtime() 

189 if self.logfile is not None: 

190 name = self.__class__.__name__ 

191 if isinstance(self._actual_atoms, UnitCellFilter): 

192 self.logfile.write( 

193 '%s: %3d %02d:%02d:%02d %15.6f %12.4f %12.4f\n' % 

194 (name, self.nsteps, T[3], T[4], T[5], e, fmax, smax)) 

195 

196 else: 

197 self.logfile.write( 

198 '%s: %3d %02d:%02d:%02d %15.6f %12.4f\n' % 

199 (name, self.nsteps, T[3], T[4], T[5], e, fmax)) 

200 self.logfile.flush()