Coverage for eminus/io/json.py: 98.36%
61 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
1# SPDX-FileCopyrightText: 2023 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""JSON file handling."""
5import base64
6import copy
7import json
9import numpy as np
11from eminus import backend as xp
14def _custom_object_hook(dct):
15 """Custom JSON object hook to create eminus classes after deserialization."""
16 import eminus
18 def set_attrs(obj, dct):
19 """Set attributes of an object using a given dictionary."""
20 for attr in dct:
21 if attr == "_log":
22 continue
23 setattr(obj, attr, copy.deepcopy(dct[attr]))
24 return obj
26 # ndarrays are base64 encoded, decode and recreate
27 if isinstance(dct, dict) and "__ndarray__" in dct:
28 data = base64.b64decode(dct["__ndarray__"])
29 return xp.asarray(copy.deepcopy(np.frombuffer(data, dct["dtype"]).reshape(dct["shape"])))
31 # Create simple eminus objects and set all attributes afterwards
32 # Explicitly call objects with verbosity since the logger is created at instantiation
34 # Atoms objects
35 if isinstance(dct, dict) and "_atom" in dct:
36 atoms = eminus.Atoms(dct["_atom"], dct["_pos"], verbose=dct["_verbose"])
37 atoms = set_attrs(atoms, dct)
38 # The tuple type is not preserved when serializing, manually cast the only important one
39 if atoms._active is not None and not isinstance(atoms._active, tuple):
40 atoms._active = [(xp.asarray(i[0]),) for i in atoms._active]
41 return atoms
42 # SCF objects
43 if isinstance(dct, dict) and "_atoms" in dct:
44 scf = eminus.SCF(dct["_atoms"], verbose=dct["_verbose"])
45 return set_attrs(scf, dct)
46 # Energy objects
47 if isinstance(dct, dict) and "Ekin" in dct:
48 energies = eminus.energies.Energy()
49 return set_attrs(energies, dct)
50 # GTH objects
51 if isinstance(dct, dict) and "NbetaNL" in dct:
52 gth = eminus.gth.GTH()
53 return set_attrs(gth, dct)
54 # Occupations objects
55 if isinstance(dct, dict) and "_Nelec" in dct:
56 occ = eminus.occupations.Occupations()
57 return set_attrs(occ, dct)
58 # KPoints objects
59 if isinstance(dct, dict) and "_kmesh" in dct:
60 kpts = eminus.kpoints.KPoints(dct["lattice"])
61 return set_attrs(kpts, dct)
62 return dct
65def read_json(filename):
66 """Load objects from a JSON file.
68 Args:
69 filename: JSON input file path/name.
71 Returns:
72 Class object.
73 """
74 if not filename.endswith(".json"):
75 filename += ".json"
77 with open(filename, encoding="utf-8") as fh:
78 return json.load(fh, object_hook=_custom_object_hook)
81def write_json(obj, filename):
82 """Save objects in a JSON file.
84 Args:
85 obj: Class object.
86 filename: JSON output file path/name.
87 """
88 import eminus
90 class _CustomEncoder(json.JSONEncoder):
91 """Custom JSON encoder class to serialize eminus classes."""
93 def default(self, o):
94 """Overwrite the default function to handle eminus objects."""
95 # ndarrays are not JSON serializable, encode them as base64 to save them
96 if xp.is_array(o):
97 o = xp.to_np(o)
98 data = base64.b64encode(o.copy(order="C")).decode("utf-8")
99 return {"__ndarray__": data, "dtype": str(o.dtype), "shape": o.shape}
101 # If obj is an eminus class dump them as a dictionary
102 if isinstance(
103 o,
104 (
105 eminus.Atoms,
106 eminus.SCF,
107 eminus.energies.Energy,
108 eminus.gth.GTH,
109 eminus.kpoints.KPoints,
110 eminus.occupations.Occupations,
111 ),
112 ):
113 # Only dumping the dict would result in a string, so do one dump and one load
114 data = json.dumps(o.__dict__, cls=_CustomEncoder)
115 return dict(json.loads(data))
116 # The logger class is not serializable, just ignore it
117 if isinstance(o, eminus.logger.CustomLogger):
118 return None
119 return json.JSONEncoder.default(self, o)
121 if not filename.endswith(".json"):
122 filename += ".json"
124 with open(filename, "w", encoding="utf-8") as fp:
125 json.dump(obj, fp, cls=_CustomEncoder)