Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import atexit 

3import functools 

4import pickle 

5import sys 

6import time 

7import warnings 

8 

9import numpy as np 

10 

11 

12def get_txt(txt, rank): 

13 if hasattr(txt, 'write'): 

14 # Note: User-supplied object might write to files from many ranks. 

15 return txt 

16 elif rank == 0: 

17 if txt is None: 

18 return open(os.devnull, 'w') 

19 elif txt == '-': 

20 return sys.stdout 

21 else: 

22 return open(txt, 'w', 1) 

23 else: 

24 return open(os.devnull, 'w') 

25 

26 

27def paropen(name, mode='r', buffering=-1, encoding=None, comm=None): 

28 """MPI-safe version of open function. 

29 

30 In read mode, the file is opened on all nodes. In write and 

31 append mode, the file is opened on the master only, and /dev/null 

32 is opened on all other nodes. 

33 """ 

34 if comm is None: 

35 comm = world 

36 if comm.rank > 0 and mode[0] != 'r': 

37 name = os.devnull 

38 return open(name, mode, buffering, encoding) 

39 

40 

41def parprint(*args, **kwargs): 

42 """MPI-safe print - prints only from master. """ 

43 if world.rank == 0: 

44 print(*args, **kwargs) 

45 

46 

47class DummyMPI: 

48 rank = 0 

49 size = 1 

50 

51 def _returnval(self, a, root=-1): 

52 # MPI interface works either on numbers, in which case a number is 

53 # returned, or on arrays, in-place. 

54 if np.isscalar(a): 

55 return a 

56 if hasattr(a, '__array__'): 

57 a = a.__array__() 

58 assert isinstance(a, np.ndarray) 

59 return None 

60 

61 def sum(self, a, root=-1): 

62 return self._returnval(a) 

63 

64 def product(self, a, root=-1): 

65 return self._returnval(a) 

66 

67 def broadcast(self, a, root): 

68 assert root == 0 

69 return self._returnval(a) 

70 

71 def barrier(self): 

72 pass 

73 

74 

75class MPI: 

76 """Wrapper for MPI world object. 

77 

78 Decides at runtime (after all imports) which one to use: 

79 

80 * MPI4Py 

81 * GPAW 

82 * a dummy implementation for serial runs 

83 

84 """ 

85 def __init__(self): 

86 self.comm = None 

87 

88 def __getattr__(self, name): 

89 if self.comm is None: 

90 self.comm = _get_comm() 

91 return getattr(self.comm, name) 

92 

93 

94def _get_comm(): 

95 """Get the correct MPI world object.""" 

96 if 'mpi4py' in sys.modules: 

97 return MPI4PY() 

98 if '_gpaw' in sys.modules: 

99 import _gpaw 

100 if hasattr(_gpaw, 'Communicator'): 

101 return _gpaw.Communicator() 

102 if '_asap' in sys.modules: 

103 import _asap 

104 if hasattr(_asap, 'Communicator'): 

105 return _asap.Communicator() 

106 return DummyMPI() 

107 

108 

109class MPI4PY: 

110 def __init__(self, mpi4py_comm=None): 

111 if mpi4py_comm is None: 

112 from mpi4py import MPI 

113 mpi4py_comm = MPI.COMM_WORLD 

114 self.comm = mpi4py_comm 

115 

116 @property 

117 def rank(self): 

118 return self.comm.rank 

119 

120 @property 

121 def size(self): 

122 return self.comm.size 

123 

124 def _returnval(self, a, b): 

125 """Behave correctly when working on scalars/arrays. 

126 

127 Either input is an array and we in-place write b (output from 

128 mpi4py) back into a, or input is a scalar and we return the 

129 corresponding output scalar.""" 

130 if np.isscalar(a): 

131 assert np.isscalar(b) 

132 return b 

133 else: 

134 assert not np.isscalar(b) 

135 a[:] = b 

136 return None 

137 

138 def sum(self, a, root=-1): 

139 if root == -1: 

140 b = self.comm.allreduce(a) 

141 else: 

142 b = self.comm.reduce(a, root) 

143 return self._returnval(a, b) 

144 

145 def split(self, split_size=None): 

146 """Divide the communicator.""" 

147 # color - subgroup id 

148 # key - new subgroup rank 

149 if not split_size: 

150 split_size = self.size 

151 color = int(self.rank // (self.size / split_size)) 

152 key = int(self.rank % (self.size / split_size)) 

153 comm = self.comm.Split(color, key) 

154 return MPI4PY(comm) 

155 

156 def barrier(self): 

157 self.comm.barrier() 

158 

159 def abort(self, code): 

160 self.comm.Abort(code) 

161 

162 def broadcast(self, a, root): 

163 b = self.comm.bcast(a, root=root) 

164 if self.rank == root: 

165 if np.isscalar(a): 

166 return a 

167 return 

168 return self._returnval(a, b) 

169 

170 

171world = None 

172 

173# Check for special MPI-enabled Python interpreters: 

174if '_gpaw' in sys.builtin_module_names: 

175 # http://wiki.fysik.dtu.dk/gpaw 

176 import _gpaw 

177 world = _gpaw.Communicator() 

178elif '_asap' in sys.builtin_module_names: 

179 # Modern version of Asap 

180 # http://wiki.fysik.dtu.dk/asap 

181 # We cannot import asap3.mpi here, as that creates an import deadlock 

182 import _asap 

183 world = _asap.Communicator() 

184 

185# Check if MPI implementation has been imported already: 

186elif '_gpaw' in sys.modules: 

187 # Same thing as above but for the module version 

188 import _gpaw 

189 try: 

190 world = _gpaw.Communicator() 

191 except AttributeError: 

192 pass 

193elif '_asap' in sys.modules: 

194 import _asap 

195 try: 

196 world = _asap.Communicator() 

197 except AttributeError: 

198 pass 

199elif 'mpi4py' in sys.modules: 

200 world = MPI4PY() 

201 

202if world is None: 

203 world = MPI() 

204 

205 

206def barrier(): 

207 world.barrier() 

208 

209 

210def broadcast(obj, root=0, comm=world): 

211 """Broadcast a Python object across an MPI communicator and return it.""" 

212 if comm.rank == root: 

213 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

214 n = np.array([len(string)], int) 

215 else: 

216 string = None 

217 n = np.empty(1, int) 

218 comm.broadcast(n, root) 

219 if comm.rank == root: 

220 string = np.frombuffer(string, np.int8) 

221 else: 

222 string = np.zeros(n, np.int8) 

223 comm.broadcast(string, root) 

224 if comm.rank == root: 

225 return obj 

226 else: 

227 return pickle.loads(string.tobytes()) 

228 

229 

230def parallel_function(func): 

231 """Decorator for broadcasting from master to slaves using MPI. 

232 

233 Disable by passing parallel=False to the function. For a method, 

234 you can also disable the parallel behavior by giving the instance 

235 a self.serial = True. 

236 """ 

237 

238 @functools.wraps(func) 

239 def new_func(*args, **kwargs): 

240 if (world.size == 1 or 

241 args and getattr(args[0], 'serial', False) or 

242 not kwargs.pop('parallel', True)): 

243 # Disable: 

244 return func(*args, **kwargs) 

245 

246 ex = None 

247 result = None 

248 if world.rank == 0: 

249 try: 

250 result = func(*args, **kwargs) 

251 except Exception as x: 

252 ex = x 

253 ex, result = broadcast((ex, result)) 

254 if ex is not None: 

255 raise ex 

256 return result 

257 

258 return new_func 

259 

260 

261def parallel_generator(generator): 

262 """Decorator for broadcasting yields from master to slaves using MPI. 

263 

264 Disable by passing parallel=False to the function. For a method, 

265 you can also disable the parallel behavior by giving the instance 

266 a self.serial = True. 

267 """ 

268 

269 @functools.wraps(generator) 

270 def new_generator(*args, **kwargs): 

271 if (world.size == 1 or 

272 args and getattr(args[0], 'serial', False) or 

273 not kwargs.pop('parallel', True)): 

274 # Disable: 

275 for result in generator(*args, **kwargs): 

276 yield result 

277 return 

278 

279 if world.rank == 0: 

280 try: 

281 for result in generator(*args, **kwargs): 

282 broadcast((None, result)) 

283 yield result 

284 except Exception as ex: 

285 broadcast((ex, None)) 

286 raise ex 

287 broadcast((None, None)) 

288 else: 

289 ex2, result = broadcast((None, None)) 

290 if ex2 is not None: 

291 raise ex2 

292 while result is not None: 

293 yield result 

294 ex2, result = broadcast((None, None)) 

295 if ex2 is not None: 

296 raise ex2 

297 

298 return new_generator 

299 

300 

301def register_parallel_cleanup_function(): 

302 """Call MPI_Abort if python crashes. 

303 

304 This will terminate the processes on the other nodes.""" 

305 

306 if world.size == 1: 

307 return 

308 

309 def cleanup(sys=sys, time=time, world=world): 

310 error = getattr(sys, 'last_type', None) 

311 if error: 

312 sys.stdout.flush() 

313 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' + 

314 'Calling MPI_Abort!\n') % (world.rank, error)) 

315 sys.stderr.flush() 

316 # Give other nodes a moment to crash by themselves (perhaps 

317 # producing helpful error messages): 

318 time.sleep(3) 

319 world.abort(42) 

320 

321 atexit.register(cleanup) 

322 

323 

324def distribute_cpus(size, comm): 

325 """Distribute cpus to tasks and calculators. 

326 

327 Input: 

328 size: number of nodes per calculator 

329 comm: total communicator object 

330 

331 Output: 

332 communicator for this rank, number of calculators, index for this rank 

333 """ 

334 

335 assert size <= comm.size 

336 assert comm.size % size == 0 

337 

338 tasks_rank = comm.rank // size 

339 

340 r0 = tasks_rank * size 

341 ranks = np.arange(r0, r0 + size) 

342 mycomm = comm.new_communicator(ranks) 

343 

344 return mycomm, comm.size // size, tasks_rank 

345 

346 

347class ParallelModuleWrapper: 

348 def __getattr__(self, name): 

349 if name == 'rank' or name == 'size': 

350 warnings.warn('ase.parallel.{name} has been deprecated. ' 

351 'Please use ase.parallel.world.{name} instead.' 

352 .format(name=name), 

353 FutureWarning) 

354 return getattr(world, name) 

355 return getattr(_parallel, name) 

356 

357 

358_parallel = sys.modules['ase.parallel'] 

359sys.modules['ase.parallel'] = ParallelModuleWrapper() # type: ignore