Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Provides LoggingCalculator class to wrap a Calculator and record 

3number of enery and force calls 

4""" 

5 

6import json 

7import logging 

8import time 

9from typing import Dict, Any 

10 

11import numpy as np 

12 

13from ase.calculators.calculator import Calculator, all_properties 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class LoggingCalculator(Calculator): 

19 """Calculator wrapper to record and plot history of energy and function 

20 evaluations 

21 """ 

22 implemented_properties = all_properties 

23 default_parameters: Dict[str, Any] = {} 

24 name = 'LoggingCalculator' 

25 

26 property_to_method_name = { 

27 'energy': 'get_potential_energy', 

28 'energies': 'get_potential_energies', 

29 'forces': 'get_forces', 

30 'stress': 'get_stress', 

31 'stresses': 'get_stresses'} 

32 

33 def __init__(self, calculator, jsonfile=None, dumpjson=False): 

34 Calculator.__init__(self) 

35 self.calculator = calculator 

36 self.fmax = {} 

37 self.walltime = {} 

38 self.energy_evals = {} 

39 self.energy_count = {} 

40 self.set_label('(none)') 

41 if jsonfile is not None: 

42 self.read_json(jsonfile) 

43 self.dumpjson = dumpjson 

44 

45 def calculate(self, atoms, properties, system_changes): 

46 Calculator.calculate(self, atoms, properties, system_changes) 

47 

48 if isinstance(self.calculator, Calculator): 

49 results = [self.calculator.get_property(prop, atoms) 

50 for prop in properties] 

51 else: 

52 results = [] 

53 for prop in properties: 

54 method_name = self.property_to_method_name[prop] 

55 method = getattr(self.calculator, method_name) 

56 results.append(method(atoms)) 

57 

58 if 'energy' in properties or 'energies' in properties: 

59 self.energy_evals.setdefault(self.label, 0) 

60 self.energy_evals[self.label] += 1 

61 try: 

62 energy = results[properties.index('energy')] 

63 except IndexError: 

64 energy = sum(results[properties.index('energies')]) 

65 logger.info('energy call count=%d energy=%.3f', 

66 self.energy_evals[self.label], energy) 

67 self.results = dict(zip(properties, results)) 

68 

69 if 'forces' in self.results: 

70 fmax = self.fmax.setdefault(self.label, []) 

71 walltime = self.walltime.setdefault(self.label, []) 

72 forces = self.results['forces'].copy() 

73 energy_count = self.energy_count.setdefault(self.label, []) 

74 energy_evals = self.energy_evals.setdefault(self.label, 0) 

75 energy_count.append(energy_evals) 

76 for constraint in atoms.constraints: 

77 constraint.adjust_forces(atoms, forces) 

78 fmax.append(abs(forces).max()) 

79 walltime.append(time.time()) 

80 logger.info('force call fmax=%.3f', fmax[-1]) 

81 

82 if self.dumpjson: 

83 self.write_json('dump.json') 

84 

85 def write_json(self, filename): 

86 with open(filename, 'w') as fd: 

87 json.dump({'fmax': self.fmax, 

88 'walltime': self.walltime, 

89 'energy_evals': self.energy_evals, 

90 'energy_count': self.energy_count}, fd) 

91 

92 def read_json(self, filename, append=False, label=None): 

93 with open(filename, 'r') as fd: 

94 dct = json.load(fd) 

95 

96 labels = dct['fmax'].keys() 

97 if label is not None and len(labels) == 1: 

98 for key in ('fmax', 'walltime', 'energy_evals', 'energy_count'): 

99 dct[key][label] = dct[key][labels[0]] 

100 del dct[key][labels[0]] 

101 if not append: 

102 self.fmax = {} 

103 self.walltime = {} 

104 self.energy_evals = {} 

105 self.energy_count = {} 

106 self.fmax.update(dct['fmax']) 

107 self.walltime.update(dct['walltime']) 

108 self.energy_evals.update(dct['energy_evals']) 

109 self.energy_count.update(dct['energy_count']) 

110 

111 def tabulate(self): 

112 fmt1 = '%-10s %10s %10s %8s' 

113 title = fmt1 % ('Label', '# Force', '# Energy', 'Walltime/s') 

114 print(title) 

115 print('-' * len(title)) 

116 fmt2 = '%-10s %10d %10d %8.2f' 

117 for label in sorted(self.fmax.keys()): 

118 print(fmt2 % (label, len(self.fmax[label]), 

119 len(self.energy_count[label]), 

120 self.walltime[label][-1] - self.walltime[label][0])) 

121 

122 def plot(self, fmaxlim=(1e-2, 1e2), forces=True, energy=True, 

123 walltime=True, 

124 markers=None, labels=None, **kwargs): 

125 import matplotlib.pyplot as plt 

126 

127 if markers is None: 

128 markers = [c + s for c in ['r', 'g', 'b', 'c', 'm', 'y', 'k'] 

129 for s in ['.-', '.--']] 

130 nsub = sum([forces, energy, walltime]) 

131 nplot = 0 

132 

133 if labels is not None: 

134 fmax_values = [v for (k, v) in sorted(zip(self.fmax.keys(), 

135 self.fmax.values()))] 

136 self.fmax = dict(zip(labels, fmax_values)) 

137 

138 energy_count_values = [v for (k, v) in 

139 sorted(zip(self.energy_count.keys(), 

140 self.energy_count.values()))] 

141 self.energy_count = dict(zip(labels, energy_count_values)) 

142 

143 walltime_values = [v for (k, v) in 

144 sorted(zip(self.walltime.keys(), 

145 self.walltime.values()))] 

146 self.walltime = dict(zip(labels, walltime_values)) 

147 

148 if forces: 

149 nplot += 1 

150 plt.subplot(nsub, 1, nplot) 

151 for label, color in zip(sorted(self.fmax.keys()), markers): 

152 fmax = np.array(self.fmax[label]) 

153 idx = np.arange(len(fmax)) 

154 plt.semilogy(idx, fmax, color, label=label, **kwargs) 

155 

156 plt.xlabel('Number of force evaluations') 

157 plt.ylabel('Maximum force / eV/A') 

158 plt.ylim(*fmaxlim) 

159 plt.legend() 

160 

161 if energy: 

162 nplot += 1 

163 plt.subplot(nsub, 1, nplot) 

164 for label, color in zip(sorted(self.energy_count.keys()), markers): 

165 energy_count = np.array(self.energy_count[label]) 

166 fmax = np.array(self.fmax[label]) 

167 plt.semilogy(energy_count, fmax, color, label=label, **kwargs) 

168 

169 plt.xlabel('Number of energy evaluations') 

170 plt.ylabel('Maximum force / eV/A') 

171 plt.ylim(*fmaxlim) 

172 plt.legend() 

173 

174 if walltime: 

175 nplot += 1 

176 plt.subplot(nsub, 1, nplot) 

177 for label, color in zip(sorted(self.walltime.keys()), markers): 

178 walltime = np.array(self.walltime[label]) 

179 fmax = np.array(self.fmax[label]) 

180 walltime -= walltime[0] 

181 plt.semilogy(walltime, fmax, color, label=label, **kwargs) 

182 

183 plt.xlabel('Walltime / s') 

184 plt.ylabel('Maximum force / eV/A') 

185 plt.ylim(*fmaxlim) 

186 plt.legend() 

187 

188 plt.subplots_adjust(hspace=0.33)