Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import atexit
3import functools
4import pickle
5import sys
6import time
7import warnings
9import numpy as np
12def get_txt(txt, rank):
13 if hasattr(txt, 'write'):
14 # Note: User-supplied object might write to files from many ranks.
15 return txt
16 elif rank == 0:
17 if txt is None:
18 return open(os.devnull, 'w')
19 elif txt == '-':
20 return sys.stdout
21 else:
22 return open(txt, 'w', 1)
23 else:
24 return open(os.devnull, 'w')
27def paropen(name, mode='r', buffering=-1, encoding=None, comm=None):
28 """MPI-safe version of open function.
30 In read mode, the file is opened on all nodes. In write and
31 append mode, the file is opened on the master only, and /dev/null
32 is opened on all other nodes.
33 """
34 if comm is None:
35 comm = world
36 if comm.rank > 0 and mode[0] != 'r':
37 name = os.devnull
38 return open(name, mode, buffering, encoding)
41def parprint(*args, **kwargs):
42 """MPI-safe print - prints only from master. """
43 if world.rank == 0:
44 print(*args, **kwargs)
47class DummyMPI:
48 rank = 0
49 size = 1
51 def _returnval(self, a, root=-1):
52 # MPI interface works either on numbers, in which case a number is
53 # returned, or on arrays, in-place.
54 if np.isscalar(a):
55 return a
56 if hasattr(a, '__array__'):
57 a = a.__array__()
58 assert isinstance(a, np.ndarray)
59 return None
61 def sum(self, a, root=-1):
62 return self._returnval(a)
64 def product(self, a, root=-1):
65 return self._returnval(a)
67 def broadcast(self, a, root):
68 assert root == 0
69 return self._returnval(a)
71 def barrier(self):
72 pass
75class MPI:
76 """Wrapper for MPI world object.
78 Decides at runtime (after all imports) which one to use:
80 * MPI4Py
81 * GPAW
82 * a dummy implementation for serial runs
84 """
85 def __init__(self):
86 self.comm = None
88 def __getattr__(self, name):
89 if self.comm is None:
90 self.comm = _get_comm()
91 return getattr(self.comm, name)
94def _get_comm():
95 """Get the correct MPI world object."""
96 if 'mpi4py' in sys.modules:
97 return MPI4PY()
98 if '_gpaw' in sys.modules:
99 import _gpaw
100 if hasattr(_gpaw, 'Communicator'):
101 return _gpaw.Communicator()
102 if '_asap' in sys.modules:
103 import _asap
104 if hasattr(_asap, 'Communicator'):
105 return _asap.Communicator()
106 return DummyMPI()
109class MPI4PY:
110 def __init__(self, mpi4py_comm=None):
111 if mpi4py_comm is None:
112 from mpi4py import MPI
113 mpi4py_comm = MPI.COMM_WORLD
114 self.comm = mpi4py_comm
116 @property
117 def rank(self):
118 return self.comm.rank
120 @property
121 def size(self):
122 return self.comm.size
124 def _returnval(self, a, b):
125 """Behave correctly when working on scalars/arrays.
127 Either input is an array and we in-place write b (output from
128 mpi4py) back into a, or input is a scalar and we return the
129 corresponding output scalar."""
130 if np.isscalar(a):
131 assert np.isscalar(b)
132 return b
133 else:
134 assert not np.isscalar(b)
135 a[:] = b
136 return None
138 def sum(self, a, root=-1):
139 if root == -1:
140 b = self.comm.allreduce(a)
141 else:
142 b = self.comm.reduce(a, root)
143 return self._returnval(a, b)
145 def split(self, split_size=None):
146 """Divide the communicator."""
147 # color - subgroup id
148 # key - new subgroup rank
149 if not split_size:
150 split_size = self.size
151 color = int(self.rank // (self.size / split_size))
152 key = int(self.rank % (self.size / split_size))
153 comm = self.comm.Split(color, key)
154 return MPI4PY(comm)
156 def barrier(self):
157 self.comm.barrier()
159 def abort(self, code):
160 self.comm.Abort(code)
162 def broadcast(self, a, root):
163 b = self.comm.bcast(a, root=root)
164 if self.rank == root:
165 if np.isscalar(a):
166 return a
167 return
168 return self._returnval(a, b)
171world = None
173# Check for special MPI-enabled Python interpreters:
174if '_gpaw' in sys.builtin_module_names:
175 # http://wiki.fysik.dtu.dk/gpaw
176 import _gpaw
177 world = _gpaw.Communicator()
178elif '_asap' in sys.builtin_module_names:
179 # Modern version of Asap
180 # http://wiki.fysik.dtu.dk/asap
181 # We cannot import asap3.mpi here, as that creates an import deadlock
182 import _asap
183 world = _asap.Communicator()
185# Check if MPI implementation has been imported already:
186elif '_gpaw' in sys.modules:
187 # Same thing as above but for the module version
188 import _gpaw
189 try:
190 world = _gpaw.Communicator()
191 except AttributeError:
192 pass
193elif '_asap' in sys.modules:
194 import _asap
195 try:
196 world = _asap.Communicator()
197 except AttributeError:
198 pass
199elif 'mpi4py' in sys.modules:
200 world = MPI4PY()
202if world is None:
203 world = MPI()
206def barrier():
207 world.barrier()
210def broadcast(obj, root=0, comm=world):
211 """Broadcast a Python object across an MPI communicator and return it."""
212 if comm.rank == root:
213 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
214 n = np.array([len(string)], int)
215 else:
216 string = None
217 n = np.empty(1, int)
218 comm.broadcast(n, root)
219 if comm.rank == root:
220 string = np.frombuffer(string, np.int8)
221 else:
222 string = np.zeros(n, np.int8)
223 comm.broadcast(string, root)
224 if comm.rank == root:
225 return obj
226 else:
227 return pickle.loads(string.tobytes())
230def parallel_function(func):
231 """Decorator for broadcasting from master to slaves using MPI.
233 Disable by passing parallel=False to the function. For a method,
234 you can also disable the parallel behavior by giving the instance
235 a self.serial = True.
236 """
238 @functools.wraps(func)
239 def new_func(*args, **kwargs):
240 if (world.size == 1 or
241 args and getattr(args[0], 'serial', False) or
242 not kwargs.pop('parallel', True)):
243 # Disable:
244 return func(*args, **kwargs)
246 ex = None
247 result = None
248 if world.rank == 0:
249 try:
250 result = func(*args, **kwargs)
251 except Exception as x:
252 ex = x
253 ex, result = broadcast((ex, result))
254 if ex is not None:
255 raise ex
256 return result
258 return new_func
261def parallel_generator(generator):
262 """Decorator for broadcasting yields from master to slaves using MPI.
264 Disable by passing parallel=False to the function. For a method,
265 you can also disable the parallel behavior by giving the instance
266 a self.serial = True.
267 """
269 @functools.wraps(generator)
270 def new_generator(*args, **kwargs):
271 if (world.size == 1 or
272 args and getattr(args[0], 'serial', False) or
273 not kwargs.pop('parallel', True)):
274 # Disable:
275 for result in generator(*args, **kwargs):
276 yield result
277 return
279 if world.rank == 0:
280 try:
281 for result in generator(*args, **kwargs):
282 broadcast((None, result))
283 yield result
284 except Exception as ex:
285 broadcast((ex, None))
286 raise ex
287 broadcast((None, None))
288 else:
289 ex2, result = broadcast((None, None))
290 if ex2 is not None:
291 raise ex2
292 while result is not None:
293 yield result
294 ex2, result = broadcast((None, None))
295 if ex2 is not None:
296 raise ex2
298 return new_generator
301def register_parallel_cleanup_function():
302 """Call MPI_Abort if python crashes.
304 This will terminate the processes on the other nodes."""
306 if world.size == 1:
307 return
309 def cleanup(sys=sys, time=time, world=world):
310 error = getattr(sys, 'last_type', None)
311 if error:
312 sys.stdout.flush()
313 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' +
314 'Calling MPI_Abort!\n') % (world.rank, error))
315 sys.stderr.flush()
316 # Give other nodes a moment to crash by themselves (perhaps
317 # producing helpful error messages):
318 time.sleep(3)
319 world.abort(42)
321 atexit.register(cleanup)
324def distribute_cpus(size, comm):
325 """Distribute cpus to tasks and calculators.
327 Input:
328 size: number of nodes per calculator
329 comm: total communicator object
331 Output:
332 communicator for this rank, number of calculators, index for this rank
333 """
335 assert size <= comm.size
336 assert comm.size % size == 0
338 tasks_rank = comm.rank // size
340 r0 = tasks_rank * size
341 ranks = np.arange(r0, r0 + size)
342 mycomm = comm.new_communicator(ranks)
344 return mycomm, comm.size // size, tasks_rank
347class ParallelModuleWrapper:
348 def __getattr__(self, name):
349 if name == 'rank' or name == 'size':
350 warnings.warn('ase.parallel.{name} has been deprecated. '
351 'Please use ase.parallel.world.{name} instead.'
352 .format(name=name),
353 FutureWarning)
354 return getattr(world, name)
355 return getattr(_parallel, name)
358_parallel = sys.modules['ase.parallel']
359sys.modules['ase.parallel'] = ParallelModuleWrapper() # type: ignore