Coverage for eminus/io/json.py: 98.31%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""JSON file handling."""
5import base64
6import copy
7import json
9import numpy as np
12def _custom_object_hook(dct): # noqa: PLR0911
13 """Custom JSON object hook to create eminus classes after deserialization."""
14 import eminus
16 def set_attrs(obj, dct):
17 """Set attributes of an object using a given dictionary."""
18 for attr in dct:
19 if attr == "_log":
20 continue
21 setattr(obj, attr, copy.deepcopy(dct[attr]))
22 return obj
24 # ndarrays are base64 encoded, decode and recreate
25 if isinstance(dct, dict) and "__ndarray__" in dct:
26 data = base64.b64decode(dct["__ndarray__"])
27 return np.frombuffer(data, dct["dtype"]).reshape(dct["shape"])
29 # Create simple eminus objects and set all attributes afterwards
30 # Explicitly call objects with verbosity since the logger is created at instantiation
32 # Atoms objects
33 if isinstance(dct, dict) and "_atom" in dct:
34 atoms = eminus.Atoms(dct["_atom"], dct["_pos"], verbose=dct["_verbose"])
35 atoms = set_attrs(atoms, dct)
36 # The tuple type is not preserved when serializing, manually cast the only important one
37 if not isinstance(atoms._active, tuple):
38 atoms._active = [tuple(i) for i in atoms._active]
39 return atoms
40 # SCF objects
41 if isinstance(dct, dict) and "_atoms" in dct:
42 scf = eminus.SCF(dct["_atoms"], verbose=dct["_verbose"])
43 return set_attrs(scf, dct)
44 # Energy objects
45 if isinstance(dct, dict) and "Ekin" in dct:
46 energies = eminus.energies.Energy()
47 return set_attrs(energies, dct)
48 # GTH objects
49 if isinstance(dct, dict) and "NbetaNL" in dct:
50 gth = eminus.gth.GTH()
51 return set_attrs(gth, dct)
52 # Occupations objects
53 if isinstance(dct, dict) and "_Nelec" in dct:
54 occ = eminus.occupations.Occupations()
55 return set_attrs(occ, dct)
56 # KPoints objects
57 if isinstance(dct, dict) and "_kmesh" in dct:
58 kpts = eminus.kpoints.KPoints(dct["lattice"])
59 return set_attrs(kpts, dct)
60 return dct
63def read_json(filename):
64 """Load objects from a JSON file.
66 Args:
67 filename: JSON input file path/name.
69 Returns:
70 Class object.
71 """
72 if not filename.endswith(".json"):
73 filename += ".json"
75 with open(filename, encoding="utf-8") as fh:
76 return json.load(fh, object_hook=_custom_object_hook)
79def write_json(obj, filename):
80 """Save objects in a JSON file.
82 Args:
83 obj: Class object.
84 filename: JSON output file path/name.
85 """
86 import eminus
88 class _CustomEncoder(json.JSONEncoder):
89 """Custom JSON encoder class to serialize eminus classes."""
91 def default(self, obj):
92 """Overwrite the default function to handle eminus objects."""
93 # ndarrays are not JSON serializable, encode them as base64 to save them
94 if isinstance(obj, np.ndarray):
95 data = base64.b64encode(obj.copy(order="C")).decode("utf-8")
96 return {"__ndarray__": data, "dtype": str(obj.dtype), "shape": obj.shape}
98 # If obj is an eminus class dump them as a dictionary
99 if isinstance(
100 obj,
101 (
102 eminus.Atoms,
103 eminus.SCF,
104 eminus.energies.Energy,
105 eminus.gth.GTH,
106 eminus.kpoints.KPoints,
107 eminus.occupations.Occupations,
108 ),
109 ):
110 # Only dumping the dict would result in a string, so do one dump and one load
111 data = json.dumps(obj.__dict__, cls=_CustomEncoder)
112 return dict(json.loads(data))
113 # The logger class is not serializable, just ignore it
114 if isinstance(obj, eminus.logger.CustomLogger):
115 return None
116 return json.JSONEncoder.default(self, obj)
118 if not filename.endswith(".json"):
119 filename += ".json"
121 with open(filename, "w", encoding="utf-8") as fp:
122 json.dump(obj, fp, cls=_CustomEncoder)