Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# flake8: noqa
2import numpy as np
3pymin = min
4pymax = max
6class LineSearch:
7 def __init__(self, xtol=1e-14):
9 self.xtol = xtol
10 self.task = 'START'
11 self.isave = np.zeros((2,), np.intc)
12 self.dsave = np.zeros((13,), float)
13 self.fc = 0
14 self.gc = 0
15 self.case = 0
16 self.old_stp = 0
18 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
19 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
20 stpmax=50., stpmin=1e-8, args=()):
21 self.stpmin = stpmin
22 self.pk = pk
23 # ??? p_size = np.sqrt((pk **2).sum())
24 self.stpmax = stpmax
25 self.xtrapl = xtrapl
26 self.xtrapu = xtrapu
27 self.maxstep = maxstep
28 phi0 = old_fval
29 derphi0 = np.dot(gfk,pk)
30 self.dim = len(pk)
31 self.gms = np.sqrt(self.dim) * maxstep
32 #alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
33 alpha1 = 1.
34 self.no_update = False
36 if isinstance(myfprime,type(())):
37 # eps = myfprime[1]
38 fprime = myfprime[0]
39 # ??? newargs = (f,eps) + args
40 gradient = False
41 else:
42 fprime = myfprime
43 newargs = args
44 gradient = True
46 fval = old_fval
47 gval = gfk
48 self.steps=[]
50 while True:
51 stp = self.step(alpha1, phi0, derphi0, c1, c2,
52 self.xtol,
53 self.isave, self.dsave)
55 if self.task[:2] == 'FG':
56 alpha1 = stp
57 fval = func(xk + stp * pk, *args)
58 self.fc += 1
59 gval = fprime(xk + stp * pk, *newargs)
60 if gradient: self.gc += 1
61 else: self.fc += len(xk) + 1
62 phi0 = fval
63 derphi0 = np.dot(gval,pk)
64 self.old_stp = alpha1
65 if self.no_update == True:
66 break
67 else:
68 break
70 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
71 stp = None # failed
72 return stp, fval, old_fval, self.no_update
74 def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
75 if self.task[:5] == 'START':
76 # Check the input arguments for errors.
77 if stp < self.stpmin:
78 self.task = 'ERROR: STP .LT. minstep'
79 if stp > self.stpmax:
80 self.task = 'ERROR: STP .GT. maxstep'
81 if g >= 0:
82 self.task = 'ERROR: INITIAL G >= 0'
83 if c1 < 0:
84 self.task = 'ERROR: c1 .LT. 0'
85 if c2 < 0:
86 self.task = 'ERROR: c2 .LT. 0'
87 if xtol < 0:
88 self.task = 'ERROR: XTOL .LT. 0'
89 if self.stpmin < 0:
90 self.task = 'ERROR: minstep .LT. 0'
91 if self.stpmax < self.stpmin:
92 self.task = 'ERROR: maxstep .LT. minstep'
93 if self.task[:5] == 'ERROR':
94 return stp
96 # Initialize local variables.
97 self.bracket = False
98 stage = 1
99 finit = f
100 ginit = g
101 gtest = c1 * ginit
102 width = self.stpmax - self.stpmin
103 width1 = width / .5
104# The variables stx, fx, gx contain the values of the step,
105# function, and derivative at the best step.
106# The variables sty, fy, gy contain the values of the step,
107# function, and derivative at sty.
108# The variables stp, f, g contain the values of the step,
109# function, and derivative at stp.
110 stx = 0
111 fx = finit
112 gx = ginit
113 sty = 0
114 fy = finit
115 gy = ginit
116 stmin = 0
117 stmax = stp + self.xtrapu * stp
118 self.task = 'FG'
119 self.save((stage, ginit, gtest, gx,
120 gy, finit, fx, fy, stx, sty,
121 stmin, stmax, width, width1))
122 stp = self.determine_step(stp)
123 #return stp, f, g
124 return stp
125 else:
126 if self.isave[0] == 1:
127 self.bracket = True
128 else:
129 self.bracket = False
130 stage = self.isave[1]
131 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, \
132 width, width1) =self.dsave
134# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
135# algorithm enters the second stage.
136 ftest = finit + stp * gtest
137 if stage == 1 and f < ftest and g >= 0.:
138 stage = 2
140# Test for warnings.
141 if self.bracket and (stp <= stmin or stp >= stmax):
142 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
143 if self.bracket and stmax - stmin <= self.xtol * stmax:
144 self.task = 'WARNING: XTOL TEST SATISFIED'
145 if stp == self.stpmax and f <= ftest and g <= gtest:
146 self.task = 'WARNING: STP = maxstep'
147 if stp == self.stpmin and (f > ftest or g >= gtest):
148 self.task = 'WARNING: STP = minstep'
150# Test for convergence.
151 if f <= ftest and abs(g) <= c2 * (- ginit):
152 self.task = 'CONVERGENCE'
154# Test for termination.
155 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
156 self.save((stage, ginit, gtest, gx,
157 gy, finit, fx, fy, stx, sty,
158 stmin, stmax, width, width1))
159 #return stp, f, g
160 return stp
162# A modified function is used to predict the step during the
163# first stage if a lower function value has been obtained but
164# the decrease is not sufficient.
165 #if stage == 1 and f <= fx and f > ftest:
166# # Define the modified function and derivative values.
167 # fm =f - stp * gtest
168 # fxm = fx - stx * gtest
169 # fym = fy - sty * gtest
170 # gm = g - gtest
171 # gxm = gx - gtest
172 # gym = gy - gtest
174# Call step to update stx, sty, and to compute the new step.
175 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
176 # fym, gym, stp, fm, gm,
177 # stmin, stmax)
179# # Reset the function and derivative values for f.
181 # fx = fxm + stx * gtest
182 # fy = fym + sty * gtest
183 # gx = gxm + gtest
184 # gy = gym + gtest
186 #else:
187# Call step to update stx, sty, and to compute the new step.
189 stx, sty, stp, gx, fx, gy, fy= self.update(stx, fx, gx, sty,
190 fy, gy, stp, f, g,
191 stmin, stmax)
194# Decide if a bisection step is needed.
196 if self.bracket:
197 if abs(sty-stx) >= .66 * width1:
198 stp = stx + .5 * (sty - stx)
199 width1 = width
200 width = abs(sty - stx)
202# Set the minimum and maximum steps allowed for stp.
204 if self.bracket:
205 stmin = min(stx, sty)
206 stmax = max(stx, sty)
207 else:
208 stmin = stp + self.xtrapl * (stp - stx)
209 stmax = stp + self.xtrapu * (stp - stx)
211# Force the step to be within the bounds maxstep and minstep.
213 stp = max(stp, self.stpmin)
214 stp = min(stp, self.stpmax)
216 if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
217 self.no_update = True
218# If further progress is not possible, let stp be the best
219# point obtained during the search.
221 if (self.bracket and stp < stmin or stp >= stmax) \
222 or (self.bracket and stmax - stmin < self.xtol * stmax):
223 stp = stx
225# Obtain another function and derivative.
227 self.task = 'FG'
228 self.save((stage, ginit, gtest, gx,
229 gy, finit, fx, fy, stx, sty,
230 stmin, stmax, width, width1))
231 return stp
233 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
234 stpmin, stpmax):
235 sign = gp * (gx / abs(gx))
237# First case: A higher function value. The minimum is bracketed.
238# If the cubic step is closer to stx than the quadratic step, the
239# cubic step is taken, otherwise the average of the cubic and
240# quadratic steps is taken.
241 if fp > fx: #case1
242 self.case = 1
243 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
244 s = max(abs(theta), abs(gx), abs(gp))
245 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
246 if stp < stx:
247 gamma = -gamma
248 p = (gamma - gx) + theta
249 q = ((gamma - gx) + gamma) + gp
250 r = p / q
251 stpc = stx + r * (stp - stx)
252 stpq = stx + ((gx / ((fx - fp) / (stp-stx) + gx)) / 2.) \
253 * (stp - stx)
254 if (abs(stpc - stx) < abs(stpq - stx)):
255 stpf = stpc
256 else:
257 stpf = stpc + (stpq - stpc) / 2.
259 self.bracket = True
261# Second case: A lower function value and derivatives of opposite
262# sign. The minimum is bracketed. If the cubic step is farther from
263# stp than the secant step, the cubic step is taken, otherwise the
264# secant step is taken.
266 elif sign < 0: #case2
267 self.case = 2
268 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
269 s = max(abs(theta), abs(gx), abs(gp))
270 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
271 if stp > stx:
272 gamma = -gamma
273 p = (gamma - gp) + theta
274 q = ((gamma - gp) + gamma) + gx
275 r = p / q
276 stpc = stp + r * (stx - stp)
277 stpq = stp + (gp / (gp - gx)) * (stx - stp)
278 if (abs(stpc - stp) > abs(stpq - stp)):
279 stpf = stpc
280 else:
281 stpf = stpq
282 self.bracket = True
284# Third case: A lower function value, derivatives of the same sign,
285# and the magnitude of the derivative decreases.
287 elif abs(gp) < abs(gx): #case3
288 self.case = 3
289# The cubic step is computed only if the cubic tends to infinity
290# in the direction of the step or if the minimum of the cubic
291# is beyond stp. Otherwise the cubic step is defined to be the
292# secant step.
294 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
295 s = max(abs(theta), abs(gx), abs(gp))
297# The case gamma = 0 only arises if the cubic does not tend
298# to infinity in the direction of the step.
300 gamma = s * np.sqrt(max(0.,(theta / s) ** 2-(gx / s) * (gp / s)))
301 if stp > stx:
302 gamma = -gamma
303 p = (gamma - gp) + theta
304 q = (gamma + (gx - gp)) + gamma
305 r = p / q
306 if r < 0. and gamma != 0:
307 stpc = stp + r * (stx - stp)
308 elif stp > stx:
309 stpc = stpmax
310 else:
311 stpc = stpmin
312 stpq = stp + (gp / (gp - gx)) * (stx - stp)
314 if self.bracket:
316# A minimizer has been bracketed. If the cubic step is
317# closer to stp than the secant step, the cubic step is
318# taken, otherwise the secant step is taken.
320 if abs(stpc - stp) < abs(stpq - stp):
321 stpf = stpc
322 else:
323 stpf = stpq
324 if stp > stx:
325 stpf = min(stp + .66 * (sty - stp), stpf)
326 else:
327 stpf = max(stp + .66 * (sty - stp), stpf)
328 else:
330# A minimizer has not been bracketed. If the cubic step is
331# farther from stp than the secant step, the cubic step is
332# taken, otherwise the secant step is taken.
334 if abs(stpc - stp) > abs(stpq - stp):
335 stpf = stpc
336 else:
337 stpf = stpq
338 stpf = min(stpmax, stpf)
339 stpf = max(stpmin, stpf)
341# Fourth case: A lower function value, derivatives of the same sign,
342# and the magnitude of the derivative does not decrease. If the
343# minimum is not bracketed, the step is either minstep or maxstep,
344# otherwise the cubic step is taken.
346 else: #case4
347 self.case = 4
348 if self.bracket:
349 theta = 3. * (fp - fy) / (sty - stp) + gy + gp
350 s = max(abs(theta), abs(gy), abs(gp))
351 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
352 if stp > sty:
353 gamma = -gamma
354 p = (gamma - gp) + theta
355 q = ((gamma - gp) + gamma) + gy
356 r = p / q
357 stpc = stp + r * (sty - stp)
358 stpf = stpc
359 elif stp > stx:
360 stpf = stpmax
361 else:
362 stpf = stpmin
364# Update the interval which contains a minimizer.
366 if fp > fx:
367 sty = stp
368 fy = fp
369 gy = gp
370 else:
371 if sign < 0:
372 sty = stx
373 fy = fx
374 gy = gx
375 stx = stp
376 fx = fp
377 gx = gp
378# Compute the new step.
380 stp = self.determine_step(stpf)
382 return stx, sty, stp, gx, fx, gy, fy
384 def determine_step(self, stp):
385 dr = stp - self.old_stp
386 x = np.reshape(self.pk, (-1, 3))
387 steplengths = ((dr*x)**2).sum(1)**0.5
388 maxsteplength = pymax(steplengths)
389 if maxsteplength >= self.maxstep:
390 dr *= self.maxstep / maxsteplength
391 stp = self.old_stp + dr
392 return stp
394 def save(self, data):
395 if self.bracket:
396 self.isave[0] = 1
397 else:
398 self.isave[0] = 0
399 self.isave[1] = data[0]
400 self.dsave = data[1:]