Skip to content

Commit 52ff1ce

Browse files
consistent qir
1 parent ce07d4f commit 52ff1ce

5 files changed

Lines changed: 244 additions & 131 deletions

File tree

tensorcircuit/mpscircuit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
"split": split,
122122
"tensors": tensors,
123123
"wavefunction": wavefunction,
124+
"dim": dim,
124125
}
125126
if split is None:
126127
split = {}

tensorcircuit/stabilizercircuit.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,17 @@ def __init__(
4848
self._stim_circuit = stim.Circuit()
4949
self._qir: List[Dict[str, Any]] = []
5050
self.is_dm = False
51-
self.inputs = None
51+
self.inputs = inputs
52+
self.tableau_inputs = tableau_inputs
5253
self._extra_qir: List[Dict[str, Any]] = []
5354
self.current_sim = stim.TableauSimulator()
55+
56+
self.circuit_param = {
57+
"nqubits": nqubits,
58+
"inputs": inputs,
59+
"tableau_inputs": tableau_inputs,
60+
}
61+
5462
if inputs:
5563
self.current_sim.set_state_from_stabilizers(inputs)
5664
if tableau_inputs:

tensorcircuit/u1circuit.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ def __init__(
4949
self._extra_qir: List[Dict[str, Any]] = []
5050
self.is_mps = False
5151

52+
self.circuit_param = {
53+
"nqubits": nqubits,
54+
"k": k,
55+
"filled": filled,
56+
"inputs": inputs,
57+
}
58+
5259
# Mapping helpers
5360
# TC uses qubit 0 as the leftmost (highest) bit
5461
# So qubit i corresponds to bit position (n-1-i) in the state integer
@@ -192,32 +199,6 @@ def _apply_iswap(self, i: int, j: int, theta: Any) -> None:
192199
cos_t * self._state + isin_t * swapped_state
193200
)
194201

195-
def _apply_fsim(self, i: int, j: int, theta: Any, phi: Any) -> None:
196-
"""Apply fSim(theta, phi): fermionic simulation gate."""
197-
bpi, bpj = self._bit_position(i), self._bit_position(j)
198-
mask_i = 1 << bpi
199-
mask_j = 1 << bpj
200-
bi = backend.right_shift(backend.bitwise_and(self._basis_tensor, mask_i), bpi)
201-
bj = backend.right_shift(backend.bitwise_and(self._basis_tensor, mask_j), bpj)
202-
diff = backend.bitwise_xor(bi, bj)
203-
mask_swap = (1 << bpi) | (1 << bpj)
204-
new_basis = backend.bitwise_xor(self._basis_tensor, diff * mask_swap)
205-
indices = backend.searchsorted(self._basis_tensor, new_basis)
206-
207-
theta_c = backend.cast(backend.convert_to_tensor(theta), dtypestr)
208-
phi_c = backend.cast(backend.convert_to_tensor(phi), dtypestr)
209-
both_on = backend.cast(backend.bitwise_and(bi, bj), dtypestr)
210-
phi_factor = 1.0 + (backend.exp(-1j * phi_c) - 1.0) * both_on
211-
212-
cos_theta = backend.cos(theta_c)
213-
isin_theta = 1j * backend.sin(theta_c)
214-
diff_f = backend.cast(diff, dtypestr)
215-
swapped_state = backend.gather1d(self._state, indices)
216-
217-
self._state = (1.0 - diff_f) * self._state * phi_factor + diff_f * (
218-
cos_theta * self._state - isin_theta * swapped_state
219-
)
220-
221202
# -------------------------------------------------------------------------
222203
# Public gate methods (delegate to internal implementations)
223204
# -------------------------------------------------------------------------
@@ -227,16 +208,24 @@ def apply_general_gate(
227208
gate: Any,
228209
*index: int,
229210
name: Optional[str] = None,
211+
split: Optional[Dict[str, Any]] = None,
212+
mpo: bool = False,
213+
ir_dict: Optional[Dict[str, Any]] = None,
230214
**kwargs: Any,
231215
) -> None:
232216
"""
233217
Apply a gate by name. Called by _meta_apply generated methods.
234218
235219
:param gate: Gate tensor (ignored, dispatch is by name)
236220
:param index: Qubit indices
237-
:param name: Gate name (rz, rzz, cz, cphase, swap, iswap, fsim)
238-
:param kwargs: May contain ir_dict with parameters from _meta_apply
221+
:param name: Gate name (rz, rzz, cz, cphase, swap, iswap)
222+
:param split: Split configuration (ignored in U1Circuit)
223+
:param mpo: MPO flag (ignored in U1Circuit)
224+
:param ir_dict: QIR dictionary for recording
225+
:param kwargs: Extra parameters
239226
"""
227+
if name is None:
228+
name = ""
240229
non_u1 = {"x", "y", "h", "t", "s", "td", "sd", "rx", "ry"}
241230
if name and name.lower() in non_u1:
242231
raise ValueError(
@@ -246,7 +235,28 @@ def apply_general_gate(
246235
gate_name = name.lower() if name else None
247236

248237
# Extract parameters: _meta_apply puts them in ir_dict['parameters']
249-
params = kwargs.get("ir_dict", {}).get("parameters", kwargs)
238+
# Also check kwargs for direct calls
239+
params = {}
240+
if ir_dict is not None and "parameters" in ir_dict:
241+
params.update(ir_dict["parameters"])
242+
params.update(kwargs)
243+
244+
# Record gate in QIR
245+
gate_dict = {
246+
"gate": gate,
247+
"index": index,
248+
"name": name,
249+
"split": split,
250+
"mpo": mpo,
251+
}
252+
if params:
253+
gate_dict["parameters"] = params
254+
255+
if ir_dict is not None:
256+
ir_dict.update(gate_dict)
257+
else:
258+
ir_dict = gate_dict
259+
self._qir.append(ir_dict)
250260

251261
if gate_name == "rz":
252262
self._apply_rz(index[0], params.get("theta", 0))
@@ -260,23 +270,16 @@ def apply_general_gate(
260270
self._apply_swap(index[0], index[1])
261271
elif gate_name == "iswap":
262272
self._apply_iswap(index[0], index[1], params.get("theta", 1.0))
263-
elif gate_name == "fsim":
264-
self._apply_fsim(
265-
index[0], index[1], params.get("theta", 0), params.get("phi", 0)
266-
)
273+
elif gate_name == "cphase":
274+
self._apply_cphase(index[0], index[1], params.get("theta", 0))
267275
else:
268276
raise ValueError(
269277
f"Gate {name} not implemented in U1Circuit. "
270-
"Supported: rz, rzz, cz, cphase, swap, iswap, fsim."
278+
"Supported: rz, rzz, cz, cphase, swap, iswap."
271279
)
272280

273281
# Note: Most gate methods (rz, rzz, cz, swap, iswap, cphase, etc.) are
274282
# auto-generated by _meta_apply() which calls apply_general_gate.
275-
# fsim is not in the standard gate list, so we define it explicitly.
276-
277-
def fsim(self, i: int, j: int, theta: Any = 0, phi: Any = 0, **kwargs: Any) -> None:
278-
"""Apply fSim gate on qubits i and j."""
279-
self._apply_fsim(i, j, theta, phi)
280283

281284
# -------------------------------------------------------------------------
282285
# State and expectation methods
@@ -585,5 +588,3 @@ def measure(
585588

586589
# Register gates via _meta_apply
587590
U1Circuit._meta_apply()
588-
589-
# TODO(@refraction-ray): qir support

tests/test_qir_unification.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import numpy as np
2+
import pytest
3+
from pytest_lazyfixture import lazy_fixture as lf
4+
import tensorcircuit as tc
5+
from tensorcircuit.u1circuit import U1Circuit
6+
from tensorcircuit.stabilizercircuit import StabilizerCircuit
7+
from tensorcircuit.mpscircuit import MPSCircuit
8+
9+
10+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
11+
def test_u1circuit_qir_roundtrip(backend):
12+
n = 4
13+
k = 2
14+
c = U1Circuit(n, k=k, filled=[0, 1])
15+
c.rz(0, theta=0.5)
16+
c.cz(0, 1)
17+
c.iswap(1, 2, theta=0.8)
18+
c.cphase(2, 3, theta=0.4)
19+
20+
qir = c.to_qir()
21+
assert len(qir) == 4
22+
23+
# Reconstruct from QIR
24+
c2 = U1Circuit.from_qir(qir, c.circuit_param)
25+
26+
# Verify expectations
27+
for i in range(n):
28+
np.testing.assert_allclose(
29+
tc.backend.numpy(c.expectation_z(i)),
30+
tc.backend.numpy(c2.expectation_z(i)),
31+
atol=1e-5,
32+
)
33+
34+
35+
@pytest.mark.parametrize("backend", [lf("npb")])
36+
def test_stabilizer_qir_roundtrip(backend):
37+
n = 4
38+
c = StabilizerCircuit(n)
39+
c.h(0)
40+
c.cnot(0, 1)
41+
c.s(1)
42+
43+
qir = c.to_qir()
44+
assert len(qir) == 3
45+
46+
c2 = StabilizerCircuit.from_qir(qir, c.circuit_param)
47+
48+
# Compare state vectors (TableauSimulator to_state_vector)
49+
np.testing.assert_allclose(
50+
tc.backend.numpy(c.state()), tc.backend.numpy(c2.state()), atol=1e-5
51+
)
52+
53+
54+
@pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
55+
def test_mps_qir_roundtrip(backend):
56+
n = 4
57+
c = MPSCircuit(n)
58+
c.h(0)
59+
c.cnot(0, 1)
60+
c.rz(1, theta=0.8)
61+
62+
qir = c.to_qir()
63+
assert len(qir) == 3
64+
65+
c2 = MPSCircuit.from_qir(qir, c.circuit_param)
66+
67+
np.testing.assert_allclose(
68+
tc.backend.numpy(c.wavefunction()),
69+
tc.backend.numpy(c2.wavefunction()),
70+
atol=1e-5,
71+
)
72+
73+
74+
@pytest.mark.parametrize("backend", [lf("npb")])
75+
def test_circuit_to_stabilizer(backend):
76+
n = 2
77+
c = tc.Circuit(n)
78+
c.h(0)
79+
c.cnot(0, 1)
80+
c.s(1)
81+
c.x(0)
82+
83+
qir = c.to_qir()
84+
# Convert to StabilizerCircuit
85+
c_stab = StabilizerCircuit.from_qir(qir, circuit_params={"nqubits": n})
86+
87+
np.testing.assert_allclose(
88+
tc.backend.numpy(c.state()), tc.backend.numpy(c_stab.state()), atol=1e-5
89+
)
90+
91+
92+
@pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
93+
def test_u1_to_mps(backend):
94+
n = 4
95+
c_u1 = U1Circuit(n, k=2, filled=[0, 1])
96+
c_u1.rz(0, theta=0.5)
97+
c_u1.cz(0, 1)
98+
c_u1.iswap(1, 2, theta=1.2)
99+
c_u1.rzz(3, 2, theta=-0.9)
100+
101+
qir = c_u1.to_qir()
102+
103+
c_mps = MPSCircuit(n)
104+
# Match initial state of U1 ([0,1] filled)
105+
c_mps.x(0)
106+
c_mps.x(1)
107+
c_mps.append_from_qir(qir)
108+
109+
np.testing.assert_allclose(
110+
tc.backend.numpy(c_u1.to_dense()),
111+
tc.backend.numpy(c_mps.wavefunction().reshape([-1])),
112+
atol=1e-5,
113+
)
114+
115+
116+
@pytest.mark.parametrize("backend", [lf("npb")])
117+
def test_mps_to_circuit(backend):
118+
n = 2
119+
c_mps = MPSCircuit(n)
120+
c_mps.h(0)
121+
c_mps.cnot(0, 1)
122+
c_mps.ryy(1, 0, theta=0.9)
123+
c_mps.x(1)
124+
125+
qir = c_mps.to_qir()
126+
c_std = tc.Circuit.from_qir(qir, circuit_params={"nqubits": n})
127+
128+
np.testing.assert_allclose(
129+
tc.backend.numpy(c_mps.wavefunction().reshape([-1])),
130+
tc.backend.numpy(c_std.state()),
131+
atol=1e-5,
132+
)
133+
134+
135+
@pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
136+
def test_dm_to_u1(backend):
137+
n = 2
138+
c_dm = tc.DMCircuit(n)
139+
c_dm.rz(0, theta=0.5)
140+
c_dm.cz(0, 1)
141+
c_dm.cphase(1, 0, theta=-1.1)
142+
143+
qir = c_dm.to_qir()
144+
c_u1 = U1Circuit(n, k=0)
145+
c_u1.append_from_qir(qir)
146+
147+
dm = tc.backend.numpy(c_dm.state())
148+
st = tc.backend.numpy(c_u1.to_dense())
149+
np.testing.assert_allclose(dm, np.outer(st, np.conj(st)), atol=1e-5)
150+
151+
152+
@pytest.mark.parametrize("backend", [lf("npb")])
153+
def test_mps_to_stabilizer(backend):
154+
n = 2
155+
c_mps = MPSCircuit(n)
156+
c_mps.h(0)
157+
c_mps.cnot(0, 1)
158+
c_mps.s(1)
159+
160+
qir = c_mps.to_qir()
161+
c_stab = StabilizerCircuit.from_qir(qir, circuit_params={"nqubits": n})
162+
163+
np.testing.assert_allclose(
164+
tc.backend.numpy(c_mps.wavefunction().reshape([-1])),
165+
tc.backend.numpy(c_stab.state()),
166+
atol=1e-5,
167+
)

0 commit comments

Comments
 (0)