Spaces:
Runtime error
Runtime error
File size: 14,357 Bytes
82a7a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import json
import copy
from tinytroupe.utils import logger
class JsonSerializableRegistry:
"""
A mixin class that provides JSON serialization, deserialization, and subclass registration.
"""
class_mapping = {}
def to_json(self, include: list = None, suppress: list = None, file_path: str = None,
serialization_type_field_name = "json_serializable_class_name") -> dict:
"""
Returns a JSON representation of the object.
Args:
include (list, optional): Attributes to include in the serialization. Will override the default behavior.
suppress (list, optional): Attributes to suppress from the serialization. Will override the default behavior.
file_path (str, optional): Path to a file where the JSON will be written.
"""
# Gather all serializable attributes from the class hierarchy
serializable_attrs = set()
suppress_attrs = set()
for cls in self.__class__.__mro__: # Traverse the class hierarchy
if hasattr(cls, 'serializable_attributes') and isinstance(cls.serializable_attributes, list):
serializable_attrs.update(cls.serializable_attributes)
if hasattr(cls, 'suppress_attributes_from_serialization') and isinstance(cls.suppress_attributes_from_serialization, list):
suppress_attrs.update(cls.suppress_attributes_from_serialization)
# Override attributes with method parameters if provided
if include:
serializable_attrs = set(include)
if suppress:
suppress_attrs.update(suppress)
result = {serialization_type_field_name: self.__class__.__name__}
for attr in serializable_attrs if serializable_attrs else self.__dict__:
if attr not in suppress_attrs:
value = getattr(self, attr, None)
attr_renamed = self._programmatic_name_to_json_name(attr)
if isinstance(value, JsonSerializableRegistry):
result[attr_renamed] = value.to_json(serialization_type_field_name=serialization_type_field_name)
elif isinstance(value, list):
result[attr_renamed] = [item.to_json(serialization_type_field_name=serialization_type_field_name) if isinstance(item, JsonSerializableRegistry) else copy.deepcopy(item) for item in value]
elif isinstance(value, dict):
result[attr_renamed] = {k: v.to_json(serialization_type_field_name=serialization_type_field_name) if isinstance(v, JsonSerializableRegistry) else copy.deepcopy(v) for k, v in value.items()}
else:
result[attr_renamed] = copy.deepcopy(value)
if file_path:
# Create directories if they do not exist
import os
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
json.dump(result, f, indent=4)
return result
@classmethod
def from_json(cls, json_dict_or_path, suppress: list = None,
serialization_type_field_name = "json_serializable_class_name",
post_init_params: dict = None):
"""
Loads a JSON representation of the object and creates an instance of the class.
Args:
json_dict_or_path (dict or str): The JSON dictionary representing the object or a file path to load the JSON from.
suppress (list, optional): Attributes to suppress from being loaded.
Returns:
An instance of the class populated with the data from json_dict_or_path.
"""
if isinstance(json_dict_or_path, str):
with open(json_dict_or_path, 'r') as f:
json_dict = json.load(f)
else:
json_dict = json_dict_or_path
subclass_name = json_dict.get(serialization_type_field_name)
target_class = cls.class_mapping.get(subclass_name, cls)
instance = target_class.__new__(target_class) # Create an instance without calling __init__
# Gather all serializable attributes from the class hierarchy
serializable_attrs = set()
custom_serialization_initializers = {}
suppress_attrs = set(suppress) if suppress else set()
for target_mro in target_class.__mro__:
if hasattr(target_mro, 'serializable_attributes') and isinstance(target_mro.serializable_attributes, list):
serializable_attrs.update(target_mro.serializable_attributes)
if hasattr(target_mro, 'custom_serialization_initializers') and isinstance(target_mro.custom_serialization_initializers, dict):
custom_serialization_initializers.update(target_mro.custom_serialization_initializers)
if hasattr(target_mro, 'suppress_attributes_from_serialization') and isinstance(target_mro.suppress_attributes_from_serialization, list):
suppress_attrs.update(target_mro.suppress_attributes_from_serialization)
# Assign values only for serializable attributes if specified, otherwise assign everything
for key in serializable_attrs if serializable_attrs else json_dict:
key_in_json = cls._programmatic_name_to_json_name(key)
if key_in_json in json_dict and key not in suppress_attrs:
value = json_dict[key_in_json]
if key in custom_serialization_initializers:
# Use custom initializer if provided
setattr(instance, key, custom_serialization_initializers[key](value))
elif isinstance(value, dict) and serialization_type_field_name in value:
# Assume it's another JsonSerializableRegistry object
setattr(instance, key, JsonSerializableRegistry.from_json(value, serialization_type_field_name=serialization_type_field_name))
elif isinstance(value, list):
# Handle collections, recursively deserialize if items are JsonSerializableRegistry objects
deserialized_collection = []
for item in value:
if isinstance(item, dict) and serialization_type_field_name in item:
deserialized_collection.append(JsonSerializableRegistry.from_json(item, serialization_type_field_name=serialization_type_field_name))
else:
deserialized_collection.append(copy.deepcopy(item))
setattr(instance, key, deserialized_collection)
else:
setattr(instance, key, copy.deepcopy(value))
# Call post-deserialization initialization if available
if hasattr(instance, '_post_deserialization_init') and callable(instance._post_deserialization_init):
post_init_params = post_init_params if post_init_params else {}
instance._post_deserialization_init(**post_init_params)
return instance
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Register the subclass using its name as the key
JsonSerializableRegistry.class_mapping[cls.__name__] = cls
# Automatically extend serializable attributes and custom initializers from parent classes
if hasattr(cls, 'serializable_attributes') and isinstance(cls.serializable_attributes, list):
for base in cls.__bases__:
if hasattr(base, 'serializable_attributes') and isinstance(base.serializable_attributes, list):
cls.serializable_attributes = list(set(base.serializable_attributes + cls.serializable_attributes))
if hasattr(cls, 'suppress_attributes_from_serialization') and isinstance(cls.suppress_attributes_from_serialization, list):
for base in cls.__bases__:
if hasattr(base, 'suppress_attributes_from_serialization') and isinstance(base.suppress_attributes_from_serialization, list):
cls.suppress_attributes_from_serialization = list(set(base.suppress_attributes_from_serialization + cls.suppress_attributes_from_serialization))
if hasattr(cls, 'custom_serialization_initializers') and isinstance(cls.custom_serialization_initializers, dict):
for base in cls.__bases__:
if hasattr(base, 'custom_serialization_initializers') and isinstance(base.custom_serialization_initializers, dict):
base_initializers = base.custom_serialization_initializers.copy()
base_initializers.update(cls.custom_serialization_initializers)
cls.custom_serialization_initializers = base_initializers
def _post_deserialization_init(self, **kwargs):
# if there's a _post_init method, call it after deserialization
if hasattr(self, '_post_init'):
self._post_init(**kwargs)
@classmethod
def _programmatic_name_to_json_name(cls, name):
"""
Converts a programmatic name to a JSON name by converting it to snake case.
"""
if hasattr(cls, 'serializable_attributes_renaming') and isinstance(cls.serializable_attributes_renaming, dict):
return cls.serializable_attributes_renaming.get(name, name)
return name
@classmethod
def _json_name_to_programmatic_name(cls, name):
"""
Converts a JSON name to a programmatic name.
"""
if hasattr(cls, 'serializable_attributes_renaming') and isinstance(cls.serializable_attributes_renaming, dict):
reverse_rename = {}
for k, v in cls.serializable_attributes_renaming.items():
if v in reverse_rename:
raise ValueError(f"Duplicate value '{v}' found in serializable_attributes_renaming.")
reverse_rename[v] = k
return reverse_rename.get(name, name)
return name
def post_init(cls):
"""
Decorator to enforce a post-initialization method call in a class, if it has one.
The method must be named `_post_init`.
"""
original_init = cls.__init__
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if hasattr(cls, '_post_init'):
cls._post_init(self)
cls.__init__ = new_init
return cls
def merge_dicts(current, additions, overwrite=False, error_on_conflict=True):
"""
Merges two dictionaries and returns a new dictionary. Works as follows:
- If a key exists in the additions dictionary but not in the current dictionary, it is added.
- If a key maps to None in the current dictionary, it is replaced by the value in the additions dictionary.
- If a key exists in both dictionaries and the values are dictionaries, the function is called recursively.
- If a key exists in both dictionaries and the values are lists, the lists are concatenated and duplicates are removed.
- If the values are of different types, an exception is raised.
- If the values are of the same type but not both lists/dictionaries, the value from the additions dictionary overwrites the value in the current dictionary based on the overwrite parameter.
Parameters:
- current (dict): The original dictionary.
- additions (dict): The dictionary with values to add.
- overwrite (bool): Whether to overwrite values if they are of the same type but not both lists/dictionaries.
- error_on_conflict (bool): Whether to raise an error if there is a conflict and overwrite is False.
Returns:
- dict: A new dictionary with merged values.
"""
merged = current.copy() # Create a copy of the current dictionary to avoid altering it
for key in additions:
if key in merged:
# If the current value is None, directly assign the new value
if merged[key] is None:
merged[key] = additions[key]
# If both values are dictionaries, merge them recursively
elif isinstance(merged[key], dict) and isinstance(additions[key], dict):
merged[key] = merge_dicts(merged[key], additions[key], overwrite, error_on_conflict)
# If both values are lists, concatenate them and remove duplicates
elif isinstance(merged[key], list) and isinstance(additions[key], list):
merged[key].extend(additions[key])
# Remove duplicates while preserving order
merged[key] = remove_duplicates(merged[key])
# If the values are of different types, raise an exception
elif type(merged[key]) != type(additions[key]):
raise TypeError(f"Cannot merge different types: {type(merged[key])} and {type(additions[key])} for key '{key}'")
# If the values are of the same type but not both lists/dictionaries, decide based on the overwrite parameter
else:
if overwrite:
merged[key] = additions[key]
elif merged[key] != additions[key]:
if error_on_conflict:
raise ValueError(f"Conflict at key '{key}': overwrite is set to False and values are different.")
else:
continue # Ignore the conflict and continue
else:
# If the key is not present in merged, add it from additions
merged[key] = additions[key]
return merged
def remove_duplicates(lst):
"""
Removes duplicates from a list while preserving order.
Handles unhashable elements by using a list comprehension.
Parameters:
- lst (list): The list to remove duplicates from.
Returns:
- list: A new list with duplicates removed.
"""
seen = []
result = []
for item in lst:
if isinstance(item, dict):
# Convert dict to a frozenset of its items to make it hashable
item_key = frozenset(item.items())
else:
item_key = item
if item_key not in seen:
seen.append(item_key)
result.append(item)
return result |