Coverage for /builds/debichem-team/python-ase/ase/io/vtkxml.py: 5.33%

75 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import numpy as np 

2 

3fast = False 

4 

5 

6def write_vti(filename, atoms, data=None): 

7 from vtk import vtkDoubleArray, vtkStructuredPoints, vtkXMLImageDataWriter 

8 

9 # if isinstance(fileobj, str): 

10 # fileobj = paropen(fileobj, 'w') 

11 

12 if isinstance(atoms, list): 

13 if len(atoms) > 1: 

14 raise ValueError('Can only write one configuration to a VTI file!') 

15 atoms = atoms[0] 

16 

17 if data is None: 

18 raise ValueError('VTK XML Image Data (VTI) format requires data!') 

19 

20 data = np.asarray(data) 

21 

22 if data.dtype == complex: 

23 data = np.abs(data) 

24 

25 cell = atoms.get_cell() 

26 

27 if not np.all(cell == np.diag(np.diag(cell))): 

28 raise ValueError('Unit cell must be orthogonal') 

29 

30 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel() 

31 

32 # Create a VTK grid of structured points 

33 spts = vtkStructuredPoints() 

34 spts.SetWholeBoundingBox(bbox) 

35 spts.SetDimensions(data.shape) 

36 spts.SetSpacing(cell.diagonal() / data.shape) 

37 # spts.SetSpacing(paw.gd.h_c * Bohr) 

38 

39 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr) 

40 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape) 

41 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape) 

42 

43 # s = paw.wfs.kpt_u[0].psit_nG[0].copy() 

44 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False) 

45 # spts.point_data.scalars = data.swapaxes(0,2).flatten() 

46 # spts.point_data.scalars.name = 'scalars' 

47 

48 # Allocate a VTK array of type double and copy data 

49 da = vtkDoubleArray() 

50 da.SetName('scalars') 

51 da.SetNumberOfComponents(1) 

52 da.SetNumberOfTuples(np.prod(data.shape)) 

53 

54 for i, d in enumerate(data.swapaxes(0, 2).flatten()): 

55 da.SetTuple1(i, d) 

56 

57 # Assign the VTK array as point data of the grid 

58 spd = spts.GetPointData() # type(spd) is vtkPointData 

59 spd.SetScalars(da) 

60 

61 """ 

62 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray 

63 iia = vtkImageImportFromArray() 

64 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten())) 

65 iia.SetArray(Numeric_asarray(data)) 

66 ida = iia.GetOutput() 

67 ipd = ida.GetPointData() 

68 ipd.SetName('scalars') 

69 spd.SetScalars(ipd.GetScalars()) 

70 """ 

71 

72 # Save the ImageData dataset to a VTK XML file. 

73 w = vtkXMLImageDataWriter() 

74 

75 if fast: 

76 w.SetDataModeToAppend() 

77 w.EncodeAppendedDataOff() 

78 else: 

79 w.SetDataModeToAscii() 

80 

81 w.SetFileName(filename) 

82 w.SetInput(spts) 

83 w.Write() 

84 

85 

86def write_vtu(filename, atoms, data=None): 

87 from vtk import ( 

88 VTK_MAJOR_VERSION, 

89 vtkPoints, 

90 vtkUnstructuredGrid, 

91 vtkXMLUnstructuredGridWriter, 

92 ) 

93 from vtk.util.numpy_support import numpy_to_vtk 

94 

95 if isinstance(atoms, list): 

96 if len(atoms) > 1: 

97 raise ValueError('Can only write one configuration to a VTI file!') 

98 atoms = atoms[0] 

99 

100 # Create a VTK grid of structured points 

101 ugd = vtkUnstructuredGrid() 

102 

103 # add atoms as vtk Points 

104 p = vtkPoints() 

105 p.SetNumberOfPoints(len(atoms)) 

106 p.SetDataTypeToDouble() 

107 for i, pos in enumerate(atoms.get_positions()): 

108 p.InsertPoint(i, *pos) 

109 ugd.SetPoints(p) 

110 

111 # add atomic numbers 

112 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1) 

113 ugd.GetPointData().AddArray(numbers) 

114 numbers.SetName("atomic numbers") 

115 

116 # add tags 

117 tags = numpy_to_vtk(atoms.get_tags(), deep=1) 

118 ugd.GetPointData().AddArray(tags) 

119 tags.SetName("tags") 

120 

121 # add covalent radii 

122 from ase.data import covalent_radii 

123 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1) 

124 ugd.GetPointData().AddArray(radii) 

125 radii.SetName("radii") 

126 

127 # Save the UnstructuredGrid dataset to a VTK XML file. 

128 w = vtkXMLUnstructuredGridWriter() 

129 

130 if fast: 

131 w.SetDataModeToAppend() 

132 w.EncodeAppendedDataOff() 

133 else: 

134 w.GetCompressor().SetCompressionLevel(0) 

135 w.SetDataModeToAscii() 

136 

137 if isinstance(filename, str): 

138 w.SetFileName(filename) 

139 else: 

140 w.SetFileName(filename.name) 

141 if VTK_MAJOR_VERSION <= 5: 

142 w.SetInput(ugd) 

143 else: 

144 w.SetInputData(ugd) 

145 w.Write()