Coverage for eminus/io/json.py: 98.31%
59 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-08 12:59 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-08 12:59 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""JSON file handling."""
5import base64
6import copy
7import json
9import numpy as np
12def _custom_object_hook(dct): # noqa: PLR0911
13 """Custom JSON object hook to create eminus classes after deserialization."""
14 import eminus
16 def set_attrs(obj, dct):
17 """Set attributes of an object using a given dictionary."""
18 for attr in dct:
19 if attr == "_log":
20 continue
21 setattr(obj, attr, copy.deepcopy(dct[attr]))
22 return obj
24 # ndarrays are base64 encoded, decode and recreate
25 if isinstance(dct, dict) and "__ndarray__" in dct:
26 data = base64.b64decode(dct["__ndarray__"])
27 return np.frombuffer(data, dct["dtype"]).reshape(dct["shape"])
29 # Create simple eminus objects and set all attributes afterwards
30 # Explicitly call objects with verbosity since the logger is created at instantiation
32 # Atoms objects
33 if isinstance(dct, dict) and "_atom" in dct:
34 atoms = eminus.Atoms(dct["_atom"], dct["_pos"], verbose=dct["_verbose"])
35 atoms = set_attrs(atoms, dct)
36 # The tuple type is not preserved when serializing, manually cast the only important one
37 if isinstance(atoms._active, list):
38 atoms._active = [tuple(i) for i in atoms._active]
39 return atoms
40 # SCF objects
41 if isinstance(dct, dict) and "_atoms" in dct:
42 scf = eminus.SCF(dct["_atoms"], verbose=dct["_verbose"])
43 return set_attrs(scf, dct)
44 # Energy objects
45 if isinstance(dct, dict) and "Ekin" in dct:
46 energies = eminus.energies.Energy()
47 return set_attrs(energies, dct)
48 # GTH objects
49 if isinstance(dct, dict) and "NbetaNL" in dct:
50 gth = eminus.gth.GTH()
51 return set_attrs(gth, dct)
52 # Occupations objects
53 if isinstance(dct, dict) and "_Nelec" in dct:
54 occ = eminus.occupations.Occupations()
55 return set_attrs(occ, dct)
56 # KPoints objects
57 if isinstance(dct, dict) and "_kmesh" in dct:
58 kpts = eminus.kpoints.KPoints(dct["lattice"])
59 return set_attrs(kpts, dct)
60 return dct
63def read_json(filename):
64 """Load objects from a JSON file.
66 Args:
67 filename: JSON input file path/name.
69 Returns:
70 Class object.
71 """
72 if not filename.endswith(".json"):
73 filename += ".json"
75 with open(filename, encoding="utf-8") as fh:
76 return json.load(fh, object_hook=_custom_object_hook)
79class _CustomEncoder(json.JSONEncoder):
80 """Custom JSON encoder class to serialize eminus classes."""
82 def default(self, obj):
83 """Overwrite the default function to handle eminus objects."""
84 import eminus
86 # ndarrays are not JSON serializable, encode them as base64 to save them
87 if isinstance(obj, np.ndarray):
88 data = base64.b64encode(obj.copy(order="C")).decode("utf-8")
89 return {"__ndarray__": data, "dtype": str(obj.dtype), "shape": obj.shape}
91 # If obj is an eminus class dump them as a dictionary
92 if isinstance(
93 obj,
94 (
95 eminus.Atoms,
96 eminus.SCF,
97 eminus.energies.Energy,
98 eminus.gth.GTH,
99 eminus.kpoints.KPoints,
100 eminus.occupations.Occupations,
101 ),
102 ):
103 # Only dumping the dict would result in a string, so do one dump and one load
104 data = json.dumps(obj.__dict__, cls=_CustomEncoder)
105 return dict(json.loads(data))
106 # The logger class is not serializable, just ignore it
107 if isinstance(obj, eminus.logger.CustomLogger):
108 return None
109 return json.JSONEncoder.default(self, obj)
112def write_json(obj, filename):
113 """Save objects in a JSON file.
115 Args:
116 obj: Class object.
117 filename: JSON output file path/name.
118 """
119 if not filename.endswith(".json"):
120 filename += ".json"
122 with open(filename, "w", encoding="utf-8") as fp:
123 json.dump(obj, fp, cls=_CustomEncoder)