├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── compile_flags.txt ├── src ├── alphabetic_huffman_code.hpp ├── alphabetic_huffman_code.test.cpp ├── bit.hpp ├── bit_cast.hpp ├── bm.hpp ├── bm.test.cpp ├── cartesian_tree.hpp ├── cartesian_tree.test.cpp ├── char_poly.hpp ├── cnt_min.hpp ├── dirichlet_series.hpp ├── dirichlet_series.test.cpp ├── fft.hpp ├── fft.test.cpp ├── fraction.hpp ├── geometry │ ├── point.hpp │ └── point3d.hpp ├── graph │ └── make_st_dag.hpp ├── hash_map.hpp ├── jacobi.hpp ├── lattice_cnt.hpp ├── lattice_cnt.test.cpp ├── lct.hpp ├── level_ancestor.hpp ├── manacher.hpp ├── mcmf.hpp ├── modnum.hpp ├── modnum.test.cpp ├── nim_prod.hpp ├── optimize.hpp ├── order_statistic.hpp ├── perm_tree.hpp ├── perm_tree.test.cpp ├── quaternion_hurwitz.hpp ├── reverse_comparator.hpp ├── rmq.hpp ├── rmq.test.cpp ├── seg_tree.hpp ├── seg_tree.test.cpp ├── smawk.hpp ├── smawk.test.cpp ├── static_tree.hpp ├── suffix_array.hpp ├── tensor.hpp ├── tensor.test.cpp ├── top_tree.hpp └── yc.hpp └── third_party └── sais-lite-2.4.1 ├── COPYING ├── Makefile ├── README ├── is_orig.c ├── sais.c ├── sais.h ├── sais.hxx ├── suftest.c └── test.c /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | /build/ 35 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | project(cp-book) 3 | 4 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 5 | 6 | Include(FetchContent) 7 | 8 | FetchContent_Declare( 9 | Catch2 10 | GIT_REPOSITORY https://github.com/catchorg/Catch2.git 11 | GIT_TAG v3.4.0 # or a later release 12 | ) 13 | 14 | FetchContent_MakeAvailable(Catch2) 15 | 16 | file(GLOB TEST_SRC_FILES "src/*.test.cpp") 17 | file(GLOB SRC_FILES "src/*.hpp") 18 | 19 | add_library(cp-book INTERFACE) 20 | target_include_directories(cp-book INTERFACE src) 21 | target_compile_features(cp-book INTERFACE cxx_std_20) 22 | # TODO: Clean these up somehow 23 | target_compile_options(cp-book INTERFACE -O2 -Wall -Wextra -pedantic -Wshadow -Wformat=2 -Wfloat-equal -Wconversion -Wlogical-op -Wshift-overflow=2 -Wduplicated-cond -Wcast-qual -Wcast-align -Wno-unused-result -Wno-sign-conversion -g -D_GLIBCXX_DEBUG -D_GLIBCXX_DEBUG_PEDANTIC -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all -fstack-protector -D_FORTIFY_SOURCE=2) 24 | target_link_options(cp-book INTERFACE -O2 -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all -fstack-protector) 25 | 26 | add_executable(tests "${TEST_SRC_FILES}") 27 | target_link_libraries(tests PUBLIC cp-book) 28 | target_link_libraries(tests PRIVATE Catch2::Catch2WithMain) 29 | target_compile_features(tests PUBLIC cxx_std_17) 30 | target_include_directories(tests PRIVATE src) 31 | 32 | list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) 33 | include(CTest) 34 | include(Catch) 35 | catch_discover_tests(tests) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | CC0 1.0 Universal 2 | 3 | Statement of Purpose 4 | 5 | The laws of most jurisdictions throughout the world automatically confer 6 | exclusive Copyright and Related Rights (defined below) upon the creator and 7 | subsequent owner(s) (each and all, an "owner") of an original work of 8 | authorship and/or a database (each, a "Work"). 9 | 10 | Certain owners wish to permanently relinquish those rights to a Work for the 11 | purpose of contributing to a commons of creative, cultural and scientific 12 | works ("Commons") that the public can reliably and without fear of later 13 | claims of infringement build upon, modify, incorporate in other works, reuse 14 | and redistribute as freely as possible in any form whatsoever and for any 15 | purposes, including without limitation commercial purposes. These owners may 16 | contribute to the Commons to promote the ideal of a free culture and the 17 | further production of creative, cultural and scientific works, or to gain 18 | reputation or greater distribution for their Work in part through the use and 19 | efforts of others. 20 | 21 | For these and/or other purposes and motivations, and without any expectation 22 | of additional consideration or compensation, the person associating CC0 with a 23 | Work (the "Affirmer"), to the extent that he or she is an owner of Copyright 24 | and Related Rights in the Work, voluntarily elects to apply CC0 to the Work 25 | and publicly distribute the Work under its terms, with knowledge of his or her 26 | Copyright and Related Rights in the Work and the meaning and intended legal 27 | effect of CC0 on those rights. 28 | 29 | 1. Copyright and Related Rights. A Work made available under CC0 may be 30 | protected by copyright and related or neighboring rights ("Copyright and 31 | Related Rights"). Copyright and Related Rights include, but are not limited 32 | to, the following: 33 | 34 | i. the right to reproduce, adapt, distribute, perform, display, communicate, 35 | and translate a Work; 36 | 37 | ii. moral rights retained by the original author(s) and/or performer(s); 38 | 39 | iii. publicity and privacy rights pertaining to a person's image or likeness 40 | depicted in a Work; 41 | 42 | iv. rights protecting against unfair competition in regards to a Work, 43 | subject to the limitations in paragraph 4(a), below; 44 | 45 | v. rights protecting the extraction, dissemination, use and reuse of data in 46 | a Work; 47 | 48 | vi. database rights (such as those arising under Directive 96/9/EC of the 49 | European Parliament and of the Council of 11 March 1996 on the legal 50 | protection of databases, and under any national implementation thereof, 51 | including any amended or successor version of such directive); and 52 | 53 | vii. other similar, equivalent or corresponding rights throughout the world 54 | based on applicable law or treaty, and any national implementations thereof. 55 | 56 | 2. Waiver. To the greatest extent permitted by, but not in contravention of, 57 | applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and 58 | unconditionally waives, abandons, and surrenders all of Affirmer's Copyright 59 | and Related Rights and associated claims and causes of action, whether now 60 | known or unknown (including existing as well as future claims and causes of 61 | action), in the Work (i) in all territories worldwide, (ii) for the maximum 62 | duration provided by applicable law or treaty (including future time 63 | extensions), (iii) in any current or future medium and for any number of 64 | copies, and (iv) for any purpose whatsoever, including without limitation 65 | commercial, advertising or promotional purposes (the "Waiver"). Affirmer makes 66 | the Waiver for the benefit of each member of the public at large and to the 67 | detriment of Affirmer's heirs and successors, fully intending that such Waiver 68 | shall not be subject to revocation, rescission, cancellation, termination, or 69 | any other legal or equitable action to disrupt the quiet enjoyment of the Work 70 | by the public as contemplated by Affirmer's express Statement of Purpose. 71 | 72 | 3. Public License Fallback. Should any part of the Waiver for any reason be 73 | judged legally invalid or ineffective under applicable law, then the Waiver 74 | shall be preserved to the maximum extent permitted taking into account 75 | Affirmer's express Statement of Purpose. In addition, to the extent the Waiver 76 | is so judged Affirmer hereby grants to each affected person a royalty-free, 77 | non transferable, non sublicensable, non exclusive, irrevocable and 78 | unconditional license to exercise Affirmer's Copyright and Related Rights in 79 | the Work (i) in all territories worldwide, (ii) for the maximum duration 80 | provided by applicable law or treaty (including future time extensions), (iii) 81 | in any current or future medium and for any number of copies, and (iv) for any 82 | purpose whatsoever, including without limitation commercial, advertising or 83 | promotional purposes (the "License"). The License shall be deemed effective as 84 | of the date CC0 was applied by Affirmer to the Work. Should any part of the 85 | License for any reason be judged legally invalid or ineffective under 86 | applicable law, such partial invalidity or ineffectiveness shall not 87 | invalidate the remainder of the License, and in such case Affirmer hereby 88 | affirms that he or she will not (i) exercise any of his or her remaining 89 | Copyright and Related Rights in the Work or (ii) assert any associated claims 90 | and causes of action with respect to the Work, in either case contrary to 91 | Affirmer's express Statement of Purpose. 92 | 93 | 4. Limitations and Disclaimers. 94 | 95 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 96 | surrendered, licensed or otherwise affected by this document. 97 | 98 | b. Affirmer offers the Work as-is and makes no representations or warranties 99 | of any kind concerning the Work, express, implied, statutory or otherwise, 100 | including without limitation warranties of title, merchantability, fitness 101 | for a particular purpose, non infringement, or the absence of latent or 102 | other defects, accuracy, or the present or absence of errors, whether or not 103 | discoverable, all to the greatest extent permissible under applicable law. 104 | 105 | c. Affirmer disclaims responsibility for clearing rights of other persons 106 | that may apply to the Work or any use thereof, including without limitation 107 | any person's Copyright and Related Rights in the Work. Further, Affirmer 108 | disclaims responsibility for obtaining any necessary consents, permissions 109 | or other rights required for any use of the Work. 110 | 111 | d. Affirmer understands and acknowledges that Creative Commons is not a 112 | party to this document and has no duty or obligation with respect to this 113 | CC0 or use of the Work. 114 | 115 | For more information, please see 116 | 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ecnerwala's CP Book 2 | 3 | This is my library of reference code for competitive programming. The goal is to 4 | write generic, fast, and clean algorithm implementations for use in contests 5 | like CodeForces or ICPC. 6 | 7 | ## Building 8 | 9 | Build using 10 | 11 | ```sh 12 | cmake -B build 13 | cmake --build build 14 | ``` 15 | 16 | Test with 17 | 18 | ```sh 19 | ctest --test-dir build 20 | ``` 21 | 22 | or directly with 23 | 24 | ```sh 25 | ./build/tests 26 | ``` 27 | 28 | ## License and Attribution 29 | 30 | All code in this book is written by me and CC0 licensed unless otherwise noted 31 | in the file. Inspiration is largely drawn from KACTL 32 | (https://github.com/kth-competitive-programming/kactl/) and other references. 33 | -------------------------------------------------------------------------------- /compile_flags.txt: -------------------------------------------------------------------------------- 1 | -x 2 | c++ 3 | -Wall 4 | -Wextra 5 | -pedantic 6 | -std=c++20 7 | -O2 8 | -Wshadow 9 | -Wformat=2 10 | -Wfloat-equal 11 | -Wconversion 12 | -Wshift-overflow 13 | -Wcast-qual 14 | -Wcast-align 15 | -D_GLIBCXX_DEBUG 16 | -D_GLIBCXX_DEBUG_PEDANTIC 17 | -fsanitize=address 18 | -fsanitize=undefined 19 | -fno-sanitize-recover=all 20 | -fstack-protector 21 | -D_FORTIFY_SOURCE=2 22 | -Wno-sign-conversion 23 | 24 | -Isrc 25 | -Itest 26 | -------------------------------------------------------------------------------- /src/alphabetic_huffman_code.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // Finds an optimal alphabetic (binary) Huffman code, i.e. one that preserves the ordering of the original weights 8 | // Implements the Garsia-Wachs algorithm: https://en.wikipedia.org/wiki/Garsia%E2%80%93Wachs_algorithm 9 | // Returns the code specified as a sequence of depths for each input weight 10 | template std::vector alphabetic_huffman_code(std::vector weights) { 11 | int N = int(weights.size()); 12 | if (N == 0) return {}; 13 | std::vector> ch; ch.reserve(N-1); 14 | 15 | { 16 | struct splay_node { 17 | mutable splay_node* p = nullptr; 18 | std::array c{nullptr, nullptr}; 19 | int d() const { return this == p->c[1]; } 20 | 21 | T_sum value; 22 | T_sum max_value; 23 | int idx; 24 | 25 | void update() { 26 | max_value = value; 27 | for (auto ch : c) { 28 | if (ch && max_value < ch->max_value) max_value = ch->max_value; 29 | } 30 | } 31 | 32 | void rot() { 33 | assert(p); 34 | 35 | int x = d(); 36 | splay_node* pa = p; 37 | splay_node* ch = c[!x]; 38 | 39 | if (ch) ch->p = pa; 40 | pa->c[x] = ch; 41 | 42 | if (pa->p) pa->p->c[pa->d()] = this; 43 | this->p = pa->p; 44 | 45 | this->c[!x] = pa; 46 | pa->p = this; 47 | 48 | pa->update(); 49 | } 50 | 51 | void splay_no_update(splay_node* top) { 52 | while (p != top) { 53 | if (p->p != top) { 54 | if (p->d() == d()) p->rot(); 55 | else rot(); 56 | } 57 | rot(); 58 | } 59 | } 60 | }; 61 | std::vector nodes(N+1); 62 | for (int i = 0; i < N; i++) { 63 | nodes[i].p = &nodes[i+1]; 64 | nodes[i+1].c[0] = &nodes[i]; 65 | nodes[i].value = T_sum(weights[i]); 66 | nodes[i].idx = i; 67 | } 68 | nodes[0].update(); 69 | splay_node* cur = &nodes[1]; 70 | 71 | // We'll store our current state as the left spine of some splay tree. 72 | // All vertices from cur to the root are precisely the vertices that may satisfy w[n-2] <= w[n] 73 | // (all others provably satisfy w[x-2] > w[x] at all times), 74 | // so cur is exactly the leftmost vertex that might satisfy w[n-2] <= w[n]. 75 | // 76 | // We then check this condition, and if it does have w[n-2] <= w[n], 77 | // we merge w[n-2] and w[n-1] and reinsert somewhere according to Garsia-Wachs, 78 | // i.e. right after the last element of w[0:n-1] greater than or equal to it. 79 | // Then, the newly inserted node is added to the candidate chain 80 | // (exercise: prove that all other positions still satisfy w[x-2] > w[x]). 81 | 82 | while (cur) { 83 | // Note: cur is not necessarily updated 84 | 85 | // First, grab the 2nd child of the left side of cur 86 | splay_node* a = cur->c[0]; 87 | assert(a); 88 | while (a->c[1]) a = a->c[1]; 89 | if (a->c[0]) { 90 | a = a->c[0]; 91 | while (a->c[1]) a = a->c[1]; 92 | } else { 93 | a = a->p; 94 | } 95 | if (a == cur) { 96 | // size one, so we're done 97 | cur->update(); 98 | cur = cur->p; 99 | continue; 100 | } 101 | a->splay_no_update(cur); 102 | assert(a == cur->c[0]); 103 | assert(a->c[1] && !a->c[1]->c[0] && !a->c[1]->c[1]); 104 | if (cur->p && cur->value < a->value) { 105 | // no merging, so we're done 106 | a->update(); 107 | cur->update(); 108 | cur = cur->p; 109 | continue; 110 | } 111 | 112 | // Otherwise, merge a and a->c[1] 113 | { 114 | int n_idx = N + int(ch.size()); 115 | ch.push_back({a->idx, a->c[1]->idx}); 116 | a->idx = n_idx; 117 | } 118 | a->value += a->c[1]->value; 119 | a->c[1]->p = nullptr; 120 | a->c[1] = nullptr; 121 | 122 | // Now, insert a right after the first guy b which is b.v >= a.v 123 | if (!a->c[0] || a->c[0]->max_value < a->value) { 124 | a->c[1] = a->c[0]; 125 | a->c[0] = nullptr; 126 | a->update(); 127 | // Don't recurse on a, since it has no left child 128 | continue; 129 | } 130 | 131 | splay_node* b = a->c[0]; 132 | while (true) { 133 | assert(b); 134 | assert(!(b->max_value < a->value)); 135 | if (!b->c[1] || b->c[1]->max_value < a->value) { 136 | if (b->value < a->value) { 137 | assert(b->c[0]); 138 | b = b->c[0]; 139 | } else { 140 | break; 141 | } 142 | } else { 143 | b = b->c[1]; 144 | } 145 | } 146 | b->splay_no_update(a); 147 | assert(b == a->c[0]); 148 | if (b->c[1]) b->c[1]->p = a; 149 | a->c[1] = b->c[1]; 150 | b->c[1] = nullptr; 151 | b->update(); 152 | cur = a; 153 | continue; 154 | } 155 | } 156 | 157 | // Reconstruct depths 158 | assert(int(ch.size()) == N-1); 159 | std::vector res(2*N-1, -1); 160 | res[2*N-2] = 0; 161 | for (int i = 2*N-2; i >= N; i--) { 162 | assert(res[i] != -1); 163 | res[ch[i-N][0]] = res[i] + 1; 164 | res[ch[i-N][1]] = res[i] + 1; 165 | } 166 | res.resize(N); 167 | return res; 168 | } 169 | 170 | // Returns the lca array of length N - 1, suitable for building a Cartesian tree 171 | inline std::vector binary_code_depths_to_lca_depths(std::vector depths) { 172 | int N = int(depths.size()); 173 | if (N == 0) return {}; 174 | std::vector res; res.reserve(N-1); 175 | std::vector stk; stk.reserve(N); 176 | for (int v : depths) { 177 | while (!stk.empty() && stk.back() == v) { 178 | stk.pop_back(); 179 | v--; 180 | } 181 | assert(stk.empty() || stk.back() < v); 182 | if (v != 0) res.push_back(v-1); 183 | stk.push_back(v); 184 | } 185 | assert(int(stk.size()) == 1 && stk.back() == 0); 186 | return res; 187 | } 188 | -------------------------------------------------------------------------------- /src/alphabetic_huffman_code.test.cpp: -------------------------------------------------------------------------------- 1 | #include "alphabetic_huffman_code.hpp" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | template T alphabetic_huffman_code_naive(std::vector weights) { 9 | int N = int(weights.size()); 10 | if (N == 0) return 0; 11 | assert(N > 0); 12 | std::vector dp(N * N); 13 | for (int i = 0; i < N; i++) { 14 | dp[i * N + i] = 0; 15 | T pref = weights[i]; 16 | for (int j = i-1; j >= 0; j--) { 17 | T v = dp[i * N + i] + dp[(i-1) * N + j]; 18 | for (int k = i-1; k >= j+1; k--) { 19 | v = std::min(v, dp[i * N + k] + dp[(k-1) * N + j]); 20 | } 21 | pref += weights[j]; 22 | v += pref; 23 | dp[i * N + j] = v; 24 | } 25 | } 26 | return dp[(N-1) * N + 0]; 27 | } 28 | 29 | TEST_CASE("Alphabetic Huffman Code", "[alphabetic_huffman_code]") { 30 | std::mt19937 mt(Catch::getSeed()); 31 | for (int z = 0; z <= 1000; z++) { 32 | int N = std::uniform_int_distribution(1, 60)(mt); 33 | int MX = 1 << std::uniform_int_distribution(0, 15)(mt); 34 | std::vector weights(N); 35 | for (auto& w : weights) w = std::uniform_int_distribution(0, MX-1)(mt); 36 | auto naive_tot = alphabetic_huffman_code_naive(weights); 37 | 38 | auto code_depths = alphabetic_huffman_code(weights); 39 | auto code_lcp = binary_code_depths_to_lca_depths(code_depths); 40 | int tot = 0; 41 | for (int i = 0; i < N; i++) { 42 | tot += weights[i] * code_depths[i]; 43 | } 44 | REQUIRE(naive_tot == tot); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/bit.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | /** Binary-indexed tree 7 | * 8 | * A binary indexed tree with N nodes of type T provides the 9 | * following two functions for 0 <= i <= N: 10 | * 11 | * prefix(int i) -> prefix_iterator 12 | * suffix(int i) -> suffix_iterator 13 | * 14 | * such that size(suffix(i) intersect prefix(j)) = (1 if i < j else 0). 15 | * Furthermore, the resulting lists always have size at most log_2(N). 16 | * 17 | * This can be used to implement either point-update/(prefix|suffix)-query or 18 | * (prefix|suffix)-update/point-query over a virtual array of size N of a 19 | * commutative monoid. This can be generalized to implement 20 | * point-update/range-query or range-update/point-query over a virtual array 21 | * of size N of a commutative group. 22 | * 23 | * With 0-indexed data, prefixes are more natural: 24 | * * For range update/query, use for_prefix for the ranges and for_suffix for the points. 25 | * * For prefix update/query, no change. 26 | * * For suffix update/query, use for_prefix(point + 1); 1-index the data. 27 | */ 28 | template class binary_indexed_tree { 29 | private: 30 | std::vector dat; 31 | public: 32 | binary_indexed_tree() {} 33 | explicit binary_indexed_tree(size_t N) : dat(N) {} 34 | binary_indexed_tree(size_t N, const T& t) : dat(N, t) {} 35 | 36 | size_t size() const { return dat.size(); } 37 | const std::vector& data() const { return dat; } 38 | std::vector& data() { return dat; } 39 | 40 | private: 41 | template struct iterator_range { 42 | private: 43 | I begin_; 44 | S end_; 45 | public: 46 | iterator_range() : begin_(), end_() {} 47 | iterator_range(const I& begin__, const S& end__) : begin_(begin__), end_(end__) {} 48 | iterator_range(I&& begin__, S&& end__) : begin_(begin__), end_(end__) {} 49 | I begin() const { return begin_; } 50 | S end() const { return end_; } 51 | }; 52 | 53 | public: 54 | class const_suffix_iterator { 55 | private: 56 | const T* dat; 57 | int a; 58 | const_suffix_iterator(const T* dat_, int a_) : dat(dat_), a(a_) {} 59 | friend class binary_indexed_tree; 60 | public: 61 | friend bool operator != (const const_suffix_iterator& i, const const_suffix_iterator& j) { 62 | assert(j.dat == nullptr); 63 | return i.a < j.a; 64 | } 65 | const_suffix_iterator& operator ++ () { 66 | a |= a+1; 67 | return *this; 68 | } 69 | const T& operator * () const { 70 | return dat[a]; 71 | } 72 | }; 73 | using const_suffix_range = iterator_range; 74 | const_suffix_range suffix(int a) const { 75 | assert(0 <= a && a <= int(dat.size())); 76 | return const_suffix_range{const_suffix_iterator{dat.data(), a}, const_suffix_iterator{nullptr, int(dat.size())}}; 77 | } 78 | 79 | class suffix_iterator { 80 | private: 81 | T* dat; 82 | int a; 83 | suffix_iterator(T* dat_, int a_) : dat(dat_), a(a_) {} 84 | friend class binary_indexed_tree; 85 | public: 86 | friend bool operator != (const suffix_iterator& i, const suffix_iterator& j) { 87 | assert(j.dat == nullptr); 88 | return i.a < j.a; 89 | } 90 | suffix_iterator& operator ++ () { 91 | a |= a+1; 92 | return *this; 93 | } 94 | T& operator * () const { 95 | return dat[a]; 96 | } 97 | }; 98 | using suffix_range = iterator_range; 99 | suffix_range suffix(int a) { 100 | assert(0 <= a && a <= int(dat.size())); 101 | return suffix_range{suffix_iterator{dat.data(), a}, suffix_iterator{nullptr, int(dat.size())}}; 102 | } 103 | 104 | class const_prefix_iterator { 105 | private: 106 | const T* dat; 107 | int a; 108 | const_prefix_iterator(const T* dat_, int a_) : dat(dat_), a(a_) {} 109 | friend class binary_indexed_tree; 110 | public: 111 | friend bool operator != (const const_prefix_iterator& i, const const_prefix_iterator& j) { 112 | assert(j.dat == nullptr); 113 | return i.a > 0; 114 | } 115 | const_prefix_iterator& operator ++ () { 116 | a &= a-1; 117 | return *this; 118 | } 119 | const T& operator * () const { 120 | return dat[a-1]; 121 | } 122 | }; 123 | using const_prefix_range = iterator_range; 124 | const_prefix_range prefix(int a) const { 125 | return const_prefix_range{const_prefix_iterator{dat.data(), a}, const_prefix_iterator{nullptr, 0}}; 126 | } 127 | 128 | class prefix_iterator { 129 | private: 130 | T* dat; 131 | int a; 132 | prefix_iterator(T* dat_, int a_) : dat(dat_), a(a_) {} 133 | friend class binary_indexed_tree; 134 | public: 135 | friend bool operator != (const prefix_iterator& i, const prefix_iterator& j) { 136 | assert(j.dat == nullptr); 137 | return i.a > 0; 138 | } 139 | prefix_iterator& operator ++ () { 140 | a &= a-1; 141 | return *this; 142 | } 143 | T& operator * () const { 144 | return dat[a-1]; 145 | } 146 | }; 147 | using prefix_range = iterator_range; 148 | prefix_range prefix(int a) { 149 | return prefix_range{prefix_iterator{dat.data(), a}, prefix_iterator{nullptr, 0}}; 150 | } 151 | }; 152 | -------------------------------------------------------------------------------- /src/bit_cast.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | // Copied from https://en.cppreference.com/w/cpp/numeric/bit_cast 7 | 8 | template 9 | typename std::enable_if_t< 10 | sizeof(To) == sizeof(From) && 11 | std::is_trivially_copyable_v && 12 | std::is_trivially_copyable_v, 13 | To> 14 | // constexpr support needs compiler magic 15 | bit_cast(const From& src) noexcept 16 | { 17 | static_assert(std::is_trivially_constructible_v, 18 | "This implementation additionally requires destination type to be trivially constructible"); 19 | 20 | To dst; 21 | std::memcpy(&dst, &src, sizeof(To)); 22 | return dst; 23 | } 24 | -------------------------------------------------------------------------------- /src/bm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | template 5 | std::vector BerlekampMassey(const std::vector& s) { 6 | int n = int(s.size()), L = 0, m = 0; 7 | std::vector C(n), B(n), T; 8 | C[0] = B[0] = 1; 9 | 10 | num b = 1; 11 | for(int i = 0; i < n; i++) { ++m; 12 | num d = s[i]; 13 | for (int j = 1; j <= L; j++) d += C[j] * s[i - j]; 14 | if (d == 0) continue; 15 | T = C; num coef = d / b; 16 | for (int j = m; j < n; j++) C[j] -= coef * B[j - m]; 17 | if (2 * L > i) continue; 18 | L = i + 1 - L; B = T; b = d; m = 0; 19 | } 20 | 21 | C.resize(L + 1); C.erase(C.begin()); 22 | for (auto& x : C) { 23 | x = -x; 24 | } 25 | return C; 26 | } 27 | 28 | template 29 | num linearRec(const std::vector& S, const std::vector& tr, int64_t k) { 30 | int n = int(tr.size()); 31 | assert(S.size() >= tr.size()); 32 | 33 | auto combine = [&](std::vector a, std::vector b, bool e = false) { 34 | // multiply a * b * x^e 35 | std::vector res(int(a.size()) + int(b.size())); 36 | for (int i = 0; i < int(a.size()); i++) { 37 | for (int j = 0; j < int(b.size()); j++) { 38 | res[i + j + e] += a[i] * b[j]; 39 | } 40 | } 41 | for (int i = int(res.size())-1; i >= n; --i) { 42 | for (int j = 0; j < n; j++) { 43 | res[i - 1 - j] += res[i] * tr[j]; 44 | } 45 | } 46 | res.resize(n); 47 | return res; 48 | }; 49 | 50 | std::vector pol(n); 51 | if (n > 0) pol[0] = num(1); 52 | 53 | assert(k >= 0); 54 | for (int i = 64 - 1 - (k == 0 ? 64 : __builtin_clzll(k)); i >= 0; i--) { 55 | pol = combine(pol, pol, (k >> i) & 1); 56 | } 57 | 58 | num res = 0; 59 | for (int i = 0; i < n; i++) res += pol[i] * S[i]; 60 | return res; 61 | } 62 | -------------------------------------------------------------------------------- /src/bm.test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "bm.hpp" 4 | #include "modnum.hpp" 5 | 6 | using namespace std; 7 | 8 | TEST_CASE("Berlekamp Massey", "[bm]") { 9 | using num = modnum; 10 | vector S({0, 1, 1, 2, 3, 5, 8, 13}); 11 | vector tr = BerlekampMassey(S); 12 | REQUIRE(tr == vector({num(1), num(1)})); 13 | num res = linearRec(S, tr, 1000); 14 | REQUIRE(res == num(517691607)); 15 | } 16 | -------------------------------------------------------------------------------- /src/cartesian_tree.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "reverse_comparator.hpp" 7 | 8 | class CartesianTree { 9 | public: 10 | struct Node { 11 | int l, m, r; // inclusive ranges 12 | std::array c; 13 | }; 14 | std::vector nodes; 15 | int root = -1; 16 | 17 | CartesianTree() {} 18 | 19 | Node& operator [] (int idx) { return nodes[idx]; } 20 | const Node& operator [] (int idx) const { return nodes[idx]; } 21 | 22 | int size() const { return int(nodes.size()); } 23 | 24 | private: 25 | CartesianTree(std::vector&& nodes_, int root_) : nodes(std::move(nodes_)), root(root_) {} 26 | 27 | public: 28 | 29 | // min-cartesian-tree, with earlier cells tiebroken earlier 30 | template > 31 | static CartesianTree build_min_tree(const std::vector& v, Comp comp = Comp()) { 32 | std::vector nodes(v.size()*2+1); 33 | std::vector stk; stk.reserve(v.size()); 34 | int root = -1; 35 | for (int i = 0; i <= int(v.size()); i++) { 36 | int cur = 2*i; 37 | nodes[cur].l = i; 38 | nodes[cur].r = i-1; 39 | nodes[cur].m = i-1; 40 | nodes[cur].c = {-1, -1}; 41 | while (!stk.empty() && (i == int(v.size()) || comp(v[i], v[nodes[stk.back()].m]))) { 42 | int nxt = stk.back(); stk.pop_back(); 43 | nodes[nxt].c[1] = cur; 44 | nodes[nxt].r = nodes[cur].r; 45 | cur = nxt; 46 | } 47 | if (i == int(v.size())) { 48 | root = cur; 49 | break; 50 | } 51 | nodes[2*i+1].l = nodes[cur].l; 52 | nodes[2*i+1].m = i; 53 | nodes[2*i+1].c[0] = cur; 54 | stk.push_back(2*i+1); 55 | } 56 | return {std::move(nodes), root}; 57 | } 58 | 59 | // max-cartesian-tree, with earlier cells tiebroken earlier 60 | template > 61 | static CartesianTree build_max_tree(const std::vector& v, Comp comp = Comp()) { 62 | return build_min_tree(v, reverse_comparator(comp)); 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /src/cartesian_tree.test.cpp: -------------------------------------------------------------------------------- 1 | #include "cartesian_tree.hpp" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | TEST_CASE("Cartesian Tree", "[cartesian_tree]") { 8 | std::mt19937 mt(Catch::getSeed()); 9 | for (int sz : {0, 1, 2, 3, 5, 8, 13}) { 10 | std::vector v(sz); 11 | iota(v.begin(), v.end(), 0); 12 | shuffle(v.begin(), v.end(), mt); 13 | { 14 | CartesianTree t = CartesianTree::build_min_tree(v); 15 | for (int i = 1; i < int(t.size()); i += 2) { 16 | REQUIRE(t[i].m == i/2); 17 | REQUIRE(t[i].l <= t[i].m); 18 | REQUIRE(t[i].m <= t[i].r); 19 | 20 | REQUIRE(t[t[i].c[0]].l == t[i].l); 21 | REQUIRE(t[t[i].c[0]].r == t[i].m-1); 22 | 23 | REQUIRE(t[t[i].c[1]].l == t[i].m+1); 24 | REQUIRE(t[t[i].c[1]].r == t[i].r); 25 | 26 | REQUIRE((t[t[i].c[0]].l > t[t[i].c[0]].r || v[t[i].m] < v[t[t[i].c[0]].m])); 27 | REQUIRE((t[t[i].c[1]].l > t[t[i].c[1]].r || v[t[i].m] < v[t[t[i].c[1]].m])); 28 | } 29 | } 30 | { 31 | CartesianTree t = CartesianTree::build_max_tree(v); 32 | for (int i = 1; i < int(t.size()); i += 2) { 33 | REQUIRE(t[i].m == i/2); 34 | REQUIRE(t[i].l <= t[i].m); 35 | REQUIRE(t[i].m <= t[i].r); 36 | 37 | REQUIRE(t[t[i].c[0]].l == t[i].l); 38 | REQUIRE(t[t[i].c[0]].r == t[i].m-1); 39 | 40 | REQUIRE(t[t[i].c[1]].l == t[i].m+1); 41 | REQUIRE(t[t[i].c[1]].r == t[i].r); 42 | 43 | REQUIRE((t[t[i].c[0]].l > t[t[i].c[0]].r || v[t[i].m] > v[t[t[i].c[0]].m])); 44 | REQUIRE((t[t[i].c[1]].l > t[t[i].c[1]].r || v[t[i].m] > v[t[t[i].c[1]].m])); 45 | } 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/char_poly.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // Compute the characteristic polynomial of a square matrix A over some field. 8 | // Not numerically stable at all. 9 | // Takes argument by value, use std::move if you can. 10 | template std::vector charPoly(std::vector> A) { 11 | int N = int(A.size()); 12 | std::vector res; res.reserve(N+1); 13 | res.push_back(num(1)); 14 | for (int i = 0, deg = 0; i < N; i++) { 15 | auto& Ai = A[i]; 16 | 17 | int c = i+1; 18 | while (c < N && Ai[c] == num(0)) c++; 19 | if (c == N) { 20 | res.resize(i+2, num(0)); 21 | for (int x = deg; x >= 0; x--) { 22 | num v = res[x]; 23 | for (int y = x+1, z = i; z >= deg; z--, y++) { 24 | res[y] -= v * Ai[z]; 25 | } 26 | } 27 | deg = i+1; 28 | continue; 29 | } 30 | 31 | num vc = Ai[c]; 32 | num ivc = inv(vc); 33 | 34 | Ai[c] = Ai[i+1]; 35 | Ai[i+1] = 0; 36 | 37 | std::swap(A[i+1], A[c]); 38 | auto& Ai1 = A[i+1]; 39 | for (int k = deg; k < N; k++) { 40 | Ai1[k] *= vc; 41 | } 42 | 43 | for (int k = i+1; k < N; k++) { 44 | auto& Ak = A[k]; 45 | { 46 | auto& x = Ak[i+1]; 47 | auto& y = Ak[c]; 48 | num tmp = y; 49 | y = x; 50 | x = tmp * ivc; 51 | } 52 | { 53 | num v = Ak[i+1]; 54 | for (int j = deg; j < N; j++) { 55 | Ak[j] -= v * Ai[j]; 56 | } 57 | } 58 | if (k > i+1) { 59 | num v = Ai[k]; 60 | for (int j = deg; j < N; j++) { 61 | Ai1[j] += v * Ak[j]; 62 | } 63 | } 64 | } 65 | 66 | for (int k = deg; k <= i; k++) { 67 | Ai1[k+1] += Ai[k]; 68 | } 69 | } 70 | reverse(res.begin(), res.end()); 71 | return res; 72 | } 73 | 74 | // Compute the characteristic polynomial of a square matrix A over F2. 75 | // Takes argument by value, use std::move if you can. 76 | // Note that MAXS must be at least N+1 77 | template std::bitset charPoly(std::vector> A) { 78 | using bs = std::bitset; 79 | int N = int(A.size()); 80 | assert(MAXS >= N+1); 81 | bs ans; ans[0] = 1; 82 | int deg = 0; 83 | for (int i = 0; i < N; i++) { 84 | { 85 | int j = int(A[i]._Find_next(i)); 86 | if (j >= N) { 87 | bs nans; 88 | for (; deg <= i; ans <<= 1, deg++) { 89 | if (A[i][deg]) nans ^= ans; 90 | } 91 | ans ^= nans; 92 | continue; 93 | } 94 | if (j != i+1) { 95 | swap(A[j], A[i+1]); 96 | for (auto& a : A) { 97 | bool tmp = a[j]; 98 | a[j] = a[i+1]; 99 | a[i+1] = tmp; 100 | } 101 | } 102 | } 103 | assert(A[i][i+1]); 104 | bs msk = A[i]; msk.flip(i+1); 105 | for (int k = 0; k < N; k++) { 106 | if (msk[k]) A[i+1] ^= A[k]; 107 | } 108 | for (auto& a : A) { 109 | if (a[i+1]) a ^= msk; 110 | } 111 | } 112 | return ans; 113 | } 114 | -------------------------------------------------------------------------------- /src/cnt_min.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reverse_comparator.hpp" 4 | 5 | template > struct cnt_min { 6 | T v; 7 | C cnt; 8 | 9 | cnt_min() : v(), cnt(0) {} 10 | explicit cnt_min(T v_) : v(v_), cnt(1) {} 11 | cnt_min(T v_, C cnt_) : v(v_), cnt(cnt_) {} 12 | 13 | friend cnt_min operator + (const cnt_min& a, const cnt_min& b) { 14 | if (!b.cnt) return a; 15 | else if (!a.cnt) return b; 16 | else if (Comp().operator()(a.v, b.v)) return a; 17 | else if (Comp().operator()(b.v, a.v)) return b; 18 | else return cnt_min(a.v, a.cnt + b.cnt); 19 | } 20 | 21 | cnt_min& operator += (const cnt_min& o) { 22 | return *this = (*this + o); 23 | } 24 | }; 25 | 26 | template > using cnt_max = cnt_min>; 27 | -------------------------------------------------------------------------------- /src/dirichlet_series.test.cpp: -------------------------------------------------------------------------------- 1 | #include "dirichlet_series.hpp" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "modnum.hpp" 8 | 9 | namespace dirichlet_series { 10 | 11 | namespace test { 12 | 13 | div_vector_layout layout; 14 | template using dv_values = dirichlet_series_values; 15 | template using dv_prefix = dirichlet_series_prefix; 16 | template using dv_bit = dirichlet_series_binary_indexed_tree; 17 | 18 | template 19 | dv_values multiply_slow(const dv_values& a, const dv_values& b) { 20 | dv_values r; 21 | for (int i = 1; i < layout.len; i++) { 22 | for (int j = 1; j < layout.len; j++) { 23 | int k = layout.get_value_bucket(layout.get_bucket_bound(i) * layout.get_bucket_bound(j)); 24 | if (k < layout.len) r.st[k] += a.st[i] * b.st[j]; 25 | } 26 | } 27 | return r; 28 | } 29 | 30 | TEMPLATE_TEST_CASE("Dirichlet series multiplication and inverse", "[dirichlet]", modnum, int64_t) { 31 | using num = TestType; 32 | for (int N = 1; N <= 30; N++) { 33 | INFO("N = " << N); 34 | std::mt19937 mt(Catch::getSeed()); 35 | layout = div_vector_layout(N); 36 | dv_prefix a([&](int64_t x) { return num(x); }); 37 | dv_prefix b([&](int64_t x) { return num(x) * num(x+1) / num(2); }); 38 | dv_prefix slow_res(multiply_slow(dv_values(a), dv_values(b))); 39 | dv_prefix fast_res = a * b; 40 | for (int i = 1; i < layout.len; i++) { 41 | INFO("i = " << i); 42 | REQUIRE(slow_res.st[i] == fast_res.st[i]); 43 | } 44 | dv_prefix a_2 = fast_res / b; 45 | for (int i = 1; i < layout.len; i++) { 46 | INFO("i = " << i); 47 | REQUIRE(a.st[i] == a_2.st[i]); 48 | } 49 | dv_prefix b_2 = fast_res / a; 50 | for (int i = 1; i < layout.len; i++) { 51 | INFO("i = " << i); 52 | REQUIRE(b.st[i] == b_2.st[i]); 53 | } 54 | if constexpr (!std::is_same_v) { 55 | dv_prefix rt_a = sqrt(a); 56 | dv_prefix a_3 = rt_a * rt_a; 57 | for (int i = 1; i < layout.len; i++) { 58 | INFO("i = " << i); 59 | REQUIRE(a.st[i] == a_3.st[i]); 60 | } 61 | dv_prefix rt_b = sqrt(b); 62 | dv_prefix b_3 = rt_b * rt_b; 63 | for (int i = 1; i < layout.len; i++) { 64 | INFO("i = " << i); 65 | REQUIRE(b.st[i] == b_3.st[i]); 66 | } 67 | } 68 | } 69 | } 70 | 71 | TEMPLATE_TEST_CASE("Dirichlet series BIT sparse multiplication", "[dirichlet]", modnum) { 72 | using num = TestType; 73 | for (int N : {1, 2, 3, 4, 5, 24, 25, 26, 99, 100, 101, 1000}) { 74 | INFO("N = " << N); 75 | layout = div_vector_layout(N); 76 | for (int m = 2; m <= N+1; m++) { 77 | INFO("m = " << m); 78 | num w = -5; 79 | dv_prefix a([&](int64_t x) -> num { return num(x * (x+1) / 2); }); 80 | dv_prefix b([&](int64_t x) -> num { return 1 + (x >= m ? w : 0); }); 81 | { 82 | dv_prefix c_slow = a * b; 83 | dv_bit bit(a); 84 | bit.sparse_mul_at_most_one(m, w); 85 | dv_prefix c(bit); 86 | for (int i = 1; i < layout.len; i++) { 87 | INFO("i = " << i); 88 | REQUIRE(c_slow.st[i] == c.st[i]); 89 | } 90 | } 91 | { 92 | dv_prefix d_slow = a / b; 93 | dv_bit bit(a); 94 | bit.sparse_div_at_most_one(m, w); 95 | dv_prefix d(bit); 96 | for (int i = 1; i < layout.len; i++) { 97 | INFO("i = " << i); 98 | REQUIRE(d_slow.st[i] == d.st[i]); 99 | } 100 | } 101 | } 102 | } 103 | } 104 | 105 | TEMPLATE_TEST_CASE("Dirichlet series euler transform", "[dirichlet]", modnum) { 106 | using num = TestType; 107 | for (int N : {1, 2, 3, 4, 5, 24, 25, 26, 99, 100, 101, 1000}) { 108 | INFO("N = " << N); 109 | layout = div_vector_layout(N); 110 | dv_prefix a([&](int64_t x) { return num(x) * num(x+1) / num(2); }); 111 | dv_prefix primes = inverse_euler_transform_fraction(a); 112 | dv_prefix primes2 = inverse_euler_transform_binary_indexed_tree(a); 113 | dv_values primes_slow_v; 114 | for (int v = 2; v <= N; v++) { 115 | bool is_prime = true; 116 | for (int p = 2; p * p <= v; p++) { 117 | if (v % p == 0) { 118 | is_prime = false; 119 | break; 120 | } 121 | } 122 | primes_slow_v[v] += is_prime ? num(v) : 0; 123 | } 124 | 125 | dv_prefix primes_slow(primes_slow_v); 126 | for (int i = 1; i < layout.len; i++) { 127 | INFO("i = " << i); 128 | INFO("bound = " << layout.get_bucket_bound(i)); 129 | REQUIRE(primes_slow.st[i] == primes.st[i]); 130 | REQUIRE(primes_slow.st[i] == primes2.st[i]); 131 | } 132 | 133 | dv_prefix b = euler_transform_fraction(primes_slow); 134 | dv_prefix b2 = euler_transform_binary_indexed_tree(primes_slow); 135 | for (int i = 1; i < layout.len; i++) { 136 | INFO("i = " << i); 137 | INFO("bound = " << layout.get_bucket_bound(i)); 138 | REQUIRE(a.st[i] == b.st[i]); 139 | REQUIRE(a.st[i] == b2.st[i]); 140 | } 141 | } 142 | } 143 | 144 | } 145 | 146 | } 147 | -------------------------------------------------------------------------------- /src/fft.test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "fft.hpp" 5 | 6 | namespace ecnerwala { 7 | namespace fft { 8 | 9 | using namespace std; 10 | 11 | template vector multiply_slow(const vector& a, const vector& b) { 12 | if (a.empty() || b.empty()) return {}; 13 | vector res(a.size() + b.size() - 1); 14 | for (int i = 0; i < int(a.size()); i++) { 15 | for (int j = 0; j < int(b.size()); j++) { 16 | res[i+j] += a[i] * b[j]; 17 | } 18 | } 19 | return res; 20 | } 21 | 22 | TEST_CASE("FFT Multiply Mod", "[fft]") { 23 | using num = modnum; 24 | mt19937 mt(48); 25 | vector a(100); 26 | vector b(168); 27 | for (num& x : a) { x = num(mt()); } 28 | for (num& x : b) { x = num(mt()); } 29 | REQUIRE(multiply>(a,b) == multiply_slow(a, b)); 30 | } 31 | 32 | TEST_CASE("FFT Inverse", "[fft]") { 33 | using num = modnum<998244353>; 34 | mt19937 mt(48); 35 | vector a(298); 36 | for (num& x : a) { x = num(mt()); } 37 | auto i = inverse, num>>(a); 38 | auto r = multiply>(a, i); 39 | REQUIRE(r == multiply_slow(a, i)); 40 | 41 | r.resize(a.size()); 42 | vector tgt(a.size()); 43 | tgt[0] = 1; 44 | REQUIRE(r == tgt); 45 | } 46 | 47 | }} // namespace ecnerwala::fft 48 | -------------------------------------------------------------------------------- /src/fraction.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | template struct fraction_t { 7 | T numer = 0, denom = 1; 8 | 9 | fraction_t() : numer(0), denom(1) {} 10 | fraction_t(T v) : numer(v), denom(1) {} 11 | fraction_t(T n, T d) : numer(n), denom(d) { 12 | if (denom < 0 || (denom == 0 && numer < 0)) { 13 | numer = -numer; 14 | denom = -denom; 15 | } 16 | } 17 | template explicit fraction_t(const fraction_t o) : numer(T(o.numer)), denom(T(o.denom)) {} 18 | 19 | friend std::ostream& operator << (std::ostream& o, const fraction_t& f) { 20 | return o << f.numer << '/' << f.denom; 21 | } 22 | friend std::istream& operator >> (std::istream& i, const fraction_t& f) { 23 | return i >> f.numer >> f.denom; 24 | } 25 | 26 | friend MulT cross(const fraction_t& a, const fraction_t& b) { 27 | return MulT(a.numer) * MulT(b.denom) - MulT(b.numer) * MulT(a.denom); 28 | } 29 | 30 | friend bool operator == (const fraction_t& a, const fraction_t& b) { 31 | return cross(a, b) == 0; 32 | } 33 | friend std::strong_ordering operator <=> (const fraction_t& a, const fraction_t& b) { 34 | return cross(a, b) <=> 0; 35 | } 36 | 37 | fraction_t operator + () const { return fraction_t(+numer, denom); } 38 | fraction_t operator - () const { return fraction_t(-numer, denom); } 39 | 40 | fraction_t& operator *= (const fraction_t& o) { 41 | numer *= o.numer; 42 | denom *= o.denom; 43 | return *this; 44 | } 45 | fraction_t& operator /= (const fraction_t& o) { 46 | numer *= o.denom; 47 | denom *= o.numer; 48 | return *this; 49 | } 50 | friend fraction_t operator * (const fraction_t& a, const fraction_t& b) { 51 | return fraction_t(a.numer * b.numer, a.denom * b.denom); 52 | } 53 | friend fraction_t operator / (const fraction_t& a, const fraction_t& b) { 54 | return fraction_t(a.numer * b.denom, a.denom * b.numer); 55 | } 56 | 57 | friend fraction_t operator + (const fraction_t& a, const fraction_t& b) { 58 | return {a.numer * b.denom + b.numer * a.denom, a.denom * b.denom}; 59 | } 60 | friend fraction_t operator - (const fraction_t& a, const fraction_t& b) { 61 | return {a.numer * b.denom - b.numer * a.denom, a.denom * b.denom}; 62 | } 63 | fraction_t& operator += (const fraction_t& o) { return *this = *this + o; } 64 | fraction_t& operator -= (const fraction_t& o) { return *this = *this - o; } 65 | 66 | fraction_t& reduce() { 67 | using std::gcd; 68 | T g = gcd(numer, denom); 69 | numer /= g; 70 | denom /= g; 71 | return *this; 72 | } 73 | }; 74 | -------------------------------------------------------------------------------- /src/geometry/point.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | template struct Point { 9 | public: 10 | T x, y; 11 | Point() : x(0), y(0) {} 12 | Point(T x_, T y_) : x(x_), y(y_) {} 13 | template explicit Point(const Point& p) : x(p.x), y(p.y) {} 14 | Point(const std::pair& p) : x(p.first), y(p.second) {} 15 | Point(const std::complex& p) : x(real(p)), y(imag(p)) {} 16 | explicit operator std::pair () const { return std::pair(x, y); } 17 | explicit operator std::complex () const { return std::complex(x, y); } 18 | void as_pair() const { return std::pair(*this); } 19 | void as_complex() const { return std::complex(*this); } 20 | 21 | friend std::ostream& operator << (std::ostream& o, const Point& p) { return o << '(' << p.x << ',' << p.y << ')'; } 22 | friend std::istream& operator >> (std::istream& i, Point& p) { return i >> p.x >> p.y; } 23 | friend bool operator == (const Point& a, const Point& b) { return a.x == b.x && a.y == b.y; } 24 | friend bool operator != (const Point& a, const Point& b) { return !(a==b); } 25 | 26 | Point operator + () const { return Point(+x, +y); } 27 | Point operator - () const { return Point(-x, -y); } 28 | 29 | Point& operator += (const Point& p) { x += p.x, y += p.y; return *this; } 30 | Point& operator -= (const Point& p) { x -= p.x, y -= p.y; return *this; } 31 | Point& operator *= (const T& t) { x *= t, y *= t; return *this; } 32 | Point& operator /= (const T& t) { x /= t, y /= t; return *this; } 33 | 34 | friend Point operator + (const Point& a, const Point& b) { return Point(a.x+b.x, a.y+b.y); } 35 | friend Point operator - (const Point& a, const Point& b) { return Point(a.x-b.x, a.y-b.y); } 36 | friend Point operator * (const Point& a, const T& t) { return Point(a.x*t, a.y*t); } 37 | friend Point operator * (const T& t ,const Point& a) { return Point(t*a.x, t*a.y); } 38 | friend Point operator / (const Point& a, const T& t) { return Point(a.x/t, a.y/t); } 39 | 40 | AreaT dist2() const { return AreaT(x) * AreaT(x) + AreaT(y) * AreaT(y); } 41 | auto dist() const { return std::sqrt(dist2()); } 42 | Point unit() const { return *this / this->dist(); } 43 | auto angle() const { return std::atan2(y, x); } 44 | 45 | T int_norm() const { return std::gcd(x,y); } 46 | Point int_unit() const { if (!x && !y) return *this; return *this / this->int_norm(); } 47 | 48 | // Convenient free-functions, mostly for generic interop 49 | friend auto norm(const Point& a) { return a.dist2(); } 50 | friend auto abs(const Point& a) { return a.dist(); } 51 | friend auto unit(const Point& a) { return a.unit(); } 52 | friend auto arg(const Point& a) { return a.angle(); } 53 | friend auto int_norm(const Point& a) { return a.int_norm(); } 54 | friend auto int_unit(const Point& a) { return a.int_unit(); } 55 | 56 | Point perp_cw() const { return Point(y, -x); } 57 | Point perp_ccw() const { return Point(-y, x); } 58 | 59 | friend AreaT dot(const Point& a, const Point& b) { return AreaT(a.x) * AreaT(b.x) + AreaT(a.y) * AreaT(b.y); } 60 | friend AreaT cross(const Point& a, const Point& b) { return AreaT(a.x) * AreaT(b.y) - AreaT(a.y) * AreaT(b.x); } 61 | friend AreaT cross3(const Point& a, const Point& b, const Point& c) { return cross(b-a, c-a); } 62 | 63 | // Complex numbers and rotation 64 | friend Point conj(const Point& a) { return Point(a.x, -a.y); } 65 | 66 | // Returns conj(a) * b 67 | friend Point dot_cross(const Point& a, const Point& b) { return Point(dot(a, b), cross(a, b)); } 68 | friend Point cmul(const Point& a, const Point& b) { return dot_cross(conj(a), b); } 69 | friend Point cdiv(const Point& a, const Point& b) { return dot_cross(b, a) / b.dist2(); } 70 | 71 | // Must be a unit vector; otherwise multiplies the result by abs(u) 72 | Point rotate(const Point& u) const { return dot_cross(conj(u), *this); } 73 | Point unrotate(const Point& u) const { return dot_cross(u, *this); } 74 | 75 | friend bool lex_less(const Point& a, const Point& b) { 76 | return std::tie(a.x, a.y) < std::tie(b.x, b.y); 77 | } 78 | 79 | friend bool same_dir(const Point& a, const Point& b) { return cross(a,b) == 0 && dot(a,b) > 0; } 80 | 81 | // check if 180 <= s..t < 360 82 | friend bool is_reflex(const Point& a, const Point& b) { auto c = cross(a,b); return c ? (c < 0) : (dot(a, b) < 0); } 83 | 84 | // operator < (s,t) for angles in [base,base+2pi) 85 | friend bool angle_less(const Point& base, const Point& s, const Point& t) { 86 | int r = is_reflex(base, s) - is_reflex(base, t); 87 | return r ? (r < 0) : (0 < cross(s, t)); 88 | } 89 | 90 | friend auto angle_cmp(const Point& base) { 91 | return [base](const Point& s, const Point& t) { return angle_less(base, s, t); }; 92 | } 93 | friend auto angle_cmp_center(const Point& center, const Point& dir) { 94 | return [center, dir](const Point& s, const Point& t) -> bool { return angle_less(dir, s-center, t-center); }; 95 | } 96 | 97 | // is p in [s,t] taken ccw? 1/0/-1 for in/border/out 98 | friend int angle_between(const Point& s, const Point& t, const Point& p) { 99 | if (same_dir(p, s) || same_dir(p, t)) return 0; 100 | return angle_less(s, p, t) ? 1 : -1; 101 | } 102 | }; 103 | -------------------------------------------------------------------------------- /src/geometry/point3d.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | const double PI = acos(-1.); 6 | const double TAU = 2 * PI; 7 | 8 | template struct Point3D { 9 | using P = Point3D; 10 | 11 | T x, y, z; 12 | Point3D() : x(0), y(0), z(0) {} 13 | Point3D(T x_, T y_, T z_) : x(x_), y(y_), z(z_) {} 14 | 15 | template explicit Point3D(const Point3D& p) : x(T(p.x)), y(T(p.y)), z(T(p.z)) {} 16 | 17 | friend std::istream& operator >> (std::istream& i, P& p) { return i >> p.x >> p.y >> p.z; } 18 | friend std::ostream& operator << (std::ostream& o, const P& p) { return o << "(" << p.x << "," << p.y << "," << p.z << ")"; } 19 | 20 | friend bool operator == (const P& a, const P& b) { return a.x == b.x && a.y == b.y && a.z == b.z; } 21 | friend bool operator != (const P& a, const P& b) { return a.x != b.x || a.y != b.y || a.z != b.z; } 22 | 23 | P& operator += (const P& o) { x += o.x, y += o.y, z += o.z; return *this; } 24 | P& operator -= (const P& o) { x -= o.x, y -= o.y, z -= o.z; return *this; } 25 | friend P operator + (const P& a, const P& b) { return P(a) += b; } 26 | friend P operator - (const P& a, const P& b) { return P(a) -= b; } 27 | 28 | P& operator *= (const T& t) { x *= t, y *= t, z *= t; return *this; } 29 | P& operator /= (const T& t) { x /= t, y /= t, z /= t; return *this; } 30 | friend P operator * (const P& p, const T& t) { return P(p) *= t; } 31 | friend P operator * (const T& t, const P& p) { return P(p) *= t; } 32 | friend P operator / (const P& a, const T& t) { return P(a) /= t; } 33 | 34 | friend P operator + (const P& a) { return P(+a.x, +a.y, +a.z); } 35 | friend P operator - (const P& a) { return P(-a.x, -a.y, -a.z); } 36 | 37 | friend AreaT dot(const P& a, const P& b) { return AreaT(a.x) * AreaT(b.x) + AreaT(a.y) * AreaT(b.y) + AreaT(a.z) * AreaT(b.z); } 38 | friend AreaT norm(const P& a) { return dot(a,a); } 39 | // We're playing a little loose with this type, expliitly cast it if you need 40 | friend Point3D cross(const P& a, const P& b) { return Point3D(AreaT(a.y) * AreaT(b.z) - AreaT(a.z) * AreaT(b.y), AreaT(a.z) * AreaT(b.x) - AreaT(a.x) * AreaT(b.z), AreaT(a.x) * AreaT(b.y) - AreaT(a.y) * AreaT(b.x)); } 41 | 42 | friend T int_norm(const P& p) { 43 | return std::gcd(std::gcd(abs(p.x), abs(p.y)), abs(p.z)); 44 | } 45 | friend P int_unit(const P& p) { 46 | T g = int_norm(p); 47 | return g ? p / g : p; 48 | } 49 | 50 | friend T abs(const P& a) { return std::sqrt(std::max(T(0), norm(a))); } 51 | friend P unit(const P& a) { return a / abs(a); } 52 | 53 | friend VolT vol(const P& a, const P& b, const P& c, const P& d) { return dot(cross(b-a, c-a), Point3D(d-a)); } 54 | 55 | friend bool lexLess(const P& a, const P& b) { return tie(a.x, a.y, a.z) < tie(b.x, b.y, b.z); } 56 | 57 | friend bool parallelSame(const P& a, const P& b) { 58 | assert(a != P()); 59 | assert(b != P()); 60 | return lexLess(a, P()) == lexLess(b, P()); 61 | } 62 | }; 63 | -------------------------------------------------------------------------------- /src/graph/make_st_dag.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "yc.hpp" 7 | 8 | // Direct a graph into a DAG so that given source and sink are the unique sources/sinks. 9 | // If there are any biconnected components not on the path from the source to 10 | // the sink, they will not be output, modify the code if necessary. 11 | // Returns a topological sort. Back out the edge directions yourself. 12 | inline std::vector make_st_dag(const std::vector>& adj, int source = -1, int sink = -1) { 13 | int N = int(adj.size()); 14 | 15 | // What's even going on lol 16 | if (N == 0) return {}; 17 | 18 | // Make some arbitrary choices as defaults 19 | if (source == -1 && sink == -1) source = 0; 20 | if (source == -1) source = adj[sink].empty() ? sink : adj[sink][0]; 21 | if (sink == -1) sink = adj[source].empty() ? source : adj[source][0]; 22 | 23 | std::vector depth(N, -1); 24 | std::vector lowval(N); 25 | std::vector has_sink(N); 26 | std::vector> ch(N); 27 | std::y_combinator([&](auto self, int cur, int prv) -> void { 28 | depth[cur] = prv != -1 ? depth[prv] + 1 : 0; 29 | lowval[cur] = depth[cur]; 30 | ch[cur].reserve(adj[cur].size()); 31 | has_sink[cur] = (cur == sink); 32 | for (int nxt : adj[cur]) { 33 | if (nxt == prv) continue; 34 | if (depth[nxt] == -1) { 35 | ch[cur].push_back(nxt); 36 | self(nxt, cur); 37 | lowval[cur] = std::min(lowval[cur], lowval[nxt]); 38 | if (has_sink[nxt]) has_sink[cur] = true; 39 | } else if (depth[nxt] < depth[cur]) { 40 | lowval[cur] = std::min(lowval[cur], depth[nxt]); 41 | } else { 42 | // down edge 43 | } 44 | } 45 | })(source, -1); 46 | 47 | // true is after, false is before 48 | std::vector edge_dir(N, false); 49 | std::vector lst_nxt(N, -1); 50 | auto lst = std::y_combinator([&](auto self, int cur) -> std::array { 51 | std::array res{cur, cur}; 52 | for (int nxt : ch[cur]) { 53 | // If we're on the path to the sink, mark it as downwards. 54 | 55 | // Comment out this line to direct extra bcc's as extra sinks 56 | if (!has_sink[nxt] && lowval[nxt] >= depth[cur]) continue; 57 | 58 | bool d = (has_sink[nxt] || lowval[nxt] >= depth[cur]) ? true : !edge_dir[lowval[nxt]]; 59 | edge_dir[depth[cur]] = d; 60 | 61 | auto ch_res = self(nxt); 62 | 63 | // Join res and ch 64 | if (!d) std::swap(res, ch_res); 65 | lst_nxt[std::exchange(res[1], ch_res[1])] = ch_res[0]; 66 | } 67 | return res; 68 | })(source); 69 | 70 | std::vector res; res.reserve(N); 71 | int cur = lst[0]; 72 | while (cur != -1) { 73 | res.push_back(cur); 74 | cur = lst_nxt[cur]; 75 | } 76 | return res; 77 | } 78 | -------------------------------------------------------------------------------- /src/hash_map.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | // #include 5 | #include 6 | 7 | struct splitmix64_hash { 8 | static uint64_t splitmix64(uint64_t x) { 9 | // http://xorshift.di.unimi.it/splitmix64.c 10 | x += 0x9e3779b97f4a7c15; 11 | x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9; 12 | x = (x ^ (x >> 27)) * 0x94d049bb133111eb; 13 | return x ^ (x >> 31); 14 | } 15 | 16 | size_t operator()(uint64_t x) const { 17 | static const uint64_t FIXED_RANDOM = std::chrono::steady_clock::now().time_since_epoch().count(); 18 | return splitmix64(x + FIXED_RANDOM); 19 | } 20 | }; 21 | 22 | template 23 | using hash_map = __gnu_pbds::gp_hash_table; 24 | 25 | template 26 | using hash_set = hash_map; 27 | -------------------------------------------------------------------------------- /src/jacobi.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | // Computes (n on m) == 1 using the binary-gcd method 7 | // m must be positive and odd, and n must be relatively prime 8 | template bool is_qr_jacobi(T n, T m) { 9 | bool r = true; 10 | assert(m & 1); 11 | assert(m > 0); 12 | if (n < 0) { 13 | if (m & 2) r = !r; 14 | n = -n; 15 | } 16 | while (m > 1) { 17 | assert(n > 0); 18 | int t = __builtin_ctzll(n); 19 | n >>= t; 20 | if ((t & 1) && (((m & 7) == 3) || ((m & 7) == 5))) { 21 | r = !r; 22 | } 23 | // n and m both odd 24 | if (n < m) { 25 | if ((n & 2) && (m & 2)) { 26 | r = !r; 27 | } 28 | using std::swap; 29 | swap(n, m); 30 | } 31 | n -= m; 32 | } 33 | return r; 34 | } 35 | -------------------------------------------------------------------------------- /src/lattice_cnt.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | // number of integer solutions to Ax + By <= C and x,y >= 0 7 | inline long long lattice_cnt(long long A, long long B, long long C) { 8 | using ll = long long; 9 | 10 | assert(A >= 0 && B >= 0); 11 | if (C < 0) return 0; 12 | 13 | assert(A > 0 && B > 0); 14 | if (A > B) std::swap(A, B); 15 | assert(A <= B); 16 | 17 | ll ans = 0; 18 | while (C >= 0) { 19 | assert(0 < A && A <= B); 20 | 21 | ll k = B/A; 22 | ll l = B%A; 23 | assert(B == k * A + l); 24 | 25 | ll f = C/B; 26 | ll e = C%B / A; 27 | ll g = C%B % A; 28 | assert(C == f * B + e * A + g); 29 | assert(C == (f * k + e) * A + f * l + g); 30 | 31 | // either x + ky <= f*k+e 32 | // i.e. 0 <= x <= (f-y) * k + e 33 | // or x >= fk + e + 1 - ky 34 | // and Ax + (Ak+l) y <= C = (fk + e + 1) A + fl - A + g 35 | // Let z = x - (fk + e + 1 - ky) 36 | // Az + A(fk + e + 1 - ky) + Aky + ly <= C = A (fk + e + 1) + fl - A + g 37 | // Az + ly <= fl - A + g 38 | 39 | ans += (f+1) * (e+1) + (f+1) * f / 2 * k; 40 | 41 | C = f*l - A + g; 42 | B = A; 43 | A = l; 44 | } 45 | return ans; 46 | } 47 | 48 | // count the number of 0 <= (a * x % m) < c for 0 <= x < n 49 | inline long long mod_count(long long a, long long m, long long c, long long n) { 50 | assert(m > 0); 51 | if (n == 0) return 0; 52 | 53 | a %= m; if (a < 0) a += m; 54 | 55 | long long extraC = c / m; c %= m; 56 | if (c < 0) extraC--, c += m; 57 | assert(0 <= c && c < m); 58 | 59 | long long ans = extraC * n; 60 | 61 | long long extraN = n / m; n %= m; 62 | if (n < 0) extraN--, n += m; 63 | assert(0 <= n && n < m); 64 | 65 | if (extraN) { 66 | ans += extraN * (lattice_cnt(m, a+m, (a+m) * (m-1)) - lattice_cnt(m, a+m, (a+m) * (m-1) - c)); 67 | } 68 | 69 | if (n) { 70 | // we want solutions to 0 <= a(N-1-x) - my < c with 0 <= x <= N-1 71 | // a * (N-1) >= ax + my > a * (N-1) - c 72 | ans += lattice_cnt(m, a+m, (a+m) * (n-1)) - lattice_cnt(m, a+m, (a+m) * (n-1) - c); 73 | } 74 | 75 | return ans; 76 | } 77 | 78 | inline long long mod_count_range(long long a, long long m, long long clo, long long chi, long long nlo, long long nhi) { 79 | return mod_count(a, m, chi, nhi) - mod_count(a, m, chi, nlo) - mod_count(a, m, clo, nhi) + mod_count(a, m, clo, nlo); 80 | } 81 | -------------------------------------------------------------------------------- /src/lattice_cnt.test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "lattice_cnt.hpp" 4 | 5 | using namespace std; 6 | 7 | long long lattice_cnt_slow(long long A, long long B, long long C) { 8 | using ll = long long; 9 | ll ans = 0; 10 | for (ll x = 0; A * x <= C; x++) { 11 | for (ll y = 0; A * x + B * y <= C; y++) { 12 | ans++; 13 | } 14 | } 15 | return ans; 16 | } 17 | 18 | long long mod_count_range_slow(long long a, long long m, long long clo, long long chi, long long nlo, long long nhi) { 19 | assert(nlo <= nhi); 20 | assert(clo <= chi); 21 | long long ans = 0; 22 | for (long long i = nlo; i < nhi; i++) { 23 | for (long long j = clo; j < chi; j++) { 24 | ans += (((a * i - j) % m) == 0); 25 | } 26 | } 27 | return ans; 28 | } 29 | 30 | TEST_CASE("Lattice Count", "[lattice_cnt]") { 31 | for (int a = 0; a <= 50; a++) { 32 | for (int b = 0; b <= 10; b++) { 33 | for (int c = -1; c <= 100; c++) { 34 | if ((a == 0 || b == 0) && c >= 0) continue; 35 | INFO("a = " << a); 36 | INFO("b = " << b); 37 | INFO("c = " << c); 38 | REQUIRE(lattice_cnt(a, b, c) == lattice_cnt_slow(a, b, c)); 39 | } 40 | } 41 | } 42 | } 43 | 44 | TEST_CASE("Mod Count (positive)", "[lattice_cnt]") { 45 | for (int m = 1; m <= 25; m++) { 46 | for (int a = 0; a <= m+10; a++) { 47 | for (int c = 0; c <= m; c++) { 48 | INFO("a = " << a); 49 | INFO("m = " << m); 50 | INFO("c = " << c); 51 | int trueAns = 0; 52 | for (int n = 1; n <= m+10; n++) { 53 | INFO("n = " << n); 54 | 55 | trueAns += (a * (n-1) % m) < c; 56 | REQUIRE(mod_count(a, m, c, n) == trueAns); 57 | } 58 | } 59 | } 60 | } 61 | } 62 | 63 | TEST_CASE("Mod Count (negatives)", "[lattice_cnt]") { 64 | for (int m : {1, 2, 3, 5, 8, 13, 21}) { 65 | for (int a : {-10, 0, 1, 2, 3, 5, m, m+5}) { 66 | auto cnds = {-37, -2*m-1, -m, -m+1, -m/2, -1, 0, 1, m/2, m+1, 2*m-1, 34}; 67 | INFO("a = " << a); 68 | INFO("m = " << m); 69 | for (int clo : cnds) { 70 | for (int nlo : cnds) { 71 | INFO("clo = " << clo); 72 | INFO("nlo = " << nlo); 73 | REQUIRE(mod_count_range(a, m, clo, 47, nlo, 49) == mod_count_range_slow(a, m, clo, 47, nlo, 49)); 74 | } 75 | } 76 | 77 | for (int chi : cnds) { 78 | for (int nhi : cnds) { 79 | INFO("chi = " << chi); 80 | INFO("nhi = " << nhi); 81 | REQUIRE(mod_count_range(a, m, -55, chi, -57, nhi) == mod_count_range_slow(a, m, -55, chi, -57, nhi)); 82 | } 83 | } 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/lct.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace lct { 7 | 8 | struct node { 9 | node* p; 10 | node* c[2]; 11 | 12 | int s; 13 | 14 | bool flip; 15 | 16 | // isroot 17 | inline bool r() { return p == nullptr || !(this == p->c[0] || this == p->c[1]); } 18 | // direction 19 | inline bool d() { assert(!r()); return this == p->c[1]; } 20 | 21 | inline void update() { s = 1 + (c[0] ? c[0]->s : 0) + (c[1] ? c[1]->s : 0); } 22 | void propogate() { 23 | if(flip) { 24 | std::swap(c[0], c[1]); 25 | if(c[0]) c[0]->flip = !c[0]->flip; 26 | if(c[1]) c[1]->flip = !c[1]->flip; 27 | flip = false; 28 | } 29 | } 30 | 31 | // precondition: parent and current are propogated 32 | void rot() { 33 | assert(!r()); 34 | 35 | int x = d(); 36 | node* pa = p; 37 | node* ch = c[!x]; 38 | 39 | assert(!pa->flip); 40 | assert(!flip); 41 | 42 | assert((!ch) || ch->p == this); 43 | 44 | if(!pa->r()) pa->p->c[pa->d()] = this; 45 | this->p = pa->p; 46 | 47 | pa->c[x] = ch; 48 | if(ch) ch->p = pa; 49 | 50 | this->c[!x] = pa; 51 | pa->p = this; 52 | 53 | pa->update(); 54 | update(); 55 | } 56 | 57 | // postcondition: always propogated 58 | void splay() { 59 | if(r()) { 60 | update(); 61 | propogate(); 62 | return; 63 | } 64 | 65 | while(!r()) { 66 | if(!p->r()) { 67 | node* gp = p->p; 68 | node* pa = p; 69 | gp->propogate(); 70 | pa->propogate(); 71 | propogate(); 72 | if(d() == p->d()) { 73 | pa->rot(); 74 | assert(p == pa); 75 | } else { 76 | rot(); 77 | assert(p == gp); 78 | } 79 | rot(); 80 | } else { 81 | p->propogate(); 82 | propogate(); 83 | rot(); 84 | assert(r()); 85 | } 86 | } 87 | update(); 88 | } 89 | 90 | // attach on right side 91 | // precondition: propogated 92 | void make_child(node* n) { 93 | assert(!flip); 94 | assert(r()); 95 | 96 | if(c[1]) { 97 | node* v = c[1]; 98 | c[1] = nullptr; 99 | assert(v->r()); 100 | 101 | update(); 102 | } 103 | 104 | assert(!flip); 105 | assert(!c[1]); 106 | 107 | if(n) { 108 | 109 | assert(n->r()); 110 | assert(n->p == this); 111 | 112 | c[1] = n; 113 | assert(c[1]->p == this); 114 | 115 | update(); 116 | } 117 | } 118 | 119 | // postcondition: propogated 120 | void expose() { 121 | splay(); 122 | assert(!flip); 123 | make_child(nullptr); 124 | while(p) { 125 | assert(r()); 126 | p->splay(); 127 | p->make_child(this); 128 | assert(!p->flip); 129 | assert(!flip); 130 | rot(); 131 | update(); 132 | assert(r()); 133 | } 134 | assert(!p); 135 | assert(!c[1]); 136 | } 137 | 138 | // does not propogate 139 | void make_root() { 140 | expose(); 141 | assert(p == nullptr); 142 | assert(r()); 143 | flip = !flip; 144 | } 145 | 146 | }; 147 | 148 | } // namespace lct 149 | -------------------------------------------------------------------------------- /src/level_ancestor.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "yc.hpp" 7 | 8 | namespace ecnerwala { 9 | 10 | using std::swap; 11 | 12 | struct level_ancestor { 13 | int N; 14 | std::vector preorder; 15 | std::vector idx; 16 | std::vector> heavyPar; // heavy parent, distance 17 | level_ancestor() : N(0) {} 18 | 19 | level_ancestor(const std::vector& par) : N(int(par.size())), preorder(N), idx(N), heavyPar(N) { 20 | std::vector> ch(N); 21 | for (int i = 0; i < N; i++) { 22 | if (par[i] != -1) ch[par[i]].push_back(i); 23 | } 24 | std::vector sz(N); 25 | int nxt_idx = 0; 26 | for (int i = 0; i < N; i++) { 27 | if (par[i] == -1) { 28 | std::y_combinator([&](auto self, int cur) -> void { 29 | sz[cur] = 1; 30 | for (int nxt : ch[cur]) { 31 | self(nxt); 32 | sz[cur] += sz[nxt]; 33 | } 34 | if (!ch[cur].empty()) { 35 | auto mit = max_element(ch[cur].begin(), ch[cur].end(), [&](int a, int b) { return sz[a] < sz[b]; }); 36 | swap(*ch[cur].begin(), *mit); 37 | } 38 | })(i); 39 | std::y_combinator([&](auto self, int cur, int isRoot = true) -> void { 40 | preorder[idx[cur] = nxt_idx++] = cur; 41 | if (isRoot) { 42 | heavyPar[idx[cur]] = {par[cur] == -1 ? -1 : idx[par[cur]], 1}; 43 | } else { 44 | assert(idx[par[cur]] == idx[cur]-1); 45 | heavyPar[idx[cur]] = heavyPar[idx[cur]-1]; 46 | heavyPar[idx[cur]].second++; 47 | } 48 | bool chRoot = false; 49 | for (int nxt : ch[cur]) { 50 | self(nxt, chRoot); 51 | chRoot = true; 52 | } 53 | })(i); 54 | } 55 | } 56 | } 57 | 58 | int get_ancestor(int a, int k) const { 59 | assert(k >= 0); 60 | a = idx[a]; 61 | while (a != -1 && k) { 62 | if (k >= heavyPar[a].second) { 63 | k -= heavyPar[a].second; 64 | assert(heavyPar[a].first <= a - heavyPar[a].second); 65 | a = heavyPar[a].first; 66 | } else { 67 | a -= k; 68 | k = 0; 69 | } 70 | } 71 | if (a == -1) return -1; 72 | else return preorder[a]; 73 | } 74 | 75 | int lca(int a, int b) const { 76 | a = idx[a], b = idx[b]; 77 | while (true) { 78 | if (a > b) swap(a, b); 79 | assert(a <= b); 80 | if (a > b - heavyPar[b].second) { 81 | return preorder[a]; 82 | } 83 | b = heavyPar[b].first; 84 | if (b == -1) return -1; 85 | } 86 | } 87 | 88 | int dist(int a, int b) const { 89 | a = idx[a], b = idx[b]; 90 | int res = 0; 91 | while (true) { 92 | if (a > b) swap(a, b); 93 | assert(a <= b); 94 | if (a > b - heavyPar[b].second) { 95 | res += b - a; 96 | break; 97 | } 98 | res += heavyPar[b].second; 99 | b = heavyPar[b].first; 100 | if (b == -1) return -1; 101 | } 102 | return res; 103 | } 104 | }; 105 | 106 | } // namespace ecnerwala 107 | -------------------------------------------------------------------------------- /src/manacher.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | /** 7 | * manacher(S): return the maximum palindromic substring of S centered at each point 8 | * 9 | * Input: string (or vector) of length N (no restrictions on character-set) 10 | * Output: vector res of length 2*N+1 11 | * For any 0 <= i <= 2*N: 12 | * * i % 2 == res[i] % 2 13 | * * the half-open substring S[(i-res[i])/2, (i+res[i])/2) is a palindrome of length res[i] 14 | * * For odd palindromes, take odd i, and vice versa 15 | */ 16 | template std::vector manacher(const V& S) { 17 | int N = int(S.size()); 18 | std::vector res(2*N+1, 0); 19 | for (int i = 1, j = -1, r = 0; i < 2*N; i++, j--) { 20 | if (i > r) { 21 | r = i+1, res[i] = 1; 22 | } else { 23 | res[i] = res[j]; 24 | } 25 | if (i+res[i] >= r) { 26 | int b = r>>1, a = i-b; 27 | while (a > 0 && b < N && S[a-1] == S[b]) { 28 | a--, b++; 29 | } 30 | res[i] = b-a, j = i, r = b<<1; 31 | } 32 | } 33 | return res; 34 | } 35 | 36 | /** 37 | * manacher_odd(S): return the maximum palindromic substring of S centered at each point 38 | * 39 | * Input: string (or vector) of length N (no restrictions on character-set) 40 | * Output: vector res of length N 41 | * For any 0 <= i < N: 42 | * * the half-open substring S[i-res[i], i+res[i]] is a palindrome of length 2*res[i]+1 43 | */ 44 | template std::vector manacher_odd(const V& S) { 45 | int N = int(S.size()); 46 | std::vector res(N); 47 | for (int i = 1, j = -1, r = 0; i < N; i++, j--) { 48 | if (i > r) { 49 | r = i, res[i] = 0; 50 | } else { 51 | res[i] = res[j]; 52 | } 53 | if (i+res[i] >= r) { 54 | int b = r, a = 2*i-r; 55 | while (a-1 >= 0 && b+1 < N && S[a-1] == S[b+1]) { 56 | a--, b++; 57 | } 58 | res[i] = b-i, j = i, r = b; 59 | } 60 | } 61 | return res; 62 | } 63 | -------------------------------------------------------------------------------- /src/mcmf.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | // #include 4 | #include 5 | 6 | // NOTE: This doesn't support negative-cost edges; you can adjust edge weights 7 | // (e.g. by precomputing a potential function) to make them positive. 8 | 9 | template 10 | struct MCMF_SSPA { 11 | int N; 12 | std::vector> adj; 13 | struct edge_t { 14 | int dest; 15 | flow_t cap; 16 | cost_t cost; 17 | }; 18 | std::vector edges; 19 | 20 | std::vector seen; 21 | std::vector pi; 22 | std::vector prv; 23 | 24 | explicit MCMF_SSPA(int N_) : N(N_), adj(N), pi(N, 0), prv(N) {} 25 | 26 | void add_edge(int from, int to, flow_t cap, cost_t cost) { 27 | assert(cap >= 0); 28 | assert(cost + pi[from] - pi[to] >= 0); // TODO: Remove this restriction 29 | int e = int(edges.size()); 30 | edges.emplace_back(edge_t{to, cap, cost}); 31 | edges.emplace_back(edge_t{from, 0, -cost}); 32 | adj[from].push_back(e); 33 | adj[to].push_back(e+1); 34 | } 35 | 36 | static constexpr cost_t INF_COST = std::numeric_limits::max() / 4; 37 | static constexpr flow_t INF_FLOW = std::numeric_limits::max() / 4; 38 | std::vector dist; 39 | __gnu_pbds::priority_queue> q; 40 | std::vector its; 41 | cost_t dijkstra(int s, int t) { 42 | dist.assign(N, INF_COST); 43 | dist[s] = 0; 44 | 45 | its.assign(N, q.end()); 46 | its[s] = q.push({-(dist[s] - pi[s]), s}); 47 | 48 | while (!q.empty()) { 49 | int i = q.top().second; q.pop(); 50 | cost_t d = dist[i]; 51 | for (int e : adj[i]) { 52 | if (edges[e].cap) { 53 | int j = edges[e].dest; 54 | cost_t nd = d + edges[e].cost; 55 | if (nd < dist[j]) { 56 | dist[j] = nd; 57 | prv[j] = e; 58 | if (its[j] == q.end()) { 59 | its[j] = q.push({-(dist[j] - pi[j]), j}); 60 | } else { 61 | q.modify(its[j], {-(dist[j] - pi[j]), j}); 62 | } 63 | } 64 | } 65 | } 66 | } 67 | 68 | swap(pi, dist); 69 | return pi[t]; 70 | } 71 | 72 | flow_t path(int s, int t) { 73 | flow_t cur_flow = std::numeric_limits::max(); 74 | for (int cur = t; cur != s; ) { 75 | int e = prv[cur]; 76 | int nxt = edges[e^1].dest; 77 | cur_flow = std::min(cur_flow, edges[e].cap); 78 | cur = nxt; 79 | } 80 | for (int cur = t; cur != s; ) { 81 | int e = prv[cur]; 82 | int nxt = edges[e^1].dest; 83 | edges[e].cap -= cur_flow; 84 | edges[e^1].cap += cur_flow; 85 | cur = nxt; 86 | } 87 | return cur_flow; 88 | } 89 | 90 | std::vector> all_flows(int s, int t, cost_t max_cost = INF_COST - 1) { 91 | assert(s != t); 92 | std::vector> res; 93 | while (dijkstra(s, t) <= max_cost) { 94 | assert(res.empty() || pi[t] >= res.back().second); 95 | flow_t f = path(s, t); 96 | res.push_back({f, pi[t]}); 97 | } 98 | return res; 99 | } 100 | 101 | std::pair max_flow(int s, int t, cost_t max_cost = INF_COST - 1) { 102 | assert(s != t); 103 | flow_t tot_flow = 0; cost_t tot_cost = 0; 104 | while (dijkstra(s, t) <= max_cost) { 105 | flow_t cur_flow = path(s, t); 106 | tot_flow += cur_flow; 107 | tot_cost += cur_flow * pi[t]; 108 | } 109 | return {tot_flow, tot_cost}; 110 | } 111 | }; 112 | 113 | template 114 | struct MCMF_Dinic { 115 | int N; 116 | std::vector> adj; 117 | struct edge_t { 118 | int dest; 119 | flow_t cap; 120 | cost_t cost; 121 | }; 122 | std::vector edges; 123 | 124 | std::vector seen; 125 | std::vector pi; 126 | 127 | explicit MCMF_Dinic(int N_) : N(N_), adj(N), pi(N, 0) {} 128 | 129 | void add_edge(int from, int to, flow_t cap, cost_t cost) { 130 | assert(cap >= 0); 131 | assert(cost + pi[from] - pi[to] >= 0); // TODO: Remove this restriction 132 | int e = int(edges.size()); 133 | edges.emplace_back(edge_t{to, cap, cost}); 134 | edges.emplace_back(edge_t{from, 0, -cost}); 135 | adj[from].push_back(e); 136 | adj[to].push_back(e+1); 137 | } 138 | 139 | static constexpr cost_t INF_COST = std::numeric_limits::max() / 4; 140 | static constexpr flow_t INF_FLOW = std::numeric_limits::max() / 4; 141 | std::vector dist; 142 | __gnu_pbds::priority_queue> q; 143 | std::vector its; 144 | cost_t dijkstra(int s, int t) { 145 | dist.assign(N, INF_COST); 146 | dist[s] = 0; 147 | 148 | its.assign(N, q.end()); 149 | its[s] = q.push({-(dist[s] - pi[s]), s}); 150 | 151 | while (!q.empty()) { 152 | int i = q.top().second; q.pop(); 153 | cost_t d = dist[i]; 154 | for (int e : adj[i]) { 155 | if (edges[e].cap) { 156 | int j = edges[e].dest; 157 | cost_t nd = d + edges[e].cost; 158 | if (nd < dist[j]) { 159 | dist[j] = nd; 160 | if (its[j] == q.end()) { 161 | its[j] = q.push({-(dist[j] - pi[j]), j}); 162 | } else { 163 | q.modify(its[j], {-(dist[j] - pi[j]), j}); 164 | } 165 | } 166 | } 167 | } 168 | } 169 | 170 | std::swap(pi, dist); 171 | return pi[t]; 172 | } 173 | 174 | std::vector buf; 175 | std::vector level; 176 | flow_t dinic_dfs(int cur, int t, flow_t f) { 177 | if (cur == t) return f; 178 | flow_t cur_f = 0; 179 | assert(f > 0); 180 | for (; buf[cur] < int(adj[cur].size()); buf[cur]++) { 181 | int e = adj[cur][buf[cur]]; 182 | int nxt = edges[e].dest; 183 | if (level[nxt] == level[cur] + 1 && edges[e].cap > 0 && edges[e].cost == pi[nxt] - pi[cur]) { 184 | flow_t v = dinic_dfs(nxt, t, std::min(f, edges[e].cap)); 185 | edges[e].cap -= v; 186 | edges[e^1].cap += v; 187 | f -= v; 188 | cur_f += v; 189 | if (f == 0) break; 190 | } 191 | } 192 | return cur_f; 193 | } 194 | flow_t dinic(int s, int t) { 195 | flow_t tot_flow = 0; 196 | while (true) { 197 | buf.clear(); 198 | buf.reserve(N); 199 | level.assign(N, -1); 200 | buf.push_back(s); 201 | level[s] = 0; 202 | for (int z = 0; z < int(buf.size()); z++) { 203 | int cur = buf[z]; 204 | for (int e : adj[cur]) { 205 | int nxt = edges[e].dest; 206 | if (edges[e].cap > 0 && edges[e].cost == pi[nxt] - pi[cur] && level[nxt] == -1) { 207 | level[nxt] = level[cur] + 1; 208 | buf.push_back(nxt); 209 | } 210 | } 211 | } 212 | if (level[t] == -1) break; 213 | buf.assign(N, 0); 214 | tot_flow += dinic_dfs(s, t, INF_FLOW); 215 | } 216 | return tot_flow; 217 | } 218 | 219 | std::vector> all_flows(int s, int t, cost_t max_cost = INF_COST - 1) { 220 | assert(s != t); 221 | std::vector> res; 222 | while (dijkstra(s, t) <= max_cost) { 223 | assert(res.empty() || pi[t] > res.back().second); 224 | flow_t f = dinic(s, t); 225 | res.push_back({f, pi[t]}); 226 | } 227 | return res; 228 | } 229 | 230 | std::pair max_flow(int s, int t, cost_t max_cost = INF_COST - 1) { 231 | assert(s != t); 232 | flow_t tot_flow = 0; cost_t tot_cost = 0; 233 | while (dijkstra(s, t) <= max_cost) { 234 | flow_t cur_flow = dinic(s, t); 235 | tot_flow += cur_flow; 236 | tot_cost += cur_flow * pi[t]; 237 | } 238 | return {tot_flow, tot_cost}; 239 | } 240 | }; 241 | 242 | template 243 | struct Dinic { 244 | int N; 245 | std::vector> adj; 246 | struct edge_t { 247 | int dest; 248 | flow_t cap; 249 | }; 250 | std::vector edges; 251 | 252 | std::vector seen; 253 | 254 | explicit Dinic(int N_) : N(N_), adj(N) {} 255 | 256 | void add_edge(int from, int to, flow_t cap) { 257 | return add_bi_edge(from, to, cap, 0); 258 | } 259 | 260 | void add_bi_edge(int from, int to, flow_t cap, flow_t rev_cap) { 261 | assert(cap >= 0); 262 | assert(rev_cap >= 0); 263 | int e = int(edges.size()); 264 | edges.emplace_back(edge_t{to, cap}); 265 | edges.emplace_back(edge_t{from, rev_cap}); 266 | adj[from].push_back(e); 267 | adj[to].push_back(e+1); 268 | } 269 | 270 | static constexpr tot_flow_t INF_FLOW = std::numeric_limits::max() / 4; 271 | std::vector buf; 272 | std::vector level; 273 | tot_flow_t dinic_dfs(int cur, int t, tot_flow_t f) { 274 | if (cur == t) return f; 275 | tot_flow_t cur_f = 0; 276 | assert(f > 0); 277 | for (; buf[cur] < int(adj[cur].size()); buf[cur]++) { 278 | int e = adj[cur][buf[cur]]; 279 | int nxt = edges[e].dest; 280 | if (level[nxt] == level[cur] + 1 && edges[e].cap > 0) { 281 | flow_t v = flow_t(dinic_dfs(nxt, t, std::min(f, edges[e].cap))); 282 | edges[e].cap -= v; 283 | edges[e^1].cap += v; 284 | f -= v; 285 | cur_f += v; 286 | if (f == 0) break; 287 | } 288 | } 289 | return cur_f; 290 | } 291 | tot_flow_t dinic(int s, int t) { 292 | tot_flow_t tot_flow = 0; 293 | while (true) { 294 | buf.clear(); 295 | buf.reserve(N); 296 | level.assign(N, -1); 297 | buf.push_back(s); 298 | level[s] = 0; 299 | for (int z = 0; z < int(buf.size()); z++) { 300 | int cur = buf[z]; 301 | for (int e : adj[cur]) { 302 | int nxt = edges[e].dest; 303 | if (edges[e].cap > 0 && level[nxt] == -1) { 304 | level[nxt] = level[cur] + 1; 305 | buf.push_back(nxt); 306 | } 307 | } 308 | } 309 | if (level[t] == -1) break; 310 | buf.assign(N, 0); 311 | tot_flow += dinic_dfs(s, t, INF_FLOW); 312 | } 313 | return tot_flow; 314 | } 315 | tot_flow_t max_flow(int s, int t) { return dinic(s, t); } 316 | }; 317 | -------------------------------------------------------------------------------- /src/modnum.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | template T mod_inv_in_range(T a, T m) { 8 | // assert(0 <= a && a < m); 9 | T x = a, y = m; 10 | // coeff of a in x and y 11 | T vx = 1, vy = 0; 12 | while (x) { 13 | T k = y / x; 14 | y %= x; 15 | vy -= k * vx; 16 | std::swap(x, y); 17 | std::swap(vx, vy); 18 | } 19 | assert(y == 1); 20 | return vy < 0 ? m + vy : vy; 21 | } 22 | 23 | template struct extended_gcd_result { 24 | T gcd; 25 | T coeff_a, coeff_b; 26 | }; 27 | template extended_gcd_result extended_gcd(T a, T b) { 28 | T x = a, y = b; 29 | // coeff of a and b in x and y 30 | T ax = 1, ay = 0; 31 | T bx = 0, by = 1; 32 | while (x) { 33 | T k = y / x; 34 | y %= x; 35 | ay -= k * ax; 36 | by -= k * bx; 37 | std::swap(x, y); 38 | std::swap(ax, ay); 39 | std::swap(bx, by); 40 | } 41 | return {y, ay, by}; 42 | } 43 | 44 | template T mod_inv(T a, T m) { 45 | a %= m; 46 | a = a < 0 ? a + m : a; 47 | return mod_inv_in_range(a, m); 48 | } 49 | 50 | template struct modnum { 51 | static constexpr int MOD = MOD_; 52 | static_assert(MOD_ > 0, "MOD must be positive"); 53 | 54 | private: 55 | int v; 56 | 57 | public: 58 | 59 | modnum() : v(0) {} 60 | modnum(int64_t v_) : v(int(v_ % MOD)) { if (v < 0) v += MOD; } 61 | explicit operator int() const { return v; } 62 | friend std::ostream& operator << (std::ostream& out, const modnum& n) { return out << int(n); } 63 | friend std::istream& operator >> (std::istream& in, modnum& n) { int64_t v_; in >> v_; n = modnum(v_); return in; } 64 | 65 | friend bool operator == (const modnum& a, const modnum& b) { return a.v == b.v; } 66 | friend bool operator != (const modnum& a, const modnum& b) { return a.v != b.v; } 67 | 68 | modnum inv() const { 69 | modnum res; 70 | res.v = mod_inv_in_range(v, MOD); 71 | return res; 72 | } 73 | friend modnum inv(const modnum& m) { return m.inv(); } 74 | modnum neg() const { 75 | modnum res; 76 | res.v = v ? MOD-v : 0; 77 | return res; 78 | } 79 | friend modnum neg(const modnum& m) { return m.neg(); } 80 | 81 | modnum operator- () const { 82 | return neg(); 83 | } 84 | modnum operator+ () const { 85 | return modnum(*this); 86 | } 87 | 88 | modnum& operator ++ () { 89 | v ++; 90 | if (v == MOD) v = 0; 91 | return *this; 92 | } 93 | modnum& operator -- () { 94 | if (v == 0) v = MOD; 95 | v --; 96 | return *this; 97 | } 98 | modnum& operator += (const modnum& o) { 99 | v -= MOD-o.v; 100 | v = (v < 0) ? v + MOD : v; 101 | return *this; 102 | } 103 | modnum& operator -= (const modnum& o) { 104 | v -= o.v; 105 | v = (v < 0) ? v + MOD : v; 106 | return *this; 107 | } 108 | modnum& operator *= (const modnum& o) { 109 | v = int(int64_t(v) * int64_t(o.v) % MOD); 110 | return *this; 111 | } 112 | modnum& operator /= (const modnum& o) { 113 | return *this *= o.inv(); 114 | } 115 | 116 | friend modnum operator ++ (modnum& a, int) { modnum r = a; ++a; return r; } 117 | friend modnum operator -- (modnum& a, int) { modnum r = a; --a; return r; } 118 | friend modnum operator + (const modnum& a, const modnum& b) { return modnum(a) += b; } 119 | friend modnum operator - (const modnum& a, const modnum& b) { return modnum(a) -= b; } 120 | friend modnum operator * (const modnum& a, const modnum& b) { return modnum(a) *= b; } 121 | friend modnum operator / (const modnum& a, const modnum& b) { return modnum(a) /= b; } 122 | }; 123 | 124 | template T pow(T a, long long b) { 125 | assert(b >= 0); 126 | T r = 1; while (b) { if (b & 1) r *= a; b >>= 1; a *= a; } return r; 127 | } 128 | 129 | template struct pairnum { 130 | U u; 131 | V v; 132 | 133 | pairnum() : u(0), v(0) {} 134 | pairnum(long long val) : u(val), v(val) {} 135 | pairnum(const U& u_, const V& v_) : u(u_), v(v_) {} 136 | 137 | friend std::ostream& operator << (std::ostream& out, const pairnum& n) { return out << '(' << n.u << ',' << ' ' << n.v << ')'; } 138 | friend std::istream& operator >> (std::istream& in, pairnum& n) { long long val; in >> val; n = pairnum(val); return in; } 139 | 140 | friend bool operator == (const pairnum& a, const pairnum& b) { return a.u == b.u && a.v == b.v; } 141 | friend bool operator != (const pairnum& a, const pairnum& b) { return a.u != b.u || a.v != b.v; } 142 | 143 | pairnum inv() const { 144 | return pairnum(u.inv(), v.inv()); 145 | } 146 | pairnum neg() const { 147 | return pairnum(u.neg(), v.neg()); 148 | } 149 | pairnum operator- () const { 150 | return pairnum(-u, -v); 151 | } 152 | pairnum operator+ () const { 153 | return pairnum(+u, +v); 154 | } 155 | 156 | pairnum& operator ++ () { 157 | ++u, ++v; 158 | return *this; 159 | } 160 | pairnum& operator -- () { 161 | --u, --v; 162 | return *this; 163 | } 164 | 165 | pairnum& operator += (const pairnum& o) { 166 | u += o.u; 167 | v += o.v; 168 | return *this; 169 | } 170 | pairnum& operator -= (const pairnum& o) { 171 | u -= o.u; 172 | v -= o.v; 173 | return *this; 174 | } 175 | pairnum& operator *= (const pairnum& o) { 176 | u *= o.u; 177 | v *= o.v; 178 | return *this; 179 | } 180 | pairnum& operator /= (const pairnum& o) { 181 | u /= o.u; 182 | v /= o.v; 183 | return *this; 184 | } 185 | 186 | friend pairnum operator ++ (pairnum& a, int) { pairnum r = a; ++a; return r; } 187 | friend pairnum operator -- (pairnum& a, int) { pairnum r = a; --a; return r; } 188 | friend pairnum operator + (const pairnum& a, const pairnum& b) { return pairnum(a) += b; } 189 | friend pairnum operator - (const pairnum& a, const pairnum& b) { return pairnum(a) -= b; } 190 | friend pairnum operator * (const pairnum& a, const pairnum& b) { return pairnum(a) *= b; } 191 | friend pairnum operator / (const pairnum& a, const pairnum& b) { return pairnum(a) /= b; } 192 | }; 193 | 194 | template struct dynamic_modnum { 195 | private: 196 | #if __cpp_inline_variables >= 201606 197 | // C++17 and up 198 | inline static int MOD_ = 0; 199 | inline static uint64_t BARRETT_M = 0; 200 | #else 201 | // NB: these must be initialized out of the class by hand: 202 | // static int dynamic_modnum::MOD = 0; 203 | // static int dynamic_modnum::BARRETT_M = 0; 204 | static int MOD_; 205 | static uint64_t BARRETT_M; 206 | #endif 207 | 208 | public: 209 | // Make only the const-reference public, to force the use of set_mod 210 | static constexpr int const& MOD = MOD_; 211 | 212 | // Barret reduction taken from KACTL: 213 | /** 214 | * Author: Simon Lindholm 215 | * Date: 2020-05-30 216 | * License: CC0 217 | * Source: https://en.wikipedia.org/wiki/Barrett_reduction 218 | * Description: Compute $a \% b$ about 5 times faster than usual, where $b$ is constant but not known at compile time. 219 | * Returns a value congruent to $a \pmod b$ in the range $[0, 2b)$. 220 | * Status: proven correct, stress-tested 221 | * Measured as having 4 times lower latency, and 8 times higher throughput, see stress-test. 222 | * Details: 223 | * More precisely, it can be proven that the result equals 0 only if $a = 0$, 224 | * and otherwise lies in $[1, (1 + a/2^64) * b)$. 225 | */ 226 | static void set_mod(int mod) { 227 | assert(mod > 0); 228 | MOD_ = mod; 229 | BARRETT_M = (uint64_t(-1) / MOD); 230 | } 231 | static uint32_t barrett_reduce_partial(uint64_t a) { 232 | return uint32_t(a - uint64_t((__uint128_t(BARRETT_M) * a) >> 64) * MOD); 233 | } 234 | static int barrett_reduce(uint64_t a) { 235 | int32_t res = int32_t(barrett_reduce_partial(a) - MOD); 236 | return (res < 0) ? res + MOD : res; 237 | } 238 | 239 | struct mod_reader { 240 | friend std::istream& operator >> (std::istream& i, mod_reader) { 241 | int mod; i >> mod; 242 | dynamic_modnum::set_mod(mod); 243 | return i; 244 | } 245 | }; 246 | static mod_reader MOD_READER() { 247 | return mod_reader(); 248 | } 249 | 250 | private: 251 | int v; 252 | 253 | public: 254 | 255 | dynamic_modnum() : v(0) {} 256 | dynamic_modnum(int64_t v_) : v(int(v_ % MOD)) { if (v < 0) v += MOD; } 257 | explicit operator int() const { return v; } 258 | friend std::ostream& operator << (std::ostream& out, const dynamic_modnum& n) { return out << int(n); } 259 | friend std::istream& operator >> (std::istream& in, dynamic_modnum& n) { int64_t v_; in >> v_; n = dynamic_modnum(v_); return in; } 260 | 261 | friend bool operator == (const dynamic_modnum& a, const dynamic_modnum& b) { return a.v == b.v; } 262 | friend bool operator != (const dynamic_modnum& a, const dynamic_modnum& b) { return a.v != b.v; } 263 | 264 | dynamic_modnum inv() const { 265 | dynamic_modnum res; 266 | res.v = mod_inv_in_range(v, MOD); 267 | return res; 268 | } 269 | friend dynamic_modnum inv(const dynamic_modnum& m) { return m.inv(); } 270 | dynamic_modnum neg() const { 271 | dynamic_modnum res; 272 | res.v = v ? MOD-v : 0; 273 | return res; 274 | } 275 | friend dynamic_modnum neg(const dynamic_modnum& m) { return m.neg(); } 276 | 277 | dynamic_modnum operator- () const { 278 | return neg(); 279 | } 280 | dynamic_modnum operator+ () const { 281 | return dynamic_modnum(*this); 282 | } 283 | 284 | dynamic_modnum& operator ++ () { 285 | v ++; 286 | if (v == MOD) v = 0; 287 | return *this; 288 | } 289 | dynamic_modnum& operator -- () { 290 | if (v == 0) v = MOD; 291 | v --; 292 | return *this; 293 | } 294 | dynamic_modnum& operator += (const dynamic_modnum& o) { 295 | v -= MOD-o.v; 296 | v = (v < 0) ? v + MOD : v; 297 | return *this; 298 | } 299 | dynamic_modnum& operator -= (const dynamic_modnum& o) { 300 | v -= o.v; 301 | v = (v < 0) ? v + MOD : v; 302 | return *this; 303 | } 304 | dynamic_modnum& operator *= (const dynamic_modnum& o) { 305 | v = barrett_reduce(int64_t(v) * int64_t(o.v)); 306 | return *this; 307 | } 308 | dynamic_modnum& operator /= (const dynamic_modnum& o) { 309 | return *this *= o.inv(); 310 | } 311 | 312 | friend dynamic_modnum operator ++ (dynamic_modnum& a, int) { dynamic_modnum r = a; ++a; return r; } 313 | friend dynamic_modnum operator -- (dynamic_modnum& a, int) { dynamic_modnum r = a; --a; return r; } 314 | friend dynamic_modnum operator + (const dynamic_modnum& a, const dynamic_modnum& b) { return dynamic_modnum(a) += b; } 315 | friend dynamic_modnum operator - (const dynamic_modnum& a, const dynamic_modnum& b) { return dynamic_modnum(a) -= b; } 316 | friend dynamic_modnum operator * (const dynamic_modnum& a, const dynamic_modnum& b) { return dynamic_modnum(a) *= b; } 317 | friend dynamic_modnum operator / (const dynamic_modnum& a, const dynamic_modnum& b) { return dynamic_modnum(a) /= b; } 318 | }; 319 | 320 | template struct mod_constraint { 321 | T v, mod; 322 | 323 | friend mod_constraint operator & (mod_constraint a, mod_constraint b) { 324 | if (a.mod < b.mod) std::swap(a, b); 325 | if (b.mod == 1) return a; 326 | 327 | extended_gcd_result egcd = extended_gcd(a.mod, b.mod); 328 | assert(a.v % egcd.gcd == b.v % egcd.gcd); 329 | 330 | T extra = b.v - a.v % b.mod; 331 | extra /= egcd.gcd; 332 | 333 | extra *= egcd.coeff_a; 334 | extra %= b.mod / egcd.gcd; 335 | extra += (extra < 0) ? b.mod / egcd.gcd : 0; 336 | 337 | return mod_constraint{ 338 | a.v + extra * a.mod, 339 | a.mod * (b.mod / egcd.gcd) 340 | }; 341 | } 342 | }; 343 | -------------------------------------------------------------------------------- /src/modnum.test.cpp: -------------------------------------------------------------------------------- 1 | #include "modnum.hpp" 2 | #include 3 | #include // Include for std::lcm and std::gcd 4 | 5 | TEST_CASE("Mod Constraint Regression Test", "[mod_constraint]") { 6 | for (int a_mod = 1; a_mod <= 10; ++a_mod) { 7 | for (int a_val = 0; a_val < a_mod; ++a_val) { 8 | for (int b_mod = 1; b_mod <= 10; ++b_mod) { 9 | for (int b_val = 0; b_val < b_mod; ++b_val) { 10 | if (a_val % std::gcd(a_mod, b_mod) != b_val % std::gcd(a_mod, b_mod)) continue; 11 | 12 | mod_constraint a{a_val, a_mod}; 13 | mod_constraint b{b_val, b_mod}; 14 | 15 | mod_constraint r = a & b; 16 | 17 | // Check that r.mod is the LCM of a.mod and b.mod 18 | int lcm_ab = std::lcm(a.mod, b.mod); 19 | REQUIRE(r.mod == lcm_ab); 20 | 21 | // Check that r.v % a.mod == a.v (and likewise for b) 22 | REQUIRE(r.v % a.mod == a.v); 23 | REQUIRE(r.v % b.mod == b.v); 24 | 25 | // Check that r.v is between 0 and r.mod 26 | REQUIRE(r.v >= 0); 27 | REQUIRE(r.v < r.mod); 28 | } 29 | } 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/nim_prod.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | // Usage: 7 | // constexpr nim_prod_t nimProd; 8 | // C++20: 9 | // constinit nim_prod_t nimProd; 10 | struct nim_prod_t { 11 | uint64_t bit_prod[64][64]{}; 12 | constexpr nim_prod_t() { 13 | for (int i = 0; i < 64; i++) { 14 | for (int j = 0; j < 64; j++) { 15 | if ((i & j) == 0) { 16 | bit_prod[i][j] = uint64_t(1) << (i|j); 17 | } else { 18 | int a = (i&j) & -(i&j); 19 | bit_prod[i][j] = bit_prod[i ^ a][j] ^ bit_prod[(i ^ a) | (a-1)][(j ^ a) | (i & (a-1))]; 20 | } 21 | } 22 | } 23 | } 24 | constexpr uint64_t operator () (uint64_t x, uint64_t y) const { 25 | uint64_t res = 0; 26 | for (int i = 0; i < 64 && (x >> i); i++) 27 | if ((x >> i) & 1) 28 | for (int j = 0; j < 64 && (y >> j); j++) 29 | if ((y >> j) & 1) 30 | res ^= bit_prod[i][j]; 31 | return res; 32 | } 33 | }; 34 | -------------------------------------------------------------------------------- /src/optimize.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #pragma GCC optimize("unroll-loops") 3 | #pragma GCC optimize("Ofast") 4 | #pragma GCC target("sse,sse2,sse3,ssse3,popcnt,abm,mmx") // Safe for yandex 5 | 6 | #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,bmi,bmi2,mmx,avx,avx2,fma") // Requires AVX2 7 | 8 | // See https://codeforces.com/blog/entry/96344 9 | 10 | inline void disable_denormal_floats() { 11 | // https://stackoverflow.com/a/8217313 12 | #define CSR_FLUSH_TO_ZERO (1 << 15) 13 | unsigned csr = __builtin_ia32_stmxcsr(); 14 | csr |= CSR_FLUSH_TO_ZERO; 15 | __builtin_ia32_ldmxcsr(csr); 16 | #undef CSR_FLUSH_TO_ZERO 17 | } 18 | -------------------------------------------------------------------------------- /src/order_statistic.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | template > 6 | using order_statistic_map = __gnu_pbds::tree< 7 | K, V, Comp, 8 | __gnu_pbds::rb_tree_tag, 9 | __gnu_pbds::tree_order_statistics_node_update 10 | >; 11 | 12 | template > 13 | using order_statistic_set = order_statistic_map; 14 | 15 | // Supports 16 | // auto iterator = order_statistic_set().find_by_order(idx); // (0-indexed) 17 | // int num_strictly_smaller = order_statistic_set().order_of_key(key); 18 | -------------------------------------------------------------------------------- /src/perm_tree.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | class PermTree { 8 | public: 9 | enum class NodeType { 10 | LEAF, 11 | INCR, 12 | DECR, 13 | FULL, 14 | PARTIAL, 15 | }; 16 | 17 | struct Node { 18 | std::array c; 19 | NodeType type; 20 | int l, r, lo, hi; 21 | }; 22 | 23 | std::vector nodes; 24 | int root = -1; 25 | 26 | PermTree() {} 27 | Node& operator [] (int idx) { return nodes[idx]; } 28 | const Node& operator [] (int idx) const { return nodes[idx]; } 29 | 30 | int size() const { return int(nodes.size()); } 31 | 32 | PermTree(const std::vector& A) : nodes(int(A.size())*2-1) { 33 | int N = int(A.size()); 34 | std::vector nxt_earlier(N); 35 | std::vector prv_earlier(N); 36 | for (int i = 0; i < N; i++) { 37 | nxt_earlier[i] = i+1; 38 | prv_earlier[i] = i-1; 39 | } 40 | for (int i = N-1; i >= 0; i--) { 41 | int a = A[i]; 42 | int p = prv_earlier[a]; 43 | int n = nxt_earlier[a]; 44 | if (p != -1) nxt_earlier[p] = n; 45 | if (n != N) prv_earlier[n] = p; 46 | } 47 | 48 | struct cnd_t { 49 | int left; 50 | int lo; 51 | int lo_gap; 52 | int hi; 53 | int hi_gap; 54 | int node; 55 | }; 56 | 57 | std::vector stk; stk.reserve(N); 58 | 59 | for (int i = 0; i < N; i++) { 60 | int a = A[i]; 61 | while (true) { 62 | if (!stk.empty() && (a < stk.back().lo_gap || a > stk.back().hi_gap)) { 63 | assert(stk.size() >= 2); 64 | stk.end()[-2].lo = std::min(stk.end()[-2].lo, stk.back().lo); 65 | stk.end()[-2].hi = std::max(stk.end()[-2].hi, stk.back().hi); 66 | 67 | int n = 2 * stk.back().left - 1; 68 | nodes[n].c = {stk.end()[-2].node, stk.end()[-1].node}; 69 | nodes[n].type = NodeType::PARTIAL; 70 | nodes[n].l = stk.end()[-2].left; 71 | nodes[n].r = i-1; 72 | nodes[n].lo = stk.end()[-2].lo; 73 | nodes[n].hi = stk.end()[-2].hi; 74 | 75 | stk.pop_back(); 76 | 77 | stk.back().node = n; 78 | } else { 79 | break; 80 | } 81 | } 82 | 83 | stk.push_back({i, a, prv_earlier[a]+1, a, nxt_earlier[a]-1, 2*i}); 84 | nodes[2*i].type = NodeType::LEAF; 85 | nodes[2*i].c = {-1, -1}; 86 | nodes[2*i].l = nodes[2*i].r = i; 87 | nodes[2*i].lo = nodes[2*i].hi = a; 88 | 89 | while (stk.size() >= 2 && std::max(stk.back().hi, stk.end()[-2].hi) - std::min(stk.back().lo, stk.end()[-2].lo) == i - stk.end()[-2].left) { 90 | // merge these two nodes into one 91 | stk.end()[-2].lo = std::min(stk.end()[-2].lo, stk.back().lo); 92 | stk.end()[-2].hi = std::max(stk.end()[-2].hi, stk.back().hi); 93 | 94 | int n = 2 * stk.back().left - 1; 95 | nodes[n].c = {stk.end()[-2].node, stk.end()[-1].node}; 96 | if (stk.end()[-2].lo == stk.end()[-1].lo) { 97 | nodes[n].type = NodeType::DECR; 98 | } else if (stk.end()[-2].hi == stk.end()[-1].hi) { 99 | nodes[n].type = NodeType::INCR; 100 | } else { 101 | nodes[n].type = NodeType::FULL; 102 | } 103 | nodes[n].l = stk.end()[-2].left; 104 | nodes[n].r = i; 105 | nodes[n].lo = stk.end()[-2].lo; 106 | nodes[n].hi = stk.end()[-2].hi; 107 | 108 | stk.pop_back(); 109 | stk.back().node = n; 110 | } 111 | } 112 | 113 | assert(stk.size() == 1); 114 | root = stk.back().node; 115 | } 116 | }; 117 | -------------------------------------------------------------------------------- /src/perm_tree.test.cpp: -------------------------------------------------------------------------------- 1 | #include "perm_tree.hpp" 2 | 3 | #include 4 | #include 5 | 6 | void check_tree(std::vector A) { 7 | int N = int(A.size()); 8 | std::vector, std::array>> actual_ranges; 9 | for (int i = 0; i < N; i++) { 10 | int lo = A[i], hi = A[i]; 11 | for (int j = i; j < N; j++) { 12 | lo = std::min(lo, A[j]); 13 | hi = std::max(hi, A[j]); 14 | assert(hi - lo >= j - i); 15 | if (hi - lo == j - i) { 16 | actual_ranges.push_back({{i, j}, {lo, hi}}); 17 | } 18 | } 19 | } 20 | 21 | PermTree tree(A); 22 | std::vector, std::array>> computed_ranges; 23 | for (int n = 0; n < tree.size(); n++) { 24 | const auto& node = tree[n]; 25 | if (node.type != PermTree::NodeType::PARTIAL) { 26 | computed_ranges.push_back({{node.l, node.r}, {node.lo, node.hi}}); 27 | } 28 | if (node.type == PermTree::NodeType::LEAF) { 29 | REQUIRE(node.c[0] == -1); 30 | REQUIRE(node.c[1] == -1); 31 | REQUIRE(node.l == n/2); 32 | REQUIRE(node.r == n/2); 33 | REQUIRE(node.lo == A[n/2]); 34 | REQUIRE(node.hi == A[n/2]); 35 | continue; 36 | } 37 | REQUIRE(node.c[0] != -1); 38 | REQUIRE(node.c[1] != -1); 39 | REQUIRE(node.l == tree[node.c[0]].l); 40 | REQUIRE(node.r == tree[node.c[1]].r); 41 | REQUIRE(tree[node.c[0]].r + 1 == tree[node.c[1]].l); 42 | REQUIRE(node.lo == std::min(tree[node.c[0]].lo, tree[node.c[1]].lo)); 43 | REQUIRE(node.hi == std::max(tree[node.c[0]].hi, tree[node.c[1]].hi)); 44 | if (node.type == PermTree::NodeType::FULL) { 45 | // There should be at least 3 pieces 46 | REQUIRE(( 47 | tree[node.c[0]].type == PermTree::NodeType::PARTIAL 48 | || tree[node.c[1]].type == PermTree::NodeType::PARTIAL 49 | )); 50 | } 51 | if (node.type == PermTree::NodeType::INCR) { 52 | REQUIRE(tree[node.c[0]].hi + 1 == tree[node.c[1]].lo); 53 | 54 | REQUIRE(tree[node.c[1]].type != PermTree::NodeType::INCR); 55 | for (int cur = node.c[0]; tree[cur].type == PermTree::NodeType::INCR; cur = tree[cur].c[0]) { 56 | int ch = tree[cur].c[1]; 57 | computed_ranges.push_back({{tree[ch].l, node.r}, {tree[ch].lo, node.hi}}); 58 | } 59 | } 60 | if (node.type == PermTree::NodeType::DECR) { 61 | REQUIRE(tree[node.c[0]].lo - 1 == tree[node.c[1]].hi); 62 | 63 | REQUIRE(tree[node.c[1]].type != PermTree::NodeType::DECR); 64 | for (int cur = node.c[0]; tree[cur].type == PermTree::NodeType::DECR; cur = tree[cur].c[0]) { 65 | int ch = tree[cur].c[1]; 66 | computed_ranges.push_back({{tree[ch].l, node.r}, {node.lo, tree[ch].hi}}); 67 | } 68 | } 69 | } 70 | std::sort(computed_ranges.begin(), computed_ranges.end()); 71 | REQUIRE(actual_ranges == computed_ranges); 72 | } 73 | 74 | TEST_CASE("Permutation Tree", "[perm_tree]") { 75 | for (int N = 1; N <= 7; N++) { 76 | std::vector A(N); 77 | std::iota(A.begin(), A.end(), 0); 78 | do { 79 | check_tree(A); 80 | } while (next_permutation(A.begin(), A.end())); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/quaternion_hurwitz.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | template 9 | struct hurwitz_quaternion { 10 | // we store the doubled quaternion 11 | num s,x,y,z; 12 | hurwitz_quaternion() : s(0), x(0), y(0), z(0) {} 13 | hurwitz_quaternion(num v) : s(2*v), x(0), y(0), z(0) {} 14 | hurwitz_quaternion(num s_, num x_, num y_, num z_) : s(2*s_), x(2*x_), y(2*y_), z(2*z_) {} 15 | struct doubled_coords_tag {}; 16 | hurwitz_quaternion(doubled_coords_tag, num s_, num x_, num y_, num z_) : s(s_), x(x_), y(y_), z(z_) { 17 | assert((s & 1) == (x & 1) && (s & 1) == (y & 1) && (s & 1) == (z & 1)); 18 | } 19 | friend std::ostream& operator << (std::ostream& o, const hurwitz_quaternion& q) { 20 | o << double(q.s)/2; 21 | { 22 | std::ios_base::fmtflags f(o.flags()); 23 | o << std::showpos << double(q.x)/2 << "i" << double(q.y)/2 << "j" << double(q.z)/2 << "k"; 24 | o.flags(f); 25 | } 26 | return o; 27 | } 28 | 29 | explicit operator bool() const { 30 | return s || x || y || z; 31 | } 32 | 33 | friend bool operator == (const hurwitz_quaternion& a, const hurwitz_quaternion& b) { 34 | return std::tie(a.s,a.x,a.y,a.z) == std::tie(b.s,b.x,b.y,b.z); 35 | } 36 | friend bool operator != (const hurwitz_quaternion& a, const hurwitz_quaternion& b) { return !(a == b); } 37 | 38 | num real_doubled() const { 39 | return s; 40 | } 41 | num real() const { 42 | assert(!(s & 1)); 43 | return s >> 1; 44 | } 45 | std::array imag_doubled() const { 46 | return {x, y, z}; 47 | } 48 | std::array imag() const { 49 | assert(!(s & 1)); 50 | return {x>>1, y>>1, z>>1}; 51 | } 52 | std::array coords_doubled() const { 53 | return {s, x, y, z}; 54 | } 55 | std::array coords() const { 56 | assert(!(s & 1)); 57 | return {s>>1, x>>1, y>>1, z>>1}; 58 | } 59 | 60 | friend num norm(const hurwitz_quaternion& q) { 61 | return (q.s * q.s + q.x * q.x + q.y * q.y + q.z * q.z) >> 2; 62 | } 63 | friend hurwitz_quaternion conj(const hurwitz_quaternion& q) { 64 | return hurwitz_quaternion(doubled_coords_tag{}, q.s, -q.x, -q.y, -q.z); 65 | } 66 | 67 | friend hurwitz_quaternion operator + (const hurwitz_quaternion& q) { 68 | return hurwitz_quaternion(doubled_coords_tag{}, +q.s, +q.x, +q.y, +q.z); 69 | } 70 | friend hurwitz_quaternion operator - (const hurwitz_quaternion& q) { 71 | return hurwitz_quaternion(doubled_coords_tag{}, -q.s, -q.x, -q.y, -q.z); 72 | } 73 | 74 | hurwitz_quaternion& operator += (const hurwitz_quaternion& o) { 75 | s += o.s; 76 | x += o.x; 77 | y += o.y; 78 | z += o.z; 79 | return *this; 80 | } 81 | friend hurwitz_quaternion operator + (const hurwitz_quaternion& a, const hurwitz_quaternion& b) { 82 | return hurwitz_quaternion(doubled_coords_tag{}, a.s + b.s, a.x + b.x, a.y + b.y, a.z + b.z); 83 | } 84 | hurwitz_quaternion& operator -= (const hurwitz_quaternion& o) { 85 | s -= o.s; 86 | x -= o.x; 87 | y -= o.y; 88 | z -= o.z; 89 | return *this; 90 | } 91 | friend hurwitz_quaternion operator - (const hurwitz_quaternion& a, const hurwitz_quaternion& b) { 92 | return hurwitz_quaternion(doubled_coords_tag{}, a.s - b.s, a.x - b.x, a.y - b.y, a.z - b.z); 93 | } 94 | 95 | friend hurwitz_quaternion operator * (const num& a, const hurwitz_quaternion& q) { 96 | return hurwitz_quaternion(doubled_coords_tag{}, a*q.s, a*q.x, a*q.y, a*q.z); 97 | } 98 | friend hurwitz_quaternion operator * (const hurwitz_quaternion& q, const num& a) { 99 | return hurwitz_quaternion(doubled_coords_tag{}, q.s*a, q.x*a, q.y*a, q.z*a); 100 | } 101 | hurwitz_quaternion& operator *= (const num& a) { 102 | s *= a; 103 | x *= a; 104 | y *= a; 105 | z *= a; 106 | return *this; 107 | } 108 | 109 | friend hurwitz_quaternion operator * (const hurwitz_quaternion& a, const hurwitz_quaternion& b) { 110 | return hurwitz_quaternion( 111 | doubled_coords_tag{}, 112 | (a.s * b.s - a.x * b.x - a.y * b.y - a.z * b.z) >> 1, 113 | (a.s * b.x + a.x * b.s + a.y * b.z - a.z * b.y) >> 1, 114 | (a.s * b.y + a.y * b.s + a.z * b.x - a.x * b.z) >> 1, 115 | (a.s * b.z + a.z * b.s + a.x * b.y - a.y * b.x) >> 1 116 | ); 117 | } 118 | hurwitz_quaternion& operator *= (const hurwitz_quaternion& o) { 119 | return *this = *this * o; 120 | } 121 | 122 | struct div_t { 123 | hurwitz_quaternion quot, rem; 124 | }; 125 | // a = b * quot + rem 126 | friend div_t right_div(const hurwitz_quaternion& a, const hurwitz_quaternion& b) { 127 | hurwitz_quaternion numer = conj(b) * a; 128 | num denom = norm(b); 129 | 130 | auto floor_div = [](num u, num v) -> num { 131 | if ((u^v) >= 0) { 132 | return u/v; 133 | } else { 134 | auto res = std::div(u, v); 135 | return res.quot - bool(res.rem); 136 | } 137 | }; 138 | num s = floor_div(numer.s, denom); 139 | num x = floor_div(numer.x, denom); 140 | num y = floor_div(numer.y, denom); 141 | num z = floor_div(numer.z, denom); 142 | 143 | hurwitz_quaternion q_odd(doubled_coords_tag{}, s | 1, x | 1, y | 1, z | 1); 144 | hurwitz_quaternion r_odd = a - b * q_odd; 145 | hurwitz_quaternion q_even(doubled_coords_tag{}, (s+1)&~num(1), (x+1)&~num(1), (y+1)&~num(1), (z+1)&~num(1)); 146 | hurwitz_quaternion r_even = a - b * q_even; 147 | div_t res = norm(r_odd) < norm(r_even) ? div_t{q_odd, r_odd} : div_t{q_even, r_even}; 148 | assert(norm(res.rem) < norm(b)); 149 | return res; 150 | } 151 | 152 | // a = ga', b = gb' 153 | friend hurwitz_quaternion right_gcd(hurwitz_quaternion a, hurwitz_quaternion b) { 154 | while (a) { 155 | b = right_div(b, a).rem; 156 | std::swap(a, b); 157 | } 158 | return b; 159 | } 160 | }; 161 | -------------------------------------------------------------------------------- /src/reverse_comparator.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | template struct reverse_comparator_t { 7 | F f; 8 | template constexpr bool operator() (Arg1&& arg1, Arg2&& arg2) & { 9 | return f(std::forward(arg2), std::forward(arg1)); 10 | } 11 | template constexpr bool operator() (Arg1&& arg1, Arg2&& arg2) const& { 12 | return f(std::forward(arg2), std::forward(arg1)); 13 | } 14 | template constexpr bool operator() (Arg1&& arg1, Arg2&& arg2) && { 15 | return std::move(f)(std::forward(arg2), std::forward(arg1)); 16 | } 17 | template constexpr bool operator() (Arg1&& arg1, Arg2&& arg2) const&& { 18 | return std::move(f)(std::forward(arg2), std::forward(arg1)); 19 | } 20 | }; 21 | 22 | template constexpr reverse_comparator_t> reverse_comparator(F&& f) { 23 | return { std::forward(f) }; 24 | } 25 | -------------------------------------------------------------------------------- /src/rmq.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | template > class RangeMinQuery : private Compare { 9 | static const int BUCKET_SIZE = 32; 10 | static const int BUCKET_SIZE_LOG = 5; 11 | static_assert(BUCKET_SIZE == (1 << BUCKET_SIZE_LOG), "BUCKET_SIZE should be a power of 2"); 12 | static const int CACHE_LINE_ALIGNMENT = 64; 13 | int n = 0; 14 | std::vector data; 15 | std::vector pref_data; 16 | std::vector suff_data; 17 | std::vector sparse_table; 18 | std::vector range_mask; 19 | 20 | private: 21 | int num_buckets() const { 22 | return n >> BUCKET_SIZE_LOG; 23 | } 24 | int num_levels() const { 25 | return num_buckets() ? 32 - __builtin_clz(num_buckets()) : 0; 26 | } 27 | int sparse_table_size() const { 28 | return num_buckets() * num_levels(); 29 | } 30 | private: 31 | const T& min(const T& a, const T& b) const { 32 | return Compare::operator()(a, b) ? a : b; 33 | } 34 | void setmin(T& a, const T& b) const { 35 | if (Compare::operator()(b, a)) a = b; 36 | } 37 | 38 | template static int get_size(const Vec& v) { using std::size; return int(size(v)); } 39 | 40 | public: 41 | RangeMinQuery() {} 42 | template explicit RangeMinQuery(const Vec& data_, const Compare& comp_ = Compare()) 43 | : Compare(comp_) 44 | , n(get_size(data_)) 45 | , data(n) 46 | , pref_data(n) 47 | , suff_data(n) 48 | , sparse_table(sparse_table_size()) 49 | , range_mask(n) 50 | { 51 | for (int i = 0; i < n; i++) data[i] = data_[i]; 52 | for (int i = 0; i < n; i++) { 53 | if (i & (BUCKET_SIZE-1)) { 54 | uint32_t m = range_mask[i-1]; 55 | while (m && !Compare::operator()(data[(i | (BUCKET_SIZE-1)) - __builtin_clz(m)], data[i])) { 56 | m -= uint32_t(1) << (BUCKET_SIZE - 1 - __builtin_clz(m)); 57 | } 58 | m |= uint32_t(1) << (i & (BUCKET_SIZE - 1)); 59 | range_mask[i] = m; 60 | } else { 61 | range_mask[i] = 1; 62 | } 63 | } 64 | for (int i = 0; i < n; i++) { 65 | pref_data[i] = data[i]; 66 | if (i & (BUCKET_SIZE-1)) { 67 | setmin(pref_data[i], pref_data[i-1]); 68 | } 69 | } 70 | for (int i = n-1; i >= 0; i--) { 71 | suff_data[i] = data[i]; 72 | if (i+1 < n && ((i+1) & (BUCKET_SIZE-1))) { 73 | setmin(suff_data[i], suff_data[i+1]); 74 | } 75 | } 76 | for (int i = 0; i < num_buckets(); i++) { 77 | sparse_table[i] = data[i * BUCKET_SIZE]; 78 | for (int v = 1; v < BUCKET_SIZE; v++) { 79 | setmin(sparse_table[i], data[i * BUCKET_SIZE + v]); 80 | } 81 | } 82 | for (int l = 0; l+1 < num_levels(); l++) { 83 | for (int i = 0; i + (1 << (l+1)) <= num_buckets(); i++) { 84 | sparse_table[(l+1) * num_buckets() + i] = min(sparse_table[l * num_buckets() + i], sparse_table[l * num_buckets() + i + (1 << l)]); 85 | } 86 | } 87 | } 88 | 89 | T query(int l, int r) const { 90 | assert(l <= r); 91 | int bucket_l = (l >> BUCKET_SIZE_LOG); 92 | int bucket_r = (r >> BUCKET_SIZE_LOG); 93 | if (bucket_l == bucket_r) { 94 | uint32_t msk = range_mask[r] & ~((uint32_t(1) << (l & (BUCKET_SIZE-1))) - 1); 95 | int ind = (l & ~(BUCKET_SIZE-1)) + __builtin_ctz(msk); 96 | return data[ind]; 97 | } else { 98 | T ans = min(suff_data[l], pref_data[r]); 99 | bucket_l++; 100 | if (bucket_l < bucket_r) { 101 | int level = (32 - __builtin_clz(bucket_r - bucket_l)) - 1; 102 | setmin(ans, sparse_table[level * num_buckets() + bucket_l]); 103 | setmin(ans, sparse_table[level * num_buckets() + bucket_r - (1 << level)]); 104 | } 105 | return ans; 106 | } 107 | } 108 | }; 109 | 110 | template using RangeMaxQuery = RangeMinQuery>; 111 | -------------------------------------------------------------------------------- /src/rmq.test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "rmq.hpp" 5 | 6 | #include 7 | #include 8 | 9 | TEST_CASE("RangeMinQuery", "[rmq]") { 10 | std::mt19937 mt(Catch::getSeed()); 11 | for (int N : {1, 2, 3, 5, 10, 20, 33, 48, 100, 163, 512}) { 12 | std::vector> data(N); 13 | for (int i = 0; i < N; i++) { 14 | data[i] = {mt(), i}; 15 | } 16 | 17 | RangeMinQuery> minQ(data); 18 | RangeMaxQuery> maxQ(data); 19 | 20 | for (int l = 0; l < N; l++) { 21 | std::pair cur_min = data[l]; 22 | std::pair cur_max = data[l]; 23 | for (int r = l; r < N; r++) { 24 | cur_min = min(cur_min, data[r]); 25 | REQUIRE(minQ.query(l, r) == cur_min); 26 | cur_max = max(cur_max, data[r]); 27 | REQUIRE(maxQ.query(l, r) == cur_max); 28 | } 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/seg_tree.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace seg_tree { 6 | 7 | // Floor of log_2(a); index of highest 1-bit 8 | inline int floor_log_2(int a) { 9 | return a ? (8 * sizeof(a)) - 1 - __builtin_clz(a) : -1; 10 | } 11 | 12 | inline int ceil_log_2(int a) { 13 | return a ? floor_log_2(2*a-1) : -1; 14 | } 15 | 16 | inline int next_pow_2(int a) { 17 | return 1 << ceil_log_2(a); 18 | } 19 | 20 | struct point { 21 | int a; 22 | point() : a(0) {} 23 | explicit point(int a_) : a(a_) { assert(a >= -1); } 24 | 25 | explicit operator bool () { return bool(a); } 26 | 27 | // This is useful so you can directly do array indices 28 | /* implicit */ operator int() const { return a; } 29 | 30 | point c(bool z) const { 31 | return point((a<<1)|z); 32 | } 33 | 34 | point operator [] (bool z) const { 35 | return c(z); 36 | } 37 | 38 | point p() const { 39 | return point(a>>1); 40 | } 41 | 42 | friend std::ostream& operator << (std::ostream& o, const point& p) { return o << int(p); } 43 | 44 | template void for_each(F f) const { 45 | for (int v = a; v > 0; v >>= 1) { 46 | f(point(v)); 47 | } 48 | } 49 | 50 | template void for_each_down(F f) const { 51 | // strictly greater than 0 52 | for (int L = floor_log_2(a); L >= 0; L--) { 53 | f(point(a >> L)); 54 | } 55 | } 56 | 57 | template void for_each_up(F f) const { 58 | for (int v = a; v > 0; v >>= 1) { 59 | f(point(v)); 60 | } 61 | } 62 | 63 | template void for_parents_down(F f) const { 64 | // strictly greater than 0 65 | for (int L = floor_log_2(a); L > 0; L--) { 66 | f(point(a >> L)); 67 | } 68 | } 69 | 70 | template void for_parents_up(F f) const { 71 | for (int v = a >> 1; v > 0; v >>= 1) { 72 | f(point(v)); 73 | } 74 | } 75 | 76 | point& operator ++ () { ++a; return *this; } 77 | point operator ++ (int) { return point(a++); } 78 | point& operator -- () { --a; return *this; } 79 | point operator -- (int) { return point(a--); } 80 | }; 81 | 82 | struct range { 83 | int a, b; 84 | range() : a(1), b(1) {} 85 | range(int a_, int b_) : a(a_), b(b_) { 86 | assert(1 <= a && a <= b && b <= 2 * a); 87 | } 88 | explicit range(std::array r) : range(r[0], r[1]) {} 89 | 90 | explicit operator std::array() const { 91 | return {a,b}; 92 | } 93 | 94 | const int& operator[] (bool z) const { 95 | return z ? b : a; 96 | } 97 | 98 | friend std::ostream& operator << (std::ostream& o, const range& r) { return o << "[" << r.a << ".." << r.b << ")"; } 99 | 100 | // Iterate over the range from outside-in. 101 | // Calls f(point a) 102 | template void for_each(F f) const { 103 | for (int x = a, y = b; x < y; x >>= 1, y >>= 1) { 104 | if (x & 1) f(point(x++)); 105 | if (y & 1) f(point(--y)); 106 | } 107 | } 108 | 109 | // Iterate over the range from outside-in. 110 | // Calls f(point a, bool is_right) 111 | template void for_each_with_side(F f) const { 112 | for (int x = a, y = b; x < y; x >>= 1, y >>= 1) { 113 | if (x & 1) f(point(x++), false); 114 | if (y & 1) f(point(--y), true); 115 | } 116 | } 117 | 118 | // Iterate over the range from left to right. 119 | // Calls f(point) 120 | template void for_each_l_to_r(F f) const { 121 | int anc_depth = floor_log_2((a-1) ^ b); 122 | int anc_msk = (1 << anc_depth) - 1; 123 | for (int v = (-a) & anc_msk; v; v &= v-1) { 124 | int i = __builtin_ctz(v); 125 | f(point(((a-1) >> i) + 1)); 126 | } 127 | for (int v = b & anc_msk; v; ) { 128 | int i = floor_log_2(v); 129 | f(point((b >> i) - 1)); 130 | v ^= (1 << i); 131 | } 132 | } 133 | 134 | // Iterate over the range from right to left. 135 | // Calls f(point) 136 | template void for_each_r_to_l(F f) const { 137 | int anc_depth = floor_log_2((a-1) ^ b); 138 | int anc_msk = (1 << anc_depth) - 1; 139 | for (int v = b & anc_msk; v; v &= v-1) { 140 | int i = __builtin_ctz(v); 141 | f(point((b >> i) - 1)); 142 | } 143 | for (int v = (-a) & anc_msk; v; ) { 144 | int i = floor_log_2(v); 145 | f(point(((a-1) >> i) + 1)); 146 | v ^= (1 << i); 147 | } 148 | } 149 | 150 | template void for_parents_down(F f) const { 151 | int x = a, y = b; 152 | if ((x ^ y) > x) { x <<= 1, std::swap(x, y); } 153 | int dx = __builtin_ctz(x); 154 | int dy = __builtin_ctz(y); 155 | int anc_depth = floor_log_2((x-1) ^ y); 156 | for (int i = floor_log_2(x); i > dx; i--) { 157 | f(point(x >> i)); 158 | } 159 | for (int i = anc_depth; i > dy; i--) { 160 | f(point(y >> i)); 161 | } 162 | } 163 | 164 | template void for_parents_up(F f) const { 165 | int x = a, y = b; 166 | if ((x ^ y) > x) { x <<= 1, std::swap(x, y); } 167 | int dx = __builtin_ctz(x); 168 | int dy = __builtin_ctz(y); 169 | int anc_depth = floor_log_2((x-1) ^ y); 170 | for (int i = dx+1; i <= anc_depth; i++) { 171 | f(point(x >> i)); 172 | } 173 | for (int v = y >> (dy+1); v; v >>= 1) { 174 | f(point(v)); 175 | } 176 | } 177 | }; 178 | 179 | struct in_order_layout { 180 | // Alias them in for convenience 181 | using point = seg_tree::point; 182 | using range = seg_tree::range; 183 | 184 | int N, S; 185 | in_order_layout() : N(0), S(0) {} 186 | in_order_layout(int N_) : N(N_), S(N ? next_pow_2(N) : 0) {} 187 | 188 | point get_point(int a) const { 189 | assert(0 <= a && a < N); 190 | a += S; 191 | return point(a >= 2 * N ? a - N : a); 192 | } 193 | 194 | range get_range(int a, int b) const { 195 | assert(0 <= a && a <= b && b <= N); 196 | if (N == 0) return range(); 197 | a += S, b += S; 198 | return range((a >= 2 * N ? 2*(a-N) : a), (b >= 2 * N ? 2*(b-N) : b)); 199 | } 200 | 201 | range get_range(std::array p) const { 202 | return get_range(p[0], p[1]); 203 | } 204 | 205 | int get_leaf_index(point pt) const { 206 | int a = int(pt); 207 | assert(N <= a && a < 2 * N); 208 | return (a < S ? a + N : a) - S; 209 | } 210 | 211 | std::array get_node_bounds(point pt) const { 212 | int a = int(pt); 213 | assert(1 <= a && a < 2 * N); 214 | int l = __builtin_clz(a) - __builtin_clz(2*N-1); 215 | int x = a << l, y = (a+1) << l; 216 | assert(S <= x && x < y && y <= 2*S); 217 | return {(x >= 2 * N ? (x>>1) + N : x) - S, (y >= 2 * N ? (y>>1) + N : y) - S}; 218 | } 219 | 220 | int get_node_split(point pt) const { 221 | int a = int(pt); 222 | assert(1 <= a && a < N); 223 | int l = __builtin_clz(2*a+1) - __builtin_clz(2*N-1); 224 | int x = (2*a+1) << l; 225 | assert(S <= x && x < 2*S); 226 | return (x >= 2 * N ? (x>>1) + N : x) - S; 227 | } 228 | 229 | int get_node_size(point pt) const { 230 | auto bounds = get_node_bounds(pt); 231 | return bounds[1] - bounds[0]; 232 | } 233 | }; 234 | 235 | struct circular_layout { 236 | // Alias them in for convenience 237 | using point = seg_tree::point; 238 | using range = seg_tree::range; 239 | 240 | int N; 241 | circular_layout() : N(0) {} 242 | circular_layout(int N_) : N(N_) {} 243 | 244 | point get_point(int a) const { 245 | assert(0 <= a && a < N); 246 | return point(N + a); 247 | } 248 | 249 | range get_range(int a, int b) const { 250 | assert(0 <= a && a <= b && b <= N); 251 | if (N == 0) return range(); 252 | return range(N + a, N + b); 253 | } 254 | 255 | range get_range(std::array p) const { 256 | return get_range(p[0], p[1]); 257 | } 258 | 259 | int get_leaf_index(point pt) const { 260 | int a = int(pt); 261 | assert(N <= a && a < 2 * N); 262 | return a - N; 263 | } 264 | 265 | // Returns {x,y} so that 0 <= x < N and 1 <= y <= N 266 | // If the point is non-wrapping, then 0 <= x < y <= N 267 | std::array get_node_bounds(point pt) const { 268 | int a = int(pt); 269 | assert(1 <= a && a < 2 * N); 270 | int l = __builtin_clz(a) - __builtin_clz(2*N-1); 271 | int S = next_pow_2(N); 272 | int x = a << l, y = (a+1) << l; 273 | assert(S <= x && x < y && y <= 2*S); 274 | return {(x >= 2 * N ? x >> 1 : x) - N, (y > 2 * N ? y >> 1 : y) - N}; 275 | } 276 | 277 | // Returns the split point of the node, such that 1 <= s <= N. 278 | int get_node_split(point pt) const { 279 | int a = int(pt); 280 | assert(1 <= a && a < N); 281 | return get_node_bounds(pt.c(0))[1]; 282 | } 283 | 284 | int get_node_size(point pt) const { 285 | auto bounds = get_node_bounds(pt); 286 | int r = bounds[1] - bounds[0]; 287 | return r > 0 ? r : r + N; 288 | } 289 | }; 290 | 291 | } // namespace seg_tree 292 | -------------------------------------------------------------------------------- /src/seg_tree.test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "seg_tree.hpp" 4 | 5 | #include 6 | 7 | TEMPLATE_TEST_CASE("Segment Tree Layouts", "[seg_tree][template]", seg_tree::in_order_layout, seg_tree::circular_layout) { 8 | for (int N : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 101, 127, 128, 129}) { 9 | auto seg = TestType(N); 10 | for (int i = 0; i < N; i++) { 11 | auto pt = seg.get_point(i); 12 | REQUIRE(seg.get_leaf_index(pt) == i); 13 | REQUIRE(seg.get_node_bounds(pt) == std::array({i,i+1})); 14 | REQUIRE(seg.get_node_size(pt) == 1); 15 | } 16 | for (seg_tree::point a(N-1); a >= 1; a--) { 17 | auto pt = seg_tree::point(a); 18 | REQUIRE(seg.get_node_size(pt) == seg.get_node_size(pt.c(0)) + seg.get_node_size(pt.c(1))); 19 | REQUIRE(seg.get_node_bounds(pt)[0] == seg.get_node_bounds(pt.c(0))[0]); 20 | REQUIRE(seg.get_node_bounds(pt)[1] == seg.get_node_bounds(pt.c(1))[1]); 21 | if constexpr (std::is_same_v) { 22 | REQUIRE(seg.get_node_bounds(pt.c(0))[1] == seg.get_node_bounds(pt.c(1))[0]); 23 | } else { 24 | REQUIRE(seg.get_node_bounds(pt.c(0))[1] % N == seg.get_node_bounds(pt.c(1))[0]); 25 | } 26 | } 27 | 28 | for (int l = 0; l <= N; l++) { 29 | for (int r = l; r <= N; r++) { 30 | auto rng = seg.get_range(l, r); 31 | 32 | { 33 | int x = l, y = r; 34 | rng.for_each([&](auto a) { 35 | auto bounds = seg.get_node_bounds(a); 36 | if (x == bounds[0]) { 37 | x = bounds[1]; 38 | } else if (y == bounds[1]) { 39 | y = bounds[0]; 40 | } else assert(false); 41 | }); 42 | REQUIRE(x == y); 43 | } 44 | { 45 | int x = l, y = r; 46 | rng.for_each_with_side([&](auto a, bool d) { 47 | auto bounds = seg.get_node_bounds(a); 48 | if (d == 0) { 49 | REQUIRE(x == bounds[0]); 50 | x = bounds[1]; 51 | } else if (d == 1) { 52 | REQUIRE(y == bounds[1]); 53 | y = bounds[0]; 54 | } else assert(false); 55 | }); 56 | REQUIRE(x == y); 57 | } 58 | { 59 | int x = l; 60 | rng.for_each_l_to_r([&](auto a) { 61 | auto bounds = seg.get_node_bounds(a); 62 | REQUIRE(x == bounds[0]); 63 | x = bounds[1]; 64 | }); 65 | REQUIRE(x == r); 66 | } 67 | { 68 | int y = r; 69 | rng.for_each_r_to_l([&](auto a) { 70 | auto bounds = seg.get_node_bounds(a); 71 | REQUIRE(y == bounds[1]); 72 | y = bounds[0]; 73 | }); 74 | REQUIRE(y == l); 75 | } 76 | } 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/smawk.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #if __cpp_concepts >= 202002 7 | #include 8 | #endif 9 | 10 | namespace smawk { 11 | 12 | template struct value_t { 13 | T v; 14 | int col; 15 | }; 16 | 17 | // Get(int row, int col) -> T 18 | // Select(int row, const value_t& opt_0, const value_t& opt_1) returns 0 or 1 for which is better 19 | #if __cpp_concepts >= 202002 20 | template concept totally_monotone_matrix_oracle = 21 | std::default_initializable && std::movable 22 | && std::invocable && std::convertible_to, T> 23 | && std::predicate&, const value_t&>; 24 | #endif 25 | 26 | 27 | template > 28 | #if __cpp_concepts >= 202002 29 | requires totally_monotone_matrix_oracle 30 | #endif 31 | class LARSCH { 32 | public: 33 | int N; 34 | Get get; 35 | Select select; 36 | int L; 37 | int num_rows; 38 | 39 | std::vector>> stk; 40 | std::vector, int>> bests; 41 | LARSCH() {} 42 | LARSCH(int N_, Get&& get_, Select&& select_) : N(N_), get(std::forward(get_)), select(std::forward