├── data └── CodeForces.md ├── src ├── dp │ ├── circle_dp │ │ ├── template.py │ │ └── example.py │ ├── prob_dp │ │ ├── template.py │ │ └── example.py │ ├── sos_dp │ │ ├── template.py │ │ └── example.py │ ├── state_dp │ │ ├── template.py │ │ └── example.py │ ├── interval_dp │ │ ├── template.py │ │ └── example.py │ ├── outline_dp │ │ ├── template.py │ │ └── example.py │ ├── game_dp │ │ ├── example.py │ │ └── template.py │ ├── linear_dp │ │ ├── example.py │ │ └── template.py │ ├── matrix_dp │ │ ├── example.py │ │ └── template.py │ ├── bag_dp │ │ └── example.py │ └── digital_dp │ │ └── example.py ├── graph │ ├── bfs │ │ ├── template.py │ │ └── example.py │ ├── network_flow │ │ └── example.py │ ├── two_sat │ │ ├── template.py │ │ └── example.py │ ├── spfa │ │ └── example.py │ ├── dijkstra │ │ └── example.py │ ├── bipartite_matching │ │ ├── example.py │ │ └── template.py │ ├── binary_search_tree │ │ ├── example.py │ │ ├── template.py │ │ └── problem.py │ ├── floyd │ │ └── example.py │ ├── minimum_spanning_tree │ │ └── example.py │ ├── tarjan │ │ └── example.py │ ├── topological_sort │ │ └── example.py │ ├── prufer │ │ ├── problem.py │ │ ├── example.py │ │ └── template.py │ ├── euler_path │ │ └── example.py │ ├── dfs │ │ └── example.py │ └── union_find │ │ └── example.py ├── basis │ ├── brute_force │ │ ├── template.py │ │ └── example.py │ ├── interactive │ │ ├── template.py │ │ └── example.py │ ├── observation │ │ ├── template.py │ │ └── example.py │ ├── recursion │ │ ├── template.py │ │ └── example.py │ ├── construction │ │ ├── template.py │ │ └── example.py │ ├── meet_in_middle │ │ ├── template.py │ │ └── example.py │ ├── offline_query │ │ ├── template.py │ │ ├── example.py │ │ └── problem.py │ ├── hash │ │ ├── template.py │ │ └── example.py │ ├── performance │ │ ├── template.py │ │ ├── example.py │ │ └── problem.py │ ├── stack │ │ ├── example.py │ │ └── template.py │ ├── tree_node │ │ ├── example.py │ │ └── template.py │ ├── md_vector │ │ ├── example.py │ │ ├── template.py │ │ └── problem.py │ ├── serialization │ │ ├── example.py │ │ ├── problem.py │ │ └── template.py │ ├── circular_section │ │ ├── example.py │ │ └── template.py │ ├── range │ │ ├── example.py │ │ └── template.py │ ├── various_sort │ │ └── example.py │ ├── ternary_search │ │ └── example.py │ ├── binary_search │ │ ├── example.py │ │ └── template.py │ ├── date │ │ ├── example.py │ │ └── problem.py │ ├── two_pointers │ │ └── example.py │ ├── implemention │ │ └── example.py │ └── diff_array │ │ └── example.py ├── math │ ├── mex_like │ │ ├── template.py │ │ └── example.py │ ├── random_like │ │ ├── template.py │ │ └── example.py │ ├── partition_fft │ │ ├── template.py │ │ ├── example.py │ │ └── problem.py │ ├── scan_line │ │ ├── example.py │ │ ├── problem.py │ │ └── template.py │ ├── convex_hull │ │ ├── example.py │ │ ├── template.py │ │ └── problem.py │ ├── cantor_expands │ │ ├── example.py │ │ └── template.py │ ├── peishu_theorem │ │ ├── template.py │ │ ├── example.py │ │ └── problem.py │ ├── nim_game │ │ ├── template.py │ │ ├── example.py │ │ └── problem.py │ ├── extend_crt │ │ ├── example.py │ │ ├── template.py │ │ └── problem.py │ ├── geometry │ │ └── example.py │ ├── comb_perm │ │ ├── example.py │ │ └── template.py │ ├── fast_power │ │ ├── example.py │ │ └── template.py │ ├── high_precision │ │ └── example.py │ ├── prime_factor │ │ └── example.py │ ├── bit_operation │ │ ├── example.py │ │ └── template.py │ ├── linear_basis │ │ ├── example.py │ │ └── template.py │ ├── lexico_graphical_order │ │ └── example.py │ └── gcd_like │ │ ├── template.py │ │ └── example.py ├── struct │ ├── linked_list │ │ ├── template.py │ │ └── example.py │ ├── bit_set │ │ ├── example.py │ │ └── template.py │ ├── list_node │ │ ├── example.py │ │ ├── template.py │ │ └── problem.py │ ├── associative_array │ │ ├── example.py │ │ └── template.py │ ├── sqrt_decomposition │ │ ├── example.py │ │ └── template.py │ ├── monotonic_queue │ │ ├── example.py │ │ └── template.py │ ├── priority_queue │ │ └── example.py │ ├── sorted_list │ │ └── example.py │ ├── monotonic_stack │ │ └── example.py │ ├── sparse_table │ │ └── example.py │ └── trie_like │ │ └── example.py ├── tree │ ├── tree_diff_array │ │ ├── example.py │ │ └── template.py │ ├── tree_diameter │ │ ├── example.py │ │ └── template.py │ └── tree_lca │ │ └── example.py ├── string │ ├── automaton │ │ ├── example.py │ │ └── template.py │ ├── suffix_array │ │ ├── example.py │ │ └── template.py │ ├── expression │ │ ├── problem.py │ │ └── example.py │ ├── palindrome_num │ │ ├── example.py │ │ └── template.py │ ├── lyndon_decomposition │ │ ├── example.py │ │ └── template.py │ ├── kmp │ │ ├── example.py │ │ └── template.py │ ├── manacher_palindrome │ │ └── example.py │ └── string_hash │ │ └── example.py ├── greedy │ ├── brain_storming │ │ ├── example.py │ │ └── template.py │ └── longest_increasing_subsequence │ │ └── example.py └── util │ └── read_file.py ├── pytest.ini ├── tests ├── leetcode │ ├── template.py │ ├── simple.py │ ├── problem_1.py │ ├── problem_2.py │ ├── problem_3.py │ └── problem_4.py └── codeforces │ ├── template.py │ ├── simple.py │ ├── problem_a.py │ ├── problem_b.py │ ├── problem_c.py │ └── problem_d.py └── README.md /data/CodeForces.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp/circle_dp/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp/prob_dp/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp/sos_dp/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp/state_dp/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/graph/bfs/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/brute_force/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/interactive/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/observation/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/recursion/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp/interval_dp/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dp/outline_dp/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/graph/network_flow/example.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/graph/two_sat/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/math/mex_like/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/math/random_like/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/construction/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/meet_in_middle/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/offline_query/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/math/partition_fft/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/struct/linked_list/template.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tree/tree_diff_array/example.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/basis/hash/template.py: -------------------------------------------------------------------------------- 1 | class HashMap: 2 | def __init__(self): 3 | return 4 | 5 | def gen_result(self): 6 | return 7 | -------------------------------------------------------------------------------- /src/basis/performance/template.py: -------------------------------------------------------------------------------- 1 | class Performance: 2 | def __init__(self): 3 | return 4 | 5 | def gen_result(self): 6 | return 7 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli = 1 3 | log_cli_level = INFO 4 | log_cli_format = %(asctime) s [%(levelname) 8s] %(message) s (%(filename) s:%(lineno) s) 5 | log_cli_date_format=%Y-%m-%d %H:%M:%S -------------------------------------------------------------------------------- /src/graph/spfa/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_spfa(self): 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/graph/dijkstra/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_dijkstra(self): 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/graph/bfs/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/graph/bipartite_matching/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_km(self): 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/string/automaton/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_AhoCorasick(self): 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/basis/stack/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_stack(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/game_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_game_dp(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/prob_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/graph/binary_search_tree/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_solution(self): 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/math/scan_line/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | def test_euler_phi(self): 6 | pass 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/basis/hash/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_hash_map(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/interactive/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/tree_node/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/circle_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_circle_dp(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/sos_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_outline_dp(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/state_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_state_dp(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/graph/floyd/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_solution(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/math/mex_like/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_mex_like(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/math/random_like/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | def test_euler_phi(self): 6 | pass 7 | return 8 | 9 | 10 | if __name__ == '__main__': 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/struct/bit_set/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/struct/linked_list/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/struct/list_node/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/tree/tree_diameter/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/md_vector/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_md_vector(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/meet_in_middle/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/offline_query/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/recursion/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_recursion(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/serialization/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/interval_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_interval_dp(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/dp/outline_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_outline_dp(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/graph/two_sat/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_directed_graph(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/math/convex_hull/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_convex_hull(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/construction/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_construction(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/observation/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_observation(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/performance/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_performance(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/struct/associative_array/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_xxxx(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/brute_force/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_violent_enumeration(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/circular_section/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_circle_section(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/graph/minimum_spanning_tree/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | 5 | class TestGeneral(unittest.TestCase): 6 | 7 | def test_minimum_spanning_tree(self): 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/math/cantor_expands/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestGeneral(unittest.TestCase): 5 | 6 | def test_cantor_expands(self): 7 | pass 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/math/peishu_theorem/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | 4 | 5 | class PeiShuTheorem: 6 | def __init__(self): 7 | return 8 | 9 | @staticmethod 10 | def get_lst_gcd(lst): 11 | return reduce(math.gcd, lst) 12 | -------------------------------------------------------------------------------- /src/math/nim_game/template.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import xor 3 | 4 | 5 | class Nim: 6 | def __init__(self, lst): 7 | self.lst = lst 8 | return 9 | 10 | def gen_result(self): 11 | return reduce(xor, self.lst) != 0 12 | -------------------------------------------------------------------------------- /src/graph/tarjan/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.graph.tarjan.template import Tarjan 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | def test_undirected_graph(self): 8 | return 9 | 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /src/basis/stack/template.py: -------------------------------------------------------------------------------- 1 | class MaxStack: 2 | def __init__(self): 3 | return 4 | 5 | def gen_result(self): 6 | return 7 | 8 | 9 | class MinStack: 10 | def __init__(self): 11 | return 12 | 13 | def gen_result(self): 14 | return 15 | -------------------------------------------------------------------------------- /src/math/nim_game/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.math.nim_game.template import Nim 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | def test_nim_game(self): 8 | nim = Nim([0, 2, 3]) 9 | assert nim.gen_result() 10 | return 11 | 12 | 13 | if __name__ == '__main__': 14 | unittest.main() 15 | -------------------------------------------------------------------------------- /src/dp/linear_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.dp.linear_dp.template import LinearDP 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_linear_dp(self): 9 | ld = LinearDP() 10 | nums = [6, 3, 5, 2, 1, 6, 8, 9] 11 | assert ld.liner_dp_template(nums) == 4 12 | return 13 | 14 | 15 | if __name__ == '__main__': 16 | unittest.main() 17 | -------------------------------------------------------------------------------- /src/struct/associative_array/template.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import Counter 3 | 4 | 5 | class HashWithRandomSeedEscapeExplode: 6 | def __int__(self): 7 | return 8 | 9 | @staticmethod 10 | def get_cnt(nums): 11 | """template of associative array""" 12 | seed = random.randint(0, 10 ** 9 + 7) 13 | return Counter([num ^ seed for num in nums]) 14 | -------------------------------------------------------------------------------- /src/math/peishu_theorem/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.math.peishu_theorem.template import PeiShuTheorem 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_peishu_theorem(self): 9 | lst = [4059, -1782] 10 | pst = PeiShuTheorem().get_lst_gcd(lst) 11 | assert pst == 99 12 | return 13 | 14 | 15 | if __name__ == '__main__': 16 | unittest.main() 17 | -------------------------------------------------------------------------------- /src/basis/range/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.basis.range.template import Range 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_range_cover_count(self): 9 | rcc = Range() 10 | lst = [[1, 4], [2, 5], [3, 6], [8, 9]] 11 | assert rcc.range_merge_to_disjoint(lst) == [[1, 6], [8, 9]] 12 | return 13 | 14 | 15 | if __name__ == '__main__': 16 | unittest.main() 17 | -------------------------------------------------------------------------------- /src/basis/performance/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:speed_up|performance 3 | Description:some skills or tricks for better performance 4 | 5 | 6 | =====================================LuoGu====================================== 7 | P1188(https://www.luogu.com.cn/problem/P1188)slice 8 | 9 | """ 10 | 11 | 12 | class XXX: 13 | def __init__(self): 14 | return 15 | 16 | 17 | class Solution: 18 | def __int__(self): 19 | return 20 | -------------------------------------------------------------------------------- /src/greedy/brain_storming/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.greedy.brain_storming.template import BrainStorming 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_brain_storming(self): 9 | bs = BrainStorming() 10 | n, m = 4, 20 11 | nums = [1, 2, 5, 10] 12 | assert bs.minimal_coin_need(n, m, nums) == 5 13 | return 14 | 15 | 16 | if __name__ == '__main__': 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /src/graph/topological_sort/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.graph.topological_sort.template import TopologicalSort 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_topological_sort(self): 9 | ts = TopologicalSort() 10 | n = 5 11 | edges = [[0, 1], [0, 2], [1, 4], [2, 3], [3, 4]] 12 | assert ts.get_rank(n, edges) == [0, 1, 1, 2, 3] 13 | return 14 | 15 | 16 | if __name__ == '__main__': 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /src/string/suffix_array/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.string.suffix_array.template import SuffixArray 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_suffix_array(self): 10 | sa = SuffixArray() 11 | for x in range(7): 12 | lst = [random.randint(0, 25) for _ in range(10 ** x)] 13 | sa.build(lst[:], 26) 14 | return 15 | 16 | 17 | if __name__ == '__main__': 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /src/dp/matrix_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.dp.matrix_dp.template import MatrixDP 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_matrix_dp(self): 9 | md = MatrixDP() 10 | matrix = [["1", "0", "1", "0", "0"], ["1", "0", "1", "1", "1"], [ 11 | "1", "1", "1", "1", "1"], ["1", "0", "0", "1", "0"]] 12 | assert md.maximal_square(matrix) == 4 13 | return 14 | 15 | 16 | if __name__ == '__main__': 17 | unittest.main() 18 | -------------------------------------------------------------------------------- /src/dp/linear_dp/template.py: -------------------------------------------------------------------------------- 1 | class LinearDP: 2 | def __init__(self): 3 | return 4 | 5 | @staticmethod 6 | def liner_dp_template(nums): 7 | # example of lis(longest increasing sequence) 8 | n = len(nums) 9 | dp = [0] * (n + 1) 10 | for i in range(n): 11 | dp[i + 1] = 1 12 | for j in range(i): 13 | if nums[i] > nums[j] and dp[j] + 1 > dp[i + 1]: 14 | dp[i + 1] = dp[j] + 1 15 | return max(dp) 16 | -------------------------------------------------------------------------------- /src/graph/prufer/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:prufer_series 3 | Description:Prufer code is a method of representing labeled rootless trees with a unique sequence of integers, which can generate a bijective relationship between labeled rootless trees and Prufer sequences. 4 | 5 | =====================================LuoGu====================================== 6 | P6086(https://www.luogu.com.cn/problem/P6086)prufer|classical 7 | P2817(https://www.luogu.com.cn/problem/P2817)cayley|specific_plan 8 | 9 | """ 10 | -------------------------------------------------------------------------------- /src/math/partition_fft/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | def test_euler_phi(self): 8 | # from src.math.partition_fft.problem import fft_v 9 | # 10 | # import numpy as np 11 | # 12 | # x = np.array([[3, 1, 2, 4]]) 13 | # 14 | # x_fft = fft_v(x) 15 | # 16 | # 17 | # print(list(x_fft)) 18 | return 19 | 20 | 21 | if __name__ == '__main__': 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /src/string/expression/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Algorithm:stack 4 | Description:xxx 5 | 6 | ====================================LeetCode==================================== 7 | 1597(https://leetcode.cn/problems/build-binary-expression-tree-from-math.infix-expression/) 8 | 9 | =====================================LuoGu====================================== 10 | P1175(https://www.luogu.com.cn/problem/P1175) 11 | P1617(https://www.luogu.com.cn/problem/P1617) 12 | P1322(https://www.luogu.com.cn/problem/P1322) 13 | """ 14 | -------------------------------------------------------------------------------- /src/dp/bag_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.dp.bag_dp.template import BagDP 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_bag_dp(self): 9 | bd = BagDP() 10 | for num in range(1, 100000): 11 | lst1 = bd.bin_split_1(num) 12 | lst2 = bd.bin_split_2(num) 13 | assert sum(lst1) == num == sum(lst2) 14 | assert len(lst1) == len(lst2) 15 | return 16 | 17 | 18 | if __name__ == '__main__': 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /src/math/extend_crt/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.math.extend_crt.template import ExtendCRT, CRT 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_crt(self): 9 | pairs = [(3, 1), (5, 1), (7, 2)] 10 | crt = CRT() 11 | assert crt.chinese_remainder(pairs) == 16 12 | 13 | exc = ExtendCRT() 14 | pairs = [(6, 11), (9, 25), (17, 33)] 15 | assert exc.pipline(pairs)[0] == 809 16 | return 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /src/basis/md_vector/template.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from operator import mul 3 | 4 | 5 | class MdVector: 6 | def __init__(self, dimension, initial): 7 | self.dimension = dimension 8 | self.dp = [initial] * reduce(mul, dimension) 9 | self.m = len(dimension) 10 | self.pos = [] 11 | for i in range(self.m): 12 | self.pos.append(reduce(mul, dimension[i + 1:] + [1])) 13 | return 14 | 15 | def get(self, lst): 16 | return sum(x * y for x, y in zip(lst, self.pos)) 17 | -------------------------------------------------------------------------------- /src/math/geometry/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.math.geometry.template import Geometry 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_geometry(self): 9 | gm = Geometry() 10 | assert gm.compute_square_point(1, 1, 6, 6) == ((1.0, 6.0), (6.0, 1.0)) 11 | assert gm.compute_square_point_non_vertical(0, 0, 0, 2) == ((1.0, 1.0), (-1.0, 1.0)) 12 | 13 | assert gm.compute_triangle_area(0, 0, 2, 0, 1, 1) == 1.0 14 | return 15 | 16 | 17 | if __name__ == '__main__': 18 | unittest.main() 19 | -------------------------------------------------------------------------------- /src/string/palindrome_num/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.string.palindrome_num.template import PalindromeNum 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_palindrome_num(self): 9 | pn = PalindromeNum() 10 | assert pn.get_palindrome_num_1(12) == pn.get_palindrome_num_2(12) 11 | 12 | n = "44" 13 | nums = pn.get_recent_palindrome_num(n) 14 | nums = [num for num in nums if num > int(n)] 15 | assert min(nums) == 55 16 | return 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /src/basis/various_sort/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.basis.various_sort.template import VariousSort 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_various_sort(self): 10 | vs = VariousSort() 11 | n = 200 12 | for _ in range(n): 13 | nums = [random.randint(0, n) for _ in range(n)] 14 | assert vs.defined_sort(nums) == vs.quick_sort_two(nums) == vs.range_merge_to_disjoint_sort(nums) == sorted( 15 | nums) 16 | return 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /src/struct/bit_set/template.py: -------------------------------------------------------------------------------- 1 | class SegBitSet: 2 | def __init__(self, n): 3 | self.n = n 4 | self.val = 0 5 | return 6 | 7 | def update(self, ll, rr): 8 | assert 0 <= ll <= rr <= self.n - 1 9 | mask = ((1 << (rr - ll + 1)) - 1) << ll 10 | self.val ^= mask 11 | return 12 | 13 | def query(self, ll, rr): 14 | assert 0 <= ll <= rr <= self.n - 1 15 | if ll == 0 and rr == self.n - 1: 16 | return self.val.bit_count() 17 | mask = ((1 << (rr - ll + 1)) - 1) << ll 18 | return (self.val & mask).bit_count() 19 | -------------------------------------------------------------------------------- /src/basis/md_vector/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:md_vector 3 | Description: 4 | 5 | ====================================LeetCode==================================== 6 | 7 | =====================================LuoGu====================================== 8 | 9 | ===================================CodeForces=================================== 10 | 11 | ====================================AtCoder===================================== 12 | 13 | =====================================AcWing===================================== 14 | 15 | """ 16 | 17 | 18 | class Solution: 19 | def __init__(self): 20 | return 21 | -------------------------------------------------------------------------------- /src/string/lyndon_decomposition/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.string.lyndon_decomposition.template import LyndonDecomposition 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | def test_solve_by_duval(self): 8 | ld = LyndonDecomposition() 9 | assert ld.solve_by_duval("ababa") == ["ab", "ab", "a"] 10 | return 11 | 12 | def test_min_cyclic_string(self): 13 | ld = LyndonDecomposition() 14 | assert ld.min_cyclic_string("ababa") == "aabab" 15 | assert ld.min_express("ababa")[1] == "aabab" 16 | return 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /src/math/comb_perm/example.py: -------------------------------------------------------------------------------- 1 | import math 2 | import unittest 3 | 4 | from src.math.comb_perm.template import Combinatorics 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | def test_comb_perm(self): 9 | for x in range(10): 10 | n = 10 ** 3 + x * 10 ** 2 11 | mod = 10 ** 9 + 7 12 | cb = Combinatorics(n, mod) 13 | for i in range(1, n + 1): 14 | assert pow(i, -1, mod) == cb.inv[i] == cb.inverse(i) 15 | assert math.factorial(i) % mod == cb.perm[i] == cb.factorial(i) 16 | assert math.comb(n, i) % mod == cb.comb(n, i) 17 | return 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /src/string/expression/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.string.expression.template import TreeExpression, EnglishNumber 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_tree_expression(self): 9 | te = TreeExpression() 10 | lst = ["2*3^4^2+(5/2-2)", "-2+3", "2*(-5/2+3*2)-33", "((-2+3)*3+5-7/2)^2", "2*(-3)", "0-(-3)", "-(-3)+2"] 11 | for s in lst: 12 | assert int(te.main_1175(s)[-1][0]) == eval(s.replace("^", "**").replace("/", "//")) 13 | return 14 | 15 | def test_english_number(self): 16 | en = EnglishNumber() 17 | num = 5208 18 | assert en.number_to_english(num) == "five thousand two hundred and eight" 19 | 20 | 21 | if __name__ == '__main__': 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /src/struct/sqrt_decomposition/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.struct.sqrt_decomposition.template import BlockSize 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_block_size(self): 9 | bs = BlockSize() 10 | for x in range(1, 10 ** 4 + 1): 11 | bs.get_divisor_split(x) 12 | cnt, seg = bs.get_divisor_split(100) 13 | assert cnt == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 4, 5, 8, 17, 50] 14 | assert seg == [[1, 4], [5, 5], [6, 5], [6, 6], [7, 6], [7, 7], [8, 7], [8, 8], [9, 9], [10, 10], [11, 11], 15 | [12, 12], [13, 14], [15, 16], [17, 20], [21, 25], [26, 33], [34, 50], [51, 100]] 16 | return 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /src/greedy/brain_storming/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class BrainStorming: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def minimal_coin_need(n, m, nums): 10 | # there are n selectable and math.infinite coins 11 | # and the minimum number of coins required to form all combinations of 1-m 12 | nums += [m + 1] 13 | nums.sort() 14 | if nums[0] != 1: 15 | return -1 16 | ans = sum_ = 0 17 | for i in range(n): 18 | nex = nums[i + 1] - 1 19 | nex = nex if nex < m else m 20 | x = math.ceil((nex - sum_) / nums[i]) 21 | x = x if x >= 0 else 0 22 | ans += x 23 | sum_ += x * nums[i] 24 | if sum_ >= m: 25 | break 26 | return ans 27 | -------------------------------------------------------------------------------- /src/basis/serialization/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:serialization|deserialization 3 | Description:2-tree|n-tree|tree_hash|tree_to_str|tree_serialization|tree_deserialization 4 | 5 | ====================================LeetCode==================================== 6 | 428(https://leetcode.cn/problems/serialize-and-deserialize-n-ary-tree/)n-tree|tree_serialization 7 | 297(https://leetcode.cn/problems/serialize-and-deserialize-binary-tree/)tree_deserialization 8 | 449(https://leetcode.cn/problems/serialize-and-deserialize-bst/)tree_deserialization 9 | 10 | =====================================LuoGu====================================== 11 | xx(xxx)xxx 12 | 13 | ===================================CodeForces=================================== 14 | xx(xxx)xxx 15 | 16 | """ 17 | 18 | 19 | class Solution: 20 | def __int__(self): 21 | return 22 | -------------------------------------------------------------------------------- /src/graph/euler_path/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.graph.euler_path.template import DirectedEulerPath 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_euler_path(self): 9 | pairs = [[1, 2], [2, 3], [3, 4], [4, 3], [3, 2], [2, 1]] 10 | pairs = [[x - 1, y - 1] for x, y in pairs] 11 | ep = DirectedEulerPath(4, pairs) 12 | ans = [[x + 1, y + 1] for x, y in ep.paths] 13 | assert ans == [[1, 2], [2, 3], [3, 4], [4, 3], [3, 2], [2, 1]] 14 | 15 | pairs = [[1, 3], [2, 1], [4, 2], [3, 3], [1, 2], [3, 4]] 16 | pairs = [[x - 1, y - 1] for x, y in pairs] 17 | ep = DirectedEulerPath(4, pairs) 18 | ans = [x + 1 for x in ep.nodes] 19 | assert ans == [1, 2, 1, 3, 3, 4, 2] 20 | return 21 | 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /src/struct/monotonic_queue/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.struct.monotonic_queue.template import PriorityQueue 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_priority_queue(self): 10 | pq = PriorityQueue() 11 | 12 | for _ in range(10): 13 | n = random.randint(100, 1000) 14 | nums = [random.randint(1, n) for _ in range(n)] 15 | k = random.randint(1, n) 16 | ans = pq.sliding_window(nums, k, "max") 17 | for i in range(n - k + 1): 18 | assert ans[i] == max(nums[i:i + k]) 19 | 20 | ans = pq.sliding_window(nums, k, "min") 21 | for i in range(n - k + 1): 22 | assert ans[i] == min(nums[i:i + k]) 23 | return 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /src/graph/dfs/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.graph.dfs.template import DFS, DfsEulerOrder 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_dfs(self): 9 | dfs = DFS() 10 | dct = [[1, 2], [0, 3], [0, 4], [1], [2]] 11 | start, end = dfs.gen_bfs_order_iteration([d[::-1] for d in dct]) 12 | assert start == [x - 1 for x in [1, 2, 4, 3, 5]] 13 | assert end == [b - 1 for _, b in [[1, 5], [2, 3], [4, 5], [3, 3], [5, 5]]] 14 | return 15 | 16 | def test_dfs_euler(self): 17 | dct = [[1, 2], [3, 4], [0, 5], [1], [1, 6], [2], [4]] 18 | dfs = DfsEulerOrder(dct) 19 | assert dfs.order_to_node == [0, 1, 3, 4, 6, 2, 5] 20 | assert dfs.euler_order == [0, 1, 3, 1, 4, 6, 4, 1, 0, 2, 5, 2, 0] 21 | return 22 | 23 | 24 | if __name__ == '__main__': 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /src/graph/union_find/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.graph.union_find.template import UnionFind, PersistentUnionFind 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_union_find(self): 9 | uf = UnionFind(5) 10 | for i, j in [[0, 1], [1, 2]]: 11 | uf.union(i, j) 12 | assert uf.part == 3 13 | return 14 | 15 | def test_persistent_union_find(self): 16 | n = 3 17 | puf = PersistentUnionFind(n) 18 | edge_list = [[0, 1, 2], [1, 2, 4], [2, 0, 8], [1, 0, 16]] 19 | edge_list.sort(key=lambda item: item[2]) 20 | for x, y, tm in edge_list: 21 | puf.union(x, y, tm) 22 | queries = [[0, 1, 2], [0, 2, 5]] 23 | assert [puf.is_connected(x, y, tm) for x, y, tm in queries] == [False, True] 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /src/dp/game_dp/template.py: -------------------------------------------------------------------------------- 1 | class DateTime: 2 | def __init__(self): 3 | self.leap_month = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] 4 | self.not_leap_month = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] 5 | return 6 | 7 | def is_leap_year(self, yy): 8 | # Determine whether it is a leap year 9 | assert sum(self.leap_month) == 366 10 | assert sum(self.not_leap_month) == 365 11 | return yy % 400 == 0 or (yy % 4 == 0 and yy % 100 != 0) 12 | 13 | def year_month_day_cnt(self, yy, mm): 14 | ans = self.leap_month[mm - 1] if self.is_leap_year(yy) else self.not_leap_month[mm - 1] 15 | return ans 16 | 17 | def is_valid(self, yy, mm, dd): 18 | if not [1900, 1, 1] <= [yy, mm, dd] <= [2006, 11, 4]: 19 | return False 20 | day = self.year_month_day_cnt(yy, mm) 21 | if not 1 <= dd <= day: 22 | return False 23 | return True 24 | -------------------------------------------------------------------------------- /src/basis/ternary_search/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.basis.ternary_search.template import TernarySearch, TriPartPackTriPart 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_tri_part_search(self): 9 | tps = TernarySearch() 10 | 11 | def fun1(x): return (x - 1) * (x - 1) 12 | 13 | assert abs(tps.find_floor_point_float(fun1, -5, 100) - 1) < 1e-5 14 | 15 | def fun2(x): return -(x - 1) * (x - 1) 16 | 17 | assert abs(tps.find_ceil_point_float(fun2, -5, 100) - 1) < 1e-5 18 | return 19 | 20 | def test_tri_part_pack_tri_part(self): 21 | tpt = TriPartPackTriPart() 22 | nodes = [[1, 1], [1, -1], [-1, 1], [-1, -1]] 23 | 24 | def target(x, y): return max([(x - p[0]) ** 2 + (y - p[1]) ** 2 for p in nodes]) 25 | 26 | x0, y0, _ = tpt.find_floor_point_float(target, -10, 10, -10, 10) 27 | assert abs(x0) < 1e-5 and abs(y0) < 1e-5 28 | return 29 | 30 | 31 | if __name__ == '__main__': 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /src/math/partition_fft/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:divide_and_conquer|fft 3 | Description: 4 | 5 | =====================================LuoGu====================================== 6 | P4721(https://www.luogu.com.cn/problem/P4721)divide_and_conquer|fft 7 | """ 8 | 9 | 10 | def fft_v(x): 11 | try: 12 | import numpy as np 13 | except ImportError: 14 | return 15 | x = np.asarray(x, dtype=float) 16 | n = x.shape[0] 17 | if np.log2(n) % 1 > 0: 18 | raise ValueError("must be a power of 2") 19 | 20 | n_min = min(n, 2) 21 | 22 | n = np.arange(n_min) 23 | k = n[:, None] 24 | m = np.exp(-2j * np.pi * n * k / n_min) 25 | xx = np.dot(m, x.reshape((n_min, -1))) 26 | while xx.shape[0] < n: 27 | x_even = xx[:, :int(xx.shape[1] / 2)] 28 | x_odd = xx[:, int(xx.shape[1] / 2):] 29 | terms = np.exp(-1j * np.pi * np.arange(xx.shape[0]) 30 | / xx.shape[0])[:, None] 31 | xx = np.vstack([x_even + terms * x_odd, 32 | x_even - terms * x_odd]) 33 | return xx.ravel() 34 | -------------------------------------------------------------------------------- /src/util/read_file.py: -------------------------------------------------------------------------------- 1 | class ReadFile: 2 | def __init__(self, path): 3 | self.fr = open(path, "r", encoding="utf-8", errors="ignore") 4 | 5 | def close(self): 6 | self.fr.close() 7 | 8 | def read_int(self): 9 | return int(self.fr.readline().rstrip()) 10 | 11 | def read_float(self): 12 | return float(self.fr.readline().rstrip()) 13 | 14 | def read_list_ints(self): 15 | return list(map(int, self.fr.readline().rstrip().split())) 16 | 17 | def read_list_floats(self): 18 | return list(map(float, self.fr.readline().rstrip().split())) 19 | 20 | def read_list_ints_minus_one(self): 21 | return list(map(lambda x: int(x) - 1, self.fr.readline().rstrip().split())) 22 | 23 | def read_str(self): 24 | return self.fr.readline().rstrip() 25 | 26 | def read_list_strs(self): 27 | return self.fr.readline().rstrip().split() 28 | 29 | def read_list_str(self): 30 | return list(self.fr.readline().rstrip()) 31 | 32 | @staticmethod 33 | def st(s): 34 | print(s) 35 | return 36 | -------------------------------------------------------------------------------- /src/struct/sqrt_decomposition/template.py: -------------------------------------------------------------------------------- 1 | class BlockSize: 2 | def __init__(self): 3 | return 4 | 5 | @staticmethod 6 | def get_divisor_split(n): 7 | # Decompose the interval [1, n] into each interval whose divisor of n does not exceed the range 8 | if n == 1: 9 | return [1], [[1, 1]] 10 | m = int(n ** 0.5) 11 | pre = [] 12 | post = [] 13 | for x in range(1, m + 1): 14 | pre.append(x) 15 | post.append(n // x) 16 | if pre[-1] == post[-1]: 17 | post.pop() 18 | post.reverse() 19 | res = pre + post 20 | 21 | cnt = [res[0]] + [res[i + 1] - res[i] for i in range(len(res) - 1)] 22 | k = len(cnt) 23 | assert k == 2 * m - int(m == n // m) 24 | 25 | right = [n // (k - i) for i in range(1, k)] 26 | pre = n // k 27 | seg = [[1, pre - 1]] if pre > 1 else [] 28 | for num in right: 29 | seg.append([pre, num]) 30 | pre = num + 1 31 | assert sum([ls[1] - ls[0] + 1 for ls in seg]) == n 32 | return cnt, seg 33 | -------------------------------------------------------------------------------- /src/graph/prufer/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.graph.prufer.template import PruferAndTree 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | def test_tree_to_prufer(self): 8 | ptt = PruferAndTree() 9 | adj = [[1, 2, 3, 4], [0], [0, 5, 6], [0], [0], [2], [2]] 10 | code = [0, 0, 0, 2, 2] 11 | assert ptt.tree_to_prufer(adj, root=6) == code 12 | 13 | ptt = PruferAndTree() 14 | adj = [[1], [0, 2, 3, 6], [1, 4, 5], [1], [2], [2], [1]] 15 | code = [1, 1, 2, 2, 1] 16 | assert ptt.tree_to_prufer(adj, root=1) == code 17 | return 18 | 19 | def test_prufer_to_tree(self): 20 | ptt = PruferAndTree() 21 | code = [0, 0, 0, 2, 2] 22 | adj = [[1, 2, 3, 4], [0], [0, 5, 6], [0], [0], [2], [2]] 23 | assert ptt.prufer_to_tree(code, root=6) == adj 24 | 25 | ptt = PruferAndTree() 26 | code = [1, 1, 2, 2, 1] 27 | adj = [[1], [0, 2, 3, 6], [1, 4, 5], [1], [2], [2], [1]] 28 | assert ptt.prufer_to_tree(code, root=1) == adj 29 | return 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /src/struct/priority_queue/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.struct.priority_queue.template import HeapqMedian, FindMedian 5 | from src.struct.sorted_list.template import SortedList 6 | 7 | 8 | class TestGeneral(unittest.TestCase): 9 | 10 | def test_heapq_median(self): 11 | ceil = 10000 12 | lst = SortedList() 13 | hm = FindMedian() 14 | for i in range(ceil): 15 | num = random.randint(0, ceil)*2 16 | x = random.randint(0, 5) 17 | if x == 0 and lst: 18 | i = random.randint(0, len(lst) - 1) 19 | hm.remove(lst.pop(i)) 20 | else: 21 | lst.add(num) 22 | hm.add(num) 23 | if not lst: 24 | continue 25 | assert len(lst) == hm.small_cnt + hm.big_cnt 26 | if len(lst) % 2: 27 | assert lst[len(lst)//2] == hm.find_median() 28 | else: 29 | assert lst[len(lst) // 2] + lst[len(lst) // 2 - 1] == hm.find_median()*2 30 | return 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /src/basis/binary_search/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.basis.binary_search.template import BinarySearch 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_binary_search(self): 9 | bs = BinarySearch() 10 | 11 | def check(xx): 12 | nonlocal tm 13 | tm += 1 14 | return xx >= y 15 | 16 | for x in range(1, 7): 17 | n = 10**x 18 | lst = [] 19 | for y in range(1, n + 1): 20 | tm = 0 21 | bs.find_int_left(1, n, check) 22 | lst.append(tm) 23 | assert (1 << max(lst)) >= n > (1 << (max(lst) - 1)) 24 | 25 | def check_right(xx): 26 | nonlocal tm 27 | tm += 1 28 | return xx <= y 29 | 30 | for x in range(1, 7): 31 | n = 10**x 32 | lst = [] 33 | for y in range(1, n + 1): 34 | tm = 0 35 | bs.find_int_right(1, n, check_right) 36 | lst.append(tm) 37 | assert (1 << max(lst)) >= n > (1 << (max(lst) - 1)) 38 | return 39 | 40 | 41 | if __name__ == '__main__': 42 | unittest.main() 43 | -------------------------------------------------------------------------------- /src/basis/date/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.basis.date.template import DateTime 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_date_time(self): 9 | dt = DateTime() 10 | assert dt.get_n_days(2023, 1, 2, 1) == "2023-01-03" 11 | 12 | assert dt.is_valid_date("2023-02-29") is False 13 | assert dt.is_valid_date("2023-02-28") is True 14 | assert dt.is_valid_date("0001-02-27") is True 15 | 16 | res = dt.all_palindrome_date() 17 | assert len(res) == 331 18 | 19 | assert dt.is_leap_year(2000) is True 20 | assert dt.is_leap_year(2100) is False 21 | assert dt.is_leap_year(0) is True 22 | 23 | assert dt.unix_to_time(1462451334) == "2016-05-05 20:28:54" 24 | assert dt.time_to_unix("2016-05-05 20:28:54") == 1462451334 25 | 26 | assert dt.unix_to_time(1462451335) == "2016-05-05 20:28:55" 27 | assert dt.time_to_unix("2016-05-05 20:28:55") == 1462451335 28 | 29 | ans = 0 30 | for i in range(10000): 31 | ans += dt.is_leap_year(i) 32 | assert ans == dt.leap_year_count(i) 33 | return 34 | 35 | 36 | if __name__ == '__main__': 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /src/struct/list_node/template.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class ListNode: 5 | def __init__(self, val=0, next=None): 6 | self.val = val 7 | self.next = next 8 | 9 | 10 | class ListNodeOperation: 11 | def __init__(self): 12 | return 13 | 14 | @staticmethod 15 | def node_to_lst(node: ListNode) -> List[int]: 16 | lst = [] 17 | while node: 18 | lst.append(node.val) 19 | node = node.next 20 | return lst 21 | 22 | @staticmethod 23 | def lst_to_node(lst: List[int]) -> ListNode: 24 | node = ListNode(-1) 25 | pre = node 26 | for num in lst: 27 | pre.next = ListNode(num) 28 | pre = pre.next 29 | return node.next 30 | 31 | @staticmethod 32 | def node_to_num(node: ListNode) -> int: 33 | num = 0 34 | while node: 35 | num = num * 10 + node.val 36 | node = node.next 37 | return num 38 | 39 | @staticmethod 40 | def num_to_node(num: int) -> ListNode: 41 | node = ListNode(-1) 42 | pre = node 43 | for x in str(num): 44 | pre.next = ListNode(int(x)) 45 | pre = pre.next 46 | return node.next 47 | -------------------------------------------------------------------------------- /src/dp/digital_dp/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.dp.digital_dp.template import DigitalDP 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_digital_dp(self): 9 | 10 | dd = DigitalDP() 11 | cnt = [0] * 10 12 | n = 1000000 13 | for i in range(1, n + 1): 14 | for w in str(i): 15 | cnt[int(w)] += 1 16 | 17 | for d in range(10): 18 | assert dd.count_digit_dp(n, d) == cnt[d] 19 | for d in range(1, 10): 20 | ans1 = dd.count_num_base(n, d) 21 | ans2 = sum(str(d) not in str(num) for num in range(1, n + 1)) 22 | assert ans1 == ans2 23 | 24 | for d in range(10): 25 | ans1 = dd.count_num_dp(n, d) 26 | ans2 = sum(str(d) not in str(num) for num in range(1, n + 1)) 27 | assert ans1 == ans2 28 | 29 | for d in range(10): 30 | nums = [] 31 | for i in range(1, n + 1): 32 | if str(d) not in str(i): 33 | nums.append(i) 34 | for i, num in enumerate(nums): 35 | assert dd.get_kth_without_d(i + 1, d) == num 36 | return 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/leetcode/template.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | import re 4 | import sys 5 | import unittest 6 | from typing import List, Callable 7 | from typing import List 8 | import heapq 9 | import math 10 | from math import inf 11 | from collections import defaultdict, Counter, deque 12 | from functools import lru_cache, cmp_to_key 13 | from itertools import combinations, accumulate, chain, count 14 | from functools import reduce 15 | from heapq import heappush, heappop, heappushpop, heapify 16 | from operator import xor, mul, add, or_ 17 | from functools import lru_cache 18 | import random 19 | from itertools import permutations, combinations 20 | 21 | from decimal import Decimal 22 | 23 | import heapq 24 | import copy 25 | 26 | from src.struct.sorted_list.template import SortedList 27 | 28 | 29 | # sys.set_int_max_str_digits(0) # for big number in leet code 30 | 31 | 32 | def max(a, b): 33 | return a if a > b else b 34 | 35 | 36 | def min(a, b): 37 | return a if a < b else b 38 | 39 | 40 | class Solution: 41 | 42 | @staticmethod 43 | def example() -> int: 44 | return 0 45 | 46 | 47 | class TestGeneral(unittest.TestCase): 48 | 49 | def test_example(self): 50 | assert Solution().example() == 0 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /src/math/fast_power/example.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import unittest 4 | 5 | from src.math.fast_power.template import FastPower, MatrixFastPower 6 | 7 | 8 | class TestGeneral(unittest.TestCase): 9 | 10 | def test_fast_power(self): 11 | fp = FastPower() 12 | 13 | for _ in range(1000): 14 | a, b, mod = random.randint(1, 123), random.randint(1, 1234), random.randint(1, 12345) 15 | assert fp.fast_power_api(a, b, mod) == fp.fast_power(a, b, mod) 16 | 17 | x, n = random.uniform(0, 1), random.randint(1, 1234) 18 | assert abs(fp.float_fast_pow(x, n) - pow(x, n)) < 1e-5 19 | 20 | mfp = MatrixFastPower() 21 | mat = [[1, 0, 1], [1, 0, 0], [0, 1, 0]] 22 | mod = 10 ** 9 + 7 23 | for _ in range(10): 24 | n = random.randint(1, 10000) 25 | cur = copy.deepcopy(mat) 26 | for _ in range(1, n): 27 | cur = mfp._matrix_mul(cur, mat, mod) 28 | assert cur == mfp.matrix_pow(mat, n, mod) == mfp.matrix_pow(mat, n, mod) 29 | 30 | ba = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] 31 | assert mfp.matrix_pow(mat, 0, mod) == mfp.matrix_pow(mat, 0, mod) == ba 32 | return 33 | 34 | 35 | if __name__ == '__main__': 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /src/basis/date/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:date 3 | Description:date|year|week|month|day|hour|second 4 | 5 | 6 | =====================================LuoGu====================================== 7 | P2655(https://www.luogu.com.cn/problem/P2655)after_date 8 | P1167#submit(https://www.luogu.com.cn/problem/P1167#submit)between_date 9 | P5407(https://www.luogu.com.cn/problem/P5407)between_date 10 | P5440(https://www.luogu.com.cn/problem/P5440)brute_force|prime 11 | 12 | 13 | """ 14 | import datetime 15 | from datetime import datetime, timedelta 16 | 17 | from src.util.fast_io import FastIO 18 | 19 | 20 | class Solution: 21 | def __init__(self): 22 | return 23 | 24 | @staticmethod 25 | def lg_p2655(ac=FastIO()): 26 | """ 27 | url: https://www.luogu.com.cn/problem/P2655 28 | tag: after_date 29 | """ 30 | n = ac.read_int() 31 | for _ in range(n): 32 | lst = ac.read_list_ints() 33 | x = (1 << (lst[0] - 1)) - 1 34 | y = lst[1] 35 | m, d, h, m, s = lst[2:] 36 | start_date = datetime(year=y, month=m, day=d, hour=h, minute=m, second=s) 37 | end_date = start_date + timedelta(seconds=x) 38 | ac.lst([end_date.year, end_date.month, end_date.day, end_date.hour, end_date.minute, end_date.second]) 39 | return 40 | -------------------------------------------------------------------------------- /src/string/kmp/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.string.kmp.template import KMP 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_kmp(self): 10 | kmp = KMP() 11 | for _ in range(10): 12 | for x in range(4): 13 | lst = [random.randint(0, 10) for _ in range(10 ** x)] 14 | n = len(lst) 15 | pi = kmp.prefix_function(lst) 16 | nxt = kmp.prefix_function_reverse(lst) 17 | z = kmp.z_function(lst) 18 | for i in range(1, n): 19 | ceil = floor = 0 20 | for j in range(1, i + 1): 21 | if lst[j:i + 1] == lst[:i - j + 1]: 22 | if i - j + 1 > ceil: 23 | ceil = i - j + 1 24 | if floor == 0 or i - j + 1 < floor: 25 | floor = i - j + 1 26 | assert pi[i] == ceil 27 | assert nxt[i] == floor 28 | i1, j1 = 0, i 29 | while j1 < n and lst[j1] == lst[i1]: 30 | i1 += 1 31 | j1 += 1 32 | assert z[i] == i1 33 | return 34 | 35 | 36 | if __name__ == '__main__': 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /tests/leetcode/simple.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | import re 4 | import sys 5 | import unittest 6 | from typing import List, Callable 7 | from typing import List 8 | import heapq 9 | import math 10 | from math import inf 11 | from collections import defaultdict, Counter, deque 12 | from functools import lru_cache, cmp_to_key 13 | from itertools import combinations, accumulate, chain, count 14 | from functools import reduce 15 | from heapq import heappush, heappop, heappushpop, heapify 16 | from operator import xor, mul, add, or_ 17 | from functools import lru_cache 18 | import random 19 | from itertools import permutations, combinations 20 | 21 | from decimal import Decimal 22 | 23 | import heapq 24 | import copy 25 | 26 | # from src.struct.sorted_list.template import SortedList 27 | 28 | # from sortedcontainers import SortedList 29 | # sys.set_int_max_str_digits(0) # for big number in leet code 30 | 31 | 32 | def max(a, b): 33 | return a if a > b else b 34 | 35 | 36 | def min(a, b): 37 | return a if a < b else b 38 | 39 | 40 | class Solution: 41 | 42 | @staticmethod 43 | def example() -> int: 44 | return 0 45 | 46 | 47 | class TestGeneral(unittest.TestCase): 48 | 49 | def test_example(self): 50 | assert Solution().example() == 0 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tests/leetcode/problem_1.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | import re 4 | import sys 5 | import unittest 6 | from typing import List, Callable 7 | from typing import List 8 | import heapq 9 | import math 10 | from math import inf 11 | from collections import defaultdict, Counter, deque 12 | from functools import lru_cache, cmp_to_key 13 | from itertools import combinations, accumulate, chain, count 14 | from functools import reduce 15 | from heapq import heappush, heappop, heappushpop, heapify 16 | from operator import xor, mul, add, or_ 17 | from functools import lru_cache 18 | import random 19 | from itertools import permutations, combinations 20 | 21 | from decimal import Decimal 22 | 23 | import heapq 24 | import copy 25 | 26 | # from src.struct.sorted_list.template import SortedList 27 | 28 | # from sortedcontainers import SortedList 29 | # sys.set_int_max_str_digits(0) # for big number in leet code 30 | 31 | 32 | def max(a, b): 33 | return a if a > b else b 34 | 35 | 36 | def min(a, b): 37 | return a if a < b else b 38 | 39 | 40 | class Solution: 41 | 42 | @staticmethod 43 | def example() -> int: 44 | return 0 45 | 46 | 47 | class TestGeneral(unittest.TestCase): 48 | 49 | def test_example(self): 50 | assert Solution().example() == 0 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tests/leetcode/problem_2.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | import re 4 | import sys 5 | import unittest 6 | from typing import List, Callable 7 | from typing import List 8 | import heapq 9 | import math 10 | from math import inf 11 | from collections import defaultdict, Counter, deque 12 | from functools import lru_cache, cmp_to_key 13 | from itertools import combinations, accumulate, chain, count 14 | from functools import reduce 15 | from heapq import heappush, heappop, heappushpop, heapify 16 | from operator import xor, mul, add, or_ 17 | from functools import lru_cache 18 | import random 19 | from itertools import permutations, combinations 20 | 21 | from decimal import Decimal 22 | 23 | import heapq 24 | import copy 25 | 26 | # from src.struct.sorted_list.template import SortedList 27 | 28 | # from sortedcontainers import SortedList 29 | # sys.set_int_max_str_digits(0) # for big number in leet code 30 | 31 | 32 | def max(a, b): 33 | return a if a > b else b 34 | 35 | 36 | def min(a, b): 37 | return a if a < b else b 38 | 39 | 40 | class Solution: 41 | 42 | @staticmethod 43 | def example() -> int: 44 | return 0 45 | 46 | 47 | class TestGeneral(unittest.TestCase): 48 | 49 | def test_example(self): 50 | assert Solution().example() == 0 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tests/leetcode/problem_3.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | import re 4 | import sys 5 | import unittest 6 | from typing import List, Callable 7 | from typing import List 8 | import heapq 9 | import math 10 | from math import inf 11 | from collections import defaultdict, Counter, deque 12 | from functools import lru_cache, cmp_to_key 13 | from itertools import combinations, accumulate, chain, count 14 | from functools import reduce 15 | from heapq import heappush, heappop, heappushpop, heapify 16 | from operator import xor, mul, add, or_ 17 | from functools import lru_cache 18 | import random 19 | from itertools import permutations, combinations 20 | 21 | from decimal import Decimal 22 | 23 | import heapq 24 | import copy 25 | 26 | # from src.struct.sorted_list.template import SortedList 27 | 28 | # from sortedcontainers import SortedList 29 | # sys.set_int_max_str_digits(0) # for big number in leet code 30 | 31 | 32 | def max(a, b): 33 | return a if a > b else b 34 | 35 | 36 | def min(a, b): 37 | return a if a < b else b 38 | 39 | 40 | class Solution: 41 | 42 | @staticmethod 43 | def example() -> int: 44 | return 0 45 | 46 | 47 | class TestGeneral(unittest.TestCase): 48 | 49 | def test_example(self): 50 | assert Solution().example() == 0 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tests/leetcode/problem_4.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | import re 4 | import sys 5 | import unittest 6 | from typing import List, Callable 7 | from typing import List 8 | import heapq 9 | import math 10 | from math import inf 11 | from collections import defaultdict, Counter, deque 12 | from functools import lru_cache, cmp_to_key 13 | from itertools import combinations, accumulate, chain, count 14 | from functools import reduce 15 | from heapq import heappush, heappop, heappushpop, heapify 16 | from operator import xor, mul, add, or_ 17 | from functools import lru_cache 18 | import random 19 | from itertools import permutations, combinations 20 | 21 | from decimal import Decimal 22 | 23 | import heapq 24 | import copy 25 | 26 | # from src.struct.sorted_list.template import SortedList 27 | 28 | # from sortedcontainers import SortedList 29 | # sys.set_int_max_str_digits(0) # for big number in leet code 30 | 31 | 32 | def max(a, b): 33 | return a if a > b else b 34 | 35 | 36 | def min(a, b): 37 | return a if a < b else b 38 | 39 | 40 | class Solution: 41 | 42 | @staticmethod 43 | def example() -> int: 44 | return 0 45 | 46 | 47 | class TestGeneral(unittest.TestCase): 48 | 49 | def test_example(self): 50 | assert Solution().example() == 0 51 | return 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /src/math/high_precision/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.math.high_precision.template import HighPrecision, FloatToFrac 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_high_precision(self): 9 | hp = HighPrecision() 10 | assert hp.float_pow("98.999", "5") == "9509420210.697891990494999" 11 | 12 | assert hp.fraction_to_decimal(45, 56) == "0.803(571428)" 13 | assert hp.fraction_to_decimal(2, 1) == "2.0" 14 | assert hp.decimal_to_fraction("0.803(571428)") == [45, 56] 15 | assert hp.decimal_to_fraction("2.0") == [2, 1] 16 | return 17 | 18 | def test_float_to_frac(self): 19 | ff = FloatToFrac() 20 | assert ff.frac_add([1, 2], [1, 3]) == [5, 6] 21 | assert ff.frac_add([1, 2], [1, -3]) == [1, 6] 22 | assert ff.frac_add([1, -2], [1, 3]) == [-1, 6] 23 | 24 | assert ff.frac_max([1, 2], [1, 3]) == [1, 2] 25 | assert ff.frac_min([1, 2], [1, 3]) == [1, 3] 26 | 27 | assert ff.frac_max([1, -2], [1, -3]) == [-1, 3] 28 | assert ff.frac_min([1, -2], [1, -3]) == [-1, 2] 29 | 30 | assert ff.frac_ceil([2, 3]) == 1 31 | assert ff.frac_ceil([5, 3]) == 2 32 | assert ff.frac_ceil([-2, 3]) == 0 33 | assert ff.frac_ceil([-5, 3]) == -1 34 | return 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /src/greedy/longest_increasing_subsequence/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from collections import defaultdict 4 | 5 | from src.greedy.longest_increasing_subsequence.template import LongestIncreasingSubsequence, LcsComputeByLis 6 | 7 | 8 | class TestGeneral(unittest.TestCase): 9 | 10 | def test_longest_increasing_subsequence(self): 11 | lis = LongestIncreasingSubsequence() 12 | nums = [1, 2, 3, 3, 2, 2, 1] 13 | assert lis.definitely_increase(nums) == 3 14 | assert lis.definitely_not_reduce(nums) == 4 15 | assert lis.definitely_reduce(nums) == 3 16 | assert lis.definitely_not_increase(nums) == 5 17 | 18 | for _ in range(10): 19 | nums = [random.randint(0, 100) for _ in range(10)] 20 | ans = LcsComputeByLis().length_and_max_sum_of_lis(nums) 21 | cur = defaultdict(int) 22 | n = len(nums) 23 | for i in range(1, 1 << n): 24 | lst = [nums[j] for j in range(n) if i & (1 << j)] 25 | m = len(lst) 26 | if lst == sorted(lst) and all(lst[j + 1] > lst[j] for j in range(m - 1)): 27 | a, b = cur[m], sum(lst) 28 | cur[m] = a if a > b else b 29 | length = max(cur) 30 | assert ans == cur[length] 31 | return 32 | 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /src/tree/tree_lca/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.tree.tree_dp.template import WeightedTree 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | def test_tree_ancestor(self): 9 | parent = [-1, 0, 0, 1, 2] 10 | n = len(parent) 11 | graph = WeightedTree(n) 12 | for i in range(n): 13 | if parent[i] != -1: 14 | graph.add_undirected_edge(parent[i], i, 1) 15 | graph.lca_build_with_multiplication() 16 | assert graph.lca_get_kth_ancestor(4, 3) == -1 17 | assert graph.lca_get_kth_ancestor(4, 2) == 0 18 | assert graph.lca_get_kth_ancestor(4, 1) == 2 19 | assert graph.lca_get_kth_ancestor(4, 0) == 4 20 | assert graph.lca_get_lca_between_nodes(3, 4) == 0 21 | assert graph.lca_get_lca_between_nodes(2, 4) == 2 22 | assert graph.lca_get_lca_between_nodes(3, 1) == 1 23 | assert graph.lca_get_lca_between_nodes(3, 2) == 0 24 | assert graph.lca_get_lca_and_dist_between_nodes(0, 0)[1] == 0 25 | assert graph.lca_get_lca_and_dist_between_nodes(0, 4)[1] == 2 26 | assert graph.lca_get_lca_and_dist_between_nodes(3, 4)[1] == 4 27 | assert graph.lca_get_lca_and_dist_between_nodes(1, 0)[1] == 1 28 | assert graph.lca_get_lca_and_dist_between_nodes(2, 3)[1] == 3 29 | return 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /src/struct/monotonic_queue/template.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | 4 | class PriorityQueue: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def sliding_window(nums, k: int, method="max"): 10 | assert k >= 1 11 | if method == "min": 12 | nums = [-num for num in nums] 13 | n = len(nums) 14 | stack = deque() 15 | ans = [] 16 | for i in range(n): 17 | while stack and stack[0][1] <= i - k: 18 | stack.popleft() 19 | while stack and stack[-1][0] <= nums[i]: 20 | stack.pop() 21 | stack.append([nums[i], i]) 22 | if i >= k - 1: 23 | ans.append(stack[0][0]) 24 | if method == "min": 25 | ans = [-num for num in ans] 26 | return ans 27 | 28 | @staticmethod 29 | def sliding_window_all(nums, k: int, method="max"): 30 | assert k >= 1 31 | if method == "min": 32 | nums = [-num for num in nums] 33 | n = len(nums) 34 | stack = deque() 35 | ans = [] 36 | for i in range(n): 37 | while stack and stack[0][1] <= i - k: 38 | stack.popleft() 39 | while stack and stack[-1][0] <= nums[i]: 40 | stack.pop() 41 | stack.append([nums[i], i]) 42 | ans.append(stack[0][0]) 43 | if method == "min": 44 | ans = [-num for num in ans] 45 | return ans 46 | -------------------------------------------------------------------------------- /src/math/scan_line/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:scan_line 3 | Description:plane|cube 4 | 5 | ====================================LeetCode==================================== 6 | 218(https://leetcode.cn/problems/the-skyline-problem/)scan_line 7 | 850(https://leetcode.cn/problems/rectangle-area-ii/)scan_line|segment_tree|discretization|O(nlogn) 8 | 9 | =====================================LuoGu====================================== 10 | P6265(https://www.luogu.com.cn/problem/P6265)scan_line 11 | P5490(https://www.luogu.com.cn/problem/P5490)scan_line 12 | P1884(https://www.luogu.com.cn/problem/P1884)scan_line 13 | P1904(https://www.luogu.com.cn/problem/P1904)scan_line 14 | 15 | """ 16 | from src.math.scan_line.template import ScanLine 17 | from src.util.fast_io import FastIO 18 | 19 | 20 | class Solution: 21 | def __init__(self): 22 | return 23 | 24 | @staticmethod 25 | def lg_p1884(ac=FastIO()): 26 | """ 27 | url: https://www.luogu.com.cn/problem/P1884 28 | tag: scan_line 29 | """ 30 | n = ac.read_int() 31 | lst = [] 32 | for _ in range(n): 33 | lst.append(ac.read_list_ints()) 34 | low_x = min(min(ls[0], ls[2]) for ls in lst) 35 | low_y = min(min(ls[1], ls[3]) for ls in lst) 36 | 37 | lst = [[ls[0] - low_x, ls[1] - low_y, ls[2] - low_x, ls[3] - low_y] for ls in lst] 38 | lst = [[ls[0], ls[3], ls[2], ls[1]] for ls in lst] 39 | ans = ScanLine().get_rec_area(lst) 40 | ac.st(ans) 41 | return 42 | -------------------------------------------------------------------------------- /src/struct/list_node/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:list_node|linked_list 3 | Description: 4 | 5 | ====================================LeetCode==================================== 6 | 6914(https://leetcode.cn/contest/weekly-contest-358/problems/double-a-number-represented-as-a-linked-list/)linked_list 7 | 8 | """ 9 | from typing import Optional 10 | 11 | from src.struct.list_node.template import ListNode, ListNodeOperation 12 | 13 | 14 | class Solution: 15 | def __int__(self): 16 | return 17 | 18 | @staticmethod 19 | def lc_6914_1(head: Optional[ListNode]) -> Optional[ListNode]: 20 | """ 21 | url: https://leetcode.cn/problems/double-a-number-represented-as-a-linked-list/ 22 | tag: linked_list 23 | """ 24 | lno = ListNodeOperation() 25 | lst = lno.node_to_lst(head)[::-1] 26 | 27 | nums = [] 28 | x = 0 29 | for num in lst: 30 | x += num * 2 31 | nums.append(x % 10) 32 | x = 1 if x >= 10 else 0 33 | if x: 34 | nums.append(x) 35 | 36 | nums.reverse() 37 | return lno.lst_to_node(nums) 38 | 39 | @staticmethod 40 | def lc_6914_2(head: Optional[ListNode]) -> Optional[ListNode]: 41 | """ 42 | url: https://leetcode.cn/contest/weekly-contest-358/problems/double-a-number-represented-as-a-linked-list/ 43 | tag: linked_list 44 | """ 45 | lno = ListNodeOperation() 46 | num = lno.node_to_num(head) * 2 47 | return lno.num_to_node(num) 48 | -------------------------------------------------------------------------------- /src/math/extend_crt/template.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | 4 | class CRT: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def chinese_remainder(pairs): 10 | mod_list, remainder_list = [p[0] for p in pairs], [p[1] for p in pairs] 11 | mod_product = reduce(lambda x, y: x * y, mod_list) 12 | mi_list = [mod_product // x for x in mod_list] 13 | mi_inverse = [ExtendCRT().extend_gcd(mi_list[i], mod_list[i])[0] for i in range(len(mi_list))] 14 | x = 0 15 | for i in range(len(remainder_list)): 16 | x += mi_list[i] * mi_inverse[i] * remainder_list[i] 17 | x %= mod_product 18 | return x 19 | 20 | 21 | class ExtendCRT: 22 | def __init__(self): 23 | return 24 | 25 | def gcd(self, a, b): 26 | if b == 0: 27 | return a 28 | return self.gcd(b, a % b) 29 | 30 | def lcm(self, a, b): 31 | return a * b // self.gcd(a, b) 32 | 33 | def extend_gcd(self, a, b): 34 | if b == 0: 35 | return 1, 0 36 | x, y = self.extend_gcd(b, a % b) 37 | return y, x - a // b * y 38 | 39 | def uni(self, p, q): 40 | r1, m1 = p 41 | r2, m2 = q 42 | 43 | d = self.gcd(m1, m2) 44 | assert (r2 - r1) % d == 0 # else without solution 45 | l1, l2 = self.extend_gcd(m1 // d, m2 // d) 46 | 47 | return (r1 + (r2 - r1) // d * l1 * m1) % self.lcm(m1, m2), self.lcm(m1, m2) 48 | 49 | def pipline(self, eq): 50 | return reduce(self.uni, eq) 51 | -------------------------------------------------------------------------------- /src/math/prime_factor/example.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src.math.prime_factor.template import AllFactorCnt, PrimeFactor, RadFactor, AllFactor 4 | 5 | 6 | class TestGeneral(unittest.TestCase): 7 | 8 | 9 | def test_rad_factor(self): # 1.891 10 | n = 2 * 10 ** 5 11 | rf = RadFactor(n) 12 | for i in range(n+1): 13 | assert sorted(rf.get_rad_factor(i)) == sorted(rf.get_rad_factor2(i)) 14 | return 15 | 16 | 17 | def test_all_factor(self): # 1.891 18 | n = 2 * 10 ** 5 19 | all_factor = [[], [1]] + [[1, i] for i in range(2, n + 1)] 20 | af = AllFactor(n) 21 | for i in range(2, n + 1): 22 | x = i 23 | while x * i <= n: 24 | all_factor[x * i].append(i) 25 | if i != x: 26 | all_factor[x * i].append(x) 27 | x += 1 28 | for i in range(1, n + 1): 29 | all_factor[i].sort() 30 | assert all_factor[i] == sorted(af.get_all_factor(i)) 31 | assert [len(ls) for ls in all_factor] == AllFactorCnt(n).all_factor_cnt 32 | assert all_factor == PrimeFactor(n).all_factor 33 | return 34 | 35 | def test_prime_factor(self): # 1.891 36 | n = 2*10**5 37 | pf = PrimeFactor(n) 38 | assert pf.prime_factor_cnt[1:] == [len(ls) for ls in pf.prime_factor[1:]] 39 | assert pf.prime_factor_mi_cnt[1:] == [sum(x for _, x in ls) for ls in pf.prime_factor[1:]] 40 | return 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /src/string/manacher_palindrome/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.string.manacher_palindrome.template import ManacherPlindrome 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_manacher_palindrome_count_start_end(self): 10 | mp = ManacherPlindrome() 11 | for x in range(4): 12 | n = 10 ** x 13 | for _ in range(10): 14 | nums = [random.randint(0, x) for _ in range(n)] 15 | start = [0] * n 16 | end = [0] * n 17 | cnt = [0] * (n + 1) 18 | start_odd = [0]*n 19 | end_odd = [0]*n 20 | for i in range(n): 21 | for j in range(i, n): 22 | if nums[i:j + 1] == nums[i:j + 1][::-1]: 23 | start[i] += 1 24 | end[j] += 1 25 | cnt[j - i + 1] += 1 26 | if (j-i+1) % 2: 27 | start_odd[i] += 1 28 | end_odd[j] += 1 29 | assert start, end == mp.palindrome_count_start_end("".join(chr(x + ord("a")) for x in nums)) 30 | assert cnt == mp.palindrome_length_count("".join(chr(x + ord("a")) for x in nums)) 31 | assert sum(cnt) == mp.palindrome_count("".join(chr(x + ord("a")) for x in nums)) 32 | assert start_odd, end_odd == mp.palindrome_count_start_end_odd("".join(chr(x + ord("a")) for x in nums)) 33 | 34 | return 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /src/basis/two_pointers/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from functools import reduce 4 | from math import gcd 5 | from operator import mul, add, xor, and_, or_ 6 | 7 | from src.basis.two_pointers.template import TwoPointer, SlidingWindowAggregation 8 | 9 | 10 | class TestGeneral(unittest.TestCase): 11 | 12 | def test_two_pointers(self): 13 | nt = TwoPointer() 14 | nums = [1, 2, 3, 4, 4, 3, 3, 2, 1, 6, 3] 15 | assert nt.same_direction(nums) == 4 16 | 17 | nums = [1, 2, 3, 4, 4, 5, 6, 9] 18 | assert nt.opposite_direction(nums, 9) 19 | nums = [1, 2, 3, 4, 4, 5, 6, 9] 20 | assert not nt.opposite_direction(nums, 16) 21 | return 22 | 23 | def test_ops_es(self): 24 | 25 | dct = {max: 0, min: 1 << 64, gcd: 0, or_: 0, xor: 0, add: 0, mul: 1, and_: (1 << 32) - 1} 26 | for op in dct: 27 | for _ in range(1000): 28 | e = dct[op] 29 | n = 100 30 | nums = [random.randint(0, 10 ** 9) for _ in range(n)] 31 | swa = SlidingWindowAggregation(e, op) 32 | k = random.randint(1, 50) 33 | ans = [] 34 | res = [] 35 | for i in range(n): 36 | swa.append(nums[i]) 37 | if i >= k - 1: 38 | ans.append(swa.query()) 39 | swa.popleft() 40 | lst = nums[i - k + 1: i + 1] 41 | res.append(reduce(op, lst)) 42 | assert len(res) == len(ans) 43 | assert res == ans 44 | 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /src/basis/binary_search/template.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BinarySearch: 4 | def __init__(self): 5 | return 6 | 7 | @staticmethod 8 | def find_int_left(low: int, high: int, check) -> int: 9 | """find the minimum int x which make check true""" 10 | while low < high: 11 | mid = low + (high - low) // 2 12 | if check(mid): 13 | high = mid 14 | else: 15 | low = mid + 1 16 | return low 17 | 18 | @staticmethod 19 | def find_int_right(low: int, high: int, check) -> int: 20 | """find the maximum int x which make check true""" 21 | while low < high: 22 | mid = low + (high - low + 1) // 2 23 | if check(mid): 24 | low = mid 25 | else: 26 | high = mid - 1 27 | return high 28 | 29 | @staticmethod 30 | def find_float_left(low: float, high: float, check, error=1e-6) -> float: 31 | """find the minimum float x which make check true""" 32 | while low < high - error: 33 | mid = low + (high - low) / 2 34 | if check(mid): 35 | high = mid 36 | else: 37 | low = mid 38 | return low if check(low) else high 39 | 40 | @staticmethod 41 | def find_float_right(low: float, high: float, check, error=1e-6) -> float: 42 | """find the maximum float x which make check true""" 43 | while low < high - error: 44 | mid = low + (high - low) / 2 45 | if check(mid): 46 | low = mid 47 | else: 48 | high = mid 49 | return high if check(high) else low 50 | -------------------------------------------------------------------------------- /src/struct/sorted_list/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.struct.sorted_list.template import SortedList 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_custom_sorted_list(self): 10 | 11 | for _ in range(10): 12 | floor = -10 ** 8 13 | ceil = 10 ** 8 14 | low = -5 * 10 ** 7 15 | high = 6 * 10 ** 8 16 | n = 10 ** 4 17 | # add 18 | lst = SortedList() 19 | local_lst = SortedList() 20 | for _ in range(n): 21 | num = random.randint(low, high) 22 | lst.add(num) 23 | local_lst.add(num) 24 | assert all(lst[i] == local_lst[i] for i in range(n)) 25 | # discard 26 | for _ in range(n): 27 | num = random.randint(low, high) 28 | lst.discard(num) 29 | 30 | local_lst.discard(num) 31 | m = len(lst) 32 | assert all(lst[i] == local_lst[i] for i in range(m)) 33 | # bisect_left 34 | for _ in range(n): 35 | num = random.randint(low, high) 36 | lst.add(num) 37 | local_lst.add(num) 38 | for _ in range(n): 39 | num = random.randint(floor, ceil) 40 | assert lst.bisect_left(num) == local_lst.bisect_left(num) 41 | # bisect_right 42 | for _ in range(n): 43 | num = random.randint(floor, ceil) 44 | assert lst.bisect_right(num) == local_lst.bisect_right(num) 45 | return 46 | 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /src/basis/implemention/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.basis.implemention.template import SpiralMatrix 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_spiral_matrix(self): 10 | sm = SpiralMatrix() 11 | nums = [[1, 2, 3, 4], [12, 13, 14, 5], [11, 16, 15, 6], [10, 9, 8, 7]] 12 | m = len(nums) 13 | n = len(nums[0]) 14 | for i in range(m): 15 | for j in range(n): 16 | assert sm.get_spiral_matrix_num1( 17 | m, n, i + 1, j + 1) == nums[i][j] 18 | assert sm.get_spiral_matrix_num2( 19 | m, n, i + 1, j + 1) == nums[i][j] 20 | 21 | nums = [[1, 2, 3, 4, 5, 6], [14, 15, 16, 17, 18, 7], 22 | [13, 12, 11, 10, 9, 8]] 23 | m = len(nums) 24 | n = len(nums[0]) 25 | for i in range(m): 26 | for j in range(n): 27 | assert sm.get_spiral_matrix_num1( 28 | m, n, i + 1, j + 1) == nums[i][j] 29 | assert sm.get_spiral_matrix_num2( 30 | m, n, i + 1, j + 1) == nums[i][j] 31 | 32 | for _ in range(10): 33 | m = random.randint(5, 100) 34 | n = random.randint(5, 100) 35 | for i in range(m): 36 | for j in range(n): 37 | num = sm.get_spiral_matrix_num1(m, n, i + 1, j + 1) 38 | assert sm.get_spiral_matrix_num2(m, n, i + 1, j + 1) == num 39 | assert sm.get_spiral_matrix_loc( 40 | m, n, num) == [i + 1, j + 1] 41 | 42 | return 43 | 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /src/math/cantor_expands/template.py: -------------------------------------------------------------------------------- 1 | from src.struct.sorted_list.template import SortedList 2 | from src.struct.tree_array.template import PointAddRangeSum 3 | 4 | 5 | class CantorExpands: 6 | def __init__(self, n, mod=0): 7 | self.mod = mod 8 | self.perm = [1] * (n + 1) 9 | for i in range(2, n): 10 | if mod: 11 | self.perm[i] = i * self.perm[i - 1] % mod 12 | else: 13 | self.perm[i] = i * self.perm[i - 1] 14 | return 15 | 16 | def array_to_rank(self, nums): 17 | """"permutation rank of nums""" 18 | n = len(nums) 19 | out = 1 20 | lst = SortedList(nums) 21 | for i in range(n): 22 | fact = self.perm[n - i - 1] 23 | res = lst.bisect_left(nums[i]) 24 | lst.discard(nums[i]) 25 | out += res * fact 26 | if self.mod: 27 | out %= self.mod 28 | return out 29 | 30 | def array_to_rank_with_tree(self, nums): 31 | """"permutation rank of nums""" 32 | n = len(nums) 33 | out = 1 34 | tree = PointAddRangeSum(n) 35 | tree.build([1] * n) 36 | for i in range(n): 37 | fact = self.perm[n - i - 1] 38 | res = tree.range_sum(0, nums[i] - 2) if nums[i] >= 2 else 0 39 | tree.point_add(nums[i] - 1, -1) 40 | out += res * fact 41 | if self.mod: 42 | out %= self.mod 43 | return out 44 | 45 | 46 | def rank_to_array(self, n, k): 47 | """"nums with permutation rank k""" 48 | nums = list(range(1, n + 1)) 49 | ans = [] 50 | while k and nums: 51 | single = self.perm[len(nums) - 1] 52 | i = (k - 1) // single 53 | ans.append(nums.pop(i)) 54 | k -= i * single 55 | return ans 56 | -------------------------------------------------------------------------------- /src/math/bit_operation/example.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import unittest 4 | 5 | from src.math.bit_operation.template import BitOperation, MinimumPairXor 6 | 7 | 8 | class TestGeneral(unittest.TestCase): 9 | 10 | def test_minimum_pair(self): 11 | for x in range(3): 12 | n = 5 * 10 ** x 13 | minimum_xor = MinimumPairXor() 14 | nums = [] 15 | for _ in range(n): 16 | num = random.randint(0, n) 17 | minimum_xor.add(num) 18 | nums.append(num) 19 | if len(nums) >= 2: 20 | c = len(nums) 21 | assert [minimum_xor.lst[i] for i in range(c)] == sorted(nums) 22 | floor = math.inf 23 | for a in range(c): 24 | for b in range(a + 1, c): 25 | cur = nums[a] ^ nums[b] 26 | if cur < floor: 27 | floor = cur 28 | assert floor == minimum_xor.query() 29 | return 30 | 31 | def test_bit_operation(self): 32 | bo = BitOperation() 33 | 34 | lst = [bo.integer_to_graycode(i) for i in range(11)] 35 | print(lst) 36 | 37 | assert bo.integer_to_graycode(0) == "0" 38 | assert bo.integer_to_graycode(22) == "11101" 39 | assert bo.graycode_to_integer("10110") == 27 40 | 41 | n = 8 42 | code = bo.get_graycode(n) 43 | m = len(code) 44 | for i in range(m): 45 | assert bo.graycode_to_integer(bin(code[i])[2:]) == i 46 | assert bo.integer_to_graycode(i) == bin(code[i])[2:] 47 | 48 | pre = 0 49 | for i in range(100000): 50 | pre ^= i 51 | assert bo.sum_xor(i) == pre 52 | return 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /src/math/scan_line/template.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | from typing import List 3 | 4 | 5 | class ScanLine: 6 | def __init__(self): 7 | return 8 | 9 | @staticmethod 10 | def get_sky_line(buildings: List[List[int]]) -> List[List[int]]: 11 | 12 | events = [] 13 | for left, right, height in buildings: 14 | events.append([left, -height, right]) 15 | events.append([right, 0, 0]) 16 | events.sort() 17 | 18 | res = [[0, 0]] 19 | stack = [[0, float('math.inf')]] 20 | for left, height, right in events: 21 | while left >= stack[0][1]: 22 | heapq.heappop(stack) 23 | if height < 0: 24 | heapq.heappush(stack, [height, right]) 25 | if res[-1][1] != -stack[0][0]: 26 | res.append([left, -stack[0][0]]) 27 | return res[1:] 28 | 29 | @staticmethod 30 | def get_rec_area(rectangles: List[List[int]]) -> int: 31 | 32 | axis = set() 33 | # [x1,y1,x2,y2] left_down to right_up 34 | for rec in rectangles: 35 | axis.add(rec[0]) 36 | axis.add(rec[2]) 37 | axis = sorted(list(axis)) 38 | ans = 0 39 | n = len(axis) 40 | for i in range(n - 1): 41 | 42 | x1, x2 = axis[i], axis[i + 1] 43 | width = x2 - x1 44 | if not width: 45 | continue 46 | 47 | items = [[rec[1], rec[3]] for rec in rectangles if rec[0] < x2 and rec[2] > x1] 48 | items.sort(key=lambda x: [x[0], -x[1]]) 49 | height = low = high = 0 50 | for y1, y2 in items: 51 | if y1 >= high: 52 | height += high - low 53 | low, high = y1, y2 54 | else: 55 | high = high if high > y2 else y2 56 | height += high - low 57 | 58 | ans += width * height 59 | return ans 60 | -------------------------------------------------------------------------------- /src/basis/tree_node/template.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | 4 | class TreeNode: 5 | def __init__(self, val=0, left=None, right=None): 6 | self.val = val 7 | self.left = left 8 | self.right = right 9 | 10 | 11 | class TreeOrder: 12 | def __init__(self): 13 | return 14 | 15 | @staticmethod 16 | def post_order(root: Optional[TreeNode]) -> List[int]: 17 | ans = [] 18 | stack = [[root, 1]] if root else [] 19 | while stack: 20 | node, state = stack.pop() 21 | if state: 22 | stack.append([node, 0]) 23 | if node.right: 24 | stack.append([node.right, 1]) 25 | if node.left: 26 | stack.append([node.left, 1]) 27 | else: 28 | ans.append(node.val) 29 | return ans 30 | 31 | @staticmethod 32 | def pre_order(root: Optional[TreeNode]) -> List[int]: 33 | ans = [] 34 | stack = [[root, 1]] if root else [] 35 | while stack: 36 | node, state = stack.pop() 37 | if state: 38 | if node.right: 39 | stack.append([node.right, 1]) 40 | if node.left: 41 | stack.append([node.left, 1]) 42 | stack.append([node, 0]) 43 | else: 44 | ans.append(node.val) 45 | return ans 46 | 47 | @staticmethod 48 | def in_order(root: Optional[TreeNode]) -> List[int]: 49 | ans = [] 50 | stack = [[root, 1]] if root else [] 51 | while stack: 52 | node, state = stack.pop() 53 | if state: 54 | if node.right: 55 | stack.append([node.right, 1]) 56 | stack.append([node, 0]) 57 | if node.left: 58 | stack.append([node.left, 1]) 59 | else: 60 | ans.append(node.val) 61 | return ans 62 | -------------------------------------------------------------------------------- /src/math/linear_basis/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from functools import reduce 4 | from operator import xor 5 | 6 | from src.math.linear_basis.template import LinearBasis 7 | 8 | 9 | class TestGeneral(unittest.TestCase): 10 | 11 | def test_linear_basis(self): 12 | for x in range(1000): 13 | m = 10 14 | lst = [random.randint(0, 1000000) for _ in range(m)] 15 | if x == 0: 16 | lst = [0] 17 | m = 1 18 | nums = [0] 19 | zero = 0 20 | for i in range(1, 1 << m): 21 | nums.append(reduce(xor, [lst[j] for j in range(m) if i & (1 << j)])) 22 | if not nums[-1]: 23 | zero = 1 24 | nums = sorted(set(nums)) 25 | lb = LinearBasis(20) 26 | for num in lst: 27 | lb.add(num) 28 | assert len(nums) == lb.tot 29 | assert lb.zero == zero 30 | n = len(nums) 31 | for i in range(n): 32 | assert lb.query_kth_xor(i) == nums[i] 33 | assert lb.query_xor_kth(nums[i]) == i 34 | 35 | x = random.randint(0, 1000) 36 | lst.append(x) 37 | m += 1 38 | zero = 0 39 | nums = [0] 40 | 41 | for i in range(1, 1 << m): 42 | nums.append(reduce(xor, [lst[j] for j in range(m) if i & (1 << j)])) 43 | if not nums[-1]: 44 | zero = 1 45 | nums = sorted(set(nums)) 46 | lb.add(x) 47 | assert len(nums) == lb.tot 48 | assert lb.zero == zero 49 | n = len(nums) 50 | for i in range(n): 51 | assert lb.query_kth_xor(i) == nums[i] 52 | assert lb.query_xor_kth(nums[i]) == i 53 | assert lb.query_max() == nums[-1] 54 | assert lb.query_min() == nums[0] 55 | return 56 | 57 | 58 | if __name__ == '__main__': 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /src/tree/tree_diameter/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class GraphDiameter: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def get_diameter(dct, root=0): 10 | n = len(dct) 11 | dis = [math.inf] * n 12 | stack = [root] 13 | dis[root] = 0 14 | while stack: 15 | nex = [] 16 | for i in stack: 17 | for j in dct[i]: 18 | if dis[j] == math.inf: 19 | dis[j] = dis[i] + 1 20 | nex.append(j) 21 | stack = nex[:] 22 | root = dis.index(max(dis)) 23 | dis = [math.inf] * n 24 | stack = [root] 25 | dis[root] = 0 26 | while stack: 27 | nex = [] 28 | for i in stack: 29 | for j in dct[i]: 30 | if dis[j] == math.inf: 31 | dis[j] = dis[i] + 1 32 | nex.append(j) 33 | stack = nex[:] 34 | return max(dis) 35 | 36 | 37 | class TreeDiameter: 38 | def __init__(self, dct): 39 | self.n = len(dct) 40 | self.dct = dct 41 | return 42 | 43 | def get_bfs_dis(self, root): 44 | dis = [math.inf] * self.n 45 | stack = [root] 46 | dis[root] = 0 47 | parent = [-1] * self.n 48 | while stack: 49 | i = stack.pop() 50 | for j, w in self.dct[i]: # weighted edge 51 | if j != parent[i]: 52 | parent[j] = i 53 | dis[j] = dis[i] + w 54 | stack.append(j) 55 | return dis, parent 56 | 57 | def get_diameter_info(self): 58 | """get tree diameter detail by weighted bfs twice""" 59 | dis, _ = self.get_bfs_dis(0) 60 | x = dis.index(max(dis)) 61 | dis, parent = self.get_bfs_dis(x) 62 | y = dis.index(max(dis)) 63 | path = [y] 64 | while path[-1] != x: 65 | path.append(parent[path[-1]]) 66 | path.reverse() 67 | return x, y, path, dis[y] 68 | -------------------------------------------------------------------------------- /src/basis/circular_section/template.py: -------------------------------------------------------------------------------- 1 | class CircleSection: 2 | def __init__(self): 3 | return 4 | 5 | @staticmethod 6 | def compute_circle_result(n: int, m: int, x: int, tm: int) -> int: 7 | """use hash table and list to record the first pos of circle section""" 8 | dct = dict() 9 | # example is x = (x + m) % n 10 | lst = [] 11 | while x not in dct: 12 | dct[x] = len(lst) 13 | lst.append(x) 14 | x = (x + m) % n 15 | 16 | length = len(lst) 17 | # the first pos of circle section 18 | ind = dct[x] 19 | # current lst is enough 20 | if tm < length: 21 | return lst[tm] 22 | 23 | # compute by circle section 24 | circle = length - ind 25 | tm -= length 26 | j = tm % circle 27 | return lst[ind + j] 28 | 29 | @staticmethod 30 | def circle_section_pre(n, grid, c, sta, cur, h): 31 | """circle section with prefix sum""" 32 | dct = dict() 33 | lst = [] 34 | cnt = [] 35 | while sta not in dct: 36 | dct[sta] = len(dct) 37 | lst.append(sta) 38 | cnt.append(c) 39 | sta = cur 40 | c = 0 41 | cur = 0 42 | for i in range(n): 43 | num = 1 if sta & (1 << i) else 2 44 | for j in range(n): 45 | if grid[i][j] == "1": 46 | c += num 47 | cur ^= (num % 2) * (1 << j) 48 | 49 | length = len(lst) 50 | ind = dct[sta] 51 | pre = [0] * (length + 1) 52 | for i in range(length): 53 | pre[i + 1] = pre[i] + cnt[i] 54 | 55 | ans = 0 56 | if h < length: 57 | return ans + pre[h] 58 | 59 | circle = length - ind 60 | circle_cnt = pre[length] - pre[ind] 61 | 62 | h -= length 63 | ans += pre[length] 64 | 65 | ans += (h // circle) * circle_cnt 66 | 67 | j = h % circle 68 | ans += pre[ind + j] - pre[ind] 69 | return ans 70 | -------------------------------------------------------------------------------- /src/graph/binary_search_tree/template.py: -------------------------------------------------------------------------------- 1 | from src.graph.union_find.template import UnionFind 2 | 3 | 4 | class BinarySearchTree: 5 | 6 | def __init__(self): 7 | return 8 | 9 | @staticmethod 10 | def build_with_unionfind(nums): 11 | """build binary search tree by the order of nums with unionfind""" 12 | 13 | n = len(nums) 14 | ind = list(range(n)) 15 | ind.sort(key=lambda it: nums[it]) 16 | rank = {idx: i for i, idx in enumerate(ind)} 17 | 18 | dct = [[] for _ in range(n)] 19 | uf = UnionFind(n) 20 | post = {} 21 | for i in range(n - 1, -1, -1): 22 | x = rank[i] 23 | if x + 1 in post: 24 | r = uf.find(post[x + 1]) 25 | dct[i].append(r) 26 | uf.union_left(i, r) 27 | if x - 1 in post: 28 | r = uf.find(post[x - 1]) 29 | dct[i].append(r) 30 | uf.union_left(i, r) 31 | post[x] = i 32 | return dct 33 | 34 | @staticmethod 35 | def build_with_stack(nums): 36 | """build binary search tree by the order of nums with stack""" 37 | 38 | n = len(nums) 39 | 40 | lst = sorted(nums) 41 | dct = {num: i + 1 for i, num in enumerate(lst)} 42 | ind = {num: i for i, num in enumerate(nums)} 43 | 44 | order = [dct[i] for i in nums] 45 | father, occur, stack = [0] * (n + 1), [0] * (n + 1), [] 46 | deep = [0] * (n + 1) 47 | for i, x in enumerate(order, 1): 48 | occur[x] = i 49 | 50 | for x, i in enumerate(occur): 51 | while stack and occur[stack[-1]] > i: 52 | if occur[father[stack[-1]]] < i: 53 | father[stack[-1]] = x 54 | stack.pop() 55 | if stack: 56 | father[x] = stack[-1] 57 | stack.append(x) 58 | 59 | for x in order: 60 | deep[x] = 1 + deep[father[x]] 61 | 62 | dct = [[] for _ in range(n)] 63 | for i in range(1, n + 1): 64 | if father[i]: 65 | u, v = father[i] - 1, i - 1 66 | x, y = ind[lst[u]], ind[lst[v]] 67 | dct[x].append(y) 68 | return dct 69 | -------------------------------------------------------------------------------- /src/basis/serialization/template.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Optional 3 | 4 | from src.basis.tree_node.template import TreeNode 5 | 6 | 7 | class CodecBFS: 8 | 9 | @staticmethod 10 | def serialize(root: Optional[TreeNode]) -> str: 11 | """Encodes a tree to a single string. 12 | """ 13 | stack = deque([root]) if root else deque() 14 | res = [] 15 | while stack: 16 | node = stack.popleft() 17 | if not node: 18 | res.append("n") 19 | continue 20 | else: 21 | res.append(str(node.val)) 22 | stack.append(node.left) 23 | stack.append(node.right) 24 | return ",".join(res) 25 | 26 | @staticmethod 27 | def deserialize(data: str) -> Optional[TreeNode]: 28 | """Decodes your encoded data to tree. 29 | """ 30 | if not data: 31 | return 32 | lst = deque(data.split(",")) 33 | ans = TreeNode(int(lst.popleft())) 34 | stack = deque([ans]) 35 | while lst: 36 | left, right = lst.popleft(), lst.popleft() 37 | pre = stack.popleft() 38 | if left != "n": 39 | pre.left = TreeNode(int(left)) 40 | stack.append(pre.left) 41 | if right != "n": 42 | pre.right = TreeNode(int(right)) 43 | stack.append(pre.right) 44 | return ans 45 | 46 | 47 | class CodecDFS: 48 | 49 | @staticmethod 50 | def serialize(root: TreeNode) -> str: 51 | """Encodes a tree to a single string. 52 | """ 53 | def dfs(node): 54 | if not node: 55 | return "n" 56 | return dfs(node.right) + "," + dfs(node.left) + "," + str(node.val) 57 | 58 | return dfs(root) 59 | 60 | @staticmethod 61 | def deserialize(data: str) -> TreeNode: 62 | """Decodes your encoded data to tree. 63 | """ 64 | lst = data.split(",") 65 | 66 | def dfs(): 67 | if not lst: 68 | return 69 | val = lst.pop() 70 | if val == "n": 71 | return 72 | root = TreeNode(int(val)) 73 | root.left = dfs() 74 | root.right = dfs() 75 | return root 76 | 77 | return dfs() 78 | -------------------------------------------------------------------------------- /src/basis/diff_array/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.basis.diff_array.template import DiffArray, DiffMatrix, PreFixSumMatrix 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_diff_array_range(self): 10 | dar = DiffArray() 11 | n = 3 12 | shifts = [[0, 1, 1], [1, 2, -1]] 13 | diff = dar.get_diff_array(n, shifts) 14 | assert diff == [1, 0, -1] 15 | 16 | n = 3 17 | shifts = [1, 2, 3] 18 | pre = dar.get_array_prefix_sum(n, shifts) 19 | assert pre == [0, 1, 3, 6] 20 | 21 | left = 1 22 | right = 2 23 | assert dar.get_array_range_sum(pre, left, right) == 5 24 | return 25 | 26 | def test_diff_array_matrix(self): 27 | dam = DiffMatrix() 28 | m = 3 29 | n = 3 30 | shifts = [[1, 2, 1, 2, 1], [2, 3, 2, 3, 1], 31 | [2, 2, 2, 2, 2], [1, 1, 3, 3, 3]] 32 | diff = [[1, 1, 3], [1, 4, 1], [0, 1, 1]] 33 | assert dam.get_diff_matrix(m, n, shifts) == diff 34 | 35 | shifts = [[1, 2, 1, 2, 1], [2, 3, 2, 3, 1], 36 | [2, 2, 2, 2, 2], [1, 1, 3, 3, 3]] 37 | shifts = [[x - 1 for x in ls[:-1]] + [ls[-1]] for ls in shifts] 38 | assert dam.get_diff_matrix2(m, n, shifts) == diff 39 | 40 | random.seed(2023) 41 | for _ in range(10): 42 | m = n = 2000 43 | nums = [[0] * n for _ in range(m)] 44 | shifts = [] 45 | for _ in range(100): 46 | x1 = 0 47 | y1 = 0 48 | x2 = m - 1 49 | y2 = n - 1 50 | num = random.randint(0, n) 51 | for i in range(x1, x2 + 1): 52 | for j in range(y1, y2 + 1): 53 | nums[i][j] += num 54 | shifts.append([x1, x2, y1, y2, num]) 55 | assert nums == dam.get_diff_matrix3(m, n, shifts) 56 | return 57 | 58 | def test_pre_fix_sum_matrix(self): 59 | diff = [[1, 1, 3], [1, 4, 1], [0, 1, 1]] 60 | pre = PreFixSumMatrix(diff) 61 | assert pre.pre == [[0, 0, 0, 0], [0, 1, 2, 5], [0, 2, 7, 11], [0, 2, 8, 13]] 62 | 63 | xa, ya, xb, yb = 1, 1, 2, 2 64 | assert pre.query(xa, ya, xb, yb) == sum(sum(d[ya: yb + 1]) for d in diff[xa: xb + 1]) 65 | return 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /src/struct/monotonic_stack/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.struct.monotonic_stack.template import MonotonicStack 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_monotonic_stack(self): 10 | n = 1000 11 | nums = [random.randint(0, n) for _ in range(n)] 12 | ms = MonotonicStack(nums) 13 | for i in range(n): 14 | 15 | pre_bigger = pre_bigger_equal = pre_smaller = pre_smaller_equal = -1 16 | for j in range(i - 1, -1, -1): 17 | if nums[j] > nums[i]: 18 | pre_bigger = j 19 | break 20 | for j in range(i - 1, -1, -1): 21 | if nums[j] >= nums[i]: 22 | pre_bigger_equal = j 23 | break 24 | for j in range(i - 1, -1, -1): 25 | if nums[j] < nums[i]: 26 | pre_smaller = j 27 | break 28 | for j in range(i - 1, -1, -1): 29 | if nums[j] <= nums[i]: 30 | pre_smaller_equal = j 31 | break 32 | assert pre_bigger == ms.pre_bigger[i] 33 | assert pre_bigger_equal == ms.pre_bigger_equal[i] 34 | assert pre_smaller == ms.pre_smaller[i] 35 | assert pre_smaller_equal == ms.pre_smaller_equal[i] 36 | 37 | post_bigger = post_bigger_equal = post_smaller = post_smaller_equal = - 1 38 | for j in range(i + 1, n): 39 | if nums[j] > nums[i]: 40 | post_bigger = j 41 | break 42 | for j in range(i + 1, n): 43 | if nums[j] >= nums[i]: 44 | post_bigger_equal = j 45 | break 46 | for j in range(i + 1, n): 47 | if nums[j] < nums[i]: 48 | post_smaller = j 49 | break 50 | for j in range(i + 1, n): 51 | if nums[j] <= nums[i]: 52 | post_smaller_equal = j 53 | break 54 | assert post_bigger == ms.post_bigger[i] 55 | assert post_bigger_equal == ms.post_bigger_equal[i] 56 | assert post_smaller == ms.post_smaller[i] 57 | assert post_smaller_equal == ms.post_smaller_equal[i] 58 | 59 | return 60 | 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /src/graph/prufer/template.py: -------------------------------------------------------------------------------- 1 | class PruferAndTree: 2 | def __init__(self): 3 | return 4 | 5 | @staticmethod 6 | def adj_to_parent(adj, root): 7 | 8 | def dfs(v): 9 | for u in adj[v]: 10 | if u != parent[v]: 11 | parent[u] = v 12 | dfs(u) 13 | 14 | n = len(adj) 15 | parent = [-1] * n 16 | dfs(root) 17 | return parent 18 | 19 | @staticmethod 20 | def parent_to_adj(parent): 21 | n = len(parent) 22 | adj = [[] for _ in range(n)] 23 | for i in range(n): 24 | if parent[i] != -1: 25 | adj[i].append(parent[i]) 26 | adj[parent[i]].append(i) 27 | return parent 28 | 29 | def tree_to_prufer(self, adj, root): 30 | parent = self.adj_to_parent(adj, root) 31 | n = len(adj) 32 | ptr = -1 33 | degree = [0] * n 34 | for i in range(0, n): 35 | degree[i] = len(adj[i]) 36 | if degree[i] == 1 and ptr == -1: 37 | ptr = i 38 | 39 | code = [0] * (n - 2) 40 | leaf = ptr 41 | for i in range(0, n - 2): 42 | nex = parent[leaf] 43 | code[i] = nex 44 | degree[nex] -= 1 45 | if degree[nex] == 1 and nex < ptr: 46 | leaf = nex 47 | else: 48 | ptr = ptr + 1 49 | while degree[ptr] != 1: 50 | ptr = ptr + 1 51 | leaf = ptr 52 | return code 53 | 54 | @staticmethod 55 | def prufer_to_tree(code, root): 56 | n = len(code) + 2 57 | 58 | degree = [1] * n 59 | for i in code: 60 | degree[i] += 1 61 | ptr = 0 62 | while degree[ptr] != 1: 63 | ptr += 1 64 | leaf = ptr 65 | 66 | adj = [[] for _ in range(n)] 67 | for v in code: 68 | adj[v].append(leaf) 69 | adj[leaf].append(v) 70 | degree[v] -= 1 71 | if degree[v] == 1 and v < ptr and v != root: 72 | leaf = v 73 | else: 74 | ptr += 1 75 | while degree[ptr] != 1: 76 | ptr += 1 77 | leaf = ptr 78 | 79 | adj[leaf].append(root) 80 | adj[root].append(leaf) 81 | for i in range(n): 82 | adj[i].sort() 83 | return adj 84 | -------------------------------------------------------------------------------- /tests/codeforces/template.py: -------------------------------------------------------------------------------- 1 | from sys import stdin 2 | 3 | 4 | class FastIO: 5 | def __init__(self): 6 | self.random_seed = 0 7 | self.flush = False 8 | self.inf = 1 << 32 9 | return 10 | 11 | @staticmethod 12 | def read_int(): 13 | return int(stdin.readline().rstrip()) 14 | 15 | @staticmethod 16 | def read_float(): 17 | return float(stdin.readline().rstrip()) 18 | 19 | @staticmethod 20 | def read_list_ints(): 21 | return list(map(int, stdin.readline().rstrip().split())) 22 | 23 | @staticmethod 24 | def read_list_ints_minus_one(): 25 | return list(map(lambda x: int(x) - 1, stdin.readline().rstrip().split())) 26 | 27 | @staticmethod 28 | def read_str(): 29 | return stdin.readline().rstrip() 30 | 31 | @staticmethod 32 | def read_list_strs(): 33 | return stdin.readline().rstrip().split() 34 | 35 | def get_random_seed(self): 36 | import random 37 | self.random_seed = random.randint(0, 10 ** 9 + 7) 38 | return 39 | 40 | def st(self, x): 41 | return print(x, flush=self.flush) 42 | 43 | def lst(self, x): 44 | return print(*x, flush=self.flush) 45 | 46 | def flatten(self, lst): 47 | self.st("\n".join(str(x) for x in lst)) 48 | return 49 | 50 | def yes(self, s=None): 51 | self.st("Yes" if not s else s) 52 | return 53 | 54 | def no(self, s=None): 55 | self.st("No" if not s else s) 56 | return 57 | 58 | @staticmethod 59 | def max(a, b): 60 | return a if a > b else b 61 | 62 | @staticmethod 63 | def min(a, b): 64 | return a if a < b else b 65 | 66 | @staticmethod 67 | def ceil(a, b): 68 | return a // b + int(a % b != 0) 69 | 70 | @staticmethod 71 | def accumulate(nums): 72 | n = len(nums) 73 | pre = [0] * (n + 1) 74 | for i in range(n): 75 | pre[i + 1] = pre[i] + nums[i] 76 | return pre 77 | 78 | 79 | class Solution: 80 | def __init__(self): 81 | return 82 | 83 | @staticmethod 84 | def main(ac=FastIO()): 85 | """ 86 | url: https://codeforces.com/problemset/problem/1208/D 87 | tag: segment_tree|reverse_thinking|construction|point_set|range_sum_bisect_left 88 | """ 89 | for _ in range(ac.read_int()): 90 | pass 91 | return 92 | 93 | 94 | Solution().main() 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Templates, Examples and Problems Of Data Structure and Algorithm 4 | 5 | # Overview 6 | This project is orgnaized by the following structure 7 | > **[src](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src)** serve as template records for training 8 | > >**[basis](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/basis)** are some basic usage of date structure and algorithm 9 | 10 | > >**[struct](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/struct)** are some commonly used data structures 11 | 12 | > >**[dp](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/dp)** are summaries of dynamic programming 13 | 14 | 15 | > >**[graph](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/graph)** are some simple and classic graph theory algorithms 16 | 17 | 18 | > >**[greedy](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/greedy)** are some classic greedy problems 19 | 20 | 21 | > >**[math](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/mathmatics)** are mathematics related algorithms 22 | 23 | 24 | >> **[tree](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/search)** are some simple and classic tree relative algorithms 25 | 26 | >> **[string](https://github.com/liupengsay/PyIsTheBestLang/tree/main/src/strings)** are some classic usages of strings 27 | 28 | >>> **template.py** have some data structure and algorithm templates 29 | 30 | 31 | 32 | >>> **example.py** have some unit tests of the related templates 33 | 34 | 35 | 36 | >>> **problem.py** have some problems which can be solved by the related 37 | 38 | 39 | 40 | > **[tests](https://github.com/liupengsay/PyIsTheBestLang/tree/main/tests)** serve as work space during competitions 41 | 42 | > > **[codeforces](https://github.com/liupengsay/PyIsTheBestLang/tree/main/tests/codeforces)** prepared for codeforces, luogu, atcoder, acwing and so on 43 | 44 | > > **[leetcode](https://github.com/liupengsay/PyIsTheBestLang/tree/main/tests/leetcode)** prepared for leetcode only 45 | 46 | >>> **template.py** is a template for copy and paste before starting to solve the problem 47 | 48 | 49 | # Profile 50 | Welcome to follow or star in my profile on different platform 51 | > [leetcode](https://leetcode.cn/u/liupengsay/) 52 | > [codeforces](https://codeforces.com/profile/liupengsay) 53 | > [atcoder](https://atcoder.jp/users/liupengsay) 54 | > [luogu](https://www.luogu.com.cn/user/739032) 55 | 56 | 57 | ## Thanks for Reading, and Good Luck on Training 58 | -------------------------------------------------------------------------------- /src/math/convex_hull/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from typing import List 4 | 5 | 6 | class MinCircleOverlap: 7 | def __init__(self): 8 | self.pi = math.acos(-1) 9 | self.esp = 10 ** (-10) 10 | return 11 | 12 | def get_min_circle_overlap(self, points: List[List[int]]): 13 | 14 | def cross(a, b): 15 | return a[0] * b[1] - b[0] * a[1] 16 | 17 | def intersection_point(p1, v1, p2, v2): 18 | u = (p1[0] - p2[0], p1[1] - p2[1]) 19 | t = cross(v2, u) / cross(v1, v2) 20 | return p1[0] + v1[0] * t, p1[1] + v1[1] * t 21 | 22 | def is_point_in_circle(circle_x, circle_y, circle_r, x, y): 23 | res = math.sqrt((x - circle_x) ** 2 + (y - circle_y) ** 2) 24 | if abs(res - circle_r) < self.esp: 25 | return True 26 | if res < circle_r: 27 | return True 28 | return False 29 | 30 | def vec_rotate(v, theta): 31 | x, y = v 32 | return x * math.cos(theta) + y * math.sin(theta), -x * math.sin(theta) + y * math.cos(theta) 33 | 34 | def get_out_circle(x1, y1, x2, y2, x3, y3): 35 | xx1, yy1 = (x1 + x2) / 2, (y1 + y2) / 2 36 | vv1 = vec_rotate((x2 - x1, y2 - y1), self.pi / 2) 37 | xx2, yy2 = (x1 + x3) / 2, (y1 + y3) / 2 38 | vv2 = vec_rotate((x3 - x1, y3 - y1), self.pi / 2) 39 | pp = intersection_point((xx1, yy1), vv1, (xx2, yy2), vv2) 40 | res = math.sqrt((pp[0] - x1) ** 2 + (pp[1] - y1) ** 2) 41 | return pp[0], pp[1], res 42 | 43 | random.shuffle(points) 44 | n = len(points) 45 | p = points 46 | 47 | cc1 = (p[0][0], p[0][1], 0) 48 | for ii in range(1, n): 49 | if not is_point_in_circle(cc1[0], cc1[1], cc1[2], p[ii][0], p[ii][1]): 50 | cc2 = (p[ii][0], p[ii][1], 0) 51 | for jj in range(ii): 52 | if not is_point_in_circle(cc2[0], cc2[1], cc2[2], p[jj][0], p[jj][1]): 53 | dis = math.sqrt((p[jj][0] - p[ii][0]) ** 2 + (p[jj][1] - p[ii][1]) ** 2) 54 | cc3 = ((p[jj][0] + p[ii][0]) / 2, (p[jj][1] + p[ii][1]) / 2, dis / 2) 55 | for kk in range(jj): 56 | if not is_point_in_circle(cc3[0], cc3[1], cc3[2], p[kk][0], p[kk][1]): 57 | cc3 = get_out_circle(p[ii][0], p[ii][1], p[jj][0], p[jj][1], p[kk][0], p[kk][1]) 58 | cc2 = cc3 59 | cc1 = cc2 60 | 61 | return cc1 62 | -------------------------------------------------------------------------------- /src/math/lexico_graphical_order/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from itertools import permutations, combinations 4 | 5 | from src.math.lexico_graphical_order.template import LexicoGraphicalOrder, Permutation 6 | 7 | 8 | class TestGeneral(unittest.TestCase): 9 | 10 | def test_lexico_graphical_order(self): 11 | lgo = LexicoGraphicalOrder() 12 | 13 | n = 10 ** 5 14 | nums = sorted([str(x) for x in range(1, n + 1)]) 15 | for _ in range(100): 16 | i = random.randint(0, n - 1) 17 | num = nums[i] 18 | assert lgo.get_kth_num(n, i + 1) == int(num) 19 | assert lgo.get_num_kth(n, int(num)) == i + 1 20 | 21 | n = 10 22 | nums = [] 23 | for i in range(1 << n): 24 | nums.append([j + 1 for j in range(n) if i & (1 << j)]) 25 | nums.sort() 26 | nums[0] = [0] 27 | for _ in range(100): 28 | i = random.randint(0, n - 1) 29 | lst = nums[i] 30 | assert lgo.get_kth_subset(n, i + 1) == lst 31 | assert lgo.get_subset_kth(n, lst) == i + 1 32 | 33 | n = 10 34 | m = 4 35 | nums = [] 36 | for item in combinations(list(range(1, n + 1)), m): 37 | nums.append(list(item)) 38 | for _ in range(100): 39 | i = random.randint(0, len(nums) - 1) 40 | lst = nums[i] 41 | assert lgo.get_kth_subset_comb(n, m, i + 1) == lst 42 | assert lgo.get_subset_comb_kth(n, m, lst) == i + 1 43 | 44 | n = 8 45 | nums = [] 46 | for item in permutations(list(range(1, n + 1)), n): 47 | nums.append(list(item)) 48 | for i, lst in enumerate(nums): 49 | lst = nums[i] 50 | assert lgo.get_kth_subset_perm(n, i + 1) == lst 51 | assert lgo.get_subset_perm_kth(n, lst) == i + 1 52 | return 53 | 54 | def test_permutation(self): 55 | n = 8 56 | pm = Permutation() 57 | for x in range(100): 58 | nums = [random.randint(0, n - 1) for _ in range(n)] 59 | if x == 0: 60 | nums = list(range(n)) 61 | tot = set() 62 | for item in permutations(nums, n): 63 | tot.add(tuple(item)) 64 | tot = sorted(tot) 65 | m = len(tot) 66 | for i in range(m): 67 | assert pm.next_permutation(list(tot[i])) == list(tot[(i + 1) % m]) 68 | assert pm.prev_permutation(list(tot[i])) == list(tot[(i - 1) % m]) 69 | return 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /src/string/palindrome_num/template.py: -------------------------------------------------------------------------------- 1 | class PalindromeNum: 2 | def __init__(self): 3 | return 4 | 5 | @staticmethod 6 | def get_palindrome_num_1(n): 7 | """template of get all positive palindrome number with length not greater than n""" 8 | dp = [[""], [str(i) for i in range(10)]] 9 | for k in range(2, n + 1): 10 | # like dp to add palindrome character 11 | if k % 2 == 1: 12 | m = k // 2 13 | lst = [] 14 | for st in dp[-1]: 15 | for i in range(10): 16 | lst.append(st[:m] + str(i) + st[m:]) 17 | dp.append(lst) 18 | else: 19 | lst = [] 20 | for st in dp[-2]: 21 | for i in range(10): 22 | lst.append(str(i) + st + str(i)) 23 | dp.append(lst) 24 | 25 | nums = [] 26 | for lst in dp: 27 | for num in lst: 28 | if num and num[0] != "0": 29 | nums.append(int(num)) 30 | nums.sort() 31 | return nums 32 | 33 | @staticmethod 34 | def get_palindrome_num_2(n): 35 | assert n >= 1 36 | """template of get all positive palindrome number whose length not greater than n""" 37 | nums = list(range(1, 10)) 38 | x = 1 39 | while len(str(x)) * 2 <= n: 40 | num = str(x) + str(x)[::-1] 41 | nums.append(int(num)) 42 | if len(str(x)) * 2 + 1 <= n: 43 | for d in range(10): 44 | nums.append(int(str(x) + str(d) + str(x)[::-1])) 45 | x += 1 46 | nums.sort() 47 | return nums 48 | 49 | @staticmethod 50 | def get_palindrome_num_3(): 51 | """template of get all positive palindrome number whose length not greater than n""" 52 | nums = list(range(10)) 53 | for i in range(1, 10 ** 5): 54 | nums.append(int(str(i) + str(i)[::-1])) 55 | for j in range(10): 56 | nums.append(int(str(i) + str(j) + str(i)[::-1])) 57 | nums.sort() 58 | return nums 59 | 60 | @staticmethod 61 | def get_recent_palindrome_num(n: str) -> list: 62 | """template of recentest palindrome num of n""" 63 | m = len(n) 64 | candidates = [10 ** (m - 1) - 1, 10 ** m + 1] 65 | prefix = int(n[:(m + 1) // 2]) 66 | for x in range(prefix - 1, prefix + 2): 67 | y = x if m % 2 == 0 else x // 10 68 | while y: 69 | x = x * 10 + y % 10 70 | y //= 10 71 | candidates.append(x) 72 | return candidates 73 | -------------------------------------------------------------------------------- /src/math/nim_game/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Algorithm:nim_game|sg_theorem 4 | Description:game_dp|winning_state|lose_state|sprague_grundy|sg_theorem 5 | 6 | =====================================LuoGu====================================== 7 | P2197(https://www.luogu.com.cn/problem/P2197)xor_sum|classical 8 | 9 | ===================================CodeForces=================================== 10 | 1396B(https://codeforces.com/contest/1396/problem/B)greedy|game_dp 11 | 2004E(https://codeforces.com/problemset/problem/2004/E)sprague_grundy|sg_theorem|game 12 | 13 | 14 | """ 15 | 16 | from src.math.nim_game.template import Nim 17 | from src.math.prime_factor.template import PrimeFactor 18 | from src.util.fast_io import FastIO 19 | 20 | 21 | class Solution: 22 | def __init__(self): 23 | return 24 | 25 | @staticmethod 26 | def cf_1396b(ac=FastIO()): 27 | """ 28 | url: https://codeforces.com/contest/1396/problem/B 29 | tag: greedy|game_dp 30 | """ 31 | for _ in range(ac.read_int()): 32 | ac.read_int() 33 | nums = ac.read_list_ints() 34 | ceil = max(nums) 35 | s = sum(nums) 36 | if ceil > s - ceil or s % 2: 37 | ac.st("T") 38 | else: 39 | ac.st("HL") 40 | return 41 | 42 | @staticmethod 43 | def lg_p2197(ac=FastIO()): 44 | """ 45 | url: https://www.luogu.com.cn/problem/P2197 46 | tag: xor_sum|classical 47 | """ 48 | for _ in range(ac.read_int()): 49 | ac.read_int() 50 | lst = ac.read_list_ints() 51 | nim = Nim(lst) 52 | if nim.gen_result(): 53 | ans = "Yes" 54 | else: 55 | ans = "No" 56 | ac.st(ans) 57 | return 58 | 59 | @staticmethod 60 | def cf_2004e(ac=FastIO()): 61 | """ 62 | url: https://codeforces.com/problemset/problem/2004/E 63 | tag: sprague_grundy|sg_theorem|game 64 | """ 65 | ceil = 10 ** 7 66 | pf = PrimeFactor(ceil + 10) 67 | sg = [0] * (ceil + 1) 68 | sg[1] = 1 69 | tot = 1 70 | for i in range(3, ceil + 1): 71 | if pf.min_prime[i] == i: 72 | tot += 1 73 | sg[i] = tot 74 | else: 75 | sg[i] = sg[pf.min_prime[i]] 76 | 77 | for _ in range(ac.read_int()): 78 | ac.read_int() 79 | nums = ac.read_list_ints() 80 | ans = 0 81 | for num in nums: 82 | ans ^= sg[num] 83 | if ans: 84 | ac.st("Alice") 85 | else: 86 | ac.st("Bob") 87 | return 88 | -------------------------------------------------------------------------------- /src/math/linear_basis/template.py: -------------------------------------------------------------------------------- 1 | from src.basis.binary_search.template import BinarySearch 2 | 3 | 4 | class LinearBasis: 5 | def __init__(self, m=64): 6 | self.m = m 7 | self.basis = [0] * self.m 8 | self.cnt = self.count_diff_xor() 9 | self.tot = 1 << self.cnt 10 | self.num = 0 11 | self.zero = 0 12 | self.length = 0 13 | return 14 | 15 | def minimize(self, x): 16 | for i in range(self.m): 17 | if x >> i & 1: 18 | x ^= self.basis[i] 19 | return x 20 | 21 | def add(self, x): 22 | assert x <= (1 << self.m) - 1 23 | x = self.minimize(x) 24 | self.num += 1 25 | if x: 26 | self.length += 1 27 | self.zero = int(self.length < self.num) 28 | 29 | for i in range(self.m - 1, -1, -1): 30 | if x >> i & 1: 31 | for j in range(self.m): 32 | if self.basis[j] >> i & 1: 33 | self.basis[j] ^= x 34 | self.basis[i] = x 35 | self.cnt = self.count_diff_xor() 36 | self.tot = 1 << self.cnt 37 | return True 38 | return False 39 | 40 | def count_diff_xor(self): 41 | num = 0 42 | for i in range(self.m): 43 | if self.basis[i] > 0: 44 | num += 1 45 | return num 46 | 47 | def query_kth_xor(self, x): 48 | res = 0 49 | for i in range(self.m): 50 | if self.basis[i]: 51 | if x & 1: 52 | res ^= self.basis[i] 53 | x >>= 1 54 | return res 55 | 56 | def query_xor_kth(self, num): 57 | bs = BinarySearch() 58 | 59 | def check(x): 60 | return self.query_kth_xor(x) <= num 61 | 62 | return bs.find_int_right(0, self.tot - 1, check) 63 | 64 | def query_max(self): 65 | return self.query_kth_xor(self.tot - 1) 66 | 67 | def query_min(self): 68 | # include empty subset 69 | return self.query_kth_xor(0) 70 | 71 | class LinearBasisVector: 72 | def __init__(self, m): 73 | self.basis = [[0] * m for _ in range(m)] 74 | self.m = m 75 | return 76 | 77 | def add(self, lst): 78 | for i in range(self.m): 79 | if self.basis[i][i] and lst[i]: 80 | a, b = self.basis[i][i], lst[i] 81 | self.basis[i] = [x * b for x in self.basis[i]] 82 | lst = [x * a for x in lst] 83 | lst = [lst[j] - self.basis[i][j] for j in range(self.m)] 84 | for j in range(self.m): 85 | if lst[j]: 86 | self.basis[j] = lst[:] 87 | return True 88 | return False -------------------------------------------------------------------------------- /src/tree/tree_diff_array/template.py: -------------------------------------------------------------------------------- 1 | class TreeDiffArray: 2 | 3 | def __init__(self): 4 | # node and edge differential method on tree 5 | return 6 | 7 | @staticmethod 8 | def bfs_iteration(dct, queries, root=0): 9 | """node differential method""" 10 | n = len(dct) 11 | stack = [root] 12 | parent = [-1] * n 13 | while stack: 14 | i = stack.pop() 15 | for j in dct[i]: 16 | if j != parent[i]: 17 | stack.append(j) 18 | parent[j] = i 19 | 20 | diff = [0] * n 21 | for u, v, ancestor in queries: 22 | # update on the path u to ancestor and v to ancestor 23 | diff[u] += 1 24 | diff[v] += 1 25 | diff[ancestor] -= 1 26 | if parent[ancestor] != -1: 27 | diff[parent[ancestor]] -= 1 28 | 29 | # differential summation from bottom to top 30 | stack = [root] 31 | while stack: 32 | i = stack.pop() 33 | if i >= 0: 34 | stack.append(~i) 35 | for j in dct[i]: 36 | if j != parent[i]: 37 | stack.append(j) 38 | else: 39 | i = ~i 40 | for j in dct[i]: 41 | if j != parent[i]: 42 | diff[i] += diff[j] 43 | return diff 44 | 45 | @staticmethod 46 | def bfs_iteration_edge(dct, queries, root=0): 47 | # Differential calculation of edges on the tree 48 | # where the count of edge is dropped to the corresponding down node 49 | n = len(dct) 50 | stack = [root] 51 | parent = [-1] * n 52 | while stack: 53 | i = stack.pop() 54 | for j in dct[i]: 55 | if j != parent[i]: 56 | stack.append(j) 57 | parent[j] = i 58 | 59 | # Perform edge difference counting 60 | diff = [0] * n 61 | for u, v, ancestor in queries: 62 | # update the edge on the path u to ancestor and v to ancestor 63 | diff[u] += 1 64 | diff[v] += 1 65 | # make the down node represent the edge count 66 | diff[ancestor] -= 2 67 | 68 | # differential summation from bottom to top 69 | stack = [[root, 1]] 70 | while stack: 71 | i, state = stack.pop() 72 | if state: 73 | stack.append([i, 0]) 74 | for j in dct[i]: 75 | if j != parent[i]: 76 | stack.append([j, 1]) 77 | else: 78 | for j in dct[i]: 79 | if j != parent[i]: 80 | diff[i] += diff[j] 81 | return diff 82 | -------------------------------------------------------------------------------- /src/math/bit_operation/template.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from src.struct.sorted_list.template import SortedList 4 | 5 | 6 | class MinimumPairXor: 7 | def __init__(self): 8 | """ 9 | if x < y < z then min(x^y, y^z) < x^z, thus the minimum xor pair must be adjacent 10 | """ 11 | self.lst = SortedList() 12 | self.xor = SortedList() 13 | return 14 | 15 | def add(self, num): 16 | i = self.lst.bisect_left(num) 17 | if i < len(self.lst): 18 | if 0 <= i - 1: 19 | self.xor.discard(self.lst[i] ^ self.lst[i - 1]) 20 | self.lst.add(num) 21 | if 0 <= i - 1 < len(self.lst): 22 | self.xor.add(num ^ self.lst[i - 1]) 23 | if 0 <= i + 1 < len(self.lst): 24 | self.xor.add(num ^ self.lst[i + 1]) 25 | return 26 | 27 | def remove(self, num): 28 | i = self.lst.bisect_left(num) 29 | if 0 <= i - 1 < len(self.lst): 30 | self.xor.discard(num ^ self.lst[i - 1]) 31 | if 0 <= i + 1 < len(self.lst): 32 | self.xor.discard(num ^ self.lst[i + 1]) 33 | self.lst.discard(num) 34 | if i < len(self.lst) and i - 1 >= 0: 35 | self.xor.add(self.lst[i] ^ self.lst[i - 1]) 36 | return 37 | 38 | def query(self): 39 | return self.xor[0] 40 | 41 | 42 | class BitOperation: 43 | def __init__(self): 44 | return 45 | 46 | @staticmethod 47 | def sum_xor(n): 48 | """xor num of range(0, x+1)""" 49 | if n % 4 == 0: 50 | return n # (4*i)^(4*i+1)^(4*i+2)^(4*i+3)=0 51 | elif n % 4 == 1: 52 | return 1 # n^(n-1) 53 | elif n % 4 == 2: 54 | return n + 1 # n^(n-1)^(n-2) 55 | return 0 # n^(n-1)^(n-2)^(n-3) 56 | 57 | @staticmethod 58 | def graycode_to_integer(graycode): 59 | graycode_len = len(graycode) 60 | binary = list() 61 | binary.append(graycode[0]) 62 | for i in range(1, graycode_len): 63 | if graycode[i] == binary[i - 1]: 64 | b = 0 65 | else: 66 | b = 1 67 | binary.append(str(b)) 68 | return int("0b" + ''.join(binary), 2) 69 | 70 | @staticmethod 71 | def integer_to_graycode(integer): 72 | binary = bin(integer).replace('0b', '') 73 | graycode = list() 74 | binary_len = len(binary) 75 | graycode.append(binary[0]) 76 | for i in range(1, binary_len): 77 | if binary[i - 1] == binary[i]: 78 | g = 0 79 | else: 80 | g = 1 81 | graycode.append(str(g)) 82 | return ''.join(graycode) 83 | 84 | @staticmethod 85 | def get_graycode(n: int) -> List[int]: 86 | """all graycode number whose length small or equal to n""" 87 | code = [0, 1] 88 | for i in range(1, n): 89 | code.extend([(1 << i) + num for num in code[::-1]]) 90 | return code 91 | -------------------------------------------------------------------------------- /src/math/gcd_like/template.py: -------------------------------------------------------------------------------- 1 | class GcdLike: 2 | 3 | def __init__(self): 4 | return 5 | 6 | @staticmethod 7 | def extend_gcd(a, b): 8 | sub = dict() 9 | stack = [(a, b, 0)] 10 | while stack: 11 | a, b, s = stack.pop() 12 | if a == 0: 13 | sub[(a, b)] = (b, 0, 1) if b >= 0 else (-b, 0, -1) 14 | continue 15 | if s == 0: 16 | stack.append((a, b, 1)) 17 | stack.append((b % a, a, 0)) 18 | else: 19 | gcd, x, y = sub[(b % a, a)] 20 | sub[(a, b)] = (gcd, y - (b // a) * x, x) if gcd >= 0 else (-gcd, -y + (b // a) * x, -x) 21 | assert gcd == a * (y - (b // a) * x) + b * x 22 | return sub[(a, b)] 23 | 24 | @staticmethod 25 | def binary_gcd(a, b): 26 | if a == 0: 27 | return abs(b) 28 | if b == 0: 29 | return abs(a) 30 | a, b = abs(a), abs(b) 31 | c = 1 32 | while a - b: 33 | if a & 1: 34 | if b & 1: 35 | if a > b: 36 | a = (a - b) >> 1 37 | else: 38 | b = (b - a) >> 1 39 | else: 40 | b = b >> 1 41 | else: 42 | if b & 1: 43 | a = a >> 1 44 | else: 45 | c = c << 1 46 | b = b >> 1 47 | a = a >> 1 48 | return c * a 49 | 50 | @staticmethod 51 | def general_gcd(x, y): 52 | while y: 53 | x, y = y, x % y 54 | return abs(x) 55 | 56 | def mod_reverse(self, a, p): 57 | g, x, y = self.extend_gcd(a, p) 58 | assert g == 1 # necessary of pow(a, -1, p) 59 | return (x + p) % p 60 | 61 | def solve_equation(self, a, b, n=1): 62 | """ 63 | a*x+b*y=n 64 | (a*x)%b=n 65 | """ 66 | gcd, x, y = self.extend_gcd(a, b) 67 | assert a * x + b * y == gcd 68 | if n % gcd: 69 | return [] 70 | x0 = x * (n // gcd) 71 | y0 = y * (n // gcd) 72 | # xt = x0 + b // gcd * t (t=0,1,2,3,...) 73 | # yt = y0 - a // gcd * t (t=0,1,2,3,...) 74 | return [gcd, x0, y0] 75 | 76 | @staticmethod 77 | def add_to_n(n): 78 | # minimum times to make a == n or b == n by change [a, b] to [a + b, b] or [a, a + b] from [1, 1] 79 | if n == 1: 80 | return 0 81 | 82 | def gcd_minus(a, b, c): 83 | nonlocal ans 84 | if c >= ans or not b: 85 | return 86 | if b == 1: 87 | ans = ans if ans < c + a - 1 else c + a - 1 88 | return 89 | # reverse_thinking 90 | gcd_minus(b, a % b, c + a // b) 91 | return 92 | 93 | ans = n - 1 94 | for i in range(1, n): 95 | gcd_minus(n, i, 0) 96 | return ans 97 | -------------------------------------------------------------------------------- /src/string/lyndon_decomposition/template.py: -------------------------------------------------------------------------------- 1 | class LyndonDecomposition: 2 | def __init__(self): 3 | return 4 | 5 | @staticmethod 6 | def solve_by_duval(s): 7 | """template of duval algorithm""" 8 | n, i = len(s), 0 9 | factorization = [] 10 | while i < n: 11 | j, k = i + 1, i 12 | while j < n and s[k] <= s[j]: 13 | if s[k] < s[j]: 14 | k = i 15 | else: 16 | k += 1 17 | j += 1 18 | while i <= k: 19 | factorization.append(s[i: i + j - k]) 20 | i += j - k 21 | return factorization 22 | 23 | @staticmethod 24 | def min_cyclic_string(s): 25 | """template of smallest cyclic string""" 26 | s += s 27 | n = len(s) 28 | i, ans = 0, 0 29 | while i < n // 2: 30 | ans = i 31 | j, k = i + 1, i 32 | while j < n and s[k] <= s[j]: 33 | if s[k] < s[j]: 34 | k = i 35 | else: 36 | k += 1 37 | j += 1 38 | while i <= k: 39 | i += j - k 40 | return s[ans: ans + n // 2] 41 | 42 | @staticmethod 43 | def min_express(sec): 44 | """template of minimum lexicographic expression""" 45 | n = len(sec) # min_suffix 46 | k, i, j = 0, 0, 1 47 | while k < n and i < n and j < n: 48 | if sec[(i + k) % n] == sec[(j + k) % n]: 49 | k += 1 50 | else: 51 | if sec[(i + k) % n] > sec[(j + k) % n]: 52 | i = i + k + 1 53 | else: 54 | j = j + k + 1 55 | if i == j: 56 | i += 1 57 | k = 0 58 | i = i if i < j else j 59 | return i, sec[i:] + sec[:i] 60 | 61 | @staticmethod 62 | def max_express(sec): 63 | """template of maximum lexicographic expression""" 64 | n = len(sec) # max_suffix 65 | k, i, j = 0, 0, 1 66 | while k < n and i < n and j < n: 67 | if sec[(i + k) % n] == sec[(j + k) % n]: 68 | k += 1 69 | else: 70 | if sec[(i + k) % n] < sec[(j + k) % n]: 71 | i = i + k + 1 72 | else: 73 | j = j + k + 1 74 | if i == j: 75 | i += 1 76 | k = 0 77 | i = i if i < j else j 78 | return i, sec[i:] + sec[:i] 79 | 80 | @staticmethod 81 | def max_suffix(s): 82 | """template of maximum lexicographic suffix""" 83 | i, j, n = 0, 1, len(s) 84 | while j < n: 85 | k = 0 86 | while j + k < n and s[i + k] == s[j + k]: 87 | k += 1 88 | if j + k < n and s[i + k] < s[j + k]: 89 | i, j = j, max(j + 1, i + k + 1) 90 | else: 91 | j = j + k + 1 92 | return i 93 | -------------------------------------------------------------------------------- /src/math/peishu_theorem/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:peishu_theorem 3 | Description:ax+by=gcd(a,b) a!=0 or b!=0 4 | 5 | ====================================LeetCode==================================== 6 | 1250(https://leetcode.cn/problems/check-if-it-is-a-good-array/)gcd|peishu_theorem|classical 7 | 8 | ===================================CodeForces=================================== 9 | 1478D(https://codeforces.com/contest/1478/problem/D)peishu_theorem|number_theory|math 10 | 510D(https://codeforces.com/problemset/problem/510/D)peishu_theorem|linear_dp|observation 11 | 12 | =====================================LuoGu====================================== 13 | P4549(https://www.luogu.com.cn/problem/P4549)gcd|peishu_theorem 14 | P8646(https://www.luogu.com.cn/problem/P8646)peishu_theorem|bag_dp 15 | 16 | """ 17 | import math 18 | from collections import defaultdict 19 | 20 | from typing import List 21 | 22 | from src.math.peishu_theorem.template import PeiShuTheorem 23 | from src.util.fast_io import FastIO 24 | 25 | 26 | class Solution: 27 | def __init__(self): 28 | return 29 | 30 | @staticmethod 31 | def lc_1250(nums: List[int]) -> bool: 32 | """ 33 | url: https://leetcode.cn/problems/check-if-it-is-a-good-array/ 34 | tag: gcd|peishu_theorem|classical 35 | """ 36 | return PeiShuTheorem().get_lst_gcd(nums) == 1 37 | 38 | @staticmethod 39 | def lg_p4549(ac=FastIO()): 40 | """ 41 | url: https://www.luogu.com.cn/problem/P4549 42 | tag: gcd|peishu_theorem 43 | """ 44 | ac.read_int() 45 | nums = ac.read_list_ints() 46 | ac.st(PeiShuTheorem().get_lst_gcd(nums)) 47 | return 48 | 49 | @staticmethod 50 | def lg_p8646(ac=FastIO()): 51 | """ 52 | url: https://www.luogu.com.cn/problem/P8646 53 | tag: peishu_theorem|bag_dp 54 | """ 55 | n = ac.read_int() 56 | nums = [ac.read_int() for _ in range(n)] 57 | nums.sort() 58 | s = 10 ** 4 59 | dp = [0] * (s + 1) 60 | dp[0] = 1 61 | for i in range(1, s + 1): 62 | for num in nums: 63 | if num > i: 64 | break 65 | if dp[i - num]: 66 | dp[i] = 1 67 | ans = s + 1 - sum(dp) 68 | if PeiShuTheorem().get_lst_gcd(nums) != 1: 69 | ac.st("INF") 70 | else: 71 | ac.st(ans) 72 | return 73 | 74 | @staticmethod 75 | def cf_510d(ac=FastIO()): 76 | """ 77 | url: https://codeforces.com/problemset/problem/510/D 78 | tag: peishu_theorem|linear_dp|observation 79 | """ 80 | n = ac.read_int() 81 | ll = ac.read_list_ints() 82 | cc = ac.read_list_ints() 83 | pre = defaultdict(lambda: math.inf) 84 | for i in range(n): 85 | cur = pre.copy() 86 | for p in pre: 87 | g = math.gcd(p, ll[i]) 88 | cur[g] = min(cur[g], pre[p] + cc[i]) 89 | cur[ll[i]] = min(cur[ll[i]], cc[i]) 90 | pre = cur 91 | ac.st(pre[1] if pre[1] < math.inf else -1) 92 | return 93 | -------------------------------------------------------------------------------- /src/struct/sparse_table/example.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import unittest 4 | from functools import reduce 5 | from itertools import accumulate 6 | from operator import or_, and_ 7 | 8 | from src.struct.sparse_table.template import SparseTable2D, SparseTable 9 | 10 | 11 | class TestGeneral(unittest.TestCase): 12 | 13 | def test_sparse_table(self): 14 | 15 | def check_and(lst): 16 | ans = lst[0] 17 | for num in lst[1:]: 18 | ans &= num 19 | return ans 20 | 21 | def check_or(lst): 22 | ans = lst[0] 23 | for num in lst[1:]: 24 | ans |= num 25 | return ans 26 | 27 | nums = [9, 3, 1, 7, 5, 6, 0, 8] 28 | st = SparseTable(nums, max) 29 | queries = [[1, 6], [1, 5], [2, 7], [2, 6], [1, 8], [4, 8], [3, 7], [1, 8]] 30 | assert [st.query(left - 1, right - 1) for left, right in queries] == [9, 9, 7, 7, 9, 8, 7, 9] 31 | 32 | ceil = 2000 33 | nums = [random.randint(1, ceil) for _ in range(2000)] 34 | st1_max = SparseTable(nums, max) 35 | st1_min = SparseTable(nums, min) 36 | st1_gcd = SparseTable(nums, math.gcd) 37 | st1_lcm = SparseTable(nums, math.lcm) 38 | st1_and = SparseTable(nums, and_) 39 | st1_or = SparseTable(nums, or_) 40 | 41 | for _ in range(ceil): 42 | left = random.randint(1, ceil - 10) 43 | right = random.randint(left, ceil) 44 | left -= 1 45 | right -= 1 46 | assert st1_max.query(left, right) == max(nums[left:right + 1]) 47 | assert st1_min.query(left, right) == min(nums[left:right + 1]) 48 | assert st1_gcd.query(left, right) == reduce(math.gcd, nums[left:right + 1]) 49 | assert st1_lcm.query(left, right) == reduce(math.lcm, nums[left:right + 1]) 50 | assert st1_and.query(left, right) == check_and(nums[left:right + 1]) 51 | assert st1_or.query(left, right) == check_or(nums[left:right + 1]) 52 | pre = list(accumulate(nums[left:], and_)) 53 | for x in range(len(pre)): 54 | val = pre[x] 55 | right = left 56 | cur = nums[left] 57 | for y in range(left, ceil): 58 | cur &= nums[y] 59 | if cur >= val: 60 | right = y 61 | else: 62 | break 63 | assert right == st1_and.bisect_right(left, val, (1 << 32) - 1)[0] 64 | return 65 | 66 | def test_sparse_table_2d_max_min(self): 67 | 68 | m = n = 50 69 | high = 100000 70 | grid = [[random.randint(0, high) for _ in range(n)] for _ in range(m)] 71 | 72 | for method in ["max", "min", "lcm", "gcd", "or", "and"]: 73 | st = SparseTable2D(grid, method) 74 | x1 = random.randint(0, m - 1) 75 | y1 = random.randint(0, n - 1) 76 | x2 = random.randint(x1, m - 1) 77 | y2 = random.randint(y1, n - 1) 78 | 79 | ans1 = st.query(x1, y1, x2, y2) 80 | ans2 = st.fun([st.fun(g[y1:y2 + 1]) for g in grid[x1:x2 + 1]]) 81 | assert ans1 == ans2 82 | return 83 | 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/codeforces/simple.py: -------------------------------------------------------------------------------- 1 | from sys import stdin, stdout 2 | import bisect 3 | import decimal 4 | import heapq 5 | from types import GeneratorType 6 | import random 7 | from bisect import bisect_left, bisect_right 8 | from heapq import heappush, heappop, heappushpop 9 | from functools import cmp_to_key 10 | from collections import defaultdict, Counter, deque 11 | import math 12 | from functools import lru_cache 13 | from heapq import nlargest 14 | from functools import reduce 15 | from decimal import Decimal 16 | from itertools import combinations, permutations 17 | from operator import xor, add 18 | from operator import mul 19 | from typing import List, Callable, Dict, Set, Tuple, DefaultDict 20 | from heapq import heappush, heappop, heapify 21 | 22 | 23 | class FastIO: 24 | def __init__(self): 25 | self.random_seed = 0 26 | self.flush = False 27 | self.inf = 1 << 32 28 | self.dire4 = [(0, -1), (0, 1), (1, 0), (-1, 0)] 29 | self.dire8 = [(0, -1), (0, 1), (1, 0), (-1, 0)] + [(1, -1), (1, 1), (-1, -1), (-1, 1)] 30 | return 31 | 32 | @staticmethod 33 | def read_int(): 34 | return int(stdin.readline().rstrip()) 35 | 36 | @staticmethod 37 | def read_float(): 38 | return float(stdin.readline().rstrip()) 39 | 40 | @staticmethod 41 | def read_list_ints(): 42 | return list(map(int, stdin.readline().rstrip().split())) 43 | 44 | @staticmethod 45 | def read_list_ints_minus_one(): 46 | return list(map(lambda x: int(x) - 1, stdin.readline().rstrip().split())) 47 | 48 | @staticmethod 49 | def read_str(): 50 | return stdin.readline().rstrip() 51 | 52 | @staticmethod 53 | def read_list_strs(): 54 | return stdin.readline().rstrip().split() 55 | 56 | def get_random_seed(self): 57 | import random 58 | self.random_seed = random.randint(0, 10 ** 9 + 7) 59 | return 60 | 61 | def st(self, x): 62 | return print(x, flush=self.flush) 63 | 64 | def yes(self, s=None): 65 | self.st("Yes" if not s else s) 66 | return 67 | 68 | def no(self, s=None): 69 | self.st("No" if not s else s) 70 | return 71 | 72 | def lst(self, x): 73 | return print(*x, flush=self.flush) 74 | 75 | def flatten(self, lst): 76 | self.st("\n".join(str(x) for x in lst)) 77 | return 78 | 79 | @staticmethod 80 | def max(a, b): 81 | return a if a > b else b 82 | 83 | @staticmethod 84 | def min(a, b): 85 | return a if a < b else b 86 | 87 | @staticmethod 88 | def ceil(a, b): 89 | return a // b + int(a % b != 0) 90 | 91 | @staticmethod 92 | def accumulate(nums): 93 | n = len(nums) 94 | pre = [0] * (n + 1) 95 | for i in range(n): 96 | pre[i + 1] = pre[i] + nums[i] 97 | return pre 98 | 99 | 100 | class Solution: 101 | def __init__(self): 102 | return 103 | 104 | @staticmethod 105 | def main(ac=FastIO()): 106 | """ 107 | url: url of the problem 108 | tag: algorithm tag 109 | """ 110 | for _ in range(ac.read_int()): 111 | pass 112 | return 113 | 114 | 115 | Solution().main() 116 | -------------------------------------------------------------------------------- /tests/codeforces/problem_a.py: -------------------------------------------------------------------------------- 1 | from sys import stdin, stdout 2 | import bisect 3 | import decimal 4 | import heapq 5 | from types import GeneratorType 6 | import random 7 | from bisect import bisect_left, bisect_right 8 | from heapq import heappush, heappop, heappushpop 9 | from functools import cmp_to_key 10 | from collections import defaultdict, Counter, deque 11 | import math 12 | from functools import lru_cache 13 | from heapq import nlargest 14 | from functools import reduce 15 | from decimal import Decimal 16 | from itertools import combinations, permutations 17 | from operator import xor, add 18 | from operator import mul 19 | from typing import List, Callable, Dict, Set, Tuple, DefaultDict 20 | from heapq import heappush, heappop, heapify 21 | 22 | 23 | class FastIO: 24 | def __init__(self): 25 | self.random_seed = 0 26 | self.flush = False 27 | self.inf = 1 << 32 28 | self.dire4 = [(0, -1), (0, 1), (1, 0), (-1, 0)] 29 | self.dire8 = [(0, -1), (0, 1), (1, 0), (-1, 0)] + [(1, -1), (1, 1), (-1, -1), (-1, 1)] 30 | return 31 | 32 | @staticmethod 33 | def read_int(): 34 | return int(stdin.readline().rstrip()) 35 | 36 | @staticmethod 37 | def read_float(): 38 | return float(stdin.readline().rstrip()) 39 | 40 | @staticmethod 41 | def read_list_ints(): 42 | return list(map(int, stdin.readline().rstrip().split())) 43 | 44 | @staticmethod 45 | def read_list_ints_minus_one(): 46 | return list(map(lambda x: int(x) - 1, stdin.readline().rstrip().split())) 47 | 48 | @staticmethod 49 | def read_str(): 50 | return stdin.readline().rstrip() 51 | 52 | @staticmethod 53 | def read_list_strs(): 54 | return stdin.readline().rstrip().split() 55 | 56 | def get_random_seed(self): 57 | import random 58 | self.random_seed = random.randint(0, 10 ** 9 + 7) 59 | return 60 | 61 | def st(self, x): 62 | return print(x, flush=self.flush) 63 | 64 | def yes(self, s=None): 65 | self.st("Yes" if not s else s) 66 | return 67 | 68 | def no(self, s=None): 69 | self.st("No" if not s else s) 70 | return 71 | 72 | def lst(self, x): 73 | return print(*x, flush=self.flush) 74 | 75 | def flatten(self, lst): 76 | self.st("\n".join(str(x) for x in lst)) 77 | return 78 | 79 | @staticmethod 80 | def max(a, b): 81 | return a if a > b else b 82 | 83 | @staticmethod 84 | def min(a, b): 85 | return a if a < b else b 86 | 87 | @staticmethod 88 | def ceil(a, b): 89 | return a // b + int(a % b != 0) 90 | 91 | @staticmethod 92 | def accumulate(nums): 93 | n = len(nums) 94 | pre = [0] * (n + 1) 95 | for i in range(n): 96 | pre[i + 1] = pre[i] + nums[i] 97 | return pre 98 | 99 | 100 | class Solution: 101 | def __init__(self): 102 | return 103 | 104 | @staticmethod 105 | def main(ac=FastIO()): 106 | """ 107 | url: url of the problem 108 | tag: algorithm tag 109 | """ 110 | for _ in range(ac.read_int()): 111 | pass 112 | return 113 | 114 | 115 | Solution().main() 116 | -------------------------------------------------------------------------------- /tests/codeforces/problem_b.py: -------------------------------------------------------------------------------- 1 | from sys import stdin, stdout 2 | import bisect 3 | import decimal 4 | import heapq 5 | from types import GeneratorType 6 | import random 7 | from bisect import bisect_left, bisect_right 8 | from heapq import heappush, heappop, heappushpop 9 | from functools import cmp_to_key 10 | from collections import defaultdict, Counter, deque 11 | import math 12 | from functools import lru_cache 13 | from heapq import nlargest 14 | from functools import reduce 15 | from decimal import Decimal 16 | from itertools import combinations, permutations 17 | from operator import xor, add 18 | from operator import mul 19 | from typing import List, Callable, Dict, Set, Tuple, DefaultDict 20 | from heapq import heappush, heappop, heapify 21 | 22 | 23 | class FastIO: 24 | def __init__(self): 25 | self.random_seed = 0 26 | self.flush = False 27 | self.inf = 1 << 32 28 | self.dire4 = [(0, -1), (0, 1), (1, 0), (-1, 0)] 29 | self.dire8 = [(0, -1), (0, 1), (1, 0), (-1, 0)] + [(1, -1), (1, 1), (-1, -1), (-1, 1)] 30 | return 31 | 32 | @staticmethod 33 | def read_int(): 34 | return int(stdin.readline().rstrip()) 35 | 36 | @staticmethod 37 | def read_float(): 38 | return float(stdin.readline().rstrip()) 39 | 40 | @staticmethod 41 | def read_list_ints(): 42 | return list(map(int, stdin.readline().rstrip().split())) 43 | 44 | @staticmethod 45 | def read_list_ints_minus_one(): 46 | return list(map(lambda x: int(x) - 1, stdin.readline().rstrip().split())) 47 | 48 | @staticmethod 49 | def read_str(): 50 | return stdin.readline().rstrip() 51 | 52 | @staticmethod 53 | def read_list_strs(): 54 | return stdin.readline().rstrip().split() 55 | 56 | def get_random_seed(self): 57 | import random 58 | self.random_seed = random.randint(0, 10 ** 9 + 7) 59 | return 60 | 61 | def st(self, x): 62 | return print(x, flush=self.flush) 63 | 64 | def yes(self, s=None): 65 | self.st("Yes" if not s else s) 66 | return 67 | 68 | def no(self, s=None): 69 | self.st("No" if not s else s) 70 | return 71 | 72 | def lst(self, x): 73 | return print(*x, flush=self.flush) 74 | 75 | def flatten(self, lst): 76 | self.st("\n".join(str(x) for x in lst)) 77 | return 78 | 79 | @staticmethod 80 | def max(a, b): 81 | return a if a > b else b 82 | 83 | @staticmethod 84 | def min(a, b): 85 | return a if a < b else b 86 | 87 | @staticmethod 88 | def ceil(a, b): 89 | return a // b + int(a % b != 0) 90 | 91 | @staticmethod 92 | def accumulate(nums): 93 | n = len(nums) 94 | pre = [0] * (n + 1) 95 | for i in range(n): 96 | pre[i + 1] = pre[i] + nums[i] 97 | return pre 98 | 99 | 100 | class Solution: 101 | def __init__(self): 102 | return 103 | 104 | @staticmethod 105 | def main(ac=FastIO()): 106 | """ 107 | url: url of the problem 108 | tag: algorithm tag 109 | """ 110 | for _ in range(ac.read_int()): 111 | pass 112 | return 113 | 114 | 115 | Solution().main() 116 | -------------------------------------------------------------------------------- /tests/codeforces/problem_c.py: -------------------------------------------------------------------------------- 1 | from sys import stdin, stdout 2 | import bisect 3 | import decimal 4 | import heapq 5 | from types import GeneratorType 6 | import random 7 | from bisect import bisect_left, bisect_right 8 | from heapq import heappush, heappop, heappushpop 9 | from functools import cmp_to_key 10 | from collections import defaultdict, Counter, deque 11 | import math 12 | from functools import lru_cache 13 | from heapq import nlargest 14 | from functools import reduce 15 | from decimal import Decimal 16 | from itertools import combinations, permutations 17 | from operator import xor, add 18 | from operator import mul 19 | from typing import List, Callable, Dict, Set, Tuple, DefaultDict 20 | from heapq import heappush, heappop, heapify 21 | 22 | 23 | class FastIO: 24 | def __init__(self): 25 | self.random_seed = 0 26 | self.flush = False 27 | self.inf = 1 << 32 28 | self.dire4 = [(0, -1), (0, 1), (1, 0), (-1, 0)] 29 | self.dire8 = [(0, -1), (0, 1), (1, 0), (-1, 0)] + [(1, -1), (1, 1), (-1, -1), (-1, 1)] 30 | return 31 | 32 | @staticmethod 33 | def read_int(): 34 | return int(stdin.readline().rstrip()) 35 | 36 | @staticmethod 37 | def read_float(): 38 | return float(stdin.readline().rstrip()) 39 | 40 | @staticmethod 41 | def read_list_ints(): 42 | return list(map(int, stdin.readline().rstrip().split())) 43 | 44 | @staticmethod 45 | def read_list_ints_minus_one(): 46 | return list(map(lambda x: int(x) - 1, stdin.readline().rstrip().split())) 47 | 48 | @staticmethod 49 | def read_str(): 50 | return stdin.readline().rstrip() 51 | 52 | @staticmethod 53 | def read_list_strs(): 54 | return stdin.readline().rstrip().split() 55 | 56 | def get_random_seed(self): 57 | import random 58 | self.random_seed = random.randint(0, 10 ** 9 + 7) 59 | return 60 | 61 | def st(self, x): 62 | return print(x, flush=self.flush) 63 | 64 | def yes(self, s=None): 65 | self.st("Yes" if not s else s) 66 | return 67 | 68 | def no(self, s=None): 69 | self.st("No" if not s else s) 70 | return 71 | 72 | def lst(self, x): 73 | return print(*x, flush=self.flush) 74 | 75 | def flatten(self, lst): 76 | self.st("\n".join(str(x) for x in lst)) 77 | return 78 | 79 | @staticmethod 80 | def max(a, b): 81 | return a if a > b else b 82 | 83 | @staticmethod 84 | def min(a, b): 85 | return a if a < b else b 86 | 87 | @staticmethod 88 | def ceil(a, b): 89 | return a // b + int(a % b != 0) 90 | 91 | @staticmethod 92 | def accumulate(nums): 93 | n = len(nums) 94 | pre = [0] * (n + 1) 95 | for i in range(n): 96 | pre[i + 1] = pre[i] + nums[i] 97 | return pre 98 | 99 | 100 | class Solution: 101 | def __init__(self): 102 | return 103 | 104 | @staticmethod 105 | def main(ac=FastIO()): 106 | """ 107 | url: url of the problem 108 | tag: algorithm tag 109 | """ 110 | for _ in range(ac.read_int()): 111 | pass 112 | return 113 | 114 | 115 | Solution().main() 116 | -------------------------------------------------------------------------------- /tests/codeforces/problem_d.py: -------------------------------------------------------------------------------- 1 | from sys import stdin, stdout 2 | import bisect 3 | import decimal 4 | import heapq 5 | from types import GeneratorType 6 | import random 7 | from bisect import bisect_left, bisect_right 8 | from heapq import heappush, heappop, heappushpop 9 | from functools import cmp_to_key 10 | from collections import defaultdict, Counter, deque 11 | import math 12 | from functools import lru_cache 13 | from heapq import nlargest 14 | from functools import reduce 15 | from decimal import Decimal 16 | from itertools import combinations, permutations 17 | from operator import xor, add 18 | from operator import mul 19 | from typing import List, Callable, Dict, Set, Tuple, DefaultDict 20 | from heapq import heappush, heappop, heapify 21 | 22 | 23 | class FastIO: 24 | def __init__(self): 25 | self.random_seed = 0 26 | self.flush = False 27 | self.inf = 1 << 32 28 | self.dire4 = [(0, -1), (0, 1), (1, 0), (-1, 0)] 29 | self.dire8 = [(0, -1), (0, 1), (1, 0), (-1, 0)] + [(1, -1), (1, 1), (-1, -1), (-1, 1)] 30 | return 31 | 32 | @staticmethod 33 | def read_int(): 34 | return int(stdin.readline().rstrip()) 35 | 36 | @staticmethod 37 | def read_float(): 38 | return float(stdin.readline().rstrip()) 39 | 40 | @staticmethod 41 | def read_list_ints(): 42 | return list(map(int, stdin.readline().rstrip().split())) 43 | 44 | @staticmethod 45 | def read_list_ints_minus_one(): 46 | return list(map(lambda x: int(x) - 1, stdin.readline().rstrip().split())) 47 | 48 | @staticmethod 49 | def read_str(): 50 | return stdin.readline().rstrip() 51 | 52 | @staticmethod 53 | def read_list_strs(): 54 | return stdin.readline().rstrip().split() 55 | 56 | def get_random_seed(self): 57 | import random 58 | self.random_seed = random.randint(0, 10 ** 9 + 7) 59 | return 60 | 61 | def st(self, x): 62 | return print(x, flush=self.flush) 63 | 64 | def yes(self, s=None): 65 | self.st("Yes" if not s else s) 66 | return 67 | 68 | def no(self, s=None): 69 | self.st("No" if not s else s) 70 | return 71 | 72 | def lst(self, x): 73 | return print(*x, flush=self.flush) 74 | 75 | def flatten(self, lst): 76 | self.st("\n".join(str(x) for x in lst)) 77 | return 78 | 79 | @staticmethod 80 | def max(a, b): 81 | return a if a > b else b 82 | 83 | @staticmethod 84 | def min(a, b): 85 | return a if a < b else b 86 | 87 | @staticmethod 88 | def ceil(a, b): 89 | return a // b + int(a % b != 0) 90 | 91 | @staticmethod 92 | def accumulate(nums): 93 | n = len(nums) 94 | pre = [0] * (n + 1) 95 | for i in range(n): 96 | pre[i + 1] = pre[i] + nums[i] 97 | return pre 98 | 99 | 100 | class Solution: 101 | def __init__(self): 102 | return 103 | 104 | @staticmethod 105 | def main(ac=FastIO()): 106 | """ 107 | url: url of the problem 108 | tag: algorithm tag 109 | """ 110 | for _ in range(ac.read_int()): 111 | pass 112 | return 113 | 114 | 115 | Solution().main() 116 | -------------------------------------------------------------------------------- /src/string/suffix_array/template.py: -------------------------------------------------------------------------------- 1 | class SuffixArray: 2 | 3 | def __init__(self): 4 | return 5 | 6 | @staticmethod 7 | def build(s, sig): 8 | # sa: index is rank and value is pos 9 | # rk: index if pos and value is rank 10 | # height: lcp of rank i-th suffix and (i-1)-th suffix 11 | # sum(height): count of same substring of s 12 | # n*(n+1)//2 - sum(height): count of different substring of s 13 | # max(height): can compute the longest duplicate substring, 14 | # which is s[i: i + height[j]] and j = height.index(max(height)) and i = sa[j] 15 | # sig: number of unique rankings which initially is the size of the character set 16 | 17 | n = len(s) 18 | sa = list(range(n)) 19 | rk = s[:] 20 | ll = 0 # ll is the length that has already been sorted, and now it needs to be sorted by 2ll length 21 | tmp = [0] * n 22 | while True: 23 | p = [i for i in range(n - ll, n)] + [x - ll for i, x in enumerate(sa) if x >= ll] 24 | # for suffixes with a length less than l, their second keyword ranking is definitely 25 | # the smallest because they are all empty 26 | # for suffixes of other lengths, suffixes starting at 'sa [i]' rank i-th, and their 27 | # first ll characters happen to be the second keyword of suffixes starting at 'sa[i] - ll' 28 | # start cardinality sorting, and first perform statistics on the first keyword 29 | # first, count how many values each has 30 | cnt = [0] * sig 31 | for i in range(n): 32 | cnt[rk[i]] += 1 33 | # make a prefix and for easy cardinality sorting 34 | for i in range(1, sig): 35 | cnt[i] += cnt[i - 1] 36 | 37 | # then use cardinality sorting to calculate the new sa 38 | for i in range(n - 1, -1, -1): 39 | w = rk[p[i]] 40 | cnt[w] -= 1 41 | sa[cnt[w]] = p[i] 42 | 43 | # new_sa to check new_rk 44 | def equal(ii, jj, lll): 45 | if rk[ii] != rk[jj]: 46 | return False 47 | if ii + lll >= n and jj + lll >= n: 48 | return True 49 | if ii + lll < n and jj + lll < n: 50 | return rk[ii + lll] == rk[jj + lll] 51 | return False 52 | 53 | sig = -1 54 | for i in range(n): 55 | tmp[i] = 0 56 | 57 | for i in range(n): 58 | # compute the lcp 59 | if i == 0 or not equal(sa[i], sa[i - 1], ll): 60 | sig += 1 61 | tmp[sa[i]] = sig 62 | 63 | for i in range(n): 64 | rk[i] = tmp[i] 65 | sig += 1 66 | if sig == n: 67 | break 68 | ll = ll << 1 if ll > 0 else 1 69 | 70 | # height 71 | k = 0 72 | height = [0] * n 73 | for i in range(n): 74 | if rk[i] > 0: 75 | j = sa[rk[i] - 1] 76 | while i + k < n and j + k < n and s[i + k] == s[j + k]: 77 | k += 1 78 | height[rk[i]] = k 79 | k = 0 if k - 1 < 0 else k - 1 80 | return sa, rk, height 81 | -------------------------------------------------------------------------------- /src/math/gcd_like/example.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import unittest 4 | 5 | from src.math.gcd_like.template import GcdLike 6 | from src.util.fast_io import SEED 7 | 8 | random.seed(SEED) 9 | 10 | 11 | class TestGeneral(unittest.TestCase): 12 | 13 | def test_gcd_like(self): 14 | gl = GcdLike() 15 | n = 10 ** 5 16 | for _ in range(1000): 17 | a = random.randint(-n, n) 18 | b = random.randint(-n, n) 19 | gcd1, x, y = gl.extend_gcd(a, b) 20 | gcd2 = gl.binary_gcd(a, b) 21 | gcd3 = gl.general_gcd(a, b) 22 | gcd4 = math.gcd(a, b) 23 | assert gcd1 == gcd2 == gcd3 == gcd4 24 | assert a * x + b * y == gcd1 25 | 26 | for a in range(-1000, 1000): 27 | for b in range(-1000, 1000): 28 | gcd1, x, y = gl.extend_gcd(a, b) 29 | gcd2 = gl.binary_gcd(a, b) 30 | gcd3 = gl.general_gcd(a, b) 31 | gcd4 = math.gcd(a, b) 32 | assert gcd1 == gcd2 == gcd3 == gcd4 33 | assert a * x + b * y == gcd1 34 | return 35 | 36 | def test_gcd_like_extend_gcd(self): 37 | gl = GcdLike() 38 | n = 10 ** 5 39 | for _ in range(1000): 40 | a = random.randint(-n, n) 41 | b = random.randint(-n, n) 42 | gl.extend_gcd(a, b) 43 | 44 | for a in range(-1000, 1000): 45 | for b in range(-1000, 1000): 46 | gl.extend_gcd(a, b) 47 | return 48 | 49 | def test_gcd_like_binary_gcd(self): 50 | gl = GcdLike() 51 | n = 10 ** 5 52 | for _ in range(1000): 53 | a = random.randint(-n, n) 54 | b = random.randint(-n, n) 55 | gl.binary_gcd(a, b) 56 | 57 | for a in range(-1000, 1000): 58 | for b in range(-1000, 1000): 59 | gl.binary_gcd(a, b) 60 | return 61 | 62 | def test_gcd_like_general_gcd(self): 63 | gl = GcdLike() 64 | n = 10 ** 5 65 | for _ in range(1000): 66 | a = random.randint(-n, n) 67 | b = random.randint(-n, n) 68 | gl.general_gcd(a, b) 69 | 70 | for a in range(-1000, 1000): 71 | for b in range(-1000, 1000): 72 | gl.general_gcd(a, b) 73 | return 74 | 75 | def test_gcd_like_math_gcd(self): 76 | n = 10 ** 5 77 | for _ in range(1000): 78 | a = random.randint(-n, n) 79 | b = random.randint(-n, n) 80 | math.gcd(a, b) 81 | 82 | for a in range(-1000, 1000): 83 | for b in range(-1000, 1000): 84 | math.gcd(a, b) 85 | return 86 | 87 | def test_gcd_like_mod_reverse(self): 88 | n = 10 ** 5 89 | gl = GcdLike() 90 | for _ in range(1000): 91 | a = random.randint(-n, n) 92 | b = random.randint(-n, n) 93 | if math.gcd(a, b) == 1 and b: 94 | assert gl.mod_reverse(a, b) == pow(a, -1, b) 95 | 96 | for a in range(-1000, 1000): 97 | for b in range(-1000, 1000): 98 | math.gcd(a, b) 99 | if math.gcd(a, b) == 1 and b: 100 | assert gl.mod_reverse(a, b) == pow(a, -1, b) 101 | return 102 | 103 | 104 | if __name__ == '__main__': 105 | unittest.main() 106 | -------------------------------------------------------------------------------- /src/basis/offline_query/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:offline_query|sorting 3 | Description:with the help of pointer and sorting for offline query 4 | 5 | ====================================LeetCode==================================== 6 | 100110(https://leetcode.cn/contest/weekly-contest-372/problems/find-building-where-alice-and-bob-can-meet/)offline_query|sorting 7 | 1851(https://leetcode.cn/problems/minimum-interval-to-include-each-query) 8 | 2736(https://leetcode.cn/problems/maximum-sum-queries/description/) 9 | 10 | =====================================LuoGu====================================== 11 | xx(xxx)xxxxxxxxxxxxxxxxxxxx 12 | 13 | ===================================CodeForces=================================== 14 | 15 | ===================================AtCoder=================================== 16 | ABC245E(https://atcoder.jp/contests/abc245/tasks/abc245_e)sorted_list|offline_query 17 | 18 | """ 19 | from typing import List 20 | 21 | from src.struct.sorted_list.template import SortedList 22 | from src.util.fast_io import FastIO 23 | 24 | 25 | class Solution: 26 | def __int__(self): 27 | return 28 | 29 | @staticmethod 30 | def lc_100110(heights: List[int], queries: List[List[int]]) -> List[int]: 31 | """ 32 | url: https://leetcode.cn/contest/weekly-contest-372/problems/find-building-where-alice-and-bob-can-meet/ 33 | tag: offline_query|sorting 34 | """ 35 | m = len(queries) 36 | for i in range(m): 37 | x, y = queries[i] 38 | queries[i] = (i, x, y, max(heights[x], heights[y])) 39 | queries.sort(key=lambda it: -it[-1]) 40 | ans = [-1] * m 41 | 42 | n = len(heights) 43 | original = heights[:] 44 | heights = [(i, heights[i]) for i in range(n)] 45 | heights.sort(key=lambda it: -it[-1]) 46 | j = 0 47 | lst = SortedList() 48 | for i, x, y, c in queries: 49 | if x < y and original[x] < original[y]: 50 | ans[i] = y 51 | continue 52 | if y < x and original[y] < original[x]: 53 | ans[i] = x 54 | continue 55 | if x == y: 56 | ans[i] = y 57 | continue 58 | while j < n and heights[j][1] > c: 59 | lst.add(heights[j][0]) 60 | j += 1 61 | k = lst.bisect_right(max(x, y)) 62 | if 0 <= k < len(lst): 63 | ans[i] = lst[k] 64 | return ans 65 | 66 | @staticmethod 67 | def abc_245e(ac=FastIO()): 68 | """ 69 | url: https://atcoder.jp/contests/abc245/tasks/abc245_e 70 | tag: sorted_list|offline_query 71 | """ 72 | n, m = ac.read_list_ints() 73 | a = ac.read_list_ints() 74 | b = ac.read_list_ints() 75 | c = ac.read_list_ints() 76 | d = ac.read_list_ints() 77 | ind1 = list(range(n)) 78 | ind2 = list(range(m)) 79 | ind1.sort(key=lambda it: -a[it]) 80 | ind2.sort(key=lambda it: -c[it]) 81 | lst = SortedList() 82 | j = 0 83 | for i in ind1: 84 | aa, bb = a[i], b[i] 85 | while j < m and c[ind2[j]] >= aa: 86 | lst.add(d[ind2[j]]) 87 | j += 1 88 | if not lst or lst[-1] < bb: 89 | ac.no() 90 | break 91 | i = lst.bisect_left(bb) 92 | lst.pop(i) 93 | else: 94 | ac.yes() 95 | return 96 | -------------------------------------------------------------------------------- /src/math/comb_perm/template.py: -------------------------------------------------------------------------------- 1 | class Combinatorics: 2 | def __init__(self, n, mod): 3 | assert mod > n 4 | self.n = n + 10 5 | self.mod = mod 6 | 7 | self.perm = [1] 8 | self.rev = [1] 9 | self.inv = [0] 10 | self.fault = [0] 11 | 12 | self.build_perm() 13 | self.build_rev() 14 | self.build_inv() 15 | self.build_fault() 16 | return 17 | 18 | def build_perm(self): 19 | self.perm = [1] * (self.n + 1) # (i!) % mod 20 | for i in range(1, self.n + 1): 21 | self.perm[i] = self.perm[i - 1] * i % self.mod 22 | return 23 | 24 | def build_rev(self): 25 | self.rev = [1] * (self.n + 1) # pow(i!, -1, mod) 26 | self.rev[-1] = pow(self.perm[-1], -1, self.mod) # GcdLike().mod_reverse(self.perm[-1], self.mod) 27 | for i in range(self.n - 1, 0, -1): 28 | self.rev[i] = (self.rev[i + 1] * (i + 1) % self.mod) # pow(i!, -1, mod) 29 | return 30 | 31 | def build_inv(self): 32 | self.inv = [0] * (self.n + 1) # pow(i, -1, mod) 33 | self.inv[1] = 1 34 | for i in range(2, self.n + 1): 35 | self.inv[i] = (self.mod - self.mod // i) * self.inv[self.mod % i] % self.mod 36 | return 37 | 38 | def build_fault(self): 39 | self.fault = [0] * (self.n + 1) # fault permutation 40 | self.fault[0] = 1 41 | self.fault[2] = 1 42 | for i in range(3, self.n + 1): 43 | self.fault[i] = (i - 1) * (self.fault[i - 1] + self.fault[i - 2]) 44 | self.fault[i] %= self.mod 45 | return 46 | 47 | def comb(self, a, b): 48 | if a < b: 49 | return 0 50 | res = self.perm[a] * self.rev[b] * self.rev[a - b] # comb(a, b) % mod = (a!/(b!(a-b)!)) % mod 51 | return res % self.mod 52 | 53 | def factorial(self, a): 54 | res = self.perm[a] # (a!) % mod 55 | return res % self.mod 56 | 57 | def inverse(self, n): 58 | res = self.perm[n - 1] * self.rev[n] % self.mod # pow(n, -1, mod) 59 | return res 60 | 61 | def catalan(self, n): 62 | res = (self.comb(2 * n, n) - self.comb(2 * n, n - 1)) % self.mod 63 | return res 64 | 65 | 66 | class Lucas: 67 | def __init__(self): 68 | # comb(a,b) % p 69 | return 70 | 71 | @staticmethod 72 | def comb(n, m, p): 73 | # comb(n, m) % p 74 | ans = 1 75 | for x in range(n - m + 1, n + 1): 76 | ans *= x 77 | ans %= p 78 | for x in range(1, m + 1): 79 | ans *= pow(x, -1, p) 80 | ans %= p 81 | return ans 82 | 83 | def lucas_iter(self, n, m, p): 84 | # math.comb(n, m) % p where p is prime 85 | if m == 0: 86 | return 1 87 | stack = [[n, m]] 88 | dct = dict() 89 | while stack: 90 | n, m = stack.pop() 91 | if n >= 0: 92 | if m == 0: 93 | dct[(n, m)] = 1 94 | continue 95 | stack.append((~n, m)) 96 | stack.append((n // p, m // p)) 97 | else: 98 | n = ~n 99 | dct[(n, m)] = (self.comb(n % p, m % p, p) % p) * dct[(n // p, m // p)] % p 100 | return dct[(n, m)] 101 | 102 | @staticmethod 103 | def extend_lucas(self, n, m, p): 104 | # math.comb(n, m) % p where p is not necessary prime 105 | return 106 | -------------------------------------------------------------------------------- /src/math/convex_hull/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:convex_hull|minimum_circle_coverage|random_increment_method 3 | Description:convex_hull 4 | 5 | ====================================LeetCode==================================== 6 | 1924(https://leetcode.cn/problems/erect-the-fence-ii/)convex_hull|tripart_pack_tripart|minimum_circle_coverage 7 | 8 | =====================================LuoGu====================================== 9 | P1742(https://www.luogu.com.cn/problem/P1742)random_increment_method|minimum_circle_coverage 10 | P3517(https://www.luogu.com.cn/problem/P3517)binary_search_of_binary_search|random_increment_method|minimum_circle_coverage 11 | 12 | """ 13 | 14 | from typing import List 15 | 16 | from src.math.convex_hull.template import MinCircleOverlap 17 | from src.util.fast_io import FastIO 18 | 19 | 20 | class Solution: 21 | def __init__(self): 22 | return 23 | 24 | @staticmethod 25 | def lc_1924(trees: List[List[int]]) -> List[float]: 26 | """ 27 | url: https://leetcode.cn/problems/erect-the-fence-ii/ 28 | tag: convex_hull|tripart_pack_tripart|minimum_circle_coverage 29 | """ 30 | ans = MinCircleOverlap().get_min_circle_overlap(trees) 31 | return list(ans) 32 | 33 | @staticmethod 34 | def lg_p1742(ac=FastIO()): 35 | """ 36 | url: https://www.luogu.com.cn/problem/P1742 37 | tag: random_increment_method|minimum_circle_coverage 38 | """ 39 | n = ac.read_int() 40 | nums = [ac.read_list_floats() for _ in range(n)] 41 | x, y, r = MinCircleOverlap().get_min_circle_overlap(nums) 42 | ac.st(r) 43 | ac.lst([x, y]) 44 | return 45 | 46 | @staticmethod 47 | def lg_3517(ac=FastIO()): 48 | """ 49 | url: https://www.luogu.com.cn/problem/P3517 50 | tag: binary_search_of_binary_search|random_increment_method|minimum_circle_coverage 51 | """ 52 | n, m = ac.read_list_ints() 53 | nums = [ac.read_list_ints() for _ in range(n)] 54 | 55 | def check(r): 56 | 57 | def circle(lst): 58 | x, y, rr = MinCircleOverlap().get_min_circle_overlap(lst) 59 | return x, y, rr 60 | 61 | cnt = i = 0 62 | res = [] 63 | while i < n: 64 | left = i 65 | right = n - 1 66 | while left < right - 1: 67 | mm = left + (right - left) // 2 68 | if circle(nums[i:mm + 1])[2] <= r: 69 | left = mm 70 | else: 71 | right = mm 72 | ll = circle(nums[i:right + 1]) 73 | if ll[2] > r: 74 | ll = circle(nums[i:left + 1]) 75 | i = left + 1 76 | else: 77 | i = right + 1 78 | res.append(ll[:-1]) 79 | cnt += 1 80 | return res, cnt <= m 81 | 82 | low = 0 83 | high = 4 * 10 ** 6 84 | error = 10 ** (-6) 85 | while low < high - error: 86 | mid = low + (high - low) / 2 87 | if check(mid)[1]: 88 | high = mid 89 | else: 90 | low = mid 91 | 92 | nodes, flag = check(low) 93 | rrr = low 94 | if not flag: 95 | nodes, flag = check(high) 96 | rrr = high 97 | ac.st(rrr) 98 | ac.st(len(nodes)) 99 | for a in nodes: 100 | ac.lst([round(a[0], 10), round(a[1], 10)]) 101 | return 102 | -------------------------------------------------------------------------------- /src/string/kmp/template.py: -------------------------------------------------------------------------------- 1 | class KMP: 2 | def __init__(self): 3 | return 4 | 5 | @classmethod 6 | def prefix_function(cls, s): 7 | """calculate the longest common true prefix and true suffix for s [:i+1] and s [:i+1]""" 8 | n = len(s) # fail tree 9 | pi = [0] * n 10 | for i in range(1, n): 11 | j = pi[i - 1] 12 | while j > 0 and s[i] != s[j]: 13 | j = pi[j - 1] 14 | if s[i] == s[j]: # all pi[i] pi[pi[i]] ... are border 15 | j += 1 # all i+1-pi[i] pi[i]+1-pi[pi[i]] ... are circular_section 16 | pi[i] = j # pi[i] <= i also known as next 17 | # pi[0] = 0 18 | return pi # longest common true prefix_suffix / i+1-nex[i] is shortest circular_section 19 | 20 | @staticmethod 21 | def z_function(s): 22 | """calculate the longest common prefix between s[i:] and s""" 23 | n = len(s) 24 | z = [0] * n 25 | left, r = 0, 0 26 | for i in range(1, n): 27 | if i <= r and z[i - left] < r - i + 1: 28 | z[i] = z[i - left] 29 | else: 30 | z[i] = max(0, r - i + 1) 31 | while i + z[i] < n and s[z[i]] == s[i + z[i]]: 32 | z[i] += 1 33 | if i + z[i] - 1 > r: 34 | left = i 35 | r = i + z[i] - 1 36 | # z[0] = 0 37 | return z 38 | 39 | def prefix_function_reverse(self, s): 40 | n = len(s) 41 | nxt = [0] + self.prefix_function(s) 42 | nxt[1] = 0 43 | for i in range(2, n + 1): 44 | j = i 45 | while nxt[j]: 46 | j = nxt[j] 47 | if nxt[i]: 48 | nxt[i] = j 49 | return nxt[1:] # shortest common true prefix_suffix / i+1-nex[i] is longest circular_section 50 | 51 | def find(self, s1, s2): 52 | """find the index position of s2 in s1""" 53 | n, m = len(s1), len(s2) 54 | pi = self.prefix_function(s2 + "#" + s1) 55 | ans = [] 56 | for i in range(m + 1, m + n + 1): 57 | if pi[i] == m: 58 | ans.append(i - m - m) 59 | return ans 60 | 61 | def find_lst(self, s1, s2, tag=-1): 62 | """find the index position of s2 in s1""" 63 | n, m = len(s1), len(s2) 64 | pi = self.prefix_function(s2 + [tag] + s1) 65 | ans = [] 66 | for i in range(m + 1, m + n + 1): 67 | if pi[i] == m: 68 | ans.append(i - m - m) 69 | return ans 70 | 71 | def find_longest_palindrome(self, s, pos="prefix") -> int: 72 | """calculate the longest prefix and longest suffix palindrome substring""" 73 | if pos == "prefix": 74 | return self.prefix_function(s + "#" + s[::-1])[-1] 75 | return self.prefix_function(s[::-1] + "#" + s)[-1] 76 | 77 | @staticmethod 78 | def kmp_automaton(s, m=26): 79 | n = len(s) 80 | nxt = [0] * m * (n + 1) 81 | j = 0 82 | for i in range(1, n + 1): 83 | j = nxt[j * m + s[i - 1]] 84 | nxt[(i - 1) * m + s[i - 1]] = i 85 | for k in range(m): 86 | nxt[i * m + k] = nxt[j * m + k] 87 | return nxt 88 | 89 | @classmethod 90 | def merge_b_from_a(cls, a, b): 91 | c = b + "#" + a 92 | f = cls.prefix_function(c) 93 | m = len(b) 94 | if max(f[m:]) == m: 95 | return a 96 | x = f[-1] 97 | return a + b[x:] 98 | 99 | 100 | class InfiniteStream: 101 | def next(self) -> int: 102 | pass 103 | -------------------------------------------------------------------------------- /src/graph/binary_search_tree/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:binary_search_tree|binary_search_tree|array_to_bst|implemention 3 | Description:build a binary_search_tree by the order of array 4 | 5 | 6 | ====================================LeetCode==================================== 7 | 1569(https://leetcode.cn/problems/number-of-ways-to-reorder-array-to-get-same-bst/)array_to_bst|dp|comb|counter|specific_plan 8 | 1902(https://leetcode.cn/problems/depth-of-bst-given-insertion-order/)array_to_bst|tree_depth|implemention 9 | 10 | =====================================LuoGu====================================== 11 | P2171(https://www.luogu.com.cn/problem/P2171)array_to_bst|reverse_order|union_find|implemention 12 | 13 | """ 14 | from typing import List 15 | 16 | from src.graph.binary_search_tree.template import BinarySearchTree 17 | from src.math.comb_perm.template import Combinatorics 18 | from src.util.fast_io import FastIO 19 | 20 | 21 | class Solution: 22 | def __init__(self): 23 | return 24 | 25 | @staticmethod 26 | def lg_p2171(ac=FastIO()): 27 | """ 28 | url: https://www.luogu.com.cn/problem/P2171 29 | tag: array_to_bst|reverse_order|union_find|implemention 30 | """ 31 | ac.read_int() 32 | nums = ac.read_list_ints() 33 | dct = BinarySearchTree().build_with_unionfind(nums) # or build_with_stack 34 | ans = [] 35 | depth = 0 36 | stack = [[0, 1]] 37 | while stack: 38 | i, d = stack.pop() 39 | if i >= 0: 40 | stack.append([~i, d]) 41 | dct[i].sort(key=lambda it: -nums[it]) 42 | for j in dct[i]: 43 | stack.append([j, d + 1]) 44 | else: 45 | i = ~i 46 | depth = depth if depth > d else d 47 | ans.append(nums[i]) 48 | ac.st(f"deep={depth}") 49 | for a in ans: 50 | ac.st(a) 51 | return 52 | 53 | @staticmethod 54 | def lc_1569(nums: List[int]) -> int: 55 | """ 56 | url: https://leetcode.cn/problems/number-of-ways-to-reorder-array-to-get-same-bst/ 57 | tag: array_to_bst|dp|comb|counter|specific_plan 58 | """ 59 | mod = 10 ** 9 + 7 60 | cb = Combinatorics(1000, mod) 61 | dct = BinarySearchTree().build_with_unionfind(nums) # build_with_stack is also ok 62 | stack = [0] 63 | n = len(nums) 64 | ans = [0] * n 65 | sub = [0] * n 66 | while stack: 67 | i = stack.pop() 68 | if i >= 0: 69 | stack.append(~i) 70 | for j in dct[i]: 71 | stack.append(j) 72 | else: 73 | i = ~i 74 | cur_ans = 1 75 | cur_sub = sum(sub[j] for j in dct[i]) 76 | sub[i] = cur_sub + 1 77 | for j in dct[i]: 78 | cur_ans *= cb.comb(cur_sub, sub[j]) * ans[j] 79 | cur_sub -= sub[j] 80 | cur_ans %= mod 81 | ans[i] = cur_ans 82 | return (ans[0] - 1) % mod 83 | 84 | @staticmethod 85 | def lc_1902(order: List[int]) -> int: 86 | """ 87 | url: https://leetcode.cn/problems/depth-of-bst-given-insertion-order/ 88 | tag: array_to_bst|tree_depth|implemention 89 | """ 90 | dct = BinarySearchTree().build_with_stack(order) # or build_with_unionfind 91 | stack = [[0, 1]] 92 | ans = 1 93 | while stack: 94 | i, d = stack.pop() 95 | for j in dct[i]: 96 | stack.append([j, d + 1]) 97 | ans = ans if ans > d + 1 else d + 1 98 | return ans 99 | -------------------------------------------------------------------------------- /src/graph/bipartite_matching/template.py: -------------------------------------------------------------------------------- 1 | class BipartiteMatching: 2 | def __init__(self, n, m): 3 | self._n = n 4 | self._m = m 5 | self._to = [[] for _ in range(n)] 6 | 7 | def add_edge(self, a, b): 8 | self._to[a].append(b) 9 | 10 | def solve(self): 11 | n, m, to = self._n, self._m, self._to 12 | prev = [-1] * n 13 | root = [-1] * n 14 | p = [-1] * n 15 | q = [-1] * m 16 | updated = True 17 | while updated: 18 | updated = False 19 | s = [] 20 | s_front = 0 21 | for i in range(n): 22 | if p[i] == -1: 23 | root[i] = i 24 | s.append(i) 25 | while s_front < len(s): 26 | v = s[s_front] 27 | s_front += 1 28 | if p[root[v]] != -1: 29 | continue 30 | for u in to[v]: 31 | if q[u] == -1: 32 | while u != -1: 33 | q[u] = v 34 | p[v], u = u, p[v] 35 | v = prev[v] 36 | updated = True 37 | break 38 | u = q[u] 39 | if prev[u] != -1: 40 | continue 41 | prev[u] = v 42 | root[u] = root[v] 43 | s.append(u) 44 | if updated: 45 | for i in range(n): 46 | prev[i] = -1 47 | root[i] = -1 48 | return [(v, p[v]) for v in range(n) if p[v] != -1] 49 | 50 | 51 | class Hungarian: 52 | def __init__(self): 53 | # Bipartite graph maximum math without weight 54 | return 55 | 56 | @staticmethod 57 | def dfs_recursion(n, m, dct): 58 | assert len(dct) == m 59 | 60 | def hungarian(i): 61 | for j in dct[i]: 62 | if not visit[j]: 63 | visit[j] = True 64 | if match[j] == -1 or hungarian(match[j]): 65 | match[j] = i 66 | return True 67 | return False 68 | 69 | # left group size is n 70 | match = [-1] * n 71 | ans = 0 72 | for x in range(m): 73 | # right group size is m 74 | visit = [False] * n 75 | if hungarian(x): 76 | ans += 1 77 | return ans 78 | 79 | @staticmethod 80 | def bfs_iteration(n, m, dct): 81 | 82 | assert len(dct) == m 83 | 84 | match = [-1] * n 85 | ans = 0 86 | for i in range(m): 87 | hungarian = [0] * m 88 | visit = [0] * n 89 | stack = [[i, 0]] 90 | while stack: 91 | x, ind = stack[-1] 92 | if ind == len(dct[x]) or hungarian[x]: 93 | stack.pop() 94 | continue 95 | y = dct[x][ind] 96 | if not visit[y]: 97 | visit[y] = 1 98 | if match[y] == -1: 99 | match[y] = x 100 | hungarian[x] = 1 101 | else: 102 | stack.append([match[y], 0]) 103 | else: 104 | if hungarian[match[y]]: 105 | match[y] = x 106 | hungarian[x] = 1 107 | stack[-1][1] += 1 108 | if hungarian[i]: 109 | ans += 1 110 | return ans 111 | -------------------------------------------------------------------------------- /src/string/string_hash/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from src.string.string_hash.template import PointSetRangeHashReverse, RangeSetRangeHashReverse 5 | 6 | 7 | class TestGeneral(unittest.TestCase): 8 | 9 | def test_string_hash(self): 10 | n = 1000 11 | st = "".join([chr(random.randint(0, 25) + ord("a")) for _ in range(n)]) 12 | 13 | p1, p2 = random.randint(26, 100), random.randint(26, 100) 14 | mod1, mod2 = random.randint( 15 | 10 ** 9 + 7, 2 ** 31 - 1), random.randint(10 ** 9 + 7, 2 ** 31 - 1) 16 | 17 | target = "".join([chr(random.randint(0, 25) + ord("a")) 18 | for _ in range(10)]) 19 | h1 = h2 = 0 20 | for w in target: 21 | h1 = h1 * p1 + (ord(w) - ord("a")) 22 | h1 %= mod1 23 | h2 = h2 * p2 + (ord(w) - ord("a")) 24 | h2 %= mod1 25 | 26 | # sliding_window|hash 27 | m = len(target) 28 | pow1 = pow(p1, m - 1, mod1) 29 | pow2 = pow(p2, m - 1, mod2) 30 | s1 = s2 = 0 31 | cnt = 0 32 | n = len(st) 33 | for i in range(n): 34 | w = st[i] 35 | s1 = s1 * p1 + (ord(w) - ord("a")) 36 | s1 %= mod1 37 | s2 = s2 * p2 + (ord(w) - ord("a")) 38 | s2 %= mod1 39 | if i >= m - 1: 40 | if (s1, s2) == (h1, h2): 41 | cnt += 1 42 | s1 = s1 - (ord(st[i - m + 1]) - ord("a")) * pow1 43 | s1 %= mod1 44 | s2 = s2 - (ord(st[i - m + 1]) - ord("a")) * pow2 45 | s2 %= mod1 46 | if st[i:i + m] == target: 47 | cnt -= 1 48 | assert cnt == 0 49 | return 50 | 51 | def test_point_set_range_hash_reverse(self): 52 | 53 | n = 10 ** 4 54 | nums = [0] * n 55 | tree = PointSetRangeHashReverse(n) 56 | for _ in range(1000): 57 | i = random.randint(0, n - 1) 58 | num = random.randint(0, n - 1) 59 | nums[i] = num 60 | tree.point_set(i, i, num) 61 | ll = random.randint(0, n - 1) 62 | rr = random.randint(ll, n - 1) 63 | res = 0 64 | for j in range(ll, rr + 1): 65 | res = (res * tree.p + nums[j]) % tree.mod 66 | assert res == tree.range_hash(ll, rr) 67 | res = 0 68 | for j in range(rr, ll - 1, -1): 69 | res = (res * tree.p + nums[j]) % tree.mod 70 | assert res == tree.range_hash_reverse(ll, rr) 71 | assert nums == tree.get() 72 | return 73 | 74 | def test_range_change_range_hash_reverse(self): 75 | 76 | n = 10 ** 4 77 | nums = [0] * n 78 | tree = RangeSetRangeHashReverse(n) 79 | for _ in range(1000): 80 | ll = random.randint(0, n - 1) 81 | rr = random.randint(ll, n - 1) 82 | num = random.randint(0, n - 1) 83 | for i in range(ll, rr + 1): 84 | nums[i] = num 85 | tree.range_set(ll, rr, num) 86 | 87 | ll = random.randint(0, n - 1) 88 | rr = random.randint(ll, n - 1) 89 | res = 0 90 | for j in range(ll, rr + 1): 91 | res = (res * tree.p + nums[j]) % tree.mod 92 | assert res == tree.range_hash(ll, rr) 93 | res = 0 94 | for j in range(rr, ll - 1, -1): 95 | res = (res * tree.p + nums[j]) % tree.mod 96 | assert res == tree.range_hash_reverse(ll, rr) 97 | assert nums == tree.get() 98 | return 99 | 100 | 101 | if __name__ == '__main__': 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /src/math/extend_crt/problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Algorithm:chinese_remainder_theorem|extended_chinese_remainder_theorem|ex_crt|crt 3 | Description:equation|same_mod 4 | 5 | 6 | ====================================LeetCode==================================== 7 | 8 | =====================================LuoGu====================================== 9 | p1495(https://www.luogu.com.cn/problem/p1495)mod_coprime|chinese_reminder_theorem|classical 10 | P4777(https://www.luogu.com.cn/problem/P4777)mod_not_coprime|crt|chinese_reminder_theorem|classical 11 | P3868(https://www.luogu.com.cn/problem/P3868)ex_crt|chinese_reminder_theorem|classical 12 | 13 | ====================================AtCoder===================================== 14 | ABC286F(https://atcoder.jp/contests/abc286/tasks/abc286_f)chinese_reminder_theorem|interaction|circular_section|classical 15 | ABC371G(https://atcoder.jp/contests/abc371/tasks/abc371_g)ex_crt|implemention|greedy|classical 16 | 17 | 18 | ===================================CodeForces=================================== 19 | 20 | """ 21 | from src.math.extend_crt.template import CRT, ExtendCRT 22 | from src.math.number_theory.template import PrimeSieve 23 | from src.util.fast_io import FastIO 24 | 25 | 26 | class Solution: 27 | def __int__(self): 28 | return 29 | 30 | @staticmethod 31 | def main(ac=FastIO()): 32 | """ 33 | url: https://atcoder.jp/contests/abc286/tasks/abc286_f 34 | tag: chinese_reminder_theorem|interaction|circular_section|classical 35 | """ 36 | ac.flush = True 37 | lst = [x for x in PrimeSieve().eratosthenes_sieve(110) if x < 110] 38 | tot = lst[:9] 39 | tot[0] *= tot[0] 40 | tot[1] *= tot[1] 41 | m = sum(tot) 42 | assert m == 108 43 | nums = list(range(1, m + 1)) 44 | pre = 0 45 | circle = dict() 46 | for num in tot: 47 | tmp = nums[pre:pre + num] 48 | nums[pre:pre + num] = tmp[1:] + tmp[:1] 49 | circle[pre + 1] = tmp[1:] + tmp[:1] 50 | pre += num 51 | ac.st(m) 52 | ac.lst(nums) 53 | b = ac.read_list_ints_minus_one() 54 | mod_res = [] 55 | pre = 0 56 | for num in tot: 57 | tmp = b[pre:pre + num] 58 | mod_res.append((num, (tmp[0] - pre) % num)) 59 | pre += num 60 | ans = CRT().chinese_remainder(mod_res) 61 | ac.st(ans) 62 | return 63 | 64 | @staticmethod 65 | def lg_p1495(ac=FastIO()): 66 | """ 67 | url: https://www.luogu.com.cn/problem/P1495 68 | tag: chinese_reminder_theorem|classical 69 | """ 70 | n = ac.read_int() 71 | crt = CRT() 72 | pairs = [ac.read_list_ints() for _ in range(n)] 73 | ans = crt.chinese_remainder(pairs) 74 | ac.st(ans) 75 | return 76 | 77 | @staticmethod 78 | def lg_p4777(ac=FastIO()): 79 | """ 80 | url: https://www.luogu.com.cn/problem/P4777 81 | tag: chinese_reminder_theorem|classical 82 | """ 83 | n = ac.read_int() 84 | ex_crt = ExtendCRT() 85 | pairs = [ac.read_list_ints()[::-1] for _ in range(n)] 86 | ans = ex_crt.pipline(pairs)[0] 87 | ac.st(ans) 88 | return 89 | 90 | @staticmethod 91 | def lg_p3868(ac=FastIO()): 92 | """ 93 | url: https://www.luogu.com.cn/problem/P3868 94 | tag: chinese_reminder_theorem|classical 95 | """ 96 | ac.read_int() 97 | a = ac.read_list_ints() 98 | b = ac.read_list_ints() 99 | ex_crt = ExtendCRT() 100 | pairs = [[x % y, y] for x, y in zip(a, b)] 101 | ans = ex_crt.pipline(pairs)[0] 102 | ac.st(ans) 103 | return 104 | -------------------------------------------------------------------------------- /src/string/automaton/template.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import math 3 | 4 | 5 | class Node: 6 | __slots__ = 'son', 'fail', 'last', 'len', 'val' 7 | 8 | def __init__(self): 9 | self.son = {} 10 | self.fail = self.last = None 11 | self.len = 0 12 | self.val = math.inf 13 | 14 | 15 | class AhoCorasick: 16 | def __init__(self): 17 | self.root = Node() 18 | 19 | def insert(self, word, cost): 20 | x = self.root 21 | for c in word: 22 | if c not in x.son: 23 | x.son[c] = Node() 24 | x = x.son[c] 25 | x.len = len(word) 26 | x.val = min(x.val, cost) 27 | 28 | def set_fail(self): 29 | q = deque() 30 | for x in self.root.son.values(): 31 | x.fail = x.last = self.root 32 | q.append(x) 33 | while q: 34 | x = q.popleft() 35 | for c, son in x.son.items(): 36 | p = x.fail 37 | while p and c not in p.son: 38 | p = p.fail 39 | son.fail = p.son[c] if p else self.root 40 | son.last = son.fail if son.fail.len else son.fail.last 41 | q.append(son) 42 | 43 | def search(self, target): 44 | pos = [[] for _ in range(len(target))] 45 | x = self.root 46 | for i, c in enumerate(target): 47 | while x and c not in x.son: 48 | x = x.fail 49 | x = x.son[c] if x else self.root 50 | cur = x 51 | while cur: 52 | if cur.len: 53 | pos[i - cur.len + 1].append(cur.val) 54 | cur = cur.last 55 | return pos 56 | 57 | 58 | class AcAutomaton: 59 | def __init__(self, p): 60 | self.m = sum(len(t) for t in p) 61 | self.n = len(p) 62 | self.p = p 63 | self.tr = [[0] * 26 for _ in range(self.m + 1)] 64 | self.end = [0] * (self.m + 1) 65 | self.fail = [0] * (self.m + 1) 66 | self.cnt = 0 67 | for i, t in enumerate(self.p): 68 | self.insert(i + 1, t) 69 | self.set_fail() 70 | return 71 | 72 | def insert(self, i: int, word: str): 73 | x = 0 74 | for c in word: 75 | c = ord(c) - ord('a') 76 | if self.tr[x][c] == 0: 77 | self.cnt += 1 78 | self.tr[x][c] = self.cnt 79 | x = self.tr[x][c] 80 | self.end[i] = x 81 | 82 | def search(self, s): 83 | freq = [0] * (self.cnt + 1) 84 | x = 0 85 | for c in s: 86 | x = self.tr[x][ord(c) - ord('a')] 87 | freq[x] += 1 88 | 89 | rg = [[] for _ in range(self.cnt + 1)] 90 | for i in range(self.cnt + 1): 91 | rg[self.fail[i]].append(i) 92 | 93 | vis = [False] * (self.cnt + 1) 94 | st = [0] 95 | while st: 96 | x = st[-1] 97 | if not vis[x]: 98 | vis[x] = True 99 | for y in rg[x]: 100 | st.append(y) 101 | else: 102 | st.pop() 103 | for y in rg[x]: 104 | freq[x] += freq[y] 105 | 106 | res = [freq[self.end[i]] for i in range(1, self.n + 1)] 107 | return res 108 | 109 | def set_fail(self): 110 | q = deque([self.tr[0][i] for i in range(26) if self.tr[0][i]]) 111 | while q: 112 | x = q.popleft() 113 | for i in range(26): 114 | if self.tr[x][i] == 0: 115 | self.tr[x][i] = self.tr[self.fail[x]][i] 116 | else: 117 | self.fail[self.tr[x][i]] = self.tr[self.fail[x]][i] 118 | q.append(self.tr[x][i]) 119 | return 120 | -------------------------------------------------------------------------------- /src/dp/matrix_dp/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class MatrixDP: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def lcp(s, t): 10 | # longest common prefix of s[i:] and t[j:] 11 | m, n = len(s), len(t) 12 | dp = [[0] * (n + 1) for _ in range(m + 1)] 13 | for i in range(n - 1, -1, -1): 14 | for j in range(n - 1, -1, -1): 15 | if s[i] == t[j]: 16 | dp[i][j] = dp[i + 1][j + 1] + 1 17 | return dp 18 | 19 | @staticmethod 20 | def min_distance(word1: str, word2: str): 21 | m, n = len(word1), len(word2) 22 | dp = [[math.inf] * (n + 1) for _ in range(m + 1)] 23 | # edit distance 24 | for i in range(m + 1): 25 | dp[i][n] = m - i 26 | for j in range(n + 1): 27 | dp[m][j] = n - j 28 | for i in range(m - 1, -1, -1): 29 | for j in range(n - 1, -1, -1): 30 | dp[i][j] = min(dp[i + 1][j] + 1, dp[i][j + 1] + 1, 31 | dp[i + 1][j + 1] + int(word1[i] != word2[j])) 32 | return dp[0][0] 33 | 34 | @staticmethod 35 | def path_mul_mod(m, n, k, grid): 36 | # calculate the modulus of the product of the matrix from the upper left corner to the lower right corner 37 | dp = [[set() for _ in range(n)] for _ in range(m)] 38 | dp[0][0].add(grid[0][0] % k) 39 | for i in range(1, m): 40 | x = grid[i][0] 41 | for p in dp[i - 1][0]: 42 | dp[i][0].add((p * x) % k) 43 | for j in range(1, n): 44 | x = grid[0][j] 45 | for p in dp[0][j - 1]: 46 | dp[0][j].add((p * x) % k) 47 | 48 | for i in range(1, m): 49 | for j in range(1, n): 50 | x = grid[i][j] 51 | for p in dp[i][j - 1]: 52 | dp[i][j].add((p * x) % k) 53 | for p in dp[i - 1][j]: 54 | dp[i][j].add((p * x) % k) 55 | ans = sorted(list(dp[-1][-1])) 56 | return ans 57 | 58 | @staticmethod 59 | def maximal_square(matrix) -> int: 60 | 61 | # The maximum square sub matrix with all value equal to 1 62 | m, n = len(matrix), len(matrix[0]) 63 | dp = [[0] * (n + 1) for _ in range(m + 1)] 64 | ans = 0 65 | for i in range(m): 66 | for j in range(n): 67 | if matrix[i][j] == "1": 68 | dp[i + 1][j + 1] = min(dp[i][j], dp[i + 1][j], dp[i][j + 1]) + 1 69 | if dp[i + 1][j + 1] > ans: 70 | ans = dp[i + 1][j + 1] 71 | # the ans is side length and ans**2 is area 72 | return ans ** 2 73 | 74 | @staticmethod 75 | def longest_common_sequence(s1, s2, s3) -> str: 76 | # Longest common subsequence LCS can be extended to 3D and 4D or higher dimension 77 | m, n, k = len(s1), len(s2), len(s3) 78 | # length of lcs 79 | dp = [[[0] * (k + 1) for _ in range(n + 1)] for _ in range(m + 1)] 80 | # example of lcs 81 | res = [[[""] * (k + 1) for _ in range(n + 1)] for _ in range(m + 1)] 82 | for i in range(m): 83 | for j in range(n): 84 | for p in range(k): 85 | if s1[i] == s2[j] == s3[p]: 86 | if dp[i + 1][j + 1][p + 1] < dp[i][j][p] + 1: 87 | dp[i + 1][j + 1][p + 1] = dp[i][j][p] + 1 88 | res[i + 1][j + 1][p + 1] = res[i][j][p] + s1[i] 89 | else: 90 | for a, b, c in [[1, 1, 0], [0, 1, 1], [1, 0, 1]]: # transfer formula 91 | if dp[i + 1][j + 1][p + 1] < dp[i + a][j + b][p + c]: 92 | dp[i + 1][j + 1][p + 1] = dp[i + a][j + b][p + c] 93 | res[i + 1][j + 1][p + 1] = res[i + a][j + b][p + c] 94 | return res[m][n][k] 95 | -------------------------------------------------------------------------------- /src/math/fast_power/template.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class FastPower: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def fast_power_api(a, b, mod): 10 | return pow(a, b, mod) 11 | 12 | @staticmethod 13 | def fast_power(a, b, mod): 14 | a = a % mod 15 | res = 1 16 | while b > 0: 17 | if b & 1: 18 | res = res * a % mod 19 | a = a * a % mod 20 | b >>= 1 21 | return res 22 | 23 | @staticmethod 24 | def float_fast_pow(x: float, m: int) -> float: 25 | 26 | if m >= 0: 27 | res = 1 28 | while m > 0: 29 | if m & 1: 30 | res *= x 31 | x *= x 32 | m >>= 1 33 | return res 34 | m = -m 35 | res = 1 36 | while m > 0: 37 | if m & 1: 38 | res *= x 39 | x *= x 40 | m >>= 1 41 | return 1.0 / res 42 | 43 | 44 | class MatrixFastPowerFlatten: 45 | def __init__(self): 46 | return 47 | 48 | @staticmethod 49 | def matrix_pow_flatten(base, n, p, mod=10 ** 9 + 7): 50 | assert len(base) == n * n 51 | res = [0] * n * n 52 | ans = [0] * n * n 53 | for i in range(n): 54 | ans[i * n + i] = 1 55 | while p: 56 | if p & 1: 57 | for i in range(n): 58 | for j in range(n): 59 | cur = 0 60 | for k in range(n): 61 | cur += ans[i * n + k] * base[k * n + j] 62 | cur %= mod 63 | res[i * n + j] = cur 64 | for i in range(n): 65 | for j in range(n): 66 | ans[i * n + j] = res[i * n + j] 67 | for i in range(n): 68 | for j in range(n): 69 | cur = 0 70 | for k in range(n): 71 | cur += base[i * n + k] * base[k * n + j] 72 | cur %= mod 73 | res[i * n + j] = cur 74 | for i in range(n): 75 | for j in range(n): 76 | base[i * n + j] = res[i * n + j] 77 | p >>= 1 78 | return ans 79 | 80 | class MatrixFastPowerMin: 81 | def __init__(self): 82 | return 83 | 84 | @staticmethod 85 | def _matrix_mul(a, b): 86 | n = len(a) 87 | res = [[0] * n for _ in range(n)] 88 | for i in range(n): 89 | for j in range(n): 90 | res[i][j] = min(max(a[i][k], b[k][j]) for k in range(n)) 91 | return res 92 | 93 | def matrix_pow(self, base, p): 94 | n = len(base) 95 | ans = [[math.inf] * n for _ in range(n)] 96 | for i in range(n): 97 | ans[i][i] = 0 98 | while p: 99 | if p & 1: 100 | ans = self._matrix_mul(ans, base) 101 | base = self._matrix_mul(base, base) 102 | p >>= 1 103 | return ans 104 | 105 | 106 | class MatrixFastPower: 107 | def __init__(self): 108 | return 109 | 110 | @staticmethod 111 | def _matrix_mul(a, b, mod=10 ** 9 + 7): 112 | n = len(a) 113 | res = [[0] * n for _ in range(n)] 114 | for i in range(n): 115 | for j in range(n): 116 | res[i][j] = sum(a[i][k] * b[k][j] for k in range(n)) % mod 117 | return res 118 | 119 | def matrix_pow(self, base, p, mod=10 ** 9 + 7): 120 | n = len(base) 121 | ans = [[0] * n for _ in range(n)] 122 | for i in range(n): 123 | ans[i][i] = 1 124 | while p: 125 | if p & 1: 126 | ans = self._matrix_mul(ans, base, mod) 127 | base = self._matrix_mul(base, base, mod) 128 | p >>= 1 129 | return ans 130 | -------------------------------------------------------------------------------- /src/struct/trie_like/example.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from collections import Counter 4 | 5 | from src.struct.trie_like.problem import Solution 6 | from src.struct.trie_like.template import BinaryTrieXor, StringTriePrefix 7 | 8 | 9 | class TestGeneral(unittest.TestCase): 10 | 11 | def test_binary_trie(self): 12 | random.seed(2024) 13 | for mi in range(10): 14 | max_num = 10 ** mi 15 | num_cnt = 5 * 10 ** 3 16 | trie = BinaryTrieXor(max_num, num_cnt) 17 | nums = [] 18 | for i in range(num_cnt): 19 | x = random.randint(0, 1) 20 | if x == 0 and nums: 21 | j = random.randint(0, len(nums) - 1) 22 | c = min(random.randint(1, nums.count(nums[j])), 3) 23 | num = nums[j] 24 | for _ in range(c): 25 | nums.remove(num) 26 | assert trie.remove(num, c) 27 | else: 28 | num = random.randint(0, max_num) 29 | c = random.randint(1, 3) 30 | for _ in range(c): 31 | nums.append(num) 32 | assert trie.add(num, c) 33 | dct = Counter(nums) 34 | for num in dct: 35 | assert trie.count(num) == dct[num] 36 | assert trie.son_and_cnt[0] & trie.mask == len(nums) 37 | if nums: 38 | num = random.randint(0, max_num) 39 | lst = [num ^ x for x in nums] 40 | lst.sort(reverse=True) 41 | assert trie.get_maximum_xor(num) == lst[0] 42 | assert trie.get_minimum_xor(num) == lst[-1] 43 | res = [trie.get_kth_maximum_xor(num, rk + 1) for rk in range(len(lst))] 44 | assert res == lst 45 | 46 | y = random.randint(0, max_num) 47 | assert trie.get_cnt_smaller_xor(num, y) == sum(num ^ x <= y for x in nums) 48 | return 49 | 50 | def test_string_trie(self): 51 | for _ in range(10): 52 | word_cnt = 10 ** 4 53 | word_length = 10 54 | trie = StringTriePrefix(word_cnt * word_length, word_cnt) 55 | words = [] 56 | for i in range(word_cnt): 57 | word = "".join(chr(ord("a") + random.randint(0, 25)) for _ in range(random.randint(1, word_length))) 58 | words.append(word) 59 | trie.add([ord(w) - ord("a") for w in word]) 60 | for i in range(word_cnt): 61 | word = "".join(chr(ord("a") + random.randint(0, 25)) for _ in range(random.randint(1, word_length))) 62 | res = 0 63 | for s in words: 64 | for j in range(min(len(word), len(s))): 65 | if word[j] == s[j]: 66 | res += 1 67 | else: 68 | break 69 | assert res == trie.count([ord(w) - ord("a") for w in word]) 70 | return 71 | 72 | def test_solution_lc_421_1(self): # 411 ms 73 | random.seed(2024) 74 | nums = [random.randint(0, (1 << 31) - 1) for _ in range(2 * 10 ** 5)] 75 | Solution().lc_421_1(nums) 76 | nums = list(range(2 * 10 ** 5)) 77 | Solution().lc_421_1(nums) 78 | return 79 | 80 | def test_solution_lc_421_2(self): # 247 ms 81 | random.seed(2024) 82 | nums = [random.randint(0, (1 << 31) - 1) for _ in range(2 * 10 ** 5)] 83 | Solution().lc_421_2(nums) 84 | nums = list(range(2 * 10 ** 5)) 85 | Solution().lc_421_2(nums) 86 | return 87 | 88 | def test_solution_lc_421(self): 89 | random.seed(2024) 90 | nums = [random.randint(0, (1 << 31) - 1) for _ in range(2 * 10 ** 5)] 91 | assert Solution().lc_421_1(nums) == Solution().lc_421_2(nums) 92 | return 93 | 94 | 95 | if __name__ == '__main__': 96 | unittest.main() 97 | -------------------------------------------------------------------------------- /src/basis/range/template.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class Range: 5 | def __init__(self): 6 | return 7 | 8 | @staticmethod 9 | def range_merge_to_disjoint(lst): 10 | """range_merge_to_disjoint intervals into disjoint intervals""" 11 | lst.sort(key=lambda it: it[0]) 12 | ans = [] 13 | x, y = lst[0] 14 | for a, b in lst[1:]: 15 | if a <= y: # [1, 3] + [3, 4] = [1, 4] 16 | # if wanted range_merge_to_disjoint like [1, 2] + [3, 4] = [1, 4] can change to a <= y+1 or a < y 17 | y = y if y > b else b 18 | else: 19 | ans.append([x, y]) 20 | x, y = a, b 21 | ans.append([x, y]) 22 | return ans 23 | 24 | @staticmethod 25 | def minimum_range_cover(s, t, lst, inter=True): 26 | """calculate the minimum number of intervals in lst for coverage [s, t]""" 27 | if not lst: 28 | return -1 29 | # [1, 3] + [3, 4] = [1, 4] by set inter=True 30 | # [1, 2] + [3, 4] = [1, 4] by set inter=False 31 | lst.sort(key=lambda x: [x[0], -x[1]]) 32 | if lst[0][0] != s: 33 | return -1 34 | if lst[0][1] >= t: 35 | return 1 36 | ans = 1 37 | end = lst[0][1] 38 | cur = -1 39 | for a, b in lst[1:]: 40 | if end >= t: 41 | return ans 42 | # can be next disjoint set 43 | if (end >= a and inter) or (not inter and end >= a - 1): 44 | cur = cur if cur > b else b 45 | else: 46 | if cur <= end: 47 | return -1 48 | # add new farthest range 49 | ans += 1 50 | end = cur 51 | cur = -1 52 | if end >= t: 53 | return ans 54 | if (end >= a and inter) or (not inter and end >= a - 1): 55 | cur = cur if cur > b else b 56 | else: 57 | return -1 # which is impossible to coverage [s, t] 58 | if cur >= t: 59 | ans += 1 60 | return ans 61 | return -1 62 | 63 | @staticmethod 64 | def minimum_interval_coverage(clips, time: int, inter=True) -> int: 65 | """calculate the minimum number of intervals in clips for coverage [0, time]""" 66 | assert inter 67 | assert time >= 0 68 | if not clips: 69 | return -1 70 | if time == 0: 71 | if min(x for x, _ in clips) > 0: 72 | return -1 73 | return 1 74 | 75 | if inter: 76 | # inter=True is necessary 77 | post = [0] * time 78 | for a, b in clips: 79 | if a < time: 80 | post[a] = post[a] if post[a] > b else b 81 | if not post[0]: 82 | return -1 83 | 84 | ans = right = pre_end = 0 85 | for i in range(time): 86 | right = right if right > post[i] else post[i] 87 | if i == right: 88 | return -1 89 | if i == pre_end: 90 | ans += 1 91 | pre_end = right 92 | else: 93 | ans = -1 94 | return ans 95 | 96 | @staticmethod 97 | def maximum_disjoint_range(lst): 98 | """select the maximum disjoint intervals""" 99 | lst.sort(key=lambda x: x[1]) 100 | ans = 0 101 | end = -math.inf 102 | for a, b in lst: 103 | if a >= end: 104 | ans += 1 105 | end = b 106 | return ans 107 | 108 | @staticmethod 109 | def minimum_point_cover_range(lst): 110 | """find the minimum number of point such that every range in lst has at least one point""" 111 | if not lst: 112 | return 0 113 | lst.sort(key=lambda it: it[1]) 114 | ans = 1 115 | a, b = lst[0] 116 | for c, d in lst[1:]: 117 | if b < c: 118 | ans += 1 119 | b = d 120 | return ans 121 | --------------------------------------------------------------------------------