Skip to content

Commit c1cdeae

Browse files
committed
Add function to fuse two contraction definitions
1 parent 61d3d42 commit c1cdeae

1 file changed

Lines changed: 69 additions & 1 deletion

File tree

varipeps/contractions/definitions.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _process_def(cls, e, name):
179179
@classmethod
180180
def _prepare_defs(cls):
181181
for name in dir(cls):
182-
if name == "add_def" or name.startswith("_"):
182+
if name == "add_def" or name == "join_defs" or name.startswith("_"):
183183
continue
184184

185185
e = getattr(cls, name)
@@ -191,6 +191,74 @@ def add_def(cls, name, definition):
191191
cls._process_def(definition, name)
192192
setattr(cls, name, definition)
193193

194+
@classmethod
195+
def join_defs(cls, name1, name2, join_indices):
196+
new_name = f"joined_{name1}_{name2}_{join_indices}"
197+
if getattr(cls, new_name, None) is not None:
198+
return new_name
199+
200+
if len(join_indices[0]) != len(join_indices[1]):
201+
raise ValueError("Length of join indices mismatches.")
202+
203+
def1 = getattr(cls, name1)
204+
def2 = getattr(cls, name2)
205+
206+
new_def = {}
207+
208+
new_def["tensors"] = list(def1["tensors"]) + list(def2["tensors"])
209+
210+
def1_flatten = [j for i in def1["ncon_network"] for j in i]
211+
max1 = max(def1_flatten)
212+
min1 = min(def1_flatten)
213+
def2_flatten = [j for i in def2["ncon_network"] for j in i]
214+
max2 = max(def2_flatten)
215+
min2 = min(def2_flatten)
216+
217+
new_def["network"] = []
218+
219+
def map_join_indices(i, offset, neg_offset, indices_map):
220+
if i > 0:
221+
return i + offset
222+
elems_joined_before = 0
223+
for pos, j in enumerate(indices_map):
224+
if i == j:
225+
return max1 + max2 + pos + 1
226+
if j > i:
227+
elems_joined_before += 1
228+
return i + elems_joined_before - neg_offset
229+
230+
def gen_new_network_entry(n, offset, neg_offset, indices_map):
231+
if isinstance(n, (list, tuple)) and all(
232+
isinstance(ni, (list, tuple)) for ni in n
233+
):
234+
new_entry = []
235+
for e in n:
236+
new_entry.append(
237+
tuple(
238+
map_join_indices(i, offset, neg_offset, indices_map)
239+
for i in e
240+
)
241+
)
242+
return new_entry
243+
elif isinstance(n, (list, tuple)) and all(isinstance(ni, int) for ni in n):
244+
return tuple(
245+
map_join_indices(i, offset, neg_offset, indices_map) for i in n
246+
)
247+
248+
for n in def1["network"]:
249+
new_def["network"].append(gen_new_network_entry(n, 0, 0, join_indices[0]))
250+
251+
for n in def2["network"]:
252+
new_def["network"].append(
253+
gen_new_network_entry(
254+
n, max1, -min1 - len(join_indices[0]), join_indices[1]
255+
)
256+
)
257+
258+
cls.add_def(new_name, new_def)
259+
260+
return new_name
261+
194262
density_matrix_one_site: Definition = {
195263
"tensors": [
196264
["tensor", "tensor_conj", "C1", "T1", "C2", "T2", "C3", "T3", "C4", "T4"]

0 commit comments

Comments
 (0)