-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_bitwise_fix.py
More file actions
133 lines (109 loc) · 4.22 KB
/
test_bitwise_fix.py
File metadata and controls
133 lines (109 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""Quick test to verify bitwise operations work with tensors"""
import sys
sys.path.insert(0, r'D:\stability\Data\Packages\ComfyUI')
import torch
from custom_nodes.more_math.more_math.Parser.UnifiedMathVisitor import UnifiedMathVisitor
def test_bitwise_xor_float():
"""Test XOR with float tensors (the reported bug)"""
print("Testing bitwise XOR with float tensors...")
# Create float tensors (this was causing the error)
a = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
b = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32)
visitor = UnifiedMathVisitor({"a": a, "b": b})
try:
# This should convert to int64, perform XOR, then convert back
result = visitor._bitwise_op(a, b, torch.bitwise_xor, lambda x, y: x ^ y)
print("✓ XOR succeeded!")
print(f" Input a (float32): {a}")
print(f" Input b (float32): {b}")
print(f" Result: {result}")
print(f" Result dtype: {result.dtype}")
return True
except Exception as e:
print(f"✗ XOR failed: {e}")
return False
def test_bitwise_and_float():
"""Test AND with float tensors"""
print("\nTesting bitwise AND with float tensors...")
a = torch.tensor([15.0, 14.0, 13.0], dtype=torch.float32)
b = torch.tensor([7.0, 3.0, 1.0], dtype=torch.float32)
visitor = UnifiedMathVisitor({"a": a, "b": b})
try:
result = visitor._bitwise_op(a, b, torch.bitwise_and, lambda x, y: x & y)
print("✓ AND succeeded!")
print(f" Input a (float32): {a}")
print(f" Input b (float32): {b}")
print(f" Result: {result}")
print(f" Result dtype: {result.dtype}")
return True
except Exception as e:
print(f"✗ AND failed: {e}")
return False
def test_bitwise_or_float():
"""Test OR with float tensors"""
print("\nTesting bitwise OR with float tensors...")
a = torch.tensor([15.0, 14.0, 13.0], dtype=torch.float32)
b = torch.tensor([7.0, 3.0, 1.0], dtype=torch.float32)
visitor = UnifiedMathVisitor({"a": a, "b": b})
try:
result = visitor._bitwise_op(a, b, torch.bitwise_or, lambda x, y: x | y)
print("✓ OR succeeded!")
print(f" Input a (float32): {a}")
print(f" Input b (float32): {b}")
print(f" Result: {result}")
print(f" Result dtype: {result.dtype}")
return True
except Exception as e:
print(f"✗ OR failed: {e}")
return False
def test_bitwise_not_float():
"""Test NOT with float tensors"""
print("\nTesting bitwise NOT with float tensors...")
a = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
visitor = UnifiedMathVisitor({"a": a})
try:
result = visitor._bitwise_not(a)
print("✓ NOT succeeded!")
print(f" Input a (float32): {a}")
print(f" Result: {result}")
print(f" Result dtype: {result.dtype}")
return True
except Exception as e:
print(f"✗ NOT failed: {e}")
return False
def test_int_tensors_preserved():
"""Ensure int tensors still work as before"""
print("\nTesting that int tensor dtypes are preserved...")
a = torch.tensor([15, 14, 13], dtype=torch.int16)
b = torch.tensor([7, 3, 1], dtype=torch.int16)
visitor = UnifiedMathVisitor({"a": a, "b": b})
try:
result = visitor._bitwise_op(a, b, torch.bitwise_and, lambda x, y: x & y)
assert result.dtype == torch.int16, f"Expected int16, got {result.dtype}"
print("✓ Int16 dtype preserved!")
print(f" Input a (int16): {a}")
print(f" Input b (int16): {b}")
print(f" Result (int16): {result}")
return True
except Exception as e:
print(f"✗ Int16 test failed: {e}")
return False
if __name__ == "__main__":
print("=" * 70)
print("Bitwise Operations Fix Verification")
print("=" * 70)
results = [
test_bitwise_xor_float(),
test_bitwise_and_float(),
test_bitwise_or_float(),
test_bitwise_not_float(),
test_int_tensors_preserved(),
]
print("\n" + "=" * 70)
if all(results):
print("✓ All tests passed!")
else:
print("✗ Some tests failed")
sys.exit(1)
print("=" * 70)