Coverage for eminus/io/json.py: 98.36%

61 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 12:19 +0000

1# SPDX-FileCopyrightText: 2023 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""JSON file handling.""" 

4 

5import base64 

6import copy 

7import json 

8 

9import numpy as np 

10 

11from eminus import backend as xp 

12 

13 

14def _custom_object_hook(dct): 

15 """Custom JSON object hook to create eminus classes after deserialization.""" 

16 import eminus 

17 

18 def set_attrs(obj, dct): 

19 """Set attributes of an object using a given dictionary.""" 

20 for attr in dct: 

21 if attr == "_log": 

22 continue 

23 setattr(obj, attr, copy.deepcopy(dct[attr])) 

24 return obj 

25 

26 # ndarrays are base64 encoded, decode and recreate 

27 if isinstance(dct, dict) and "__ndarray__" in dct: 

28 data = base64.b64decode(dct["__ndarray__"]) 

29 return xp.asarray(copy.deepcopy(np.frombuffer(data, dct["dtype"]).reshape(dct["shape"]))) 

30 

31 # Create simple eminus objects and set all attributes afterwards 

32 # Explicitly call objects with verbosity since the logger is created at instantiation 

33 

34 # Atoms objects 

35 if isinstance(dct, dict) and "_atom" in dct: 

36 atoms = eminus.Atoms(dct["_atom"], dct["_pos"], verbose=dct["_verbose"]) 

37 atoms = set_attrs(atoms, dct) 

38 # The tuple type is not preserved when serializing, manually cast the only important one 

39 if atoms._active is not None and not isinstance(atoms._active, tuple): 

40 atoms._active = [(xp.asarray(i[0]),) for i in atoms._active] 

41 return atoms 

42 # SCF objects 

43 if isinstance(dct, dict) and "_atoms" in dct: 

44 scf = eminus.SCF(dct["_atoms"], verbose=dct["_verbose"]) 

45 return set_attrs(scf, dct) 

46 # Energy objects 

47 if isinstance(dct, dict) and "Ekin" in dct: 

48 energies = eminus.energies.Energy() 

49 return set_attrs(energies, dct) 

50 # GTH objects 

51 if isinstance(dct, dict) and "NbetaNL" in dct: 

52 gth = eminus.gth.GTH() 

53 return set_attrs(gth, dct) 

54 # Occupations objects 

55 if isinstance(dct, dict) and "_Nelec" in dct: 

56 occ = eminus.occupations.Occupations() 

57 return set_attrs(occ, dct) 

58 # KPoints objects 

59 if isinstance(dct, dict) and "_kmesh" in dct: 

60 kpts = eminus.kpoints.KPoints(dct["lattice"]) 

61 return set_attrs(kpts, dct) 

62 return dct 

63 

64 

65def read_json(filename): 

66 """Load objects from a JSON file. 

67 

68 Args: 

69 filename: JSON input file path/name. 

70 

71 Returns: 

72 Class object. 

73 """ 

74 if not filename.endswith(".json"): 

75 filename += ".json" 

76 

77 with open(filename, encoding="utf-8") as fh: 

78 return json.load(fh, object_hook=_custom_object_hook) 

79 

80 

81def write_json(obj, filename): 

82 """Save objects in a JSON file. 

83 

84 Args: 

85 obj: Class object. 

86 filename: JSON output file path/name. 

87 """ 

88 import eminus 

89 

90 class _CustomEncoder(json.JSONEncoder): 

91 """Custom JSON encoder class to serialize eminus classes.""" 

92 

93 def default(self, o): 

94 """Overwrite the default function to handle eminus objects.""" 

95 # ndarrays are not JSON serializable, encode them as base64 to save them 

96 if xp.is_array(o): 

97 o = xp.to_np(o) 

98 data = base64.b64encode(o.copy(order="C")).decode("utf-8") 

99 return {"__ndarray__": data, "dtype": str(o.dtype), "shape": o.shape} 

100 

101 # If obj is an eminus class dump them as a dictionary 

102 if isinstance( 

103 o, 

104 ( 

105 eminus.Atoms, 

106 eminus.SCF, 

107 eminus.energies.Energy, 

108 eminus.gth.GTH, 

109 eminus.kpoints.KPoints, 

110 eminus.occupations.Occupations, 

111 ), 

112 ): 

113 # Only dumping the dict would result in a string, so do one dump and one load 

114 data = json.dumps(o.__dict__, cls=_CustomEncoder) 

115 return dict(json.loads(data)) 

116 # The logger class is not serializable, just ignore it 

117 if isinstance(o, eminus.logger.CustomLogger): 

118 return None 

119 return json.JSONEncoder.default(self, o) 

120 

121 if not filename.endswith(".json"): 

122 filename += ".json" 

123 

124 with open(filename, "w", encoding="utf-8") as fp: 

125 json.dump(obj, fp, cls=_CustomEncoder)