Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import socket
3from subprocess import Popen, PIPE
4from contextlib import contextmanager
6import numpy as np
8from ase.calculators.calculator import (Calculator, all_changes,
9 PropertyNotImplementedError)
10import ase.units as units
11from ase.utils import IOContext
12from ase.stress import full_3x3_to_voigt_6_stress
15def actualunixsocketname(name):
16 return '/tmp/ipi_{}'.format(name)
19class SocketClosed(OSError):
20 pass
23class IPIProtocol:
24 """Communication using IPI protocol."""
26 def __init__(self, socket, txt=None):
27 self.socket = socket
29 if txt is None:
30 def log(*args):
31 pass
32 else:
33 def log(*args):
34 print('Driver:', *args, file=txt)
35 txt.flush()
36 self.log = log
38 def sendmsg(self, msg):
39 self.log(' sendmsg', repr(msg))
40 # assert msg in self.statements, msg
41 msg = msg.encode('ascii').ljust(12)
42 self.socket.sendall(msg)
44 def _recvall(self, nbytes):
45 """Repeatedly read chunks until we have nbytes.
47 Normally we get all bytes in one read, but that is not guaranteed."""
48 remaining = nbytes
49 chunks = []
50 while remaining > 0:
51 chunk = self.socket.recv(remaining)
52 if len(chunk) == 0:
53 # (If socket is still open, recv returns at least one byte)
54 raise SocketClosed()
55 chunks.append(chunk)
56 remaining -= len(chunk)
57 msg = b''.join(chunks)
58 assert len(msg) == nbytes and remaining == 0
59 return msg
61 def recvmsg(self):
62 msg = self._recvall(12)
63 if not msg:
64 raise SocketClosed()
66 assert len(msg) == 12, msg
67 msg = msg.rstrip().decode('ascii')
68 # assert msg in self.responses, msg
69 self.log(' recvmsg', repr(msg))
70 return msg
72 def send(self, a, dtype):
73 buf = np.asarray(a, dtype).tobytes()
74 # self.log(' send {}'.format(np.array(a).ravel().tolist()))
75 self.log(' send {} bytes of {}'.format(len(buf), dtype))
76 self.socket.sendall(buf)
78 def recv(self, shape, dtype):
79 a = np.empty(shape, dtype)
80 nbytes = np.dtype(dtype).itemsize * np.prod(shape)
81 buf = self._recvall(nbytes)
82 assert len(buf) == nbytes, (len(buf), nbytes)
83 self.log(' recv {} bytes of {}'.format(len(buf), dtype))
84 # print(np.frombuffer(buf, dtype=dtype))
85 a.flat[:] = np.frombuffer(buf, dtype=dtype)
86 # self.log(' recv {}'.format(a.ravel().tolist()))
87 assert np.isfinite(a).all()
88 return a
90 def sendposdata(self, cell, icell, positions):
91 assert cell.size == 9
92 assert icell.size == 9
93 assert positions.size % 3 == 0
95 self.log(' sendposdata')
96 self.sendmsg('POSDATA')
97 self.send(cell.T / units.Bohr, np.float64)
98 self.send(icell.T * units.Bohr, np.float64)
99 self.send(len(positions), np.int32)
100 self.send(positions / units.Bohr, np.float64)
102 def recvposdata(self):
103 cell = self.recv((3, 3), np.float64).T.copy()
104 icell = self.recv((3, 3), np.float64).T.copy()
105 natoms = self.recv(1, np.int32)
106 natoms = int(natoms)
107 positions = self.recv((natoms, 3), np.float64)
108 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
110 def sendrecv_force(self):
111 self.log(' sendrecv_force')
112 self.sendmsg('GETFORCE')
113 msg = self.recvmsg()
114 assert msg == 'FORCEREADY', msg
115 e = self.recv(1, np.float64)[0]
116 natoms = self.recv(1, np.int32)
117 assert natoms >= 0
118 forces = self.recv((int(natoms), 3), np.float64)
119 virial = self.recv((3, 3), np.float64).T.copy()
120 nmorebytes = self.recv(1, np.int32)
121 nmorebytes = int(nmorebytes)
122 if nmorebytes > 0:
123 # Receiving 0 bytes will block forever on python2.
124 morebytes = self.recv(nmorebytes, np.byte)
125 else:
126 morebytes = b''
127 return (e * units.Ha, (units.Ha / units.Bohr) * forces,
128 units.Ha * virial, morebytes)
130 def sendforce(self, energy, forces, virial,
131 morebytes=np.zeros(1, dtype=np.byte)):
132 assert np.array([energy]).size == 1
133 assert forces.shape[1] == 3
134 assert virial.shape == (3, 3)
136 self.log(' sendforce')
137 self.sendmsg('FORCEREADY') # mind the units
138 self.send(np.array([energy / units.Ha]), np.float64)
139 natoms = len(forces)
140 self.send(np.array([natoms]), np.int32)
141 self.send(units.Bohr / units.Ha * forces, np.float64)
142 self.send(1.0 / units.Ha * virial.T, np.float64)
143 # We prefer to always send at least one byte due to trouble with
144 # empty messages. Reading a closed socket yields 0 bytes
145 # and thus can be confused with a 0-length bytestring.
146 self.send(np.array([len(morebytes)]), np.int32)
147 self.send(morebytes, np.byte)
149 def status(self):
150 self.log(' status')
151 self.sendmsg('STATUS')
152 msg = self.recvmsg()
153 return msg
155 def end(self):
156 self.log(' end')
157 self.sendmsg('EXIT')
159 def recvinit(self):
160 self.log(' recvinit')
161 bead_index = self.recv(1, np.int32)
162 nbytes = self.recv(1, np.int32)
163 initbytes = self.recv(nbytes, np.byte)
164 return bead_index, initbytes
166 def sendinit(self):
167 # XXX Not sure what this function is supposed to send.
168 # It 'works' with QE, but for now we try not to call it.
169 self.log(' sendinit')
170 self.sendmsg('INIT')
171 self.send(0, np.int32) # 'bead index' always zero for now
172 # We send one byte, which is zero, since things may not work
173 # with 0 bytes. Apparently implementations ignore the
174 # initialization string anyway.
175 self.send(1, np.int32)
176 self.send(np.zeros(1), np.byte) # initialization string
178 def calculate(self, positions, cell):
179 self.log('calculate')
180 msg = self.status()
181 # We don't know how NEEDINIT is supposed to work, but some codes
182 # seem to be okay if we skip it and send the positions instead.
183 if msg == 'NEEDINIT':
184 self.sendinit()
185 msg = self.status()
186 assert msg == 'READY', msg
187 icell = np.linalg.pinv(cell).transpose()
188 self.sendposdata(cell, icell, positions)
189 msg = self.status()
190 assert msg == 'HAVEDATA', msg
191 e, forces, virial, morebytes = self.sendrecv_force()
192 r = dict(energy=e,
193 forces=forces,
194 virial=virial)
195 if morebytes:
196 r['morebytes'] = morebytes
197 return r
200@contextmanager
201def bind_unixsocket(socketfile):
202 assert socketfile.startswith('/tmp/ipi_'), socketfile
203 serversocket = socket.socket(socket.AF_UNIX)
204 try:
205 serversocket.bind(socketfile)
206 except OSError as err:
207 raise OSError('{}: {}'.format(err, repr(socketfile)))
209 try:
210 with serversocket:
211 yield serversocket
212 finally:
213 os.unlink(socketfile)
216@contextmanager
217def bind_inetsocket(port):
218 serversocket = socket.socket(socket.AF_INET)
219 serversocket.setsockopt(socket.SOL_SOCKET,
220 socket.SO_REUSEADDR, 1)
221 serversocket.bind(('', port))
222 with serversocket:
223 yield serversocket
226class FileIOSocketClientLauncher:
227 def __init__(self, calc):
228 self.calc = calc
230 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
231 assert self.calc is not None
232 cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
233 self.calc.write_input(atoms, properties=properties,
234 system_changes=all_changes)
235 cwd = self.calc.directory
236 cmd = cmd.format(port=port, unixsocket=unixsocket)
237 return Popen(cmd, shell=True, cwd=cwd)
240class SocketServer(IOContext):
241 default_port = 31415
243 def __init__(self, # launch_client=None,
244 port=None, unixsocket=None, timeout=None,
245 log=None):
246 """Create server and listen for connections.
248 Parameters:
250 client_command: Shell command to launch client process, or None
251 The process will be launched immediately, if given.
252 Else the user is expected to launch a client whose connection
253 the server will then accept at any time.
254 One calculate() is called, the server will block to wait
255 for the client.
256 port: integer or None
257 Port on which to listen for INET connections. Defaults
258 to 31415 if neither this nor unixsocket is specified.
259 unixsocket: string or None
260 Filename for unix socket.
261 timeout: float or None
262 timeout in seconds, or unlimited by default.
263 This parameter is passed to the Python socket object; see
264 documentation therof
265 log: file object or None
266 useful debug messages are written to this."""
268 if unixsocket is None and port is None:
269 port = self.default_port
270 elif unixsocket is not None and port is not None:
271 raise ValueError('Specify only one of unixsocket and port')
273 self.port = port
274 self.unixsocket = unixsocket
275 self.timeout = timeout
276 self._closed = False
278 if unixsocket is not None:
279 actualsocket = actualunixsocketname(unixsocket)
280 conn_name = 'UNIX-socket {}'.format(actualsocket)
281 socket_context = bind_unixsocket(actualsocket)
282 else:
283 conn_name = 'INET port {}'.format(port)
284 socket_context = bind_inetsocket(port)
286 self.serversocket = self.closelater(socket_context)
288 if log:
289 print('Accepting clients on {}'.format(conn_name), file=log)
291 self.serversocket.settimeout(timeout)
293 self.serversocket.listen(1)
295 self.log = log
297 self.proc = None
299 self.protocol = None
300 self.clientsocket = None
301 self.address = None
303 #if launch_client is not None:
304 # self.proc = launch_client(port=port, unixsocket=unixsocket)
306 def _accept(self):
307 """Wait for client and establish connection."""
308 # It should perhaps be possible for process to be launched by user
309 log = self.log
310 if log:
311 print('Awaiting client', file=self.log)
313 # If we launched the subprocess, the process may crash.
314 # We want to detect this, using loop with timeouts, and
315 # raise an error rather than blocking forever.
316 if self.proc is not None:
317 self.serversocket.settimeout(1.0)
319 while True:
320 try:
321 self.clientsocket, self.address = self.serversocket.accept()
322 self.closelater(self.clientsocket)
323 except socket.timeout:
324 if self.proc is not None:
325 status = self.proc.poll()
326 if status is not None:
327 raise OSError('Subprocess terminated unexpectedly'
328 ' with status {}'.format(status))
329 else:
330 break
332 self.serversocket.settimeout(self.timeout)
333 self.clientsocket.settimeout(self.timeout)
335 if log:
336 # For unix sockets, address is b''.
337 source = ('client' if self.address == b'' else self.address)
338 print('Accepted connection from {}'.format(source), file=log)
340 self.protocol = IPIProtocol(self.clientsocket, txt=log)
342 def close(self):
343 if self._closed:
344 return
346 super().close()
348 if self.log:
349 print('Close socket server', file=self.log)
350 self._closed = True
352 # Proper way to close sockets?
353 # And indeed i-pi connections...
354 # if self.protocol is not None:
355 # self.protocol.end() # Send end-of-communication string
356 self.protocol = None
357 if self.proc is not None:
358 exitcode = self.proc.wait()
359 if exitcode != 0:
360 import warnings
361 # Quantum Espresso seems to always exit with status 128,
362 # even if successful.
363 # Should investigate at some point
364 warnings.warn('Subprocess exited with status {}'
365 .format(exitcode))
366 # self.log('IPI server closed')
368 def calculate(self, atoms):
369 """Send geometry to client and return calculated things as dict.
371 This will block until client has established connection, then
372 wait for the client to finish the calculation."""
373 assert not self._closed
375 # If we have not established connection yet, we must block
376 # until the client catches up:
377 if self.protocol is None:
378 self._accept()
379 return self.protocol.calculate(atoms.positions, atoms.cell)
382class SocketClient:
383 def __init__(self, host='localhost', port=None,
384 unixsocket=None, timeout=None, log=None, comm=None):
385 """Create client and connect to server.
387 Parameters:
389 host: string
390 Hostname of server. Defaults to localhost
391 port: integer or None
392 Port to which to connect. By default 31415.
393 unixsocket: string or None
394 If specified, use corresponding UNIX socket.
395 See documentation of unixsocket for SocketIOCalculator.
396 timeout: float or None
397 See documentation of timeout for SocketIOCalculator.
398 log: file object or None
399 Log events to this file
400 comm: communicator or None
401 MPI communicator object. Defaults to ase.parallel.world.
402 When ASE runs in parallel, only the process with world.rank == 0
403 will communicate over the socket. The received information
404 will then be broadcast on the communicator. The SocketClient
405 must be created on all ranks of world, and will see the same
406 Atoms objects."""
407 if comm is None:
408 from ase.parallel import world
409 comm = world
411 # Only rank0 actually does the socket work.
412 # The other ranks only need to follow.
413 #
414 # Note: We actually refrain from assigning all the
415 # socket-related things except on master
416 self.comm = comm
418 if self.comm.rank == 0:
419 if unixsocket is not None:
420 sock = socket.socket(socket.AF_UNIX)
421 actualsocket = actualunixsocketname(unixsocket)
422 sock.connect(actualsocket)
423 else:
424 if port is None:
425 port = SocketServer.default_port
426 sock = socket.socket(socket.AF_INET)
427 sock.connect((host, port))
428 sock.settimeout(timeout)
429 self.host = host
430 self.port = port
431 self.unixsocket = unixsocket
433 self.protocol = IPIProtocol(sock, txt=log)
434 self.log = self.protocol.log
435 self.closed = False
437 self.bead_index = 0
438 self.bead_initbytes = b''
439 self.state = 'READY'
441 def close(self):
442 if not self.closed:
443 self.log('Close SocketClient')
444 self.closed = True
445 self.protocol.socket.close()
447 def calculate(self, atoms, use_stress):
448 # We should also broadcast the bead index, once we support doing
449 # multiple beads.
450 self.comm.broadcast(atoms.positions, 0)
451 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
453 energy = atoms.get_potential_energy()
454 forces = atoms.get_forces()
455 if use_stress:
456 stress = atoms.get_stress(voigt=False)
457 virial = -atoms.get_volume() * stress
458 else:
459 virial = np.zeros((3, 3))
460 return energy, forces, virial
462 def irun(self, atoms, use_stress=None):
463 if use_stress is None:
464 use_stress = any(atoms.pbc)
466 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
467 return my_irun(atoms, use_stress)
469 def irun_rankN(self, atoms, use_stress=True):
470 stop_criterion = np.zeros(1, bool)
471 while True:
472 self.comm.broadcast(stop_criterion, 0)
473 if stop_criterion[0]:
474 return
476 self.calculate(atoms, use_stress)
477 yield
479 def irun_rank0(self, atoms, use_stress=True):
480 # For every step we either calculate or quit. We need to
481 # tell other MPI processes (if this is MPI-parallel) whether they
482 # should calculate or quit.
483 try:
484 while True:
485 try:
486 msg = self.protocol.recvmsg()
487 except SocketClosed:
488 # Server closed the connection, but we want to
489 # exit gracefully anyway
490 msg = 'EXIT'
492 if msg == 'EXIT':
493 # Send stop signal to clients:
494 self.comm.broadcast(np.ones(1, bool), 0)
495 # (When otherwise exiting, things crashed and we should
496 # let MPI_ABORT take care of the mess instead of trying
497 # to synchronize the exit)
498 return
499 elif msg == 'STATUS':
500 self.protocol.sendmsg(self.state)
501 elif msg == 'POSDATA':
502 assert self.state == 'READY'
503 cell, icell, positions = self.protocol.recvposdata()
504 atoms.cell[:] = cell
505 atoms.positions[:] = positions
507 # User may wish to do something with the atoms object now.
508 # Should we provide option to yield here?
509 #
510 # (In that case we should MPI-synchronize *before*
511 # whereas now we do it after.)
513 # Send signal for other ranks to proceed with calculation:
514 self.comm.broadcast(np.zeros(1, bool), 0)
515 energy, forces, virial = self.calculate(atoms, use_stress)
517 self.state = 'HAVEDATA'
518 yield
519 elif msg == 'GETFORCE':
520 assert self.state == 'HAVEDATA', self.state
521 self.protocol.sendforce(energy, forces, virial)
522 self.state = 'NEEDINIT'
523 elif msg == 'INIT':
524 assert self.state == 'NEEDINIT'
525 bead_index, initbytes = self.protocol.recvinit()
526 self.bead_index = bead_index
527 self.bead_initbytes = initbytes
528 self.state = 'READY'
529 else:
530 raise KeyError('Bad message', msg)
531 finally:
532 self.close()
534 def run(self, atoms, use_stress=False):
535 for _ in self.irun(atoms, use_stress=use_stress):
536 pass
539class SocketIOCalculator(Calculator, IOContext):
540 implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
541 supported_changes = {'positions', 'cell'}
543 def __init__(self, calc=None, port=None,
544 unixsocket=None, timeout=None, log=None, *,
545 launch_client=None):
546 """Initialize socket I/O calculator.
548 This calculator launches a server which passes atomic
549 coordinates and unit cells to an external code via a socket,
550 and receives energy, forces, and stress in return.
552 ASE integrates this with the Quantum Espresso, FHI-aims and
553 Siesta calculators. This works with any external code that
554 supports running as a client over the i-PI protocol.
556 Parameters:
558 calc: calculator or None
560 If calc is not None, a client process will be launched
561 using calc.command, and the input file will be generated
562 using ``calc.write_input()``. Otherwise only the server will
563 run, and it is up to the user to launch a compliant client
564 process.
566 port: integer
568 port number for socket. Should normally be between 1025
569 and 65535. Typical ports for are 31415 (default) or 3141.
571 unixsocket: str or None
573 if not None, ignore host and port, creating instead a
574 unix socket using this name prefixed with ``/tmp/ipi_``.
575 The socket is deleted when the calculator is closed.
577 timeout: float >= 0 or None
579 timeout for connection, by default infinite. See
580 documentation of Python sockets. For longer jobs it is
581 recommended to set a timeout in case of undetected
582 client-side failure.
584 log: file object or None (default)
586 logfile for communication over socket. For debugging or
587 the curious.
589 In order to correctly close the sockets, it is
590 recommended to use this class within a with-block:
592 >>> with SocketIOCalculator(...) as calc:
593 ... atoms.calc = calc
594 ... atoms.get_forces()
595 ... atoms.rattle()
596 ... atoms.get_forces()
598 It is also possible to call calc.close() after
599 use. This is best done in a finally-block."""
601 Calculator.__init__(self)
603 if calc is not None:
604 if launch_client is not None:
605 raise ValueError('Cannot pass both calc and launch_client')
606 launch_client = FileIOSocketClientLauncher(calc)
607 self.launch_client = launch_client
608 #self.calc = calc
609 self.timeout = timeout
610 self.server = None
612 self.log = self.openfile(log)
614 # We only hold these so we can pass them on to the server.
615 # They may both be None as stored here.
616 self._port = port
617 self._unixsocket = unixsocket
619 # If there is a calculator, we will launch in calculate() because
620 # we are responsible for executing the external process, too, and
621 # should do so before blocking. Without a calculator we want to
622 # block immediately:
623 if self.launch_client is None:
624 self.server = self.launch_server()
626 def todict(self):
627 d = {'type': 'calculator',
628 'name': 'socket-driver'}
629 #if self.calc is not None:
630 # d['calc'] = self.calc.todict()
631 return d
633 def launch_server(self):
634 return self.closelater(SocketServer(
635 #launch_client=launch_client,
636 port=self._port,
637 unixsocket=self._unixsocket,
638 timeout=self.timeout, log=self.log,
639 ))
641 def calculate(self, atoms=None, properties=['energy'],
642 system_changes=all_changes):
643 bad = [change for change in system_changes
644 if change not in self.supported_changes]
646 # First time calculate() is called, system_changes will be
647 # all_changes. After that, only positions and cell may change.
648 if self.atoms is not None and any(bad):
649 raise PropertyNotImplementedError(
650 'Cannot change {} through IPI protocol. '
651 'Please create new socket calculator.'
652 .format(bad if len(bad) > 1 else bad[0]))
654 self.atoms = atoms.copy()
656 if self.server is None:
657 self.server = self.launch_server()
658 proc = self.launch_client(atoms, properties,
659 port=self._port,
660 unixsocket=self._unixsocket)
661 self.server.proc = proc # XXX nasty hack
663 results = self.server.calculate(atoms)
664 results['free_energy'] = results['energy']
665 virial = results.pop('virial')
666 if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
667 vol = atoms.get_volume()
668 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
669 self.results.update(results)
671 def close(self):
672 self.server = None
673 super().close()
676class PySocketIOClient:
677 def __init__(self, calculator_factory):
678 self._calculator_factory = calculator_factory
680 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
681 import sys
682 import pickle
684 # We pickle everything first, so we won't need to bother with the
685 # process as long as it succeeds.
686 transferbytes = pickle.dumps([
687 dict(unixsocket=unixsocket, port=port),
688 atoms.copy(),
689 self._calculator_factory,
690 ])
692 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
693 stdin=PIPE)
695 proc.stdin.write(transferbytes)
696 proc.stdin.close()
697 return proc
699 @staticmethod
700 def main():
701 import sys
702 import pickle
704 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
705 atoms.calc = get_calculator()
706 client = SocketClient(host='localhost',
707 unixsocket=socketinfo.get('unixsocket'),
708 port=socketinfo.get('port'))
709 # XXX In principle we could avoid calculating stress until
710 # someone requests the stress, could we not?
711 # Which would make use_stress boolean unnecessary.
712 client.run(atoms, use_stress=True)
715if __name__ == '__main__':
716 PySocketIOClient.main()