Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import numpy as np 

2 

3 

4fast = False 

5 

6 

7def write_vti(filename, atoms, data=None): 

8 from vtk import vtkStructuredPoints, vtkDoubleArray, vtkXMLImageDataWriter 

9 

10 # if isinstance(fileobj, str): 

11 # fileobj = paropen(fileobj, 'w') 

12 

13 if isinstance(atoms, list): 

14 if len(atoms) > 1: 

15 raise ValueError('Can only write one configuration to a VTI file!') 

16 atoms = atoms[0] 

17 

18 if data is None: 

19 raise ValueError('VTK XML Image Data (VTI) format requires data!') 

20 

21 data = np.asarray(data) 

22 

23 if data.dtype == complex: 

24 data = np.abs(data) 

25 

26 cell = atoms.get_cell() 

27 

28 if not np.all(cell == np.diag(np.diag(cell))): 

29 raise ValueError('Unit cell must be orthogonal') 

30 

31 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel() 

32 

33 # Create a VTK grid of structured points 

34 spts = vtkStructuredPoints() 

35 spts.SetWholeBoundingBox(bbox) 

36 spts.SetDimensions(data.shape) 

37 spts.SetSpacing(cell.diagonal() / data.shape) 

38 # spts.SetSpacing(paw.gd.h_c * Bohr) 

39 

40 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr) 

41 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape) 

42 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape) 

43 

44 # s = paw.wfs.kpt_u[0].psit_nG[0].copy() 

45 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False) 

46 # spts.point_data.scalars = data.swapaxes(0,2).flatten() 

47 # spts.point_data.scalars.name = 'scalars' 

48 

49 # Allocate a VTK array of type double and copy data 

50 da = vtkDoubleArray() 

51 da.SetName('scalars') 

52 da.SetNumberOfComponents(1) 

53 da.SetNumberOfTuples(np.prod(data.shape)) 

54 

55 for i, d in enumerate(data.swapaxes(0, 2).flatten()): 

56 da.SetTuple1(i, d) 

57 

58 # Assign the VTK array as point data of the grid 

59 spd = spts.GetPointData() # type(spd) is vtkPointData 

60 spd.SetScalars(da) 

61 

62 """ 

63 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray 

64 iia = vtkImageImportFromArray() 

65 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten())) 

66 iia.SetArray(Numeric_asarray(data)) 

67 ida = iia.GetOutput() 

68 ipd = ida.GetPointData() 

69 ipd.SetName('scalars') 

70 spd.SetScalars(ipd.GetScalars()) 

71 """ 

72 

73 # Save the ImageData dataset to a VTK XML file. 

74 w = vtkXMLImageDataWriter() 

75 

76 if fast: 

77 w.SetDataModeToAppend() 

78 w.EncodeAppendedDataOff() 

79 else: 

80 w.SetDataModeToAscii() 

81 

82 w.SetFileName(filename) 

83 w.SetInput(spts) 

84 w.Write() 

85 

86 

87def write_vtu(filename, atoms, data=None): 

88 from vtk import (VTK_MAJOR_VERSION, vtkUnstructuredGrid, vtkPoints, 

89 vtkXMLUnstructuredGridWriter) 

90 from vtk.util.numpy_support import numpy_to_vtk 

91 

92 if isinstance(atoms, list): 

93 if len(atoms) > 1: 

94 raise ValueError('Can only write one configuration to a VTI file!') 

95 atoms = atoms[0] 

96 

97 # Create a VTK grid of structured points 

98 ugd = vtkUnstructuredGrid() 

99 

100 # add atoms as vtk Points 

101 p = vtkPoints() 

102 p.SetNumberOfPoints(len(atoms)) 

103 p.SetDataTypeToDouble() 

104 for i, pos in enumerate(atoms.get_positions()): 

105 p.InsertPoint(i, *pos) 

106 ugd.SetPoints(p) 

107 

108 # add atomic numbers 

109 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1) 

110 ugd.GetPointData().AddArray(numbers) 

111 numbers.SetName("atomic numbers") 

112 

113 # add tags 

114 tags = numpy_to_vtk(atoms.get_tags(), deep=1) 

115 ugd.GetPointData().AddArray(tags) 

116 tags.SetName("tags") 

117 

118 # add covalent radii 

119 from ase.data import covalent_radii 

120 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1) 

121 ugd.GetPointData().AddArray(radii) 

122 radii.SetName("radii") 

123 

124 # Save the UnstructuredGrid dataset to a VTK XML file. 

125 w = vtkXMLUnstructuredGridWriter() 

126 

127 if fast: 

128 w.SetDataModeToAppend() 

129 w.EncodeAppendedDataOff() 

130 else: 

131 w.GetCompressor().SetCompressionLevel(0) 

132 w.SetDataModeToAscii() 

133 

134 if isinstance(filename, str): 

135 w.SetFileName(filename) 

136 else: 

137 w.SetFileName(filename.name) 

138 if VTK_MAJOR_VERSION <= 5: 

139 w.SetInput(ugd) 

140 else: 

141 w.SetInputData(ugd) 

142 w.Write()