Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# flake8: noqa
2import numpy as np
3pymin = min
4pymax = max
6class LineSearch:
7 def __init__(self, xtol=1e-14):
9 self.xtol = xtol
10 self.task = 'START'
11 self.isave = np.zeros((2,), np.intc)
12 self.dsave = np.zeros((13,), float)
13 self.fc = 0
14 self.gc = 0
15 = 0
16 self.old_stp = 0
18 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
19 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
20 stpmax=50., stpmin=1e-8, args=()):
21 self.stpmin = stpmin
22 = pk
23 # ??? p_size = np.sqrt((pk **2).sum())
24 self.stpmax = stpmax
25 self.xtrapl = xtrapl
26 self.xtrapu = xtrapu
27 self.maxstep = maxstep
28 phi0 = old_fval
29 derphi0 =,pk)
30 self.dim = len(pk)
31 self.gms = np.sqrt(self.dim) * maxstep
32 #alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
33 alpha1 = 1.
34 self.no_update = False
36 if isinstance(myfprime,type(())):
37 # eps = myfprime[1]
38 fprime = myfprime[0]
39 # ??? newargs = (f,eps) + args
40 gradient = False
41 else:
42 fprime = myfprime
43 newargs = args
44 gradient = True
46 fval = old_fval
47 gval = gfk
48 self.steps=[]
50 while True:
51 stp = self.step(alpha1, phi0, derphi0, c1, c2,
52 self.xtol,
53 self.isave, self.dsave)
55 if self.task[:2] == 'FG':
56 alpha1 = stp
57 fval = func(xk + stp * pk, *args)
58 self.fc += 1
59 gval = fprime(xk + stp * pk, *newargs)
60 if gradient: self.gc += 1
61 else: self.fc += len(xk) + 1
62 phi0 = fval
63 derphi0 =,pk)
64 self.old_stp = alpha1
65 if self.no_update == True:
66 break
67 else:
68 break
70 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
71 stp = None # failed
72 return stp, fval, old_fval, self.no_update
74 def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
75 if self.task[:5] == 'START':
76 # Check the input arguments for errors.
77 if stp < self.stpmin:
78 self.task = 'ERROR: STP .LT. minstep'
79 if stp > self.stpmax:
80 self.task = 'ERROR: STP .GT. maxstep'
81 if g >= 0:
82 self.task = 'ERROR: INITIAL G >= 0'
83 if c1 < 0:
84 self.task = 'ERROR: c1 .LT. 0'
85 if c2 < 0:
86 self.task = 'ERROR: c2 .LT. 0'
87 if xtol < 0:
88 self.task = 'ERROR: XTOL .LT. 0'
89 if self.stpmin < 0:
90 self.task = 'ERROR: minstep .LT. 0'
91 if self.stpmax < self.stpmin:
92 self.task = 'ERROR: maxstep .LT. minstep'
93 if self.task[:5] == 'ERROR':
94 return stp
96 # Initialize local variables.
97 self.bracket = False
98 stage = 1
99 finit = f
100 ginit = g
101 gtest = c1 * ginit
102 width = self.stpmax - self.stpmin
103 width1 = width / .5
104# The variables stx, fx, gx contain the values of the step,
105# function, and derivative at the best step.
106# The variables sty, fy, gy contain the values of the step,
107# function, and derivative at sty.
108# The variables stp, f, g contain the values of the step,
109# function, and derivative at stp.
110 stx = 0
111 fx = finit
112 gx = ginit
113 sty = 0
114 fy = finit
115 gy = ginit
116 stmin = 0
117 stmax = stp + self.xtrapu * stp
118 self.task = 'FG'
119, ginit, gtest, gx,
120 gy, finit, fx, fy, stx, sty,
121 stmin, stmax, width, width1))
122 stp = self.determine_step(stp)
123 #return stp, f, g
124 return stp
125 else:
126 if self.isave[0] == 1:
127 self.bracket = True
128 else:
129 self.bracket = False
130 stage = self.isave[1]
131 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, \
132 width, width1) =self.dsave
134# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
135# algorithm enters the second stage.
136 ftest = finit + stp * gtest
137 if stage == 1 and f < ftest and g >= 0.:
138 stage = 2
140# Test for warnings.
141 if self.bracket and (stp <= stmin or stp >= stmax):
143 if self.bracket and stmax - stmin <= self.xtol * stmax:
145 if stp == self.stpmax and f <= ftest and g <= gtest:
146 self.task = 'WARNING: STP = maxstep'
147 if stp == self.stpmin and (f > ftest or g >= gtest):
148 self.task = 'WARNING: STP = minstep'
150# Test for convergence.
151 if f <= ftest and abs(g) <= c2 * (- ginit):
152 self.task = 'CONVERGENCE'
154# Test for termination.
155 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
156, ginit, gtest, gx,
157 gy, finit, fx, fy, stx, sty,
158 stmin, stmax, width, width1))
159 #return stp, f, g
160 return stp
162# A modified function is used to predict the step during the
163# first stage if a lower function value has been obtained but
164# the decrease is not sufficient.
165 #if stage == 1 and f <= fx and f > ftest:
166# # Define the modified function and derivative values.
167 # fm =f - stp * gtest
168 # fxm = fx - stx * gtest
169 # fym = fy - sty * gtest
170 # gm = g - gtest
171 # gxm = gx - gtest
172 # gym = gy - gtest
174# Call step to update stx, sty, and to compute the new step.
175 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
176 # fym, gym, stp, fm, gm,
177 # stmin, stmax)
179# # Reset the function and derivative values for f.
181 # fx = fxm + stx * gtest
182 # fy = fym + sty * gtest
183 # gx = gxm + gtest
184 # gy = gym + gtest
186 #else:
187# Call step to update stx, sty, and to compute the new step.
189 stx, sty, stp, gx, fx, gy, fy= self.update(stx, fx, gx, sty,
190 fy, gy, stp, f, g,
191 stmin, stmax)
194# Decide if a bisection step is needed.
196 if self.bracket:
197 if abs(sty-stx) >= .66 * width1:
198 stp = stx + .5 * (sty - stx)
199 width1 = width
200 width = abs(sty - stx)
202# Set the minimum and maximum steps allowed for stp.
204 if self.bracket:
205 stmin = min(stx, sty)
206 stmax = max(stx, sty)
207 else:
208 stmin = stp + self.xtrapl * (stp - stx)
209 stmax = stp + self.xtrapu * (stp - stx)
211# Force the step to be within the bounds maxstep and minstep.
213 stp = max(stp, self.stpmin)
214 stp = min(stp, self.stpmax)
216 if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
217 self.no_update = True
218# If further progress is not possible, let stp be the best
219# point obtained during the search.
221 if (self.bracket and stp < stmin or stp >= stmax) \
222 or (self.bracket and stmax - stmin < self.xtol * stmax):
223 stp = stx
225# Obtain another function and derivative.
227 self.task = 'FG'
228, ginit, gtest, gx,
229 gy, finit, fx, fy, stx, sty,
230 stmin, stmax, width, width1))
231 return stp
233 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
234 stpmin, stpmax):
235 sign = gp * (gx / abs(gx))
237# First case: A higher function value. The minimum is bracketed.
238# If the cubic step is closer to stx than the quadratic step, the
239# cubic step is taken, otherwise the average of the cubic and
240# quadratic steps is taken.
241 if fp > fx: #case1
242 = 1
243 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
244 s = max(abs(theta), abs(gx), abs(gp))
245 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
246 if stp < stx:
247 gamma = -gamma
248 p = (gamma - gx) + theta
249 q = ((gamma - gx) + gamma) + gp
250 r = p / q
251 stpc = stx + r * (stp - stx)
252 stpq = stx + ((gx / ((fx - fp) / (stp-stx) + gx)) / 2.) \
253 * (stp - stx)
254 if (abs(stpc - stx) < abs(stpq - stx)):
255 stpf = stpc
256 else:
257 stpf = stpc + (stpq - stpc) / 2.
259 self.bracket = True
261# Second case: A lower function value and derivatives of opposite
262# sign. The minimum is bracketed. If the cubic step is farther from
263# stp than the secant step, the cubic step is taken, otherwise the
264# secant step is taken.
266 elif sign < 0: #case2
267 = 2
268 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
269 s = max(abs(theta), abs(gx), abs(gp))
270 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
271 if stp > stx:
272 gamma = -gamma
273 p = (gamma - gp) + theta
274 q = ((gamma - gp) + gamma) + gx
275 r = p / q
276 stpc = stp + r * (stx - stp)
277 stpq = stp + (gp / (gp - gx)) * (stx - stp)
278 if (abs(stpc - stp) > abs(stpq - stp)):
279 stpf = stpc
280 else:
281 stpf = stpq
282 self.bracket = True
284# Third case: A lower function value, derivatives of the same sign,
285# and the magnitude of the derivative decreases.
287 elif abs(gp) < abs(gx): #case3
288 = 3
289# The cubic step is computed only if the cubic tends to infinity
290# in the direction of the step or if the minimum of the cubic
291# is beyond stp. Otherwise the cubic step is defined to be the
292# secant step.
294 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
295 s = max(abs(theta), abs(gx), abs(gp))
297# The case gamma = 0 only arises if the cubic does not tend
298# to infinity in the direction of the step.
300 gamma = s * np.sqrt(max(0.,(theta / s) ** 2-(gx / s) * (gp / s)))
301 if stp > stx:
302 gamma = -gamma
303 p = (gamma - gp) + theta
304 q = (gamma + (gx - gp)) + gamma
305 r = p / q
306 if r < 0. and gamma != 0:
307 stpc = stp + r * (stx - stp)
308 elif stp > stx:
309 stpc = stpmax
310 else:
311 stpc = stpmin
312 stpq = stp + (gp / (gp - gx)) * (stx - stp)
314 if self.bracket:
316# A minimizer has been bracketed. If the cubic step is
317# closer to stp than the secant step, the cubic step is
318# taken, otherwise the secant step is taken.
320 if abs(stpc - stp) < abs(stpq - stp):
321 stpf = stpc
322 else:
323 stpf = stpq
324 if stp > stx:
325 stpf = min(stp + .66 * (sty - stp), stpf)
326 else:
327 stpf = max(stp + .66 * (sty - stp), stpf)
328 else:
330# A minimizer has not been bracketed. If the cubic step is
331# farther from stp than the secant step, the cubic step is
332# taken, otherwise the secant step is taken.
334 if abs(stpc - stp) > abs(stpq - stp):
335 stpf = stpc
336 else:
337 stpf = stpq
338 stpf = min(stpmax, stpf)
339 stpf = max(stpmin, stpf)
341# Fourth case: A lower function value, derivatives of the same sign,
342# and the magnitude of the derivative does not decrease. If the
343# minimum is not bracketed, the step is either minstep or maxstep,
344# otherwise the cubic step is taken.
346 else: #case4
347 = 4
348 if self.bracket:
349 theta = 3. * (fp - fy) / (sty - stp) + gy + gp
350 s = max(abs(theta), abs(gy), abs(gp))
351 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
352 if stp > sty:
353 gamma = -gamma
354 p = (gamma - gp) + theta
355 q = ((gamma - gp) + gamma) + gy
356 r = p / q
357 stpc = stp + r * (sty - stp)
358 stpf = stpc
359 elif stp > stx:
360 stpf = stpmax
361 else:
362 stpf = stpmin
364# Update the interval which contains a minimizer.
366 if fp > fx:
367 sty = stp
368 fy = fp
369 gy = gp
370 else:
371 if sign < 0:
372 sty = stx
373 fy = fx
374 gy = gx
375 stx = stp
376 fx = fp
377 gx = gp
378# Compute the new step.
380 stp = self.determine_step(stpf)
382 return stx, sty, stp, gx, fx, gy, fy
384 def determine_step(self, stp):
385 dr = stp - self.old_stp
386 x = np.reshape(, (-1, 3))
387 steplengths = ((dr*x)**2).sum(1)**0.5
388 maxsteplength = pymax(steplengths)
389 if maxsteplength >= self.maxstep:
390 dr *= self.maxstep / maxsteplength
391 stp = self.old_stp + dr
392 return stp
394 def save(self, data):
395 if self.bracket:
396 self.isave[0] = 1
397 else:
398 self.isave[0] = 0
399 self.isave[1] = data[0]
400 self.dsave = data[1:]