├── __init__.py ├── .gitignore ├── tests ├── __init__.py ├── character_test.py └── kostka_test.py ├── .DS_Store ├── FIGS ├── char.pdf ├── char_dim.pdf └── kostka_1e-12.pdf ├── index.html ├── DATA ├── .gitignore ├── mps_data.dat ├── sage_data.dat └── kostka_short.dat ├── __pycache__ └── character_builder.cpython-311.pyc ├── requirements.txt ├── champs ├── __init__.py ├── kostka_builder.py ├── character_builder.py └── builder.py ├── example.py ├── README.md ├── plotting ├── make_plot_kostka.py ├── make_plot_dim.py └── make_plot.py ├── timing ├── timing_sage.py ├── timing_gap.g ├── timing_mps.py └── timing_kostka.py ├── utils.py └── experiments.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import * 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/.DS_Store -------------------------------------------------------------------------------- /FIGS/char.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/FIGS/char.pdf -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ./ 5 | -------------------------------------------------------------------------------- /DATA/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all datafiles in this directory 2 | 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /DATA/mps_data.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/DATA/mps_data.dat -------------------------------------------------------------------------------- /DATA/sage_data.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/DATA/sage_data.dat -------------------------------------------------------------------------------- /FIGS/char_dim.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/FIGS/char_dim.pdf -------------------------------------------------------------------------------- /DATA/kostka_short.dat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/DATA/kostka_short.dat -------------------------------------------------------------------------------- /FIGS/kostka_1e-12.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/FIGS/kostka_1e-12.pdf -------------------------------------------------------------------------------- /tests/character_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation of the character computation for the symmetric group S_n using the MPS algorithm. 3 | """ 4 | -------------------------------------------------------------------------------- /__pycache__/character_builder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbravyi/symmetric_group_characters/HEAD/__pycache__/character_builder.cpython-311.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.10.0 2 | matplotlib-inline==0.1.7 3 | mpnum==1.0.2 4 | numpy==2.1.3 5 | pillow==11.1.0 6 | quimb==1.10.0 7 | scipy==1.15.0 8 | sphinx 9 | sphinx_rtd_theme 10 | 11 | 12 | -------------------------------------------------------------------------------- /champs/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import Builder 2 | from .character_builder import CharacterBuilder 3 | from .kostka_builder import KostkaBuilder 4 | 5 | __all__ = ["Builder", "CharacterBuilder", "KostkaBuilder"] 6 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import champs 4 | from utils import get_partitions 5 | 6 | # compute characters of the symmetric group S_n using the MPS algorithm 7 | n = 8 8 | 9 | Pn = get_partitions(n) 10 | 11 | Mu = random.choice(Pn) 12 | print("n=", n) 13 | print("Conjugacy class Mu=", Mu) 14 | 15 | # compute MPS that encodes all characters of Mu 16 | builder = champs.CharacterBuilder(Mu) 17 | 18 | # compute all characters of Mu 19 | for Lambda in Pn: 20 | chi = builder.get_character(Lambda) 21 | print("irrep Lambda=", Lambda, "character=", chi) 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChaMPS 2 | 3 | This is a Matrix Product State (MPS) algorithm for computing characters of irreducible representations of the symmetric group $S_n$ that appeared [here](https://arxiv.org/abs/2501.12579). 4 | The algorithm computes an MPS encoding all characters of a given conjugacy class of $S_n$. It relies on a mapping from characters of $S_n$ to quantum spin chains proposed by 5 | [Marcos Crichigno and Anupam Prakash](https://arxiv.org/abs/2404.04322) 6 | 7 | ## Documentation 8 | Read the [documentation](https://sbravyi.github.io/symmetric_group_characters/index.html) 9 | 10 | ## Examples 11 | Example of how to use the algorithm can be found in **example.py**. See **experiments.ipynb** jupyter notebook to reproduce all the experiments from the paper. 12 | 13 | -------------------------------------------------------------------------------- /plotting/make_plot_kostka.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | SCRIPT_DIR = Path(__file__).parent.resolve() 7 | DATA_DIR = SCRIPT_DIR.parent / "DATA" 8 | FIG_DIR = SCRIPT_DIR.parent / "FIGS" 9 | path = DATA_DIR # change for different datasets 10 | 11 | # input data file 12 | time_data = pd.DataFrame(pd.read_pickle(DATA_DIR / "kostka_short.dat")) 13 | 14 | fig, ax = plt.subplots(figsize=(10, 8)) 15 | 16 | g = sns.boxplot( 17 | data=time_data, 18 | x="n", 19 | y="Runtime", 20 | hue="Algorithm", 21 | log_scale=True, 22 | linewidth=0.8, 23 | widths=0.35, 24 | showfliers=False, 25 | ax=ax, 26 | ) 27 | plt.grid() 28 | handles, labels = ax.get_legend_handles_labels() 29 | ax.legend(handles=handles[0:], labels=labels[0:]) 30 | plt.legend(ncol=len(time_data.columns)) 31 | plt.ylabel("Runtime (seconds)") 32 | 33 | file_name = FIG_DIR / "kostka_times.pdf" 34 | plt.savefig(file_name) 35 | -------------------------------------------------------------------------------- /timing/timing_sage.py: -------------------------------------------------------------------------------- 1 | from sage.all import symmetrica 2 | import numpy as np 3 | import time 4 | import pickle 5 | from utils import get_partitions 6 | from pathlib import Path 7 | 8 | result = [] 9 | 10 | # partitions for the timing test 11 | SelectMu = [] 12 | for m in range(2, 16): 13 | SelectMu.append([2] * m) 14 | 15 | 16 | for Mu in SelectMu: 17 | n = np.sum(Mu) 18 | 19 | # compute all partitions of n 20 | Pn = get_partitions(n) 21 | 22 | print("n=", n) 23 | print("Number of partitions=", len(Pn)) 24 | 25 | result_entry = {} 26 | result_entry["num_partitions"] = len(Pn) 27 | 28 | print("Mu=", Mu) 29 | 30 | result_entry["Mu"] = Mu 31 | 32 | t = time.time() 33 | table = {} 34 | for Lambda in Pn: 35 | table[Lambda] = int(symmetrica.charvalue(Lambda, Mu)) 36 | sage_runtime = time.time() - t 37 | print("sage runtime=", "{0:.5f}".format(sage_runtime)) 38 | result_entry["sage_runtime"] = sage_runtime 39 | result_entry["table"] = table 40 | result.append(result_entry) 41 | print("###################################") 42 | 43 | 44 | SCRIPT_DIR = Path(__file__).parent.resolve() 45 | DATA_DIR = SCRIPT_DIR.parent / "DATA" 46 | 47 | path = DATA_DIR # data directory 48 | 49 | file_name = path / "sage_data.dat" 50 | 51 | with open(file_name, "wb") as fp: 52 | pickle.dump(result, fp) 53 | print("Done") 54 | 55 | print("file_name=", file_name) 56 | -------------------------------------------------------------------------------- /plotting/make_plot_dim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from pathlib import Path 5 | 6 | SCRIPT_DIR = Path(__file__).parent.resolve() 7 | DATA_DIR = SCRIPT_DIR.parent / "DATA" 8 | FIG_DIR = SCRIPT_DIR.parent / "FIGS" 9 | 10 | 11 | file_name = "mps_data.dat" 12 | 13 | with open(DATA_DIR / file_name, "rb") as fp: 14 | result_mps = pickle.load(fp) 15 | 16 | num_trials = len(result_mps) 17 | 18 | # minimum n to include in the plot 19 | n_min = 10 20 | result_mps = [ 21 | result_mps[i] for i in range(num_trials) if np.sum(result_mps[i]["Mu"]) >= n_min 22 | ] 23 | num_trials = len(result_mps) 24 | 25 | partitions = [str(np.sum(result_mps[i]["Mu"])) for i in range(num_trials)] 26 | 27 | data = {} 28 | data["D"] = [result_mps[i]["D"] for i in range(num_trials)] 29 | data["upper bound"] = [result_mps[i]["Dupper"] for i in range(num_trials)] 30 | 31 | 32 | x = np.arange(num_trials) # the label locations 33 | width = 0.25 # the width of the bars 34 | multiplier = 0 35 | 36 | plt.figure(figsize=(10, 8)) 37 | plt.rcParams["font.size"] = "20" 38 | ax = plt.gca() 39 | 40 | 41 | color = {} 42 | color["D"] = "C0" 43 | color["upper bound"] = "C2" 44 | for name, D in data.items(): 45 | offset = width * multiplier 46 | rects = ax.bar(x + offset, D, width, label=name, color=color[name]) 47 | multiplier += 1 48 | 49 | # Add some text for labels, title and custom x-axis tick labels, etc. 50 | ax.set_ylabel("MPS bond dimension $D$") 51 | ax.set_xticks(x + width, partitions) 52 | ax.legend(loc="upper left", ncols=3) 53 | plt.yscale("log") 54 | ax.set_xlabel("n") 55 | ax.grid(True) 56 | 57 | 58 | plt.show() 59 | plt.savefig(FIG_DIR / "char_dim.pdf") 60 | -------------------------------------------------------------------------------- /timing/timing_gap.g: -------------------------------------------------------------------------------- 1 | # measure the runtime for computing a given column of the character table of the symmetric group S_n 2 | 3 | LoadPackage("ctbllib"); 4 | 5 | 6 | 7 | # permutation that defines a column of the character table 8 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10); 9 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12); 10 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14); 11 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16); 12 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18); 13 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20); 14 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20)(21,22); 15 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20)(21,22)(23,24); 16 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20)(21,22)(23,24)(25,26); 17 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20)(21,22)(23,24)(25,26); 18 | #sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20)(21,22)(23,24)(25,26)(27,28); 19 | sigma:=(1,2)(3,4)(5,6)(7,8)(9,10)(11,12)(13,14)(15,16)(17,18)(19,20)(21,22)(23,24)(25,26)(27,28)(29,30); 20 | n:=Size(ListPerm(sigma)); 21 | Print("n="); 22 | Print(n); 23 | Print("\n"); 24 | Print("group element="); 25 | Print(sigma); 26 | Print("\n"); 27 | 28 | G:=SymmetricGroup(n); 29 | irrG:=Irr(G); 30 | Print("Begin computation"); 31 | Print("\n"); 32 | 33 | startTime := Runtime(); 34 | for irrep in irrG do 35 | chi:=(sigma^irrep); 36 | #Print("chi="); 37 | #Print(chi); 38 | #Print("\n"); 39 | od; 40 | 41 | Print("Done"); 42 | Print("\n"); 43 | endTime := Runtime(); 44 | runtime:= endTime - startTime; 45 | Print("runtime="); 46 | Print(runtime); 47 | Print("\n"); 48 | -------------------------------------------------------------------------------- /timing/timing_mps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import pickle 4 | import champs 5 | from utils import get_partitions 6 | from pathlib import Path 7 | 8 | 9 | # stores all computation results and runtime 10 | result = [] 11 | 12 | # partitions for the runtime test: (n/2) cycles of length=2 13 | SelectMu = [] 14 | for m in range(2, 16): 15 | SelectMu.append([2] * m) 16 | 17 | for Mu in SelectMu: 18 | n = np.sum(Mu) 19 | Pn = get_partitions(n) 20 | 21 | print("n=", n) 22 | print("Number of partitions=", len(Pn)) 23 | 24 | result_entry = {} 25 | result_entry["num_partitions"] = len(Pn) 26 | 27 | print("Mu=", Mu) 28 | result_entry["Mu"] = Mu 29 | 30 | t = time.time() 31 | builder = champs.CharacterBuilder(Mu) 32 | # runtime for computing the MPS 33 | runtime_part1 = time.time() - t 34 | 35 | D = builder.get_bond_dimension() 36 | print("Maximum MPS bond dimension =", D) 37 | result_entry["D"] = D 38 | 39 | # analytic upper bound on the bond dimension 40 | Dupper = np.prod([Mu[i] + 2 for i in range(len(Mu))]) 41 | # print('Upper bound on MPS bond dimension =',Dupper) 42 | result_entry["Dupper"] = Dupper 43 | 44 | # compute all characters of Mu 45 | table_mps = {} 46 | t = time.time() 47 | for Lambda in Pn: 48 | table_mps[Lambda] = builder.get_character(Lambda) 49 | # runtime for computing the characters 50 | runtime_part2 = time.time() - t 51 | print("MPS runtime (part 1)=", "{0:.2f}".format(runtime_part1)) 52 | print("MPS runtime (part 2)=", "{0:.2f}".format(runtime_part2)) 53 | result_entry["MPS_runtime_part1"] = runtime_part1 54 | result_entry["MPS_runtime_part2"] = runtime_part2 55 | result_entry["table_mps"] = table_mps 56 | result.append(result_entry) 57 | print("###################################") 58 | 59 | 60 | SCRIPT_DIR = Path(__file__).parent.resolve() 61 | DATA_DIR = SCRIPT_DIR.parent / "DATA" 62 | 63 | path = DATA_DIR # data directory 64 | 65 | file_name = path / "mps_data.dat" 66 | 67 | 68 | with open(file_name, "wb") as fp: 69 | pickle.dump(result, fp) 70 | print("Done") 71 | 72 | print("file_name=", file_name) 73 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from math import lgamma, exp 2 | 3 | 4 | def partitions(n: int, minimum_partition_value: int = 1): 5 | """ 6 | Returns a generator that yields all partitions of n. 7 | 8 | Source: 9 | https://stackoverflow.com/questions/10035752/elegant-python-code-for-integer-partitioning 10 | 11 | Args: 12 | n (int): Integer to partition 13 | minimum_partition_value (int, optional): Minimum partition value. Defaults to 1. 14 | """ 15 | yield (n,) 16 | for i in range(minimum_partition_value, n // 2 + 1): 17 | for p in partitions(n - i, i): 18 | yield (i,) + p 19 | 20 | 21 | def get_partitions(n: int) -> list[tuple[int]]: 22 | """ 23 | Returns a list of all partitions of n. 24 | 25 | Args: 26 | n (int): Integer to partition 27 | 28 | Returns: 29 | list[tuple[int]]: List of all partitions of n. 30 | """ 31 | return [tuple(reversed(list(p))) for p in list(partitions(n))] 32 | 33 | 34 | def perm_module_d(mu: tuple[int]) -> int: 35 | """ 36 | Returns the dimension of the permutation module of label Mu 37 | 38 | Args: 39 | Mu (tuple[int]): Partition as a list of positive integers in nonincreasing order that sum up to n. 40 | 41 | Returns: 42 | int: Dimension of the permutation module of label Mu 43 | """ 44 | val = lgamma(sum(mu) + 1) 45 | for part in mu: 46 | val -= lgamma(part + 1) 47 | return int(round(exp(val))) 48 | 49 | 50 | def majorize(mu: tuple[int], ell: tuple[int], eq=True) -> bool: 51 | """ 52 | Determines if lambda >= Mu in majorization order 53 | 54 | Args: 55 | mu (tuple[int]): Partition as a list of positive integers in nonincreasing order. 56 | ell (tuple[int]): Partition as a list of positive integers in nonincreasing order. 57 | eq: Flag for requiring that Mu and Lambda are partitions of the same number 58 | 59 | Returns: 60 | bool: True if Lambda >= Mu in majorization order, False otherwise. 61 | """ 62 | sum_mu = 0 63 | sum_lm = 0 64 | 65 | for i in range(min(len(ell), len(mu))): 66 | sum_mu += mu[i] 67 | sum_lm += ell[i] 68 | if sum_mu > sum_lm: 69 | return False 70 | 71 | remaining_mu = sum(mu[i:]) 72 | remaining_lm = sum(ell[i:]) 73 | if eq: 74 | return sum_mu + remaining_mu == sum_lm + remaining_lm 75 | return True 76 | -------------------------------------------------------------------------------- /tests/kostka_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation of skew and standard Kostka numbers 3 | Tested m < 8 and >= 8 (caching is handled different in these cases) 4 | """ 5 | 6 | from kostka_builder import KostkaBuilder 7 | from utils import get_partitions 8 | from sage.all import * 9 | import random as rm 10 | 11 | P7 = utils.get_partitions(7) 12 | P8 = utils.get_partitions(8) 13 | 14 | # Standard Kostka Numbers 15 | 16 | Mu7 = rm.choice(P7) 17 | bd_7 = KostkaBuilder(Mu7) 18 | table_7 = {} 19 | errors_7 = 0 20 | for Lambda in P7: 21 | table_7[Lambda] = bd_7.get_kostka(Lambda) 22 | if table_7[Lambda] != symmetrica.kostka_number(Lambda, Mu7): 23 | errors_7 += 1 24 | print("Errors for Kostkas with Mu=" + str(Mu7) + ": ", errors_7) 25 | 26 | 27 | Mu8 = rm.choice(P8) 28 | bd_8 = KostkaBuilder(Mu8) 29 | table_8 = {} 30 | errors_8 = 0 31 | for Lambda in P8: 32 | table_8[Lambda] = bd_8.get_kostka(Lambda) 33 | if table_8[Lambda] != symmetrica.kostka_number(Lambda, Mu8): 34 | errors_8 += 1 35 | print("Errors for Kostkas with Mu=" + str(Mu8) + ": ", errors_8) 36 | 37 | 38 | # Skew Kostka Numbers: 39 | 40 | # correct skew Kostka numbers for 41 | # Mu = (2,1,1) 42 | # Nu = (2,1) 43 | skew_test_7 = { 44 | (7,): 0, 45 | (6, 1): 1, 46 | (5, 1, 1): 3, 47 | (4, 1, 1, 1): 3, 48 | (3, 1, 1, 1, 1): 1, 49 | (2, 1, 1, 1, 1, 1): 0, 50 | (1, 1, 1, 1, 1, 1, 1): 0, 51 | (2, 2, 1, 1, 1): 1, 52 | (3, 2, 1, 1): 5, 53 | (4, 2, 1): 7, 54 | (2, 2, 2, 1): 2, 55 | (3, 3, 1): 4, 56 | (5, 2): 3, 57 | (3, 2, 2): 4, 58 | (4, 3): 3, 59 | } 60 | 61 | bd_skew_7 = KostkaBuilder((2, 1, 1), Nu=(2, 1)) 62 | table_skew_7 = {} 63 | errors_skew_7 = 0 64 | for Lambda in P7: 65 | table_skew_7[Lambda] = bd_skew_7.get_kostka(Lambda) 66 | if table_skew_7[Lambda] != skew_test_7[Lambda]: 67 | errors_skew_7 += 1 68 | print("Errors for skew Kostkas of size n=7: ", errors_skew_7) 69 | 70 | # correct skew Kostka numbers for 71 | # Mu = (2,1,1) 72 | # Nu = (2,1,1) 73 | skew_test_8 = { 74 | (8,): 0, 75 | (7, 1): 0, 76 | (6, 1, 1): 1, 77 | (5, 1, 1, 1): 3, 78 | (4, 1, 1, 1, 1): 3, 79 | (3, 1, 1, 1, 1, 1): 1, 80 | (2, 1, 1, 1, 1, 1, 1): 0, 81 | (1, 1, 1, 1, 1, 1, 1, 1): 0, 82 | (2, 2, 1, 1, 1, 1): 1, 83 | (3, 2, 1, 1, 1): 5, 84 | (4, 2, 1, 1): 7, 85 | (2, 2, 2, 1, 1): 2, 86 | (3, 3, 1, 1): 4, 87 | (5, 2, 1): 3, 88 | (3, 2, 2, 1): 5, 89 | (4, 3, 1): 3, 90 | (6, 2): 0, 91 | (4, 2, 2): 3, 92 | (2, 2, 2, 2): 1, 93 | (3, 3, 2): 2, 94 | (5, 3): 0, 95 | (4, 4): 0, 96 | } 97 | 98 | bd_skew_8 = KostkaBuilder((2, 1, 1), Nu=(2, 1, 1)) 99 | table_skew_8 = {} 100 | errors_skew_8 = 0 101 | for Lambda in P8: 102 | table_skew_8[Lambda] = bd_skew_8.get_kostka(Lambda) 103 | if table_skew_8[Lambda] != skew_test_8[Lambda]: 104 | errors_skew_8 += 1 105 | print("Errors for skew Kostkas of size n=8: ", errors_skew_8) 106 | -------------------------------------------------------------------------------- /champs/kostka_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mpnum as mp # MPS/MPO package 3 | from utils import majorize 4 | import champs 5 | 6 | 7 | class KostkaBuilder(champs.Builder): 8 | """ 9 | MPS algorithm for skew Kostka numbers. Computes Kostkas for a given weight vector Mu and skew Nu. 10 | Set Nu = (0,) for standard Kostka numbers. 11 | 12 | Args: 13 | Mu (tuple[int]): we assume that Mu is given in non-increasing order 14 | Nu (tuple[int], optional): _description_. Defaults to (0, ). 15 | relerr (_type_, optional): _description_. Defaults to 1e-14. 16 | """ 17 | 18 | def __init__(self, Mu: tuple[int], Nu: tuple[int] = (0,), relerr=1e-14): 19 | super().__init__(Mu, Nu, relerr=relerr) 20 | 21 | # Computes the skew Kostka K_Lambda\Nu,Mu for a partition Lambda 22 | # Input: 23 | # Lambda: a non-increasing list of positive integers summing to n 24 | 25 | def get_kostka(self, Lambda: tuple[int]) -> int: 26 | r""" 27 | Computes the Kostka K_lambda,Mu for a partition Lambda 28 | 29 | Args: 30 | Lambda (tuple[int]): Partition as a list of positive integers in nonincreasing order that sum up to n. 31 | 32 | Returns: 33 | int: Kostka number K_lambda,Mu or (skew Kostka K_(Lambda / Nu),Mu if Nu != (0,)) 34 | """ 35 | # check majorization or if lambda \ nu is a valid skew partition 36 | if self.Nu == (0,) and not majorize(self.Mu, Lambda): # standard Kostka 37 | return 0 38 | elif not self.valid_skew(Lambda): # skew Kostka 39 | return 0 40 | 41 | return int(np.round(self._contract(Lambda))) 42 | 43 | # Returns a MPO representing (operator) complete symmetric polynomials 44 | def get_MPO(self, k): 45 | array = [] 46 | # index ordering LUDR 47 | 48 | # left boundary 49 | tensor = np.zeros((1, 2, 2, 2 * k + 1)) 50 | tensor[0, :, :, 0] = np.eye(2) 51 | tensor[0, :, :, 1] = np.array([[0, 1], [0, 0]]) # annihilate 52 | array.append(tensor) 53 | 54 | # bulk 55 | tensor = np.zeros((2 * k + 1, 2, 2, 2 * k + 1)) 56 | for i in range(k - 1): # runs until k-2 57 | tensor[2 * i, :, :, 2 * i] = np.eye(2) 58 | tensor[2 * i + 1, :, :, 2 * i + 2] = np.array([[0, 0], [1, 0]]) 59 | tensor[2 * i + 1, :, :, 2 * i + 3] = np.array([[1, 0], [0, 0]]) 60 | tensor[2 * i, :, :, 2 * i + 1] = np.array([[0, 1], [0, 0]]) 61 | 62 | tensor[2 * k - 2, :, :, 2 * k - 2] = np.eye(2) 63 | tensor[2 * k - 2, :, :, 2 * k - 1] = np.array([[0, 1], [0, 0]]) 64 | tensor[2 * k - 1, :, :, 2 * k] = np.array([[0, 0], [1, 0]]) 65 | tensor[2 * k, :, :, 2 * k] = np.eye(2) 66 | 67 | array = array + (2 * self.m - 2) * [tensor] 68 | 69 | # right boundary 70 | tensor = np.zeros((2 * k + 1, 2, 2, 1)) 71 | tensor[2 * k, :, :, 0] = np.eye(2) 72 | tensor[2 * k - 1, :, :, 0] = np.array([[0, 0], [1, 0]]) # create 73 | array.append(tensor) 74 | 75 | return mp.MPArray(mp.mpstruct.LocalTensors(array)) 76 | -------------------------------------------------------------------------------- /plotting/make_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from utils import get_partitions 5 | from pathlib import Path 6 | 7 | SCRIPT_DIR = Path(__file__).parent.resolve() 8 | DATA_DIR = SCRIPT_DIR.parent / "DATA" 9 | FIG_DIR = SCRIPT_DIR.parent / "FIGS" 10 | 11 | path = DATA_DIR # data directory 12 | 13 | file_name = "mps_data.dat" 14 | with open(path / file_name, "rb") as fp: 15 | result_mps = pickle.load(fp) 16 | 17 | file_name = "sage_data.dat" 18 | with open(path / file_name, "rb") as fp: 19 | result_sage = pickle.load(fp) 20 | 21 | num_trials = len(result_mps) 22 | assert len(result_sage) == num_trials 23 | 24 | # GAP runtimes (in milliseconds) 25 | runtime_gap = {} 26 | runtime_gap[10] = 25 27 | runtime_gap[12] = 58 28 | runtime_gap[14] = 115 29 | runtime_gap[16] = 234 30 | runtime_gap[18] = 443 31 | runtime_gap[20] = 874 32 | runtime_gap[22] = 1558 33 | runtime_gap[24] = 2757 34 | runtime_gap[26] = 4811 35 | runtime_gap[28] = 8340 36 | runtime_gap[30] = 14023 37 | 38 | 39 | # minimum n to include in the plot 40 | n_min = 10 41 | result_mps = [ 42 | result_mps[i] for i in range(num_trials) if np.sum(result_mps[i]["Mu"]) >= n_min 43 | ] 44 | result_sage = [ 45 | result_sage[i] for i in range(num_trials) if np.sum(result_sage[i]["Mu"]) >= n_min 46 | ] 47 | num_trials = len(result_mps) 48 | assert num_trials == len(result_sage) 49 | 50 | 51 | data = {} 52 | data["MPS"] = [ 53 | result_mps[i]["MPS_runtime_part1"] + result_mps[i]["MPS_runtime_part2"] 54 | for i in range(num_trials) 55 | ] 56 | data["GAP"] = [runtime_gap[i] / 1000 for i in runtime_gap] 57 | data["SAGE"] = [result_sage[i]["sage_runtime"] for i in range(num_trials)] 58 | 59 | 60 | # compute maximum approximation error 61 | err_max_full = 0 62 | for i in range(num_trials): 63 | Mu = result_mps[i]["Mu"] 64 | assert Mu == result_sage[i]["Mu"] 65 | n = np.sum(Mu) 66 | print("n=", n) 67 | Pn = get_partitions(n) 68 | err_max = 0 69 | for Lambda in Pn: 70 | chi_mps = result_mps[i]["table_mps"][Lambda] 71 | chi_sage = result_sage[i]["table"][Lambda] 72 | err = np.abs(chi_mps - chi_sage) 73 | err_max = max(err, err_max) 74 | print("Mu=", Mu, "maximum approximation error=", err_max) 75 | err_max_full = max(err_max_full, err_max) 76 | print("Full maximum approximation error = ", err_max_full) 77 | 78 | 79 | x = np.arange(num_trials) # the label locations 80 | width = 0.25 # the width of the bars 81 | multiplier = 0 82 | 83 | assert len(x) == len(data["MPS"]) 84 | assert len(x) == len(data["SAGE"]) 85 | assert len(x) == len(data["GAP"]) 86 | 87 | plt.rcParams["font.size"] = "20" 88 | plt.figure(figsize=(10, 8)) 89 | ax = plt.gca() 90 | 91 | for name, runtime in data.items(): 92 | offset = width * multiplier 93 | rects = ax.bar(x + offset, runtime, width, label=name) 94 | multiplier += 1 95 | 96 | 97 | partitions = [str(np.sum(result_mps[i]["Mu"])) for i in range(num_trials)] 98 | 99 | 100 | ax.set_ylabel("Runtime (seconds)") 101 | ax.set_xticks(x + width, partitions) 102 | ax.legend(loc="upper left", ncols=3) 103 | plt.yscale("log") 104 | ax.set_xlabel("n") 105 | ax.grid(True) 106 | plt.show() 107 | plt.savefig(FIG_DIR / "char.pdf") 108 | -------------------------------------------------------------------------------- /timing/timing_kostka.py: -------------------------------------------------------------------------------- 1 | import time 2 | from sage.all import * 3 | from champs.kostka_builder import KostkaBuilder 4 | from utils import get_partitions 5 | import random as rm 6 | import numpy as np 7 | from pathlib import Path 8 | import pickle 9 | 10 | # Code to time and compare the MPS Kostka algorithm to symmetrica 11 | SCRIPT_DIR = Path(__file__).parent.resolve() 12 | DATA_DIR = SCRIPT_DIR.parent / "DATA" 13 | 14 | path = DATA_DIR # data directory 15 | file_name = path / "kostka_short.dat" 16 | 17 | start = 10 18 | stop = 40 # non inclusive 19 | step = 4 20 | relerr = 1e-14 21 | its = 100 # number of iterations per size 22 | 23 | 24 | def trial_mps(Mu:tuple[int], Pn:list[tuple[int]]) -> tuple[float, dict]: 25 | """ 26 | Compute Kostka numbers using MPS algorithm for a given weight vector Mu and partitions Pn. 27 | 28 | Args: 29 | Mu (tuple[int]): weight vector 30 | Pn (tuple[int]): partitions 31 | 32 | Returns: 33 | tuple[float, dict]: elapsed time and dictionary of Kostka numbers 34 | """ 35 | table_mps = {} 36 | t = time.time() 37 | builder = KostkaBuilder(Mu, relerr=relerr) 38 | for Lambda in Pn: 39 | table_mps[Lambda] = builder.get_kostka(Lambda) 40 | return time.time() - t, table_mps 41 | 42 | 43 | def trial_sage(Mu:tuple[int], Pn:list[tuple[int]]) -> tuple[float, dict]: 44 | """ 45 | Compute Kostka numbers using SAGE for a given weight vector Mu and partitions Pn. 46 | 47 | Args: 48 | Mu (tuple[int]): weight vector 49 | Pn (list[tuple[int]]): partitions 50 | 51 | Returns: 52 | tuple[float, dict]: elapsed time and dictionary of Kostka numbers 53 | """ 54 | table_sage = {} 55 | t = time.time() 56 | for Lambda in Pn: 57 | table_sage[Lambda] = symmetrica.kostka_number(Lambda, Mu) 58 | return time.time() - t, table_sage 59 | 60 | 61 | results = [] # array of dictionaries 62 | 63 | for n in range(start, stop, step): 64 | arr = () 65 | Pn = get_partitions(n) 66 | 67 | print("Running Kostkas for partitions of length " + str(n)) 68 | 69 | # run time trials 70 | for i in range(its): 71 | print("Iteration: " + str(i), end="\r") 72 | Mu = rm.choice(Pn) # random Mu 73 | while len(Mu) > int(n / 3): # require Mu to be "short" 74 | Mu = rm.choice(Pn) 75 | 76 | elapsed_mps, table_mps = trial_mps(Mu, Pn) 77 | elapsed_sage, table_sage = trial_sage(Mu, Pn) 78 | 79 | # check for errors 80 | max_error = 0 81 | num_error = 0 82 | if table_mps and table_sage: 83 | for Lambda in Pn: 84 | tmp = np.abs(table_mps[Lambda] - table_sage[Lambda]) 85 | if tmp > max_error: 86 | max_error = tmp 87 | if tmp >= 0.5: 88 | num_error += 1 89 | 90 | # MPS data 91 | results.append( 92 | { 93 | "n": n, 94 | "Algorithm": "MPS", 95 | "Runtime": elapsed_mps, 96 | "Mu": Mu, 97 | "Errors": num_error, 98 | "Max error": max_error, 99 | "Relerr": relerr, 100 | } 101 | ) 102 | 103 | # SAGE data 104 | results.append( 105 | { 106 | "n": n, 107 | "Algorithm": "SAGE", 108 | "Runtime": elapsed_sage, 109 | "Mu": Mu, 110 | "Errors": 0, 111 | "Max error": 0, 112 | "Relerr": 0, 113 | } 114 | ) 115 | 116 | print("################################################") 117 | 118 | with open(file_name, "wb") as fp: 119 | pickle.dump(results, fp) 120 | print("Done") 121 | print("file_name=", file_name) 122 | -------------------------------------------------------------------------------- /experiments.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Experiments\n", 8 | "\n", 9 | "The following notebook reproduces the data and code used in the manuscript.\n", 10 | "\n", 11 | "Run this with SageMath kernel. You can install it from [here](https://www.sagemath.org).\n", 12 | "\n", 13 | "SageMath kernel may not see your virtual environment. \n", 14 | "You can install the minimum requirements for executing the timing experiments by executing the following code: " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "vscode": { 22 | "languageId": "python" 23 | } 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "%pip install -q -r requirements.txt" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Characters" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### Timing\n", 42 | "\n", 43 | "#### MPS \n", 44 | "\n", 45 | "Runs the MPS character timing. This generates data to reproduce Fig. 1 in the manuscript.\n", 46 | "\n", 47 | "The character computation timing experiments take about 16 minutes on M1 Mac with 64GB memory. " 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "vscode": { 55 | "languageId": "python" 56 | } 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "%run ./timing/timing_mps.py" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "#### SAGE and GAP\n", 68 | "\n", 69 | "Note that Sage also executes the GAP code in timing_gap.g. " 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "vscode": { 77 | "languageId": "python" 78 | } 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "%run ./timing/timing_sage.py" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Plotting\n", 90 | "\n", 91 | "#### Runtime Comparison" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "vscode": { 99 | "languageId": "python" 100 | } 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "%run ./plotting/make_plot.py" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "#### Bond Dimension" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "vscode": { 119 | "languageId": "python" 120 | } 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "%run ./plotting/make_plot_dim.py" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## Kostka\n", 132 | "\n", 133 | "### Timing\n", 134 | "\n", 135 | "#### MPS\n", 136 | "\n", 137 | "Kostka number MPS algorithm timing experiment. Note: this takes about an hour on a M1 Mac with 64GB RAM." 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": { 144 | "vscode": { 145 | "languageId": "python" 146 | } 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "%run ./timing/timing_kostka.py" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "### Plotting" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "vscode": { 165 | "languageId": "python" 166 | } 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "%run ./plotting/make_plot_kostka.py" 171 | ] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "SageMath 10.3", 177 | "language": "sage", 178 | "name": "SageMath-10.3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "sage", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.11.8" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /champs/character_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # NOTE: MPNUM is no longer maintained, but it's still a good package for MPS/MPO simulations! 4 | # The following fixes dependency issues for numpy 2.0 5 | if np.version.version > "2.0": 6 | np.float_ = np.float64 7 | np.complex_ = np.complex128 8 | 9 | # The following fixes dependency issues for python >= 3.7 10 | import sys 11 | import collections 12 | 13 | if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: 14 | collections.Sequence = collections.abc.Sequence 15 | collections.Iterable = collections.abc.Iterable 16 | collections.Iterator = collections.abc.Iterator 17 | 18 | import mpnum as mp # MPS/MPO simulation package 19 | import champs 20 | 21 | import quimb.tensor as qtn 22 | from champs.builder import QUIMB_BACKEND, MPNUM_BACKEND 23 | 24 | 25 | class CharacterBuilder(champs.Builder): 26 | def __init__( 27 | self, 28 | Mu: tuple[int], 29 | Nu: tuple[int] = (0,), 30 | relerr: float = 1e-10, 31 | backend: str = MPNUM_BACKEND, 32 | ): 33 | """ 34 | MPS algorithm for characters of the symmetric group S_n described in arXiv:2501.???? 35 | 36 | Takes as input a conjugacy class Mu of S_n specified as a list of 37 | positive integers that sum to n 38 | 39 | Args: 40 | Mu (tuple[int]): S_n conjugacy class as a list of positive integers that sum up to n. 41 | relerr (float, optional): MPS compression relative error. Defaults to 1e-10. 42 | """ 43 | 44 | super().__init__(Mu=Mu, Nu=Nu, relerr=relerr, backend=backend) 45 | 46 | def get_character(self, Lambda: tuple[int]) -> int: 47 | """ 48 | Computes the character chi_Lambda(Mu) for a conjugacy class Mu and an irrep Lambda of S_n 49 | Note that the conjugacy class Mu is fixed by the CharacterBuilder object. Caches the partial products of MPS matrices over each interval to speed up the computation. 50 | 51 | Args: 52 | Lambda (tuple[int]): an irrep of S_n as a list of positive integers that sums up to n. 53 | 54 | Returns: 55 | int: character chi_Lambda(Mu) 56 | """ 57 | assert len(Lambda) <= self.n 58 | 59 | return int(np.round(self._contract(Lambda))) 60 | 61 | def _get_MPNUM_MPO(self, k: int) -> mp.MPArray: 62 | """ 63 | MPO representation of the current operator J_k = sum_i a_i a_{i+k}^dag. 64 | 65 | Uses the MPNUM package to build the MPO. 66 | 67 | Args: 68 | k (int): parameter specifying the current operator J_k. 69 | 70 | Returns: 71 | mp.MPArray: MPO representation of the current operator J_k. 72 | """ 73 | array = [] 74 | 75 | # left boundary 76 | tensor = np.zeros((1, 2, 2, k + 2)) 77 | tensor[0, :, :, 0] = np.eye(2) 78 | tensor[0, :, :, 1] = np.array([[0, 1], [0, 0]]) # flip qubit from '1' to '0' 79 | array.append(tensor) 80 | 81 | # bulk 82 | tensor = np.zeros((k + 2, 2, 2, k + 2)) # index ordering Left Right Up Down 83 | tensor[0, :, :, 0] = np.eye(2) 84 | tensor[k + 1, :, :, k + 1] = np.eye(2) 85 | tensor[0, :, :, 1] = np.array([[0, 1], [0, 0]]) # flip qubit from '1' to '0' 86 | tensor[k, :, :, k + 1] = np.array( 87 | [[0, 0], [1, 0]] 88 | ) # flip qubit from '0' to '1' 89 | 90 | # Pauli Z 91 | for j in range(1, k): 92 | tensor[j, :, :, j + 1] = np.array([[1, 0], [0, -1]]) 93 | 94 | array = array + (2 * self.m - 2) * [tensor] 95 | 96 | # right boundary 97 | tensor = np.zeros((k + 2, 2, 2, 1)) 98 | tensor[k + 1, :, :, 0] = np.eye(2) 99 | # flip qubit from '0' to '1' 100 | tensor[k, :, :, 0] = np.array([[0, 0], [1, 0]]) 101 | array.append(tensor) 102 | return mp.MPArray(mp.mpstruct.LocalTensors(array)) 103 | 104 | def _get_QUIMB_MPO(self, k: int) -> qtn.tensor_1d.MatrixProductOperator: 105 | """ 106 | MPO representation of the current operator J_k = sum_i a_i a_{i+k}^dag. 107 | Uses the QUIMB package to build the MPO. 108 | 109 | Args: 110 | k (int): parameter specifying the current operator J_k. 111 | 112 | Returns: 113 | qtn.tensor_1d.MatrixProductOperator: MPO representation of the current operator J_k. 114 | """ 115 | array = [] 116 | 117 | # left boundary 118 | tensor = np.zeros((k + 2, 2, 2)) 119 | tensor[0, :, :] = np.eye(2) 120 | # flip qubit from '1' to '0' 121 | tensor[1, :, :] = np.array([[0, 1], [0, 0]]) 122 | array.append(tensor) 123 | 124 | # bulk 125 | tensor = np.zeros((k + 2, k + 2, 2, 2)) 126 | tensor[0, 0, :, :] = np.eye(2) 127 | tensor[k + 1, k + 1, :, :] = np.eye(2) 128 | # flip qubit from '1' to '0' 129 | tensor[0, 1, :, :] = np.array([[0, 1], [0, 0]]) 130 | # flip qubit from '0' to '1' 131 | tensor[k, k + 1, :, :] = np.array([[0, 0], [1, 0]]) 132 | 133 | # Pauli Z 134 | for j in range(1, k): 135 | tensor[j, j + 1, :, :] = np.array([[1, 0], [0, -1]]) 136 | 137 | array = array + (2 * self.n - 2) * [tensor] 138 | 139 | # right boundary 140 | tensor = np.zeros((k + 2, 2, 2)) 141 | tensor[k + 1, :, :] = np.eye(2) 142 | # flip qubit from '0' to '1' 143 | tensor[k, :, :] = np.array([[0, 0], [1, 0]]) 144 | array.append(tensor) 145 | 146 | return qtn.tensor_1d.MatrixProductOperator( 147 | array, 148 | shape="lrud", 149 | tags=self.qubits, 150 | upper_ind_id="k{}", 151 | lower_ind_id="b{}", 152 | site_tag_id="I{}", 153 | ) 154 | 155 | def get_MPO(self, k: int) -> mp.MPArray | qtn.tensor_1d.MatrixProductOperator: 156 | """ 157 | MPO representation of the current operator J_k = sum_i a_i a_{i+k}^dag. 158 | 159 | Args: 160 | k (int): parameter specifying the current operator J_k. 161 | 162 | Returns: 163 | mpnum.MPArray | qtn.tensor_1d.MatrixProductOperator: MPO representation of the current operator J_k. Return type depends on the self.backend. 164 | """ 165 | 166 | if self.backend == MPNUM_BACKEND: 167 | return self._get_MPNUM_MPO(k) 168 | 169 | elif self.backend == QUIMB_BACKEND: 170 | return self._get_QUIMB_MPO(k) 171 | -------------------------------------------------------------------------------- /champs/builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # NOTE: MPNUM is no longer maintained, but it's still a good package for MPS/MPO simulations! 4 | # The following fixes dependency issues for numpy 2.0 5 | if np.version.version > "2.0": 6 | np.float_ = np.float64 7 | np.complex_ = np.complex128 8 | 9 | # The following fixes dependency issues for python >= 3.7 10 | import sys 11 | import collections 12 | 13 | if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: 14 | collections.Sequence = collections.abc.Sequence 15 | collections.Iterable = collections.abc.Iterable 16 | collections.Iterator = collections.abc.Iterator 17 | 18 | # ----------------- QUIMB Imports ----------------- 19 | import quimb.tensor as qtn 20 | from quimb.tensor.tensor_1d_compress import mps_gate_with_mpo_direct 21 | 22 | # ----------------- MPNUM Imports ----------------- 23 | import mpnum as mp # MPS/MPO simulation package 24 | from utils import majorize 25 | 26 | 27 | # ----------------- MPNUM Constants ----------------- 28 | MPNUM_BACKEND = "mpnum" 29 | 30 | # Basis states for the mpnum backend 31 | MPNUM_UP: np.array = np.zeros((1, 2, 1)) 32 | MPNUM_UP[0, 1, 0] = 1 33 | 34 | MPNUM_DOWN: np.array = np.zeros((1, 2, 1)) 35 | MPNUM_DOWN[0, 0, 0] = 1 36 | 37 | # ----------------- QUIMB Constants ----------------- 38 | QUIMB_BACKEND = "quimb" 39 | 40 | # Basis states for the quimb backend 41 | QUIMB_UP_BOUNDARY: np.array = np.zeros((1, 2)) 42 | QUIMB_UP_BOUNDARY[0, 1] = 1 43 | 44 | QUIMB_UP_BULK: np.array = np.zeros((1, 1, 2)) 45 | QUIMB_UP_BULK[0, 0, 1] = 1 46 | 47 | QUIMB_DOWN_BULK: np.array = np.zeros((1, 1, 2)) 48 | QUIMB_DOWN_BULK[0, 0, 0] = 1 49 | 50 | QUIMB_DOWN_BOUNDARY: np.array = np.zeros((1, 2)) 51 | QUIMB_DOWN_BOUNDARY[0, 0] = 1 52 | 53 | 54 | class Builder: 55 | """ 56 | Defines the structure of the builders for 57 | Kostka numbers and Sn characters. 58 | 59 | Args: 60 | Mu (tuple[int]): S_n conjugacy class as a list of positive integers that sum up to n. 61 | Nu (tuple[int], optional): S_n conjugacy class as a list of positive integers that sum up to n. Defaults to (0,). Used to encode skew partitions Mu/Nu. 62 | relerr (float, optional): MPS compression relative error. Defaults to 1e-10. 63 | """ 64 | 65 | def __init__( 66 | self, 67 | Mu: tuple[int], 68 | Nu: tuple[int] = (0,), 69 | relerr: float = 1e-10, 70 | backend: str = MPNUM_BACKEND, 71 | ): 72 | self.Mu = Mu 73 | self.Nu = Nu 74 | self.n = np.sum(self.Mu) 75 | # size of Lambdas to evaluate will be n + sum(Nu) 76 | self.m = self.n + np.sum(self.Nu) # size of Hilbert space 77 | self.Nu = Nu + (0,) * (self.m - len(Nu)) # pad Nu with 0s 78 | 79 | self.relerr = relerr # relative error for MPS compression 80 | self.backend = backend 81 | self.maximum_rank = 1 82 | 83 | self.qubits = [ 84 | i for i in range(2 * self.m) 85 | ] # used by quimb to label the qubits 86 | 87 | # compute the MPS that encodes all characters of Mu 88 | self.mps = self.get_MPS() 89 | 90 | # CACHING: 91 | 92 | # divide the spin chain into four intervals: left (L), center left 93 | # (C1), center right C2, right (R) 94 | self.m1 = int(np.round(self.m / 2)) 95 | self.m2 = self.m 96 | self.m3 = int(np.round(3 * self.m / 2)) 97 | self.L = [i for i in range(2 * self.m) if i < self.m1] 98 | self.C1 = [i for i in range(2 * self.m) if i >= self.m1 and i < self.m2] 99 | self.C2 = [i for i in range(2 * self.m) if i >= self.m2 and i < self.m3] 100 | self.R = [i for i in range(2 * self.m) if i >= self.m3] 101 | self.mL = len(self.L) 102 | self.mC1 = len(self.C1) 103 | self.mC2 = len(self.C2) 104 | self.mR = len(self.R) 105 | # cache partial products of MPS matrices over each interval 106 | self.cacheL = {} 107 | self.cacheC1 = {} 108 | self.cacheC2 = {} 109 | self.cacheR = {} 110 | 111 | def get_MPS(self) -> mp.MPArray | qtn.tensor_1d.MatrixProductState: 112 | """ 113 | Compute the MPS that encodes all characters of Mu. 114 | 115 | Returns: 116 | mp.MPArray | qtn.tensor_1d.MatrixProductState: MPS that encodes all characters of Mu. The return type depends on self.backend. 117 | """ 118 | 119 | self.maximum_rank = 1 120 | mps = self.get_initial_MPS() 121 | 122 | if self.backend == MPNUM_BACKEND: 123 | for k in self.Mu: 124 | mpo = self.get_MPO(k) 125 | mps = mp.dot(mpo, mps) 126 | mps.compress(method="svd", relerr=self.relerr) 127 | self.maximum_rank = max(self.maximum_rank, np.max(mps.ranks)) 128 | 129 | elif self.backend == QUIMB_BACKEND: 130 | for k in self.Mu: 131 | mpo = self.get_MPO(k) 132 | mps_gate_with_mpo_direct( 133 | mps, mpo, cutoff=self.relerr, cutoff_mode="rsum1", inplace=True 134 | ) 135 | for q in self.qubits: 136 | if q == 0 or q == (2 * self.m - 1): 137 | D = mps.arrays[q].shape[0] 138 | else: 139 | D = max(mps.arrays[q].shape[0], mps.arrays[q].shape[1]) 140 | self.maximum_rank = max(D, self.maximum_rank) 141 | return mps 142 | 143 | def get_bond_dimension(self) -> int: 144 | """ 145 | Returns the maximum bond dimension (maximum Schmidt rank) of the MPS. 146 | 147 | Returns: 148 | int: _description_ 149 | """ 150 | return self.maximum_rank 151 | 152 | def get_initial_MPS(self) -> mp.MPArray | qtn.tensor_1d.MatrixProductState: 153 | """ 154 | Compute the MPS that encodes the initial state. 155 | 156 | Returns: 157 | mp.MPArray | qtn.tensor_1d.MatrixProductState: an MPS tensor representation of the initial state. The return type depends on self.backend. 158 | """ 159 | 160 | if self.backend == MPNUM_BACKEND: 161 | array = [] # Local tensors 162 | # Traverse Nu in reverse order 163 | array += [MPNUM_DOWN] * self.Nu[self.m - 1] # step right 164 | array += [MPNUM_UP] # step up 165 | for i in range(self.m - 1, 0, -1): 166 | array += [MPNUM_DOWN] * (self.Nu[i - 1] - self.Nu[i]) # step right 167 | array += [MPNUM_UP] # step up 168 | array = array + [MPNUM_DOWN] * (2 * self.m - len(array)) # step right 169 | return mp.MPArray(mp.mpstruct.LocalTensors(array)) 170 | 171 | # NOTE: for Nu = (0, ) the initial state is the vacuum state 172 | # MPS representation of the initial state |1^n 0^n> 173 | # array = self.n * [self.tensor1] + self.n * [self.tensor0] 174 | # mps = mp.MPArray(mp.mpstruct.LocalTensors(array)) 175 | 176 | elif self.backend == QUIMB_BACKEND: 177 | if sum(self.Nu) != 0: 178 | raise NotImplementedError( 179 | "Initial state for skew partitions with QUIMB backend is not implemented yet." 180 | ) 181 | 182 | array = ( 183 | [QUIMB_UP_BOUNDARY] 184 | + (self.m - 1) * [QUIMB_UP_BULK] 185 | + (self.m - 1) * [QUIMB_DOWN_BULK] 186 | + [QUIMB_DOWN_BOUNDARY] 187 | ) 188 | 189 | return qtn.tensor_1d.MatrixProductState( 190 | array, 191 | shape="lrp", 192 | tags=self.qubits, 193 | site_ind_id="k{}", 194 | site_tag_id="I{}", 195 | ) 196 | 197 | def get_MPO(self, k: int) -> mp.MPArray | qtn.tensor_1d.MatrixProductOperator: 198 | """ 199 | MPO representation of the current operator J_k = sum_i a_i a_{i+k}^dag. 200 | 201 | Args: 202 | k (int): parameter specifying the current operator J_k. 203 | 204 | Returns: 205 | mpnum.MPArray | qtn.tensor_1d.MatrixProductOperator: MPO representation of the current operator J_k. the return type depends on self.backend. 206 | """ 207 | raise NotImplementedError 208 | 209 | def valid_skew(self, Lambda: tuple[int]) -> bool: 210 | """ 211 | Determines if Lamba \ Nu has enough boxes to have weight Mu. 212 | 213 | Args: 214 | Lambda (tuple[int]): partition that defines the skew shape Lambda \ Nu. 215 | 216 | Returns: 217 | bool: True if Lambda \ Nu has enough boxes to have weight Mu, False otherwise. 218 | """ 219 | return majorize(self.Nu, Lambda, eq=False) and sum(Lambda) == self.m 220 | 221 | def _contract(self, Lambda: tuple[int]): 222 | """ 223 | 224 | Args: 225 | Lambda (tuple[int]): partition that defines the skew shape Lambda \ Nu 226 | 227 | Returns: 228 | float: the amplitude in the mps 229 | 230 | """ 231 | 232 | # Lambda must be a partition of m 233 | assert sum(Lambda) == self.m 234 | 235 | padded_Lambda = list(Lambda) + [0] * (self.m - len(Lambda)) 236 | 237 | if self.m < 8: 238 | # don't use caching for small m's 239 | array = [MPNUM_DOWN] * (2 * self.m) 240 | for i in range(self.m): 241 | array[padded_Lambda[i] + self.m - 1 - i] = MPNUM_UP 242 | basis_state_mps = mp.MPArray(mp.mpstruct.LocalTensors(array)) 243 | # compute inner product between a basis state and the MPS 244 | return mp.mparray.inner(basis_state_mps, self.mps) 245 | 246 | bitstring = np.zeros(2 * self.m, dtype=int) 247 | supp = [padded_Lambda[i] + self.m - i - 1 for i in range(self.m)] 248 | bitstring[supp] = 1 249 | # project bitstring onto each caching register 250 | xL = bitstring[self.L] 251 | xC1 = bitstring[self.C1] 252 | xC2 = bitstring[self.C2] 253 | xR = bitstring[self.R] 254 | 255 | if not (tuple(xL) in self.cacheL): 256 | self.cacheL[tuple(xL)] = np.linalg.multi_dot( 257 | [self.mps.lt[self.L[i]][:, xL[i], :] for i in range(self.mL)] 258 | ) 259 | 260 | if not (tuple(xC1) in self.cacheC1): 261 | self.cacheC1[tuple(xC1)] = np.linalg.multi_dot( 262 | [self.mps.lt[self.C1[i]][:, xC1[i], :] for i in range(self.mC1)] 263 | ) 264 | 265 | if not (tuple(xC2) in self.cacheC2): 266 | self.cacheC2[tuple(xC2)] = np.linalg.multi_dot( 267 | [self.mps.lt[self.C2[i]][:, xC2[i], :] for i in range(self.mC2)] 268 | ) 269 | 270 | if not (tuple(xR) in self.cacheR): 271 | self.cacheR[tuple(xR)] = np.linalg.multi_dot( 272 | [self.mps.lt[self.R[i]][:, xR[i], :] for i in range(self.mR)] 273 | ) 274 | 275 | chi = (self.cacheL[tuple(xL)] @ self.cacheC1[tuple(xC1)]) @ ( 276 | self.cacheC2[tuple(xC2)] @ self.cacheR[tuple(xR)] 277 | ) 278 | return chi[0][0] 279 | --------------------------------------------------------------------------------