Skip to content

Commit a4a69f0

Browse files
committed
Implement gauge fixing based on overlap with previous CTM tensor to make it more GPU friendly
1 parent c1cdeae commit a4a69f0

2 files changed

Lines changed: 157 additions & 73 deletions

File tree

varipeps/ctmrg/absorption.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -94,38 +94,18 @@ def _get_ctmrg_1x2_structure(
9494
return view_tensors, view_tensor_objs
9595

9696

97-
def _post_process_CTM_tensors(a: jnp.ndarray, config: VariPEPS_Config) -> jnp.ndarray:
97+
def _post_process_CTM_tensors(
98+
a: jnp.ndarray, a_old: jnp.ndarray, config: VariPEPS_Config
99+
) -> jnp.ndarray:
98100
a = a / jnp.linalg.norm(a)
99-
a_abs = jnp.abs(a)
100-
a_abs_max = jnp.max(a_abs)
101101

102-
def scan_max_element(carry, x):
103-
x_a, x_a_abs = x
104-
found, phase = carry
105-
106-
def new_phase(ph, curr_x, curr_x_abs):
107-
return cond(
108-
curr_x_abs >= (config.svd_sign_fix_eps * a_abs_max),
109-
lambda p, c_x, c_x_a: c_x / c_x_a,
110-
lambda p, c_x, c_x_a: p,
111-
ph,
112-
curr_x,
113-
curr_x_abs,
114-
)
115-
116-
phase = cond(
117-
found, lambda ph, curr_x, curr_x_abs: ph, new_phase, phase, x_a, x_a_abs
118-
)
119-
120-
return (jnp.logical_not(jnp.isnan(phase)), phase), None
121-
122-
(_, phase), _ = scan(
123-
scan_max_element,
124-
(jnp.array(False), jnp.array(jnp.nan, dtype=a.dtype)),
125-
(a.flatten(), a_abs.flatten()),
126-
)
102+
if a_old.shape == a.shape:
103+
phase = jnp.sum(a.conj() * a_old)
104+
phase = phase / jnp.abs(phase)
105+
else:
106+
phase = 1
127107

128-
return a * phase.conj()
108+
return a * phase
129109

130110

131111
def do_left_absorption(
@@ -187,7 +167,9 @@ class definition for details.
187167
[working_tensor_obj],
188168
[C1_projector],
189169
)
190-
new_C1.append(_post_process_CTM_tensors(new_C1_tmp, config))
170+
new_C1.append(
171+
_post_process_CTM_tensors(new_C1_tmp, view[0, 1][0][0].C1, config)
172+
)
191173

192174
T4_projector_top = left_projectors.get_projector(x, y, -1, 0).top
193175
T4_projector_bottom = left_projectors.get_projector(x, y, 0, 0).bottom
@@ -209,7 +191,9 @@ class definition for details.
209191
[working_tensor_obj],
210192
[T4_projector_top, T4_projector_bottom],
211193
)
212-
new_T4.append(_post_process_CTM_tensors(new_T4_tmp, config))
194+
new_T4.append(
195+
_post_process_CTM_tensors(new_T4_tmp, view[0, 1][0][0].T4, config)
196+
)
213197

214198
C4_projector = left_projectors.get_projector(x, y, 0, 0).top
215199
new_C4_tmp = apply_contraction_jitted(
@@ -218,7 +202,9 @@ class definition for details.
218202
[working_tensor_obj],
219203
[C4_projector],
220204
)
221-
new_C4.append(_post_process_CTM_tensors(new_C4_tmp, config))
205+
new_C4.append(
206+
_post_process_CTM_tensors(new_C4_tmp, view[0, 1][0][0].C4, config)
207+
)
222208

223209
for x, view in column_views:
224210
view[0, 1] = view[0, 1][0][0].replace_left_env_tensors(
@@ -291,7 +277,9 @@ class definition for details.
291277
[working_tensor_obj],
292278
[C2_projector],
293279
)
294-
new_C2.append(_post_process_CTM_tensors(new_C2_tmp, config))
280+
new_C2.append(
281+
_post_process_CTM_tensors(new_C2_tmp, view[0, -1][0][0].C2, config)
282+
)
295283

296284
T2_projector_top = right_projectors.get_projector(x, y, -1, 0).top
297285
T2_projector_bottom = right_projectors.get_projector(x, y, 0, 0).bottom
@@ -312,7 +300,9 @@ class definition for details.
312300
[working_tensor_obj],
313301
[T2_projector_top, T2_projector_bottom],
314302
)
315-
new_T2.append(_post_process_CTM_tensors(new_T2_tmp, config))
303+
new_T2.append(
304+
_post_process_CTM_tensors(new_T2_tmp, view[0, -1][0][0].T2, config)
305+
)
316306

317307
C3_projector = right_projectors.get_projector(x, y, 0, 0).top
318308
new_C3_tmp = apply_contraction_jitted(
@@ -321,7 +311,9 @@ class definition for details.
321311
[working_tensor_obj],
322312
[C3_projector],
323313
)
324-
new_C3.append(_post_process_CTM_tensors(new_C3_tmp, config))
314+
new_C3.append(
315+
_post_process_CTM_tensors(new_C3_tmp, view[0, -1][0][0].C3, config)
316+
)
325317

326318
for x, view in column_views:
327319
view[0, -1] = view[0, -1][0][0].replace_right_env_tensors(
@@ -390,7 +382,9 @@ class definition for details.
390382
[working_tensor_obj],
391383
[C1_projector],
392384
)
393-
new_C1.append(_post_process_CTM_tensors(new_C1_tmp, config))
385+
new_C1.append(
386+
_post_process_CTM_tensors(new_C1_tmp, view[1, 0][0][0].C1, config)
387+
)
394388

395389
T1_projector_left = top_projectors.get_projector(x, y, 0, -1).left # type: ignore
396390
T1_projector_right = top_projectors.get_projector(x, y, 0, 0).right # type: ignore
@@ -411,7 +405,9 @@ class definition for details.
411405
[working_tensor_obj],
412406
[T1_projector_left, T1_projector_right],
413407
)
414-
new_T1.append(_post_process_CTM_tensors(new_T1_tmp, config))
408+
new_T1.append(
409+
_post_process_CTM_tensors(new_T1_tmp, view[1, 0][0][0].T1, config)
410+
)
415411

416412
C2_projector = top_projectors.get_projector(x, y, 0, 0).left # type: ignore
417413
new_C2_tmp = apply_contraction_jitted(
@@ -420,7 +416,9 @@ class definition for details.
420416
[working_tensor_obj],
421417
[C2_projector],
422418
)
423-
new_C2.append(_post_process_CTM_tensors(new_C2_tmp, config))
419+
new_C2.append(
420+
_post_process_CTM_tensors(new_C2_tmp, view[1, 0][0][0].C2, config)
421+
)
424422

425423
for y, view in row_views:
426424
view[1, 0] = view[1, 0][0][0].replace_top_env_tensors(
@@ -491,7 +489,9 @@ class definition for details.
491489
[working_tensor_obj],
492490
[C4_projector],
493491
)
494-
new_C4.append(_post_process_CTM_tensors(new_C4_tmp, config))
492+
new_C4.append(
493+
_post_process_CTM_tensors(new_C4_tmp, view[-1, 0][0][0].C4, config)
494+
)
495495

496496
T3_projector_left = bottom_projectors.get_projector(x, y, 0, -1).left # type: ignore
497497
T3_projector_right = bottom_projectors.get_projector(x, y, 0, 0).right # type: ignore
@@ -512,7 +512,9 @@ class definition for details.
512512
[working_tensor_obj],
513513
[T3_projector_left, T3_projector_right],
514514
)
515-
new_T3.append(_post_process_CTM_tensors(new_T3_tmp, config))
515+
new_T3.append(
516+
_post_process_CTM_tensors(new_T3_tmp, view[-1, 0][0][0].T3, config)
517+
)
516518

517519
C3_projector = bottom_projectors.get_projector(x, y, 0, 0).left # type: ignore
518520
new_C3_tmp = apply_contraction_jitted(
@@ -521,7 +523,9 @@ class definition for details.
521523
[working_tensor_obj],
522524
[C3_projector],
523525
)
524-
new_C3.append(_post_process_CTM_tensors(new_C3_tmp, config))
526+
new_C3.append(
527+
_post_process_CTM_tensors(new_C3_tmp, view[-1, 0][0][0].C3, config)
528+
)
525529

526530
for y, view in row_views:
527531
view[-1, 0] = view[-1, 0][0][0].replace_bottom_env_tensors(
@@ -738,7 +742,9 @@ def do_left_absorption_split_transfer(
738742
[working_tensor_obj],
739743
[C1_ket_projector, C1_bra_projector],
740744
)
741-
new_C1_list.append(_post_process_CTM_tensors(new_C1_tmp, config))
745+
new_C1_list.append(
746+
_post_process_CTM_tensors(new_C1_tmp, view[0, 1][0][0].C1, config)
747+
)
742748

743749
T4_ket_projector_top = left_projectors.get_projector(x, y, -1, 0).top_ket
744750
T4_bra_projector_top = left_projectors.get_projector(x, y, -1, 0).top_bra
@@ -785,8 +791,12 @@ def do_left_absorption_split_transfer(
785791
],
786792
)
787793

788-
new_T4_ket_list.append(_post_process_CTM_tensors(new_T4_ket, config))
789-
new_T4_bra_list.append(_post_process_CTM_tensors(new_T4_bra, config))
794+
new_T4_ket_list.append(
795+
_post_process_CTM_tensors(new_T4_ket, view[0, 1][0][0].T4_ket, config)
796+
)
797+
new_T4_bra_list.append(
798+
_post_process_CTM_tensors(new_T4_bra, view[0, 1][0][0].T4_bra, config)
799+
)
790800

791801
C4_ket_projector = left_projectors.get_projector(x, y, 0, 0).top_ket
792802
C4_bra_projector = left_projectors.get_projector(x, y, 0, 0).top_bra
@@ -796,7 +806,9 @@ def do_left_absorption_split_transfer(
796806
[working_tensor_obj],
797807
[C4_ket_projector, C4_bra_projector],
798808
)
799-
new_C4_list.append(_post_process_CTM_tensors(new_C4_tmp, config))
809+
new_C4_list.append(
810+
_post_process_CTM_tensors(new_C4_tmp, view[0, 1][0][0].C4, config)
811+
)
800812

801813
for x, view in column_views:
802814
view[0, 1] = view[0, 1][0][0].replace_left_env_tensors(
@@ -874,7 +886,9 @@ def do_right_absorption_split_transfer(
874886
[working_tensor_obj],
875887
[C2_ket_projector, C2_bra_projector],
876888
)
877-
new_C2_list.append(_post_process_CTM_tensors(new_C2_tmp, config))
889+
new_C2_list.append(
890+
_post_process_CTM_tensors(new_C2_tmp, view[0, -1][0][0].C2, config)
891+
)
878892

879893
T2_ket_projector_top = right_projectors.get_projector(x, y, -1, 0).top_ket
880894
T2_bra_projector_top = right_projectors.get_projector(x, y, -1, 0).top_bra
@@ -921,8 +935,12 @@ def do_right_absorption_split_transfer(
921935
],
922936
)
923937

924-
new_T2_ket_list.append(_post_process_CTM_tensors(new_T2_ket, config))
925-
new_T2_bra_list.append(_post_process_CTM_tensors(new_T2_bra, config))
938+
new_T2_ket_list.append(
939+
_post_process_CTM_tensors(new_T2_ket, view[0, -1][0][0].T2_ket, config)
940+
)
941+
new_T2_bra_list.append(
942+
_post_process_CTM_tensors(new_T2_bra, view[0, -1][0][0].T2_bra, config)
943+
)
926944

927945
C3_ket_projector = right_projectors.get_projector(x, y, 0, 0).top_ket
928946
C3_bra_projector = right_projectors.get_projector(x, y, 0, 0).top_bra
@@ -932,7 +950,9 @@ def do_right_absorption_split_transfer(
932950
[working_tensor_obj],
933951
[C3_ket_projector, C3_bra_projector],
934952
)
935-
new_C3_list.append(_post_process_CTM_tensors(new_C3_tmp, config))
953+
new_C3_list.append(
954+
_post_process_CTM_tensors(new_C3_tmp, view[0, -1][0][0].C3, config)
955+
)
936956

937957
for x, view in column_views:
938958
view[0, -1] = view[0, -1][0][0].replace_right_env_tensors(
@@ -1006,7 +1026,9 @@ def do_top_absorption_split_transfer(
10061026
[working_tensor_obj],
10071027
[C1_ket_projector, C1_bra_projector],
10081028
)
1009-
new_C1_list.append(_post_process_CTM_tensors(new_C1_tmp, config))
1029+
new_C1_list.append(
1030+
_post_process_CTM_tensors(new_C1_tmp, view[1, 0][0][0].C1, config)
1031+
)
10101032

10111033
T1_ket_projector_left = top_projectors.get_projector(x, y, 0, -1).left_ket
10121034
T1_bra_projector_left = top_projectors.get_projector(x, y, 0, -1).left_bra
@@ -1049,8 +1071,12 @@ def do_top_absorption_split_transfer(
10491071
],
10501072
)
10511073

1052-
new_T1_ket_list.append(_post_process_CTM_tensors(new_T1_ket, config))
1053-
new_T1_bra_list.append(_post_process_CTM_tensors(new_T1_bra, config))
1074+
new_T1_ket_list.append(
1075+
_post_process_CTM_tensors(new_T1_ket, view[1, 0][0][0].T4_ket, config)
1076+
)
1077+
new_T1_bra_list.append(
1078+
_post_process_CTM_tensors(new_T1_bra, view[1, 0][0][0].T4_bra, config)
1079+
)
10541080

10551081
C2_ket_projector = top_projectors.get_projector(x, y, 0, 0).left_ket
10561082
C2_bra_projector = top_projectors.get_projector(x, y, 0, 0).left_bra
@@ -1060,7 +1086,9 @@ def do_top_absorption_split_transfer(
10601086
[working_tensor_obj],
10611087
[C2_ket_projector, C2_bra_projector],
10621088
)
1063-
new_C2_list.append(_post_process_CTM_tensors(new_C2_tmp, config))
1089+
new_C2_list.append(
1090+
_post_process_CTM_tensors(new_C2_tmp, view[1, 0][0][0].C2, config)
1091+
)
10641092

10651093
for y, view in row_views:
10661094
view[1, 0] = view[1, 0][0][0].replace_top_env_tensors(
@@ -1136,7 +1164,9 @@ def do_bottom_absorption_split_transfer(
11361164
[working_tensor_obj],
11371165
[C4_ket_projector, C4_bra_projector],
11381166
)
1139-
new_C4_list.append(_post_process_CTM_tensors(new_C4_tmp, config))
1167+
new_C4_list.append(
1168+
_post_process_CTM_tensors(new_C4_tmp, view[-1, 0][0][0].C4, config)
1169+
)
11401170

11411171
T3_ket_projector_left = bottom_projectors.get_projector(
11421172
x, y, 0, -1
@@ -1187,8 +1217,12 @@ def do_bottom_absorption_split_transfer(
11871217
],
11881218
)
11891219

1190-
new_T3_ket_list.append(_post_process_CTM_tensors(new_T3_ket, config))
1191-
new_T3_bra_list.append(_post_process_CTM_tensors(new_T3_bra, config))
1220+
new_T3_ket_list.append(
1221+
_post_process_CTM_tensors(new_T3_ket, view[-1, 0][0][0].T3_ket, config)
1222+
)
1223+
new_T3_bra_list.append(
1224+
_post_process_CTM_tensors(new_T3_bra, view[-1, 0][0][0].T3_bra, config)
1225+
)
11921226

11931227
C3_ket_projector = bottom_projectors.get_projector(x, y, 0, 0).left_ket
11941228
C3_bra_projector = bottom_projectors.get_projector(x, y, 0, 0).left_bra
@@ -1198,7 +1232,9 @@ def do_bottom_absorption_split_transfer(
11981232
[working_tensor_obj],
11991233
[C3_ket_projector, C3_bra_projector],
12001234
)
1201-
new_C3_list.append(_post_process_CTM_tensors(new_C3_tmp, config))
1235+
new_C3_list.append(
1236+
_post_process_CTM_tensors(new_C3_tmp, view[-1, 0][0][0].C3, config)
1237+
)
12021238

12031239
for y, view in row_views:
12041240
view[-1, 0] = view[-1, 0][0][0].replace_bottom_env_tensors(

0 commit comments

Comments
 (0)