-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_ensemble.py
More file actions
437 lines (367 loc) · 19.2 KB
/
train_ensemble.py
File metadata and controls
437 lines (367 loc) · 19.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
"""
Time Series Ensemble Training Application
Trains TimeSeriesStackingEnsemble with TimesNet, Non-stationary Transformer, Informer, and LuTransformer
"""
# Suppress TensorFlow oneDNN messages
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import sys
import pytorch_lightning as pl
import logging
# Suppress protobuf version warnings
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='google.protobuf')
warnings.filterwarnings('ignore', message='.*protobuf.*version.*')
import argparse
import numpy as np
import pandas as pd
import time
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from pathlib import Path
from typing import Dict, Any
# Add TradeBot to path for logging config
sys.path.insert(0, str(Path(__file__).parent))
from TradeBot.config.logging_config import setup_logging
# Set PyTorch deterministic behavior for consistent predictions
import torch
#torch.use_deterministic_algorithms(True, warn_only=True)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False
# Add TimeSeriesLib and its utils to path first to avoid conflicts
timeserieslib_path = os.path.join(os.path.dirname(__file__), 'TimeSeriesLib')
timeserieslib_utils_path = os.path.join(timeserieslib_path, 'utils')
if timeserieslib_path not in sys.path:
sys.path.insert(0, timeserieslib_path)
if timeserieslib_utils_path not in sys.path:
sys.path.insert(0, timeserieslib_utils_path)
from Training.TimeSeriesSimpleEnsemble import TimeSeriesSimpleEnsemble
from Training.StockDataSet import TimeSeriesDataset
from Core.config import get_timeseries_ensemble_config, get_model_config, get_training_config, get_data_config, get_paths, get_path_manager
from Core.helpers import set_random_seed, cleanup_memory
from Core.market_data_validator import MarketDataValidator
def main():
# Configure training logging (file + console)
setup_logging(
log_dir=os.getenv('LOG_DIR', None), # Default: ./logs or /var/log/tradebot
log_level=os.getenv('LOG_LEVEL', 'DEBUG'),
console_output=os.getenv('LOG_CONSOLE', 'true').lower() == 'true',
file_output=os.getenv('LOG_FILE', 'true').lower() == 'true',
app_log_filename='training.log',
error_log_filename='training_error.log'
)
logger = logging.getLogger(__name__)
set_random_seed(42)
torch.set_float32_matmul_precision('high')
training_config = get_training_config()
data_config = get_data_config()
paths_config = get_paths()
model_config = get_model_config()
ensemble_config = get_timeseries_ensemble_config()
parser = argparse.ArgumentParser(description='Train Time Series Simple Ensemble with Weighted Voting')
parser.add_argument('--ticker', type=str, required=True, help='Stock ticker symbol')
parser.add_argument('--epochs', type=int, default=training_config['epochs'], help='Number of training epochs')
parser.add_argument('--patience', type=int, default=training_config['early_stopping_patience'], help='Early stopping patience')
parser.add_argument('--from-scratch', action='store_true', help='Clean up existing model files and train a new model from scratch')
parser.add_argument('--select-features', action='store_true', help='Enable automatic feature selection using StockFeatureSelector (forces from-scratch)')
parser.add_argument('--skip-base-models', action='store_true', help='Skip base model training (requires existing trained base models)')
parser.add_argument('--last-fold', action='store_true', help='Train only on the last fold (skip previous folds)')
parser.add_argument('--deployment', action='store_true', help='Enable deployment mode (train on all data without validation)')
parser.add_argument('--skip-meta-learner', action='store_true', help='Skip base calibration and meta-learner training.')
parser.add_argument('--version', type=str, default=None, help='Model version (e.g., v1.0_2025-01-15). If not provided, auto-generates from current date.')
parser.add_argument('--auto-update-holdout', action='store_true', help='Automatically update holdout_datetime to 30 days before current date in config.py')
args = parser.parse_args()
# Validate market data indexes before training
logger.info("Validating market data indexes (SPY, XLP, VIXY)...")
validator = MarketDataValidator()
validator.validate_and_raise([args.ticker, 'XLP', 'VIXY'])
#timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
mode = "Load existing base models" if args.skip_base_models else ("Deployment" if args.deployment else "Train with CV")
logger.info(f"🚀 Simple Ensemble Training | Ticker: {args.ticker} | Epochs: {args.epochs} | Patience: {args.patience} | Mode: {mode} | Deployment: {args.deployment}")
# Validate arguments
if args.epochs <= 0:
raise ValueError(f"Epochs must be positive, got {args.epochs}")
if args.patience <= 0:
raise ValueError(f"Patience must be positive, got {args.patience}")
if not args.ticker or not isinstance(args.ticker, str):
raise ValueError(f"Ticker must be a non-empty string, got {args.ticker}")
interval = data_config['interval']
# Handle holdout datetime validation and auto-update
holdout_datetime = data_config.get('holdout_datetime')
if args.auto_update_holdout:
data_config = update_holdout_datetime(logger)
holdout_datetime = data_config.get('holdout_datetime')
# Validate holdout datetime age (warn if >30 days old)
if holdout_datetime is not None:
ny_time = datetime.now(ZoneInfo('America/New_York')).replace(tzinfo=None)
holdout_datetime = holdout_datetime.replace(tzinfo=None)
holdout_age = (ny_time - holdout_datetime).days
if holdout_age > 30:
logger.warning(f"⚠️ WARNING: holdout_datetime is {holdout_age} days old (more than 30 days)")
logger.warning(f" Current holdout_datetime: {holdout_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
logger.warning(f" Current date: {ny_time.strftime('%Y-%m-%d %H:%M:%S')}")
logger.warning(f" Consider updating holdout_datetime to 30 days before current date")
logger.warning(f" Use --auto-update-holdout flag to automatically update it")
elif holdout_age < 0:
logger.warning(f"⚠️ WARNING: holdout_datetime is in the future: {holdout_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
else:
logger.info(f"✓ Holdout datetime is {holdout_age} days old (within 30-day window)")
# Log the content of config.py
config_py_path = os.path.join(os.path.dirname(__file__), "Core", "config.py")
logger.debug('=' * 60)
logger.debug(f"Content of config.py ({config_py_path}):")
logger.debug('-' * 60)
with open(config_py_path, "r") as f:
logger.debug(f.read())
logger.debug('=' * 60)
# Force from_scratch when select_features is enabled
#if args.select_features:
# args.from_scratch = True
# print(f"🎯 Feature selection enabled - forcing from-scratch training")
# Handle from-scratch training
# if args.from_scratch:
# print("🧹 Starting fresh training - cleaning up existing model files...")
# import shutil
# import os
#
# # Clean up existing models for this ticker and interval
# models_dir = paths_config['models_dir']
# task_type = model_config['task_type']
# # Target directory: {ticker}_{interval}_{task_type}
# target_dir = os.path.join(models_dir, f"{args.ticker}_{interval}_{task_type}")
# if os.path.exists(target_dir):
# shutil.rmtree(target_dir)
# print(f" 🗑️ Cleaned up: {target_dir}")
dataset = load_data_excluding_holdout(args.ticker, logger)
# Generate version if not provided
if args.version is None:
ny_time = datetime.now(ZoneInfo('America/New_York'))
args.version = f"v1.0_{ny_time.strftime('%Y-%m-%d')}"
logger.info(f"Auto-generated version: {args.version}")
# Train simple ensemble
ensemble = train_ensemble(dataset, args.ticker, ensemble_config, args.epochs, args.patience, args.version, args.skip_base_models, args.last_fold, args.deployment, args.skip_meta_learner, logger)
# Print performance report
print_performance_report(ensemble, logger)
# Final cleanup
cleanup_memory()
logger.info("✅ Done.")
logger.info("=" * 60)
def update_holdout_datetime(logger: logging.Logger) -> Dict[str, Any]:
"""Update holdout_datetime in config.py to 30 days before current date and reload config.
Returns:
Updated data_config dictionary
"""
ny_time = datetime.now(ZoneInfo('America/New_York'))
new_holdout = ny_time - timedelta(days=30)
new_holdout_str = new_holdout.strftime('%Y-%m-%d %H:%M:%S')
config_py_path = os.path.join(os.path.dirname(__file__), "Core", "config.py")
with open(config_py_path, 'r', encoding='utf-8') as f:
config_content = f.read()
# Update holdout_datetime in config file
import re
# Match both string format ('YYYY-MM-DD HH:MM:SS') and None
pattern = r"('holdout_datetime':\s*)(?:'[^']*'|None)"
replacement = f"\\1'{new_holdout_str}'"
updated_content = re.sub(pattern, replacement, config_content)
with open(config_py_path, 'w', encoding='utf-8') as f:
f.write(updated_content)
logger.info(f"✅ Auto-updated holdout_datetime to: {new_holdout_str} (30 days before today)")
# Reload config to get updated value
import importlib
import Core.config
importlib.reload(Core.config)
# Access via module namespace to avoid local variable binding issue
return Core.config.get_data_config()
def load_data_excluding_holdout(ticker: str, logger: logging.Logger) -> TimeSeriesDataset:
"""Load data into a pandas dataframe excluding rows newer than holdout_datetime, without scaling"""
# Get period from config - FAIL FAST if missing
data_config = get_data_config()
if 'period' not in data_config:
raise ValueError("period not found in data configuration")
if 'interval' not in data_config:
raise ValueError("interval not found in data configuration")
period = data_config['period']
interval = data_config['interval']
holdout_datetime = data_config['holdout_datetime']
logger.info(f"Holdout datetime: {holdout_datetime}")
# Load preprocessed data with features
path_manager = get_path_manager()
data_file = path_manager.get_processed_market_data_file(ticker, period, interval)
if not os.path.exists(data_file):
raise FileNotFoundError(f"Preprocessed data file not found: {data_file}")
df = pd.read_csv(data_file, index_col=0)
df.index = pd.to_datetime(df.index)
# Remove prediction test columns if present
for col in ['Pred_Peak_Proba', 'Pred_Bottom_Proba', 'Pred_Slope_Proba']:
if col in df.columns:
df.drop(columns=col, inplace=True)
# Filter data to exclude rows newer than holdout_datetime for training
if holdout_datetime is not None:
# Ensure holdout_datetime is timezone-naive to match DataFrame index
if holdout_datetime.tzinfo is not None:
holdout_datetime = holdout_datetime.replace(tzinfo=None)
df = df[df.index <= holdout_datetime]
if df.empty:
raise ValueError(f"Data file is empty for {ticker}")
logger.info(f"📈 Loaded {len(df)} samples. Date range: {df.index[0]} to {df.index[-1]}")
# Create StockDataSet
model_config = get_model_config()
required_keys = ['input_features', 'output_features', 'freq', 'seq_len', 'pred_len', 'task_type']
for key in required_keys:
if key not in model_config:
raise ValueError(f"Required config key '{key}' not found in model configuration")
dataset = TimeSeriesDataset.from_dataframe(
train=True,
dataframe=df,
input_feature_names=model_config['input_features'],
output_feature_names=model_config['output_features'],
freq=model_config['freq'],
seq_len=model_config['seq_len'],
pred_len=model_config['pred_len'],
scaler=None,
scale_output_too=model_config['task_type'] == 'regression'
)
del df
return dataset
def train_ensemble(dataset: TimeSeriesDataset,
ticker: str,
ensemble_config: Dict[str, Any],
epochs: int,
patience: int,
version: str,
skip_base_models: bool = False,
last_fold: bool = False,
deployment: bool = False,
skip_meta_learner: bool = False,
logger: logging.Logger = None) -> TimeSeriesSimpleEnsemble:
"""Train the simple ensemble with weighted voting - FAIL FAST on config errors"""
if logger is None:
logger = logging.getLogger(__name__)
# Validate ensemble configuration
required_keys = ['base_models', 'cv_folds']
for key in required_keys:
if key not in ensemble_config:
raise ValueError(f"Required ensemble config key '{key}' not found")
# Validate base_models is not empty
if not ensemble_config['base_models'] or len(ensemble_config['base_models']) == 0:
raise ValueError("base_models cannot be empty")
logger.info(f"Base models: {list(ensemble_config['base_models'].keys())}")
logger.info(f"CV folds: {ensemble_config['cv_folds']}")
ensemble = TimeSeriesSimpleEnsemble(
base_models_config=ensemble_config['base_models'],
cv_folds=ensemble_config['cv_folds'],
epochs=epochs,
patience=patience
)
# Train ensemble (base models only, no meta-learner)
start_time = time.time()
logger.info(f"Starting ensemble training... (skip_base_models: {skip_base_models}, last_fold: {last_fold}, deployment: {deployment}, skip_meta_learner: {skip_meta_learner}, version: {version})")
ensemble.fit(dataset, ticker, version=version, skip_base_models=skip_base_models, last_fold=last_fold, deployment=deployment, skip_meta_learner=skip_meta_learner)
training_time = time.time() - start_time
logger.info(f"Ensemble training completed in {training_time/60:.2f} minutes")
return ensemble
def print_performance_report(ensemble: TimeSeriesSimpleEnsemble, logger: logging.Logger = None) -> None:
"""Print formatted performance report from ensemble.metrics as a grid"""
if logger is None:
logger = logging.getLogger(__name__)
if ensemble.metrics is None:
logger.warning("⚠️ No metrics available. Ensemble may not be fitted yet.")
return
metrics = ensemble.metrics
# Build report as a single multi-line string to avoid splitting into multiple log entries
report_lines = []
report_lines.append("="*80)
report_lines.append("ENSEMBLE PERFORMANCE REPORT (Last Fold Validation)")
# Print overall ground truth distribution
if 'ground_truth_distribution' in metrics:
gt_dist = metrics['ground_truth_distribution']
total = gt_dist.get('total', 0)
if total > 0:
peak_pct = (gt_dist.get('peak', 0) / total * 100)
slope_pct = (gt_dist.get('slope', 0) / total * 100)
bottom_pct = (gt_dist.get('bottom', 0) / total * 100)
report_lines.append("-"*80)
report_lines.append(f"Ground Truth Distribution (horizon 0):")
report_lines.append(f" Peaks: {gt_dist.get('peak', 0)} ({peak_pct:.1f}%)")
report_lines.append(f" Slopes: {gt_dist.get('slope', 0)} ({slope_pct:.1f}%)")
report_lines.append(f" Bottoms: {gt_dist.get('bottom', 0)} ({bottom_pct:.1f}%)")
report_lines.append(f" Total: {total}")
# Print metrics for each horizon as a grid
if 'horizons' in metrics and len(metrics['horizons']) > 0:
report_lines.append("-"*80)
report_lines.append("ENSEMBLE PERFORMANCE BY HORIZON (Grid Format):")
# Prepare data for grid
rows = []
for h_metric in metrics['horizons']:
horizon = h_metric.get('horizon', 0)
# Main metrics
brier = h_metric.get('brier', 0.0)
auc_pr = h_metric.get('auc_pr', 0.0)
# Per-class metrics
per_class_acc = h_metric.get('per_class_accuracy', {})
per_class_proba = h_metric.get('per_class_pred_proba', {})
per_class_bias = h_metric.get('per_class_bias', {})
rows.append({
'horizon': horizon,
'auc_pr': f"{auc_pr:.3f}",
'brier': f"{brier:.4f}",
'brier_rmse': f"{np.sqrt(brier):.4f}",
'acc_peak': f"{per_class_acc.get('peak', 0.0):.3f}",
'acc_slope': f"{per_class_acc.get('slope', 0.0):.3f}",
'acc_bottom': f"{per_class_acc.get('bottom', 0.0):.3f}",
'proba_peak': f"{per_class_proba.get('peak', 0.0):.3f}",
'proba_slope': f"{per_class_proba.get('slope', 0.0):.3f}",
'proba_bottom': f"{per_class_proba.get('bottom', 0.0):.3f}",
'bias_peak': f"{per_class_bias.get('peak', 0.0):+.1f}%",
'bias_slope': f"{per_class_bias.get('slope', 0.0):+.1f}%",
'bias_bottom': f"{per_class_bias.get('bottom', 0.0):+.1f}%"
})
# Build header
header = (
f"{'Horizon':<8} | "
f"{'AUC-PR':<9} | "
f"{'Brier':<7} | "
f"{'RMSE':<7} | "
f"{'Acc Peak':<9} | "
f"{'Acc Slope':<10} | "
f"{'Acc Bottom':<11} | "
f"{'Proba Peak':<11} | "
f"{'Proba Slope':<12} | "
f"{'Proba Bottom':<13} | "
f"{'Bias Peak':<10} | "
f"{'Bias Slope':<11} | "
f"{'Bias Bottom':<12}"
)
report_lines.append("")
report_lines.append(header)
report_lines.append("-" * len(header))
# Build rows
for row in rows:
report_lines.append(
f"{row['horizon']:<8} | "
f"{row['auc_pr']:<9} | "
f"{row['brier']:<7} | "
f"{row['brier_rmse']:<7} | "
f"{row['acc_peak']:<9} | "
f"{row['acc_slope']:<10} | "
f"{row['acc_bottom']:<11} | "
f"{row['proba_peak']:<11} | "
f"{row['proba_slope']:<12} | "
f"{row['proba_bottom']:<13} | "
f"{row['bias_peak']:<10} | "
f"{row['bias_slope']:<11} | "
f"{row['bias_bottom']:<12}"
)
report_lines.append("-" * len(header))
# Log the entire report as a single message
logger.info("\n".join(report_lines))
def get_class_name(class_idx: int) -> str:
model_config = get_model_config()
return model_config['output_features'][class_idx]
def get_class_names() -> list:
"""Get class names from configuration"""
model_config = get_model_config()
return model_config['output_features']
if __name__ == "__main__":
exit(main())