lamhieu commited on
Commit
8ad7a2a
·
1 Parent(s): b0cb394

chore: using redis sync instead of

Browse files
Files changed (1) hide show
  1. lightweight_embeddings/analytics.py +138 -60
lightweight_embeddings/analytics.py CHANGED
@@ -1,56 +1,76 @@
1
  import logging
2
  import asyncio
3
- import redis.asyncio as redis
4
  import redis.exceptions
5
  from datetime import datetime
6
  from collections import defaultdict
7
  from typing import Dict
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
11
  class Analytics:
12
  def __init__(self, redis_url: str, sync_interval: int = 60, max_retries: int = 5):
13
  """
14
- Initializes the Analytics class with an async Redis connection and sync interval.
 
15
 
16
  Parameters:
17
- - redis_url: Redis connection URL (e.g., 'redis://localhost:6379/0')
18
- - sync_interval: Interval in seconds for syncing with Redis.
19
- - max_retries: Maximum number of retries for reconnecting to Redis.
20
  """
21
  self.redis_url = redis_url
22
  self.sync_interval = sync_interval
23
  self.max_retries = max_retries
 
 
24
  self.redis_client = self._create_redis_client()
 
 
25
  self.local_buffer = {
26
- "access": defaultdict(
27
- lambda: defaultdict(int)
28
- ), # {period: {model_id: access_count}}
29
- "tokens": defaultdict(
30
- lambda: defaultdict(int)
31
- ), # {period: {model_id: tokens_count}}
32
  }
33
- self.lock = asyncio.Lock() # Async lock for thread-safe updates
34
- asyncio.create_task(self._start_sync_task())
 
 
 
 
35
 
36
  logger.info("Initialized Analytics with Redis connection: %s", redis_url)
37
 
38
  def _create_redis_client(self) -> redis.Redis:
39
  """
40
- Creates and returns a new Redis client.
41
  """
42
  return redis.from_url(
43
  self.redis_url,
44
  decode_responses=True,
45
  health_check_interval=10,
46
  socket_connect_timeout=5,
47
- retry_on_timeout=True,
48
  socket_keepalive=True,
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _get_period_keys(self) -> tuple:
52
  """
53
- Returns keys for day, week, month, and year based on the current date.
54
  """
55
  now = datetime.utcnow()
56
  day_key = now.strftime("%Y-%m-%d")
@@ -61,23 +81,23 @@ class Analytics:
61
 
62
  async def access(self, model_id: str, tokens: int):
63
  """
64
- Records an access and token usage for a specific model_id.
65
 
66
  Parameters:
67
- - model_id: The ID of the model being accessed.
68
- - tokens: Number of tokens used in this access.
69
  """
70
  day_key, week_key, month_key, year_key = self._get_period_keys()
71
 
72
  async with self.lock:
73
- # Increment access count
74
  self.local_buffer["access"][day_key][model_id] += 1
75
  self.local_buffer["access"][week_key][model_id] += 1
76
  self.local_buffer["access"][month_key][model_id] += 1
77
  self.local_buffer["access"][year_key][model_id] += 1
78
  self.local_buffer["access"]["total"][model_id] += 1
79
 
80
- # Increment token count
81
  self.local_buffer["tokens"][day_key][model_id] += tokens
82
  self.local_buffer["tokens"][week_key][model_id] += tokens
83
  self.local_buffer["tokens"][month_key][model_id] += tokens
@@ -86,10 +106,7 @@ class Analytics:
86
 
87
  async def stats(self) -> Dict[str, Dict[str, Dict[str, int]]]:
88
  """
89
- Returns statistics for all models from the local buffer.
90
-
91
- Returns:
92
- - A dictionary with access counts and token usage for each period.
93
  """
94
  async with self.lock:
95
  return {
@@ -103,31 +120,87 @@ class Analytics:
103
  },
104
  }
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  async def _sync_to_redis(self):
107
  """
108
- Synchronizes local buffer data with Redis.
 
109
  """
 
 
110
  async with self.lock:
111
  try:
112
- pipeline = self.redis_client.pipeline()
113
 
114
- # Sync access counts
115
  for period, models in self.local_buffer["access"].items():
 
116
  for model_id, count in models.items():
117
- redis_key = f"analytics:access:{period}"
118
  pipeline.hincrby(redis_key, model_id, count)
119
 
120
- # Sync token counts
121
  for period, models in self.local_buffer["tokens"].items():
 
122
  for model_id, count in models.items():
123
- redis_key = f"analytics:tokens:{period}"
124
  pipeline.hincrby(redis_key, model_id, count)
125
 
126
- pipeline.execute()
127
- self.local_buffer["access"].clear() # Clear access buffer after sync
128
- self.local_buffer["tokens"].clear() # Clear tokens buffer after sync
129
- logger.info("Synced analytics data to Redis.")
130
 
 
131
  except redis.exceptions.ConnectionError as e:
132
  logger.error("Redis connection error during sync: %s", e)
133
  raise e
@@ -137,64 +210,69 @@ class Analytics:
137
 
138
  async def _start_sync_task(self):
139
  """
140
- Starts a background task that periodically syncs data to Redis.
141
- Implements retry logic with exponential backoff on connection failures.
142
  """
143
- retry_delay = 1 # Initial retry delay in seconds
144
-
145
  while True:
146
  await asyncio.sleep(self.sync_interval)
147
  try:
148
  await self._sync_to_redis()
149
- retry_delay = 1 # Reset retry delay after successful sync
150
  except redis.exceptions.ConnectionError as e:
151
- logger.error("Redis connection error: %s", e)
152
  await self._handle_redis_reconnection()
153
  except Exception as e:
154
- logger.error("Error during sync: %s", e)
155
- # Depending on the error, you might want to handle differently
156
 
157
  async def _handle_redis_reconnection(self):
158
  """
159
- Handles Redis reconnection with exponential backoff.
160
  """
 
161
  retry_count = 0
162
- delay = 1 # Start with 1 second delay
163
 
164
  while retry_count < self.max_retries:
165
  try:
166
- logger.info("Attempting to reconnect to Redis (Attempt %d)...", retry_count + 1)
167
- self.redis_client.close()
 
 
 
 
168
  self.redis_client = self._create_redis_client()
169
- # Optionally, perform a simple command to check connection
170
- self.redis_client.ping()
171
  logger.info("Successfully reconnected to Redis.")
172
  return
173
  except redis.exceptions.ConnectionError as e:
174
  logger.error("Reconnection attempt %d failed: %s", retry_count + 1, e)
175
  retry_count += 1
176
  await asyncio.sleep(delay)
177
- delay *= 2 # Exponential backoff
178
 
179
- logger.critical("Max reconnection attempts reached. Unable to reconnect to Redis.")
180
- # Depending on your application's requirements, you might choose to exit or keep retrying indefinitely
181
- # For example, to keep retrying:
 
 
182
  while True:
183
  try:
184
  logger.info("Retrying to reconnect to Redis...")
185
- self.redis_client.close()
186
  self.redis_client = self._create_redis_client()
187
- self.redis_client.ping()
188
- logger.info("Successfully reconnected to Redis.")
189
  break
190
  except redis.exceptions.ConnectionError as e:
191
- logger.error("Reconnection attempt failed: %s", e)
192
  await asyncio.sleep(delay)
193
- delay = min(delay * 2, 60) # Cap the delay to 60 seconds
194
 
195
  async def close(self):
196
  """
197
- Closes the Redis connection gracefully.
198
  """
199
- self.redis_client.close()
 
200
  logger.info("Closed Redis connection.")
 
1
  import logging
2
  import asyncio
3
+ import redis
4
  import redis.exceptions
5
  from datetime import datetime
6
  from collections import defaultdict
7
  from typing import Dict
8
+ from functools import partial
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
+
13
  class Analytics:
14
  def __init__(self, redis_url: str, sync_interval: int = 60, max_retries: int = 5):
15
  """
16
+ Initializes the Analytics class with a synchronous Redis client,
17
+ wrapped in asynchronous methods by using run_in_executor.
18
 
19
  Parameters:
20
+ - redis_url (str): Redis connection URL (e.g., 'redis://localhost:6379/0').
21
+ - sync_interval (int): Interval in seconds for syncing with Redis.
22
+ - max_retries (int): Maximum number of reconnection attempts to Redis.
23
  """
24
  self.redis_url = redis_url
25
  self.sync_interval = sync_interval
26
  self.max_retries = max_retries
27
+
28
+ # Synchronous Redis client
29
  self.redis_client = self._create_redis_client()
30
+
31
+ # Local buffer stores cumulative data for two-way sync
32
  self.local_buffer = {
33
+ "access": defaultdict(lambda: defaultdict(int)),
34
+ "tokens": defaultdict(lambda: defaultdict(int)),
 
 
 
 
35
  }
36
+
37
+ # Asynchronous lock to protect shared data
38
+ self.lock = asyncio.Lock()
39
+
40
+ # Initialize data from Redis, then start the periodic sync loop
41
+ asyncio.create_task(self._initialize())
42
 
43
  logger.info("Initialized Analytics with Redis connection: %s", redis_url)
44
 
45
  def _create_redis_client(self) -> redis.Redis:
46
  """
47
+ Creates and returns a new synchronous Redis client.
48
  """
49
  return redis.from_url(
50
  self.redis_url,
51
  decode_responses=True,
52
  health_check_interval=10,
53
  socket_connect_timeout=5,
 
54
  socket_keepalive=True,
55
  )
56
 
57
+ async def _initialize(self):
58
+ """
59
+ Fetches existing data from Redis into the local buffer,
60
+ then starts the periodic synchronization task.
61
+ """
62
+ try:
63
+ await self._sync_from_redis()
64
+ logger.info("Initial sync from Redis to local buffer completed.")
65
+ except Exception as e:
66
+ logger.error("Error during initial sync from Redis: %s", e)
67
+
68
+ # Launch the periodic sync task
69
+ asyncio.create_task(self._start_sync_task())
70
+
71
  def _get_period_keys(self) -> tuple:
72
  """
73
+ Returns day, week, month, and year keys based on the current UTC date.
74
  """
75
  now = datetime.utcnow()
76
  day_key = now.strftime("%Y-%m-%d")
 
81
 
82
  async def access(self, model_id: str, tokens: int):
83
  """
84
+ Records an access event and token usage for a specific model.
85
 
86
  Parameters:
87
+ - model_id (str): The ID of the accessed model.
88
+ - tokens (int): Number of tokens used in this access event.
89
  """
90
  day_key, week_key, month_key, year_key = self._get_period_keys()
91
 
92
  async with self.lock:
93
+ # Access counts
94
  self.local_buffer["access"][day_key][model_id] += 1
95
  self.local_buffer["access"][week_key][model_id] += 1
96
  self.local_buffer["access"][month_key][model_id] += 1
97
  self.local_buffer["access"][year_key][model_id] += 1
98
  self.local_buffer["access"]["total"][model_id] += 1
99
 
100
+ # Token usage
101
  self.local_buffer["tokens"][day_key][model_id] += tokens
102
  self.local_buffer["tokens"][week_key][model_id] += tokens
103
  self.local_buffer["tokens"][month_key][model_id] += tokens
 
106
 
107
  async def stats(self) -> Dict[str, Dict[str, Dict[str, int]]]:
108
  """
109
+ Returns a copy of current statistics from the local buffer.
 
 
 
110
  """
111
  async with self.lock:
112
  return {
 
120
  },
121
  }
122
 
123
+ async def _sync_from_redis(self):
124
+ """
125
+ Pulls existing analytics data from Redis into the local buffer.
126
+ Uses run_in_executor to avoid blocking the event loop.
127
+ """
128
+ loop = asyncio.get_running_loop()
129
+
130
+ async with self.lock:
131
+ # Scan 'access' keys
132
+ cursor = 0
133
+ while True:
134
+ cursor, keys = await loop.run_in_executor(
135
+ None,
136
+ partial(
137
+ self.redis_client.scan,
138
+ cursor=cursor,
139
+ match="analytics:access:*",
140
+ count=100,
141
+ ),
142
+ )
143
+ for key in keys:
144
+ # key is "analytics:access:<period>"
145
+ period = key.replace("analytics:access:", "")
146
+ data = await loop.run_in_executor(
147
+ None, partial(self.redis_client.hgetall, key)
148
+ )
149
+ for model_id, count_str in data.items():
150
+ self.local_buffer["access"][period][model_id] += int(count_str)
151
+ if cursor == 0:
152
+ break
153
+
154
+ # Scan 'tokens' keys
155
+ cursor = 0
156
+ while True:
157
+ cursor, keys = await loop.run_in_executor(
158
+ None,
159
+ partial(
160
+ self.redis_client.scan,
161
+ cursor=cursor,
162
+ match="analytics:tokens:*",
163
+ count=100,
164
+ ),
165
+ )
166
+ for key in keys:
167
+ # key is "analytics:tokens:<period>"
168
+ period = key.replace("analytics:tokens:", "")
169
+ data = await loop.run_in_executor(
170
+ None, partial(self.redis_client.hgetall, key)
171
+ )
172
+ for model_id, count_str in data.items():
173
+ self.local_buffer["tokens"][period][model_id] += int(count_str)
174
+ if cursor == 0:
175
+ break
176
+
177
  async def _sync_to_redis(self):
178
  """
179
+ Pushes the local buffer data to Redis (local -> Redis).
180
+ Uses a pipeline to minimize round trips and run_in_executor to avoid blocking.
181
  """
182
+ loop = asyncio.get_running_loop()
183
+
184
  async with self.lock:
185
  try:
186
+ pipeline = self.redis_client.pipeline(transaction=False)
187
 
188
+ # Push 'access' data
189
  for period, models in self.local_buffer["access"].items():
190
+ redis_key = f"analytics:access:{period}"
191
  for model_id, count in models.items():
 
192
  pipeline.hincrby(redis_key, model_id, count)
193
 
194
+ # Push 'tokens' data
195
  for period, models in self.local_buffer["tokens"].items():
196
+ redis_key = f"analytics:tokens:{period}"
197
  for model_id, count in models.items():
 
198
  pipeline.hincrby(redis_key, model_id, count)
199
 
200
+ # Execute the pipeline in a separate thread
201
+ await loop.run_in_executor(None, pipeline.execute)
 
 
202
 
203
+ logger.info("Analytics data successfully synced to Redis.")
204
  except redis.exceptions.ConnectionError as e:
205
  logger.error("Redis connection error during sync: %s", e)
206
  raise e
 
210
 
211
  async def _start_sync_task(self):
212
  """
213
+ Periodically runs _sync_to_redis at a configurable interval.
214
+ Also handles reconnections on ConnectionError.
215
  """
 
 
216
  while True:
217
  await asyncio.sleep(self.sync_interval)
218
  try:
219
  await self._sync_to_redis()
 
220
  except redis.exceptions.ConnectionError as e:
221
+ logger.error("Redis connection error during scheduled sync: %s", e)
222
  await self._handle_redis_reconnection()
223
  except Exception as e:
224
+ logger.error("Error during scheduled sync: %s", e)
225
+ # Handle other errors as appropriate
226
 
227
  async def _handle_redis_reconnection(self):
228
  """
229
+ Attempts to reconnect to Redis using exponential backoff.
230
  """
231
+ loop = asyncio.get_running_loop()
232
  retry_count = 0
233
+ delay = 1
234
 
235
  while retry_count < self.max_retries:
236
  try:
237
+ logger.info(
238
+ "Attempting to reconnect to Redis (attempt %d)...", retry_count + 1
239
+ )
240
+ # Close existing connection
241
+ await loop.run_in_executor(None, self.redis_client.close)
242
+ # Create a new client
243
  self.redis_client = self._create_redis_client()
244
+ # Test the new connection
245
+ await loop.run_in_executor(None, self.redis_client.ping)
246
  logger.info("Successfully reconnected to Redis.")
247
  return
248
  except redis.exceptions.ConnectionError as e:
249
  logger.error("Reconnection attempt %d failed: %s", retry_count + 1, e)
250
  retry_count += 1
251
  await asyncio.sleep(delay)
252
+ delay *= 2 # exponential backoff
253
 
254
+ logger.critical(
255
+ "Max reconnection attempts reached. Unable to reconnect to Redis."
256
+ )
257
+
258
+ # Optional: Keep retrying indefinitely instead of giving up.
259
  while True:
260
  try:
261
  logger.info("Retrying to reconnect to Redis...")
262
+ await loop.run_in_executor(None, self.redis_client.close)
263
  self.redis_client = self._create_redis_client()
264
+ await loop.run_in_executor(None, self.redis_client.ping)
265
+ logger.info("Reconnected to Redis after extended retries.")
266
  break
267
  except redis.exceptions.ConnectionError as e:
268
+ logger.error("Extended reconnection attempt failed: %s", e)
269
  await asyncio.sleep(delay)
270
+ delay = min(delay * 2, 60) # Cap at 60 seconds or choose your own max
271
 
272
  async def close(self):
273
  """
274
+ Closes the Redis client connection. Still wrapped in an async method to avoid blocking.
275
  """
276
+ loop = asyncio.get_running_loop()
277
+ await loop.run_in_executor(None, self.redis_client.close)
278
  logger.info("Closed Redis connection.")