Coverage for /builds/debichem-team/python-ase/ase/calculators/socketio.py: 90.91%
396 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import os
2import socket
3from contextlib import contextmanager
4from subprocess import PIPE, Popen
6import numpy as np
8import ase.units as units
9from ase.calculators.calculator import (
10 Calculator,
11 OldShellProfile,
12 PropertyNotImplementedError,
13 StandardProfile,
14 all_changes,
15)
16from ase.calculators.genericfileio import GenericFileIOCalculator
17from ase.parallel import world
18from ase.stress import full_3x3_to_voigt_6_stress
19from ase.utils import IOContext
22def actualunixsocketname(name):
23 return f'/tmp/ipi_{name}'
26class SocketClosed(OSError):
27 pass
30class IPIProtocol:
31 """Communication using IPI protocol."""
33 def __init__(self, socket, txt=None):
34 self.socket = socket
36 if txt is None:
37 def log(*args):
38 pass
39 else:
40 def log(*args):
41 print('Driver:', *args, file=txt)
42 txt.flush()
43 self.log = log
45 def sendmsg(self, msg):
46 self.log(' sendmsg', repr(msg))
47 # assert msg in self.statements, msg
48 msg = msg.encode('ascii').ljust(12)
49 self.socket.sendall(msg)
51 def _recvall(self, nbytes):
52 """Repeatedly read chunks until we have nbytes.
54 Normally we get all bytes in one read, but that is not guaranteed."""
55 remaining = nbytes
56 chunks = []
57 while remaining > 0:
58 chunk = self.socket.recv(remaining)
59 if len(chunk) == 0:
60 # (If socket is still open, recv returns at least one byte)
61 raise SocketClosed
62 chunks.append(chunk)
63 remaining -= len(chunk)
64 msg = b''.join(chunks)
65 assert len(msg) == nbytes and remaining == 0
66 return msg
68 def recvmsg(self):
69 msg = self._recvall(12)
70 if not msg:
71 raise SocketClosed
73 assert len(msg) == 12, msg
74 msg = msg.rstrip().decode('ascii')
75 # assert msg in self.responses, msg
76 self.log(' recvmsg', repr(msg))
77 return msg
79 def send(self, a, dtype):
80 buf = np.asarray(a, dtype).tobytes()
81 # self.log(' send {}'.format(np.array(a).ravel().tolist()))
82 self.log(f' send {len(buf)} bytes of {dtype}')
83 self.socket.sendall(buf)
85 def recv(self, shape, dtype):
86 a = np.empty(shape, dtype)
87 nbytes = np.dtype(dtype).itemsize * np.prod(shape)
88 buf = self._recvall(nbytes)
89 assert len(buf) == nbytes, (len(buf), nbytes)
90 self.log(f' recv {len(buf)} bytes of {dtype}')
91 # print(np.frombuffer(buf, dtype=dtype))
92 a.flat[:] = np.frombuffer(buf, dtype=dtype)
93 # self.log(' recv {}'.format(a.ravel().tolist()))
94 assert np.isfinite(a).all()
95 return a
97 def sendposdata(self, cell, icell, positions):
98 assert cell.size == 9
99 assert icell.size == 9
100 assert positions.size % 3 == 0
102 self.log(' sendposdata')
103 self.sendmsg('POSDATA')
104 self.send(cell.T / units.Bohr, np.float64)
105 self.send(icell.T * units.Bohr, np.float64)
106 self.send(len(positions), np.int32)
107 self.send(positions / units.Bohr, np.float64)
109 def recvposdata(self):
110 cell = self.recv((3, 3), np.float64).T.copy()
111 icell = self.recv((3, 3), np.float64).T.copy()
112 natoms = self.recv(1, np.int32)[0]
113 positions = self.recv((natoms, 3), np.float64)
114 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
116 def sendrecv_force(self):
117 self.log(' sendrecv_force')
118 self.sendmsg('GETFORCE')
119 msg = self.recvmsg()
120 assert msg == 'FORCEREADY', msg
121 e = self.recv(1, np.float64)[0]
122 natoms = self.recv(1, np.int32)[0]
123 assert natoms >= 0
124 forces = self.recv((int(natoms), 3), np.float64)
125 virial = self.recv((3, 3), np.float64).T.copy()
126 nmorebytes = self.recv(1, np.int32)[0]
127 morebytes = self.recv(nmorebytes, np.byte)
128 return (e * units.Ha, (units.Ha / units.Bohr) * forces,
129 units.Ha * virial, morebytes)
131 def sendforce(self, energy, forces, virial,
132 morebytes=np.zeros(1, dtype=np.byte)):
133 assert np.array([energy]).size == 1
134 assert forces.shape[1] == 3
135 assert virial.shape == (3, 3)
137 self.log(' sendforce')
138 self.sendmsg('FORCEREADY') # mind the units
139 self.send(np.array([energy / units.Ha]), np.float64)
140 natoms = len(forces)
141 self.send(np.array([natoms]), np.int32)
142 self.send(units.Bohr / units.Ha * forces, np.float64)
143 self.send(1.0 / units.Ha * virial.T, np.float64)
144 # We prefer to always send at least one byte due to trouble with
145 # empty messages. Reading a closed socket yields 0 bytes
146 # and thus can be confused with a 0-length bytestring.
147 self.send(np.array([len(morebytes)]), np.int32)
148 self.send(morebytes, np.byte)
150 def status(self):
151 self.log(' status')
152 self.sendmsg('STATUS')
153 msg = self.recvmsg()
154 return msg
156 def end(self):
157 self.log(' end')
158 self.sendmsg('EXIT')
160 def recvinit(self):
161 self.log(' recvinit')
162 bead_index = self.recv(1, np.int32)
163 nbytes = self.recv(1, np.int32)
164 initbytes = self.recv(nbytes, np.byte)
165 return bead_index, initbytes
167 def sendinit(self):
168 # XXX Not sure what this function is supposed to send.
169 # It 'works' with QE, but for now we try not to call it.
170 self.log(' sendinit')
171 self.sendmsg('INIT')
172 self.send(0, np.int32) # 'bead index' always zero for now
173 # We send one byte, which is zero, since things may not work
174 # with 0 bytes. Apparently implementations ignore the
175 # initialization string anyway.
176 self.send(1, np.int32)
177 self.send(np.zeros(1), np.byte) # initialization string
179 def calculate(self, positions, cell):
180 self.log('calculate')
181 msg = self.status()
182 # We don't know how NEEDINIT is supposed to work, but some codes
183 # seem to be okay if we skip it and send the positions instead.
184 if msg == 'NEEDINIT':
185 self.sendinit()
186 msg = self.status()
187 assert msg == 'READY', msg
188 icell = np.linalg.pinv(cell).transpose()
189 self.sendposdata(cell, icell, positions)
190 msg = self.status()
191 assert msg == 'HAVEDATA', msg
192 e, forces, virial, morebytes = self.sendrecv_force()
193 r = dict(energy=e,
194 forces=forces,
195 virial=virial,
196 morebytes=morebytes)
197 return r
200@contextmanager
201def bind_unixsocket(socketfile):
202 assert socketfile.startswith('/tmp/ipi_'), socketfile
203 serversocket = socket.socket(socket.AF_UNIX)
204 try:
205 serversocket.bind(socketfile)
206 except OSError as err:
207 raise OSError(f'{err}: {socketfile!r}')
209 try:
210 with serversocket:
211 yield serversocket
212 finally:
213 os.unlink(socketfile)
216@contextmanager
217def bind_inetsocket(port):
218 serversocket = socket.socket(socket.AF_INET)
219 serversocket.setsockopt(socket.SOL_SOCKET,
220 socket.SO_REUSEADDR, 1)
221 serversocket.bind(('', port))
222 with serversocket:
223 yield serversocket
226class FileIOSocketClientLauncher:
227 def __init__(self, calc):
228 self.calc = calc
230 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
231 assert self.calc is not None
232 cwd = self.calc.directory
234 profile = getattr(self.calc, 'profile', None)
235 if isinstance(self.calc, GenericFileIOCalculator):
236 # New GenericFileIOCalculator:
237 template = getattr(self.calc, 'template')
239 self.calc.write_inputfiles(atoms, properties)
240 if unixsocket is not None:
241 argv = template.socketio_argv(
242 profile, unixsocket=unixsocket, port=None
243 )
244 else:
245 argv = template.socketio_argv(
246 profile, unixsocket=None, port=port
247 )
248 return Popen(argv, cwd=cwd, env=os.environ)
249 else:
250 # Old FileIOCalculator:
251 self.calc.write_input(atoms, properties=properties,
252 system_changes=all_changes)
254 if isinstance(profile, StandardProfile):
255 return profile.execute_nonblocking(self.calc)
257 if profile is None:
258 cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
259 cmd = cmd.format(port=port, unixsocket=unixsocket)
260 elif isinstance(profile, OldShellProfile):
261 cmd = profile.command.replace("PREFIX", self.calc.prefix)
262 else:
263 raise TypeError(
264 f"Profile type {type(profile)} not supported for socketio")
266 return Popen(cmd, shell=True, cwd=cwd)
269class SocketServer(IOContext):
270 default_port = 31415
272 def __init__(self, # launch_client=None,
273 port=None, unixsocket=None, timeout=None,
274 log=None):
275 """Create server and listen for connections.
277 Parameters:
279 client_command: Shell command to launch client process, or None
280 The process will be launched immediately, if given.
281 Else the user is expected to launch a client whose connection
282 the server will then accept at any time.
283 One calculate() is called, the server will block to wait
284 for the client.
285 port: integer or None
286 Port on which to listen for INET connections. Defaults
287 to 31415 if neither this nor unixsocket is specified.
288 unixsocket: string or None
289 Filename for unix socket.
290 timeout: float or None
291 timeout in seconds, or unlimited by default.
292 This parameter is passed to the Python socket object; see
293 documentation therof
294 log: file object or None
295 useful debug messages are written to this."""
297 if unixsocket is None and port is None:
298 port = self.default_port
299 elif unixsocket is not None and port is not None:
300 raise ValueError('Specify only one of unixsocket and port')
302 self.port = port
303 self.unixsocket = unixsocket
304 self.timeout = timeout
305 self._closed = False
307 if unixsocket is not None:
308 actualsocket = actualunixsocketname(unixsocket)
309 conn_name = f'UNIX-socket {actualsocket}'
310 socket_context = bind_unixsocket(actualsocket)
311 else:
312 conn_name = f'INET port {port}'
313 socket_context = bind_inetsocket(port)
315 self.serversocket = self.closelater(socket_context)
317 if log:
318 print(f'Accepting clients on {conn_name}', file=log)
320 self.serversocket.settimeout(timeout)
322 self.serversocket.listen(1)
324 self.log = log
326 self.proc = None
328 self.protocol = None
329 self.clientsocket = None
330 self.address = None
332 # if launch_client is not None:
333 # self.proc = launch_client(port=port, unixsocket=unixsocket)
335 def _accept(self):
336 """Wait for client and establish connection."""
337 # It should perhaps be possible for process to be launched by user
338 log = self.log
339 if log:
340 print('Awaiting client', file=self.log)
342 # If we launched the subprocess, the process may crash.
343 # We want to detect this, using loop with timeouts, and
344 # raise an error rather than blocking forever.
345 if self.proc is not None:
346 self.serversocket.settimeout(1.0)
348 while True:
349 try:
350 self.clientsocket, self.address = self.serversocket.accept()
351 self.closelater(self.clientsocket)
352 except socket.timeout:
353 if self.proc is not None:
354 status = self.proc.poll()
355 if status is not None:
356 raise OSError('Subprocess terminated unexpectedly'
357 ' with status {}'.format(status))
358 else:
359 break
361 self.serversocket.settimeout(self.timeout)
362 self.clientsocket.settimeout(self.timeout)
364 if log:
365 # For unix sockets, address is b''.
366 source = ('client' if self.address == b'' else self.address)
367 print(f'Accepted connection from {source}', file=log)
369 self.protocol = IPIProtocol(self.clientsocket, txt=log)
371 def close(self):
372 if self._closed:
373 return
375 super().close()
377 if self.log:
378 print('Close socket server', file=self.log)
379 self._closed = True
381 # Proper way to close sockets?
382 # And indeed i-pi connections...
383 # if self.protocol is not None:
384 # self.protocol.end() # Send end-of-communication string
385 self.protocol = None
386 if self.proc is not None:
387 exitcode = self.proc.wait()
388 if exitcode != 0:
389 import warnings
391 # Quantum Espresso seems to always exit with status 128,
392 # even if successful.
393 # Should investigate at some point
394 warnings.warn('Subprocess exited with status {}'
395 .format(exitcode))
396 # self.log('IPI server closed')
398 def calculate(self, atoms):
399 """Send geometry to client and return calculated things as dict.
401 This will block until client has established connection, then
402 wait for the client to finish the calculation."""
403 assert not self._closed
405 # If we have not established connection yet, we must block
406 # until the client catches up:
407 if self.protocol is None:
408 self._accept()
409 return self.protocol.calculate(atoms.positions, atoms.cell)
412class SocketClient:
413 def __init__(self, host='localhost', port=None,
414 unixsocket=None, timeout=None, log=None, comm=world):
415 """Create client and connect to server.
417 Parameters:
419 host: string
420 Hostname of server. Defaults to localhost
421 port: integer or None
422 Port to which to connect. By default 31415.
423 unixsocket: string or None
424 If specified, use corresponding UNIX socket.
425 See documentation of unixsocket for SocketIOCalculator.
426 timeout: float or None
427 See documentation of timeout for SocketIOCalculator.
428 log: file object or None
429 Log events to this file
430 comm: communicator or None
431 MPI communicator object. Defaults to ase.parallel.world.
432 When ASE runs in parallel, only the process with world.rank == 0
433 will communicate over the socket. The received information
434 will then be broadcast on the communicator. The SocketClient
435 must be created on all ranks of world, and will see the same
436 Atoms objects."""
437 # Only rank0 actually does the socket work.
438 # The other ranks only need to follow.
439 #
440 # Note: We actually refrain from assigning all the
441 # socket-related things except on master
442 self.comm = comm
444 if self.comm.rank == 0:
445 if unixsocket is not None:
446 sock = socket.socket(socket.AF_UNIX)
447 actualsocket = actualunixsocketname(unixsocket)
448 sock.connect(actualsocket)
449 else:
450 if port is None:
451 port = SocketServer.default_port
452 sock = socket.socket(socket.AF_INET)
453 sock.connect((host, port))
454 sock.settimeout(timeout)
455 self.host = host
456 self.port = port
457 self.unixsocket = unixsocket
459 self.protocol = IPIProtocol(sock, txt=log)
460 self.log = self.protocol.log
461 self.closed = False
463 self.bead_index = 0
464 self.bead_initbytes = b''
465 self.state = 'READY'
467 def close(self):
468 if not self.closed:
469 self.log('Close SocketClient')
470 self.closed = True
471 self.protocol.socket.close()
473 def calculate(self, atoms, use_stress):
474 # We should also broadcast the bead index, once we support doing
475 # multiple beads.
476 self.comm.broadcast(atoms.positions, 0)
477 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
479 energy = atoms.get_potential_energy()
480 forces = atoms.get_forces()
481 if use_stress:
482 stress = atoms.get_stress(voigt=False)
483 virial = -atoms.get_volume() * stress
484 else:
485 virial = np.zeros((3, 3))
486 return energy, forces, virial
488 def irun(self, atoms, use_stress=None):
489 if use_stress is None:
490 use_stress = any(atoms.pbc)
492 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
493 return my_irun(atoms, use_stress)
495 def irun_rankN(self, atoms, use_stress=True):
496 stop_criterion = np.zeros(1, bool)
497 while True:
498 self.comm.broadcast(stop_criterion, 0)
499 if stop_criterion[0]:
500 return
502 self.calculate(atoms, use_stress)
503 yield
505 def irun_rank0(self, atoms, use_stress=True):
506 # For every step we either calculate or quit. We need to
507 # tell other MPI processes (if this is MPI-parallel) whether they
508 # should calculate or quit.
509 try:
510 while True:
511 try:
512 msg = self.protocol.recvmsg()
513 except SocketClosed:
514 # Server closed the connection, but we want to
515 # exit gracefully anyway
516 msg = 'EXIT'
518 if msg == 'EXIT':
519 # Send stop signal to clients:
520 self.comm.broadcast(np.ones(1, bool), 0)
521 # (When otherwise exiting, things crashed and we should
522 # let MPI_ABORT take care of the mess instead of trying
523 # to synchronize the exit)
524 return
525 elif msg == 'STATUS':
526 self.protocol.sendmsg(self.state)
527 elif msg == 'POSDATA':
528 assert self.state == 'READY'
529 cell, _icell, positions = self.protocol.recvposdata()
530 atoms.cell[:] = cell
531 atoms.positions[:] = positions
533 # User may wish to do something with the atoms object now.
534 # Should we provide option to yield here?
535 #
536 # (In that case we should MPI-synchronize *before*
537 # whereas now we do it after.)
539 # Send signal for other ranks to proceed with calculation:
540 self.comm.broadcast(np.zeros(1, bool), 0)
541 energy, forces, virial = self.calculate(atoms, use_stress)
543 self.state = 'HAVEDATA'
544 yield
545 elif msg == 'GETFORCE':
546 assert self.state == 'HAVEDATA', self.state
547 self.protocol.sendforce(energy, forces, virial)
548 self.state = 'NEEDINIT'
549 elif msg == 'INIT':
550 assert self.state == 'NEEDINIT'
551 bead_index, initbytes = self.protocol.recvinit()
552 self.bead_index = bead_index
553 self.bead_initbytes = initbytes
554 self.state = 'READY'
555 else:
556 raise KeyError('Bad message', msg)
557 finally:
558 self.close()
560 def run(self, atoms, use_stress=False):
561 for _ in self.irun(atoms, use_stress=use_stress):
562 pass
565class SocketIOCalculator(Calculator, IOContext):
566 implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
567 supported_changes = {'positions', 'cell'}
569 def __init__(self, calc=None, port=None,
570 unixsocket=None, timeout=None, log=None, *,
571 launch_client=None, comm=world):
572 """Initialize socket I/O calculator.
574 This calculator launches a server which passes atomic
575 coordinates and unit cells to an external code via a socket,
576 and receives energy, forces, and stress in return.
578 ASE integrates this with the Quantum Espresso, FHI-aims and
579 Siesta calculators. This works with any external code that
580 supports running as a client over the i-PI protocol.
582 Parameters:
584 calc: calculator or None
586 If calc is not None, a client process will be launched
587 using calc.command, and the input file will be generated
588 using ``calc.write_input()``. Otherwise only the server will
589 run, and it is up to the user to launch a compliant client
590 process.
592 port: integer
594 port number for socket. Should normally be between 1025
595 and 65535. Typical ports for are 31415 (default) or 3141.
597 unixsocket: str or None
599 if not None, ignore host and port, creating instead a
600 unix socket using this name prefixed with ``/tmp/ipi_``.
601 The socket is deleted when the calculator is closed.
603 timeout: float >= 0 or None
605 timeout for connection, by default infinite. See
606 documentation of Python sockets. For longer jobs it is
607 recommended to set a timeout in case of undetected
608 client-side failure.
610 log: file object or None (default)
612 logfile for communication over socket. For debugging or
613 the curious.
615 In order to correctly close the sockets, it is
616 recommended to use this class within a with-block:
618 >>> from ase.calculators.socketio import SocketIOCalculator
620 >>> with SocketIOCalculator(...) as calc: # doctest:+SKIP
621 ... atoms.calc = calc
622 ... atoms.get_forces()
623 ... atoms.rattle()
624 ... atoms.get_forces()
626 It is also possible to call calc.close() after
627 use. This is best done in a finally-block."""
629 Calculator.__init__(self)
631 if calc is not None:
632 if launch_client is not None:
633 raise ValueError('Cannot pass both calc and launch_client')
634 launch_client = FileIOSocketClientLauncher(calc)
635 self.launch_client = launch_client
636 self.timeout = timeout
637 self.server = None
639 self.log = self.openfile(file=log, comm=comm)
641 # We only hold these so we can pass them on to the server.
642 # They may both be None as stored here.
643 self._port = port
644 self._unixsocket = unixsocket
646 # If there is a calculator, we will launch in calculate() because
647 # we are responsible for executing the external process, too, and
648 # should do so before blocking. Without a calculator we want to
649 # block immediately:
650 if self.launch_client is None:
651 self.server = self.launch_server()
653 def todict(self):
654 d = {'type': 'calculator',
655 'name': 'socket-driver'}
656 # if self.calc is not None:
657 # d['calc'] = self.calc.todict()
658 return d
660 def launch_server(self):
661 return self.closelater(SocketServer(
662 # launch_client=launch_client,
663 port=self._port,
664 unixsocket=self._unixsocket,
665 timeout=self.timeout, log=self.log,
666 ))
668 def calculate(self, atoms=None, properties=['energy'],
669 system_changes=all_changes):
670 bad = [change for change in system_changes
671 if change not in self.supported_changes]
673 # First time calculate() is called, system_changes will be
674 # all_changes. After that, only positions and cell may change.
675 if self.atoms is not None and any(bad):
676 raise PropertyNotImplementedError(
677 'Cannot change {} through IPI protocol. '
678 'Please create new socket calculator.'
679 .format(bad if len(bad) > 1 else bad[0]))
681 self.atoms = atoms.copy()
683 if self.server is None:
684 self.server = self.launch_server()
685 proc = self.launch_client(atoms, properties,
686 port=self._port,
687 unixsocket=self._unixsocket)
688 self.server.proc = proc # XXX nasty hack
690 results = self.server.calculate(atoms)
691 results['free_energy'] = results['energy']
692 virial = results.pop('virial')
693 if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
694 vol = atoms.get_volume()
695 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
696 self.results.update(results)
698 def close(self):
699 self.server = None
700 super().close()
703class PySocketIOClient:
704 def __init__(self, calculator_factory):
705 self._calculator_factory = calculator_factory
707 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
708 import pickle
709 import sys
711 # We pickle everything first, so we won't need to bother with the
712 # process as long as it succeeds.
713 transferbytes = pickle.dumps([
714 dict(unixsocket=unixsocket, port=port),
715 atoms.copy(),
716 self._calculator_factory,
717 ])
719 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
720 stdin=PIPE)
722 proc.stdin.write(transferbytes)
723 proc.stdin.close()
724 return proc
726 @staticmethod
727 def main():
728 import pickle
729 import sys
731 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
732 atoms.calc = get_calculator()
733 client = SocketClient(host='localhost',
734 unixsocket=socketinfo.get('unixsocket'),
735 port=socketinfo.get('port'))
736 # XXX In principle we could avoid calculating stress until
737 # someone requests the stress, could we not?
738 # Which would make use_stress boolean unnecessary.
739 client.run(atoms, use_stress=True)
742if __name__ == '__main__':
743 PySocketIOClient.main()