@@ -94,38 +94,18 @@ def _get_ctmrg_1x2_structure(
9494 return view_tensors , view_tensor_objs
9595
9696
97- def _post_process_CTM_tensors (a : jnp .ndarray , config : VariPEPS_Config ) -> jnp .ndarray :
97+ def _post_process_CTM_tensors (
98+ a : jnp .ndarray , a_old : jnp .ndarray , config : VariPEPS_Config
99+ ) -> jnp .ndarray :
98100 a = a / jnp .linalg .norm (a )
99- a_abs = jnp .abs (a )
100- a_abs_max = jnp .max (a_abs )
101101
102- def scan_max_element (carry , x ):
103- x_a , x_a_abs = x
104- found , phase = carry
105-
106- def new_phase (ph , curr_x , curr_x_abs ):
107- return cond (
108- curr_x_abs >= (config .svd_sign_fix_eps * a_abs_max ),
109- lambda p , c_x , c_x_a : c_x / c_x_a ,
110- lambda p , c_x , c_x_a : p ,
111- ph ,
112- curr_x ,
113- curr_x_abs ,
114- )
115-
116- phase = cond (
117- found , lambda ph , curr_x , curr_x_abs : ph , new_phase , phase , x_a , x_a_abs
118- )
119-
120- return (jnp .logical_not (jnp .isnan (phase )), phase ), None
121-
122- (_ , phase ), _ = scan (
123- scan_max_element ,
124- (jnp .array (False ), jnp .array (jnp .nan , dtype = a .dtype )),
125- (a .flatten (), a_abs .flatten ()),
126- )
102+ if a_old .shape == a .shape :
103+ phase = jnp .sum (a .conj () * a_old )
104+ phase = phase / jnp .abs (phase )
105+ else :
106+ phase = 1
127107
128- return a * phase . conj ()
108+ return a * phase
129109
130110
131111def do_left_absorption (
@@ -187,7 +167,9 @@ class definition for details.
187167 [working_tensor_obj ],
188168 [C1_projector ],
189169 )
190- new_C1 .append (_post_process_CTM_tensors (new_C1_tmp , config ))
170+ new_C1 .append (
171+ _post_process_CTM_tensors (new_C1_tmp , view [0 , 1 ][0 ][0 ].C1 , config )
172+ )
191173
192174 T4_projector_top = left_projectors .get_projector (x , y , - 1 , 0 ).top
193175 T4_projector_bottom = left_projectors .get_projector (x , y , 0 , 0 ).bottom
@@ -209,7 +191,9 @@ class definition for details.
209191 [working_tensor_obj ],
210192 [T4_projector_top , T4_projector_bottom ],
211193 )
212- new_T4 .append (_post_process_CTM_tensors (new_T4_tmp , config ))
194+ new_T4 .append (
195+ _post_process_CTM_tensors (new_T4_tmp , view [0 , 1 ][0 ][0 ].T4 , config )
196+ )
213197
214198 C4_projector = left_projectors .get_projector (x , y , 0 , 0 ).top
215199 new_C4_tmp = apply_contraction_jitted (
@@ -218,7 +202,9 @@ class definition for details.
218202 [working_tensor_obj ],
219203 [C4_projector ],
220204 )
221- new_C4 .append (_post_process_CTM_tensors (new_C4_tmp , config ))
205+ new_C4 .append (
206+ _post_process_CTM_tensors (new_C4_tmp , view [0 , 1 ][0 ][0 ].C4 , config )
207+ )
222208
223209 for x , view in column_views :
224210 view [0 , 1 ] = view [0 , 1 ][0 ][0 ].replace_left_env_tensors (
@@ -291,7 +277,9 @@ class definition for details.
291277 [working_tensor_obj ],
292278 [C2_projector ],
293279 )
294- new_C2 .append (_post_process_CTM_tensors (new_C2_tmp , config ))
280+ new_C2 .append (
281+ _post_process_CTM_tensors (new_C2_tmp , view [0 , - 1 ][0 ][0 ].C2 , config )
282+ )
295283
296284 T2_projector_top = right_projectors .get_projector (x , y , - 1 , 0 ).top
297285 T2_projector_bottom = right_projectors .get_projector (x , y , 0 , 0 ).bottom
@@ -312,7 +300,9 @@ class definition for details.
312300 [working_tensor_obj ],
313301 [T2_projector_top , T2_projector_bottom ],
314302 )
315- new_T2 .append (_post_process_CTM_tensors (new_T2_tmp , config ))
303+ new_T2 .append (
304+ _post_process_CTM_tensors (new_T2_tmp , view [0 , - 1 ][0 ][0 ].T2 , config )
305+ )
316306
317307 C3_projector = right_projectors .get_projector (x , y , 0 , 0 ).top
318308 new_C3_tmp = apply_contraction_jitted (
@@ -321,7 +311,9 @@ class definition for details.
321311 [working_tensor_obj ],
322312 [C3_projector ],
323313 )
324- new_C3 .append (_post_process_CTM_tensors (new_C3_tmp , config ))
314+ new_C3 .append (
315+ _post_process_CTM_tensors (new_C3_tmp , view [0 , - 1 ][0 ][0 ].C3 , config )
316+ )
325317
326318 for x , view in column_views :
327319 view [0 , - 1 ] = view [0 , - 1 ][0 ][0 ].replace_right_env_tensors (
@@ -390,7 +382,9 @@ class definition for details.
390382 [working_tensor_obj ],
391383 [C1_projector ],
392384 )
393- new_C1 .append (_post_process_CTM_tensors (new_C1_tmp , config ))
385+ new_C1 .append (
386+ _post_process_CTM_tensors (new_C1_tmp , view [1 , 0 ][0 ][0 ].C1 , config )
387+ )
394388
395389 T1_projector_left = top_projectors .get_projector (x , y , 0 , - 1 ).left # type: ignore
396390 T1_projector_right = top_projectors .get_projector (x , y , 0 , 0 ).right # type: ignore
@@ -411,7 +405,9 @@ class definition for details.
411405 [working_tensor_obj ],
412406 [T1_projector_left , T1_projector_right ],
413407 )
414- new_T1 .append (_post_process_CTM_tensors (new_T1_tmp , config ))
408+ new_T1 .append (
409+ _post_process_CTM_tensors (new_T1_tmp , view [1 , 0 ][0 ][0 ].T1 , config )
410+ )
415411
416412 C2_projector = top_projectors .get_projector (x , y , 0 , 0 ).left # type: ignore
417413 new_C2_tmp = apply_contraction_jitted (
@@ -420,7 +416,9 @@ class definition for details.
420416 [working_tensor_obj ],
421417 [C2_projector ],
422418 )
423- new_C2 .append (_post_process_CTM_tensors (new_C2_tmp , config ))
419+ new_C2 .append (
420+ _post_process_CTM_tensors (new_C2_tmp , view [1 , 0 ][0 ][0 ].C2 , config )
421+ )
424422
425423 for y , view in row_views :
426424 view [1 , 0 ] = view [1 , 0 ][0 ][0 ].replace_top_env_tensors (
@@ -491,7 +489,9 @@ class definition for details.
491489 [working_tensor_obj ],
492490 [C4_projector ],
493491 )
494- new_C4 .append (_post_process_CTM_tensors (new_C4_tmp , config ))
492+ new_C4 .append (
493+ _post_process_CTM_tensors (new_C4_tmp , view [- 1 , 0 ][0 ][0 ].C4 , config )
494+ )
495495
496496 T3_projector_left = bottom_projectors .get_projector (x , y , 0 , - 1 ).left # type: ignore
497497 T3_projector_right = bottom_projectors .get_projector (x , y , 0 , 0 ).right # type: ignore
@@ -512,7 +512,9 @@ class definition for details.
512512 [working_tensor_obj ],
513513 [T3_projector_left , T3_projector_right ],
514514 )
515- new_T3 .append (_post_process_CTM_tensors (new_T3_tmp , config ))
515+ new_T3 .append (
516+ _post_process_CTM_tensors (new_T3_tmp , view [- 1 , 0 ][0 ][0 ].T3 , config )
517+ )
516518
517519 C3_projector = bottom_projectors .get_projector (x , y , 0 , 0 ).left # type: ignore
518520 new_C3_tmp = apply_contraction_jitted (
@@ -521,7 +523,9 @@ class definition for details.
521523 [working_tensor_obj ],
522524 [C3_projector ],
523525 )
524- new_C3 .append (_post_process_CTM_tensors (new_C3_tmp , config ))
526+ new_C3 .append (
527+ _post_process_CTM_tensors (new_C3_tmp , view [- 1 , 0 ][0 ][0 ].C3 , config )
528+ )
525529
526530 for y , view in row_views :
527531 view [- 1 , 0 ] = view [- 1 , 0 ][0 ][0 ].replace_bottom_env_tensors (
@@ -738,7 +742,9 @@ def do_left_absorption_split_transfer(
738742 [working_tensor_obj ],
739743 [C1_ket_projector , C1_bra_projector ],
740744 )
741- new_C1_list .append (_post_process_CTM_tensors (new_C1_tmp , config ))
745+ new_C1_list .append (
746+ _post_process_CTM_tensors (new_C1_tmp , view [0 , 1 ][0 ][0 ].C1 , config )
747+ )
742748
743749 T4_ket_projector_top = left_projectors .get_projector (x , y , - 1 , 0 ).top_ket
744750 T4_bra_projector_top = left_projectors .get_projector (x , y , - 1 , 0 ).top_bra
@@ -785,8 +791,12 @@ def do_left_absorption_split_transfer(
785791 ],
786792 )
787793
788- new_T4_ket_list .append (_post_process_CTM_tensors (new_T4_ket , config ))
789- new_T4_bra_list .append (_post_process_CTM_tensors (new_T4_bra , config ))
794+ new_T4_ket_list .append (
795+ _post_process_CTM_tensors (new_T4_ket , view [0 , 1 ][0 ][0 ].T4_ket , config )
796+ )
797+ new_T4_bra_list .append (
798+ _post_process_CTM_tensors (new_T4_bra , view [0 , 1 ][0 ][0 ].T4_bra , config )
799+ )
790800
791801 C4_ket_projector = left_projectors .get_projector (x , y , 0 , 0 ).top_ket
792802 C4_bra_projector = left_projectors .get_projector (x , y , 0 , 0 ).top_bra
@@ -796,7 +806,9 @@ def do_left_absorption_split_transfer(
796806 [working_tensor_obj ],
797807 [C4_ket_projector , C4_bra_projector ],
798808 )
799- new_C4_list .append (_post_process_CTM_tensors (new_C4_tmp , config ))
809+ new_C4_list .append (
810+ _post_process_CTM_tensors (new_C4_tmp , view [0 , 1 ][0 ][0 ].C4 , config )
811+ )
800812
801813 for x , view in column_views :
802814 view [0 , 1 ] = view [0 , 1 ][0 ][0 ].replace_left_env_tensors (
@@ -874,7 +886,9 @@ def do_right_absorption_split_transfer(
874886 [working_tensor_obj ],
875887 [C2_ket_projector , C2_bra_projector ],
876888 )
877- new_C2_list .append (_post_process_CTM_tensors (new_C2_tmp , config ))
889+ new_C2_list .append (
890+ _post_process_CTM_tensors (new_C2_tmp , view [0 , - 1 ][0 ][0 ].C2 , config )
891+ )
878892
879893 T2_ket_projector_top = right_projectors .get_projector (x , y , - 1 , 0 ).top_ket
880894 T2_bra_projector_top = right_projectors .get_projector (x , y , - 1 , 0 ).top_bra
@@ -921,8 +935,12 @@ def do_right_absorption_split_transfer(
921935 ],
922936 )
923937
924- new_T2_ket_list .append (_post_process_CTM_tensors (new_T2_ket , config ))
925- new_T2_bra_list .append (_post_process_CTM_tensors (new_T2_bra , config ))
938+ new_T2_ket_list .append (
939+ _post_process_CTM_tensors (new_T2_ket , view [0 , - 1 ][0 ][0 ].T2_ket , config )
940+ )
941+ new_T2_bra_list .append (
942+ _post_process_CTM_tensors (new_T2_bra , view [0 , - 1 ][0 ][0 ].T2_bra , config )
943+ )
926944
927945 C3_ket_projector = right_projectors .get_projector (x , y , 0 , 0 ).top_ket
928946 C3_bra_projector = right_projectors .get_projector (x , y , 0 , 0 ).top_bra
@@ -932,7 +950,9 @@ def do_right_absorption_split_transfer(
932950 [working_tensor_obj ],
933951 [C3_ket_projector , C3_bra_projector ],
934952 )
935- new_C3_list .append (_post_process_CTM_tensors (new_C3_tmp , config ))
953+ new_C3_list .append (
954+ _post_process_CTM_tensors (new_C3_tmp , view [0 , - 1 ][0 ][0 ].C3 , config )
955+ )
936956
937957 for x , view in column_views :
938958 view [0 , - 1 ] = view [0 , - 1 ][0 ][0 ].replace_right_env_tensors (
@@ -1006,7 +1026,9 @@ def do_top_absorption_split_transfer(
10061026 [working_tensor_obj ],
10071027 [C1_ket_projector , C1_bra_projector ],
10081028 )
1009- new_C1_list .append (_post_process_CTM_tensors (new_C1_tmp , config ))
1029+ new_C1_list .append (
1030+ _post_process_CTM_tensors (new_C1_tmp , view [1 , 0 ][0 ][0 ].C1 , config )
1031+ )
10101032
10111033 T1_ket_projector_left = top_projectors .get_projector (x , y , 0 , - 1 ).left_ket
10121034 T1_bra_projector_left = top_projectors .get_projector (x , y , 0 , - 1 ).left_bra
@@ -1049,8 +1071,12 @@ def do_top_absorption_split_transfer(
10491071 ],
10501072 )
10511073
1052- new_T1_ket_list .append (_post_process_CTM_tensors (new_T1_ket , config ))
1053- new_T1_bra_list .append (_post_process_CTM_tensors (new_T1_bra , config ))
1074+ new_T1_ket_list .append (
1075+ _post_process_CTM_tensors (new_T1_ket , view [1 , 0 ][0 ][0 ].T4_ket , config )
1076+ )
1077+ new_T1_bra_list .append (
1078+ _post_process_CTM_tensors (new_T1_bra , view [1 , 0 ][0 ][0 ].T4_bra , config )
1079+ )
10541080
10551081 C2_ket_projector = top_projectors .get_projector (x , y , 0 , 0 ).left_ket
10561082 C2_bra_projector = top_projectors .get_projector (x , y , 0 , 0 ).left_bra
@@ -1060,7 +1086,9 @@ def do_top_absorption_split_transfer(
10601086 [working_tensor_obj ],
10611087 [C2_ket_projector , C2_bra_projector ],
10621088 )
1063- new_C2_list .append (_post_process_CTM_tensors (new_C2_tmp , config ))
1089+ new_C2_list .append (
1090+ _post_process_CTM_tensors (new_C2_tmp , view [1 , 0 ][0 ][0 ].C2 , config )
1091+ )
10641092
10651093 for y , view in row_views :
10661094 view [1 , 0 ] = view [1 , 0 ][0 ][0 ].replace_top_env_tensors (
@@ -1136,7 +1164,9 @@ def do_bottom_absorption_split_transfer(
11361164 [working_tensor_obj ],
11371165 [C4_ket_projector , C4_bra_projector ],
11381166 )
1139- new_C4_list .append (_post_process_CTM_tensors (new_C4_tmp , config ))
1167+ new_C4_list .append (
1168+ _post_process_CTM_tensors (new_C4_tmp , view [- 1 , 0 ][0 ][0 ].C4 , config )
1169+ )
11401170
11411171 T3_ket_projector_left = bottom_projectors .get_projector (
11421172 x , y , 0 , - 1
@@ -1187,8 +1217,12 @@ def do_bottom_absorption_split_transfer(
11871217 ],
11881218 )
11891219
1190- new_T3_ket_list .append (_post_process_CTM_tensors (new_T3_ket , config ))
1191- new_T3_bra_list .append (_post_process_CTM_tensors (new_T3_bra , config ))
1220+ new_T3_ket_list .append (
1221+ _post_process_CTM_tensors (new_T3_ket , view [- 1 , 0 ][0 ][0 ].T3_ket , config )
1222+ )
1223+ new_T3_bra_list .append (
1224+ _post_process_CTM_tensors (new_T3_bra , view [- 1 , 0 ][0 ][0 ].T3_bra , config )
1225+ )
11921226
11931227 C3_ket_projector = bottom_projectors .get_projector (x , y , 0 , 0 ).left_ket
11941228 C3_bra_projector = bottom_projectors .get_projector (x , y , 0 , 0 ).left_bra
@@ -1198,7 +1232,9 @@ def do_bottom_absorption_split_transfer(
11981232 [working_tensor_obj ],
11991233 [C3_ket_projector , C3_bra_projector ],
12001234 )
1201- new_C3_list .append (_post_process_CTM_tensors (new_C3_tmp , config ))
1235+ new_C3_list .append (
1236+ _post_process_CTM_tensors (new_C3_tmp , view [- 1 , 0 ][0 ][0 ].C3 , config )
1237+ )
12021238
12031239 for y , view in row_views :
12041240 view [- 1 , 0 ] = view [- 1 , 0 ][0 ][0 ].replace_bottom_env_tensors (
0 commit comments