Coverage for /builds/debichem-team/python-ase/ase/calculators/genericfileio.py: 86.61%

127 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import shlex 

2from abc import ABC, abstractmethod 

3from contextlib import ExitStack 

4from os import PathLike 

5from pathlib import Path 

6from typing import Any, Iterable, List, Mapping, Optional, Set 

7 

8from ase.calculators.abc import GetOutputsMixin 

9from ase.calculators.calculator import ( 

10 BadConfiguration, 

11 BaseCalculator, 

12 _validate_command, 

13) 

14from ase.config import cfg as _cfg 

15 

16link_calculator_docs = ( 

17 "https://wiki.fysik.dtu.dk/ase/ase/calculators/" 

18 "calculators.html#calculator-configuration" 

19) 

20 

21 

22class BaseProfile(ABC): 

23 configvars: Set[str] = set() 

24 

25 def __init__(self, command): 

26 self.command = _validate_command(command) 

27 

28 @property 

29 def _split_command(self): 

30 return shlex.split(self.command) 

31 

32 def get_command(self, inputfile, calc_command=None) -> List[str]: 

33 """ 

34 Get the command to run. This should be a list of strings. 

35 

36 Parameters 

37 ---------- 

38 inputfile : str 

39 calc_command: list[str]: calculator command (used for sockets) 

40 

41 Returns 

42 ------- 

43 list of str 

44 The command to run. 

45 """ 

46 if calc_command is None: 

47 calc_command = self.get_calculator_command(inputfile) 

48 return [*self._split_command, *calc_command] 

49 

50 @abstractmethod 

51 def get_calculator_command(self, inputfile): 

52 """ 

53 The calculator specific command as a list of strings. 

54 

55 Parameters 

56 ---------- 

57 inputfile : str 

58 

59 Returns 

60 ------- 

61 list of str 

62 The command to run. 

63 """ 

64 

65 def run( 

66 self, directory: Path, inputfile: Optional[str], 

67 outputfile: str, errorfile: Optional[str] = None, 

68 append: bool = False 

69 ) -> None: 

70 """ 

71 Run the command in the given directory. 

72 

73 Parameters 

74 ---------- 

75 directory : pathlib.Path 

76 The directory to run the command in. 

77 inputfile : Optional[str] 

78 The name of the input file. 

79 outputfile : str 

80 The name of the output file. 

81 errorfile: Optional[str] 

82 the stderror file 

83 append: bool 

84 if True then use append mode 

85 """ 

86 

87 import os 

88 from subprocess import check_call 

89 

90 argv_command = self.get_command(inputfile) 

91 mode = 'wb' if not append else 'ab' 

92 

93 with ExitStack() as stack: 

94 output_path = directory / outputfile 

95 fd_out = stack.enter_context(open(output_path, mode)) 

96 if errorfile is not None: 

97 error_path = directory / errorfile 

98 fd_err = stack.enter_context(open(error_path, mode)) 

99 else: 

100 fd_err = None 

101 check_call( 

102 argv_command, 

103 cwd=directory, 

104 stdout=fd_out, 

105 stderr=fd_err, 

106 env=os.environ, 

107 ) 

108 

109 @abstractmethod 

110 def version(self): 

111 """Get the version of the code. 

112 

113 Returns 

114 ------- 

115 str 

116 The version of the code. 

117 """ 

118 

119 @classmethod 

120 def from_config(cls, cfg, section_name): 

121 """Create a profile from a configuration file. 

122 

123 Parameters 

124 ---------- 

125 cfg : ase.config.Config 

126 The configuration object. 

127 section_name : str 

128 The name of the section in the configuration file. E.g. the name 

129 of the template that this profile is for. 

130 

131 Returns 

132 ------- 

133 BaseProfile 

134 The profile object. 

135 """ 

136 section = cfg.parser[section_name] 

137 command = section['command'] 

138 

139 kwargs = { 

140 varname: section[varname] 

141 for varname in cls.configvars if varname in section 

142 } 

143 

144 try: 

145 return cls(command=command, **kwargs) 

146 except TypeError as err: 

147 raise BadConfiguration(*err.args) 

148 

149 

150def read_stdout(args, createfile=None): 

151 """Run command in tempdir and return standard output. 

152 

153 Helper function for getting version numbers of DFT codes. 

154 Most DFT codes don't implement a --version flag, so in order to 

155 determine the code version, we just run the code until it prints 

156 a version number.""" 

157 import tempfile 

158 from subprocess import PIPE, Popen 

159 

160 with tempfile.TemporaryDirectory() as directory: 

161 if createfile is not None: 

162 path = Path(directory) / createfile 

163 path.touch() 

164 proc = Popen( 

165 args, 

166 stdout=PIPE, 

167 stderr=PIPE, 

168 stdin=PIPE, 

169 cwd=directory, 

170 encoding='utf-8', # Make this a parameter if any non-utf8/ascii 

171 ) 

172 stdout, _ = proc.communicate() 

173 # Exit code will be != 0 because there isn't an input file 

174 return stdout 

175 

176 

177class CalculatorTemplate(ABC): 

178 def __init__(self, name: str, implemented_properties: Iterable[str]): 

179 self.name = name 

180 self.implemented_properties = frozenset(implemented_properties) 

181 

182 @abstractmethod 

183 def write_input(self, profile, directory, atoms, parameters, properties): 

184 ... 

185 

186 @abstractmethod 

187 def execute(self, directory, profile): 

188 ... 

189 

190 @abstractmethod 

191 def read_results(self, directory: PathLike) -> Mapping[str, Any]: 

192 ... 

193 

194 @abstractmethod 

195 def load_profile(self, cfg): 

196 ... 

197 

198 def socketio_calculator( 

199 self, 

200 profile, 

201 parameters, 

202 directory, 

203 # We may need quite a few socket kwargs here 

204 # if we want to expose all the timeout etc. from 

205 # SocketIOCalculator. 

206 unixsocket=None, 

207 port=None, 

208 ): 

209 import os 

210 from subprocess import Popen 

211 

212 from ase.calculators.socketio import SocketIOCalculator 

213 

214 if port and unixsocket: 

215 raise TypeError( 

216 'For the socketio_calculator only a UNIX ' 

217 '(unixsocket) or INET (port) socket can be used' 

218 ' not both.' 

219 ) 

220 

221 if not port and not unixsocket: 

222 raise TypeError( 

223 'For the socketio_calculator either a ' 

224 'UNIX (unixsocket) or INET (port) socket ' 

225 'must be used' 

226 ) 

227 

228 if not ( 

229 hasattr(self, 'socketio_argv') 

230 and hasattr(self, 'socketio_parameters') 

231 ): 

232 raise TypeError( 

233 f'Template {self} does not implement mandatory ' 

234 'socketio_argv() and socketio_parameters()' 

235 ) 

236 

237 # XXX need socketio ABC or something 

238 argv = profile.get_command( 

239 inputfile=None, 

240 calc_command=self.socketio_argv(profile, unixsocket, port) 

241 ) 

242 parameters = { 

243 **self.socketio_parameters(unixsocket, port), 

244 **parameters, 

245 } 

246 

247 # Not so elegant that socket args are passed to this function 

248 # via socketiocalculator when we could make a closure right here. 

249 def launch(atoms, properties, port, unixsocket): 

250 directory.mkdir(exist_ok=True, parents=True) 

251 

252 self.write_input( 

253 atoms=atoms, 

254 profile=profile, 

255 parameters=parameters, 

256 properties=properties, 

257 directory=directory, 

258 ) 

259 

260 with open(directory / self.outputname, 'w') as out_fd: 

261 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ) 

262 

263 return SocketIOCalculator( 

264 launch_client=launch, unixsocket=unixsocket, port=port 

265 ) 

266 

267 

268class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin): 

269 cfg = _cfg 

270 

271 def __init__( 

272 self, 

273 *, 

274 template, 

275 profile, 

276 directory, 

277 parameters=None, 

278 ): 

279 self.template = template 

280 if profile is None: 

281 if template.name not in self.cfg.parser: 

282 raise BadConfiguration( 

283 f"No configuration of '{template.name}'. " 

284 f"See '{link_calculator_docs}'" 

285 ) 

286 try: 

287 profile = template.load_profile(self.cfg) 

288 except Exception as err: 

289 configvars = self.cfg.as_dict() 

290 raise BadConfiguration( 

291 f'Failed to load section [{template.name}] ' 

292 f'from configuration: {configvars}' 

293 ) from err 

294 

295 self.profile = profile 

296 

297 # Maybe we should allow directory to be a factory, so 

298 # calculators e.g. produce new directories on demand. 

299 self.directory = Path(directory) 

300 super().__init__(parameters) 

301 

302 def set(self, *args, **kwargs): 

303 raise RuntimeError( 

304 'No setting parameters for now, please. ' 

305 'Just create new calculators.' 

306 ) 

307 

308 def __repr__(self): 

309 return f'{type(self).__name__}({self.template.name})' 

310 

311 @property 

312 def implemented_properties(self): 

313 return self.template.implemented_properties 

314 

315 @property 

316 def name(self): 

317 return self.template.name 

318 

319 def write_inputfiles(self, atoms, properties): 

320 # SocketIOCalculators like to write inputfiles 

321 # without calculating. 

322 self.directory.mkdir(exist_ok=True, parents=True) 

323 self.template.write_input( 

324 profile=self.profile, 

325 atoms=atoms, 

326 parameters=self.parameters, 

327 properties=properties, 

328 directory=self.directory, 

329 ) 

330 

331 def calculate(self, atoms, properties, system_changes): 

332 self.write_inputfiles(atoms, properties) 

333 self.template.execute(self.directory, self.profile) 

334 self.results = self.template.read_results(self.directory) 

335 # XXX Return something useful? 

336 

337 def _outputmixin_get_results(self): 

338 return self.results 

339 

340 def socketio(self, **socketkwargs): 

341 return self.template.socketio_calculator( 

342 directory=self.directory, 

343 parameters=self.parameters, 

344 profile=self.profile, 

345 **socketkwargs, 

346 )