Skip to content

Commit 8efb3bb

Browse files
committed
Add parallel test support and improve test isolation
Enhanced pytest_groups.py to support parallel execution via pytest-xdist, updated documentation and environment variable handling, and improved isolation for problematic tests. Added new files to the isolated test list to prevent state conflicts and blocking issues. requirements-dev.txt updated to include pytest-xdist. Minor fix in test_abstract_models.py to ensure dt is set. Synchronized dt setting in environment.py with brainstate.environ.
1 parent 1280296 commit 8efb3bb

4 files changed

Lines changed: 87 additions & 29 deletions

File tree

brainpy/_src/dyn/synapses/tests/test_abstract_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class TestDualExpon(unittest.TestCase):
1212
def test_dual_expon(self):
13-
# bm.set(dt=0.01)
13+
bm.set(dt=0.01)
1414

1515
class Net(bp.DynSysGroup):
1616
def __init__(self, tau_r, tau_d, n_spk):

brainpy/_src/math/environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
from typing import Any, Callable, TypeVar, cast
1212

13+
import brainstate.environ
1314
import jax
1415
from jax import config, numpy as jnp, devices
1516
from jax.lib import xla_bridge
@@ -563,6 +564,7 @@ def set_dt(dt):
563564
"""
564565
assert isinstance(dt, float), f'"dt" must a float, but we got {dt}'
565566
defaults.__dict__['dt'] = dt
567+
brainstate.environ.set(dt=dt)
566568

567569

568570
def get_dt():

pytest_groups.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
#!/usr/bin/env python3
22
"""
33
Minimal isolated test runner for BrainPy - only isolates known problematic tests.
4+
5+
Supports parallel execution via pytest-xdist for faster testing:
6+
7+
Usage:
8+
python pytest_groups.py # Auto-detect optimal worker count
9+
PYTEST_WORKERS=4 python pytest_groups.py # Use 4 workers
10+
PYTEST_WORKERS=1 python pytest_groups.py # Sequential execution
11+
12+
Environment variables:
13+
PYTEST_WORKERS: Number of workers ('auto' or integer, default: 'auto')
14+
IS_GITHUB_ACTIONS: Set to '1' in CI for cleaner output
15+
16+
Note: Isolated tests always run sequentially to avoid state conflicts.
417
"""
518

619
import subprocess
@@ -24,6 +37,10 @@
2437
# Files that contain problematic tests (need to be run separately)
2538
ISOLATED_FILES = [
2639
"brainpy/_src/math/object_transform/tests/test_base.py", # causes state pollution
40+
"brainpy/_src/dyn/neurons/tests/test_lif.py", # NoneType * DynamicJaxprTracer in parallel
41+
"brainpy/_src/integrators/sde/tests/test_normal.py", # plt.show() blocking on macOS
42+
"brainpy/_src/integrators/tests/test_integ_runner.py", # plt.show() blocking on macOS
43+
"brainpy/_src/analysis/lowdim/tests/test_phase_plane.py", # plt.show() blocking on macOS
2744
]
2845

2946
# Additional files that need isolation in CI environments
@@ -43,7 +60,7 @@ def run_isolated_test(test_path):
4360
if is_github_actions:
4461
test_args.extend(["--maxfail=1", "-q"])
4562

46-
cmd = base_cmd + [test_path] + test_args
63+
cmd = base_cmd + test_args + [test_path]
4764
result = subprocess.run(cmd)
4865
return result.returncode == 0
4966

@@ -55,6 +72,12 @@ def main():
5572
is_github_actions = os.getenv('IS_GITHUB_ACTIONS') == '1'
5673
base_cmd = [sys.executable, "-m", "pytest"]
5774
test_args = ["-v", "--tb=short"]
75+
76+
# Add parallel execution support
77+
workers = os.getenv('PYTEST_WORKERS', 'auto')
78+
if workers != '1': # Skip parallel if explicitly set to 1
79+
test_args.extend(["-n", workers])
80+
5881
if is_github_actions:
5982
test_args.extend(["--maxfail=5"])
6083

@@ -80,52 +103,84 @@ def main():
80103
print("=" * 80)
81104

82105
# Run main test suite (excluding problematic files)
83-
print(f"\n{'Running main test suite...':<60} ", end="", flush=True)
106+
print(f"\n{'Main test suite:':<60}")
107+
print("-" * 80)
84108

85-
cmd = base_cmd + ["brainpy/_src/"] + test_args + ignore_patterns
109+
cmd = base_cmd + test_args + ignore_patterns + ["brainpy/_src/"]
86110
main_start = time.time()
87-
main_result = subprocess.run(cmd, capture_output=True, text=True)
88-
main_time = time.time() - main_start
89-
main_passed = main_result.returncode == 0
90111

91-
if main_passed:
92-
print(f"PASSED ({main_time:.1f}s)")
93-
else:
94-
print(f"FAILED ({main_time:.1f}s)")
95-
if not is_github_actions:
96-
# Extract key info from pytest output
112+
if is_github_actions:
113+
# In CI, capture output to keep logs clean
114+
main_result = subprocess.run(cmd, capture_output=True, text=True)
115+
main_passed = main_result.returncode == 0
116+
main_time = time.time() - main_start
117+
118+
print(f"Main test suite: {'PASSED' if main_passed else 'FAILED'} ({main_time:.1f}s)")
119+
120+
if not main_passed:
121+
# Show failures in CI
97122
lines = main_result.stdout.split('\n')
98-
failed_lines = [line for line in lines if 'FAILED' in line][:5] # Show first 5 failures
123+
failed_lines = [line for line in lines if 'FAILED' in line][:5]
99124
if failed_lines:
100-
print("\n Recent failures:")
125+
print("Recent failures:")
101126
for line in failed_lines:
102-
print(f" {line}")
127+
print(f" {line}")
128+
else:
129+
# Locally, show real-time progress
130+
main_result = subprocess.run(cmd)
131+
main_passed = main_result.returncode == 0
132+
main_time = time.time() - main_start
133+
134+
print(f"\nMain test suite: {'PASSED' if main_passed else 'FAILED'} ({main_time:.1f}s)")
103135

104136
# Run isolated problematic files
105137
isolated_results = []
106138
for file_path in sorted(all_problematic_files):
107139
if os.path.exists(file_path):
108140
file_name = file_path.split("/")[-1]
109-
print(f"{'Isolated: ' + file_name:<60} ", end="", flush=True)
141+
print(f"\n{'Isolated: ' + file_name:<60}")
142+
print("-" * 80)
110143

111-
cmd = base_cmd + [file_path] + test_args + ["-x"]
144+
# For isolated tests, remove parallel args to avoid conflicts
145+
iso_test_args = []
146+
skip_next = False
147+
for arg in test_args:
148+
if skip_next:
149+
skip_next = False
150+
continue
151+
if arg == "-n":
152+
skip_next = True # Skip the next argument (worker count)
153+
continue
154+
iso_test_args.append(arg)
155+
iso_test_args.append("-x")
156+
157+
cmd = base_cmd + iso_test_args + [file_path]
112158
iso_start = time.time()
113-
result = subprocess.run(cmd, capture_output=True, text=True)
114-
iso_time = time.time() - iso_start
115-
passed = result.returncode == 0
116-
isolated_results.append(passed)
117159

118-
if passed:
119-
print(f"PASSED ({iso_time:.1f}s)")
120-
else:
121-
print(f"FAILED ({iso_time:.1f}s)")
122-
if not is_github_actions:
160+
if is_github_actions:
161+
# In CI, capture output
162+
result = subprocess.run(cmd, capture_output=True, text=True)
163+
passed = result.returncode == 0
164+
iso_time = time.time() - iso_start
165+
isolated_results.append(passed)
166+
167+
print(f"Isolated {file_name}: {'PASSED' if passed else 'FAILED'} ({iso_time:.1f}s)")
168+
169+
if not passed:
123170
lines = result.stdout.split('\n')
124171
failed_lines = [line for line in lines if 'FAILED' in line][:3]
125172
if failed_lines:
126-
print(" Failures:")
173+
print("Failures:")
127174
for line in failed_lines:
128-
print(f" {line}")
175+
print(f" {line}")
176+
else:
177+
# Locally, show real-time progress
178+
result = subprocess.run(cmd)
179+
passed = result.returncode == 0
180+
iso_time = time.time() - iso_start
181+
isolated_results.append(passed)
182+
183+
print(f"\nIsolated {file_name}: {'PASSED' if passed else 'FAILED'} ({iso_time:.1f}s)")
129184
else:
130185
isolated_results.append(True)
131186

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ setuptools
1212

1313
# test requirements
1414
pytest
15+
pytest-xdist # for parallel test execution
1516
absl-py

0 commit comments

Comments
 (0)