├── .gitignore ├── 01knapsack.py ├── LICENSE ├── README.md ├── avl_tree.py ├── binary_indexed_tree.py ├── codeforces.py ├── cpp_extension ├── acl__document.py ├── acl_convolution.py ├── acl_lazysegtree.py ├── acl_math.py ├── acl_mcf_graph.py ├── acl_mf_graph.py ├── acl_segtree.py ├── acl_two_sat.py ├── ctypes_cpp_set.py ├── shorten.py ├── template.py ├── wrap_cpp_multiset.py └── wrap_cpp_set.py ├── fast_primality_test.py ├── geometry.py ├── numba_library.py ├── old.py ├── python2_template.py ├── segtree.py ├── snippet.py ├── snippet2.py └── square_skip_list.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | not_mine/ 3 | -------------------------------------------------------------------------------- /01knapsack.py: -------------------------------------------------------------------------------- 1 | # 枝刈り探索(分枝限定法) 2 | # 検証: https://atcoder.jp/contests/abc032/submissions/8197125 3 | 4 | class Knapsack: 5 | def __init__(self, VW): 6 | self.VW = VW 7 | self.VW.sort(key=lambda vw: vw[0] / vw[1], reverse=True) 8 | self.n = len(VW) 9 | 10 | def solve(self, capacity, ok=0): 11 | self.ok = ok 12 | self.capacity = capacity 13 | return self._dfs(0, 0, 0) 14 | 15 | def _dfs(self, i, v_now, w_now): 16 | if i==self.n: 17 | self.ok = max(self.ok, v_now) 18 | return v_now 19 | ng, f = self._solve_relaxation(i, self.capacity-w_now) 20 | ng += v_now 21 | if f: 22 | self.ok = max(self.ok, ng) 23 | return ng 24 | if ng < self.ok: 25 | return -float("inf") 26 | res = -float("inf") 27 | v, w = self.VW[i] 28 | if w_now + w <= self.capacity: 29 | res = max(res, self._dfs(i+1, v_now + v, w_now + w)) 30 | res = max(res, self._dfs(i+1, v_now, w_now)) 31 | return res 32 | 33 | def _solve_relaxation(self, i, capacity): 34 | res = 0 35 | f = True 36 | for v, w in self.VW[i:]: 37 | if capacity == 0: 38 | break 39 | if w <= capacity: 40 | capacity -= w 41 | res += v 42 | else: 43 | f = False 44 | res += v * (capacity / w) 45 | break 46 | return res, f 47 | 48 | def main(): 49 | N, W = map(int, input().split()) 50 | VW = [list(map(int, input().split())) for _ in range(N)] 51 | knapsack = Knapsack(VW) 52 | print(knapsack.solve(W)) 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # snippet 2 | 競プロの 3 | 4 | ##### 01knapsack.py 5 | 6 | - 分枝限定法 7 | 8 | ##### avl_tree.py 9 | 10 | - AVL 木(非推奨、square_skip_list.py を使うべき) 11 | 12 | ##### binary_indexed_tree.py 13 | 14 | - Binary Indexed Tree 15 | 16 | ##### fast_primality_test.py 17 | 18 | - 高速素数判定 19 | 20 | ##### geometry.py 21 | 22 | - 点と直線の距離 23 | - 2 線分の交差判定 24 | - Monotone Chain (凸包) 25 | - 最小包含円 26 | - 円と多角形の共通部分の面積 27 | 28 | ##### numba_library.py 29 | 30 | - Numba コンパイルのテンプレート 31 | - Numba で高速化したライブラリ 32 | 33 | ##### segtree.py 34 | 35 | - Segment Tree いろいろ 36 | 37 | ##### snippet.py 38 | - 拡張ユークリッド互除法 39 | - 中国剰余定理 40 | - mod 逆元 41 | - 組み合わせ計算 42 | - C (組み合わせ) 43 | - P (順列) 44 | - H (重複組み合わせ) 45 | - 上昇階乗冪 46 | - 第 1 種スターリング数 47 | - 第 2 種スターリング数 48 | - Balls and Boxes 3 (第2種スターリング数 \* k!) 49 | - ベルヌーイ数 50 | - ファウルハーバーの公式 51 | - Lah number 52 | - ベル数 53 | - 素数リスト作成 54 | - 素因数分解 (愚直) 55 | - 素因数分解 (ロー法) 56 | - ミラーラビン素数判定法 57 | - BIT (Binary Indexed Tree) 58 | - BIT でのいもす法 59 | - 区間加算 60 | - ダイクストラ法 61 | - SPFA (Shortest Path Faster Algorithm) 62 | - Union Find 木 63 | - 高速ゼータ変換 64 | - (倍数集合の高速ゼータ変換) 65 | - Segment Tree 66 | - Manacher (最長回文) 67 | - Z-algorithm 68 | - Dinic 法 (最大流問題) 69 | - LIS (最長増加部分列) 70 | - LCA (ダブリング) 71 | - ニュートン補間 72 | - X = [0, 1, 2, ... , n] のラグランジュ補間 73 | - ファウルハーバーの公式 74 | - ローリングハッシュ 75 | - 畳み込み (NumPy) 76 | - Garner のアルゴリズム (NumPy) 77 | - 任意 mod 畳み込み (Numpy) 78 | - n 個を k 人に均等/貪欲に分けるだけ 79 | - xorshift (乱数生成) 80 | - (重み付き Union Find) 81 | - (Trie) 82 | - (全方位木 DP) 83 | 84 | 85 | ##### snippet2.py 86 | 87 | - 分割数 88 | - 原始ピタゴラス数の列挙 89 | - オイラーの Φ 関数 90 | - sqrt(n) の連分数展開 91 | - ペル方程式 x\*x - n\*y\*y = 1 の最小整数解 92 | - ペル方程式の拡張 x\*x - n\*y\*y = -1 の最小整数解 93 | - カレンダー 94 | - ツェラーの公式 95 | - 階乗進数 96 | - 順列の辞書順 i 番目 97 | - ラグランジュ補間 98 | - 多倍長整数を使った Karatsuba 法 99 | - F2 上の Gauss Jordan の掃き出し法 100 | - LCA (HL分解) 101 | - ドント方式 102 | - 割り算の結果によって処理を変えることで場合分けが O(sqrt(N)) 通りで済む問題の補助 103 | - osa_k 法 104 | - 奇置換・偶置換の判定 105 | - 二項間漸化式 a_n = x \* a_{n-1} + y 106 | - 二項間漸化式 a_n = r \* a_{n-1} + b \* n + c 107 | - 多項式 108 | - きたまさ法 109 | - 最大クリーク (Bron–Kerbosch algorithm + ヒューリスティック枝刈り) 110 | - 全方位木 DP 111 | - スライド最小値 112 | - ダブリング 113 | - Berlekamp--Massey 114 | - Karatsuba 法 115 | - リスト埋め込み用エンコーダ・デコーダ 116 | 117 | ##### square_skip_list.py 118 | 119 | - C++ の multiset に相当するデータ構造 120 | 121 | ##### wrap_cpp_multiset.py 122 | 123 | - C++ の multiset のラッパー 124 | -------------------------------------------------------------------------------- /avl_tree.py: -------------------------------------------------------------------------------- 1 | class AvlTree: # std::set 2 | def __init__(self, values=None, sorted_=False, n=0): 3 | # values: 初期値のリスト 4 | # sorted_: 初期値がソート済みであるか 5 | # n: add メソッドを使う回数の最大値 6 | 7 | # sorted_==True であれば、初期値の数の線形時間で木を構築する 8 | # 値を追加するときは必ず n を設定する 9 | if values is None: 10 | self.left = [-1] * (n + 1) 11 | self.right = [-1] * (n + 1) 12 | self.values = [-float("inf")] 13 | self.diff = [0] * (n + 1) # left - right 14 | self.size_l = [0] * (n + 1) 15 | self.idx_new_val = 0 16 | else: 17 | if not sorted_: 18 | values.sort() 19 | len_ = self.idx_new_val = len(values) 20 | n += len_ 21 | self_left = self.left = [-1] * (n + 1) 22 | self_right = self.right = [-1] * (n + 1) 23 | self_values = self.values = [-float("inf")] + values 24 | self_diff = self.diff = [0] * (n + 1) # left - right 25 | self_size_l = self.size_l = [0] * (n + 1) 26 | 27 | st = [[1, len_ + 1, 0]] 28 | while len(st) > 0: # dfs っぽく木を構築 29 | l, r, idx_par = st.pop() # 半開区間 30 | c = (l + r) >> 1 # python -> //2 pypy -> >>1 31 | if self_values[c] < self_values[idx_par]: 32 | self_left[idx_par] = c 33 | else: 34 | self_right[idx_par] = c 35 | siz = r - l 36 | if siz & -siz == siz != 1: # 2 冪だったら 37 | self_diff[c] = 1 38 | self_size_l[c] = siz_l = c - l 39 | if siz_l > 0: 40 | st.append([l, c, c]) 41 | c1 = c + 1 42 | if c1 < r: # 左にノードがなければ右には必ず無いので 43 | st.append([c1, r, c]) 44 | 45 | def rotate_right(self, idx_par, lr): # lr: 親の左なら 0 46 | self_left = self.left 47 | self_right = self.right 48 | self_diff = self.diff 49 | self_size_l = self.size_l 50 | 51 | lr_container = self_right if lr else self_left 52 | idx = lr_container[idx_par] 53 | #assert self_diff[idx] == 2 54 | idx_l = self_left[idx] 55 | diff_l = self_diff[idx_l] 56 | 57 | if diff_l == -1: # 複回転 58 | idx_lr = self_right[idx_l] 59 | diff_lr = self_diff[idx_lr] 60 | if diff_lr == 0: 61 | self_diff[idx] = 0 62 | self_diff[idx_l] = 0 63 | elif diff_lr == 1: 64 | self_diff[idx] = -1 65 | self_diff[idx_l] = 0 66 | self_diff[idx_lr] = 0 67 | else: # diff_lr == -1 68 | self_diff[idx] = 0 69 | self_diff[idx_l] = 1 70 | self_diff[idx_lr] = 0 71 | 72 | # 部分木の大きさの計算 73 | self_size_l[idx_lr] += self_size_l[idx_l] + 1 74 | self_size_l[idx] -= self_size_l[idx_lr] + 1 75 | 76 | # 回転 77 | self_right[idx_l] = self_left[idx_lr] 78 | self_left[idx] = self_right[idx_lr] 79 | self_left[idx_lr] = idx_l 80 | self_right[idx_lr] = idx 81 | lr_container[idx_par] = idx_lr 82 | 83 | return 0 84 | 85 | else: # 単回転 86 | if diff_l == 0: 87 | self_diff[idx] = 1 88 | nb = self_diff[idx_l] = -1 89 | else: # diff_l == 1 90 | self_diff[idx] = 0 91 | nb = self_diff[idx_l] = 0 92 | 93 | # 部分木の大きさの計算 94 | self_size_l[idx] -= self_size_l[idx_l] + 1 95 | 96 | # 回転 97 | self_left[idx] = self_right[idx_l] 98 | self_right[idx_l] = idx 99 | lr_container[idx_par] = idx_l 100 | 101 | return nb # 新しい根の diff を返す 102 | 103 | def rotate_left(self, idx_par, lr): # lr: 親の左なら 0 104 | self_left = self.left 105 | self_right = self.right 106 | self_diff = self.diff 107 | self_size_l = self.size_l 108 | 109 | lr_container = self_right if lr else self_left 110 | idx = lr_container[idx_par] 111 | #assert self_diff[idx] == -2 112 | idx_r = self_right[idx] 113 | diff_l = self_diff[idx_r] 114 | 115 | if diff_l == 1: # 複回転 116 | idx_rl = self_left[idx_r] 117 | diff_rl = self_diff[idx_rl] 118 | if diff_rl == 0: 119 | self_diff[idx] = 0 120 | self_diff[idx_r] = 0 121 | elif diff_rl == -1: 122 | self_diff[idx] = 1 123 | self_diff[idx_r] = 0 124 | self_diff[idx_rl] = 0 125 | else: # diff_lr == 1 126 | self_diff[idx] = 0 127 | self_diff[idx_r] = -1 128 | self_diff[idx_rl] = 0 129 | 130 | # 部分木の大きさの計算 131 | self_size_l[idx_r] -= self_size_l[idx_rl] + 1 132 | self_size_l[idx_rl] += self_size_l[idx] + 1 133 | 134 | # 回転 135 | self_left[idx_r] = self_right[idx_rl] 136 | self_right[idx] = self_left[idx_rl] 137 | self_right[idx_rl] = idx_r 138 | self_left[idx_rl] = idx 139 | lr_container[idx_par] = idx_rl 140 | 141 | return 0 142 | 143 | else: # 単回転 144 | if diff_l == 0: 145 | self_diff[idx] = -1 146 | nb = self_diff[idx_r] = 1 147 | else: # diff_l == 1 148 | self_diff[idx] = 0 149 | nb = self_diff[idx_r] = 0 150 | 151 | # 部分木の大きさの計算 152 | self_size_l[idx_r] += self_size_l[idx] + 1 153 | 154 | # 回転 155 | self_right[idx] = self_left[idx_r] 156 | self_left[idx_r] = idx 157 | lr_container[idx_par] = idx_r 158 | 159 | return nb # 新しい根の diff を返す 160 | 161 | def add(self, x): # insert 162 | # x を加える 163 | # x が既に入ってる場合は False を、 164 | # そうでなければ True を返す 165 | 166 | idx = 0 167 | path = [] 168 | path_left = [] 169 | 170 | self_values = self.values 171 | self_left = self.left 172 | self_right = self.right 173 | 174 | while idx != -1: 175 | path.append(idx) 176 | value = self_values[idx] 177 | if x < value: 178 | path_left.append(idx) # 重複を許さないので処理を後にする必要がある 179 | idx = self_left[idx] 180 | elif value < x: 181 | idx = self_right[idx] 182 | else: # x == value 183 | return False # 重複を許さない 184 | 185 | self.idx_new_val += 1 186 | self_diff = self.diff 187 | self_size_l = self.size_l 188 | 189 | idx = path[-1] 190 | if x < value: 191 | self_left[idx] = self.idx_new_val 192 | else: 193 | self_right[idx] = self.idx_new_val 194 | 195 | self_values.append(x) 196 | 197 | for idx_ in path_left: 198 | self_size_l[idx_] += 1 199 | 200 | self_diff[idx] += 1 if x < value else -1 201 | for idx_par in path[-2::-1]: 202 | diff = self_diff[idx] 203 | if diff == 0: 204 | return True 205 | elif diff == 2: # 右回転 206 | self.rotate_right(idx_par, self_right[idx_par] == idx) 207 | return True 208 | elif diff == -2: # 左回転 209 | self.rotate_left(idx_par, self_right[idx_par] == idx) 210 | return True 211 | else: 212 | self_diff[idx_par] += 1 if self_left[idx_par] == idx else -1 213 | idx = idx_par 214 | return True 215 | 216 | def remove(self, x): # erase 217 | # x を削除する 218 | # x の存在が保証されている必要がある 219 | 220 | idx = 0 221 | path = [] 222 | idx_x = -1 223 | 224 | self_values = self.values 225 | self_left = self.left 226 | self_right = self.right 227 | self_diff = self.diff 228 | self_size_l = self.size_l 229 | 230 | while idx != -1: 231 | path.append(idx) 232 | value = self_values[idx] 233 | if value < x: 234 | idx = self_right[idx] 235 | elif x < value: 236 | self_size_l[idx] -= 1 # 値の存在を保証しているので 237 | idx = self_left[idx] 238 | else: # x == value 239 | idx_x = idx 240 | self_size_l[idx] -= 1 241 | idx = self_left[idx] 242 | 243 | idx_last_par, idx_last = path[-2:] 244 | 245 | if idx_last == idx_x: # x に左の子が存在しない 246 | # 親の idx を付け替える 247 | if self_left[idx_last_par] == idx_x: 248 | self_left[idx_last_par] = self_right[idx_x] 249 | self_diff[idx_last_par] -= 1 250 | else: 251 | self_right[idx_last_par] = self_right[idx_x] 252 | self_diff[idx_last_par] += 1 253 | else: # x に左の子が存在する 254 | # 自身の value を付け替える 255 | self_values[idx_x] = self_values[idx_last] 256 | if idx_last_par == idx_x: # x 左 idx_last (左 _)? 257 | self_left[idx_last_par] = self_left[idx_last] 258 | self_diff[idx_last_par] -= 1 259 | else: # x 左 _ 右 ... 右 idx_last (左 _)? 260 | self_right[idx_last_par] = self_left[idx_last] 261 | self_diff[idx_last_par] += 1 262 | 263 | self_rotate_left = self.rotate_left 264 | self_rotate_right = self.rotate_right 265 | diff = self_diff[idx_last_par] 266 | idx = idx_last_par 267 | for idx_par in path[-3::-1]: 268 | # assert diff == self_diff[idx] 269 | lr = self_right[idx_par] == idx 270 | if diff == 0: 271 | pass 272 | elif diff == 2: # 右回転 273 | diff_ = self_rotate_right(idx_par, lr) 274 | if diff_ != 0: 275 | return True 276 | elif diff == -2: # 左回転 277 | diff_ = self_rotate_left(idx_par, lr) 278 | if diff_ != 0: 279 | return True 280 | else: 281 | return True 282 | diff = self_diff[idx_par] = self_diff[idx_par] + (1 if lr else -1) 283 | idx = idx_par 284 | return True 285 | 286 | def pop(self, idx_): 287 | # 小さい方から idx_ 番目の要素を削除してその要素を返す(0-indexed) 288 | # idx_ 番目の値の存在が保証されている必要がある 289 | 290 | path = [0] 291 | idx_x = -1 292 | 293 | self_values = self.values 294 | self_left = self.left 295 | self_right = self.right 296 | self_diff = self.diff 297 | self_size_l = self.size_l 298 | 299 | sum_left = 0 300 | idx = self_right[0] 301 | while idx != -1: 302 | path.append(idx) 303 | c = sum_left + self_size_l[idx] 304 | if idx_ < c: 305 | self_size_l[idx] -= 1 # 値の存在が保証されているので 306 | idx = self_left[idx] 307 | elif c < idx_: 308 | idx = self_right[idx] 309 | sum_left = c + 1 310 | else: 311 | idx_x = idx 312 | x = self_values[idx] 313 | self_size_l[idx] -= 1 # なんで? 314 | idx = self_left[idx] 315 | 316 | idx_last_par, idx_last = path[-2:] 317 | 318 | if idx_last == idx_x: # x に左の子が存在しない 319 | # 親の idx を付け替える 320 | if self_left[idx_last_par] == idx_x: 321 | self_left[idx_last_par] = self_right[idx_x] 322 | self_diff[idx_last_par] -= 1 323 | else: 324 | self_right[idx_last_par] = self_right[idx_x] 325 | self_diff[idx_last_par] += 1 326 | else: # x に左の子が存在する 327 | # 自身の value を付け替える 328 | self_values[idx_x] = self_values[idx_last] 329 | if idx_last_par == idx_x: # x 左 idx_last (左 _)? 330 | self_left[idx_last_par] = self_left[idx_last] 331 | self_diff[idx_last_par] -= 1 332 | else: # x 左 _ 右 ... 右 idx_last (左 _)? 333 | self_right[idx_last_par] = self_left[idx_last] 334 | self_diff[idx_last_par] += 1 335 | 336 | self_rotate_left = self.rotate_left 337 | self_rotate_right = self.rotate_right 338 | diff = self_diff[idx_last_par] 339 | idx = idx_last_par 340 | for idx_par in path[-3::-1]: 341 | # assert diff == self_diff[idx] 342 | lr = self_right[idx_par] == idx 343 | if diff == 0: 344 | pass 345 | elif diff == 2: # 右回転 346 | diff_ = self_rotate_right(idx_par, lr) 347 | if diff_ != 0: 348 | return x 349 | elif diff == -2: # 左回転 350 | diff_ = self_rotate_left(idx_par, lr) 351 | if diff_ != 0: 352 | return x 353 | else: 354 | return x 355 | diff = self_diff[idx_par] = self_diff[idx_par] + (1 if lr else -1) 356 | idx = idx_par 357 | return x 358 | 359 | def __getitem__(self, idx_): 360 | # 小さい方から idx_ 番目の要素返す 361 | 362 | self_left = self.left 363 | self_right = self.right 364 | self_size_l = self.size_l 365 | 366 | sum_left = 0 367 | idx = self_right[0] 368 | while idx != -1: 369 | c = sum_left + self_size_l[idx] 370 | if idx_ < c: 371 | idx = self_left[idx] 372 | elif c < idx_: 373 | idx = self_right[idx] 374 | sum_left = c + 1 375 | else: # c == idx_ 376 | return self.values[idx] 377 | raise IndexError 378 | 379 | def __contains__(self, x): # count 380 | # 値 x があるか 381 | 382 | self_left = self.left 383 | self_right = self.right 384 | self_values = self.values 385 | self_size_l = self.size_l 386 | 387 | idx = self_right[0] 388 | res = 0 389 | while idx != -1: 390 | value = self_values[idx] 391 | if value < x: 392 | res += self_size_l[idx] + 1 393 | idx = self_right[idx] 394 | elif x < value: 395 | idx = self_left[idx] 396 | else: 397 | return True # res + self_size_l[idx] 398 | return False 399 | 400 | def bisect_left(self, x): # lower_bound 401 | self_left = self.left 402 | self_right = self.right 403 | self_values = self.values 404 | self_size_l = self.size_l 405 | 406 | idx = self_right[0] 407 | res = 0 408 | while idx != -1: 409 | value = self_values[idx] 410 | if value < x: 411 | res += self_size_l[idx] + 1 412 | idx = self_right[idx] 413 | elif x < value: 414 | idx = self_left[idx] 415 | else: # value == x 416 | return res + self_size_l[idx] 417 | return res 418 | 419 | def bisect_right(self, x): # upper_bound 420 | self_left = self.left 421 | self_right = self.right 422 | self_values = self.values 423 | self_size_l = self.size_l 424 | 425 | idx = self_right[0] 426 | res = 0 427 | while idx != -1: 428 | value = self_values[idx] 429 | if value < x: 430 | res += self_size_l[idx] + 1 431 | idx = self_right[idx] 432 | elif x < value: 433 | idx = self_left[idx] 434 | else: # value == x: 435 | return res + self_size_l[idx] + 1 436 | return res 437 | 438 | def print_tree(self, idx=0, depth=0, from_="・"): 439 | if idx == 0: 440 | idx = self.right[idx] 441 | if idx == -1: 442 | return 443 | self.print_tree(self.left[idx], depth + 1, "┏") 444 | print("\t\t" * depth + from_ + " val=[" + str(self.values[idx]) + 445 | "] diff=[" + str(self.diff[idx]) + 446 | "] size_l=[" + str(self.size_l[idx]) + "]") 447 | self.print_tree(self.right[idx], depth + 1, "┗") 448 | 449 | 450 | # 検証1: https://atcoder.jp/contests/cpsco2019-s1/submissions/5788902 451 | # 検証2: https://atcoder.jp/contests/arc033/submissions/6945940 452 | -------------------------------------------------------------------------------- /binary_indexed_tree.py: -------------------------------------------------------------------------------- 1 | class Bit: 2 | # 参考1: http://hos.ac/slides/20140319_bit.pdf 3 | # 参考2: https://atcoder.jp/contests/arc046/submissions/6264201 4 | # 検証: https://atcoder.jp/contests/arc046/submissions/7435621 5 | # values の 0 番目は使わない 6 | # len(values) を 2 冪 +1 にすることで二分探索の条件を減らす 7 | def __init__(self, a): 8 | if hasattr(a, "__iter__"): 9 | le = len(a) 10 | self.n = 1 << le.bit_length() # le を超える最小の 2 冪 11 | self.values = values = [0] * (self.n+1) 12 | values[1:le+1] = a[:] 13 | for i in range(1, self.n): 14 | values[i + (i & -i)] += values[i] 15 | elif isinstance(a, int): 16 | self.n = 1 << a.bit_length() 17 | self.values = [0] * (self.n+1) 18 | else: 19 | raise TypeError 20 | 21 | def add(self, i, val): 22 | n, values = self.n, self.values 23 | while i <= n: 24 | values[i] += val 25 | i += i & -i 26 | 27 | def sum(self, i): # (0, i] 28 | values = self.values 29 | res = 0 30 | while i > 0: 31 | res += values[i] 32 | i -= i & -i 33 | return res 34 | 35 | def bisect_left(self, v): # self.sum(i) が v 以上になる最小の i 36 | n, values = self.n, self.values 37 | if v > values[n]: 38 | return None 39 | i, step = 0, n>>1 40 | while step: 41 | if values[i+step] < v: 42 | i += step 43 | v -= values[i] 44 | step >>= 1 45 | return i + 1 46 | -------------------------------------------------------------------------------- /codeforces.py: -------------------------------------------------------------------------------- 1 | # region fastio # from https://codeforces.com/contest/1333/submission/75948789 2 | import sys, io, os 3 | 4 | BUFSIZE = 8192 5 | 6 | 7 | class FastIO(io.IOBase): 8 | newlines = 0 9 | 10 | def __init__(self, file): 11 | self._fd = file.fileno() 12 | self.buffer = io.BytesIO() 13 | self.writable = "x" in file.mode or "r" not in file.mode 14 | self.write = self.buffer.write if self.writable else None 15 | 16 | def read(self): 17 | while True: 18 | b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) 19 | if not b: 20 | break 21 | ptr = self.buffer.tell() 22 | self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) 23 | self.newlines = 0 24 | return self.buffer.read() 25 | 26 | def readline(self): 27 | while self.newlines == 0: 28 | b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) 29 | self.newlines = b.count(b"\n") + (not b) 30 | ptr = self.buffer.tell() 31 | self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) 32 | self.newlines -= 1 33 | return self.buffer.readline() 34 | 35 | def flush(self): 36 | if self.writable: 37 | os.write(self._fd, self.buffer.getvalue()) 38 | self.buffer.truncate(0), self.buffer.seek(0) 39 | 40 | 41 | class IOWrapper(io.IOBase): 42 | def __init__(self, file): 43 | self.buffer = FastIO(file) 44 | self.flush = self.buffer.flush 45 | self.writable = self.buffer.writable 46 | self.write = lambda s: self.buffer.write(s.encode("ascii")) 47 | self.read = lambda: self.buffer.read().decode("ascii") 48 | self.readline = lambda: self.buffer.readline().decode("ascii") 49 | 50 | 51 | def print(*args, **kwargs): 52 | """Prints the values to a stream, or to sys.stdout by default.""" 53 | sep, file = kwargs.pop("sep", " "), kwargs.pop("file", sys.stdout) 54 | at_start = True 55 | for x in args: 56 | if not at_start: 57 | file.write(sep) 58 | file.write(str(x)) 59 | at_start = False 60 | file.write(kwargs.pop("end", "\n")) 61 | if kwargs.pop("flush", False): 62 | file.flush() 63 | 64 | 65 | sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout) 66 | input = lambda: sys.stdin.readline().rstrip("\r\n") 67 | #endregion 68 | -------------------------------------------------------------------------------- /cpp_extension/acl__document.py: -------------------------------------------------------------------------------- 1 | class SegTree: 2 | def __init__(self, op, e, n): 3 | # n は要素数または iterable 4 | # O(n) 5 | pass 6 | 7 | def set(self, p, x): 8 | # O(log(n)) 9 | pass 10 | 11 | def get(self, p): 12 | # O(1) 13 | pass 14 | 15 | def prod(self, l, r): 16 | # O(log(n)) 17 | pass 18 | 19 | def all_prod(self): 20 | # O(1) 21 | pass 22 | 23 | def max_right(self, l, f): 24 | # O(log(n)) 25 | pass 26 | 27 | def max_left(self, r, f): 28 | # O(log(n)) 29 | pass 30 | 31 | def to_list(self): 32 | # O(n) 33 | pass 34 | 35 | def test_segtree(): 36 | from atcoder import SegTree 37 | 38 | input = sys.stdin.buffer.readline 39 | N, Q = map(int, input().split()) 40 | A = map(int, input().split()) 41 | 42 | op = lambda a, b: a if a > b else b 43 | e = -1 44 | seg = SegTree(op, e, A) 45 | 46 | Ans = [] 47 | m = map(int, sys.stdin.buffer.read().split()) 48 | for t, x, v in zip(m, m, m): 49 | x -= 1 50 | if t == 1: 51 | seg.set(x, v) 52 | elif t == 2: 53 | ans = seg.prod(x, v) 54 | Ans.append(ans) 55 | else: 56 | ans = seg.max_right(x, lambda a: a < v) + 1 57 | Ans.append(ans) 58 | if Ans: 59 | print("\n".join(map(str, Ans))) 60 | 61 | 62 | class LazySegTree: 63 | def __init__(self, op, e, mapping, composition, identity, n): 64 | # n は要素数または iterable 65 | # O(n) 66 | pass 67 | 68 | def set(self, p, x): 69 | # O(log(n)) 70 | pass 71 | 72 | def get(self, p): 73 | # O(log(n)) 74 | pass 75 | 76 | def prod(self, l, r): 77 | # O(log(n)) 78 | pass 79 | 80 | def prod_getitem(self, l, r, idx): # original 81 | # O(log(n)) 82 | pass 83 | 84 | def all_prod(self): 85 | # O(1) 86 | pass 87 | 88 | def apply(self, l, r, x=None): 89 | # O(log(n) 90 | # 2 引数で呼び出した場合は p, x で、 p の 1 箇所にのみ適用 91 | pass 92 | 93 | def max_right(self, l, f): 94 | # O(log(n)) 95 | pass 96 | 97 | def max_left(self, r, f): 98 | # O(log(n)) 99 | pass 100 | 101 | def to_list(self): 102 | # O(n) のはず 103 | pass 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /cpp_extension/acl_lazysegtree.py: -------------------------------------------------------------------------------- 1 | # TODO: メモリリーク確認 2 | # TODO: max_right とかが正しく動くか検証 3 | # TODO: 更新ルールの異なる複数のセグ木を作ったときに正しく動くか検証 4 | 5 | 6 | code_lazy_segtree = r""" 7 | #define PY_SSIZE_T_CLEAN 8 | #include 9 | #include "structmember.h" 10 | 11 | //#define ALLOW_MEMORY_LEAK // メモリリーク許容して高速化 12 | #define ILLEGAL_FUNC_CALL // 違法な内部 API 呼び出しで高速化 13 | 14 | // >>> AtCoder >>> 15 | 16 | #ifndef ATCODER_LAZYSEGTREE_HPP 17 | #define ATCODER_LAZYSEGTREE_HPP 1 18 | 19 | #include 20 | 21 | #ifndef ATCODER_INTERNAL_BITOP_HPP 22 | #define ATCODER_INTERNAL_BITOP_HPP 1 23 | 24 | #ifdef _MSC_VER 25 | #include 26 | #endif 27 | 28 | namespace atcoder { 29 | 30 | namespace internal { 31 | 32 | // @param n `0 <= n` 33 | // @return minimum non-negative `x` s.t. `n <= 2**x` 34 | int ceil_pow2(int n) { 35 | int x = 0; 36 | while ((1U << x) < (unsigned int)(n)) x++; 37 | return x; 38 | } 39 | 40 | // @param n `1 <= n` 41 | // @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0` 42 | int bsf(unsigned int n) { 43 | #ifdef _MSC_VER 44 | unsigned long index; 45 | _BitScanForward(&index, n); 46 | return index; 47 | #else 48 | return __builtin_ctz(n); 49 | #endif 50 | } 51 | 52 | } // namespace internal 53 | 54 | } // namespace atcoder 55 | 56 | #endif // ATCODER_INTERNAL_BITOP_HPP 57 | 58 | #include 59 | #include 60 | #include 61 | namespace atcoder { 62 | 63 | template 70 | struct lazy_segtree { 71 | public: 72 | lazy_segtree() : lazy_segtree(0) {} 73 | lazy_segtree(int n) : lazy_segtree(std::vector(n, e())) {} 74 | lazy_segtree(const std::vector& v) : _n(int(v.size())) { 75 | log = internal::ceil_pow2(_n); 76 | size = 1 << log; 77 | d = std::vector(2 * size, e()); 78 | lz = std::vector(size, id()); 79 | for (int i = 0; i < _n; i++) d[size + i] = v[i]; 80 | for (int i = size - 1; i >= 1; i--) { 81 | update(i); 82 | } 83 | } 84 | 85 | void set(int p, S x) { 86 | assert(0 <= p && p < _n); 87 | p += size; 88 | for (int i = log; i >= 1; i--) push(p >> i); 89 | d[p] = x; 90 | for (int i = 1; i <= log; i++) update(p >> i); 91 | } 92 | 93 | S get(int p) { 94 | assert(0 <= p && p < _n); 95 | p += size; 96 | for (int i = log; i >= 1; i--) push(p >> i); 97 | return d[p]; 98 | } 99 | 100 | S prod(int l, int r) { 101 | assert(0 <= l && l <= r && r <= _n); 102 | if (l == r) return e(); 103 | 104 | l += size; 105 | r += size; 106 | 107 | for (int i = log; i >= 1; i--) { 108 | if (((l >> i) << i) != l) push(l >> i); 109 | if (((r >> i) << i) != r) push(r >> i); 110 | } 111 | 112 | S sml = e(), smr = e(); 113 | while (l < r) { 114 | if (l & 1) sml = op(sml, d[l++]); 115 | if (r & 1) smr = op(d[--r], smr); 116 | l >>= 1; 117 | r >>= 1; 118 | } 119 | 120 | return op(sml, smr); 121 | } 122 | 123 | S all_prod() { return d[1]; } 124 | 125 | void apply(int p, F f) { 126 | assert(0 <= p && p < _n); 127 | p += size; 128 | for (int i = log; i >= 1; i--) push(p >> i); 129 | d[p] = mapping(f, d[p]); 130 | for (int i = 1; i <= log; i++) update(p >> i); 131 | } 132 | void apply(int l, int r, F f) { 133 | assert(0 <= l && l <= r && r <= _n); 134 | if (l == r) return; 135 | 136 | l += size; 137 | r += size; 138 | 139 | for (int i = log; i >= 1; i--) { 140 | if (((l >> i) << i) != l) push(l >> i); 141 | if (((r >> i) << i) != r) push((r - 1) >> i); 142 | } 143 | 144 | { 145 | int l2 = l, r2 = r; 146 | while (l < r) { 147 | if (l & 1) all_apply(l++, f); 148 | if (r & 1) all_apply(--r, f); 149 | l >>= 1; 150 | r >>= 1; 151 | } 152 | l = l2; 153 | r = r2; 154 | } 155 | 156 | for (int i = 1; i <= log; i++) { 157 | if (((l >> i) << i) != l) update(l >> i); 158 | if (((r >> i) << i) != r) update((r - 1) >> i); 159 | } 160 | } 161 | 162 | template int max_right(int l) { 163 | return max_right(l, [](S x) { return g(x); }); 164 | } 165 | template int max_right(int l, G g) { 166 | assert(0 <= l && l <= _n); 167 | assert(g(e())); 168 | if (l == _n) return _n; 169 | l += size; 170 | for (int i = log; i >= 1; i--) push(l >> i); 171 | S sm = e(); 172 | do { 173 | while (l % 2 == 0) l >>= 1; 174 | if (!g(op(sm, d[l]))) { 175 | while (l < size) { 176 | push(l); 177 | l = (2 * l); 178 | if (g(op(sm, d[l]))) { 179 | sm = op(sm, d[l]); 180 | l++; 181 | } 182 | } 183 | return l - size; 184 | } 185 | sm = op(sm, d[l]); 186 | l++; 187 | } while ((l & -l) != l); 188 | return _n; 189 | } 190 | 191 | template int min_left(int r) { 192 | return min_left(r, [](S x) { return g(x); }); 193 | } 194 | template int min_left(int r, G g) { 195 | assert(0 <= r && r <= _n); 196 | assert(g(e())); 197 | if (r == 0) return 0; 198 | r += size; 199 | for (int i = log; i >= 1; i--) push((r - 1) >> i); 200 | S sm = e(); 201 | do { 202 | r--; 203 | while (r > 1 && (r % 2)) r >>= 1; 204 | if (!g(op(d[r], sm))) { 205 | while (r < size) { 206 | push(r); 207 | r = (2 * r + 1); 208 | if (g(op(d[r], sm))) { 209 | sm = op(d[r], sm); 210 | r--; 211 | } 212 | } 213 | return r + 1 - size; 214 | } 215 | sm = op(d[r], sm); 216 | } while ((r & -r) != r); 217 | return 0; 218 | } 219 | 220 | private: 221 | int _n, size, log; 222 | std::vector d; 223 | std::vector lz; 224 | 225 | void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); } 226 | void all_apply(int k, F f) { 227 | d[k] = mapping(f, d[k]); 228 | if (k < size) lz[k] = composition(f, lz[k]); 229 | } 230 | void push(int k) { 231 | all_apply(2 * k, lz[k]); 232 | all_apply(2 * k + 1, lz[k]); 233 | lz[k] = id(); 234 | } 235 | }; 236 | 237 | } // namespace atcoder 238 | 239 | #endif // ATCODER_LAZYSEGTREE_HPP 240 | 241 | // <<< AtCoder <<< 242 | 243 | using namespace std; 244 | using namespace atcoder; 245 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 246 | 247 | struct AutoDecrefPtr{ 248 | PyObject* p; 249 | AutoDecrefPtr(PyObject* _p) : p(_p) {}; 250 | #ifndef ALLOW_MEMORY_LEAK 251 | AutoDecrefPtr(const AutoDecrefPtr& rhs) : p(rhs.p) { Py_INCREF(p); }; 252 | ~AutoDecrefPtr(){ Py_DECREF(p); } 253 | AutoDecrefPtr &operator=(const AutoDecrefPtr& rhs){ 254 | Py_DECREF(p); 255 | p = rhs.p; 256 | Py_INCREF(p); 257 | return *this; 258 | } 259 | #endif 260 | }; 261 | 262 | #ifdef ILLEGAL_FUNC_CALL 263 | static PyObject* py_function_args[2]; 264 | #endif 265 | 266 | // >>> functions for laze_segtree constructor >>> 267 | static PyObject* lazy_segtree_op_py; 268 | static AutoDecrefPtr lazy_segtree_op(AutoDecrefPtr a, AutoDecrefPtr b){ 269 | #ifdef ILLEGAL_FUNC_CALL 270 | py_function_args[0] = a.p; 271 | py_function_args[1] = b.p; 272 | PyObject* res = _PyObject_FastCall(lazy_segtree_op_py, py_function_args, 2); 273 | #else 274 | PyObject* res(PyObject_CallFunctionObjArgs(lazy_segtree_op_py, a.p, b.p, NULL)); 275 | #endif 276 | Py_INCREF(res); // ??????????????? 277 | return AutoDecrefPtr(res); 278 | } 279 | static PyObject* lazy_segtree_e_py; 280 | static AutoDecrefPtr lazy_segtree_e(){ 281 | Py_INCREF(lazy_segtree_e_py); 282 | return AutoDecrefPtr(lazy_segtree_e_py); 283 | } 284 | static PyObject* lazy_segtree_mapping_py; 285 | static AutoDecrefPtr lazy_segtree_mapping(AutoDecrefPtr f, AutoDecrefPtr x){ 286 | #ifdef ILLEGAL_FUNC_CALL 287 | py_function_args[0] = f.p; 288 | py_function_args[1] = x.p; 289 | PyObject* res = _PyObject_FastCall(lazy_segtree_mapping_py, py_function_args, 2); 290 | return AutoDecrefPtr(res); 291 | #else 292 | return AutoDecrefPtr(PyObject_CallFunctionObjArgs(lazy_segtree_mapping_py, f.p, x.p, NULL)); 293 | #endif 294 | } 295 | static PyObject* lazy_segtree_composition_py; 296 | static AutoDecrefPtr lazy_segtree_composition(AutoDecrefPtr f, AutoDecrefPtr g){ 297 | #ifdef ILLEGAL_FUNC_CALL 298 | py_function_args[0] = f.p; 299 | py_function_args[1] = g.p; 300 | PyObject* res = _PyObject_FastCall(lazy_segtree_composition_py, py_function_args, 2); 301 | return AutoDecrefPtr(res); 302 | #else 303 | return AutoDecrefPtr(PyObject_CallFunctionObjArgs(lazy_segtree_composition_py, f.p, g.p, NULL)); 304 | #endif 305 | } 306 | static PyObject* lazy_segtree_id_py; 307 | static AutoDecrefPtr lazy_segtree_id(){ 308 | Py_INCREF(lazy_segtree_id_py); 309 | return AutoDecrefPtr(lazy_segtree_id_py); 310 | } 311 | using lazyseg = lazy_segtree; 318 | // <<< functions for laze_segtree constructor <<< 319 | 320 | static PyObject* lazy_segtree_f_py; 321 | static bool lazy_segtree_f(AutoDecrefPtr x){ 322 | PyObject* pyfunc_res = PyObject_CallFunctionObjArgs(lazy_segtree_f_py, x.p, NULL); 323 | int res = PyObject_IsTrue(pyfunc_res); 324 | if(res == -1) PyErr_Format(PyExc_ValueError, "error in LazySegTree f"); 325 | return (bool)res; 326 | } 327 | 328 | struct LazySegTree{ 329 | PyObject_HEAD 330 | lazyseg* seg; 331 | PyObject* op; 332 | PyObject* e; 333 | PyObject* mapping; 334 | PyObject* composition; 335 | PyObject* id; 336 | int n; 337 | }; 338 | static inline void set_rules(LazySegTree* self){ 339 | lazy_segtree_op_py = self->op; 340 | lazy_segtree_e_py = self->e; 341 | lazy_segtree_mapping_py = self->mapping; 342 | lazy_segtree_composition_py = self->composition; 343 | lazy_segtree_id_py = self->id; 344 | } 345 | 346 | // >>> LazySegTree functions >>> 347 | 348 | extern PyTypeObject LazySegTreeType; 349 | 350 | static void LazySegTree_dealloc(LazySegTree* self){ 351 | delete self->seg; 352 | Py_DECREF(self->op); 353 | Py_DECREF(self->e); 354 | Py_DECREF(self->mapping); 355 | Py_DECREF(self->composition); 356 | Py_DECREF(self->id); 357 | Py_TYPE(self)->tp_free((PyObject*)self); 358 | } 359 | static PyObject* LazySegTree_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 360 | LazySegTree* self; 361 | self = (LazySegTree*)type->tp_alloc(type, 0); 362 | return (PyObject*)self; 363 | } 364 | static int LazySegTree_init(LazySegTree* self, PyObject* args){ 365 | if(Py_SIZE(args) != 6){ 366 | self->op = Py_None; // 何か入れておかないとヤバいことになる 367 | Py_INCREF(Py_None); 368 | self->e = Py_None; 369 | Py_INCREF(Py_None); 370 | self->mapping = Py_None; 371 | Py_INCREF(Py_None); 372 | self->composition = Py_None; 373 | Py_INCREF(Py_None); 374 | self->id = Py_None; 375 | Py_INCREF(Py_None); 376 | PyErr_Format(PyExc_TypeError, 377 | "LazySegTree constructor expected 6 arguments (op, e, mapping, composition, identity, n), got %d", Py_SIZE(args)); 378 | return -1; 379 | } 380 | PyObject* arg; 381 | if(!PyArg_ParseTuple(args, "OOOOOO", 382 | &self->op, &self->e, 383 | &self->mapping, &self->composition, &self->id, &arg)) return -1; 384 | Py_INCREF(self->op); 385 | Py_INCREF(self->e); 386 | Py_INCREF(self->mapping); 387 | Py_INCREF(self->composition); 388 | Py_INCREF(self->id); 389 | set_rules(self); 390 | if(PyLong_Check(arg)){ 391 | int n = (int)PyLong_AsLong(arg); 392 | if(PyErr_Occurred()) return -1; 393 | if(n < 0 || n > (int)1e8) { 394 | PyErr_Format(PyExc_ValueError, "constraint error in LazySegTree constructor (got %d)", n); 395 | return -1; 396 | } 397 | self->seg = new lazyseg(n); 398 | self->n = n; 399 | }else{ 400 | PyObject *iterator = PyObject_GetIter(arg); 401 | if(iterator==NULL) return -1; 402 | PyObject *item; 403 | vector vec; 404 | if(Py_TYPE(arg)->tp_as_sequence != NULL) vec.reserve((int)Py_SIZE(arg)); 405 | while(item = PyIter_Next(iterator)) { 406 | vec.emplace_back(item); 407 | } 408 | Py_DECREF(iterator); 409 | if (PyErr_Occurred()) return -1; 410 | self->seg = new lazyseg(vec); 411 | self->n = (int)vec.size(); 412 | } 413 | return 0; 414 | } 415 | static PyObject* LazySegTree_set(LazySegTree* self, PyObject* args){ 416 | long p; 417 | PyObject* x; 418 | PARSE_ARGS("lO", &p, &x); 419 | if(p < 0 || p >= self->n){ 420 | PyErr_Format(PyExc_IndexError, "LazySegTree set index out of range (size=%d, index=%d)", self->n, p); 421 | return (PyObject*)NULL; 422 | } 423 | Py_INCREF(x); 424 | set_rules(self); 425 | self->seg->set((int)p, AutoDecrefPtr(x)); 426 | Py_RETURN_NONE; 427 | } 428 | static PyObject* LazySegTree_get(LazySegTree* self, PyObject* args){ 429 | long p; 430 | PARSE_ARGS("l", &p); 431 | if(p < 0 || p >= self->n){ 432 | PyErr_Format(PyExc_IndexError, "LazySegTree get index out of range (size=%d, index=%d)", self->n, p); 433 | return (PyObject*)NULL; 434 | } 435 | set_rules(self); 436 | PyObject* res = self->seg->get((int)p).p; 437 | return Py_BuildValue("O", res); 438 | } 439 | static PyObject* LazySegTree_prod(LazySegTree* self, PyObject* args){ 440 | long l, r; 441 | PARSE_ARGS("ll", &l, &r); 442 | set_rules(self); 443 | PyObject* res = self->seg->prod((int)l, (int)r).p; 444 | return Py_BuildValue("O", res); 445 | } 446 | static PyObject* LazySegTree_prod_getitem(LazySegTree* self, PyObject* args){ 447 | long l, r, idx; 448 | PARSE_ARGS("lll", &l, &r, &idx); 449 | set_rules(self); 450 | PyObject* res = self->seg->prod((int)l, (int)r).p; 451 | res = PySequence_Fast_GET_ITEM(res, idx); // 要素がタプルと仮定 452 | return Py_BuildValue("O", res); 453 | } 454 | static PyObject* LazySegTree_all_prod(LazySegTree* self, PyObject* args){ 455 | PyObject* res = self->seg->all_prod().p; 456 | return Py_BuildValue("O", res); 457 | } 458 | static PyObject* LazySegTree_apply(LazySegTree* self, PyObject* args){ 459 | if(Py_SIZE(args) == 3){ 460 | long l, r; 461 | PyObject* x; 462 | PARSE_ARGS("llO", &l, &r, &x); 463 | Py_INCREF(x); 464 | set_rules(self); 465 | self->seg->apply(l, r, AutoDecrefPtr(x)); 466 | Py_RETURN_NONE; 467 | }else if(Py_SIZE(args) == 2){ 468 | long p; 469 | PyObject* x; 470 | PARSE_ARGS("lO", &p, &x); 471 | if(p < 0 || p >= self->n){ 472 | PyErr_Format(PyExc_IndexError, "LazySegTree apply index out of range (size=%d, index=%d)", self->n, p); 473 | return (PyObject*)NULL; 474 | } 475 | Py_INCREF(x); 476 | set_rules(self); 477 | self->seg->apply(p, AutoDecrefPtr(x)); 478 | Py_RETURN_NONE; 479 | }else{ 480 | PyErr_Format(PyExc_TypeError, 481 | "LazySegTree apply expected 2 (p, x) or 3 (l, r, x) arguments, got %d", Py_SIZE(args)); 482 | return (PyObject*)NULL; 483 | } 484 | } 485 | static PyObject* LazySegTree_max_right(LazySegTree* self, PyObject* args){ 486 | long l; 487 | PARSE_ARGS("lO", &l, &lazy_segtree_f_py); 488 | if(l < 0 || l > self->n){ 489 | PyErr_Format(PyExc_IndexError, "LazySegTree max_right index out of range (size=%d, l=%d)", self->n, l); 490 | return (PyObject*)NULL; 491 | } 492 | set_rules(self); 493 | int res = self->seg->max_right((int)l); 494 | return Py_BuildValue("l", res); 495 | } 496 | static PyObject* LazySegTree_min_left(LazySegTree* self, PyObject* args){ 497 | long r; 498 | PARSE_ARGS("lO", &r, &lazy_segtree_f_py); 499 | if(r < 0 || r > self->n){ 500 | PyErr_Format(PyExc_IndexError, "LazySegTree max_right index out of range (size=%d, r=%d)", self->n, r); 501 | return (PyObject*)NULL; 502 | } 503 | set_rules(self); 504 | int res = self->seg->min_left((int)r); 505 | return Py_BuildValue("l", res); 506 | } 507 | static PyObject* LazySegTree_to_list(LazySegTree* self){ 508 | PyObject* list = PyList_New(self->n); 509 | for(int i=0; in; i++){ 510 | PyObject* val = self->seg->get(i).p; 511 | Py_INCREF(val); 512 | PyList_SET_ITEM(list, i, val); 513 | } 514 | return list; 515 | } 516 | static PyObject* LazySegTree_repr(PyObject* self){ 517 | PyObject* list = LazySegTree_to_list((LazySegTree*)self); 518 | PyObject* res = PyUnicode_FromFormat("LazySegTree(%R)", list); 519 | Py_ReprLeave(self); 520 | Py_DECREF(list); 521 | return res; 522 | } 523 | // <<< LazySegTree functions <<< 524 | 525 | static PyMethodDef LazySegTree_methods[] = { 526 | {"set", (PyCFunction)LazySegTree_set, METH_VARARGS, "Set item"}, 527 | {"get", (PyCFunction)LazySegTree_get, METH_VARARGS, "Get item"}, 528 | {"prod", (PyCFunction)LazySegTree_prod, METH_VARARGS, "Get item"}, 529 | {"prod_getitem", (PyCFunction)LazySegTree_prod_getitem, METH_VARARGS, "Get item"}, 530 | {"all_prod", (PyCFunction)LazySegTree_all_prod, METH_VARARGS, "Get item"}, 531 | {"apply", (PyCFunction)LazySegTree_apply, METH_VARARGS, "Apply function"}, 532 | {"max_right", (PyCFunction)LazySegTree_max_right, METH_VARARGS, "Binary search on lazy segtree"}, 533 | {"min_left", (PyCFunction)LazySegTree_min_left, METH_VARARGS, "Binary search on lazy segtree"}, 534 | {"to_list", (PyCFunction)LazySegTree_to_list, METH_VARARGS, "Convert to list"}, 535 | {NULL} /* Sentinel */ 536 | }; 537 | PyTypeObject LazySegTreeType = { 538 | PyObject_HEAD_INIT(NULL) 539 | "atcoder.LazySegTree", /*tp_name*/ 540 | sizeof(LazySegTree), /*tp_basicsize*/ 541 | 0, /*tp_itemsize*/ 542 | (destructor)LazySegTree_dealloc, /*tp_dealloc*/ 543 | 0, /*tp_print*/ 544 | 0, /*tp_getattr*/ 545 | 0, /*tp_setattr*/ 546 | 0, /*reserved*/ 547 | LazySegTree_repr, /*tp_repr*/ 548 | 0, /*tp_as_number*/ 549 | 0, /*tp_as_sequence*/ 550 | 0, /*tp_as_mapping*/ 551 | 0, /*tp_hash*/ 552 | 0, /*tp_call*/ 553 | 0, /*tp_str*/ 554 | 0, /*tp_getattro*/ 555 | 0, /*tp_setattro*/ 556 | 0, /*tp_as_buffer*/ 557 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 558 | 0, /*tp_doc*/ 559 | 0, /*tp_traverse*/ 560 | 0, /*tp_clear*/ 561 | 0, /*tp_richcompare*/ 562 | 0, /*tp_weaklistoffset*/ 563 | 0, /*tp_iter*/ 564 | 0, /*tp_iternext*/ 565 | LazySegTree_methods, /*tp_methods*/ 566 | 0, /*tp_members*/ 567 | 0, /*tp_getset*/ 568 | 0, /*tp_base*/ 569 | 0, /*tp_dict*/ 570 | 0, /*tp_descr_get*/ 571 | 0, /*tp_descr_set*/ 572 | 0, /*tp_dictoffset*/ 573 | (initproc)LazySegTree_init, /*tp_init*/ 574 | 0, /*tp_alloc*/ 575 | LazySegTree_new, /*tp_new*/ 576 | 0, /*tp_free*/ 577 | 0, /*tp_is_gc*/ 578 | 0, /*tp_bases*/ 579 | 0, /*tp_mro*/ 580 | 0, /*tp_cache*/ 581 | 0, /*tp_subclasses*/ 582 | 0, /*tp_weaklist*/ 583 | 0, /*tp_del*/ 584 | 0, /*tp_version_tag*/ 585 | 0, /*tp_finalize*/ 586 | }; 587 | 588 | static PyModuleDef atcodermodule = { 589 | PyModuleDef_HEAD_INIT, 590 | "atcoder", 591 | NULL, 592 | -1, 593 | }; 594 | 595 | PyMODINIT_FUNC PyInit_atcoder(void) 596 | { 597 | PyObject* m; 598 | if(PyType_Ready(&LazySegTreeType) < 0) return NULL; 599 | 600 | m = PyModule_Create(&atcodermodule); 601 | if(m == NULL) return NULL; 602 | 603 | Py_INCREF(&LazySegTreeType); 604 | if (PyModule_AddObject(m, "LazySegTree", (PyObject*)&LazySegTreeType) < 0) { 605 | Py_DECREF(&LazySegTreeType); 606 | Py_DECREF(m); 607 | return NULL; 608 | } 609 | 610 | return m; 611 | } 612 | """ 613 | code_setup = r""" 614 | from distutils.core import setup, Extension 615 | module = Extension( 616 | "atcoder", 617 | sources=["atcoder_library_wrapper.cpp"], 618 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 619 | ) 620 | setup( 621 | name="atcoder-library", 622 | version="0.0.1", 623 | description="wrapper for atcoder library", 624 | ext_modules=[module] 625 | ) 626 | """ 627 | 628 | import os 629 | import sys 630 | 631 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 632 | with open("atcoder_library_wrapper.cpp", "w") as f: 633 | f.write(code_lazy_segtree) 634 | with open("setup.py", "w") as f: 635 | f.write(code_setup) 636 | os.system(f"{sys.executable} setup.py build_ext --inplace") 637 | 638 | from atcoder import LazySegTree 639 | -------------------------------------------------------------------------------- /cpp_extension/acl_math.py: -------------------------------------------------------------------------------- 1 | # TODO: メモリリーク確認 2 | 3 | code_acl_math = r""" 4 | #define PY_SSIZE_T_CLEAN 5 | #include 6 | 7 | 8 | // >>> AtCoder >>> 9 | 10 | #ifndef ATCODER_MATH_HPP 11 | #define ATCODER_MATH_HPP 1 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #ifndef ATCODER_INTERNAL_MATH_HPP 19 | #define ATCODER_INTERNAL_MATH_HPP 1 20 | 21 | #include 22 | 23 | #ifdef _MSC_VER 24 | #include 25 | #endif 26 | 27 | namespace atcoder { 28 | 29 | namespace internal { 30 | 31 | // @param m `1 <= m` 32 | // @return x mod m 33 | constexpr long long safe_mod(long long x, long long m) { 34 | x %= m; 35 | if (x < 0) x += m; 36 | return x; 37 | } 38 | 39 | // Fast modular multiplication by barrett reduction 40 | // Reference: https://en.wikipedia.org/wiki/Barrett_reduction 41 | // NOTE: reconsider after Ice Lake 42 | struct barrett { 43 | unsigned int _m; 44 | unsigned long long im; 45 | 46 | // @param m `1 <= m < 2^31` 47 | barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {} 48 | 49 | // @return m 50 | unsigned int umod() const { return _m; } 51 | 52 | // @param a `0 <= a < m` 53 | // @param b `0 <= b < m` 54 | // @return `a * b % m` 55 | unsigned int mul(unsigned int a, unsigned int b) const { 56 | // [1] m = 1 57 | // a = b = im = 0, so okay 58 | 59 | // [2] m >= 2 60 | // im = ceil(2^64 / m) 61 | // -> im * m = 2^64 + r (0 <= r < m) 62 | // let z = a*b = c*m + d (0 <= c, d < m) 63 | // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im 64 | // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 65 | // ((ab * im) >> 64) == c or c + 1 66 | unsigned long long z = a; 67 | z *= b; 68 | #ifdef _MSC_VER 69 | unsigned long long x; 70 | _umul128(z, im, &x); 71 | #else 72 | unsigned long long x = 73 | (unsigned long long)(((unsigned __int128)(z)*im) >> 64); 74 | #endif 75 | unsigned int v = (unsigned int)(z - x * _m); 76 | if (_m <= v) v += _m; 77 | return v; 78 | } 79 | }; 80 | 81 | // @param n `0 <= n` 82 | // @param m `1 <= m` 83 | // @return `(x ** n) % m` 84 | constexpr long long pow_mod_constexpr(long long x, long long n, int m) { 85 | if (m == 1) return 0; 86 | unsigned int _m = (unsigned int)(m); 87 | unsigned long long r = 1; 88 | unsigned long long y = safe_mod(x, m); 89 | while (n) { 90 | if (n & 1) r = (r * y) % _m; 91 | y = (y * y) % _m; 92 | n >>= 1; 93 | } 94 | return r; 95 | } 96 | 97 | // Reference: 98 | // M. Forisek and J. Jancina, 99 | // Fast Primality Testing for Integers That Fit into a Machine Word 100 | // @param n `0 <= n` 101 | constexpr bool is_prime_constexpr(int n) { 102 | if (n <= 1) return false; 103 | if (n == 2 || n == 7 || n == 61) return true; 104 | if (n % 2 == 0) return false; 105 | long long d = n - 1; 106 | while (d % 2 == 0) d /= 2; 107 | constexpr long long bases[3] = {2, 7, 61}; 108 | for (long long a : bases) { 109 | long long t = d; 110 | long long y = pow_mod_constexpr(a, t, n); 111 | while (t != n - 1 && y != 1 && y != n - 1) { 112 | y = y * y % n; 113 | t <<= 1; 114 | } 115 | if (y != n - 1 && t % 2 == 0) { 116 | return false; 117 | } 118 | } 119 | return true; 120 | } 121 | template constexpr bool is_prime = is_prime_constexpr(n); 122 | 123 | // @param b `1 <= b` 124 | // @return pair(g, x) s.t. g = gcd(a, b), xa = g (mod b), 0 <= x < b/g 125 | constexpr std::pair inv_gcd(long long a, long long b) { 126 | a = safe_mod(a, b); 127 | if (a == 0) return {b, 0}; 128 | 129 | // Contracts: 130 | // [1] s - m0 * a = 0 (mod b) 131 | // [2] t - m1 * a = 0 (mod b) 132 | // [3] s * |m1| + t * |m0| <= b 133 | long long s = b, t = a; 134 | long long m0 = 0, m1 = 1; 135 | 136 | while (t) { 137 | long long u = s / t; 138 | s -= t * u; 139 | m0 -= m1 * u; // |m1 * u| <= |m1| * s <= b 140 | 141 | // [3]: 142 | // (s - t * u) * |m1| + t * |m0 - m1 * u| 143 | // <= s * |m1| - t * u * |m1| + t * (|m0| + |m1| * u) 144 | // = s * |m1| + t * |m0| <= b 145 | 146 | auto tmp = s; 147 | s = t; 148 | t = tmp; 149 | tmp = m0; 150 | m0 = m1; 151 | m1 = tmp; 152 | } 153 | // by [3]: |m0| <= b/g 154 | // by g != b: |m0| < b/g 155 | if (m0 < 0) m0 += b / s; 156 | return {s, m0}; 157 | } 158 | 159 | // Compile time primitive root 160 | // @param m must be prime 161 | // @return primitive root (and minimum in now) 162 | constexpr int primitive_root_constexpr(int m) { 163 | if (m == 2) return 1; 164 | if (m == 167772161) return 3; 165 | if (m == 469762049) return 3; 166 | if (m == 754974721) return 11; 167 | if (m == 998244353) return 3; 168 | int divs[20] = {}; 169 | divs[0] = 2; 170 | int cnt = 1; 171 | int x = (m - 1) / 2; 172 | while (x % 2 == 0) x /= 2; 173 | for (int i = 3; (long long)(i)*i <= x; i += 2) { 174 | if (x % i == 0) { 175 | divs[cnt++] = i; 176 | while (x % i == 0) { 177 | x /= i; 178 | } 179 | } 180 | } 181 | if (x > 1) { 182 | divs[cnt++] = x; 183 | } 184 | for (int g = 2;; g++) { 185 | bool ok = true; 186 | for (int i = 0; i < cnt; i++) { 187 | if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) { 188 | ok = false; 189 | break; 190 | } 191 | } 192 | if (ok) return g; 193 | } 194 | } 195 | template constexpr int primitive_root = primitive_root_constexpr(m); 196 | 197 | } // namespace internal 198 | 199 | } // namespace atcoder 200 | 201 | #endif // ATCODER_INTERNAL_MATH_HPP 202 | 203 | namespace atcoder { 204 | 205 | long long pow_mod(long long x, long long n, int m) { 206 | assert(0 <= n && 1 <= m); 207 | if (m == 1) return 0; 208 | internal::barrett bt((unsigned int)(m)); 209 | unsigned int r = 1, y = (unsigned int)(internal::safe_mod(x, m)); 210 | while (n) { 211 | if (n & 1) r = bt.mul(r, y); 212 | y = bt.mul(y, y); 213 | n >>= 1; 214 | } 215 | return r; 216 | } 217 | 218 | long long inv_mod(long long x, long long m) { 219 | assert(1 <= m); 220 | auto z = internal::inv_gcd(x, m); 221 | assert(z.first == 1); 222 | return z.second; 223 | } 224 | 225 | // (rem, mod) 226 | std::pair crt(const std::vector& r, 227 | const std::vector& m) { 228 | assert(r.size() == m.size()); 229 | int n = int(r.size()); 230 | // Contracts: 0 <= r0 < m0 231 | long long r0 = 0, m0 = 1; 232 | for (int i = 0; i < n; i++) { 233 | assert(1 <= m[i]); 234 | long long r1 = internal::safe_mod(r[i], m[i]), m1 = m[i]; 235 | if (m0 < m1) { 236 | std::swap(r0, r1); 237 | std::swap(m0, m1); 238 | } 239 | if (m0 % m1 == 0) { 240 | if (r0 % m1 != r1) return {0, 0}; 241 | continue; 242 | } 243 | // assume: m0 > m1, lcm(m0, m1) >= 2 * max(m0, m1) 244 | 245 | // (r0, m0), (r1, m1) -> (r2, m2 = lcm(m0, m1)); 246 | // r2 % m0 = r0 247 | // r2 % m1 = r1 248 | // -> (r0 + x*m0) % m1 = r1 249 | // -> x*u0*g % (u1*g) = (r1 - r0) (u0*g = m0, u1*g = m1) 250 | // -> x = (r1 - r0) / g * inv(u0) (mod u1) 251 | 252 | // im = inv(u0) (mod u1) (0 <= im < u1) 253 | long long g, im; 254 | std::tie(g, im) = internal::inv_gcd(m0, m1); 255 | 256 | long long u1 = (m1 / g); 257 | // |r1 - r0| < (m0 + m1) <= lcm(m0, m1) 258 | if ((r1 - r0) % g) return {0, 0}; 259 | 260 | // u1 * u1 <= m1 * m1 / g / g <= m0 * m1 / g = lcm(m0, m1) 261 | long long x = (r1 - r0) / g % u1 * im % u1; 262 | 263 | // |r0| + |m0 * x| 264 | // < m0 + m0 * (u1 - 1) 265 | // = m0 + m0 * m1 / g - m0 266 | // = lcm(m0, m1) 267 | r0 += x * m0; 268 | m0 *= u1; // -> lcm(m0, m1) 269 | if (r0 < 0) r0 += m0; 270 | } 271 | return {r0, m0}; 272 | } 273 | 274 | long long floor_sum(long long n, long long m, long long a, long long b) { 275 | long long ans = 0; 276 | if (a >= m) { 277 | ans += (n - 1) * n * (a / m) / 2; 278 | a %= m; 279 | } 280 | if (b >= m) { 281 | ans += n * (b / m); 282 | b %= m; 283 | } 284 | 285 | long long y_max = (a * n + b) / m, x_max = (y_max * m - b); 286 | if (y_max == 0) return ans; 287 | ans += (n - (x_max + a - 1) / a) * y_max; 288 | ans += floor_sum(y_max, a, m, (a - x_max % a) % a); 289 | return ans; 290 | } 291 | 292 | } // namespace atcoder 293 | 294 | #endif // ATCODER_MATH_HPP 295 | 296 | // <<< AtCoder <<< 297 | 298 | 299 | using namespace std; 300 | using namespace atcoder; 301 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 302 | 303 | 304 | // >>> acl_math definition >>> 305 | 306 | static PyObject* acl_math_pow_mod(PyObject* self, PyObject* args){ 307 | long long x, n; 308 | long m; 309 | PARSE_ARGS("LLl", &x, &n, &m); 310 | if(n < 0 || m <= 0){ 311 | PyErr_Format(PyExc_IndexError, 312 | "pow_mod constraint error (costraint: 0<=n, 1<=m, got x=%lld, n=%lld, m=%d)", x, n, m); 313 | return (PyObject*)NULL; 314 | } 315 | return Py_BuildValue("L", pow_mod(x, n, m)); 316 | } 317 | static PyObject* acl_math_inv_mod(PyObject* self, PyObject* args){ 318 | long long x, m; 319 | PARSE_ARGS("LL", &x, &m); 320 | if(m <= 0){ 321 | PyErr_Format(PyExc_IndexError, 322 | "inv_mod constraint error (costraint: 1<=m, got x=%lld, m=%d)", x, m); 323 | return (PyObject*)NULL; 324 | } 325 | return Py_BuildValue("L", inv_mod(x, m)); 326 | } 327 | static PyObject* acl_math_crt(PyObject* self, PyObject* args){ 328 | PyObject *r_iterable, *m_iterable, *iterator, *item; 329 | PARSE_ARGS("OO", &r_iterable, &m_iterable); 330 | vector r, m; 331 | 332 | iterator = PyObject_GetIter(r_iterable); 333 | if(iterator==NULL) return NULL; 334 | if(Py_TYPE(r_iterable)->tp_as_sequence != NULL) r.reserve((int)Py_SIZE(r_iterable)); 335 | while(item = PyIter_Next(iterator)) { 336 | const long long& ri = PyLong_AsLongLong(item); 337 | r.push_back(ri); 338 | Py_DECREF(item); 339 | } 340 | Py_DECREF(iterator); 341 | if (PyErr_Occurred()) return NULL; 342 | 343 | iterator = PyObject_GetIter(m_iterable); 344 | if(iterator==NULL) return NULL; 345 | if(Py_TYPE(m_iterable)->tp_as_sequence != NULL) m.reserve((int)Py_SIZE(m_iterable)); 346 | while(item = PyIter_Next(iterator)) { 347 | const long long& mi = PyLong_AsLongLong(item); 348 | if(mi <= 0 && !PyErr_Occurred()) PyErr_Format(PyExc_ValueError, 349 | "crt constraint error (constraint: m>=1, got %lld)", mi); 350 | m.push_back(mi); 351 | Py_DECREF(item); 352 | } 353 | Py_DECREF(iterator); 354 | if (PyErr_Occurred()) return NULL; 355 | 356 | if(r.size() != m.size()){ 357 | PyErr_Format(PyExc_ValueError, 358 | "crt constraint error (constraint: len(r)=len(m), got len(r)=%d, len(m)=%d)", r.size(), m.size()); 359 | return NULL; 360 | } 361 | const pair& res = crt(r, m); 362 | return Py_BuildValue("LL", res.first, res.second); 363 | } 364 | static PyObject* acl_math_floor_sum(PyObject* self, PyObject* args){ 365 | long long n, m, a, b; 366 | PARSE_ARGS("LLLL", &n, &m, &a, &b); 367 | if(n < 0 || n > (long long)1e9 || m <= 0 || m > (long long)1e9 || a < 0 || a >= m || b < 0 || b >= m){ 368 | PyErr_Format(PyExc_IndexError, 369 | "floor_sum constraint error (costraint: 0<=n<=1e9, 1<=m<=1e9, 0<=a,b 6 | #include "structmember.h" 7 | 8 | // 元のライブラリの private を剥がした 9 | 10 | // >>> AtCoder >>> 11 | 12 | #ifndef ATCODER_MINCOSTFLOW_HPP 13 | #define ATCODER_MINCOSTFLOW_HPP 1 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | namespace atcoder { 22 | 23 | template struct mcf_graph { 24 | public: 25 | mcf_graph() {} 26 | mcf_graph(int n) : _n(n), g(n) {} 27 | 28 | int add_edge(int from, int to, Cap cap, Cost cost) { 29 | assert(0 <= from && from < _n); 30 | assert(0 <= to && to < _n); 31 | int m = int(pos.size()); 32 | pos.push_back({from, int(g[from].size())}); 33 | int from_id = int(g[from].size()); 34 | int to_id = int(g[to].size()); 35 | if (from == to) to_id++; 36 | g[from].push_back(_edge{to, to_id, cap, cost}); 37 | g[to].push_back(_edge{from, from_id, 0, -cost}); 38 | return m; 39 | } 40 | 41 | struct edge { 42 | int from, to; 43 | Cap cap, flow; 44 | Cost cost; 45 | }; 46 | 47 | edge get_edge(int i) { 48 | int m = int(pos.size()); 49 | assert(0 <= i && i < m); 50 | auto _e = g[pos[i].first][pos[i].second]; 51 | auto _re = g[_e.to][_e.rev]; 52 | return edge{ 53 | pos[i].first, _e.to, _e.cap + _re.cap, _re.cap, _e.cost, 54 | }; 55 | } 56 | std::vector edges() { 57 | int m = int(pos.size()); 58 | std::vector result(m); 59 | for (int i = 0; i < m; i++) { 60 | result[i] = get_edge(i); 61 | } 62 | return result; 63 | } 64 | 65 | std::pair flow(int s, int t) { 66 | return flow(s, t, std::numeric_limits::max()); 67 | } 68 | std::pair flow(int s, int t, Cap flow_limit) { 69 | return slope(s, t, flow_limit).back(); 70 | } 71 | std::vector> slope(int s, int t) { 72 | return slope(s, t, std::numeric_limits::max()); 73 | } 74 | std::vector> slope(int s, int t, Cap flow_limit) { 75 | assert(0 <= s && s < _n); 76 | assert(0 <= t && t < _n); 77 | assert(s != t); 78 | // variants (C = maxcost): 79 | // -(n-1)C <= dual[s] <= dual[i] <= dual[t] = 0 80 | // reduced cost (= e.cost + dual[e.from] - dual[e.to]) >= 0 for all edge 81 | std::vector dual(_n, 0), dist(_n); 82 | std::vector pv(_n), pe(_n); 83 | std::vector vis(_n); 84 | auto dual_ref = [&]() { 85 | std::fill(dist.begin(), dist.end(), 86 | std::numeric_limits::max()); 87 | std::fill(pv.begin(), pv.end(), -1); 88 | std::fill(pe.begin(), pe.end(), -1); 89 | std::fill(vis.begin(), vis.end(), false); 90 | struct Q { 91 | Cost key; 92 | int to; 93 | bool operator<(Q r) const { return key > r.key; } 94 | }; 95 | std::priority_queue que; 96 | dist[s] = 0; 97 | que.push(Q{0, s}); 98 | while (!que.empty()) { 99 | int v = que.top().to; 100 | que.pop(); 101 | if (vis[v]) continue; 102 | vis[v] = true; 103 | if (v == t) break; 104 | // dist[v] = shortest(s, v) + dual[s] - dual[v] 105 | // dist[v] >= 0 (all reduced cost are positive) 106 | // dist[v] <= (n-1)C 107 | for (int i = 0; i < int(g[v].size()); i++) { 108 | auto e = g[v][i]; 109 | if (vis[e.to] || !e.cap) continue; 110 | // |-dual[e.to] + dual[v]| <= (n-1)C 111 | // cost <= C - -(n-1)C + 0 = nC 112 | Cost cost = e.cost - dual[e.to] + dual[v]; 113 | if (dist[e.to] - dist[v] > cost) { 114 | dist[e.to] = dist[v] + cost; 115 | pv[e.to] = v; 116 | pe[e.to] = i; 117 | que.push(Q{dist[e.to], e.to}); 118 | } 119 | } 120 | } 121 | if (!vis[t]) { 122 | return false; 123 | } 124 | 125 | for (int v = 0; v < _n; v++) { 126 | if (!vis[v]) continue; 127 | // dual[v] = dual[v] - dist[t] + dist[v] 128 | // = dual[v] - (shortest(s, t) + dual[s] - dual[t]) + (shortest(s, v) + dual[s] - dual[v]) 129 | // = - shortest(s, t) + dual[t] + shortest(s, v) 130 | // = shortest(s, v) - shortest(s, t) >= 0 - (n-1)C 131 | dual[v] -= dist[t] - dist[v]; 132 | } 133 | return true; 134 | }; 135 | Cap flow = 0; 136 | Cost cost = 0, prev_cost_per_flow = -1; 137 | std::vector> result; 138 | result.push_back({flow, cost}); 139 | while (flow < flow_limit) { 140 | if (!dual_ref()) break; 141 | Cap c = flow_limit - flow; 142 | for (int v = t; v != s; v = pv[v]) { 143 | c = std::min(c, g[pv[v]][pe[v]].cap); 144 | } 145 | for (int v = t; v != s; v = pv[v]) { 146 | auto& e = g[pv[v]][pe[v]]; 147 | e.cap -= c; 148 | g[v][e.rev].cap += c; 149 | } 150 | Cost d = -dual[s]; 151 | flow += c; 152 | cost += c * d; 153 | if (prev_cost_per_flow == d) { 154 | result.pop_back(); 155 | } 156 | result.push_back({flow, cost}); 157 | prev_cost_per_flow = d; 158 | } 159 | return result; 160 | } 161 | 162 | // private: 163 | int _n; 164 | 165 | struct _edge { 166 | int to, rev; 167 | Cap cap; 168 | Cost cost; 169 | }; 170 | 171 | std::vector> pos; 172 | std::vector> g; 173 | }; 174 | 175 | } // namespace atcoder 176 | 177 | #endif // ATCODER_MINCOSTFLOW_HPP 178 | 179 | // <<< AtCoder <<< 180 | 181 | using namespace std; 182 | using namespace atcoder; 183 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 184 | 185 | 186 | struct MCFGraph{ 187 | PyObject_HEAD 188 | mcf_graph* graph; 189 | }; 190 | 191 | struct MCFGraphEdge{ 192 | PyObject_HEAD 193 | mcf_graph::edge* edge; 194 | }; 195 | 196 | 197 | extern PyTypeObject MCFGraphType; 198 | extern PyTypeObject MCFGraphEdgeType; 199 | 200 | 201 | // >>> MCFGraph definition >>> 202 | 203 | static void MCFGraph_dealloc(MCFGraph* self){ 204 | delete self->graph; 205 | Py_TYPE(self)->tp_free((PyObject*)self); 206 | } 207 | static PyObject* MCFGraph_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 208 | return type->tp_alloc(type, 0); 209 | } 210 | static int MCFGraph_init(MCFGraph* self, PyObject* args){ 211 | long n; 212 | if(!PyArg_ParseTuple(args, "l", &n)) return -1; 213 | if(n < 0 || n > (long)1e8){ 214 | PyErr_Format(PyExc_IndexError, 215 | "constraint error in MCFGraph constructor (constraint: 0<=n<=1e8, got n=%d)", n); 216 | return -1; 217 | } 218 | self->graph = new mcf_graph(n); 219 | return 0; 220 | } 221 | static PyObject* MCFGraph_add_edge(MCFGraph* self, PyObject* args){ 222 | long from, to; 223 | long long cap, cost; 224 | PARSE_ARGS("llLL", &from, &to, &cap, &cost); 225 | if(from < 0 || from >= self->graph->_n || to < 0 || to >= self->graph->_n){ 226 | PyErr_Format(PyExc_IndexError, 227 | "MCFGraph add_edge index out of range (n=%d, from=%d, to=%d)", self->graph->_n, from, to); 228 | return (PyObject*)NULL; 229 | } 230 | if(from == to){ 231 | PyErr_Format(PyExc_IndexError, "got self-loop (from=%d, to=%d)", from, to); 232 | return (PyObject*)NULL; 233 | } 234 | if(cap < 0){ 235 | PyErr_Format(PyExc_IndexError, "got negative cap (cap=%d)", cap); 236 | return (PyObject*)NULL; 237 | } 238 | if(cost < 0){ 239 | PyErr_Format(PyExc_IndexError, "got negative cost (cap=%d)", cost); 240 | return (PyObject*)NULL; 241 | } 242 | const int res = self->graph->add_edge(from, to, cap, cost); 243 | return Py_BuildValue("l", res); 244 | } 245 | static PyObject* MCFGraph_flow(MCFGraph* self, PyObject* args){ 246 | long s, t; 247 | long long flow_limit = numeric_limits::max(); 248 | PARSE_ARGS("ll|L", &s, &t, &flow_limit); 249 | if(s < 0 || s >= self->graph->_n || t < 0 || t >= self->graph->_n){ 250 | PyErr_Format(PyExc_IndexError, 251 | "MCFGraph flow index out of range (n=%d, s=%d, t=%d)", self->graph->_n, s, t); 252 | return (PyObject*)NULL; 253 | } 254 | if(s == t){ 255 | PyErr_Format(PyExc_IndexError, "got s == t (s=%d, t=%d)", s, t); 256 | return (PyObject*)NULL; 257 | } 258 | const pair& flow_cost = self->graph->flow(s, t, flow_limit); 259 | return Py_BuildValue("LL", flow_cost.first, flow_cost.second); 260 | } 261 | static PyObject* MCFGraph_slope(MCFGraph* self, PyObject* args){ 262 | long s, t; 263 | long long flow_limit = numeric_limits::max(); 264 | PARSE_ARGS("ll|L", &s, &t, &flow_limit); 265 | if(s < 0 || s >= self->graph->_n || t < 0 || t >= self->graph->_n){ 266 | PyErr_Format(PyExc_IndexError, 267 | "MCFGraph slope index out of range (n=%d, s=%d, t=%d)", self->graph->_n, s, t); 268 | return (PyObject*)NULL; 269 | } 270 | if(s == t){ 271 | PyErr_Format(PyExc_IndexError, "got s == t (s=%d, t=%d)", s, t); 272 | return (PyObject*)NULL; 273 | } 274 | const vector>& slope = self->graph->slope(s, t, flow_limit); 275 | const int siz = (int)slope.size(); 276 | PyObject* list = PyList_New(siz); 277 | for(int i = 0; i < siz; i++){ 278 | PyList_SET_ITEM(list, i, Py_BuildValue("LL", slope[i].first, slope[i].second)); 279 | } 280 | return list; 281 | } 282 | static PyObject* MCFGraph_get_edge(MCFGraph* self, PyObject* args){ 283 | long i; 284 | PARSE_ARGS("l", &i); 285 | const int m = (int)self->graph->pos.size(); 286 | if(i < 0 || i >= m){ 287 | PyErr_Format(PyExc_IndexError, 288 | "MCFGraph get_edge index out of range (m=%d, i=%d)", m, i); 289 | return (PyObject*)NULL; 290 | } 291 | MCFGraphEdge* edge = PyObject_NEW(MCFGraphEdge, &MCFGraphEdgeType); 292 | edge->edge = new mcf_graph::edge(self->graph->get_edge(i)); 293 | return (PyObject*)edge; 294 | } 295 | static PyObject* MCFGraph_edges(MCFGraph* self, PyObject* args){ 296 | const auto& edges = self->graph->edges(); 297 | const int m = (int)edges.size(); 298 | PyObject* list = PyList_New(m); 299 | for(int i = 0; i < m; i++){ 300 | MCFGraphEdge* edge = PyObject_NEW(MCFGraphEdge, &MCFGraphEdgeType); 301 | edge->edge = new mcf_graph::edge(edges[i]); 302 | PyList_SET_ITEM(list, i, (PyObject*)edge); 303 | } 304 | return list; 305 | } 306 | static PyObject* MCFGraph_repr(PyObject* self){ 307 | PyObject* edges = MCFGraph_edges((MCFGraph*)self, NULL); 308 | PyObject* res = PyUnicode_FromFormat("MCFGraph(%R)", edges); 309 | Py_DECREF(edges); 310 | return res; 311 | } 312 | static PyMethodDef MCFGraph_methods[] = { 313 | {"add_edge", (PyCFunction)MCFGraph_add_edge, METH_VARARGS, "Add edge"}, 314 | {"flow", (PyCFunction)MCFGraph_flow, METH_VARARGS, "Flow"}, 315 | {"slope", (PyCFunction)MCFGraph_slope, METH_VARARGS, "Slope"}, 316 | {"get_edge", (PyCFunction)MCFGraph_get_edge, METH_VARARGS, "Get edge"}, 317 | {"edges", (PyCFunction)MCFGraph_edges, METH_VARARGS, "Get edges"}, 318 | {NULL} /* Sentinel */ 319 | }; 320 | PyTypeObject MCFGraphType = { 321 | PyObject_HEAD_INIT(NULL) 322 | "atcoder.MCFGraph", /*tp_name*/ 323 | sizeof(MCFGraph), /*tp_basicsize*/ 324 | 0, /*tp_itemsize*/ 325 | (destructor)MCFGraph_dealloc, /*tp_dealloc*/ 326 | 0, /*tp_print*/ 327 | 0, /*tp_getattr*/ 328 | 0, /*tp_setattr*/ 329 | 0, /*reserved*/ 330 | MCFGraph_repr, /*tp_repr*/ 331 | 0, /*tp_as_number*/ 332 | 0, /*tp_as_sequence*/ 333 | 0, /*tp_as_mapping*/ 334 | 0, /*tp_hash*/ 335 | 0, /*tp_call*/ 336 | 0, /*tp_str*/ 337 | 0, /*tp_getattro*/ 338 | 0, /*tp_setattro*/ 339 | 0, /*tp_as_buffer*/ 340 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 341 | 0, /*tp_doc*/ 342 | 0, /*tp_traverse*/ 343 | 0, /*tp_clear*/ 344 | 0, /*tp_richcompare*/ 345 | 0, /*tp_weaklistoffset*/ 346 | 0, /*tp_iter*/ 347 | 0, /*tp_iternext*/ 348 | MCFGraph_methods, /*tp_methods*/ 349 | 0, /*tp_members*/ 350 | 0, /*tp_getset*/ 351 | 0, /*tp_base*/ 352 | 0, /*tp_dict*/ 353 | 0, /*tp_descr_get*/ 354 | 0, /*tp_descr_set*/ 355 | 0, /*tp_dictoffset*/ 356 | (initproc)MCFGraph_init, /*tp_init*/ 357 | 0, /*tp_alloc*/ 358 | MCFGraph_new, /*tp_new*/ 359 | 0, /*tp_free*/ 360 | 0, /*tp_is_gc*/ 361 | 0, /*tp_bases*/ 362 | 0, /*tp_mro*/ 363 | 0, /*tp_cache*/ 364 | 0, /*tp_subclasses*/ 365 | 0, /*tp_weaklist*/ 366 | 0, /*tp_del*/ 367 | 0, /*tp_version_tag*/ 368 | 0, /*tp_finalize*/ 369 | }; 370 | 371 | // <<< MCFGraph definition <<< 372 | 373 | 374 | // >>> MCFGraphEdge definition >>> 375 | 376 | static void MCFGraphEdge_dealloc(MCFGraphEdge* self){ 377 | delete self->edge; 378 | Py_TYPE(self)->tp_free((PyObject*)self); 379 | } 380 | static PyObject* MCFGraphEdge_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 381 | return type->tp_alloc(type, 0); 382 | } 383 | static int MCFGraphEdge_init(MCFGraphEdge* self, PyObject* args){ 384 | int from, to; 385 | long long cap, flow, cost; 386 | if(!PyArg_ParseTuple(args, "llLLL", &from, &to, &cap, &flow, &cost)) return -1; 387 | self->edge = new mcf_graph::edge(mcf_graph::edge{from, to, cap, flow, cost}); 388 | return 0; 389 | } 390 | static PyObject* MCFGraphEdge_get_from(MCFGraphEdge* self, PyObject* args){ 391 | return PyLong_FromLong(self->edge->from); 392 | } 393 | static PyObject* MCFGraphEdge_get_to(MCFGraphEdge* self, PyObject* args){ 394 | return PyLong_FromLong(self->edge->to); 395 | } 396 | static PyObject* MCFGraphEdge_get_flow(MCFGraphEdge* self, PyObject* args){ 397 | return PyLong_FromLongLong(self->edge->flow); 398 | } 399 | static PyObject* MCFGraphEdge_get_cap(MCFGraphEdge* self, PyObject* args){ 400 | return PyLong_FromLongLong(self->edge->cap); 401 | } 402 | static PyObject* MCFGraphEdge_get_cost(MCFGraphEdge* self, PyObject* args){ 403 | return PyLong_FromLongLong(self->edge->cost); 404 | } 405 | static PyObject* MCFGraphEdge_repr(PyObject* self){ 406 | MCFGraphEdge* self_ = (MCFGraphEdge*)self; 407 | PyObject* res = PyUnicode_FromFormat("MCFGraphEdge(%2d -> %2d, %2lld / %2lld, cost = %2lld)", 408 | self_->edge->from, self_->edge->to, self_->edge->flow, self_->edge->cap, self_->edge->cost); 409 | return res; 410 | } 411 | PyGetSetDef MCFGraphEdge_getsets[] = { 412 | {"from_", (getter)MCFGraphEdge_get_from, NULL, NULL, NULL}, 413 | {"to", (getter)MCFGraphEdge_get_to, NULL, NULL, NULL}, 414 | {"flow", (getter)MCFGraphEdge_get_flow, NULL, NULL, NULL}, 415 | {"cap", (getter)MCFGraphEdge_get_cap, NULL, NULL, NULL}, 416 | {"cost", (getter)MCFGraphEdge_get_cost, NULL, NULL, NULL}, 417 | {NULL} 418 | }; 419 | PyTypeObject MCFGraphEdgeType = { 420 | PyObject_HEAD_INIT(NULL) 421 | "atcoder.MCFGraphEdge", /*tp_name*/ 422 | sizeof(MCFGraphEdge), /*tp_basicsize*/ 423 | 0, /*tp_itemsize*/ 424 | (destructor)MCFGraphEdge_dealloc, /*tp_dealloc*/ 425 | 0, /*tp_print*/ 426 | 0, /*tp_getattr*/ 427 | 0, /*tp_setattr*/ 428 | 0, /*reserved*/ 429 | MCFGraphEdge_repr, /*tp_repr*/ 430 | 0, /*tp_as_number*/ 431 | 0, /*tp_as_sequence*/ 432 | 0, /*tp_as_mapping*/ 433 | 0, /*tp_hash*/ 434 | 0, /*tp_call*/ 435 | 0, /*tp_str*/ 436 | 0, /*tp_getattro*/ 437 | 0, /*tp_setattro*/ 438 | 0, /*tp_as_buffer*/ 439 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 440 | 0, /*tp_doc*/ 441 | 0, /*tp_traverse*/ 442 | 0, /*tp_clear*/ 443 | 0, /*tp_richcompare*/ 444 | 0, /*tp_weaklistoffset*/ 445 | 0, /*tp_iter*/ 446 | 0, /*tp_iternext*/ 447 | 0, /*tp_methods*/ 448 | 0, /*tp_members*/ 449 | MCFGraphEdge_getsets, /*tp_getset*/ 450 | 0, /*tp_base*/ 451 | 0, /*tp_dict*/ 452 | 0, /*tp_descr_get*/ 453 | 0, /*tp_descr_set*/ 454 | 0, /*tp_dictoffset*/ 455 | (initproc)MCFGraphEdge_init, /*tp_init*/ 456 | 0, /*tp_alloc*/ 457 | MCFGraphEdge_new, /*tp_new*/ 458 | 0, /*tp_free*/ 459 | 0, /*tp_is_gc*/ 460 | 0, /*tp_bases*/ 461 | 0, /*tp_mro*/ 462 | 0, /*tp_cache*/ 463 | 0, /*tp_subclasses*/ 464 | 0, /*tp_weaklist*/ 465 | 0, /*tp_del*/ 466 | 0, /*tp_version_tag*/ 467 | 0, /*tp_finalize*/ 468 | }; 469 | 470 | // <<< MCFGraphEdge definition <<< 471 | 472 | 473 | static PyModuleDef atcodermodule = { 474 | PyModuleDef_HEAD_INIT, 475 | "atcoder", 476 | NULL, 477 | -1, 478 | }; 479 | 480 | PyMODINIT_FUNC PyInit_atcoder(void) 481 | { 482 | PyObject* m; 483 | if(PyType_Ready(&MCFGraphType) < 0) return NULL; 484 | if(PyType_Ready(&MCFGraphEdgeType) < 0) return NULL; 485 | 486 | m = PyModule_Create(&atcodermodule); 487 | if(m == NULL) return NULL; 488 | 489 | Py_INCREF(&MCFGraphType); 490 | if (PyModule_AddObject(m, "MCFGraph", (PyObject*)&MCFGraphType) < 0) { 491 | Py_DECREF(&MCFGraphType); 492 | Py_DECREF(m); 493 | return NULL; 494 | } 495 | 496 | Py_INCREF(&MCFGraphEdgeType); 497 | if (PyModule_AddObject(m, "MCFGraphEdge", (PyObject*)&MCFGraphEdgeType) < 0) { 498 | Py_DECREF(&MCFGraphEdgeType); 499 | Py_DECREF(m); 500 | return NULL; 501 | } 502 | return m; 503 | } 504 | """ 505 | code_mcf_graph_setup = r""" 506 | from distutils.core import setup, Extension 507 | module = Extension( 508 | "atcoder", 509 | sources=["mcf_graph.cpp"], 510 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 511 | ) 512 | setup( 513 | name="atcoder-library", 514 | version="0.0.1", 515 | description="wrapper for atcoder library", 516 | ext_modules=[module] 517 | ) 518 | """ 519 | 520 | import os 521 | import sys 522 | 523 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 524 | with open("mcf_graph.cpp", "w") as f: 525 | f.write(code_mcf_graph) 526 | with open("mcf_graph_setup.py", "w") as f: 527 | f.write(code_mcf_graph_setup) 528 | os.system(f"{sys.executable} mcf_graph_setup.py build_ext --inplace") 529 | 530 | from atcoder import MCFGraph, MCFGraphEdge 531 | -------------------------------------------------------------------------------- /cpp_extension/acl_mf_graph.py: -------------------------------------------------------------------------------- 1 | # TODO: メモリリーク確認 2 | # TODO: min_cut, change_edge が正しく動くか確認 3 | 4 | code_mf_graph = r""" 5 | #define PY_SSIZE_T_CLEAN 6 | #include 7 | #include "structmember.h" 8 | 9 | // 元のライブラリの private を剥がした 10 | 11 | // >>> AtCoder >>> 12 | 13 | #ifndef ATCODER_MAXFLOW_HPP 14 | #define ATCODER_MAXFLOW_HPP 1 15 | 16 | #include 17 | 18 | #ifndef ATCODER_INTERNAL_QUEUE_HPP 19 | #define ATCODER_INTERNAL_QUEUE_HPP 1 20 | 21 | #include 22 | 23 | namespace atcoder { 24 | 25 | namespace internal { 26 | 27 | template struct simple_queue { 28 | std::vector payload; 29 | int pos = 0; 30 | void reserve(int n) { payload.reserve(n); } 31 | int size() const { return int(payload.size()) - pos; } 32 | bool empty() const { return pos == int(payload.size()); } 33 | void push(const T& t) { payload.push_back(t); } 34 | T& front() { return payload[pos]; } 35 | void clear() { 36 | payload.clear(); 37 | pos = 0; 38 | } 39 | void pop() { pos++; } 40 | }; 41 | 42 | } // namespace internal 43 | 44 | } // namespace atcoder 45 | 46 | #endif // ATCODER_INTERNAL_QUEUE_HPP 47 | 48 | #include 49 | #include 50 | #include 51 | #include 52 | 53 | namespace atcoder { 54 | 55 | template struct mf_graph { 56 | public: 57 | mf_graph() : _n(0) {} 58 | mf_graph(int n) : _n(n), g(n) {} 59 | 60 | int add_edge(int from, int to, Cap cap) { 61 | assert(0 <= from && from < _n); 62 | assert(0 <= to && to < _n); 63 | assert(0 <= cap); 64 | int m = int(pos.size()); 65 | pos.push_back({from, int(g[from].size())}); 66 | int from_id = int(g[from].size()); 67 | int to_id = int(g[to].size()); 68 | if (from == to) to_id++; 69 | g[from].push_back(_edge{to, to_id, cap}); 70 | g[to].push_back(_edge{from, from_id, 0}); 71 | return m; 72 | } 73 | 74 | struct edge { 75 | int from, to; 76 | Cap cap, flow; 77 | }; 78 | 79 | edge get_edge(int i) { 80 | int m = int(pos.size()); 81 | assert(0 <= i && i < m); 82 | auto _e = g[pos[i].first][pos[i].second]; 83 | auto _re = g[_e.to][_e.rev]; 84 | return edge{pos[i].first, _e.to, _e.cap + _re.cap, _re.cap}; 85 | } 86 | std::vector edges() { 87 | int m = int(pos.size()); 88 | std::vector result; 89 | for (int i = 0; i < m; i++) { 90 | result.push_back(get_edge(i)); 91 | } 92 | return result; 93 | } 94 | void change_edge(int i, Cap new_cap, Cap new_flow) { 95 | int m = int(pos.size()); 96 | assert(0 <= i && i < m); 97 | assert(0 <= new_flow && new_flow <= new_cap); 98 | auto& _e = g[pos[i].first][pos[i].second]; 99 | auto& _re = g[_e.to][_e.rev]; 100 | _e.cap = new_cap - new_flow; 101 | _re.cap = new_flow; 102 | } 103 | 104 | Cap flow(int s, int t) { 105 | return flow(s, t, std::numeric_limits::max()); 106 | } 107 | Cap flow(int s, int t, Cap flow_limit) { 108 | assert(0 <= s && s < _n); 109 | assert(0 <= t && t < _n); 110 | assert(s != t); 111 | 112 | std::vector level(_n), iter(_n); 113 | internal::simple_queue que; 114 | 115 | auto bfs = [&]() { 116 | std::fill(level.begin(), level.end(), -1); 117 | level[s] = 0; 118 | que.clear(); 119 | que.push(s); 120 | while (!que.empty()) { 121 | int v = que.front(); 122 | que.pop(); 123 | for (auto e : g[v]) { 124 | if (e.cap == 0 || level[e.to] >= 0) continue; 125 | level[e.to] = level[v] + 1; 126 | if (e.to == t) return; 127 | que.push(e.to); 128 | } 129 | } 130 | }; 131 | auto dfs = [&](auto self, int v, Cap up) { 132 | if (v == s) return up; 133 | Cap res = 0; 134 | int level_v = level[v]; 135 | for (int& i = iter[v]; i < int(g[v].size()); i++) { 136 | _edge& e = g[v][i]; 137 | if (level_v <= level[e.to] || g[e.to][e.rev].cap == 0) continue; 138 | Cap d = 139 | self(self, e.to, std::min(up - res, g[e.to][e.rev].cap)); 140 | if (d <= 0) continue; 141 | g[v][i].cap += d; 142 | g[e.to][e.rev].cap -= d; 143 | res += d; 144 | if (res == up) break; 145 | } 146 | return res; 147 | }; 148 | 149 | Cap flow = 0; 150 | while (flow < flow_limit) { 151 | bfs(); 152 | if (level[t] == -1) break; 153 | std::fill(iter.begin(), iter.end(), 0); 154 | while (flow < flow_limit) { 155 | Cap f = dfs(dfs, t, flow_limit - flow); 156 | if (!f) break; 157 | flow += f; 158 | } 159 | } 160 | return flow; 161 | } 162 | 163 | std::vector min_cut(int s) { 164 | std::vector visited(_n); 165 | internal::simple_queue que; 166 | que.push(s); 167 | while (!que.empty()) { 168 | int p = que.front(); 169 | que.pop(); 170 | visited[p] = true; 171 | for (auto e : g[p]) { 172 | if (e.cap && !visited[e.to]) { 173 | visited[e.to] = true; 174 | que.push(e.to); 175 | } 176 | } 177 | } 178 | return visited; 179 | } 180 | 181 | // private: 182 | int _n; 183 | struct _edge { 184 | int to, rev; 185 | Cap cap; 186 | }; 187 | std::vector> pos; 188 | std::vector> g; 189 | }; 190 | 191 | } // namespace atcoder 192 | 193 | #endif // ATCODER_MAXFLOW_HPP 194 | 195 | // <<< AtCoder <<< 196 | 197 | using namespace std; 198 | using namespace atcoder; 199 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 200 | 201 | 202 | struct MFGraph{ 203 | PyObject_HEAD 204 | mf_graph* graph; 205 | //unique_ptr> graph; 206 | }; 207 | 208 | struct MFGraphEdge{ 209 | PyObject_HEAD 210 | mf_graph::edge* edge; 211 | //unique_ptr::edge> edge; 212 | }; 213 | 214 | 215 | extern PyTypeObject MFGraphType; 216 | extern PyTypeObject MFGraphEdgeType; 217 | 218 | 219 | // >>> MFGraph definition >>> 220 | 221 | static void MFGraph_dealloc(MFGraph* self){ 222 | delete self->graph; 223 | Py_TYPE(self)->tp_free((PyObject*)self); 224 | } 225 | static PyObject* MFGraph_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 226 | MFGraph* self; 227 | self = (MFGraph*)type->tp_alloc(type, 0); 228 | return (PyObject*)self; 229 | } 230 | static int MFGraph_init(MFGraph* self, PyObject* args){ 231 | long n; 232 | if(!PyArg_ParseTuple(args, "l", &n)) return -1; 233 | if(n < 0 || n > (long)1e8){ 234 | PyErr_Format(PyExc_IndexError, 235 | "constraint error in MFGraph constructor (constraint: 0<=n<=1e8, got n=%d)", n); 236 | return -1; 237 | } 238 | //self->graph = make_unique>(n); 239 | self->graph = new mf_graph(n); 240 | return 0; 241 | } 242 | static PyObject* MFGraph_add_edge(MFGraph* self, PyObject* args){ 243 | long from, to; 244 | long long cap; 245 | PARSE_ARGS("llL", &from, &to, &cap); 246 | if(from < 0 || from >= self->graph->_n || to < 0 || to >= self->graph->_n){ 247 | PyErr_Format(PyExc_IndexError, 248 | "MFGraph add_edge index out of range (n=%d, from=%d, to=%d)", self->graph->_n, from, to); 249 | return (PyObject*)NULL; 250 | } 251 | if(from == to){ 252 | PyErr_Format(PyExc_IndexError, "got self-loop (from=%d, to=%d)", from, to); 253 | return (PyObject*)NULL; 254 | } 255 | if(cap < 0){ 256 | PyErr_Format(PyExc_IndexError, "got negative cap (cap=%d)", cap); 257 | return (PyObject*)NULL; 258 | } 259 | const int res = self->graph->add_edge(from, to, cap); 260 | return Py_BuildValue("l", res); 261 | } 262 | static PyObject* MFGraph_flow(MFGraph* self, PyObject* args){ 263 | long s, t; 264 | long long flow_limit = numeric_limits::max(); 265 | PARSE_ARGS("ll|L", &s, &t, &flow_limit); 266 | if(s < 0 || s >= self->graph->_n || t < 0 || t >= self->graph->_n){ 267 | PyErr_Format(PyExc_IndexError, 268 | "MFGraph flow index out of range (n=%d, s=%d, t=%d)", self->graph->_n, s, t); 269 | return (PyObject*)NULL; 270 | } 271 | if(s == t){ 272 | PyErr_Format(PyExc_IndexError, "got s == t (s=%d, t=%d)", s, t); 273 | return (PyObject*)NULL; 274 | } 275 | const long long& flow = self->graph->flow(s, t, flow_limit); 276 | return Py_BuildValue("L", flow); 277 | } 278 | static PyObject* MFGraph_min_cut(MFGraph* self, PyObject* args){ 279 | long s; 280 | PARSE_ARGS("l", &s); 281 | if(s < 0 || s >= self->graph->_n){ 282 | PyErr_Format(PyExc_IndexError, 283 | "MFGraph min_cut index out of range (n=%d, s=%d)", self->graph->_n, s); 284 | return (PyObject*)NULL; 285 | } 286 | const vector& vec = self->graph->min_cut(s); 287 | PyObject* list = PyList_New(vec.size()); 288 | for(int i = 0; i < (int)vec.size(); i++){ 289 | PyObject* b = vec[i] ? Py_True : Py_False; 290 | Py_INCREF(b); 291 | PyList_SET_ITEM(list, i, b); 292 | } 293 | return list; 294 | } 295 | static PyObject* MFGraph_get_edge(MFGraph* self, PyObject* args){ 296 | long i; 297 | PARSE_ARGS("l", &i); 298 | const int m = (int)self->graph->pos.size(); 299 | if(i < 0 || i >= m){ 300 | PyErr_Format(PyExc_IndexError, 301 | "MFGraph get_edge index out of range (m=%d, i=%d)", m, i); 302 | return (PyObject*)NULL; 303 | } 304 | MFGraphEdge* edge = PyObject_NEW(MFGraphEdge, &MFGraphEdgeType); 305 | //edge->edge = make_unique::edge>(self->graph->get_edge(i)); // なぜか edge に値が入っていて詰まる 306 | edge->edge = new mf_graph::edge(self->graph->get_edge(i)); 307 | return (PyObject*)edge; 308 | } 309 | static PyObject* MFGraph_edges(MFGraph* self, PyObject* args){ 310 | const auto& edges = self->graph->edges(); 311 | const int m = (int)edges.size(); 312 | PyObject* list = PyList_New(m); 313 | for(int i = 0; i < m; i++){ 314 | MFGraphEdge* edge = PyObject_NEW(MFGraphEdge, &MFGraphEdgeType); 315 | //edge->edge = make_unique::edge>(edges[i]); 316 | edge->edge = new mf_graph::edge(edges[i]); 317 | PyList_SET_ITEM(list, i, (PyObject*)edge); 318 | } 319 | return list; 320 | } 321 | static PyObject* MFGraph_change_edge(MFGraph* self, PyObject* args){ 322 | long i; 323 | long long new_cap, new_flow; 324 | PARSE_ARGS("lLL", &i, &new_cap, &new_flow); 325 | const int m = (int)self->graph->pos.size(); 326 | if(i < 0 || i >= m){ 327 | PyErr_Format(PyExc_IndexError, 328 | "MFGraph change_edge index out of range (m=%d, i=%d)", m, i); 329 | return (PyObject*)NULL; 330 | } 331 | if(new_flow < 0 || new_cap < new_flow){ 332 | PyErr_Format( 333 | PyExc_IndexError, 334 | "MFGraph change_edge constraint error (constraint: 0<=new_flow<=new_cap, got new_flow=%lld, new_cap=%lld)", 335 | new_flow, new_cap 336 | ); 337 | return (PyObject*)NULL; 338 | } 339 | self->graph->change_edge(i, new_cap, new_flow); 340 | Py_RETURN_NONE; 341 | } 342 | static PyObject* MFGraph_repr(PyObject* self){ 343 | PyObject* edges = MFGraph_edges((MFGraph*)self, NULL); 344 | PyObject* res = PyUnicode_FromFormat("MFGraph(%R)", edges); 345 | Py_DECREF(edges); 346 | return res; 347 | } 348 | static PyMethodDef MFGraph_methods[] = { 349 | {"add_edge", (PyCFunction)MFGraph_add_edge, METH_VARARGS, "Add edge"}, 350 | {"flow", (PyCFunction)MFGraph_flow, METH_VARARGS, "Flow"}, 351 | {"min_cut", (PyCFunction)MFGraph_min_cut, METH_VARARGS, "Get vertices those can be reached from source"}, 352 | {"get_edge", (PyCFunction)MFGraph_get_edge, METH_VARARGS, "Get edge"}, 353 | {"edges", (PyCFunction)MFGraph_edges, METH_VARARGS, "Get edges"}, 354 | {"change_edge", (PyCFunction)MFGraph_change_edge, METH_VARARGS, "Change edge"}, 355 | {NULL} /* Sentinel */ 356 | }; 357 | PyTypeObject MFGraphType = { 358 | PyObject_HEAD_INIT(NULL) 359 | "atcoder.MFGraph", /*tp_name*/ 360 | sizeof(MFGraph), /*tp_basicsize*/ 361 | 0, /*tp_itemsize*/ 362 | (destructor)MFGraph_dealloc, /*tp_dealloc*/ 363 | 0, /*tp_print*/ 364 | 0, /*tp_getattr*/ 365 | 0, /*tp_setattr*/ 366 | 0, /*reserved*/ 367 | MFGraph_repr, /*tp_repr*/ 368 | 0, /*tp_as_number*/ 369 | 0, /*tp_as_sequence*/ 370 | 0, /*tp_as_mapping*/ 371 | 0, /*tp_hash*/ 372 | 0, /*tp_call*/ 373 | 0, /*tp_str*/ 374 | 0, /*tp_getattro*/ 375 | 0, /*tp_setattro*/ 376 | 0, /*tp_as_buffer*/ 377 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 378 | 0, /*tp_doc*/ 379 | 0, /*tp_traverse*/ 380 | 0, /*tp_clear*/ 381 | 0, /*tp_richcompare*/ 382 | 0, /*tp_weaklistoffset*/ 383 | 0, /*tp_iter*/ 384 | 0, /*tp_iternext*/ 385 | MFGraph_methods, /*tp_methods*/ 386 | 0, /*tp_members*/ 387 | 0, /*tp_getset*/ 388 | 0, /*tp_base*/ 389 | 0, /*tp_dict*/ 390 | 0, /*tp_descr_get*/ 391 | 0, /*tp_descr_set*/ 392 | 0, /*tp_dictoffset*/ 393 | (initproc)MFGraph_init, /*tp_init*/ 394 | 0, /*tp_alloc*/ 395 | MFGraph_new, /*tp_new*/ 396 | 0, /*tp_free*/ 397 | 0, /*tp_is_gc*/ 398 | 0, /*tp_bases*/ 399 | 0, /*tp_mro*/ 400 | 0, /*tp_cache*/ 401 | 0, /*tp_subclasses*/ 402 | 0, /*tp_weaklist*/ 403 | 0, /*tp_del*/ 404 | 0, /*tp_version_tag*/ 405 | 0, /*tp_finalize*/ 406 | }; 407 | 408 | // <<< MFGraph definition <<< 409 | 410 | 411 | // >>> MFGraphEdge definition >>> 412 | 413 | static void MFGraphEdge_dealloc(MFGraphEdge* self){ 414 | delete self->edge; 415 | Py_TYPE(self)->tp_free((PyObject*)self); 416 | } 417 | static PyObject* MFGraphEdge_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 418 | //MFGraphEdge* self; 419 | //self = (MFGraphEdge*)type->tp_alloc(type, 0); 420 | //return (PyObject*)self; 421 | return type->tp_alloc(type, 0); 422 | } 423 | static int MFGraphEdge_init(MFGraphEdge* self, PyObject* args){ 424 | int from, to; 425 | long long cap, flow; 426 | if(!PyArg_ParseTuple(args, "llLL", &from, &to, &cap, &flow)) return -1; 427 | //self->edge = make_unique::edge>(mf_graph::edge{from, to, cap, flow}); 428 | self->edge = new mf_graph::edge(mf_graph::edge{from, to, cap, flow}); 429 | return 0; 430 | } 431 | static PyObject* MFGraphEdge_get_from(MFGraphEdge* self, PyObject* args){ 432 | return PyLong_FromLong(self->edge->from); 433 | } 434 | static PyObject* MFGraphEdge_get_to(MFGraphEdge* self, PyObject* args){ 435 | return PyLong_FromLong(self->edge->to); 436 | } 437 | static PyObject* MFGraphEdge_get_flow(MFGraphEdge* self, PyObject* args){ 438 | return PyLong_FromLongLong(self->edge->flow); 439 | } 440 | static PyObject* MFGraphEdge_get_cap(MFGraphEdge* self, PyObject* args){ 441 | return PyLong_FromLongLong(self->edge->cap); 442 | } 443 | static PyObject* MFGraphEdge_repr(PyObject* self){ 444 | MFGraphEdge* self_ = (MFGraphEdge*)self; 445 | PyObject* res = PyUnicode_FromFormat("MFGraphEdge(%2d -> %2d, %2lld / %2lld)", 446 | self_->edge->from, self_->edge->to, self_->edge->flow, self_->edge->cap); 447 | return res; 448 | } 449 | PyGetSetDef MFGraphEdge_getsets[] = { 450 | {"from_", (getter)MFGraphEdge_get_from, NULL, NULL, NULL}, 451 | {"to", (getter)MFGraphEdge_get_to, NULL, NULL, NULL}, 452 | {"flow", (getter)MFGraphEdge_get_flow, NULL, NULL, NULL}, 453 | {"cap", (getter)MFGraphEdge_get_cap, NULL, NULL, NULL}, 454 | {NULL} 455 | }; 456 | PyTypeObject MFGraphEdgeType = { 457 | PyObject_HEAD_INIT(NULL) 458 | "atcoder.MFGraphEdge", /*tp_name*/ 459 | sizeof(MFGraphEdge), /*tp_basicsize*/ 460 | 0, /*tp_itemsize*/ 461 | (destructor)MFGraphEdge_dealloc, /*tp_dealloc*/ 462 | 0, /*tp_print*/ 463 | 0, /*tp_getattr*/ 464 | 0, /*tp_setattr*/ 465 | 0, /*reserved*/ 466 | MFGraphEdge_repr, /*tp_repr*/ 467 | 0, /*tp_as_number*/ 468 | 0, /*tp_as_sequence*/ 469 | 0, /*tp_as_mapping*/ 470 | 0, /*tp_hash*/ 471 | 0, /*tp_call*/ 472 | 0, /*tp_str*/ 473 | 0, /*tp_getattro*/ 474 | 0, /*tp_setattro*/ 475 | 0, /*tp_as_buffer*/ 476 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 477 | 0, /*tp_doc*/ 478 | 0, /*tp_traverse*/ 479 | 0, /*tp_clear*/ 480 | 0, /*tp_richcompare*/ 481 | 0, /*tp_weaklistoffset*/ 482 | 0, /*tp_iter*/ 483 | 0, /*tp_iternext*/ 484 | 0, /*tp_methods*/ 485 | 0, /*tp_members*/ 486 | MFGraphEdge_getsets, /*tp_getset*/ 487 | 0, /*tp_base*/ 488 | 0, /*tp_dict*/ 489 | 0, /*tp_descr_get*/ 490 | 0, /*tp_descr_set*/ 491 | 0, /*tp_dictoffset*/ 492 | (initproc)MFGraphEdge_init, /*tp_init*/ 493 | 0, /*tp_alloc*/ 494 | MFGraphEdge_new, /*tp_new*/ 495 | 0, /*tp_free*/ 496 | 0, /*tp_is_gc*/ 497 | 0, /*tp_bases*/ 498 | 0, /*tp_mro*/ 499 | 0, /*tp_cache*/ 500 | 0, /*tp_subclasses*/ 501 | 0, /*tp_weaklist*/ 502 | 0, /*tp_del*/ 503 | 0, /*tp_version_tag*/ 504 | 0, /*tp_finalize*/ 505 | }; 506 | 507 | // <<< MFGraphEdge definition <<< 508 | 509 | 510 | static PyModuleDef atcodermodule = { 511 | PyModuleDef_HEAD_INIT, 512 | "atcoder", 513 | NULL, 514 | -1, 515 | }; 516 | 517 | PyMODINIT_FUNC PyInit_atcoder(void) 518 | { 519 | PyObject* m; 520 | if(PyType_Ready(&MFGraphType) < 0) return NULL; 521 | if(PyType_Ready(&MFGraphEdgeType) < 0) return NULL; 522 | 523 | m = PyModule_Create(&atcodermodule); 524 | if(m == NULL) return NULL; 525 | 526 | Py_INCREF(&MFGraphType); 527 | if (PyModule_AddObject(m, "MFGraph", (PyObject*)&MFGraphType) < 0) { 528 | Py_DECREF(&MFGraphType); 529 | Py_DECREF(m); 530 | return NULL; 531 | } 532 | 533 | Py_INCREF(&MFGraphEdgeType); 534 | if (PyModule_AddObject(m, "MFGraphEdge", (PyObject*)&MFGraphEdgeType) < 0) { 535 | Py_DECREF(&MFGraphEdgeType); 536 | Py_DECREF(m); 537 | return NULL; 538 | } 539 | return m; 540 | } 541 | """ 542 | code_mf_graph_setup = r""" 543 | from distutils.core import setup, Extension 544 | module = Extension( 545 | "atcoder", 546 | sources=["mf_graph.cpp"], 547 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 548 | ) 549 | setup( 550 | name="atcoder-library", 551 | version="0.0.1", 552 | description="wrapper for atcoder library", 553 | ext_modules=[module] 554 | ) 555 | """ 556 | 557 | import os 558 | import sys 559 | 560 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 561 | with open("mf_graph.cpp", "w") as f: 562 | f.write(code_mf_graph) 563 | with open("mf_graph_setup.py", "w") as f: 564 | f.write(code_mf_graph_setup) 565 | os.system(f"{sys.executable} mf_graph_setup.py build_ext --inplace") 566 | 567 | from atcoder import MFGraph, MFGraphEdge 568 | 569 | 570 | class ProjectSelection: 571 | def __init__(self, n, n_buffer=0): 572 | self.n = n 573 | self.n_buffer = n_buffer 574 | self.g = MFGraph(n + n_buffer + 2) 575 | self.offset = 0 576 | self.additional_node_index = n + 2 577 | 578 | def add_constraint(self, x, zero_one, gain): 579 | assert 0 <= x < self.n, f"x={x}, self.n={self.n}" 580 | s = self.n 581 | t = s + 1 582 | if zero_one == 0: 583 | if gain > 0: 584 | self.offset += gain 585 | self.g.add_edge(s, x, gain) 586 | elif gain < 0: 587 | self.g.add_edge(x, t, -gain) 588 | elif zero_one == 1: 589 | if gain > 0: 590 | self.offset += gain 591 | self.g.add_edge(x, t, gain) 592 | else: 593 | self.g.add_edge(s, x, -gain) 594 | else: 595 | assert False, f"zero_one={zero_one}" 596 | 597 | def add_constraint_01(self, x, y, gain): # (x, y) == (0, 1) なら gain 得る 598 | assert gain <= 0 599 | self.g.add_edge(x, y, -gain) 600 | 601 | def add_constraint_neq(self, x, y, gain): # 異なるなら gain 得る 602 | assert gain <= 0 603 | self.g.add_edge(x, y, -gain) 604 | self.g.add_edge(y, x, -gain) 605 | 606 | def add_constraint_eq(self, x, y, gain): # 同じなら gain 得る 607 | assert gain >= 0 608 | self.offset += gain 609 | self.g.add_edge(x, y, gain) 610 | self.g.add_edge(y, x, gain) 611 | 612 | def add_constraint_2(self, x, y, zero_one, gain): # x == y == zero_one なら gain 得る 613 | assert gain >= 0 614 | s = self.n 615 | t = s + 1 616 | w = self.additional_node_index 617 | assert w < self.n + self.n_buffer + 2 618 | self.offset += gain 619 | if zero_one == 0: 620 | self.g.add_edge(s, w, gain) 621 | self.g.add_edge(w, x, 10**18) 622 | self.g.add_edge(w, y, 10**18) 623 | elif zero_one == 1: 624 | self.g.add_edge(w, t, gain) 625 | self.g.add_edge(x, w, 10**18) 626 | self.g.add_edge(y, w, 10**18) 627 | else: 628 | assert False 629 | self.additional_node_index += 1 630 | 631 | def solve(self): 632 | return self.offset - self.g.flow(self.n, self.n+1) 633 | 634 | 635 | -------------------------------------------------------------------------------- /cpp_extension/acl_segtree.py: -------------------------------------------------------------------------------- 1 | # TODO: 更新ルールの異なる複数のセグ木を作ったときに正しく動くか検証 2 | # TODO: 引数の順番間違えたりすると黙って落ちるのを何とかする 3 | 4 | # 注: PyPy で普通に書いた方が速い 5 | 6 | 7 | code_segtree = r""" 8 | #define PY_SSIZE_T_CLEAN 9 | #include 10 | #include "structmember.h" 11 | 12 | // >>> AtCoder >>> 13 | 14 | #ifndef ATCODER_SEGTREE_HPP 15 | #define ATCODER_SEGTREE_HPP 1 16 | 17 | #include 18 | #ifndef ATCODER_INTERNAL_BITOP_HPP 19 | #define ATCODER_INTERNAL_BITOP_HPP 1 20 | 21 | #ifdef _MSC_VER 22 | #include 23 | #endif 24 | 25 | namespace atcoder { 26 | 27 | namespace internal { 28 | 29 | // @param n `0 <= n` 30 | // @return minimum non-negative `x` s.t. `n <= 2**x` 31 | int ceil_pow2(int n) { 32 | int x = 0; 33 | while ((1U << x) < (unsigned int)(n)) x++; 34 | return x; 35 | } 36 | 37 | // @param n `1 <= n` 38 | // @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0` 39 | int bsf(unsigned int n) { 40 | #ifdef _MSC_VER 41 | unsigned long index; 42 | _BitScanForward(&index, n); 43 | return index; 44 | #else 45 | return __builtin_ctz(n); 46 | #endif 47 | } 48 | 49 | } // namespace internal 50 | 51 | } // namespace atcoder 52 | 53 | #endif // ATCODER_INTERNAL_BITOP_HPP 54 | #include 55 | #include 56 | 57 | namespace atcoder { 58 | 59 | template struct segtree { 60 | public: 61 | segtree() : segtree(0) {} 62 | segtree(int n) : segtree(std::vector(n, e())) {} 63 | segtree(const std::vector& v) : _n(int(v.size())) { 64 | log = internal::ceil_pow2(_n); 65 | size = 1 << log; 66 | d = std::vector(2 * size, e()); 67 | for (int i = 0; i < _n; i++) d[size + i] = v[i]; 68 | for (int i = size - 1; i >= 1; i--) { 69 | update(i); 70 | } 71 | } 72 | 73 | void set(int p, S x) { 74 | assert(0 <= p && p < _n); 75 | p += size; 76 | d[p] = x; 77 | for (int i = 1; i <= log; i++) update(p >> i); 78 | } 79 | 80 | S get(int p) { 81 | assert(0 <= p && p < _n); 82 | return d[p + size]; 83 | } 84 | 85 | S prod(int l, int r) { 86 | assert(0 <= l && l <= r && r <= _n); 87 | S sml = e(), smr = e(); 88 | l += size; 89 | r += size; 90 | 91 | while (l < r) { 92 | if (l & 1) sml = op(sml, d[l++]); 93 | if (r & 1) smr = op(d[--r], smr); 94 | l >>= 1; 95 | r >>= 1; 96 | } 97 | return op(sml, smr); 98 | } 99 | 100 | S all_prod() { return d[1]; } 101 | 102 | template int max_right(int l) { 103 | return max_right(l, [](S x) { return f(x); }); 104 | } 105 | template int max_right(int l, F f) { 106 | assert(0 <= l && l <= _n); 107 | assert(f(e())); 108 | if (l == _n) return _n; 109 | l += size; 110 | S sm = e(); 111 | do { 112 | while (l % 2 == 0) l >>= 1; 113 | if (!f(op(sm, d[l]))) { 114 | while (l < size) { 115 | l = (2 * l); 116 | if (f(op(sm, d[l]))) { 117 | sm = op(sm, d[l]); 118 | l++; 119 | } 120 | } 121 | return l - size; 122 | } 123 | sm = op(sm, d[l]); 124 | l++; 125 | } while ((l & -l) != l); 126 | return _n; 127 | } 128 | 129 | template int min_left(int r) { 130 | return min_left(r, [](S x) { return f(x); }); 131 | } 132 | template int min_left(int r, F f) { 133 | assert(0 <= r && r <= _n); 134 | assert(f(e())); 135 | if (r == 0) return 0; 136 | r += size; 137 | S sm = e(); 138 | do { 139 | r--; 140 | while (r > 1 && (r % 2)) r >>= 1; 141 | if (!f(op(d[r], sm))) { 142 | while (r < size) { 143 | r = (2 * r + 1); 144 | if (f(op(d[r], sm))) { 145 | sm = op(d[r], sm); 146 | r--; 147 | } 148 | } 149 | return r + 1 - size; 150 | } 151 | sm = op(d[r], sm); 152 | } while ((r & -r) != r); 153 | return 0; 154 | } 155 | 156 | private: 157 | int _n, size, log; 158 | std::vector d; 159 | 160 | void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); } 161 | }; 162 | 163 | } // namespace atcoder 164 | 165 | #endif // ATCODER_SEGTREE_HPP 166 | 167 | // <<< AtCoder <<< 168 | 169 | using namespace std; 170 | using namespace atcoder; 171 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 172 | 173 | struct AutoDecrefPtr{ 174 | PyObject* p; 175 | AutoDecrefPtr(PyObject* _p) : p(_p) {}; 176 | AutoDecrefPtr(const AutoDecrefPtr& rhs) : p(rhs.p) { Py_INCREF(p); }; 177 | ~AutoDecrefPtr(){ Py_DECREF(p); } 178 | AutoDecrefPtr &operator=(const AutoDecrefPtr& rhs){ 179 | Py_DECREF(p); 180 | p = rhs.p; 181 | Py_INCREF(p); 182 | return *this; 183 | } 184 | }; 185 | 186 | static PyObject* segtree_op_py; 187 | static AutoDecrefPtr segtree_op(AutoDecrefPtr a, AutoDecrefPtr b){ 188 | auto tmp = PyObject_CallFunctionObjArgs(segtree_op_py, a.p, b.p, NULL); 189 | return AutoDecrefPtr(tmp); 190 | } 191 | static PyObject* segtree_e_py; 192 | static AutoDecrefPtr segtree_e(){ 193 | Py_INCREF(segtree_e_py); 194 | return AutoDecrefPtr(segtree_e_py); 195 | } 196 | static PyObject* segtree_f_py; 197 | static bool segtree_f(AutoDecrefPtr x){ 198 | PyObject* pyfunc_res = PyObject_CallFunctionObjArgs(segtree_f_py, x.p, NULL); 199 | int res = PyObject_IsTrue(pyfunc_res); 200 | if(res == -1) PyErr_Format(PyExc_ValueError, "error in SegTree f"); 201 | return (bool)res; 202 | } 203 | struct SegTree{ 204 | PyObject_HEAD 205 | segtree* seg; 206 | PyObject* op; 207 | PyObject* e; 208 | int n; 209 | }; 210 | 211 | extern PyTypeObject SegTreeType; 212 | 213 | static void SegTree_dealloc(SegTree* self){ 214 | delete self->seg; 215 | Py_DECREF(self->op); 216 | Py_DECREF(self->e); 217 | Py_TYPE(self)->tp_free((PyObject*)self); 218 | } 219 | static PyObject* SegTree_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 220 | SegTree* self; 221 | self = (SegTree*)type->tp_alloc(type, 0); 222 | return (PyObject*)self; 223 | } 224 | static inline void set_op_e(SegTree* self){ 225 | segtree_op_py = self->op; 226 | segtree_e_py = self->e; 227 | } 228 | static int SegTree_init(SegTree* self, PyObject* args){ 229 | if(Py_SIZE(args) != 3){ 230 | self->op = Py_None; // 何か入れておかないとヤバいことになる 231 | Py_INCREF(Py_None); 232 | self->e = Py_None; 233 | Py_INCREF(Py_None); 234 | PyErr_Format(PyExc_TypeError, "SegTree constructor expected 3 arguments (op, e, n), got %d", Py_SIZE(args)); 235 | return -1; 236 | } 237 | PyObject* arg; 238 | if(!PyArg_ParseTuple(args, "OOO", &self->op, &self->e, &arg)) return -1; 239 | Py_INCREF(self->op); 240 | Py_INCREF(self->e); 241 | set_op_e(self); 242 | if(PyLong_Check(arg)){ 243 | int n = (int)PyLong_AsLong(arg); 244 | if(PyErr_Occurred()) return -1; 245 | if(n < 0 || n > (int)1e8) { 246 | PyErr_Format(PyExc_ValueError, "constraint error in SegTree constructor (got %d)", n); 247 | return -1; 248 | } 249 | self->seg = new segtree(n); 250 | self->n = n; 251 | }else{ 252 | PyObject *iterator = PyObject_GetIter(arg); 253 | if(iterator==NULL) return -1; 254 | PyObject *item; 255 | vector vec; 256 | if(Py_TYPE(arg)->tp_as_sequence != NULL) vec.reserve((int)Py_SIZE(arg)); 257 | while(item = PyIter_Next(iterator)) { 258 | vec.push_back(item); 259 | } 260 | Py_DECREF(iterator); 261 | if (PyErr_Occurred()) return -1; 262 | self->seg = new segtree(vec); 263 | self->n = (int)vec.size(); 264 | } 265 | return 0; 266 | } 267 | static PyObject* SegTree_set(SegTree* self, PyObject* args){ 268 | long p; 269 | PyObject* x; 270 | PARSE_ARGS("lO", &p, &x); 271 | if(p < 0 || p >= self->n){ 272 | PyErr_Format(PyExc_IndexError, "SegTree set index out of range (size=%d, index=%d)", self->n, p); 273 | return (PyObject*)NULL; 274 | } 275 | Py_INCREF(x); 276 | set_op_e(self); 277 | self->seg->set((int)p, AutoDecrefPtr(x)); 278 | Py_RETURN_NONE; 279 | } 280 | static PyObject* SegTree_get(SegTree* self, PyObject* args){ 281 | long p; 282 | PARSE_ARGS("l", &p); 283 | if(p < 0 || p >= self->n){ 284 | PyErr_Format(PyExc_IndexError, "SegTree get index out of range (size=%d, index=%d)", self->n, p); 285 | return (PyObject*)NULL; 286 | } 287 | PyObject* res = self->seg->get((int)p).p; 288 | return Py_BuildValue("O", res); 289 | } 290 | static PyObject* SegTree_prod(SegTree* self, PyObject* args){ 291 | long l, r; 292 | PARSE_ARGS("ll", &l, &r); 293 | set_op_e(self); 294 | auto res = self->seg->prod((int)l, (int)r).p; 295 | return Py_BuildValue("O", res); 296 | } 297 | static PyObject* SegTree_all_prod(SegTree* self, PyObject* args){ 298 | PyObject* res = self->seg->all_prod().p; 299 | return Py_BuildValue("O", res); 300 | } 301 | static PyObject* SegTree_max_right(SegTree* self, PyObject* args){ 302 | long l; 303 | PARSE_ARGS("lO", &l, &segtree_f_py); 304 | if(l < 0 || l > self->n){ 305 | PyErr_Format(PyExc_IndexError, "SegTree max_right index out of range (size=%d, l=%d)", self->n, l); 306 | return (PyObject*)NULL; 307 | } 308 | set_op_e(self); 309 | int res = self->seg->max_right((int)l); 310 | return Py_BuildValue("l", res); 311 | } 312 | static PyObject* SegTree_min_left(SegTree* self, PyObject* args){ 313 | long r; 314 | PARSE_ARGS("lO", &r, &segtree_f_py); 315 | if(r < 0 || r > self->n){ 316 | PyErr_Format(PyExc_IndexError, "SegTree max_right index out of range (size=%d, r=%d)", self->n, r); 317 | return (PyObject*)NULL; 318 | } 319 | set_op_e(self); 320 | int res = self->seg->min_left((int)r); 321 | return Py_BuildValue("l", res); 322 | } 323 | static PyObject* SegTree_to_list(SegTree* self){ 324 | PyObject* list = PyList_New(self->n); 325 | for(int i=0; in; i++){ 326 | PyObject* val = self->seg->get(i).p; 327 | Py_INCREF(val); 328 | PyList_SET_ITEM(list, i, val); 329 | } 330 | return list; 331 | } 332 | static PyObject* SegTree_repr(PyObject* self){ 333 | PyObject* list = SegTree_to_list((SegTree*)self); 334 | PyObject* res = PyUnicode_FromFormat("SegTree(%R)", list); 335 | Py_ReprLeave(self); 336 | Py_DECREF(list); 337 | return res; 338 | } 339 | 340 | static PyMethodDef SegTree_methods[] = { 341 | {"set", (PyCFunction)SegTree_set, METH_VARARGS, "Set item"}, 342 | {"get", (PyCFunction)SegTree_get, METH_VARARGS, "Get item"}, 343 | {"prod", (PyCFunction)SegTree_prod, METH_VARARGS, "Get item"}, 344 | {"all_prod", (PyCFunction)SegTree_all_prod, METH_VARARGS, "Get item"}, 345 | {"max_right", (PyCFunction)SegTree_max_right, METH_VARARGS, "Binary search on segtree"}, 346 | {"min_left", (PyCFunction)SegTree_min_left, METH_VARARGS, "Binary search on segtree"}, 347 | {"to_list", (PyCFunction)SegTree_to_list, METH_VARARGS, "Convert to list"}, 348 | {NULL} /* Sentinel */ 349 | }; 350 | PyTypeObject SegTreeType = { 351 | PyObject_HEAD_INIT(NULL) 352 | "atcoder.SegTree", /*tp_name*/ 353 | sizeof(SegTree), /*tp_basicsize*/ 354 | 0, /*tp_itemsize*/ 355 | (destructor)SegTree_dealloc, /*tp_dealloc*/ 356 | 0, /*tp_print*/ 357 | 0, /*tp_getattr*/ 358 | 0, /*tp_setattr*/ 359 | 0, /*reserved*/ 360 | SegTree_repr, /*tp_repr*/ 361 | 0, /*tp_as_number*/ 362 | 0, /*tp_as_sequence*/ 363 | 0, /*tp_as_mapping*/ 364 | 0, /*tp_hash*/ 365 | 0, /*tp_call*/ 366 | 0, /*tp_str*/ 367 | 0, /*tp_getattro*/ 368 | 0, /*tp_setattro*/ 369 | 0, /*tp_as_buffer*/ 370 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 371 | 0, /*tp_doc*/ 372 | 0, /*tp_traverse*/ 373 | 0, /*tp_clear*/ 374 | 0, /*tp_richcompare*/ 375 | 0, /*tp_weaklistoffset*/ 376 | 0, /*tp_iter*/ 377 | 0, /*tp_iternext*/ 378 | SegTree_methods, /*tp_methods*/ 379 | 0, /*tp_members*/ 380 | 0, /*tp_getset*/ 381 | 0, /*tp_base*/ 382 | 0, /*tp_dict*/ 383 | 0, /*tp_descr_get*/ 384 | 0, /*tp_descr_set*/ 385 | 0, /*tp_dictoffset*/ 386 | (initproc)SegTree_init, /*tp_init*/ 387 | 0, /*tp_alloc*/ 388 | SegTree_new, /*tp_new*/ 389 | 0, /*tp_free*/ 390 | 0, /*tp_is_gc*/ 391 | 0, /*tp_bases*/ 392 | 0, /*tp_mro*/ 393 | 0, /*tp_cache*/ 394 | 0, /*tp_subclasses*/ 395 | 0, /*tp_weaklist*/ 396 | 0, /*tp_del*/ 397 | 0, /*tp_version_tag*/ 398 | 0, /*tp_finalize*/ 399 | }; 400 | 401 | static PyModuleDef atcodermodule = { 402 | PyModuleDef_HEAD_INIT, 403 | "atcoder", 404 | NULL, 405 | -1, 406 | }; 407 | 408 | PyMODINIT_FUNC PyInit_atcoder(void) 409 | { 410 | PyObject* m; 411 | if(PyType_Ready(&SegTreeType) < 0) return NULL; 412 | 413 | m = PyModule_Create(&atcodermodule); 414 | if(m == NULL) return NULL; 415 | 416 | Py_INCREF(&SegTreeType); 417 | if (PyModule_AddObject(m, "SegTree", (PyObject*)&SegTreeType) < 0) { 418 | Py_DECREF(&SegTreeType); 419 | Py_DECREF(m); 420 | return NULL; 421 | } 422 | 423 | return m; 424 | } 425 | """ 426 | code_setup = r""" 427 | from distutils.core import setup, Extension 428 | module = Extension( 429 | "atcoder", 430 | sources=["atcoder_library_wrapper.cpp"], 431 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 432 | ) 433 | setup( 434 | name="atcoder-library", 435 | version="0.0.1", 436 | description="wrapper for atcoder library", 437 | ext_modules=[module] 438 | ) 439 | """ 440 | 441 | import os 442 | import sys 443 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 444 | with open("atcoder_library_wrapper.cpp", "w") as f: 445 | f.write(code_segtree) 446 | with open("setup.py", "w") as f: 447 | f.write(code_setup) 448 | os.system(f"{sys.executable} setup.py build_ext --inplace") 449 | 450 | from atcoder import SegTree 451 | 452 | -------------------------------------------------------------------------------- /cpp_extension/acl_two_sat.py: -------------------------------------------------------------------------------- 1 | # TODO: メモリリーク確認 2 | # TODO: __repr__ を書く 3 | 4 | code_two_sat = r""" 5 | #define PY_SSIZE_T_CLEAN 6 | #include 7 | #include "structmember.h" 8 | 9 | // 元のライブラリの private を剥がした 10 | 11 | // >>> AtCoder >>> 12 | 13 | #ifndef ATCODER_TWOSAT_HPP 14 | #define ATCODER_TWOSAT_HPP 1 15 | 16 | #ifndef ATCODER_INTERNAL_SCC_HPP 17 | #define ATCODER_INTERNAL_SCC_HPP 1 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | namespace atcoder { 24 | namespace internal { 25 | 26 | template struct csr { 27 | std::vector start; 28 | std::vector elist; 29 | csr(int n, const std::vector>& edges) 30 | : start(n + 1), elist(edges.size()) { 31 | for (auto e : edges) { 32 | start[e.first + 1]++; 33 | } 34 | for (int i = 1; i <= n; i++) { 35 | start[i] += start[i - 1]; 36 | } 37 | auto counter = start; 38 | for (auto e : edges) { 39 | elist[counter[e.first]++] = e.second; 40 | } 41 | } 42 | }; 43 | 44 | // Reference: 45 | // R. Tarjan, 46 | // Depth-First Search and Linear Graph Algorithms 47 | struct scc_graph { 48 | public: 49 | scc_graph(int n) : _n(n) {} 50 | 51 | int num_vertices() { return _n; } 52 | 53 | void add_edge(int from, int to) { edges.push_back({from, {to}}); } 54 | 55 | // @return pair of (# of scc, scc id) 56 | std::pair> scc_ids() { 57 | auto g = csr(_n, edges); 58 | int now_ord = 0, group_num = 0; 59 | std::vector visited, low(_n), ord(_n, -1), ids(_n); 60 | visited.reserve(_n); 61 | auto dfs = [&](auto self, int v) -> void { 62 | low[v] = ord[v] = now_ord++; 63 | visited.push_back(v); 64 | for (int i = g.start[v]; i < g.start[v + 1]; i++) { 65 | auto to = g.elist[i].to; 66 | if (ord[to] == -1) { 67 | self(self, to); 68 | low[v] = std::min(low[v], low[to]); 69 | } else { 70 | low[v] = std::min(low[v], ord[to]); 71 | } 72 | } 73 | if (low[v] == ord[v]) { 74 | while (true) { 75 | int u = visited.back(); 76 | visited.pop_back(); 77 | ord[u] = _n; 78 | ids[u] = group_num; 79 | if (u == v) break; 80 | } 81 | group_num++; 82 | } 83 | }; 84 | for (int i = 0; i < _n; i++) { 85 | if (ord[i] == -1) dfs(dfs, i); 86 | } 87 | for (auto& x : ids) { 88 | x = group_num - 1 - x; 89 | } 90 | return {group_num, ids}; 91 | } 92 | 93 | std::vector> scc() { 94 | auto ids = scc_ids(); 95 | int group_num = ids.first; 96 | std::vector counts(group_num); 97 | for (auto x : ids.second) counts[x]++; 98 | std::vector> groups(ids.first); 99 | for (int i = 0; i < group_num; i++) { 100 | groups[i].reserve(counts[i]); 101 | } 102 | for (int i = 0; i < _n; i++) { 103 | groups[ids.second[i]].push_back(i); 104 | } 105 | return groups; 106 | } 107 | 108 | private: 109 | int _n; 110 | struct edge { 111 | int to; 112 | }; 113 | std::vector> edges; 114 | }; 115 | 116 | } // namespace internal 117 | 118 | } // namespace atcoder 119 | 120 | #endif // ATCODER_INTERNAL_SCC_HPP 121 | 122 | #include 123 | #include 124 | 125 | namespace atcoder { 126 | 127 | // Reference: 128 | // B. Aspvall, M. Plass, and R. Tarjan, 129 | // A Linear-Time Algorithm for Testing the Truth of Certain Quantified Boolean 130 | // Formulas 131 | struct two_sat { 132 | public: 133 | two_sat() : _n(0), scc(0) {} 134 | two_sat(int n) : _n(n), _answer(n), scc(2 * n) {} 135 | 136 | void add_clause(int i, bool f, int j, bool g) { 137 | assert(0 <= i && i < _n); 138 | assert(0 <= j && j < _n); 139 | scc.add_edge(2 * i + (f ? 0 : 1), 2 * j + (g ? 1 : 0)); 140 | scc.add_edge(2 * j + (g ? 0 : 1), 2 * i + (f ? 1 : 0)); 141 | } 142 | bool satisfiable() { 143 | auto id = scc.scc_ids().second; 144 | for (int i = 0; i < _n; i++) { 145 | if (id[2 * i] == id[2 * i + 1]) return false; 146 | _answer[i] = id[2 * i] < id[2 * i + 1]; 147 | } 148 | return true; 149 | } 150 | std::vector answer() { return _answer; } 151 | 152 | // private: 153 | int _n; 154 | std::vector _answer; 155 | internal::scc_graph scc; 156 | }; 157 | 158 | } // namespace atcoder 159 | 160 | #endif // ATCODER_TWOSAT_HPP 161 | 162 | // <<< AtCoder <<< 163 | 164 | using namespace std; 165 | using namespace atcoder; 166 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 167 | 168 | 169 | struct TwoSAT{ 170 | PyObject_HEAD 171 | two_sat* ts; 172 | }; 173 | 174 | 175 | extern PyTypeObject TwoSATType; 176 | 177 | 178 | // >>> TwoSAT definition >>> 179 | 180 | static void TwoSAT_dealloc(TwoSAT* self){ 181 | delete self->ts; 182 | Py_TYPE(self)->tp_free((PyObject*)self); 183 | } 184 | static PyObject* TwoSAT_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 185 | return type->tp_alloc(type, 0); 186 | } 187 | static int TwoSAT_init(TwoSAT* self, PyObject* args){ 188 | long n; 189 | if(!PyArg_ParseTuple(args, "l", &n)) return -1; 190 | if(n < 0 || n > (long)1e8){ 191 | PyErr_Format(PyExc_IndexError, 192 | "TwoSAT constructor constraint error (constraint: 0<=n<=1e8, got n=%d)", n); 193 | return -1; 194 | } 195 | self->ts = new two_sat(n); 196 | return 0; 197 | } 198 | static PyObject* TwoSAT_add_clause(TwoSAT* self, PyObject* args){ 199 | long i, j; 200 | int f, g; 201 | PARSE_ARGS("lplp", &i, &f, &j, &g); 202 | if(i < 0 || i >= self->ts->_n || j < 0 || j >= self->ts->_n){ 203 | PyErr_Format(PyExc_IndexError, 204 | "TwoSAT add_clause index out of range (n=%d, i=%d, j=%d)", self->ts->_n, i, j); 205 | return (PyObject*)NULL; 206 | } 207 | self->ts->add_clause(i, (bool)f, j, (bool)g); 208 | Py_RETURN_NONE; 209 | } 210 | static PyObject* TwoSAT_satisfiable(TwoSAT* self, PyObject* args){ 211 | PyObject* res = self->ts->satisfiable() ? Py_True : Py_False; 212 | return Py_BuildValue("O", res); 213 | } 214 | static PyObject* TwoSAT_answer(TwoSAT* self, PyObject* args){ 215 | const vector& answer = self->ts->answer(); 216 | const int& n = self->ts->_n; 217 | PyObject* list = PyList_New(n); 218 | for(int i = 0; i < n; i++){ 219 | PyList_SET_ITEM(list, i, Py_BuildValue("O", answer[i] ? Py_True : Py_False)); 220 | } 221 | return list; 222 | } 223 | /* 224 | static PyObject* TwoSAT_repr(PyObject* self){ 225 | PyObject* res = PyUnicode_FromFormat("TwoSAT()"); 226 | return res; 227 | } 228 | */ 229 | static PyMethodDef TwoSAT_methods[] = { 230 | {"add_clause", (PyCFunction)TwoSAT_add_clause, METH_VARARGS, "Add clause"}, 231 | {"satisfiable", (PyCFunction)TwoSAT_satisfiable, METH_VARARGS, "Check if problem satisfiable"}, 232 | {"answer", (PyCFunction)TwoSAT_answer, METH_VARARGS, "Get answer"}, 233 | {NULL} /* Sentinel */ 234 | }; 235 | PyTypeObject TwoSATType = { 236 | PyObject_HEAD_INIT(NULL) 237 | "acl_twosat.TwoSAT", /*tp_name*/ 238 | sizeof(TwoSAT), /*tp_basicsize*/ 239 | 0, /*tp_itemsize*/ 240 | (destructor)TwoSAT_dealloc, /*tp_dealloc*/ 241 | 0, /*tp_print*/ 242 | 0, /*tp_getattr*/ 243 | 0, /*tp_setattr*/ 244 | 0, /*reserved*/ 245 | 0,//TwoSAT_repr, /*tp_repr*/ 246 | 0, /*tp_as_number*/ 247 | 0, /*tp_as_sequence*/ 248 | 0, /*tp_as_mapping*/ 249 | 0, /*tp_hash*/ 250 | 0, /*tp_call*/ 251 | 0, /*tp_str*/ 252 | 0, /*tp_getattro*/ 253 | 0, /*tp_setattro*/ 254 | 0, /*tp_as_buffer*/ 255 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 256 | 0, /*tp_doc*/ 257 | 0, /*tp_traverse*/ 258 | 0, /*tp_clear*/ 259 | 0, /*tp_richcompare*/ 260 | 0, /*tp_weaklistoffset*/ 261 | 0, /*tp_iter*/ 262 | 0, /*tp_iternext*/ 263 | TwoSAT_methods, /*tp_methods*/ 264 | 0, /*tp_members*/ 265 | 0, /*tp_getset*/ 266 | 0, /*tp_base*/ 267 | 0, /*tp_dict*/ 268 | 0, /*tp_descr_get*/ 269 | 0, /*tp_descr_set*/ 270 | 0, /*tp_dictoffset*/ 271 | (initproc)TwoSAT_init, /*tp_init*/ 272 | 0, /*tp_alloc*/ 273 | TwoSAT_new, /*tp_new*/ 274 | 0, /*tp_free*/ 275 | 0, /*tp_is_gc*/ 276 | 0, /*tp_bases*/ 277 | 0, /*tp_mro*/ 278 | 0, /*tp_cache*/ 279 | 0, /*tp_subclasses*/ 280 | 0, /*tp_weaklist*/ 281 | 0, /*tp_del*/ 282 | 0, /*tp_version_tag*/ 283 | 0, /*tp_finalize*/ 284 | }; 285 | 286 | // <<< TwoSAT definition <<< 287 | 288 | 289 | static PyModuleDef acl_twosatmodule = { 290 | PyModuleDef_HEAD_INIT, 291 | "acl_twosat", 292 | NULL, 293 | -1, 294 | }; 295 | 296 | PyMODINIT_FUNC PyInit_acl_twosat(void) 297 | { 298 | PyObject* m; 299 | if(PyType_Ready(&TwoSATType) < 0) return NULL; 300 | 301 | m = PyModule_Create(&acl_twosatmodule); 302 | if(m == NULL) return NULL; 303 | 304 | Py_INCREF(&TwoSATType); 305 | if (PyModule_AddObject(m, "TwoSAT", (PyObject*)&TwoSATType) < 0) { 306 | Py_DECREF(&TwoSATType); 307 | Py_DECREF(m); 308 | return NULL; 309 | } 310 | 311 | return m; 312 | } 313 | """ 314 | code_two_sat_setup = r""" 315 | from distutils.core import setup, Extension 316 | module = Extension( 317 | "acl_twosat", 318 | sources=["two_sat.cpp"], 319 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 320 | ) 321 | setup( 322 | name="acl_twosat", 323 | version="0.0.1", 324 | description="wrapper for atcoder library twosat", 325 | ext_modules=[module] 326 | ) 327 | """ 328 | 329 | import os 330 | import sys 331 | 332 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 333 | with open("two_sat.cpp", "w") as f: 334 | f.write(code_two_sat) 335 | with open("two_sat_setup.py", "w") as f: 336 | f.write(code_two_sat_setup) 337 | os.system(f"{sys.executable} two_sat_setup.py build_ext --inplace") 338 | 339 | from acl_twosat import TwoSAT 340 | -------------------------------------------------------------------------------- /cpp_extension/ctypes_cpp_set.py: -------------------------------------------------------------------------------- 1 | # 拡張モジュールの方が速いっぽい? 2 | 3 | cppset_code = r""" // 参考: https://atcoder.jp/contests/abc128/submissions/5808742 4 | #include 5 | //#undef __GNUC__ 6 | #ifdef __GNUC__ 7 | #include 8 | #include 9 | using namespace std; 10 | using namespace __gnu_pbds; 11 | using pb_set = tree< 12 | long long, 13 | null_type, 14 | less, 15 | rb_tree_tag, 16 | tree_order_statistics_node_update 17 | >; 18 | #else 19 | #include 20 | using namespace std; 21 | using pb_set = set; 22 | #endif 23 | 24 | extern "C" { 25 | 26 | void* set_construct(){ 27 | return (void*)(new pb_set); 28 | } 29 | 30 | void set_destruct(void* st){ 31 | delete (pb_set*)st; 32 | } 33 | 34 | bool set_add(void* st, long long x){ 35 | return ((pb_set*)st)->insert(x).second; 36 | } 37 | 38 | void set_remove(void* st, long long x){ 39 | auto it = ((pb_set*)st)->find(x); 40 | if(it == ((pb_set*)st)->end()){ 41 | //fprintf(stderr, "cppset remove: KeyError\n"); 42 | return; 43 | } 44 | ((pb_set*)st)->erase(it); 45 | } 46 | 47 | long long set_search_higher_equal(void* st, long long x){ 48 | return *((pb_set*)st)->lower_bound(x); 49 | } 50 | 51 | long long set_min(void* st){ 52 | if(((pb_set*)st)->size()==0){ 53 | //fprintf(stderr, "min from an empty set"); 54 | return -1; 55 | } 56 | return *((pb_set*)st)->begin(); 57 | } 58 | 59 | long long set_max(void* st){ 60 | if(((pb_set*)st)->size()==0){ 61 | //fprintf(stderr, "max from an empty set"); 62 | return -1; 63 | } 64 | return *prev(((pb_set*)st)->end()); 65 | } 66 | 67 | long long set_pop_min(void* st){ 68 | if(((pb_set*)st)->size()==0){ 69 | //fprintf(stderr, "pop_min from an empty set"); 70 | return -1; 71 | } 72 | auto it = ((pb_set*)st)->begin(); 73 | long long res = *it; 74 | ((pb_set*)st)->erase(it); 75 | return res; 76 | } 77 | 78 | long long set_pop_max(void* st){ 79 | if(((pb_set*)st)->size()==0){ 80 | //fprintf(stderr, "pop_max from an empty set"); 81 | return -1; 82 | } 83 | auto it = prev(((pb_set*)st)->end()); 84 | long long res = *it; 85 | ((pb_set*)st)->erase(it); 86 | return res; 87 | } 88 | 89 | long long set_len(void* st){ 90 | return ((pb_set*)st)->size(); 91 | } 92 | 93 | bool set_contains(void* st, long long x){ 94 | return ((pb_set*)st)->find(x) != ((pb_set*)st)->end(); 95 | } 96 | 97 | long long set_getitem(void* st_, long long idx){ 98 | pb_set* st = (pb_set*)st_; 99 | long long idx_pos = idx >= 0 ? idx : idx + (long long)st->size(); 100 | if(idx_pos >= (long long)st->size() || idx_pos < 0){ 101 | //fprintf(stderr, "cppset getitem: index out of range\n"); 102 | return -1; 103 | } 104 | #ifdef __GNUC__ 105 | auto it = st->find_by_order(idx_pos); 106 | #else 107 | auto it = st->begin(); 108 | for(int i=0; i= 0 ? idx : idx + (long long)st->size(); 116 | if(idx_pos >= (long long)st->size() || idx_pos < 0){ 117 | //fprintf(stderr, "cppset pop: index out of range\n"); 118 | return -1; 119 | } 120 | #ifdef __GNUC__ 121 | auto it = st->find_by_order(idx_pos); 122 | #else 123 | auto it = st->begin(); 124 | for(int i=0; ierase(it); 128 | return res; 129 | } 130 | 131 | long long set_index(void* st, long long x){ 132 | #ifdef __GNUC__ 133 | return ((pb_set*)st)->order_of_key(x); 134 | #else 135 | long long res = 0; 136 | auto it = ((pb_set*)st)->begin(); 137 | while(it != ((pb_set*)st)->end() && *it < x) it++, res++; 138 | return res; 139 | #endif 140 | } 141 | 142 | 143 | } // extern "C" 144 | """ 145 | 146 | import os 147 | import sys 148 | from functools import partial 149 | import distutils.ccompiler 150 | from ctypes import CDLL, CFUNCTYPE, c_bool, c_longlong, c_void_p 151 | 152 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 153 | with open("cppset.cpp", "w") as f: 154 | f.write(cppset_code) 155 | if os.name == "nt": 156 | os.system(r'"C:\Program Files\mingw-w64\x86_64-8.1.0-posix-seh-rt_v6-rev0\mingw64\bin\g++" -fPIC -shared -std=c++14 -O3 cppset.cpp -o cppset.dll') 157 | # compiler = distutils.ccompiler.new_compiler() 158 | # compiler.compile(["cppset.cpp"], extra_postargs=["/LD"]) 159 | # link_args = ["/DLL"] 160 | # compiler.link_shared_lib(["cppset.obj"], "cppset", extra_postargs=link_args) 161 | else: 162 | os.system(f"g++ -fPIC -shared -std=c++14 -O3 cppset.cpp -o cppset.so") 163 | if os.name == "nt": 164 | lib = CDLL(f"{os.getcwd()}/cppset.dll") 165 | else: 166 | lib = CDLL(f"{os.getcwd()}/cppset.so") 167 | 168 | 169 | class CppSetInt: 170 | def __init__(self): 171 | self.ptr = CFUNCTYPE(c_void_p)(("set_construct", lib))() 172 | self.add = partial(CFUNCTYPE(c_bool, c_void_p, c_longlong)(("set_add", lib)), self.ptr) 173 | self.remove = partial(CFUNCTYPE(None, c_void_p, c_longlong)(("set_remove", lib)), self.ptr) 174 | self.search_higher_equal = partial( 175 | CFUNCTYPE(c_longlong, c_void_p, c_longlong)(("set_search_higher_equal", lib)), self.ptr) 176 | self.min = partial(CFUNCTYPE(c_longlong, c_void_p)(("set_min", lib)), self.ptr) 177 | self.max = partial(CFUNCTYPE(c_longlong, c_void_p)(("set_max", lib)), self.ptr) 178 | self.pop_min = partial(CFUNCTYPE(c_longlong, c_void_p)(("set_pop_min", lib)), self.ptr) 179 | self.pop_max = partial(CFUNCTYPE(c_longlong, c_void_p)(("set_pop_max", lib)), self.ptr) 180 | self.__len__ = partial(CFUNCTYPE(c_longlong, c_void_p)(("set_len", lib)), self.ptr) 181 | self.contains = partial(CFUNCTYPE(c_bool, c_void_p, c_longlong)(("set_contains", lib)), self.ptr) 182 | self.__getitem__ = partial(CFUNCTYPE(c_longlong, c_void_p, c_longlong)(("set_getitem", lib)), self.ptr) 183 | self.pop = partial(CFUNCTYPE(c_longlong, c_void_p, c_longlong)(("set_pop", lib)), self.ptr) 184 | self.index = partial(CFUNCTYPE(c_longlong, c_void_p, c_longlong)(("set_index", lib)), self.ptr) 185 | 186 | def __len__(self): 187 | return self.__len__() 188 | 189 | def __contains__(self): 190 | return self.__contains__() 191 | 192 | -------------------------------------------------------------------------------- /cpp_extension/shorten.py: -------------------------------------------------------------------------------- 1 | # C++ のコードを 1 行にする 2 | 3 | import re 4 | def comment_remover(text): 5 | # from https://stackoverflow.com/questions/241327/remove-c-and-c-comments-using-python 6 | def replacer(match): 7 | s = match.group(0) 8 | if s.startswith('/'): 9 | return " " # note: a space and not an empty string 10 | else: 11 | return s 12 | pattern = re.compile( 13 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 14 | re.DOTALL | re.MULTILINE 15 | ) 16 | return re.sub(pattern, replacer, text) 17 | 18 | def shorten(text): 19 | text = comment_remover(text) 20 | return re.sub(r" +", " ", text).replace("\n ", "\n").replace("\n", r"\n") 21 | -------------------------------------------------------------------------------- /cpp_extension/template.py: -------------------------------------------------------------------------------- 1 | code_cppset = r""" 2 | #define PY_SSIZE_T_CLEAN 3 | #include 4 | #include "structmember.h" 5 | 6 | typedef struct { 7 | PyObject_HEAD 8 | PyObject *first; /* first name */ 9 | PyObject *last; /* last name */ 10 | int number; 11 | } CustomObject; 12 | 13 | static void 14 | Custom_dealloc(CustomObject *self) 15 | { 16 | Py_XDECREF(self->first); 17 | Py_XDECREF(self->last); 18 | Py_TYPE(self)->tp_free((PyObject *) self); 19 | } 20 | 21 | static PyObject * 22 | Custom_new(PyTypeObject *type, PyObject *args, PyObject *kwds) 23 | { 24 | CustomObject *self; 25 | self = (CustomObject *) type->tp_alloc(type, 0); 26 | if (self != NULL) { 27 | self->first = PyUnicode_FromString(""); 28 | if (self->first == NULL) { 29 | Py_DECREF(self); 30 | return NULL; 31 | } 32 | self->last = PyUnicode_FromString(""); 33 | if (self->last == NULL) { 34 | Py_DECREF(self); 35 | return NULL; 36 | } 37 | self->number = 0; 38 | } 39 | return (PyObject *) self; 40 | } 41 | 42 | static int 43 | Custom_init(CustomObject *self, PyObject *args, PyObject *kwds) 44 | { 45 | static char *kwlist[] = {"first", "last", "number", NULL}; 46 | PyObject *first = NULL, *last = NULL, *tmp; 47 | 48 | if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi", kwlist, 49 | &first, &last, 50 | &self->number)) 51 | return -1; 52 | 53 | if (first) { 54 | tmp = self->first; 55 | Py_INCREF(first); 56 | self->first = first; 57 | Py_XDECREF(tmp); 58 | } 59 | if (last) { 60 | tmp = self->last; 61 | Py_INCREF(last); 62 | self->last = last; 63 | Py_XDECREF(tmp); 64 | } 65 | return 0; 66 | } 67 | 68 | static PyMemberDef Custom_members[] = { 69 | {"first", T_OBJECT_EX, offsetof(CustomObject, first), 0, 70 | "first name"}, 71 | {"last", T_OBJECT_EX, offsetof(CustomObject, last), 0, 72 | "last name"}, 73 | {"number", T_INT, offsetof(CustomObject, number), 0, 74 | "custom number"}, 75 | {NULL} /* Sentinel */ 76 | }; 77 | 78 | static PyObject * 79 | Custom_name(CustomObject *self, PyObject *Py_UNUSED(ignored)) 80 | { 81 | if (self->first == NULL) { 82 | PyErr_SetString(PyExc_AttributeError, "first"); 83 | return NULL; 84 | } 85 | if (self->last == NULL) { 86 | PyErr_SetString(PyExc_AttributeError, "last"); 87 | return NULL; 88 | } 89 | return PyUnicode_FromFormat("%S %S", self->first, self->last); 90 | } 91 | 92 | static PyMethodDef Custom_methods[] = { 93 | {"name", (PyCFunction) Custom_name, METH_NOARGS, 94 | "Return the name, combining the first and last name" 95 | }, 96 | {NULL} /* Sentinel */ 97 | }; 98 | 99 | static PyTypeObject CustomType = { 100 | PyVarObject_HEAD_INIT(NULL, 0) 101 | "custom2.Custom", /*tp_name*/ 102 | sizeof(CustomObject), /*tp_basicsize*/ 103 | 0, /*tp_itemsize*/ 104 | (destructor) Custom_dealloc, /*tp_dealloc*/ 105 | 0, /*tp_print*/ 106 | 0, /*tp_getattr*/ 107 | 0, /*tp_setattr*/ 108 | 0, /*reserved*/ 109 | 0, /*tp_repr*/ 110 | 0, /*tp_as_number*/ 111 | 0, /*tp_as_sequence*/ 112 | 0, /*tp_as_mapping*/ 113 | 0, /*tp_hash*/ 114 | 0, /*tp_call*/ 115 | 0, /*tp_str*/ 116 | 0, /*tp_getattro*/ 117 | 0, /*tp_setattro*/ 118 | 0, /*tp_as_buffer*/ 119 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 120 | 0, /*tp_doc*/ 121 | 0, /*tp_traverse*/ 122 | 0, /*tp_clear*/ 123 | 0, /*tp_richcompare*/ 124 | 0, /*tp_weaklistoffset*/ 125 | 0, /*tp_iter*/ 126 | 0, /*tp_iternext*/ 127 | Custom_methods, /*tp_methods*/ 128 | Custom_members, /*tp_members*/ 129 | 0, /*tp_getset*/ 130 | 0, /*tp_base*/ 131 | 0, /*tp_dict*/ 132 | 0, /*tp_descr_get*/ 133 | 0, /*tp_descr_set*/ 134 | 0, /*tp_dictoffset*/ 135 | (initproc) Custom_init, /*tp_init*/ 136 | 0, /*tp_alloc*/ 137 | Custom_new, /*tp_new*/ 138 | 0, /*tp_free*/ 139 | 0, /*tp_is_gc*/ 140 | 0, /*tp_bases*/ 141 | 0, /*tp_mro*/ 142 | 0, /*tp_cache*/ 143 | 0, /*tp_subclasses*/ 144 | 0, /*tp_weaklist*/ 145 | 0, /*tp_del*/ 146 | 0, /*tp_version_tag*/ 147 | 0, /*tp_finalize*/ 148 | }; 149 | 150 | static PyModuleDef custommodule = { 151 | PyModuleDef_HEAD_INIT, 152 | "custom2", 153 | NULL, 154 | -1, 155 | }; 156 | 157 | PyMODINIT_FUNC 158 | PyInit_custom2(void) 159 | { 160 | PyObject *m; 161 | if (PyType_Ready(&CustomType) < 0) 162 | return NULL; 163 | 164 | m = PyModule_Create(&custommodule); 165 | if (m == NULL) 166 | return NULL; 167 | 168 | Py_INCREF(&CustomType); 169 | if (PyModule_AddObject(m, "Custom", (PyObject *) &CustomType) < 0) { 170 | Py_DECREF(&CustomType); 171 | Py_DECREF(m); 172 | return NULL; 173 | } 174 | 175 | return m; 176 | } 177 | """ 178 | code_setup = r""" 179 | from distutils.core import setup, Extension 180 | module = Extension( 181 | "custom2", 182 | sources=["set_wrapper.cpp"], 183 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 184 | ) 185 | setup( 186 | name="SetMethod", 187 | version="0.1.2", 188 | description="wrapper for C++ set", 189 | ext_modules=[module] 190 | ) 191 | """ 192 | 193 | import os 194 | import sys 195 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 196 | with open("set_wrapper.cpp", "w") as f: 197 | f.write(code_cppset) 198 | with open("setup.py", "w") as f: 199 | f.write(code_setup) 200 | os.system(f"{sys.executable} setup.py build_ext --inplace") 201 | 202 | 203 | import custom2 204 | 205 | 206 | print(custom2.__dir__()) 207 | print(custom2.Custom().__dir__()) 208 | a = custom2.Custom() 209 | b = custom2.Custom() 210 | a.first = "kae" 211 | a.last = "hiiragi" 212 | b.first = "kasumi" 213 | b.last = "honjo" 214 | print(a.name()) 215 | print(b.name()) 216 | -------------------------------------------------------------------------------- /cpp_extension/wrap_cpp_multiset.py: -------------------------------------------------------------------------------- 1 | # 検証: https://atcoder.jp/contests/cpsco2019-s1/submissions/14705333 2 | 3 | import os 4 | import sys 5 | python_version = f"{sys.version_info.major}.{sys.version_info.minor}" 6 | 7 | if sys.argv[-1] == "ONLINE_JUDGE": 8 | code_multiset = r""" 9 | #include 10 | #include 11 | #include 12 | using namespace std; 13 | 14 | // 参考1: http://www.speedupcode.com/c-class-in-python3/ 15 | // 参考2: https://qiita.com/junkoda/items/2b1eda7569186809ca14 16 | 17 | const static auto comp_pyobj = [](PyObject* const & lhs, PyObject* const & rhs){ 18 | return Py_TYPE(lhs)->tp_richcompare(lhs, rhs, Py_LT) == Py_True; 19 | }; 20 | 21 | struct MultiSet4PyObject{ 22 | multiset st; 23 | MultiSet4PyObject() : st(comp_pyobj) {} 24 | MultiSet4PyObject(vector& vec) : st(vec.begin(), vec.end(), comp_pyobj) {} 25 | ~MultiSet4PyObject(){ 26 | for(PyObject* const & p : st){ 27 | Py_DECREF(p); 28 | } 29 | } 30 | void add(PyObject* x){ 31 | Py_INCREF(x); 32 | st.insert(x); 33 | } 34 | void remove(PyObject* x){ 35 | auto it = st.find(x); 36 | st.erase(it); 37 | Py_DECREF(*it); 38 | } 39 | PyObject* search_higher_equal(PyObject* x) const { 40 | return *st.lower_bound(x); 41 | } 42 | }; 43 | 44 | void MultiSet4PyObject_free(PyObject *obj){ 45 | MultiSet4PyObject* const p = (MultiSet4PyObject*) PyCapsule_GetPointer(obj, "MultiSet4PyObjectPtr"); 46 | delete p; 47 | } 48 | PyObject* MultiSet4PyObject_construct(PyObject* self, PyObject* args){ 49 | MultiSet4PyObject* ms = new MultiSet4PyObject(); 50 | return PyCapsule_New((void*)ms, "MultiSet4PyObjectPtr", MultiSet4PyObject_free); 51 | } 52 | PyObject* MultiSet4PyObject_construct_from_list(PyObject* self, PyObject* args){ 53 | PyObject* lst; 54 | if(!PyArg_ParseTuple(args, "O", &lst)) return NULL; 55 | int siz; 56 | if(PyList_Check(lst)) siz = PyList_GET_SIZE(lst); 57 | else if(PyTuple_Check(lst)) siz = PyTuple_GET_SIZE(lst); 58 | else return NULL; 59 | vector vec(siz); 60 | for(int i=0; iadd(x); 72 | return Py_BuildValue(""); 73 | } 74 | PyObject* MultiSet4PyObject_remove(PyObject* self, PyObject* args){ 75 | PyObject *msCapsule, *x; 76 | if(!PyArg_ParseTuple(args, "OO", &msCapsule, &x)) return NULL; 77 | MultiSet4PyObject* ms = (MultiSet4PyObject*)PyCapsule_GetPointer(msCapsule, "MultiSet4PyObjectPtr"); 78 | ms->remove(x); 79 | return Py_BuildValue(""); 80 | } 81 | PyObject* MultiSet4PyObject_search_higher_equal(PyObject* self, PyObject* args){ 82 | PyObject *msCapsule, *x; 83 | if(!PyArg_ParseTuple(args, "OO", &msCapsule, &x)) return NULL; 84 | MultiSet4PyObject* ms = (MultiSet4PyObject*)PyCapsule_GetPointer(msCapsule, "MultiSet4PyObjectPtr"); 85 | PyObject* res = ms->search_higher_equal(x); 86 | return Py_BuildValue("O", res); 87 | } 88 | 89 | static PyMethodDef MultiSetMethods[] = { 90 | {"construct", MultiSet4PyObject_construct, METH_VARARGS, "Create multiset object"}, 91 | {"construct_from_list", MultiSet4PyObject_construct_from_list, METH_VARARGS, "Create multiset object from list"}, 92 | {"add", MultiSet4PyObject_add, METH_VARARGS, "Add item"}, 93 | {"remove", MultiSet4PyObject_remove, METH_VARARGS, "Remove item"}, 94 | {"search_higher_equal", MultiSet4PyObject_search_higher_equal, METH_VARARGS, "Search item"}, 95 | {NULL, NULL, 0, NULL} 96 | }; 97 | 98 | static struct PyModuleDef multisetmodule = { 99 | PyModuleDef_HEAD_INIT, 100 | "multiset", 101 | NULL, 102 | -1, 103 | MultiSetMethods, 104 | }; 105 | 106 | PyMODINIT_FUNC PyInit_multiset(void){ 107 | return PyModule_Create(&multisetmodule); 108 | } 109 | """ 110 | code_setup = r""" 111 | from distutils.core import setup, Extension 112 | module = Extension( 113 | "multiset", 114 | sources=["multiset_wrapper.cpp"], 115 | extra_compile_args=["-O3", "-march=native"] 116 | ) 117 | setup( 118 | name="MultiSetMethod", 119 | version="0.0.4", 120 | description="wrapper for C++ multiset", 121 | ext_modules=[module] 122 | ) 123 | """ 124 | with open("multiset_wrapper.cpp", "w") as f: 125 | f.write(code_multiset) 126 | with open("setup.py", "w") as f: 127 | f.write(code_setup) 128 | os.system(f"python{python_version} setup.py build_ext --inplace") 129 | exit() 130 | 131 | 132 | import multiset 133 | 134 | -------------------------------------------------------------------------------- /cpp_extension/wrap_cpp_set.py: -------------------------------------------------------------------------------- 1 | code_cppset = r""" 2 | #define PY_SSIZE_T_CLEAN 3 | #include 4 | #include "structmember.h" 5 | #include 6 | //#undef __GNUC__ // g++ 拡張を使わない場合はここのコメントアウトを外すと高速になる 7 | #ifdef __GNUC__ 8 | #include 9 | #include 10 | using namespace std; 11 | using namespace __gnu_pbds; 12 | const static auto comp_pyobj = [](PyObject* const & lhs, PyObject* const & rhs){ 13 | return (bool)PyObject_RichCompareBool(lhs, rhs, Py_LT); // 比較できない場合 -1 14 | }; 15 | using pb_set = tree< 16 | PyObject*, 17 | null_type, 18 | decltype(comp_pyobj), 19 | rb_tree_tag, 20 | tree_order_statistics_node_update 21 | >; 22 | #else 23 | #include 24 | using namespace std; 25 | const static auto comp_pyobj = [](PyObject* const & lhs, PyObject* const & rhs){ 26 | return (bool)PyObject_RichCompareBool(lhs, rhs, Py_LT); 27 | }; 28 | using pb_set = set; 29 | #endif 30 | #define PARSE_ARGS(types, ...) if(!PyArg_ParseTuple(args, types, __VA_ARGS__)) return NULL 31 | struct Set4PyObject{ 32 | pb_set st; 33 | pb_set::iterator it; 34 | Set4PyObject() : st(comp_pyobj), it(st.begin()) {} 35 | Set4PyObject(vector& vec) : st(vec.begin(), vec.end(), comp_pyobj), it(st.begin()) {} 36 | Set4PyObject(const Set4PyObject& obj) : st(obj.st), it(st.begin()) { 37 | for(PyObject* const & p : st) Py_INCREF(p); 38 | } 39 | ~Set4PyObject(){ 40 | for(PyObject* const & p : st) Py_DECREF(p); 41 | } 42 | bool add(PyObject* x){ 43 | const auto& r = st.insert(x); 44 | it = r.first; 45 | if(r.second){ 46 | Py_INCREF(x); 47 | return true; 48 | }else{ 49 | return false; 50 | } 51 | } 52 | PyObject* remove(PyObject* x){ 53 | it = st.find(x); 54 | if(it == st.end()) return PyErr_SetObject(PyExc_KeyError, x), (PyObject*)NULL; 55 | Py_DECREF(*it); 56 | it = st.erase(it); 57 | if(it == st.end()) return Py_None; 58 | return *it; 59 | } 60 | PyObject* search_higher_equal(PyObject* x){ 61 | it = st.lower_bound(x); 62 | if(it == st.end()) return Py_None; 63 | return *it; 64 | } 65 | PyObject* min(){ 66 | if(st.size()==0) 67 | return PyErr_SetString(PyExc_IndexError, "min from an empty set"), (PyObject*)NULL; 68 | it = st.begin(); 69 | return *it; 70 | } 71 | PyObject* max(){ 72 | if(st.size()==0) 73 | return PyErr_SetString(PyExc_IndexError, "max from an empty set"), (PyObject*)NULL; 74 | it = prev(st.end()); 75 | return *it; 76 | } 77 | PyObject* pop_min(){ 78 | if(st.size()==0) 79 | return PyErr_SetString(PyExc_IndexError, "pop_min from an empty set"), (PyObject*)NULL; 80 | it = st.begin(); 81 | PyObject* res = *it; 82 | it = st.erase(it); 83 | return res; 84 | } 85 | PyObject* pop_max(){ 86 | if(st.size()==0) 87 | return PyErr_SetString(PyExc_IndexError, "pop_max from an empty set"), (PyObject*)NULL; 88 | it = prev(st.end()); 89 | PyObject* res = *it; 90 | it = st.erase(it); 91 | return res; 92 | } 93 | size_t len() const { 94 | return st.size(); 95 | } 96 | PyObject* iter_next(){ 97 | if(it == st.end()) return Py_None; 98 | if(++it == st.end()) return Py_None; 99 | return *it; 100 | } 101 | PyObject* iter_prev(){ 102 | if(it == st.begin()) return Py_None; 103 | return *--it; 104 | } 105 | PyObject* to_list() const { 106 | PyObject* list = PyList_New(st.size()); 107 | int i = 0; 108 | for(PyObject* const & p : st){ 109 | Py_INCREF(p); 110 | PyList_SET_ITEM(list, i++, p); 111 | } 112 | return list; 113 | } 114 | PyObject* get() const { 115 | if(it == st.end()) return Py_None; 116 | return *it; 117 | } 118 | PyObject* erase(){ 119 | if(it == st.end()) return PyErr_SetString(PyExc_KeyError, "erase end"), (PyObject*)NULL; 120 | it = st.erase(it); 121 | if(it == st.end()) return Py_None; 122 | return *it; 123 | } 124 | PyObject* getitem(const long& idx){ 125 | long idx_pos = idx >= 0 ? idx : idx + (long)st.size(); 126 | if(idx_pos >= (long)st.size() || idx_pos < 0) 127 | return PyErr_Format( 128 | PyExc_IndexError, 129 | "cppset getitem index out of range (size=%d, idx=%d)", st.size(), idx 130 | ), (PyObject*)NULL; 131 | #ifdef __GNUC__ 132 | it = st.find_by_order(idx_pos); 133 | #else 134 | it = st.begin(); 135 | for(int i=0; i= 0 ? idx : idx + (long)st.size(); 141 | if(idx_pos >= (long)st.size() || idx_pos < 0) 142 | return PyErr_Format( 143 | PyExc_IndexError, 144 | "cppset pop index out of range (size=%d, idx=%d)", st.size(), idx 145 | ), (PyObject*)NULL; 146 | #ifdef __GNUC__ 147 | it = st.find_by_order(idx_pos); 148 | #else 149 | it = st.begin(); 150 | for(int i=0; ist; 177 | Py_TYPE(self)->tp_free((PyObject*)self); 178 | } 179 | static PyObject* CppSet_new(PyTypeObject* type, PyObject* args, PyObject* kwds){ 180 | CppSet* self; 181 | self = (CppSet*)type->tp_alloc(type, 0); 182 | return (PyObject*)self; 183 | } 184 | static int CppSet_init(CppSet* self, PyObject* args, PyObject* kwds){ 185 | static char* kwlist[] = {(char*)"lst", NULL}; 186 | PyObject* lst = NULL; 187 | if(!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, &lst)) return -1; 188 | if(lst == NULL){ 189 | self->st = new Set4PyObject(); 190 | Py_SIZE(self) = 0; 191 | }else{ 192 | int siz; 193 | if(PyList_Check(lst)) siz = (int)PyList_GET_SIZE(lst); 194 | else if(PyTuple_Check(lst)) siz = (int)PyTuple_GET_SIZE(lst); 195 | else return PyErr_SetString(PyExc_TypeError, "got neither list nor tuple"), NULL; 196 | vector vec(siz); 197 | for(int i=0; ist = new Set4PyObject(vec); 202 | Py_SIZE(self) = siz; 203 | } 204 | return 0; 205 | } 206 | static PyObject* CppSet_add(CppSet* self, PyObject* args){ 207 | PyObject* x; 208 | PARSE_ARGS("O", &x); 209 | bool res = self->st->add(x); 210 | if(res) Py_SIZE(self)++; 211 | return Py_BuildValue("O", res ? Py_True : Py_False); 212 | } 213 | static PyObject* CppSet_remove(CppSet* self, PyObject* args){ 214 | PyObject* x; 215 | PARSE_ARGS("O", &x); 216 | PyObject* res = self->st->remove(x); 217 | if(res==NULL) return (PyObject*)NULL; 218 | Py_SIZE(self)--; 219 | return Py_BuildValue("O", res); 220 | } 221 | static PyObject* CppSet_search_higher_equal(CppSet* self, PyObject* args){ 222 | PyObject* x; 223 | PARSE_ARGS("O", &x); 224 | PyObject* res = self->st->search_higher_equal(x); 225 | return Py_BuildValue("O", res); 226 | } 227 | static PyObject* CppSet_min(CppSet* self, PyObject* args){ 228 | PyObject* res = self->st->min(); 229 | return Py_BuildValue("O", res); 230 | } 231 | static PyObject* CppSet_max(CppSet* self, PyObject* args){ 232 | PyObject* res = self->st->max(); 233 | return Py_BuildValue("O", res); 234 | } 235 | static PyObject* CppSet_pop_min(CppSet* self, PyObject* args){ 236 | PyObject* res = self->st->pop_min(); 237 | if(res==NULL) return (PyObject*)NULL; 238 | Py_SIZE(self)--; 239 | return res; // 参照カウントを増やさない 240 | } 241 | static PyObject* CppSet_pop_max(CppSet* self, PyObject* args){ 242 | PyObject* res = self->st->pop_max(); 243 | if(res==NULL) return (PyObject*)NULL; 244 | Py_SIZE(self)--; 245 | return res; // 参照カウントを増やさない 246 | } 247 | static Py_ssize_t CppSet_len(CppSet* self){ 248 | return Py_SIZE(self); 249 | } 250 | static PyObject* CppSet_next(CppSet* self, PyObject* args){ 251 | PyObject* res = self->st->iter_next(); 252 | return Py_BuildValue("O", res); 253 | } 254 | static PyObject* CppSet_prev(CppSet* self, PyObject* args){ 255 | PyObject* res = self->st->iter_prev(); 256 | return Py_BuildValue("O", res); 257 | } 258 | static PyObject* CppSet_to_list(CppSet* self, PyObject* args){ 259 | PyObject* res = self->st->to_list(); 260 | return res; 261 | } 262 | static PyObject* CppSet_get(CppSet* self, PyObject* args){ 263 | PyObject* res = self->st->get(); 264 | return Py_BuildValue("O", res); 265 | } 266 | static PyObject* CppSet_erase(CppSet* self, PyObject* args){ 267 | PyObject* res = self->st->erase(); 268 | if(res==NULL) return (PyObject*)NULL; 269 | Py_SIZE(self)--; 270 | return Py_BuildValue("O", res); 271 | } 272 | static PyObject* CppSet_copy(CppSet* self, PyObject* args){ 273 | CppSet* st2 = (CppSet*)CppSet_new(&CppSetType, (PyObject*)NULL, (PyObject*)NULL); 274 | if (st2==NULL) return (PyObject*)NULL; 275 | st2->st = new Set4PyObject(*self->st); 276 | Py_SIZE(st2) = Py_SIZE(self); 277 | return (PyObject*)st2; 278 | } 279 | static PyObject* CppSet_getitem(CppSet* self, Py_ssize_t idx){ 280 | PyObject* res = self->st->getitem((long)idx); 281 | return Py_BuildValue("O", res); 282 | } 283 | static PyObject* CppSet_pop(CppSet* self, PyObject* args){ 284 | long idx; 285 | PARSE_ARGS("l", &idx); 286 | PyObject* res = self->st->pop(idx); 287 | if(res==NULL) return (PyObject*)NULL; 288 | Py_SIZE(self)--; 289 | return Py_BuildValue("O", res); 290 | } 291 | static PyObject* CppSet_index(CppSet* self, PyObject* args){ 292 | PyObject* x; 293 | PARSE_ARGS("O", &x); 294 | long res = self->st->index(x); 295 | return Py_BuildValue("l", res); 296 | } 297 | static int CppSet_contains(CppSet* self, PyObject* x){ 298 | return PyObject_RichCompareBool(self->st->search_higher_equal(x), x, Py_EQ); 299 | } 300 | static int CppSet_bool(CppSet* self){ 301 | return Py_SIZE(self) != 0; 302 | } 303 | static PyObject* CppSet_repr(PyObject* self){ 304 | PyObject *result, *aslist; 305 | aslist = ((CppSet*)self)->st->to_list(); 306 | result = PyUnicode_FromFormat("CppSet(%R)", aslist); 307 | Py_ReprLeave(self); 308 | Py_DECREF(aslist); 309 | return result; 310 | } 311 | 312 | static PyMethodDef CppSet_methods[] = { 313 | {"add", (PyCFunction)CppSet_add, METH_VARARGS, "Add item"}, 314 | {"remove", (PyCFunction)CppSet_remove, METH_VARARGS, "Remove item"}, 315 | {"search_higher_equal", (PyCFunction)CppSet_search_higher_equal, METH_VARARGS, "Search item"}, 316 | {"min", (PyCFunction)CppSet_min, METH_VARARGS, "Get minimum item"}, 317 | {"max", (PyCFunction)CppSet_max, METH_VARARGS, "Get maximum item"}, 318 | {"pop_min", (PyCFunction)CppSet_pop_min, METH_VARARGS, "Pop minimum item"}, 319 | {"pop_max", (PyCFunction)CppSet_pop_max, METH_VARARGS, "Pop maximum item"}, 320 | {"next", (PyCFunction)CppSet_next, METH_VARARGS, "Get next value"}, 321 | {"prev", (PyCFunction)CppSet_prev, METH_VARARGS, "Get previous value"}, 322 | {"to_list", (PyCFunction)CppSet_to_list, METH_VARARGS, "Make list from set"}, 323 | {"get", (PyCFunction)CppSet_get, METH_VARARGS, "Get item that iterator is pointing at"}, 324 | {"erase", (PyCFunction)CppSet_erase, METH_VARARGS, "Erase item that iterator is pointing at"}, 325 | {"copy", (PyCFunction)CppSet_copy, METH_VARARGS, "Copy set"}, 326 | {"getitem", (PyCFunction)CppSet_getitem, METH_VARARGS, "Get item by index"}, 327 | {"pop", (PyCFunction)CppSet_pop, METH_VARARGS, "Pop item"}, 328 | {"index", (PyCFunction)CppSet_index, METH_VARARGS, "Get index of item"}, 329 | {NULL} /* Sentinel */ 330 | }; 331 | static PySequenceMethods CppSet_as_sequence = { 332 | (lenfunc)CppSet_len, /* sq_length */ 333 | 0, /* sq_concat */ 334 | 0, /* sq_repeat */ 335 | (ssizeargfunc)CppSet_getitem, /* sq_item */ 336 | 0, /* sq_slice */ 337 | 0, /* sq_ass_item */ 338 | 0, /* sq_ass_slice */ 339 | (objobjproc)CppSet_contains, /* sq_contains */ 340 | 0, /* sq_inplace_concat */ 341 | 0, /* sq_inplace_repeat */ 342 | }; 343 | static PyNumberMethods CppSet_as_number = { 344 | 0, /* nb_add */ 345 | 0, /* nb_subtract */ 346 | 0, /* nb_multiply */ 347 | 0, /* nb_remainder */ 348 | 0, /* nb_divmod */ 349 | 0, /* nb_power */ 350 | 0, /* nb_negative */ 351 | 0, /* nb_positive */ 352 | 0, /* nb_absolute */ 353 | (inquiry)CppSet_bool, /* nb_bool */ 354 | 0, /* nb_invert */ 355 | }; 356 | PyTypeObject CppSetType = { 357 | PyVarObject_HEAD_INIT(NULL, 0) 358 | "cppset.CppSet", /*tp_name*/ 359 | sizeof(CppSet), /*tp_basicsize*/ 360 | 0, /*tp_itemsize*/ 361 | (destructor) CppSet_dealloc, /*tp_dealloc*/ 362 | 0, /*tp_print*/ 363 | 0, /*tp_getattr*/ 364 | 0, /*tp_setattr*/ 365 | 0, /*reserved*/ 366 | CppSet_repr, /*tp_repr*/ 367 | &CppSet_as_number, /*tp_as_number*/ 368 | &CppSet_as_sequence, /*tp_as_sequence*/ 369 | 0, /*tp_as_mapping*/ 370 | 0, /*tp_hash*/ 371 | 0, /*tp_call*/ 372 | 0, /*tp_str*/ 373 | 0, /*tp_getattro*/ 374 | 0, /*tp_setattro*/ 375 | 0, /*tp_as_buffer*/ 376 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 377 | 0, /*tp_doc*/ 378 | 0, /*tp_traverse*/ 379 | 0, /*tp_clear*/ 380 | 0, /*tp_richcompare*/ 381 | 0, /*tp_weaklistoffset*/ 382 | 0, /*tp_iter*/ 383 | 0, /*tp_iternext*/ 384 | CppSet_methods, /*tp_methods*/ 385 | 0, /*tp_members*/ 386 | 0, /*tp_getset*/ 387 | 0, /*tp_base*/ 388 | 0, /*tp_dict*/ 389 | 0, /*tp_descr_get*/ 390 | 0, /*tp_descr_set*/ 391 | 0, /*tp_dictoffset*/ 392 | (initproc)CppSet_init, /*tp_init*/ 393 | 0, /*tp_alloc*/ 394 | CppSet_new, /*tp_new*/ 395 | 0, /*tp_free*/ 396 | 0, /*tp_is_gc*/ 397 | 0, /*tp_bases*/ 398 | 0, /*tp_mro*/ 399 | 0, /*tp_cache*/ 400 | 0, /*tp_subclasses*/ 401 | 0, /*tp_weaklist*/ 402 | 0, /*tp_del*/ 403 | 0, /*tp_version_tag*/ 404 | 0, /*tp_finalize*/ 405 | }; 406 | 407 | static PyModuleDef cppsetmodule = { 408 | PyModuleDef_HEAD_INIT, 409 | "cppset", 410 | NULL, 411 | -1, 412 | }; 413 | 414 | PyMODINIT_FUNC PyInit_cppset(void) 415 | { 416 | PyObject* m; 417 | if(PyType_Ready(&CppSetType) < 0) return NULL; 418 | 419 | m = PyModule_Create(&cppsetmodule); 420 | if(m == NULL) return NULL; 421 | 422 | Py_INCREF(&CppSetType); 423 | if (PyModule_AddObject(m, "CppSet", (PyObject*) &CppSetType) < 0) { 424 | Py_DECREF(&CppSetType); 425 | Py_DECREF(m); 426 | return NULL; 427 | } 428 | 429 | return m; 430 | } 431 | """ 432 | code_setup = r""" 433 | from distutils.core import setup, Extension 434 | module = Extension( 435 | "cppset", 436 | sources=["set_wrapper.cpp"], 437 | extra_compile_args=["-O3", "-march=native", "-std=c++14"] 438 | ) 439 | setup( 440 | name="SetMethod", 441 | version="0.2.1", 442 | description="wrapper for C++ set", 443 | ext_modules=[module] 444 | ) 445 | """ 446 | 447 | import os 448 | import sys 449 | if sys.argv[-1] == "ONLINE_JUDGE" or os.getcwd() != "/imojudge/sandbox": 450 | with open("set_wrapper.cpp", "w") as f: 451 | f.write(code_cppset) 452 | with open("setup.py", "w") as f: 453 | f.write(code_setup) 454 | os.system(f"{sys.executable} setup.py build_ext --inplace") 455 | 456 | from cppset import CppSet 457 | -------------------------------------------------------------------------------- /geometry.py: -------------------------------------------------------------------------------- 1 | def norm(x1, y1, x2, y2): 2 | # hypot を使ったほうが良さそう 3 | return ((x1-x2)**2 + (y1-y2)**2)**0.5 4 | 5 | def d(a, b, c, x, y): 6 | # 点と直線の距離 7 | return abs(a*x + b*y + c) / (a**2 + b**2)**0.5 8 | 9 | def line_cross(x1, y1, x2, y2, x3, y3, x4, y4): 10 | # 線分 AB と CD の交差判定 11 | def f(x, y, x1, y1, x2, y2): # 直線上にあるとき 0 になる 12 | return (x1-x2)*(y-y1)+(y1-y2)*(x1-x) 13 | b1 = f(x3, y3, x1, y1, x2, y2) * f(x4, y4, x1, y1, x2, y2) < 0 # 点 C と点 D が直線 AB の異なる側にある 14 | b2 = f(x1, y1, x3, y3, x4, y4) * f(x2, y2, x3, y3, x4, y4) < 0 # 点 A と点 B が直線 CD の異なる側にある 15 | return b1 and b2 16 | 17 | def get_convex_hull(points): # 複素数 18 | # 凸包 Monotone Chain O(nlogn) 19 | # 参考: https://matsu7874.hatenablog.com/entry/2018/12/17/025713 20 | def det(p, q): 21 | return (p.conjugate()*q).imag 22 | points.sort(key=lambda x: (x.real, x.imag)) 23 | ch = [] 24 | for p in points: 25 | while len(ch) > 1: 26 | v_cur = ch[-1]-ch[-2] 27 | v_new = p-ch[-2] 28 | if det(v_cur, v_new) > 0: 29 | break 30 | ch.pop() 31 | ch.append(p) 32 | t = len(ch) 33 | for p in points[-2::-1]: 34 | while len(ch) > t: 35 | v_cur = ch[-1]-ch[-2] 36 | v_new = p-ch[-2] 37 | if det(v_cur, v_new) > 0: 38 | break 39 | ch.pop() 40 | ch.append(p) 41 | return ch[:-1] 42 | 43 | def minimum_enclosing_circle(points): # 複素数 44 | # 最小包含円 O(N) 45 | # 返り値は中心の座標と半径 46 | # 参考: https://tubo28.me/compprog/algorithm/minball/ 47 | # 検証: https://atcoder.jp/contests/abc151/submissions/15834319 48 | from random import sample 49 | N = len(points) 50 | if N == 1: 51 | return points[0], 0 52 | points = sample(points, N) 53 | def cross(a, b): 54 | return a.real * b.imag - a.imag * b.real 55 | def norm2(a): 56 | return a.real * a.real + a.imag * a.imag 57 | def make_circle_3(a, b, c): 58 | A, B, C = norm2(b-c), norm2(c-a), norm2(a-b) 59 | S = cross(b-a, c-a) 60 | p = (A*(B+C-A)*a + B*(C+A-B)*b + C*(A+B-C)*c) / (4*S*S) 61 | radius = abs(p-a) 62 | return p, radius 63 | def make_circle_2(a, b): 64 | c = (a+b) / 2 65 | radius = abs(a-c) 66 | return c, radius 67 | def in_circle(point, circle): 68 | return abs(point-circle[0]) <= circle[1]+1e-7 69 | p0 = points[0] 70 | circle = make_circle_2(p0, points[1]) 71 | for i, p_i in enumerate(points[2:], 2): 72 | if not in_circle(p_i, circle): 73 | circle = make_circle_2(p0, p_i) 74 | for j, p_j in enumerate(points[1:i], 1): 75 | if not in_circle(p_j, circle): 76 | circle = make_circle_2(p_i, p_j) 77 | for p_k in points[:j]: 78 | if not in_circle(p_k, circle): 79 | circle = make_circle_3(p_i, p_j, p_k) 80 | return circle 81 | 82 | 83 | def intersection(circle, polygon): 84 | # circle: (x, y, r) 85 | # polygon: [(x1, y1), (x2, y2), ...] 86 | # 円と多角形の共通部分の面積 87 | # 多角形の点が反時計回りで与えられれば正の値、時計回りなら負の値を返す 88 | from math import acos, hypot, isclose, sqrt 89 | def cross(v1, v2): # 外積 90 | x1, y1 = v1 91 | x2, y2 = v2 92 | return x1 * y2 - x2 * y1 93 | 94 | def dot(v1, v2): # 内積 95 | x1, y1 = v1 96 | x2, y2 = v2 97 | return x1 * x2 + y1 * y2 98 | 99 | def seg_intersection(circle, seg): 100 | # 円と線分の交点(円の中心が原点でない場合は未検証) 101 | x0, y0, r = circle 102 | p1, p2 = seg 103 | x1, y1 = p1 104 | x2, y2 = p2 105 | 106 | p1p2 = (x2 - x1) ** 2 + (y2 - y1) ** 2 107 | op1 = (x1 - x0) ** 2 + (y1 - y0) ** 2 108 | rr = r * r 109 | dp = dot((x1 - x0, y1 - y0), (x2 - x1, y2 - y1)) 110 | 111 | d = dp * dp - p1p2 * (op1 - rr) 112 | ps = [] 113 | 114 | if isclose(d, 0.0, abs_tol=1e-9): 115 | t = -dp / p1p2 116 | if ge(t, 0.0) and le(t, 1.0): 117 | ps.append((x1 + t * (x2 - x1), y1 + t * (y2 - y1))) 118 | elif d > 0.0: 119 | t1 = (-dp - sqrt(d)) / p1p2 120 | if ge(t1, 0.0) and le(t1, 1.0): 121 | ps.append((x1 + t1 * (x2 - x1), y1 + t1 * (y2 - y1))) 122 | t2 = (-dp + sqrt(d)) / p1p2 123 | if ge(t2, 0.0) and le(t2, 1.0): 124 | ps.append((x1 + t2 * (x2 - x1), y1 + t2 * (y2 - y1))) 125 | 126 | # assert all(isclose(r, hypot(x, y)) for x, y in ps) 127 | return ps 128 | 129 | def le(f1, f2): # less equal 130 | return f1 < f2 or isclose(f1, f2, abs_tol=1e-9) 131 | 132 | def ge(f1, f2): # greater equal 133 | return f1 > f2 or isclose(f1, f2, abs_tol=1e-9) 134 | 135 | x, y, r = circle 136 | polygon = [(xp-x, yp-y) for xp, yp in polygon] 137 | area = 0.0 138 | for p1, p2 in zip(polygon, polygon[1:] + [polygon[0]]): 139 | ps = seg_intersection((0, 0, r), (p1, p2)) 140 | for pp1, pp2 in zip([p1] + ps, ps + [p2]): 141 | c = cross(pp1, pp2) # pp1 と pp2 の位置関係によって正負が変わる 142 | if c == 0: # pp1, pp2, 原点が同一直線上にある場合 143 | continue 144 | d1 = hypot(*pp1) 145 | d2 = hypot(*pp2) 146 | if le(d1, r) and le(d2, r): 147 | area += c / 2 # pp1, pp2, 原点を結んだ三角形の面積 148 | else: 149 | t = acos(dot(pp1, pp2) / (d1 * d2)) # pp1-原点とpp2-原点の成す角 150 | sign = 1.0 if c >= 0 else -1.0 151 | area += sign * r * r * t / 2 # 扇形の面積 152 | return area 153 | 154 | 155 | -------------------------------------------------------------------------------- /numba_library.py: -------------------------------------------------------------------------------- 1 | # 参考 2 | # https://ikatakos.com/pot/programming/python/packages/numba 3 | 4 | 5 | # >>> numba compile >>> 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def numba_compile(numba_config): 11 | import os, sys 12 | if sys.argv[-1] == "ONLINE_JUDGE": 13 | from numba import njit 14 | from numba.pycc import CC 15 | cc = CC("my_module") 16 | for func, signature in numba_config: 17 | globals()[func.__name__] = njit(signature)(func) 18 | cc.export(func.__name__, signature)(func) 19 | cc.compile() 20 | exit() 21 | elif os.name == "posix": 22 | exec(f"from my_module import {','.join(func.__name__ for func, _ in numba_config)}") 23 | for func, _ in numba_config: 24 | globals()[func.__name__] = vars()[func.__name__] 25 | else: 26 | from numba import njit 27 | for func, signature in numba_config: 28 | globals()[func.__name__] = njit(signature, cache=True)(func) 29 | print("compiled!", file=sys.stderr) 30 | 31 | def solve(In): 32 | idx_In = np.array([-1], dtype=np.int64) 33 | def read(): 34 | idx_In[0] += 1 35 | return In[idx_In[0]] 36 | 37 | numba_compile([ 38 | [solve, "void(i8[:])"], 39 | ]) 40 | 41 | def main(): 42 | In = np.array(sys.stdin.buffer.read().split(), dtype=np.int64) 43 | solve(In) 44 | 45 | main() 46 | 47 | # <<< numba compile <<< 48 | 49 | 50 | # >>> binary indexed tree >>> 51 | # 必要な要素数+1 の長さの ndarray の 1 要素目以降を使う 52 | def bitify(arr): # [bitify, "void(i8[:])"], 53 | # len(arr) は 2 冪 + 1 54 | for i in range(1, len(arr)-1): 55 | arr[i + (i & -i)] += arr[i] 56 | def bit_sum(bit, i): # [bit_sum, "i8(i8[:],i8)"], 57 | # (0, i] 58 | res = 0 59 | while i: 60 | res += bit[i] 61 | i -= i & -i 62 | return res 63 | def bit_add(bit, i, val): # [bit_add, "void(i8[:],i8,i8)"], 64 | n = len(bit) 65 | while i < n: 66 | bit[i] += val 67 | i += i & -i 68 | # <<< binary indexed tree <<< 69 | 70 | 71 | def inversion_number(arr): # [inversion_number, "i8(f8[:])"], 72 | # 転倒数 73 | n = len(arr) 74 | arr = np.argsort(arr) + 1 75 | bit = np.zeros(n+1, dtype=np.int64) 76 | res = n * (n-1) >> 1 77 | for val in arr: 78 | res -= bit_sum(bit, val) 79 | bit_add(bit, val, 1) 80 | return res 81 | 82 | 83 | def pow_mod(base, exp): # [numba_pow, "i8(i8,i8)"], 84 | # mod はグローバル変数を参照 85 | exp %= mod - 1 86 | res = 1 87 | while exp: 88 | if exp % 2: 89 | res = res * base % mod 90 | base = base * base % mod 91 | exp //= 2 92 | return res 93 | 94 | def comb_cunstruct(n): # [comb_cunstruct, "Tuple((i8[:],i8[:]))(i8,)"], 95 | # mod はグローバル変数を参照 96 | fac = np.empty(n + 1, dtype=np.int64) 97 | facinv = np.empty(n + 1, dtype=np.int64) 98 | fac[0] = f = 1 99 | for i in range(1, n + 1): 100 | f = f * i % mod 101 | fac[i] = f 102 | f = pow_mod(f, -1) 103 | for i in range(n, -1, -1): 104 | facinv[i] = f 105 | f = f * i % mod 106 | return fac, facinv 107 | 108 | def comb(n, r, fac, facinv): # [comb, "i8(i8,i8,i8[:],i8[:])"], 109 | # mod はグローバル変数を参照 110 | return fac[n] * facinv[r] % mod * facinv[n - r] % mod 111 | 112 | 113 | def z_algo(S): # [z_algo, "i8[:](i8[:])"], 114 | # Z-algoirhm O(n) 115 | # Z[i] := S と S[i:] で prefix が何文字一致しているか 116 | # 検証1: https://atcoder.jp/contests/abc150/submissions/15829530 117 | # 検証2: https://atcoder.jp/contests/abc141/submissions/15855247 118 | i, j, n = 1, 0, len(S) 119 | Z = np.zeros(S.shape, dtype=np.int64) 120 | Z[0] = n 121 | while i < n: 122 | while i+j < n and S[j] == S[i+j]: 123 | j += 1 124 | if j == 0: 125 | i += 1 126 | continue 127 | Z[i] = j 128 | d = 1 129 | while i+d < n and d+Z[d] < j: 130 | Z[i+d] = Z[d] 131 | d += 1 132 | i += d 133 | j -= d 134 | return Z 135 | 136 | 137 | def sort_edges(N, edges_): # [sort_edges, "Tuple((i8[:],i8[:]))(i8,i8[:,:])"], 138 | # N: 頂点番号の最大値 139 | M = len(edges_) 140 | edges = np.empty((M * 2, 2), dtype=np.int64) 141 | edges[:M] = edges_ 142 | edges[M:] = edges_[:, ::-1] 143 | order = np.argsort(edges[:, 0]) # O(N) にできなくもない 144 | edges = edges[order, 1] 145 | c = np.zeros(N+1, dtype=np.int64) 146 | c_ = np.bincount(edges_.ravel()) # minlength を使わせて 147 | c[:len(c_)] = c_ 148 | c = np.cumsum(c) 149 | lefts = np.zeros(len(c) + 1, dtype=np.int64) 150 | lefts[1:] = c 151 | return edges, lefts 152 | 153 | def eular_tour(edges, lefts, root): # [eular_tour, "Tuple((i8[:],i8[:],i8[:],i8[:]))(i8[:],i8[:],i8)"], 154 | # グラフは 1-indexed が良い 155 | n = len(lefts)-1 156 | stack = [root] 157 | tour = [0] * 0 158 | firsts = np.full(n, -100, dtype=np.int64) 159 | lasts = np.full(n, -100, dtype=np.int64) 160 | parents = np.full(n, -100, dtype=np.int64) 161 | while stack: 162 | v = stack.pop() 163 | if firsts[v] >= 0: 164 | lasts[v] = len(tour) 165 | tour.append(-v) # 帰りがけの辺の表現をマイナス以外にしたい場合ここを変える 166 | continue 167 | p = parents[v] 168 | firsts[v] = len(tour) 169 | tour.append(v) 170 | stack.append(v) 171 | for u in edges[lefts[v]:lefts[v+1]]: 172 | if p != u: 173 | parents[u] = v 174 | stack.append(u) 175 | tour = np.array(tour, dtype=np.int64) 176 | return tour, firsts, lasts, parents 177 | 178 | 179 | from functools import reduce 180 | def rerooting(n, edges): # [rerooting, "(i8,i8[:,:])"], 181 | # 全方位木 dp 182 | # 参考1: https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e 183 | # 参考2: https://atcoder.jp/contests/abc160/submissions/15255726 184 | # 検証: https://atcoder.jp/contests/abc160/submissions/15971370 185 | 186 | # >>> ここを変える >>> 187 | # 必要な情報は引数に持たせる 188 | identity = (1, 0) 189 | def merge(a, b): 190 | return a[0] * b[0] % mod * comb(a[1] + b[1], a[1], fac, facinv) % mod, a[1] + b[1] 191 | def add_node(value, idx): 192 | return value[0], value[1] + 1 193 | # <<< ここを変える <<< 194 | 195 | G = [[0]*0 for _ in range(n)] 196 | for i in range(n-1): 197 | a, b = edges[i] 198 | G[a].append(b) 199 | G[b].append(a) 200 | # step 1 201 | order = [] # 行きがけ順 202 | stack = [0] 203 | while stack: 204 | v = stack.pop() 205 | order.append(v) 206 | for u in G[v]: 207 | stack.append(u) 208 | G[u].remove(v) 209 | # 下から登る 210 | dp_down = [identity] * n # 自身とその下 211 | for v in order[:0:-1]: 212 | dp_down[v] = add_node(reduce( 213 | merge, [dp_down[u] for u in G[v]], identity 214 | ), v) 215 | # step 2 216 | # 上から降りる 217 | dp_up = [identity] * n # 親とその先 218 | for v in order: 219 | Gv = G[v] 220 | if len(Gv) == 0: 221 | continue 222 | cum = identity 223 | right = [identity] 224 | for u in Gv[:0:-1]: 225 | cum = merge(dp_down[u], cum) 226 | right.append(cum) 227 | right.reverse() 228 | cum = dp_up[v] 229 | for u, cum_r in zip(Gv, right): 230 | dp_up[u] = add_node(merge(cum, cum_r), v) 231 | cum = merge(cum, dp_down[u]) 232 | results = [identity] * 0 233 | for v, Gv in enumerate(G): 234 | results.append(add_node( 235 | reduce(merge, [dp_down[u] for u in Gv], dp_up[v]), v 236 | )) 237 | return np.array(results) 238 | 239 | 240 | # セグメント木: https://atcoder.jp/contests/abc158/submissions/16233600 241 | # 平方分割(遅延評価): https://atcoder.jp/contests/abc177/submissions/16376895 242 | # 文字列を uint8 で読み込む: np.frombuffer(input(), dtype=np.uint8) 243 | 244 | -------------------------------------------------------------------------------- /old.py: -------------------------------------------------------------------------------- 1 | class Lca: # 最近共通祖先(ダブリング) 2 | # HL 分解での LCA を書いたのでボツ 3 | def __init__(self, E, root): 4 | import sys 5 | sys.setrecursionlimit(500000) 6 | self.root = root 7 | self.E = E # V 8 | self.n = len(E) # 頂点数 9 | self.logn = 1 # n < 1<= (1<= 0: 22 | self.parent[k+1][v] = self.parent[k][p_] 23 | 24 | def dfs(self, v, p, dep): 25 | # ノード番号、親のノード番号、深さ 26 | self.parent[0][v] = p 27 | self.depth[v] = dep 28 | for e in self.E[v]: 29 | if e != p: 30 | self.dfs(e, v, dep+1) 31 | 32 | def get(self, u, v): 33 | if self.depth[u] > self.depth[v]: 34 | u, v = v, u # self.depth[u] <= self.depth[v] 35 | dep_diff = self.depth[v]-self.depth[u] 36 | for k in range(self.logn): 37 | if dep_diff >> k & 1: 38 | v = self.parent[k][v] 39 | if u==v: 40 | return u 41 | for k in range(self.logn-1, -1, -1): 42 | if self.parent[k][u] != self.parent[k][v]: 43 | u = self.parent[k][u] 44 | v = self.parent[k][v] 45 | return self.parent[0][u] 46 | 47 | def get_convex_hull(points): # 凸包 48 | # 複素数のものを使っていきたいのでボツ 49 | def det(p, q): 50 | return p[0] * q[1] - p[1] * q[0] 51 | def sub(p, q): 52 | return (p[0] - q[0], p[1] - q[1]) 53 | points.sort() 54 | ch = [] 55 | for p in points: 56 | while len(ch) > 1: 57 | v_cur = sub(ch[-1], ch[-2]) 58 | v_new = sub(p, ch[-2]) 59 | if det(v_cur, v_new) > 0: 60 | break 61 | ch.pop() 62 | ch.append(p) 63 | t = len(ch) 64 | for p in points[-2::-1]: 65 | while len(ch) > t: 66 | v_cur = sub(ch[-1], ch[-2]) 67 | v_new = sub(p, ch[-2]) 68 | if det(v_cur, v_new) > 0: 69 | break 70 | ch.pop() 71 | ch.append(p) 72 | return ch[:-1] 73 | 74 | -------------------------------------------------------------------------------- /python2_template.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import sys 3 | PYTHON3 = sys.version_info.major == 3 4 | if not PYTHON3: 5 | from itertools import izip as zip 6 | range = xrange 7 | input = raw_input 8 | -------------------------------------------------------------------------------- /segtree.py: -------------------------------------------------------------------------------- 1 | class SegmentTree(object): 2 | # https://atcoder.jp/contests/abc014/submissions/3935971 3 | __slots__ = ["elem_size", "tree", "default", "op"] 4 | def __init__(self, a: list, default: int, op): 5 | from math import ceil, log 6 | real_size = len(a) 7 | self.elem_size = elem_size = 1 << ceil(log(real_size, 2)) 8 | self.tree = tree = [default] * (elem_size * 2) 9 | tree[elem_size:elem_size + real_size] = a 10 | self.default = default 11 | self.op = op 12 | for i in range(elem_size - 1, 0, -1): 13 | tree[i] = op(tree[i << 1], tree[(i << 1) + 1]) 14 | 15 | def get_value(self, x: int, y: int) -> int: # 半開区間 16 | l, r = x + self.elem_size, y + self.elem_size 17 | tree, result, op = self.tree, self.default, self.op 18 | while l < r: 19 | if l & 1: 20 | result = op(tree[l], result) 21 | l += 1 22 | if r & 1: 23 | r -= 1 24 | result = op(tree[r], result) 25 | l, r = l >> 1, r >> 1 26 | return result 27 | 28 | def set_value(self, i: int, value: int) -> None: 29 | k = self.elem_size + i 30 | self.tree[k] = value 31 | self.update(k) 32 | 33 | def update(self, i: int) -> None: 34 | op, tree = self.op, self.tree 35 | while i > 1: 36 | i >>= 1 37 | tree[i] = op(tree[i << 1], tree[(i << 1) + 1]) 38 | 39 | # モノイドの元として (index, value) を使うと index も取得できるようになる 40 | seg = SegmentTree(list(enumerate(A)), 41 | (-1, float("inf")), 42 | lambda x, y: x if x[1]b else a, op2=lambda a,b:a>b): 51 | # 同じ場合最左のインデックスを返す 52 | # 最小値・最左: default=float("inf"), op=lambda a,b: b if a>b else a, op2=lambda a,b:a>b 53 | # 最大値・最左: -float("inf"), max, lambda a,b:a<=b 54 | from math import ceil, log 55 | real_size = len(a) 56 | self.elem_size = elem_size = 1 << ceil(log(real_size, 2)) 57 | self.tree = tree = [default] * (elem_size * 2) 58 | self.index = index = [0] * (elem_size * 2) 59 | tree[elem_size:elem_size + real_size] = a 60 | index[elem_size:elem_size + real_size] = list(range(real_size)) 61 | self.default = default 62 | self.op = op 63 | self.op2 = op2 64 | for i in range(elem_size-1, 0, -1): 65 | v1, v2 = tree[i<<1], tree[(i<<1)+1] 66 | tree[i] = op(v1, v2) 67 | index[i] = index[(i<<1) + op2(v1, v2)] 68 | 69 | def get_value(self, x: int, y: int) -> tuple: # 半開区間 70 | l, r = x + self.elem_size, y + self.elem_size 71 | tree, op, op2, index = self.tree, self.op, self.op2, self.index 72 | result_l = result_r = self.default 73 | idx_l = idx_r = -1 74 | while l < r: 75 | if l & 1: 76 | v1, v2 = result_l, tree[l] 77 | result_l = op(v1, v2) 78 | if op2(v1, v2)==1: 79 | idx_l = index[l] 80 | l += 1 81 | if r & 1: 82 | r -= 1 83 | v1, v2 = tree[r], result_r 84 | result_r = op(v1, v2) 85 | if op2(v1, v2)==0: 86 | idx_r = index[r] 87 | l, r = l >> 1, r >> 1 88 | result = op(result_l, result_r) 89 | idx = idx_r if op2(result_l, result_r) else idx_l 90 | return result, idx 91 | 92 | def set_value(self, i: int, value: int) -> None: 93 | k = self.elem_size + i 94 | self.tree[k] = value 95 | self.update(k) 96 | 97 | def update(self, i: int) -> None: 98 | op, tree, index, op2 = self.op, self.tree, self.index, self.op2 99 | while i > 1: 100 | i >>= 1 101 | v1, v2 = tree[i<<1], tree[(i<<1)+1] 102 | tree[i] = op(v1, v2) 103 | index[i] = index[(i<<1) + op2(v1, v2)] 104 | 105 | 106 | class SegTree(object): 107 | # 区間の中で v 以下の値のうち最も左にある値と index を取得 108 | # 普通のセグ木に get_threshold_left と get_threshold_left_all を加えただけ 109 | # 検証1: https://atcoder.jp/contests/arc038/submissions/6933949 (全区間のみ) 110 | # 検証2: https://atcoder.jp/contests/arc046/submissions/7430924 (全区間のみ) 111 | # 抽象化したい 112 | __slots__ = ["elem_size", "tree", "default", "op"] 113 | def __init__(self, a: list, default=float("inf"), op=min): 114 | from math import ceil, log 115 | real_size = len(a) 116 | self.elem_size = elem_size = 1 << ceil(log(real_size, 2)) 117 | self.tree = tree = [default] * (elem_size * 2) 118 | tree[elem_size:elem_size + real_size] = a 119 | self.default = default 120 | self.op = op 121 | for i in range(elem_size - 1, 0, -1): 122 | tree[i] = op(tree[i << 1], tree[(i << 1) + 1]) 123 | 124 | def get_value(self, x: int, y: int) -> int: # 半開区間 125 | l, r = x + self.elem_size, y + self.elem_size 126 | tree, result, op = self.tree, self.default, self.op 127 | while l < r: 128 | if l & 1: 129 | result = op(tree[l], result) 130 | l += 1 131 | if r & 1: 132 | r -= 1 133 | result = op(tree[r], result) 134 | l, r = l >> 1, r >> 1 135 | return result 136 | 137 | def get_threshold_left(self, x, y, v): 138 | # 区間 [x, y) 内で一番左の v 以下の値 139 | tree, result, op, elem_size = self.tree, self.default, self.op, self.elem_size 140 | l, r = x + elem_size, y + elem_size 141 | idx_left = idx_right = -1 # 内部 index 142 | while l < r: 143 | if l & 1: 144 | result = op(tree[l], result) 145 | if idx_left == -1 and tree[l] <= v: 146 | idx_left = l 147 | l += 1 148 | if r & 1: 149 | r -= 1 150 | result = op(tree[r], result) 151 | if tree[r] <= v: 152 | idx_right = r 153 | l, r = l >> 1, r >> 1 154 | if idx_left==idx_right==-1: 155 | return -1, -1 156 | idx = idx_left if idx_left!=-1 else idx_right 157 | while idx < elem_size: 158 | idx <<= 1 159 | if tree[idx] > v: 160 | idx += 1 161 | return tree[idx], idx-elem_size 162 | 163 | def get_threshold_left_all(self, v): 164 | # 全区間で一番左の v 以下の値 165 | tree, op, elem_size = self.tree, self.op, self.elem_size 166 | if tree[1] > v: 167 | return -1, -1 168 | idx = 1 169 | while idx < elem_size: 170 | idx <<= 1 171 | if tree[idx] > v: 172 | idx += 1 173 | return tree[idx], idx-elem_size 174 | 175 | def set_value(self, i: int, value: int) -> None: 176 | k = self.elem_size + i 177 | self.tree[k] = value 178 | self.update(k) 179 | 180 | def update(self, i: int) -> None: 181 | op, tree = self.op, self.tree 182 | while i > 1: 183 | i >>= 1 184 | tree[i] = op(tree[i << 1], tree[(i << 1) + 1]) 185 | 186 | 187 | class SparseTable: 188 | def __init__(self, values, op=min, zero_element=float("inf")): # O(nlogn * (op の計算量)) 189 | self.n = n = len(values) 190 | self.table = table = [values] 191 | self.op = op 192 | self.zero_element = zero_element 193 | for d in range(n.bit_length()-1): 194 | table.append([op(v1, v2) for v1, v2 in zip(table[-1], table[-1][1<= g: 208 | g_prov = g_next 209 | r += step 210 | step >>= 1 211 | return r 212 | 213 | 214 | class Rmq: 215 | # 平方分割 216 | # 値を変更すると元のリストの値も書き換わる 217 | # 検証: http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=3990681 218 | def __init__(self, a, sqrt_n=150, inf=(1<<31)-1): 219 | self.sqrt_n = sqrt_n 220 | if hasattr(a, "__iter__"): 221 | from itertools import zip_longest 222 | self.n = len(a) 223 | self.layer0 = [min(values) for values in zip_longest(*[iter(a)]*sqrt_n, fillvalue=inf)] 224 | self.layer1 = a 225 | elif isinstance(a, int): 226 | self.n = a 227 | self.layer0 = [inf] * ((a - 1) // sqrt_n + 1) 228 | self.layer1 = [inf] * a 229 | else: 230 | raise TypeError 231 | 232 | def get_min(self, l, r): 233 | sqrt_n = self.sqrt_n 234 | parent_l, parent_r = l//sqrt_n+1, (r-1)//sqrt_n 235 | if parent_l < parent_r: 236 | return min(min(self.layer0[parent_l:parent_r]), 237 | min(self.layer1[l:parent_l*sqrt_n]), 238 | min(self.layer1[parent_r*sqrt_n:r])) 239 | else: 240 | return min(self.layer1[l:r]) 241 | 242 | def set_value(self, idx, val): 243 | self.layer1[idx] = val 244 | idx0 = idx // self.sqrt_n 245 | idx1 = idx0 * self.sqrt_n 246 | self.layer0[idx0] = min(self.layer1[idx1:idx1+self.sqrt_n]) 247 | 248 | def chmin(self, idx, val): 249 | if self.layer1[idx] > val: 250 | self.layer1[idx] = val 251 | idx //= self.sqrt_n 252 | self.layer0[idx] = min(self.layer0[idx], val) 253 | 254 | def debug(self): 255 | print("layer0=", self.layer0) 256 | print("layer1=", self.layer1) 257 | 258 | def __getitem__(self, item): 259 | return self.layer1[item] 260 | 261 | def __setitem__(self, key, value): 262 | self.set_value(key, value) 263 | 264 | # https://atcoder.jp/contests/nikkei2019-2-qual/submissions/8434117 もある 265 | -------------------------------------------------------------------------------- /square_skip_list.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_left, bisect_right, insort_right 2 | class SquareSkipList: 3 | # SkipList の層数を 2 にした感じの何か 4 | # std::multiset の代用になる 5 | # 検証1 (add, pop) データ構造: https://atcoder.jp/contests/arc033/submissions/14718760 6 | # 検証2 (init, add, remove, search_higher_equal) Exclusive OR Queries: https://atcoder.jp/contests/cpsco2019-s1/submissions/14705333 7 | # 検証3 (add, search_higher, search_lower) Second Sum: https://atcoder.jp/contests/abc140/submissions/7488469 8 | # 検証4 (add, __getitem__) [CF] Optimal Subsequences (Hard Version): https://codeforces.com/contest/1261/submission/65643461 9 | # 検証5 (init, add, pop) Dinner Planning: https://atcoder.jp/contests/code-festival-2018-final-open/submissions/13916065 10 | # 検証6 (要素がタプル, init, add, remove, max, pop_max) Lake: https://atcoder.jp/contests/snuke21/submissions/14718529 11 | # 検証7 (init, add, remove, pop, max, pop_max) ドーナツの箱詰め: https://atcoder.jp/contests/donuts-2015/submissions/14829916 12 | # 検証8 (init, add, remove, min, max) Smart Infants: https://atcoder.jp/contests/abc170/submissions/15264112 13 | def __init__(self, values=None, sorted_=False, square=1000, seed=42, inf=float("inf")): 14 | # values: 初期値のリスト 15 | # sorted_: 初期値がソート済みであるか 16 | # square: 最大データ数の平方根 17 | # seed: 乱数のシード 18 | # inf: 番兵(要素がタプルのときは (float("inf"), float("inf")) にする) 19 | self.square = square 20 | if values is None: 21 | self.rand_y = seed 22 | self.layer1 = [inf] 23 | self.layer0 = [[]] 24 | else: 25 | self.layer1 = layer1 = [] 26 | self.layer0 = layer0 = [] 27 | if not sorted_: 28 | values.sort() 29 | y = seed 30 | l0 = [] 31 | for v in values: 32 | y ^= (y & 0x7ffff) << 13 33 | y ^= y >> 17 34 | y ^= (y & 0x7ffffff) << 5 35 | if y % square == 0: 36 | layer0.append(l0) 37 | l0 = [] 38 | layer1.append(v) 39 | else: 40 | l0.append(v) 41 | layer1.append(inf) 42 | layer0.append(l0) 43 | self.rand_y = y 44 | 45 | def add(self, x): # 要素の追加 # O(sqrt(n)) 46 | # xorshift 47 | y = self.rand_y 48 | y ^= (y & 0x7ffff) << 13 49 | y ^= y >> 17 50 | y ^= (y & 0x7ffffff) << 5 51 | self.rand_y = y 52 | 53 | if y % self.square == 0: 54 | layer1, layer0 = self.layer1, self.layer0 55 | idx1 = bisect_right(layer1, x) 56 | layer1.insert(idx1, x) 57 | layer0_idx1 = layer0[idx1] 58 | idx0 = bisect_right(layer0_idx1, x) 59 | layer0.insert(idx1 + 1, layer0_idx1[idx0:]) # layer0 は dict で管理した方が良いかもしれない # dict 微妙だった 60 | del layer0_idx1[idx0:] 61 | else: 62 | idx1 = bisect_right(self.layer1, x) 63 | insort_right(self.layer0[idx1], x) 64 | 65 | def remove(self, x): # 要素の削除 # O(sqrt(n)) 66 | # x が存在しない場合、x 以上の最小の要素が削除される 67 | idx1 = bisect_left(self.layer1, x) 68 | layer0_idx1 = self.layer0[idx1] 69 | idx0 = bisect_left(layer0_idx1, x) 70 | if idx0 == len(layer0_idx1): 71 | del self.layer1[idx1] 72 | self.layer0[idx1] += self.layer0.pop(idx1 + 1) 73 | else: 74 | del layer0_idx1[idx0] 75 | 76 | def search_higher_equal(self, x): # x 以上の最小の値を返す O(log(n)) 77 | idx1 = bisect_left(self.layer1, x) 78 | layer0_idx1 = self.layer0[idx1] 79 | idx0 = bisect_left(layer0_idx1, x) 80 | if idx0 == len(layer0_idx1): 81 | return self.layer1[idx1] 82 | return layer0_idx1[idx0] 83 | 84 | def search_higher(self, x): # x を超える最小の値を返す O(log(n)) 85 | idx1 = bisect_right(self.layer1, x) 86 | layer0_idx1 = self.layer0[idx1] 87 | idx0 = bisect_right(layer0_idx1, x) 88 | if idx0 == len(layer0_idx1): 89 | return self.layer1[idx1] 90 | return layer0_idx1[idx0] 91 | 92 | def search_lower(self, x): # x 未満の最大の値を返す O(log(n)) 93 | idx1 = bisect_left(self.layer1, x) 94 | layer0_idx1 = self.layer0[idx1] 95 | idx0 = bisect_left(layer0_idx1, x) 96 | if idx0 == 0: # layer0_idx1 が空の場合とすべて x 以上の場合 97 | return self.layer1[idx1 - 1] 98 | return layer0_idx1[idx0 - 1] 99 | 100 | def pop(self, idx): 101 | # 小さい方から idx 番目の要素を削除してその要素を返す(0-indexed) 102 | # O(sqrt(n)) 103 | # for を回すので重め、使うなら square パラメータを大きめにするべき 104 | layer0 = self.layer0 105 | s = -1 106 | for i, l0 in enumerate(layer0): 107 | s += len(l0) + 1 108 | if s >= idx: 109 | break 110 | if s == idx: 111 | layer0[i] += layer0.pop(i + 1) 112 | return self.layer1.pop(i) 113 | else: 114 | return layer0[i].pop(idx - s) 115 | 116 | def pop_max(self): 117 | # 最大値を削除してその要素を返す(0-indexed) O(1) 118 | # 空ならエラー 119 | if self.layer0[-1]: 120 | return self.layer0[-1].pop() 121 | else: 122 | del self.layer0[-1] 123 | return self.layer1.pop(-2) 124 | 125 | def __getitem__(self, item): 126 | # 小さい方から idx 番目の要素を返す O(sqrt(N)) 127 | layer0 = self.layer0 128 | s = -1 129 | for i, l0 in enumerate(layer0): 130 | s += len(l0) + 1 131 | if s >= item: 132 | break 133 | if s == item: 134 | return self.layer1[i] 135 | else: 136 | return layer0[i][item - s] 137 | 138 | def min(self): # 最小値を返す 空なら inf を返す O(1) 139 | return self.layer0[0][0] if self.layer0[0] else self.layer1[0] 140 | 141 | def max(self): # 最大値を返す 空ならエラー O(1) 142 | return self.layer0[-1][-1] if self.layer0[-1] else self.layer1[-2] 143 | 144 | def merge(self, r): # 結合 O(sqrt(n)) 145 | self.layer0[-1] += r.layer0[0] 146 | self.layer0 += r.layer0[1:] 147 | del self.layer1[-1] 148 | self.layer1 += r.layer1 149 | 150 | def split(self, k): # k 以上を切り離す O(sqrt(n)) 151 | idx1 = bisect_left(self.layer1, k) 152 | layer0_idx1 = self.layer0[idx1] 153 | idx0 = bisect_left(layer0_idx1, k) 154 | r = SquareSkipList(square=self.square, seed=self.rand_y) 155 | r.layer1 = self.layer1[idx1:] 156 | r.layer0 = [layer0_idx1[idx0:]] + self.layer0[idx1 + 1:] 157 | del self.layer1[idx1:-1], layer0_idx1[idx0:], self.layer0[idx1 + 1:] 158 | return r 159 | 160 | def print(self): 161 | print(self.layer1) 162 | print(self.layer0) 163 | 164 | def __iter__(self): 165 | layer1 = self.layer1 166 | layer0 = self.layer0 167 | idx1 = idx0 = 0 168 | layer0_idx1 = layer0[idx1] 169 | while True: 170 | if len(layer0_idx1) == idx0: 171 | if len(layer1) - 1 == idx1: 172 | return 173 | yield layer1[idx1] 174 | idx1 += 1 175 | layer0_idx1 = layer0[idx1] 176 | idx0 = 0 177 | else: 178 | yield layer0_idx1[idx0] 179 | idx0 += 1 180 | 181 | --------------------------------------------------------------------------------