-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_agents.py
More file actions
194 lines (148 loc) · 5.63 KB
/
Copy pathtest_agents.py
File metadata and controls
194 lines (148 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
Test script for QuantAgent components
"""
import os
import sys
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from utils.data_fetcher import DataFetcher
from utils.config import Config
from agents.indicator_agent import IndicatorAgent
from agents.pattern_agent import PatternAgent
from agents.trend_agent import TrendAgent
def create_sample_data():
"""Create sample OHLCV data for testing."""
dates = pd.date_range(start='2023-01-01', end='2023-12-31', freq='D')
np.random.seed(42)
# Generate realistic price data
base_price = 100
prices = []
current_price = base_price
for i in range(len(dates)):
# Random walk with slight upward bias
change = np.random.normal(0.001, 0.02) # 0.1% daily drift, 2% volatility
current_price *= (1 + change)
prices.append(current_price)
# Create OHLCV data
data = []
for i, price in enumerate(prices):
daily_volatility = 0.01
high = price * (1 + np.random.uniform(0, daily_volatility))
low = price * (1 - np.random.uniform(0, daily_volatility))
open_price = prices[i-1] if i > 0 else price
close = price
volume = np.random.randint(1000000, 10000000)
data.append({
'Open': open_price,
'High': high,
'Low': low,
'Close': close,
'Volume': volume
})
df = pd.DataFrame(data, index=dates)
return df
def test_data_fetcher():
"""Test the data fetcher."""
print("Testing DataFetcher...")
fetcher = DataFetcher()
# Test with a well-known symbol
data = fetcher.fetch_data("AAPL", period="1mo", interval="1d")
if data is not None:
print(f"✓ Successfully fetched {len(data)} rows of AAPL data")
print(f" Columns: {list(data.columns)}")
print(f" Date range: {data.index[0]} to {data.index[-1]}")
else:
print("✗ Failed to fetch data, using sample data for testing")
data = create_sample_data()
print(f"✓ Created sample data with {len(data)} rows")
return data
def test_indicator_agent(data):
"""Test the IndicatorAgent."""
print("\nTesting IndicatorAgent...")
try:
agent = IndicatorAgent()
result = agent.analyze(data)
print(f"✓ IndicatorAgent analysis completed")
print(f" Forecast: {result['forecast']}")
print(f" Evidence: {result['evidence']}")
print(f" Indicators: {result['indicators']}")
return result
except Exception as e:
print(f"✗ IndicatorAgent failed: {str(e)}")
return None
def test_pattern_agent(data):
"""Test the PatternAgent."""
print("\nTesting PatternAgent...")
try:
agent = PatternAgent(chart_dir="charts")
result = agent.analyze(data, symbol="TEST")
print(f"✓ PatternAgent analysis completed")
print(f" Chart saved to: {result['chart_path']}")
print(f" Pattern description: {result['pattern_description'][:100]}...")
return result
except Exception as e:
print(f"✗ PatternAgent failed: {str(e)}")
return None
def test_trend_agent(data):
"""Test the TrendAgent."""
print("\nTesting TrendAgent...")
try:
agent = TrendAgent()
result = agent.analyze(data)
print(f"✓ TrendAgent analysis completed")
print(f" Overall direction: {result['direction']}")
print(f" Confidence: {result['confidence']:.2f}")
print(f" Summary: {result['summary'][:100]}...")
return result
except Exception as e:
print(f"✗ TrendAgent failed: {str(e)}")
return None
def test_config():
"""Test the configuration system."""
print("\nTesting Config...")
try:
config = Config()
print(f"✓ Config loaded successfully")
# Test validation
validation = config.validate_api_keys()
print(f" API key validation: {validation}")
missing = config.get_missing_keys()
if missing:
print(f" Missing keys: {missing}")
else:
print(f" All API keys configured")
return config
except Exception as e:
print(f"✗ Config failed: {str(e)}")
return None
def main():
"""Run all tests."""
print("QuantAgent Component Tests")
print("=" * 50)
# Test configuration
config = test_config()
# Test data fetching
data = test_data_fetcher()
if data is not None:
# Test all agents
indicator_result = test_indicator_agent(data)
pattern_result = test_pattern_agent(data)
trend_result = test_trend_agent(data)
print("\n" + "=" * 50)
print("Test Summary:")
print(f"✓ Data Fetcher: {'PASS' if data is not None else 'FAIL'}")
print(f"✓ IndicatorAgent: {'PASS' if indicator_result else 'FAIL'}")
print(f"✓ PatternAgent: {'PASS' if pattern_result else 'FAIL'}")
print(f"✓ TrendAgent: {'PASS' if trend_result else 'FAIL'}")
print(f"✓ Config: {'PASS' if config else 'FAIL'}")
if all([data is not None, indicator_result, pattern_result, trend_result, config]):
print("\n🎉 All tests passed! Ready to build the Streamlit app.")
else:
print("\n⚠️ Some tests failed. Check the errors above.")
else:
print("\n✗ Cannot proceed without data")
if __name__ == "__main__":
main()