Coverage for eminus/io/json.py: 98.31%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-18 08:43 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""JSON file handling.""" 

4 

5import base64 

6import copy 

7import json 

8 

9import numpy as np 

10 

11 

12def _custom_object_hook(dct): # noqa: PLR0911 

13 """Custom JSON object hook to create eminus classes after deserialization.""" 

14 import eminus 

15 

16 def set_attrs(obj, dct): 

17 """Set attributes of an object using a given dictionary.""" 

18 for attr in dct: 

19 if attr == "_log": 

20 continue 

21 setattr(obj, attr, copy.deepcopy(dct[attr])) 

22 return obj 

23 

24 # ndarrays are base64 encoded, decode and recreate 

25 if isinstance(dct, dict) and "__ndarray__" in dct: 

26 data = base64.b64decode(dct["__ndarray__"]) 

27 return np.frombuffer(data, dct["dtype"]).reshape(dct["shape"]) 

28 

29 # Create simple eminus objects and set all attributes afterwards 

30 # Explicitly call objects with verbosity since the logger is created at instantiation 

31 

32 # Atoms objects 

33 if isinstance(dct, dict) and "_atom" in dct: 

34 atoms = eminus.Atoms(dct["_atom"], dct["_pos"], verbose=dct["_verbose"]) 

35 atoms = set_attrs(atoms, dct) 

36 # The tuple type is not preserved when serializing, manually cast the only important one 

37 if not isinstance(atoms._active, tuple): 

38 atoms._active = [tuple(i) for i in atoms._active] 

39 return atoms 

40 # SCF objects 

41 if isinstance(dct, dict) and "_atoms" in dct: 

42 scf = eminus.SCF(dct["_atoms"], verbose=dct["_verbose"]) 

43 return set_attrs(scf, dct) 

44 # Energy objects 

45 if isinstance(dct, dict) and "Ekin" in dct: 

46 energies = eminus.energies.Energy() 

47 return set_attrs(energies, dct) 

48 # GTH objects 

49 if isinstance(dct, dict) and "NbetaNL" in dct: 

50 gth = eminus.gth.GTH() 

51 return set_attrs(gth, dct) 

52 # Occupations objects 

53 if isinstance(dct, dict) and "_Nelec" in dct: 

54 occ = eminus.occupations.Occupations() 

55 return set_attrs(occ, dct) 

56 # KPoints objects 

57 if isinstance(dct, dict) and "_kmesh" in dct: 

58 kpts = eminus.kpoints.KPoints(dct["lattice"]) 

59 return set_attrs(kpts, dct) 

60 return dct 

61 

62 

63def read_json(filename): 

64 """Load objects from a JSON file. 

65 

66 Args: 

67 filename: JSON input file path/name. 

68 

69 Returns: 

70 Class object. 

71 """ 

72 if not filename.endswith(".json"): 

73 filename += ".json" 

74 

75 with open(filename, encoding="utf-8") as fh: 

76 return json.load(fh, object_hook=_custom_object_hook) 

77 

78 

79def write_json(obj, filename): 

80 """Save objects in a JSON file. 

81 

82 Args: 

83 obj: Class object. 

84 filename: JSON output file path/name. 

85 """ 

86 import eminus 

87 

88 class _CustomEncoder(json.JSONEncoder): 

89 """Custom JSON encoder class to serialize eminus classes.""" 

90 

91 def default(self, obj): 

92 """Overwrite the default function to handle eminus objects.""" 

93 # ndarrays are not JSON serializable, encode them as base64 to save them 

94 if isinstance(obj, np.ndarray): 

95 data = base64.b64encode(obj.copy(order="C")).decode("utf-8") 

96 return {"__ndarray__": data, "dtype": str(obj.dtype), "shape": obj.shape} 

97 

98 # If obj is an eminus class dump them as a dictionary 

99 if isinstance( 

100 obj, 

101 ( 

102 eminus.Atoms, 

103 eminus.SCF, 

104 eminus.energies.Energy, 

105 eminus.gth.GTH, 

106 eminus.kpoints.KPoints, 

107 eminus.occupations.Occupations, 

108 ), 

109 ): 

110 # Only dumping the dict would result in a string, so do one dump and one load 

111 data = json.dumps(obj.__dict__, cls=_CustomEncoder) 

112 return dict(json.loads(data)) 

113 # The logger class is not serializable, just ignore it 

114 if isinstance(obj, eminus.logger.CustomLogger): 

115 return None 

116 return json.JSONEncoder.default(self, obj) 

117 

118 if not filename.endswith(".json"): 

119 filename += ".json" 

120 

121 with open(filename, "w", encoding="utf-8") as fp: 

122 json.dump(obj, fp, cls=_CustomEncoder)