Coverage for /builds/debichem-team/python-ase/ase/io/vtkxml.py: 5.33%
75 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import numpy as np
3fast = False
6def write_vti(filename, atoms, data=None):
7 from vtk import vtkDoubleArray, vtkStructuredPoints, vtkXMLImageDataWriter
9 # if isinstance(fileobj, str):
10 # fileobj = paropen(fileobj, 'w')
12 if isinstance(atoms, list):
13 if len(atoms) > 1:
14 raise ValueError('Can only write one configuration to a VTI file!')
15 atoms = atoms[0]
17 if data is None:
18 raise ValueError('VTK XML Image Data (VTI) format requires data!')
20 data = np.asarray(data)
22 if data.dtype == complex:
23 data = np.abs(data)
25 cell = atoms.get_cell()
27 if not np.all(cell == np.diag(np.diag(cell))):
28 raise ValueError('Unit cell must be orthogonal')
30 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel()
32 # Create a VTK grid of structured points
33 spts = vtkStructuredPoints()
34 spts.SetWholeBoundingBox(bbox)
35 spts.SetDimensions(data.shape)
36 spts.SetSpacing(cell.diagonal() / data.shape)
37 # spts.SetSpacing(paw.gd.h_c * Bohr)
39 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr)
40 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape)
41 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape)
43 # s = paw.wfs.kpt_u[0].psit_nG[0].copy()
44 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False)
45 # spts.point_data.scalars = data.swapaxes(0,2).flatten()
46 # spts.point_data.scalars.name = 'scalars'
48 # Allocate a VTK array of type double and copy data
49 da = vtkDoubleArray()
50 da.SetName('scalars')
51 da.SetNumberOfComponents(1)
52 da.SetNumberOfTuples(np.prod(data.shape))
54 for i, d in enumerate(data.swapaxes(0, 2).flatten()):
55 da.SetTuple1(i, d)
57 # Assign the VTK array as point data of the grid
58 spd = spts.GetPointData() # type(spd) is vtkPointData
59 spd.SetScalars(da)
61 """
62 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray
63 iia = vtkImageImportFromArray()
64 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten()))
65 iia.SetArray(Numeric_asarray(data))
66 ida = iia.GetOutput()
67 ipd = ida.GetPointData()
68 ipd.SetName('scalars')
69 spd.SetScalars(ipd.GetScalars())
70 """
72 # Save the ImageData dataset to a VTK XML file.
73 w = vtkXMLImageDataWriter()
75 if fast:
76 w.SetDataModeToAppend()
77 w.EncodeAppendedDataOff()
78 else:
79 w.SetDataModeToAscii()
81 w.SetFileName(filename)
82 w.SetInput(spts)
83 w.Write()
86def write_vtu(filename, atoms, data=None):
87 from vtk import (
88 VTK_MAJOR_VERSION,
89 vtkPoints,
90 vtkUnstructuredGrid,
91 vtkXMLUnstructuredGridWriter,
92 )
93 from vtk.util.numpy_support import numpy_to_vtk
95 if isinstance(atoms, list):
96 if len(atoms) > 1:
97 raise ValueError('Can only write one configuration to a VTI file!')
98 atoms = atoms[0]
100 # Create a VTK grid of structured points
101 ugd = vtkUnstructuredGrid()
103 # add atoms as vtk Points
104 p = vtkPoints()
105 p.SetNumberOfPoints(len(atoms))
106 p.SetDataTypeToDouble()
107 for i, pos in enumerate(atoms.get_positions()):
108 p.InsertPoint(i, *pos)
109 ugd.SetPoints(p)
111 # add atomic numbers
112 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1)
113 ugd.GetPointData().AddArray(numbers)
114 numbers.SetName("atomic numbers")
116 # add tags
117 tags = numpy_to_vtk(atoms.get_tags(), deep=1)
118 ugd.GetPointData().AddArray(tags)
119 tags.SetName("tags")
121 # add covalent radii
122 from ase.data import covalent_radii
123 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1)
124 ugd.GetPointData().AddArray(radii)
125 radii.SetName("radii")
127 # Save the UnstructuredGrid dataset to a VTK XML file.
128 w = vtkXMLUnstructuredGridWriter()
130 if fast:
131 w.SetDataModeToAppend()
132 w.EncodeAppendedDataOff()
133 else:
134 w.GetCompressor().SetCompressionLevel(0)
135 w.SetDataModeToAscii()
137 if isinstance(filename, str):
138 w.SetFileName(filename)
139 else:
140 w.SetFileName(filename.name)
141 if VTK_MAJOR_VERSION <= 5:
142 w.SetInput(ugd)
143 else:
144 w.SetInputData(ugd)
145 w.Write()