File size: 14,357 Bytes
82a7a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import json
import copy

from tinytroupe.utils import logger

class JsonSerializableRegistry:
    """
    A mixin class that provides JSON serialization, deserialization, and subclass registration.
    """
    
    class_mapping = {}

    def to_json(self, include: list = None, suppress: list = None, file_path: str = None,
                serialization_type_field_name = "json_serializable_class_name") -> dict:
        """
        Returns a JSON representation of the object.
        
        Args:
            include (list, optional): Attributes to include in the serialization. Will override the default behavior.
            suppress (list, optional): Attributes to suppress from the serialization. Will override the default behavior.
            file_path (str, optional): Path to a file where the JSON will be written.
        """
        # Gather all serializable attributes from the class hierarchy
        serializable_attrs = set()
        suppress_attrs = set()
        for cls in self.__class__.__mro__:  # Traverse the class hierarchy
            if hasattr(cls, 'serializable_attributes') and isinstance(cls.serializable_attributes, list):
                serializable_attrs.update(cls.serializable_attributes)
            if hasattr(cls, 'suppress_attributes_from_serialization') and isinstance(cls.suppress_attributes_from_serialization, list):
                suppress_attrs.update(cls.suppress_attributes_from_serialization)
        
        # Override attributes with method parameters if provided
        if include:
            serializable_attrs = set(include)
        if suppress:
            suppress_attrs.update(suppress)
        
        result = {serialization_type_field_name: self.__class__.__name__}
        for attr in serializable_attrs if serializable_attrs else self.__dict__:
            if attr not in suppress_attrs:
                value = getattr(self, attr, None)

                attr_renamed = self._programmatic_name_to_json_name(attr)
                if isinstance(value, JsonSerializableRegistry):
                    result[attr_renamed] = value.to_json(serialization_type_field_name=serialization_type_field_name)
                elif isinstance(value, list):
                    result[attr_renamed] = [item.to_json(serialization_type_field_name=serialization_type_field_name) if isinstance(item, JsonSerializableRegistry) else copy.deepcopy(item) for item in value]
                elif isinstance(value, dict):
                    result[attr_renamed] = {k: v.to_json(serialization_type_field_name=serialization_type_field_name) if isinstance(v, JsonSerializableRegistry) else copy.deepcopy(v) for k, v in value.items()}
                else:
                    result[attr_renamed] = copy.deepcopy(value)
        
        if file_path:
            # Create directories if they do not exist
            import os
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            with open(file_path, 'w') as f:
                json.dump(result, f, indent=4)
        
        return result

    @classmethod
    def from_json(cls, json_dict_or_path, suppress: list = None, 
                  serialization_type_field_name = "json_serializable_class_name", 
                  post_init_params: dict = None):
        """
        Loads a JSON representation of the object and creates an instance of the class.
        
        Args:
            json_dict_or_path (dict or str): The JSON dictionary representing the object or a file path to load the JSON from.
            suppress (list, optional): Attributes to suppress from being loaded.
            
        Returns:
            An instance of the class populated with the data from json_dict_or_path.
        """
        if isinstance(json_dict_or_path, str):
            with open(json_dict_or_path, 'r') as f:
                json_dict = json.load(f)
        else:
            json_dict = json_dict_or_path
        
        subclass_name = json_dict.get(serialization_type_field_name)
        target_class = cls.class_mapping.get(subclass_name, cls)
        instance = target_class.__new__(target_class)  # Create an instance without calling __init__
        
        # Gather all serializable attributes from the class hierarchy
        serializable_attrs = set()
        custom_serialization_initializers = {}
        suppress_attrs = set(suppress) if suppress else set()
        for target_mro in target_class.__mro__:
            if hasattr(target_mro, 'serializable_attributes') and isinstance(target_mro.serializable_attributes, list):
                serializable_attrs.update(target_mro.serializable_attributes)
            if hasattr(target_mro, 'custom_serialization_initializers') and isinstance(target_mro.custom_serialization_initializers, dict):
                custom_serialization_initializers.update(target_mro.custom_serialization_initializers)
            if hasattr(target_mro, 'suppress_attributes_from_serialization') and isinstance(target_mro.suppress_attributes_from_serialization, list):
                suppress_attrs.update(target_mro.suppress_attributes_from_serialization)
        
        # Assign values only for serializable attributes if specified, otherwise assign everything
        for key in serializable_attrs if serializable_attrs else json_dict:
            key_in_json = cls._programmatic_name_to_json_name(key)
            if key_in_json in json_dict and key not in suppress_attrs:
                value = json_dict[key_in_json]
                if key in custom_serialization_initializers:
                    # Use custom initializer if provided
                    setattr(instance, key, custom_serialization_initializers[key](value))
                elif isinstance(value, dict) and serialization_type_field_name in value:
                    # Assume it's another JsonSerializableRegistry object
                    setattr(instance, key, JsonSerializableRegistry.from_json(value, serialization_type_field_name=serialization_type_field_name))
                elif isinstance(value, list):
                    # Handle collections, recursively deserialize if items are JsonSerializableRegistry objects
                    deserialized_collection = []
                    for item in value:
                        if isinstance(item, dict) and serialization_type_field_name in item:
                            deserialized_collection.append(JsonSerializableRegistry.from_json(item, serialization_type_field_name=serialization_type_field_name))
                        else:
                            deserialized_collection.append(copy.deepcopy(item))
                    setattr(instance, key, deserialized_collection)
                else:
                    setattr(instance, key, copy.deepcopy(value))
        
        # Call post-deserialization initialization if available
        if hasattr(instance, '_post_deserialization_init') and callable(instance._post_deserialization_init):
            post_init_params = post_init_params if post_init_params else {}
            instance._post_deserialization_init(**post_init_params)
        
        return instance

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        # Register the subclass using its name as the key
        JsonSerializableRegistry.class_mapping[cls.__name__] = cls
        
        # Automatically extend serializable attributes and custom initializers from parent classes 
        if hasattr(cls, 'serializable_attributes') and isinstance(cls.serializable_attributes, list):
            for base in cls.__bases__:
                if hasattr(base, 'serializable_attributes') and isinstance(base.serializable_attributes, list):
                    cls.serializable_attributes = list(set(base.serializable_attributes + cls.serializable_attributes))
        
        if hasattr(cls, 'suppress_attributes_from_serialization') and isinstance(cls.suppress_attributes_from_serialization, list):
            for base in cls.__bases__:
                if hasattr(base, 'suppress_attributes_from_serialization') and isinstance(base.suppress_attributes_from_serialization, list):
                    cls.suppress_attributes_from_serialization = list(set(base.suppress_attributes_from_serialization + cls.suppress_attributes_from_serialization))
        
        if hasattr(cls, 'custom_serialization_initializers') and isinstance(cls.custom_serialization_initializers, dict):
            for base in cls.__bases__:
                if hasattr(base, 'custom_serialization_initializers') and isinstance(base.custom_serialization_initializers, dict):
                    base_initializers = base.custom_serialization_initializers.copy()
                    base_initializers.update(cls.custom_serialization_initializers)
                    cls.custom_serialization_initializers = base_initializers

    def _post_deserialization_init(self, **kwargs):
        # if there's a _post_init method, call it after deserialization
        if hasattr(self, '_post_init'):
            self._post_init(**kwargs)

    @classmethod
    def _programmatic_name_to_json_name(cls, name):
        """
        Converts a programmatic name to a JSON name by converting it to snake case.
        """
        if hasattr(cls, 'serializable_attributes_renaming') and isinstance(cls.serializable_attributes_renaming, dict):
            return cls.serializable_attributes_renaming.get(name, name)
        return name
    
    @classmethod
    def _json_name_to_programmatic_name(cls, name):
        """
        Converts a JSON name to a programmatic name.
        """
        if hasattr(cls, 'serializable_attributes_renaming') and isinstance(cls.serializable_attributes_renaming, dict):
            reverse_rename = {}
            for k, v in cls.serializable_attributes_renaming.items():
                if v in reverse_rename:
                    raise ValueError(f"Duplicate value '{v}' found in serializable_attributes_renaming.")
                reverse_rename[v] = k
            return reverse_rename.get(name, name)
        return name

def post_init(cls):
    """
    Decorator to enforce a post-initialization method call in a class, if it has one.
    The method must be named `_post_init`.
    """
    original_init = cls.__init__

    def new_init(self, *args, **kwargs):
        original_init(self, *args, **kwargs)
        if hasattr(cls, '_post_init'):
            cls._post_init(self)

    cls.__init__ = new_init
    return cls

def merge_dicts(current, additions, overwrite=False, error_on_conflict=True):
    """
    Merges two dictionaries and returns a new dictionary. Works as follows:
    - If a key exists in the additions dictionary but not in the current dictionary, it is added.
    - If a key maps to None in the current dictionary, it is replaced by the value in the additions dictionary.
    - If a key exists in both dictionaries and the values are dictionaries, the function is called recursively.
    - If a key exists in both dictionaries and the values are lists, the lists are concatenated and duplicates are removed.
    - If the values are of different types, an exception is raised.
    - If the values are of the same type but not both lists/dictionaries, the value from the additions dictionary overwrites the value in the current dictionary based on the overwrite parameter.
    
    Parameters:
    - current (dict): The original dictionary.
    - additions (dict): The dictionary with values to add.
    - overwrite (bool): Whether to overwrite values if they are of the same type but not both lists/dictionaries.
    - error_on_conflict (bool): Whether to raise an error if there is a conflict and overwrite is False.
    
    Returns:
    - dict: A new dictionary with merged values.
    """
    merged = current.copy()  # Create a copy of the current dictionary to avoid altering it

    for key in additions:
        if key in merged:
            # If the current value is None, directly assign the new value
            if merged[key] is None:
                merged[key] = additions[key]
            # If both values are dictionaries, merge them recursively
            elif isinstance(merged[key], dict) and isinstance(additions[key], dict):
                merged[key] = merge_dicts(merged[key], additions[key], overwrite, error_on_conflict)
            # If both values are lists, concatenate them and remove duplicates
            elif isinstance(merged[key], list) and isinstance(additions[key], list):
                merged[key].extend(additions[key])
                # Remove duplicates while preserving order
                merged[key] = remove_duplicates(merged[key])
            # If the values are of different types, raise an exception
            elif type(merged[key]) != type(additions[key]):
                raise TypeError(f"Cannot merge different types: {type(merged[key])} and {type(additions[key])} for key '{key}'")
            # If the values are of the same type but not both lists/dictionaries, decide based on the overwrite parameter
            else:
                if overwrite:
                    merged[key] = additions[key]
                elif merged[key] != additions[key]:
                    if error_on_conflict:
                        raise ValueError(f"Conflict at key '{key}': overwrite is set to False and values are different.")
                    else:
                        continue  # Ignore the conflict and continue
        else:
            # If the key is not present in merged, add it from additions
            merged[key] = additions[key]

    return merged

def remove_duplicates(lst):
        """
        Removes duplicates from a list while preserving order.
        Handles unhashable elements by using a list comprehension.

        Parameters:
        - lst (list): The list to remove duplicates from.

        Returns:
        - list: A new list with duplicates removed.
        """
        seen = []
        result = []
        for item in lst:
            if isinstance(item, dict):
                # Convert dict to a frozenset of its items to make it hashable
                item_key = frozenset(item.items())
            else:
                item_key = item

            if item_key not in seen:
                seen.append(item_key)
                result.append(item)
        return result