Coverage for eminus/io/json.py: 98.31%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-08 12:59 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""JSON file handling.""" 

4 

5import base64 

6import copy 

7import json 

8 

9import numpy as np 

10 

11 

12def _custom_object_hook(dct): # noqa: PLR0911 

13 """Custom JSON object hook to create eminus classes after deserialization.""" 

14 import eminus 

15 

16 def set_attrs(obj, dct): 

17 """Set attributes of an object using a given dictionary.""" 

18 for attr in dct: 

19 if attr == "_log": 

20 continue 

21 setattr(obj, attr, copy.deepcopy(dct[attr])) 

22 return obj 

23 

24 # ndarrays are base64 encoded, decode and recreate 

25 if isinstance(dct, dict) and "__ndarray__" in dct: 

26 data = base64.b64decode(dct["__ndarray__"]) 

27 return np.frombuffer(data, dct["dtype"]).reshape(dct["shape"]) 

28 

29 # Create simple eminus objects and set all attributes afterwards 

30 # Explicitly call objects with verbosity since the logger is created at instantiation 

31 

32 # Atoms objects 

33 if isinstance(dct, dict) and "_atom" in dct: 

34 atoms = eminus.Atoms(dct["_atom"], dct["_pos"], verbose=dct["_verbose"]) 

35 atoms = set_attrs(atoms, dct) 

36 # The tuple type is not preserved when serializing, manually cast the only important one 

37 if isinstance(atoms._active, list): 

38 atoms._active = [tuple(i) for i in atoms._active] 

39 return atoms 

40 # SCF objects 

41 if isinstance(dct, dict) and "_atoms" in dct: 

42 scf = eminus.SCF(dct["_atoms"], verbose=dct["_verbose"]) 

43 return set_attrs(scf, dct) 

44 # Energy objects 

45 if isinstance(dct, dict) and "Ekin" in dct: 

46 energies = eminus.energies.Energy() 

47 return set_attrs(energies, dct) 

48 # GTH objects 

49 if isinstance(dct, dict) and "NbetaNL" in dct: 

50 gth = eminus.gth.GTH() 

51 return set_attrs(gth, dct) 

52 # Occupations objects 

53 if isinstance(dct, dict) and "_Nelec" in dct: 

54 occ = eminus.occupations.Occupations() 

55 return set_attrs(occ, dct) 

56 # KPoints objects 

57 if isinstance(dct, dict) and "_kmesh" in dct: 

58 kpts = eminus.kpoints.KPoints(dct["lattice"]) 

59 return set_attrs(kpts, dct) 

60 return dct 

61 

62 

63def read_json(filename): 

64 """Load objects from a JSON file. 

65 

66 Args: 

67 filename: JSON input file path/name. 

68 

69 Returns: 

70 Class object. 

71 """ 

72 if not filename.endswith(".json"): 

73 filename += ".json" 

74 

75 with open(filename, encoding="utf-8") as fh: 

76 return json.load(fh, object_hook=_custom_object_hook) 

77 

78 

79class _CustomEncoder(json.JSONEncoder): 

80 """Custom JSON encoder class to serialize eminus classes.""" 

81 

82 def default(self, obj): 

83 """Overwrite the default function to handle eminus objects.""" 

84 import eminus 

85 

86 # ndarrays are not JSON serializable, encode them as base64 to save them 

87 if isinstance(obj, np.ndarray): 

88 data = base64.b64encode(obj.copy(order="C")).decode("utf-8") 

89 return {"__ndarray__": data, "dtype": str(obj.dtype), "shape": obj.shape} 

90 

91 # If obj is an eminus class dump them as a dictionary 

92 if isinstance( 

93 obj, 

94 ( 

95 eminus.Atoms, 

96 eminus.SCF, 

97 eminus.energies.Energy, 

98 eminus.gth.GTH, 

99 eminus.kpoints.KPoints, 

100 eminus.occupations.Occupations, 

101 ), 

102 ): 

103 # Only dumping the dict would result in a string, so do one dump and one load 

104 data = json.dumps(obj.__dict__, cls=_CustomEncoder) 

105 return dict(json.loads(data)) 

106 # The logger class is not serializable, just ignore it 

107 if isinstance(obj, eminus.logger.CustomLogger): 

108 return None 

109 return json.JSONEncoder.default(self, obj) 

110 

111 

112def write_json(obj, filename): 

113 """Save objects in a JSON file. 

114 

115 Args: 

116 obj: Class object. 

117 filename: JSON output file path/name. 

118 """ 

119 if not filename.endswith(".json"): 

120 filename += ".json" 

121 

122 with open(filename, "w", encoding="utf-8") as fp: 

123 json.dump(obj, fp, cls=_CustomEncoder)