Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# flake8: noqa 

2import numpy as np 

3pymin = min 

4pymax = max 

5 

6class LineSearch: 

7 def __init__(self, xtol=1e-14): 

8 

9 self.xtol = xtol 

10 self.task = 'START' 

11 self.isave = np.zeros((2,), np.intc) 

12 self.dsave = np.zeros((13,), float) 

13 self.fc = 0 

14 self.gc = 0 

15 self.case = 0 

16 self.old_stp = 0 

17 

18 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, 

19 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., 

20 stpmax=50., stpmin=1e-8, args=()): 

21 self.stpmin = stpmin 

22 self.pk = pk 

23 # ??? p_size = np.sqrt((pk **2).sum()) 

24 self.stpmax = stpmax 

25 self.xtrapl = xtrapl 

26 self.xtrapu = xtrapu 

27 self.maxstep = maxstep 

28 phi0 = old_fval 

29 derphi0 = np.dot(gfk,pk) 

30 self.dim = len(pk) 

31 self.gms = np.sqrt(self.dim) * maxstep 

32 #alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0) 

33 alpha1 = 1. 

34 self.no_update = False 

35 

36 if isinstance(myfprime,type(())): 

37 # eps = myfprime[1] 

38 fprime = myfprime[0] 

39 # ??? newargs = (f,eps) + args 

40 gradient = False 

41 else: 

42 fprime = myfprime 

43 newargs = args 

44 gradient = True 

45 

46 fval = old_fval 

47 gval = gfk 

48 self.steps=[] 

49 

50 while True: 

51 stp = self.step(alpha1, phi0, derphi0, c1, c2, 

52 self.xtol, 

53 self.isave, self.dsave) 

54 

55 if self.task[:2] == 'FG': 

56 alpha1 = stp 

57 fval = func(xk + stp * pk, *args) 

58 self.fc += 1 

59 gval = fprime(xk + stp * pk, *newargs) 

60 if gradient: self.gc += 1 

61 else: self.fc += len(xk) + 1 

62 phi0 = fval 

63 derphi0 = np.dot(gval,pk) 

64 self.old_stp = alpha1 

65 if self.no_update == True: 

66 break 

67 else: 

68 break 

69 

70 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': 

71 stp = None # failed 

72 return stp, fval, old_fval, self.no_update 

73 

74 def step(self, stp, f, g, c1, c2, xtol, isave, dsave): 

75 if self.task[:5] == 'START': 

76 # Check the input arguments for errors. 

77 if stp < self.stpmin: 

78 self.task = 'ERROR: STP .LT. minstep' 

79 if stp > self.stpmax: 

80 self.task = 'ERROR: STP .GT. maxstep' 

81 if g >= 0: 

82 self.task = 'ERROR: INITIAL G >= 0' 

83 if c1 < 0: 

84 self.task = 'ERROR: c1 .LT. 0' 

85 if c2 < 0: 

86 self.task = 'ERROR: c2 .LT. 0' 

87 if xtol < 0: 

88 self.task = 'ERROR: XTOL .LT. 0' 

89 if self.stpmin < 0: 

90 self.task = 'ERROR: minstep .LT. 0' 

91 if self.stpmax < self.stpmin: 

92 self.task = 'ERROR: maxstep .LT. minstep' 

93 if self.task[:5] == 'ERROR': 

94 return stp 

95 

96 # Initialize local variables. 

97 self.bracket = False 

98 stage = 1 

99 finit = f 

100 ginit = g 

101 gtest = c1 * ginit 

102 width = self.stpmax - self.stpmin 

103 width1 = width / .5 

104# The variables stx, fx, gx contain the values of the step, 

105# function, and derivative at the best step. 

106# The variables sty, fy, gy contain the values of the step, 

107# function, and derivative at sty. 

108# The variables stp, f, g contain the values of the step, 

109# function, and derivative at stp. 

110 stx = 0 

111 fx = finit 

112 gx = ginit 

113 sty = 0 

114 fy = finit 

115 gy = ginit 

116 stmin = 0 

117 stmax = stp + self.xtrapu * stp 

118 self.task = 'FG' 

119 self.save((stage, ginit, gtest, gx, 

120 gy, finit, fx, fy, stx, sty, 

121 stmin, stmax, width, width1)) 

122 stp = self.determine_step(stp) 

123 #return stp, f, g 

124 return stp 

125 else: 

126 if self.isave[0] == 1: 

127 self.bracket = True 

128 else: 

129 self.bracket = False 

130 stage = self.isave[1] 

131 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, \ 

132 width, width1) =self.dsave 

133 

134# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the 

135# algorithm enters the second stage. 

136 ftest = finit + stp * gtest 

137 if stage == 1 and f < ftest and g >= 0.: 

138 stage = 2 

139 

140# Test for warnings. 

141 if self.bracket and (stp <= stmin or stp >= stmax): 

142 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' 

143 if self.bracket and stmax - stmin <= self.xtol * stmax: 

144 self.task = 'WARNING: XTOL TEST SATISFIED' 

145 if stp == self.stpmax and f <= ftest and g <= gtest: 

146 self.task = 'WARNING: STP = maxstep' 

147 if stp == self.stpmin and (f > ftest or g >= gtest): 

148 self.task = 'WARNING: STP = minstep' 

149 

150# Test for convergence. 

151 if f <= ftest and abs(g) <= c2 * (- ginit): 

152 self.task = 'CONVERGENCE' 

153 

154# Test for termination. 

155 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': 

156 self.save((stage, ginit, gtest, gx, 

157 gy, finit, fx, fy, stx, sty, 

158 stmin, stmax, width, width1)) 

159 #return stp, f, g 

160 return stp 

161 

162# A modified function is used to predict the step during the 

163# first stage if a lower function value has been obtained but 

164# the decrease is not sufficient. 

165 #if stage == 1 and f <= fx and f > ftest: 

166# # Define the modified function and derivative values. 

167 # fm =f - stp * gtest 

168 # fxm = fx - stx * gtest 

169 # fym = fy - sty * gtest 

170 # gm = g - gtest 

171 # gxm = gx - gtest 

172 # gym = gy - gtest 

173 

174# Call step to update stx, sty, and to compute the new step. 

175 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty, 

176 # fym, gym, stp, fm, gm, 

177 # stmin, stmax) 

178 

179# # Reset the function and derivative values for f. 

180 

181 # fx = fxm + stx * gtest 

182 # fy = fym + sty * gtest 

183 # gx = gxm + gtest 

184 # gy = gym + gtest 

185 

186 #else: 

187# Call step to update stx, sty, and to compute the new step. 

188 

189 stx, sty, stp, gx, fx, gy, fy= self.update(stx, fx, gx, sty, 

190 fy, gy, stp, f, g, 

191 stmin, stmax) 

192 

193 

194# Decide if a bisection step is needed. 

195 

196 if self.bracket: 

197 if abs(sty-stx) >= .66 * width1: 

198 stp = stx + .5 * (sty - stx) 

199 width1 = width 

200 width = abs(sty - stx) 

201 

202# Set the minimum and maximum steps allowed for stp. 

203 

204 if self.bracket: 

205 stmin = min(stx, sty) 

206 stmax = max(stx, sty) 

207 else: 

208 stmin = stp + self.xtrapl * (stp - stx) 

209 stmax = stp + self.xtrapu * (stp - stx) 

210 

211# Force the step to be within the bounds maxstep and minstep. 

212 

213 stp = max(stp, self.stpmin) 

214 stp = min(stp, self.stpmax) 

215 

216 if (stx == stp and stp == self.stpmax and stmin > self.stpmax): 

217 self.no_update = True 

218# If further progress is not possible, let stp be the best 

219# point obtained during the search. 

220 

221 if (self.bracket and stp < stmin or stp >= stmax) \ 

222 or (self.bracket and stmax - stmin < self.xtol * stmax): 

223 stp = stx 

224 

225# Obtain another function and derivative. 

226 

227 self.task = 'FG' 

228 self.save((stage, ginit, gtest, gx, 

229 gy, finit, fx, fy, stx, sty, 

230 stmin, stmax, width, width1)) 

231 return stp 

232 

233 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, 

234 stpmin, stpmax): 

235 sign = gp * (gx / abs(gx)) 

236 

237# First case: A higher function value. The minimum is bracketed. 

238# If the cubic step is closer to stx than the quadratic step, the 

239# cubic step is taken, otherwise the average of the cubic and 

240# quadratic steps is taken. 

241 if fp > fx: #case1 

242 self.case = 1 

243 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

244 s = max(abs(theta), abs(gx), abs(gp)) 

245 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) 

246 if stp < stx: 

247 gamma = -gamma 

248 p = (gamma - gx) + theta 

249 q = ((gamma - gx) + gamma) + gp 

250 r = p / q 

251 stpc = stx + r * (stp - stx) 

252 stpq = stx + ((gx / ((fx - fp) / (stp-stx) + gx)) / 2.) \ 

253 * (stp - stx) 

254 if (abs(stpc - stx) < abs(stpq - stx)): 

255 stpf = stpc 

256 else: 

257 stpf = stpc + (stpq - stpc) / 2. 

258 

259 self.bracket = True 

260 

261# Second case: A lower function value and derivatives of opposite 

262# sign. The minimum is bracketed. If the cubic step is farther from 

263# stp than the secant step, the cubic step is taken, otherwise the 

264# secant step is taken. 

265 

266 elif sign < 0: #case2 

267 self.case = 2 

268 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

269 s = max(abs(theta), abs(gx), abs(gp)) 

270 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) 

271 if stp > stx: 

272 gamma = -gamma 

273 p = (gamma - gp) + theta 

274 q = ((gamma - gp) + gamma) + gx 

275 r = p / q 

276 stpc = stp + r * (stx - stp) 

277 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

278 if (abs(stpc - stp) > abs(stpq - stp)): 

279 stpf = stpc 

280 else: 

281 stpf = stpq 

282 self.bracket = True 

283 

284# Third case: A lower function value, derivatives of the same sign, 

285# and the magnitude of the derivative decreases. 

286 

287 elif abs(gp) < abs(gx): #case3 

288 self.case = 3 

289# The cubic step is computed only if the cubic tends to infinity 

290# in the direction of the step or if the minimum of the cubic 

291# is beyond stp. Otherwise the cubic step is defined to be the 

292# secant step. 

293 

294 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

295 s = max(abs(theta), abs(gx), abs(gp)) 

296 

297# The case gamma = 0 only arises if the cubic does not tend 

298# to infinity in the direction of the step. 

299 

300 gamma = s * np.sqrt(max(0.,(theta / s) ** 2-(gx / s) * (gp / s))) 

301 if stp > stx: 

302 gamma = -gamma 

303 p = (gamma - gp) + theta 

304 q = (gamma + (gx - gp)) + gamma 

305 r = p / q 

306 if r < 0. and gamma != 0: 

307 stpc = stp + r * (stx - stp) 

308 elif stp > stx: 

309 stpc = stpmax 

310 else: 

311 stpc = stpmin 

312 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

313 

314 if self.bracket: 

315 

316# A minimizer has been bracketed. If the cubic step is 

317# closer to stp than the secant step, the cubic step is 

318# taken, otherwise the secant step is taken. 

319 

320 if abs(stpc - stp) < abs(stpq - stp): 

321 stpf = stpc 

322 else: 

323 stpf = stpq 

324 if stp > stx: 

325 stpf = min(stp + .66 * (sty - stp), stpf) 

326 else: 

327 stpf = max(stp + .66 * (sty - stp), stpf) 

328 else: 

329 

330# A minimizer has not been bracketed. If the cubic step is 

331# farther from stp than the secant step, the cubic step is 

332# taken, otherwise the secant step is taken. 

333 

334 if abs(stpc - stp) > abs(stpq - stp): 

335 stpf = stpc 

336 else: 

337 stpf = stpq 

338 stpf = min(stpmax, stpf) 

339 stpf = max(stpmin, stpf) 

340 

341# Fourth case: A lower function value, derivatives of the same sign, 

342# and the magnitude of the derivative does not decrease. If the 

343# minimum is not bracketed, the step is either minstep or maxstep, 

344# otherwise the cubic step is taken. 

345 

346 else: #case4 

347 self.case = 4 

348 if self.bracket: 

349 theta = 3. * (fp - fy) / (sty - stp) + gy + gp 

350 s = max(abs(theta), abs(gy), abs(gp)) 

351 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) 

352 if stp > sty: 

353 gamma = -gamma 

354 p = (gamma - gp) + theta 

355 q = ((gamma - gp) + gamma) + gy 

356 r = p / q 

357 stpc = stp + r * (sty - stp) 

358 stpf = stpc 

359 elif stp > stx: 

360 stpf = stpmax 

361 else: 

362 stpf = stpmin 

363 

364# Update the interval which contains a minimizer. 

365 

366 if fp > fx: 

367 sty = stp 

368 fy = fp 

369 gy = gp 

370 else: 

371 if sign < 0: 

372 sty = stx 

373 fy = fx 

374 gy = gx 

375 stx = stp 

376 fx = fp 

377 gx = gp 

378# Compute the new step. 

379 

380 stp = self.determine_step(stpf) 

381 

382 return stx, sty, stp, gx, fx, gy, fy 

383 

384 def determine_step(self, stp): 

385 dr = stp - self.old_stp 

386 x = np.reshape(self.pk, (-1, 3)) 

387 steplengths = ((dr*x)**2).sum(1)**0.5 

388 maxsteplength = pymax(steplengths) 

389 if maxsteplength >= self.maxstep: 

390 dr *= self.maxstep / maxsteplength 

391 stp = self.old_stp + dr 

392 return stp 

393 

394 def save(self, data): 

395 if self.bracket: 

396 self.isave[0] = 1 

397 else: 

398 self.isave[0] = 0 

399 self.isave[1] = data[0] 

400 self.dsave = data[1:]