Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Provides LoggingCalculator class to wrap a Calculator and record
3number of enery and force calls
4"""
6import json
7import logging
8import time
9from typing import Dict, Any
11import numpy as np
13from ase.calculators.calculator import Calculator, all_properties
15logger = logging.getLogger(__name__)
18class LoggingCalculator(Calculator):
19 """Calculator wrapper to record and plot history of energy and function
20 evaluations
21 """
22 implemented_properties = all_properties
23 default_parameters: Dict[str, Any] = {}
24 name = 'LoggingCalculator'
26 property_to_method_name = {
27 'energy': 'get_potential_energy',
28 'energies': 'get_potential_energies',
29 'forces': 'get_forces',
30 'stress': 'get_stress',
31 'stresses': 'get_stresses'}
33 def __init__(self, calculator, jsonfile=None, dumpjson=False):
34 Calculator.__init__(self)
35 self.calculator = calculator
36 self.fmax = {}
37 self.walltime = {}
38 self.energy_evals = {}
39 self.energy_count = {}
40 self.set_label('(none)')
41 if jsonfile is not None:
42 self.read_json(jsonfile)
43 self.dumpjson = dumpjson
45 def calculate(self, atoms, properties, system_changes):
46 Calculator.calculate(self, atoms, properties, system_changes)
48 if isinstance(self.calculator, Calculator):
49 results = [self.calculator.get_property(prop, atoms)
50 for prop in properties]
51 else:
52 results = []
53 for prop in properties:
54 method_name = self.property_to_method_name[prop]
55 method = getattr(self.calculator, method_name)
56 results.append(method(atoms))
58 if 'energy' in properties or 'energies' in properties:
59 self.energy_evals.setdefault(self.label, 0)
60 self.energy_evals[self.label] += 1
61 try:
62 energy = results[properties.index('energy')]
63 except IndexError:
64 energy = sum(results[properties.index('energies')])
65 logger.info('energy call count=%d energy=%.3f',
66 self.energy_evals[self.label], energy)
67 self.results = dict(zip(properties, results))
69 if 'forces' in self.results:
70 fmax = self.fmax.setdefault(self.label, [])
71 walltime = self.walltime.setdefault(self.label, [])
72 forces = self.results['forces'].copy()
73 energy_count = self.energy_count.setdefault(self.label, [])
74 energy_evals = self.energy_evals.setdefault(self.label, 0)
75 energy_count.append(energy_evals)
76 for constraint in atoms.constraints:
77 constraint.adjust_forces(atoms, forces)
78 fmax.append(abs(forces).max())
79 walltime.append(time.time())
80 logger.info('force call fmax=%.3f', fmax[-1])
82 if self.dumpjson:
83 self.write_json('dump.json')
85 def write_json(self, filename):
86 with open(filename, 'w') as fd:
87 json.dump({'fmax': self.fmax,
88 'walltime': self.walltime,
89 'energy_evals': self.energy_evals,
90 'energy_count': self.energy_count}, fd)
92 def read_json(self, filename, append=False, label=None):
93 with open(filename, 'r') as fd:
94 dct = json.load(fd)
96 labels = dct['fmax'].keys()
97 if label is not None and len(labels) == 1:
98 for key in ('fmax', 'walltime', 'energy_evals', 'energy_count'):
99 dct[key][label] = dct[key][labels[0]]
100 del dct[key][labels[0]]
101 if not append:
102 self.fmax = {}
103 self.walltime = {}
104 self.energy_evals = {}
105 self.energy_count = {}
106 self.fmax.update(dct['fmax'])
107 self.walltime.update(dct['walltime'])
108 self.energy_evals.update(dct['energy_evals'])
109 self.energy_count.update(dct['energy_count'])
111 def tabulate(self):
112 fmt1 = '%-10s %10s %10s %8s'
113 title = fmt1 % ('Label', '# Force', '# Energy', 'Walltime/s')
114 print(title)
115 print('-' * len(title))
116 fmt2 = '%-10s %10d %10d %8.2f'
117 for label in sorted(self.fmax.keys()):
118 print(fmt2 % (label, len(self.fmax[label]),
119 len(self.energy_count[label]),
120 self.walltime[label][-1] - self.walltime[label][0]))
122 def plot(self, fmaxlim=(1e-2, 1e2), forces=True, energy=True,
123 walltime=True,
124 markers=None, labels=None, **kwargs):
125 import matplotlib.pyplot as plt
127 if markers is None:
128 markers = [c + s for c in ['r', 'g', 'b', 'c', 'm', 'y', 'k']
129 for s in ['.-', '.--']]
130 nsub = sum([forces, energy, walltime])
131 nplot = 0
133 if labels is not None:
134 fmax_values = [v for (k, v) in sorted(zip(self.fmax.keys(),
135 self.fmax.values()))]
136 self.fmax = dict(zip(labels, fmax_values))
138 energy_count_values = [v for (k, v) in
139 sorted(zip(self.energy_count.keys(),
140 self.energy_count.values()))]
141 self.energy_count = dict(zip(labels, energy_count_values))
143 walltime_values = [v for (k, v) in
144 sorted(zip(self.walltime.keys(),
145 self.walltime.values()))]
146 self.walltime = dict(zip(labels, walltime_values))
148 if forces:
149 nplot += 1
150 plt.subplot(nsub, 1, nplot)
151 for label, color in zip(sorted(self.fmax.keys()), markers):
152 fmax = np.array(self.fmax[label])
153 idx = np.arange(len(fmax))
154 plt.semilogy(idx, fmax, color, label=label, **kwargs)
156 plt.xlabel('Number of force evaluations')
157 plt.ylabel('Maximum force / eV/A')
158 plt.ylim(*fmaxlim)
159 plt.legend()
161 if energy:
162 nplot += 1
163 plt.subplot(nsub, 1, nplot)
164 for label, color in zip(sorted(self.energy_count.keys()), markers):
165 energy_count = np.array(self.energy_count[label])
166 fmax = np.array(self.fmax[label])
167 plt.semilogy(energy_count, fmax, color, label=label, **kwargs)
169 plt.xlabel('Number of energy evaluations')
170 plt.ylabel('Maximum force / eV/A')
171 plt.ylim(*fmaxlim)
172 plt.legend()
174 if walltime:
175 nplot += 1
176 plt.subplot(nsub, 1, nplot)
177 for label, color in zip(sorted(self.walltime.keys()), markers):
178 walltime = np.array(self.walltime[label])
179 fmax = np.array(self.fmax[label])
180 walltime -= walltime[0]
181 plt.semilogy(walltime, fmax, color, label=label, **kwargs)
183 plt.xlabel('Walltime / s')
184 plt.ylabel('Maximum force / eV/A')
185 plt.ylim(*fmaxlim)
186 plt.legend()
188 plt.subplots_adjust(hspace=0.33)