Dmitry Beresnev commited on
Commit
5260ec0
·
1 Parent(s): 37c39e5

add cache for the downloaded data

Browse files
src/core/ticker_scanner/parallel_data_downloader.py CHANGED
@@ -2,12 +2,14 @@
2
  parallel_yf_downloader.py
3
  Parallel downloading of ticker historical prices using multiprocessing,
4
  with retry and rate-limit handling and batching.
 
5
  """
6
 
7
  import time
8
  import random
9
  from itertools import islice
10
- from typing import Any
 
11
  from concurrent.futures import ProcessPoolExecutor, as_completed
12
 
13
  import yfinance as yf
@@ -22,13 +24,75 @@ MAX_RETRIES = 3 # Retry count on failure
22
  SLEEP_BETWEEN_RETRIES = 1.0 # Seconds between retries
23
  BATCH_SIZE = 50 # Number of tickers per batch
24
  MIN_DATA_POINTS = 50 # Minimum number of price points required
 
25
 
 
 
 
26
 
27
- def fetch_prices(ticker: str, max_retries: int = MAX_RETRIES) -> dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
  Download all-time closing prices for a single ticker safely.
30
- Returns dict {'ticker': ticker, 'prices': ndarray, 'dates': DatetimeIndex} or None if failed.
 
 
 
 
 
 
 
 
31
  """
 
 
 
 
 
 
 
32
  for attempt in range(max_retries):
33
  try:
34
  df = yf.download(ticker, period="max", progress=False, auto_adjust=True)
@@ -59,11 +123,18 @@ def fetch_prices(ticker: str, max_retries: int = MAX_RETRIES) -> dict[str, Any]:
59
  if prices.ndim > 1:
60
  prices = prices.flatten()
61
 
62
- return {
63
  "ticker": ticker,
64
  "prices": prices,
65
  "dates": dates
66
  }
 
 
 
 
 
 
 
67
  except yf.shared.YFRateLimitError:
68
  wait = SLEEP_BETWEEN_RETRIES + random.random()
69
  logger.warning(f"Rate limited for {ticker}. Waiting {wait:.1f}s and retrying...")
@@ -84,25 +155,55 @@ def batch(iterable: list[str], n: int = BATCH_SIZE):
84
  break
85
  yield chunk
86
 
87
- def download_tickers_parallel(tickers: list[str], max_workers: int = MAX_WORKERS) -> list[dict[str, Any]]:
 
88
  """
89
  Download a large list of tickers in parallel batches.
90
- Returns a list of {'ticker': ..., 'prices': ..., 'dates': ...} dicts.
 
 
 
 
 
 
 
 
91
  """
92
- all_results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  all_failed = []
94
 
95
- for batch_num, ticker_batch in enumerate(batch(tickers, BATCH_SIZE), start=1):
96
- logger.info(f"Processing batch {batch_num}: {len(ticker_batch)} tickers")
97
- results, failed = process_batch(ticker_batch, max_workers)
98
- all_results.extend(results)
99
- all_failed.extend(failed)
100
- # small sleep between batches to reduce rate-limit chance
101
- time.sleep(1 + random.random())
 
 
102
 
103
- logger.info(f"Total downloaded: {len(all_results)}")
104
  if all_failed:
105
  logger.warning(f"Total failed: {len(all_failed)} - {all_failed[:10]}") # Show first 10
 
106
  return all_results
107
 
108
  def process_batch(ticker_batch: list[str], max_workers: int) -> tuple[list[dict[str, Any]], list[Any]]:
@@ -127,22 +228,29 @@ def process_batch(ticker_batch: list[str], max_workers: int) -> tuple[list[dict[
127
  return results, failed
128
 
129
  def run_parallel_data_downloader(exchange: StockExchange = StockExchange.NASDAQ,
130
- limit: int = 200) -> list[dict[str, Any]]:
 
131
  """
132
- Main function to download ticker data in parallel.
133
 
134
  Args:
135
  exchange: Stock exchange to download from
136
  limit: Maximum number of tickers to download
 
137
 
138
  Returns:
139
  List of dicts with ticker, prices, and dates
140
  """
141
  all_tickers = TickersProvider().get_tickers(exchange)
142
  tickers = all_tickers[:limit]
143
- logger.info(f"Starting parallel download for {len(tickers)} tickers from {exchange.value}...")
144
- data = download_tickers_parallel(tickers)
145
- logger.info(f"Downloaded {len(data)} tickers successfully")
 
 
 
 
 
146
  return data
147
 
148
 
 
2
  parallel_yf_downloader.py
3
  Parallel downloading of ticker historical prices using multiprocessing,
4
  with retry and rate-limit handling and batching.
5
+ Includes in-memory caching with 2-hour expiry.
6
  """
7
 
8
  import time
9
  import random
10
  from itertools import islice
11
+ from typing import Any, Optional
12
+ from datetime import datetime, timedelta
13
  from concurrent.futures import ProcessPoolExecutor, as_completed
14
 
15
  import yfinance as yf
 
24
  SLEEP_BETWEEN_RETRIES = 1.0 # Seconds between retries
25
  BATCH_SIZE = 50 # Number of tickers per batch
26
  MIN_DATA_POINTS = 50 # Minimum number of price points required
27
+ CACHE_EXPIRY_HOURS = 2 # Cache expiry time in hours
28
 
29
+ # In-memory cache for ticker data
30
+ _ticker_cache: dict[str, dict[str, Any]] = {}
31
+ _cache_timestamps: dict[str, datetime] = {}
32
 
33
+
34
+ def _is_cache_valid(ticker: str) -> bool:
35
+ """Check if cached data for ticker is still valid (not expired)"""
36
+ if ticker not in _cache_timestamps:
37
+ return False
38
+
39
+ cache_age = datetime.now() - _cache_timestamps[ticker]
40
+ return cache_age < timedelta(hours=CACHE_EXPIRY_HOURS)
41
+
42
+
43
+ def _get_cached_data(ticker: str) -> Optional[dict[str, Any]]:
44
+ """Get cached data if valid, None otherwise"""
45
+ if _is_cache_valid(ticker):
46
+ logger.debug(f"Using cached data for {ticker}")
47
+ return _ticker_cache.get(ticker)
48
+ return None
49
+
50
+
51
+ def _cache_data(ticker: str, data: dict[str, Any]) -> None:
52
+ """Cache ticker data with current timestamp"""
53
+ _ticker_cache[ticker] = data
54
+ _cache_timestamps[ticker] = datetime.now()
55
+ logger.debug(f"Cached data for {ticker}")
56
+
57
+
58
+ def clear_cache() -> None:
59
+ """Clear all cached data (useful for testing or manual refresh)"""
60
+ global _ticker_cache, _cache_timestamps
61
+ _ticker_cache.clear()
62
+ _cache_timestamps.clear()
63
+ logger.info("Cache cleared")
64
+
65
+
66
+ def get_cache_stats() -> dict[str, Any]:
67
+ """Get cache statistics"""
68
+ valid_count = sum(1 for ticker in _ticker_cache.keys() if _is_cache_valid(ticker))
69
+ return {
70
+ 'total_cached': len(_ticker_cache),
71
+ 'valid_cached': valid_count,
72
+ 'expired_cached': len(_ticker_cache) - valid_count
73
+ }
74
+
75
+
76
+ def fetch_prices(ticker: str, max_retries: int = MAX_RETRIES, use_cache: bool = True) -> dict[str, Any]:
77
  """
78
  Download all-time closing prices for a single ticker safely.
79
+ Uses in-memory cache if available and not expired.
80
+
81
+ Args:
82
+ ticker: Stock ticker symbol
83
+ max_retries: Maximum number of retry attempts
84
+ use_cache: Whether to use cached data if available
85
+
86
+ Returns:
87
+ dict {'ticker': ticker, 'prices': ndarray, 'dates': DatetimeIndex} or None if failed
88
  """
89
+ # Check cache first
90
+ if use_cache:
91
+ cached_data = _get_cached_data(ticker)
92
+ if cached_data is not None:
93
+ return cached_data
94
+
95
+ # Download fresh data
96
  for attempt in range(max_retries):
97
  try:
98
  df = yf.download(ticker, period="max", progress=False, auto_adjust=True)
 
123
  if prices.ndim > 1:
124
  prices = prices.flatten()
125
 
126
+ result = {
127
  "ticker": ticker,
128
  "prices": prices,
129
  "dates": dates
130
  }
131
+
132
+ # Cache the result
133
+ if use_cache:
134
+ _cache_data(ticker, result)
135
+
136
+ return result
137
+
138
  except yf.shared.YFRateLimitError:
139
  wait = SLEEP_BETWEEN_RETRIES + random.random()
140
  logger.warning(f"Rate limited for {ticker}. Waiting {wait:.1f}s and retrying...")
 
155
  break
156
  yield chunk
157
 
158
+ def download_tickers_parallel(tickers: list[str], max_workers: int = MAX_WORKERS,
159
+ use_cache: bool = True) -> list[dict[str, Any]]:
160
  """
161
  Download a large list of tickers in parallel batches.
162
+ Uses in-memory cache to avoid re-downloading recently fetched data.
163
+
164
+ Args:
165
+ tickers: List of ticker symbols to download
166
+ max_workers: Number of parallel workers
167
+ use_cache: Whether to use cached data
168
+
169
+ Returns:
170
+ List of {'ticker': ..., 'prices': ..., 'dates': ...} dicts
171
  """
172
+ # Separate cached and non-cached tickers
173
+ cached_results = []
174
+ tickers_to_download = []
175
+
176
+ if use_cache:
177
+ for ticker in tickers:
178
+ cached_data = _get_cached_data(ticker)
179
+ if cached_data:
180
+ cached_results.append(cached_data)
181
+ else:
182
+ tickers_to_download.append(ticker)
183
+
184
+ if cached_results:
185
+ logger.info(f"Using cached data for {len(cached_results)} tickers")
186
+ else:
187
+ tickers_to_download = tickers
188
+
189
+ # Download remaining tickers
190
+ all_results = cached_results.copy()
191
  all_failed = []
192
 
193
+ if tickers_to_download:
194
+ logger.info(f"Downloading {len(tickers_to_download)} tickers...")
195
+ for batch_num, ticker_batch in enumerate(batch(tickers_to_download, BATCH_SIZE), start=1):
196
+ logger.info(f"Processing batch {batch_num}: {len(ticker_batch)} tickers")
197
+ results, failed = process_batch(ticker_batch, max_workers)
198
+ all_results.extend(results)
199
+ all_failed.extend(failed)
200
+ # small sleep between batches to reduce rate-limit chance
201
+ time.sleep(1 + random.random())
202
 
203
+ logger.info(f"Total available: {len(all_results)} (cached: {len(cached_results)}, downloaded: {len(all_results) - len(cached_results)})")
204
  if all_failed:
205
  logger.warning(f"Total failed: {len(all_failed)} - {all_failed[:10]}") # Show first 10
206
+
207
  return all_results
208
 
209
  def process_batch(ticker_batch: list[str], max_workers: int) -> tuple[list[dict[str, Any]], list[Any]]:
 
228
  return results, failed
229
 
230
  def run_parallel_data_downloader(exchange: StockExchange = StockExchange.NASDAQ,
231
+ limit: int = 200,
232
+ use_cache: bool = True) -> list[dict[str, Any]]:
233
  """
234
+ Main function to download ticker data in parallel with caching.
235
 
236
  Args:
237
  exchange: Stock exchange to download from
238
  limit: Maximum number of tickers to download
239
+ use_cache: Whether to use cached data (expires after 2 hours)
240
 
241
  Returns:
242
  List of dicts with ticker, prices, and dates
243
  """
244
  all_tickers = TickersProvider().get_tickers(exchange)
245
  tickers = all_tickers[:limit]
246
+
247
+ # Log cache stats
248
+ cache_stats = get_cache_stats()
249
+ logger.info(f"Cache stats: {cache_stats['valid_cached']} valid, {cache_stats['expired_cached']} expired")
250
+
251
+ logger.info(f"Starting download for {len(tickers)} tickers from {exchange.value}...")
252
+ data = download_tickers_parallel(tickers, use_cache=use_cache)
253
+ logger.info(f"Retrieved {len(data)} tickers successfully")
254
  return data
255
 
256