├── .gitignore ├── BFS.hs ├── BinarySearch.hs ├── CHANGELOG.md ├── Enumeration.hs ├── Geom.hs ├── IdempotentSemigroup.hs ├── Interval.hs ├── LICENSE ├── NumberTheory.hs ├── Perm.hs ├── Queue.hs ├── README.md ├── Scanner.hs ├── ScannerBS.hs ├── SegTree.hs ├── Sieve.hs ├── Slice.hs ├── SparseTable.hs ├── SqrtTree.hs ├── Stack.hs ├── Tree.hs ├── Trie.hs ├── UnionFind.hs ├── Util.hs └── comprog-hs.cabal /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | TAGS 3 | *.o 4 | *.hi 5 | dist-newstyle 6 | -------------------------------------------------------------------------------- /BFS.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2021/10/14/competitive-programming-in-haskell-bfs-part-1/ 2 | -- https://byorgey.wordpress.com/2021/10/18/competitive-programming-in-haskell-bfs-part-2-alternative-apis/ 3 | -- https://byorgey.wordpress.com/2021/10/29/competitive-programming-in-haskell-bfs-part-3-implementation-via-hashmap/ 4 | -- https://byorgey.wordpress.com/2021/11/15/competitive-programming-in-haskell-enumeration/ 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE RankNTypes #-} 7 | {-# LANGUAGE RecordWildCards #-} 8 | {-# LANGUAGE ScopedTypeVariables #-} 9 | 10 | module BFS where 11 | 12 | import Enumeration 13 | 14 | import Control.Arrow ((>>>)) 15 | import Control.Monad 16 | import Control.Monad.ST 17 | import qualified Data.Array.IArray as IA 18 | import Data.Array.ST 19 | import Data.Array.Unboxed (UArray) 20 | import qualified Data.Array.Unboxed as U 21 | import Data.Array.Unsafe (unsafeFreeze) 22 | import Data.Sequence (Seq (..), ViewL (..), (<|), (|>)) 23 | import qualified Data.Sequence as Seq 24 | 25 | ------------------------------------------------------------ 26 | -- Utilities 27 | ------------------------------------------------------------ 28 | 29 | infixl 0 >$> 30 | (>$>) :: a -> (a -> b) -> b 31 | (>$>) = flip ($) 32 | {-# INLINE (>$>) #-} 33 | 34 | exhaustM :: Monad m => (a -> m (Maybe a)) -> a -> m a 35 | exhaustM f = go 36 | where 37 | go a = do 38 | ma <- f a 39 | maybe (return a) go ma 40 | 41 | ------------------------------------------------------------ 42 | -- BFS 43 | ------------------------------------------------------------ 44 | 45 | data BFSResult v = BFSR {getLevel :: v -> Maybe Int, getParent :: v -> Maybe v} 46 | 47 | type V = Int 48 | data BFSState s = BS {level :: STUArray s V Int, parent :: STUArray s V V, queue :: Seq V} 49 | 50 | initBFSState :: Int -> [Int] -> ST s (BFSState s) 51 | initBFSState n vs = do 52 | l <- newArray (0, n - 1) (-1) 53 | p <- newArray (0, n - 1) (-1) 54 | 55 | forM_ vs $ \v -> writeArray l v 0 56 | return $ BS l p (Seq.fromList vs) 57 | 58 | bfs :: forall v. Enumeration v -> [v] -> (v -> [v]) -> (v -> Bool) -> BFSResult v 59 | bfs Enumeration {..} vs next goal = 60 | toResult $ bfs' card (map locate vs) (map locate . next . select) (goal . select) 61 | where 62 | toResult :: (forall s. ST s (BFSState s)) -> BFSResult v 63 | toResult m = runST $ do 64 | st <- m 65 | (level' :: UArray V Int) <- unsafeFreeze (level st) 66 | (parent' :: UArray V V) <- unsafeFreeze (parent st) 67 | return $ 68 | BFSR 69 | ((\l -> guard (l /= -1) >> Just l) . (level' IA.!) . locate) 70 | ((\p -> guard (p /= -1) >> Just (select p)) . (parent' IA.!) . locate) 71 | 72 | visited :: BFSState s -> V -> ST s Bool 73 | visited BS {..} v = (/= -1) <$> readArray level v 74 | {-# INLINE visited #-} 75 | 76 | bfs' :: Int -> [V] -> (V -> [V]) -> (V -> Bool) -> ST s (BFSState s) 77 | bfs' n vs next goal = do 78 | st <- initBFSState n vs 79 | exhaustM bfsStep st 80 | where 81 | bfsStep st@BS {..} = case Seq.viewl queue of 82 | EmptyL -> return Nothing 83 | v :< q' 84 | | goal v -> return Nothing 85 | | otherwise -> 86 | v >$> next 87 | >>> filterM (fmap not . visited st) 88 | >=> foldM (upd v) (st {queue = q'}) 89 | >>> fmap Just 90 | 91 | upd p b@BS {..} v = do 92 | lp <- readArray level p 93 | writeArray level v (lp + 1) 94 | writeArray parent v p 95 | return $ b {queue = queue |> v} 96 | -------------------------------------------------------------------------------- /BinarySearch.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2023/01/01/competitive-programming-in-haskell-better-binary-search/ 2 | -- https://byorgey.wordpress.com/2023/01/02/binary-search-over-floating-point-representations/ 3 | -- https://julesjacobs.com/notes/binarysearch/binarysearch.pdf 4 | 5 | module BinarySearch where 6 | 7 | import Data.Bits 8 | import Data.Word (Word64) 9 | 10 | -- | Generic search function. Takes a step function, a predicate, and 11 | -- low and high values, and it finds values (l,r) right next to each 12 | -- other where the predicate switches from False to True. 13 | -- 14 | -- More specifically, for @search mid p l r@: 15 | -- - PRECONDITIONS: 16 | -- - @(p l)@ must be @False@ and @(p r)@ must be @True@! 17 | -- - If using one of the binary search step functions, the 18 | -- predicate @p@ must be monotonic, that is, intuitively, @p@ 19 | -- must switch from @False@ to @True@ exactly once. Formally, if 20 | -- @x <= y@ then @p x <= p y@ (where @False <= True@). 21 | -- - The mid function is called as @mid l r@ to find the next 22 | -- index to search on the interval [l, r]. @mid l r@ must 23 | -- return a value strictly between @l@ and @r@. 24 | -- - Use 'binary' to do binary search over the integers. 25 | -- - Use @('continuous' eps)@ to do binary search over rational 26 | -- or floating point values. (l,r) are returned such that r - 27 | -- l <= eps. 28 | -- - Use 'fwd' or 'bwd' to do linear search through a range. 29 | -- - If you are feeling adventurous, use 'floating' to do 30 | -- precise binary search over the bit representation of 31 | -- 'Double' values. 32 | -- - Returns (l,r) such that: 33 | -- - @not (p l) && p r@ 34 | -- - @mid l r == Nothing@ 35 | search :: (a -> a -> Maybe a) -> (a -> Bool) -> a -> a -> (a, a) 36 | search mid p = go 37 | where 38 | go l r = case mid l r of 39 | Nothing -> (l, r) 40 | Just m 41 | | p m -> go l m 42 | | otherwise -> go m r 43 | 44 | -- | Step function for binary search over an integral type. Stops 45 | -- when @r - l <= 1@; otherwise returns their midpoint. 46 | binary :: Integral a => a -> a -> Maybe a 47 | binary l r 48 | | r - l > 1 = Just (l + (r - l) `div` 2) 49 | | otherwise = Nothing 50 | 51 | -- | Step function for continuous binary search. Stops once 52 | -- @r - l <= eps@; otherwise returns their midpoint. 53 | continuous :: (Fractional a, Ord a) => a -> a -> a -> Maybe a 54 | continuous eps l r 55 | | r - l > eps = Just (l + (r - l) / 2) 56 | | otherwise = Nothing 57 | 58 | -- | Step function for forward linear search. Stops when @r - l <= 59 | -- 1$; otherwise returns @l + 1@. 60 | fwd :: (Num a, Ord a) => a -> a -> Maybe a 61 | fwd l r 62 | | r - l > 1 = Just (l + 1) 63 | | otherwise = Nothing 64 | 65 | -- | Step function for backward linear search. Stops when @r - l <= 66 | -- 1$; otherwise returns @r - 1@. 67 | bwd :: (Num a, Ord a) => a -> a -> Maybe a 68 | bwd l r 69 | | r - l > 1 = Just (r - 1) 70 | | otherwise = Nothing 71 | 72 | ------------------------------------------------------------ 73 | -- Binary search on floating-point representations 74 | 75 | -- A lot of blood, sweat, and tears went into these functions. Are 76 | -- they even correct? Was it worth it? Who knows! See: 77 | -- 78 | -- https://byorgey.wordpress.com/2023/01/01/competitive-programming-in-haskell-better-binary-search/#comment-40882 79 | -- https://byorgey.wordpress.com/2023/01/02/binary-search-over-floating-point-representations/ 80 | -- https://web.archive.org/web/20220326204603/http://stereopsis.com/radix.html 81 | 82 | -- | Step function for precise binary search over the bit 83 | -- representation of @Double@ values. 84 | floating :: Double -> Double -> Maybe Double 85 | floating l r = b2f <$> binary (f2b l) (f2b r) 86 | 87 | -- | A monotonic conversion from 'Double' to 'Word64'. That is, 88 | -- @x < y@ iff @f2b x < f2b y@. 'b2f' is its inverse. 89 | f2b :: Double -> Word64 90 | f2b 0 = 0x7fffffffffffffff 91 | f2b x = 92 | (if m < 0 then 0 else bit 63) 93 | `xor` flipNeg (eBits `xor` mBits) 94 | where 95 | (m, e) = decodeFloat x 96 | eBits = fromIntegral (e + d + bias) `shiftL` d 97 | mBits = fromIntegral (abs m) `clearBit` d 98 | 99 | d = floatDigits x - 1 100 | bias = 1023 101 | flipNeg 102 | | m < 0 = (`clearBit` 63) . complement 103 | | otherwise = id 104 | 105 | prop_f2b_monotonic :: Double -> Double -> Bool 106 | prop_f2b_monotonic x y = (x < y) == (f2b x < f2b y) 107 | 108 | -- | The left inverse of 'f2b', that is, for all @x :: Double@, @b2f 109 | -- (f2b x) == x@. Note @f2b (b2f x) == x@ does not strictly hold 110 | -- since not every @Word64@ value corresponds to a valid @Double@ 111 | -- value. 112 | b2f :: Word64 -> Double 113 | b2f 0x7fffffffffffffff = 0 114 | b2f w = encodeFloat m (fromIntegral e - bias - d) 115 | where 116 | s = testBit w 63 117 | w' = (if s then id else complement) w 118 | 119 | d = floatDigits (1 :: Double) - 1 120 | bias = 1023 121 | m = (if s then id else negate) (fromIntegral ((w' .&. ((1 `shiftL` d) - 1)) `setBit` d)) 122 | e = clearBit w' 63 `shiftR` d 123 | 124 | prop_b2f_f2b :: Double -> Bool 125 | prop_b2f_f2b x = b2f (f2b x) == x 126 | 127 | -- Some Word64 values correspond to +/-Infinity or NaN. For most 128 | -- others, f2b is inverse to b2f; for a few that represent very tiny 129 | -- floating-point values, the Word64 resulting from a round trip may 130 | -- differ by 1. 131 | prop_f2b_b2f :: Word64 -> Bool 132 | prop_f2b_b2f w = isInfinite x || isNaN x || dist (f2b x) w <= 1 133 | where 134 | x = b2f w 135 | dist x y 136 | | x < y = y - x 137 | | otherwise = x - y 138 | 139 | -- Given two distinct Double values, if we take the midpoint of their 140 | -- corresponding Word64 values, we get another Word64 that represents 141 | -- a floating point number strictly in between the original two. 142 | prop_f2b_mid_monotonic :: Double -> Double -> Bool 143 | prop_f2b_mid_monotonic x y = x == y || (x' < z && z < y') 144 | where 145 | x' = min x y 146 | y' = max x y 147 | l = f2b x' 148 | r = f2b y' 149 | m = l + (r - l) `div` 2 150 | z = b2f m 151 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Revision history for comprog-hs 2 | 3 | ## 0.1.0.0 -- YYYY-mm-dd 4 | 5 | * First version. Released on an unsuspecting world. 6 | -------------------------------------------------------------------------------- /Enumeration.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2021/11/15/competitive-programming-in-haskell-enumeration/ 2 | 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE TypeApplications #-} 5 | 6 | module Enumeration where 7 | 8 | import qualified Data.List as L 9 | 10 | import Data.Hashable 11 | import qualified Data.Array as A 12 | import qualified Data.HashMap.Strict as HM 13 | 14 | data Enumeration a = Enumeration 15 | { card :: !Int 16 | , select :: Int -> a 17 | , locate :: a -> Int 18 | } 19 | 20 | -- | Map a pair of inverse functions over an invertible enumeration of 21 | -- @a@ values to turn it into an invertible enumeration of @b@ 22 | -- values. Because invertible enumerations contain a /bijection/ to 23 | -- the natural numbers, we really do need both directions of a 24 | -- bijection between @a@ and @b@ in order to map. This is why 25 | -- 'Enumeration' cannot be an instance of 'Functor'. 26 | mapE :: (a -> b) -> (b -> a) -> Enumeration a -> Enumeration b 27 | mapE f g (Enumeration c s l) = Enumeration c (f . s) (l . g) 28 | 29 | -- | List the elements of an enumeration in order. Inverse of 30 | -- 'listE'. 31 | enumerate :: Enumeration a -> [a] 32 | enumerate e = map (select e) [0 .. card e-1] 33 | 34 | -- | The empty enumeration, with cardinality zero and no elements. 35 | voidE :: Enumeration a 36 | voidE = Enumeration 0 (error "select void") (error "locate void") 37 | 38 | -- | The unit enumeration, with a single value of @()@ at index 0. 39 | unitE :: Enumeration () 40 | unitE = singletonE () 41 | 42 | -- | An enumeration of a single given element at index 0. 43 | singletonE :: a -> Enumeration a 44 | singletonE a = Enumeration 1 (const a) (const 0) 45 | 46 | -- | A finite prefix of the natural numbers. 47 | finiteE :: Int -> Enumeration Int 48 | finiteE n = Enumeration n id id 49 | 50 | -- | Construct an enumeration from the elements of a finite list. 51 | -- The elements of the list must all be distinct. To turn an 52 | -- enumeration back into a list, use 'enumerate'. 53 | listE :: forall a. (Hashable a, Eq a) => [a] -> Enumeration a 54 | listE as = Enumeration n (toA A.!) (fromA HM.!) 55 | where 56 | n = length as 57 | toA :: A.Array Int a 58 | toA = A.listArray (0,n-1) as 59 | 60 | fromA :: HM.HashMap a Int 61 | fromA = HM.fromList (zip as [0 :: Int ..]) 62 | 63 | -- | Enumerate all the values of a bounded 'Enum' instance. 64 | boundedEnum :: forall a. (Enum a, Bounded a) => Enumeration a 65 | boundedEnum = Enumeration 66 | { card = hi - lo + 1 67 | , select = toEnum . (+lo) 68 | , locate = subtract lo . fromEnum 69 | } 70 | where 71 | lo, hi :: Int 72 | lo = fromIntegral (fromEnum (minBound @a)) 73 | hi = fromIntegral (fromEnum (maxBound @a)) 74 | 75 | -- | Sum, /i.e./ disjoint union, of two enumerations. All the values 76 | -- of the first are enumerated before the values of the second. 77 | (>+<) :: Enumeration a -> Enumeration b -> Enumeration (Either a b) 78 | a >+< b = Enumeration 79 | { card = card a + card b 80 | , select = \k -> if k < card a then Left (select a k) else Right (select b (k - card a)) 81 | , locate = either (locate a) ((+card a) . locate b) 82 | } 83 | 84 | -- | Cartesian product of enumerations, with a lexicographic ordering. 85 | (>*<) :: Enumeration a -> Enumeration b -> Enumeration (a,b) 86 | a >*< b = Enumeration 87 | { card = card a * card b 88 | , select = \k -> let (i,j) = k `divMod` card b in (select a i, select b j) 89 | , locate = \(x,y) -> card b * locate a x + locate b y 90 | } 91 | 92 | -- | Take a finite prefix from the beginning of an enumeration. @takeE 93 | -- k e@ always yields the empty enumeration for \(k \leq 0\), and 94 | -- results in @e@ whenever @k@ is greater than or equal to the 95 | -- cardinality of the enumeration. Otherwise @takeE k e@ has 96 | -- cardinality @k@ and matches @e@ from @0@ to @k-1@. 97 | takeE :: Int -> Enumeration a -> Enumeration a 98 | takeE k e 99 | | k <= 0 = voidE 100 | | k >= card e = e 101 | | otherwise = Enumeration k (select e) (locate e) 102 | 103 | -- | Drop some elements from the beginning of an enumeration. @dropE k 104 | -- e@ yields @e@ unchanged if \(k \leq 0\), and results in the empty 105 | -- enumeration whenever @k@ is greater than or equal to the 106 | -- cardinality of @e@. 107 | dropE :: Int -> Enumeration a -> Enumeration a 108 | dropE k e 109 | | k <= 0 = e 110 | | k >= card e = voidE 111 | | otherwise = Enumeration 112 | { card = card e - k 113 | , select = select e . (+k) 114 | , locate = subtract k . locate e 115 | } 116 | 117 | -- | Zip two enumerations in parallel, producing the pair of 118 | -- elements at each index. The resulting enumeration is truncated 119 | -- to the cardinality of the smaller of the two arguments. 120 | zipE :: Enumeration a -> Enumeration b -> Enumeration (a,b) 121 | zipE ea eb = Enumeration 122 | { card = min (card ea) (card eb) 123 | , select = \k -> (select ea k, select eb k) 124 | , locate = locate ea . fst 125 | } 126 | -------------------------------------------------------------------------------- /Geom.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2020/06/24/competitive-programming-in-haskell-vectors-and-2d-geometry/ 2 | {-# LANGUAGE DeriveFunctor #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 | {-# LANGUAGE ViewPatterns #-} 5 | 6 | module Geom where 7 | 8 | import Data.Function (on) 9 | import Data.List (nub) 10 | import Data.Maybe (mapMaybe) 11 | import Data.Ord (compare) 12 | import Data.Ratio 13 | 14 | ------------------------------------------------------------ 15 | -- 2D points and vectors 16 | 17 | data V2 s = V2 {getX :: !s, getY :: !s} deriving (Eq, Ord, Show, Functor) 18 | type V2D = V2 Double 19 | 20 | type P2 s = V2 s 21 | type P2D = P2 Double 22 | 23 | instance Foldable V2 where 24 | foldMap f (V2 x y) = f x <> f y 25 | 26 | zero :: Num s => V2 s 27 | zero = V2 0 0 28 | 29 | -- | Adding and subtracting vectors. 30 | (^+^), (^-^) :: Num s => V2 s -> V2 s -> V2 s 31 | V2 x1 y1 ^+^ V2 x2 y2 = V2 (x1 + x2) (y1 + y2) 32 | V2 x1 y1 ^-^ V2 x2 y2 = V2 (x1 - x2) (y1 - y2) 33 | 34 | -- | Scalar multiple of a vector. 35 | (*^) :: Num s => s -> V2 s -> V2 s 36 | (*^) k = fmap (k *) 37 | 38 | (^/) :: Fractional s => V2 s -> s -> V2 s 39 | v ^/ k = (1 / k) *^ v 40 | 41 | ------------------------------------------------------------ 42 | -- Utilities 43 | 44 | -- | These combinators allows us to write e.g. 'v2 int' or 'v2 double' 45 | -- to get a 'Scanner (V2 s)'. 46 | v2, p2 :: Applicative f => f s -> f (V2 s) 47 | v2 s = V2 <$> s <*> s 48 | p2 = v2 49 | 50 | newtype ByX s = ByX {unByX :: V2 s} deriving (Eq, Show, Functor) 51 | newtype ByY s = ByY {unByY :: V2 s} deriving (Eq, Show, Functor) 52 | 53 | instance Ord s => Ord (ByX s) where 54 | compare = compare `on` (getX . unByX) 55 | 56 | instance Ord s => Ord (ByY s) where 57 | compare = compare `on` (getY . unByY) 58 | 59 | -- Manhattan distance 60 | manhattan :: Num s => P2 s -> P2 s -> s 61 | manhattan (V2 x1 y1) (V2 x2 y2) = abs (x1 - x2) + abs (y1 - y2) 62 | 63 | ------------------------------------------------------------ 64 | -- Angles 65 | 66 | newtype Angle s = A s -- angle (radians) 67 | deriving (Show, Eq, Ord, Num, Fractional, Floating) 68 | 69 | fromDeg :: Floating s => s -> Angle s 70 | fromDeg d = A (d * pi / 180) 71 | 72 | fromRad :: s -> Angle s 73 | fromRad = A 74 | 75 | toDeg :: Floating s => Angle s -> s 76 | toDeg (A r) = r * 180 / pi 77 | 78 | toRad :: Angle s -> s 79 | toRad (A r) = r 80 | 81 | dir :: V2D -> Angle Double 82 | dir (V2 x y) = A $ atan2 y x 83 | 84 | -- | Construct a vector in polar coordinates. 85 | fromPolar :: Floating s => s -> Angle s -> V2 s 86 | fromPolar r θ = rot θ (V2 r 0) 87 | 88 | -- | Rotate a vector counterclockwise by a given angle. 89 | rot :: Floating s => Angle s -> V2 s -> V2 s 90 | rot (A θ) (V2 x y) = V2 (cos θ * x - sin θ * y) (sin θ * x + cos θ * y) 91 | 92 | perp :: Num s => V2 s -> V2 s 93 | perp (V2 x y) = V2 (-y) x 94 | 95 | ------------------------------------------------------------ 96 | -- Dot product 97 | 98 | -- | Dot product of two vectors. u·v = |u||v| cos θ (where θ is the 99 | -- (unsigned) angle between u and v). So u·v is zero iff the vectors 100 | -- are perpendicular. 101 | dot :: Num s => V2 s -> V2 s -> s 102 | dot (V2 x1 y1) (V2 x2 y2) = x1 * x2 + y1 * y2 103 | 104 | -- | 'dotP p1 p2 p3' computes the dot product of the vectors from p1 105 | -- to p2 and from p1 to p3. 106 | dotP :: Num s => P2 s -> P2 s -> P2 s -> s 107 | dotP p1 p2 p3 = dot (p2 ^-^ p1) (p3 ^-^ p1) 108 | 109 | -- | Squared norm of a vector, /i.e./ square of its length, /i.e./ dot 110 | -- product with itself. 111 | normSq :: Num s => V2 s -> s 112 | normSq v = dot v v 113 | 114 | -- | Norm, /i.e./ length of a vector. 115 | norm :: Floating s => V2 s -> s 116 | norm = sqrt . normSq 117 | 118 | normalize :: Floating s => V2 s -> V2 s 119 | normalize v = v ^/ norm v 120 | 121 | -- | 'angleP p1 p2 p3' computes the (unsigned) angle of p1-p2-p3 122 | -- (/i.e./ the angle at p2 made by rays to p1 and p3). The result 123 | -- will always be in the range $[0, \pi]$. 124 | angleP :: Floating s => P2 s -> P2 s -> P2 s -> Angle s 125 | angleP x y z = A $ acos (dot a b / (norm a * norm b)) 126 | where 127 | a = x ^-^ y 128 | b = z ^-^ y 129 | 130 | -- | 'signedAngleP p1 p2 p3' computes the /signed/ angle p1-p2-p3 131 | -- (/i.e./ the angle at p2 made by rays to p1 and p3), in the range 132 | -- $[-\pi, \pi]$. Positive iff the ray from p2 to p3 is 133 | -- counterclockwise of the ray from p2 to p1. 134 | signedAngleP :: (Floating s, Ord s) => P2 s -> P2 s -> P2 s -> Angle s 135 | signedAngleP x y z = case turnP x y z of 136 | CCW -> -angleP x y z 137 | _ -> angleP x y z 138 | 139 | ------------------------------------------------------------ 140 | -- Cross product 141 | 142 | -- | 2D cross product of two vectors. Gives the signed area of their 143 | -- parallelogram (positive iff the second is counterclockwise of the 144 | -- first). [Geometric algebra tells us that this is really the 145 | -- coefficient of the bivector resulting from the outer product of 146 | -- the two vectors.] 147 | -- 148 | -- Note this works even for integral scalar types. 149 | cross :: Num s => V2 s -> V2 s -> s 150 | cross (V2 ux uy) (V2 vx vy) = ux * vy - vx * uy 151 | 152 | -- | A version of cross product specialized to three points describing 153 | -- the endpoints of the vectors. The first argument is the shared 154 | -- tail of the vectors, and the second and third arguments are the 155 | -- endpoints of the vectors. 156 | crossP :: Num s => P2 s -> P2 s -> P2 s -> s 157 | crossP p1 p2 p3 = cross (p2 ^-^ p1) (p3 ^-^ p1) 158 | 159 | -- | The signed area of a triangle with given vertices can be computed 160 | -- as half the cross product of two of the edges. 161 | -- 162 | -- Note that this requires 'Fractional' because of the division by 163 | -- two. If you want to stick with integral scalars, you can just 164 | -- use 'crossP' to get twice the signed area. 165 | signedTriArea :: Fractional s => P2 s -> P2 s -> P2 s -> s 166 | signedTriArea p1 p2 p3 = crossP p1 p2 p3 / 2 167 | 168 | -- | The (nonnegative) area of the triangle with the given vertices. 169 | triArea :: Fractional s => P2 s -> P2 s -> P2 s -> s 170 | triArea p1 p2 p3 = abs (signedTriArea p1 p2 p3) 171 | 172 | -- | The signed area of the polygon with the given vertices, via the 173 | -- "shoelace formula". Positive iff the points are given in 174 | -- counterclockwise order. 175 | signedPolyArea :: Fractional s => [P2 s] -> s 176 | signedPolyArea pts = sum $ zipWith (signedTriArea zero) pts (tail pts ++ [head pts]) 177 | 178 | -- | The (nonnegative) area of the polygon with the given vertices. 179 | polyArea :: Fractional s => [P2 s] -> s 180 | polyArea = abs . signedPolyArea 181 | 182 | -- | Direction of a turn: counterclockwise (left), clockwise (right), 183 | -- or parallel (/i.e./ 0 or 180 degree turn). 184 | data Turn = CCW | Par | CW 185 | 186 | -- | Cross product can also be used to compute the direction of a 187 | -- turn. If you are travelling from p1 to p2, 'turnP p1 p2 p3' says 188 | -- whether you have to make a left (ccw) or right (cw) turn to 189 | -- continue on to p3 (or if it is parallel). Equivalently, if you 190 | -- are standing at p1 looking towards p2, and imagine the line 191 | -- through p1 and p2 dividing the plane in two, is p3 on the right 192 | -- side, left side, or on the line? 193 | turnP :: (Num s, Ord s) => P2 s -> P2 s -> P2 s -> Turn 194 | turnP x y z 195 | | s > 0 = CCW 196 | | s == 0 = Par 197 | | otherwise = CW 198 | where 199 | s = signum (crossP x y z) 200 | 201 | ------------------------------------------------------------ 202 | -- 2D Lines 203 | 204 | data L2 s = L2 {getDirection :: !(V2 s), getOffset :: !s} 205 | type L2D = L2 Double 206 | 207 | lineFromEquation :: Num s => s -> s -> s -> L2 s 208 | lineFromEquation a b c = L2 (V2 b (-a)) c 209 | 210 | lineFromPoints :: Num s => P2 s -> P2 s -> L2 s 211 | lineFromPoints p q = L2 v (v `cross` p) 212 | where 213 | v = q ^-^ p 214 | 215 | slope :: (Integral n, Eq n) => L2 n -> Maybe (Ratio n) 216 | slope (getDirection -> V2 x y) = case x of 217 | 0 -> Nothing 218 | _ -> Just (y % x) 219 | 220 | dslope :: (Fractional s, Eq s) => L2 s -> Maybe s 221 | dslope (getDirection -> V2 x y) = case x of 222 | 0 -> Nothing 223 | _ -> Just (y / x) 224 | 225 | side :: Num s => L2 s -> P2 s -> s 226 | side (L2 v c) p = cross v p - c 227 | 228 | leftOf :: (Num s, Ord s) => L2 s -> P2 s -> Bool 229 | leftOf l p = side l p > 0 230 | 231 | rightOf :: (Num s, Ord s) => L2 s -> P2 s -> Bool 232 | rightOf l p = side l p < 0 233 | 234 | toProjection :: Fractional s => L2 s -> P2 s -> V2 s 235 | toProjection l@(L2 v _) p = (-side l p / normSq v) *^ perp v 236 | 237 | project :: Fractional s => L2 s -> P2 s -> P2 s 238 | project l p = p ^+^ toProjection l p 239 | 240 | reflectAcross :: Fractional s => L2 s -> P2 s -> P2 s 241 | reflectAcross l p = p ^+^ (2 *^ toProjection l p) 242 | 243 | lineIntersection :: (Fractional s, Eq s) => L2 s -> L2 s -> Maybe (P2 s) 244 | lineIntersection (L2 v1 c1) (L2 v2 c2) 245 | | cross v1 v2 == 0 = Nothing 246 | | otherwise = Just $ ((c1 *^ v2) ^-^ (c2 *^ v1)) ^/ cross v1 v2 247 | 248 | ------------------------------------------------------------ 249 | -- Segments 250 | 251 | data Seg s = Seg (P2 s) (P2 s) deriving (Eq, Show) 252 | 253 | segLine :: Num s => Seg s -> L2 s 254 | segLine (Seg p q) = lineFromPoints p q 255 | 256 | -- Test whether two segments intersect. 257 | -- http://www.geeksforgeeks.org/check-if-two-given-line-segments-intersect/ 258 | segsIntersect :: (Ord s, Num s) => Seg s -> Seg s -> Bool 259 | segsIntersect (Seg p1 q1) (Seg p2 q2) 260 | | o1 /= o2 && o3 /= o4 = True 261 | | o1 == 0 && onSegment p1 p2 q1 = True 262 | | o2 == 0 && onSegment p1 q2 q1 = True 263 | | o3 == 0 && onSegment p2 p1 q2 = True 264 | | o4 == 0 && onSegment p2 q1 q2 = True 265 | | otherwise = False 266 | where 267 | o1 = signum $ crossP p1 q1 p2 268 | o2 = signum $ crossP p1 q1 q2 269 | o3 = signum $ crossP p2 q2 p1 270 | o4 = signum $ crossP p2 q2 q1 271 | 272 | -- Given three *collinear* points p, q, r, check whether q lies on pr. 273 | onSegment (V2 px py) (V2 qx qy) (V2 rx ry) = 274 | min px rx <= qx 275 | && qx <= max px rx 276 | && min py ry <= qy 277 | && qy <= max py ry 278 | 279 | segsIntersection :: (Ord s, Fractional s) => Seg s -> Seg s -> Maybe (P2 s) 280 | segsIntersection s1 s2 281 | | segsIntersect s1 s2 = lineIntersection (segLine s1) (segLine s2) 282 | | otherwise = Nothing 283 | 284 | ------------------------------------------------------------ 285 | -- Rectangles 286 | 287 | data Rect s = Rect {lowerLeft :: P2 s, dims :: V2 s} deriving (Eq, Show) 288 | 289 | rectFromCorners :: (Num s, Ord s) => P2 s -> P2 s -> Rect s 290 | rectFromCorners (V2 x1 y1) (V2 x2 y2) = 291 | Rect (V2 (min x1 x2) (min y1 y2)) (V2 (abs (x2 - x1)) (abs (y2 - y1))) 292 | 293 | pointInRect :: (Ord s, Num s) => P2 s -> Rect s -> Bool 294 | pointInRect (V2 px py) (Rect (V2 llx lly) (V2 dx dy)) = 295 | and 296 | [ px >= llx 297 | , px <= llx + dx 298 | , py >= lly 299 | , py <= lly + dy 300 | ] 301 | 302 | rectSegs :: Num s => Rect s -> [Seg s] 303 | rectSegs (Rect ll d@(V2 dx dy)) = [Seg ll ul, Seg ul ur, Seg ur lr, Seg lr ll] 304 | where 305 | ul = ll ^+^ V2 0 dy 306 | ur = ll ^+^ d 307 | lr = ll ^+^ V2 dx 0 308 | 309 | rectSegIntersection :: (Fractional s, Ord s) => Rect s -> Seg s -> Maybe (Seg s) 310 | rectSegIntersection r s@(Seg t u) 311 | | pointInRect t r && pointInRect u r = Just s 312 | | otherwise = case nub (mapMaybe (segsIntersection s) (rectSegs r)) of 313 | [] -> Nothing 314 | [p, q] -> Just $ Seg p q 315 | [p] 316 | | pointInRect t r -> Just (Seg p t) 317 | | pointInRect u r -> Just (Seg p u) 318 | | otherwise -> Nothing 319 | 320 | ------------------------------------------------------------ 321 | -- Circles 322 | 323 | data Circle s = Circle {center :: P2 s, radius :: s} deriving (Eq, Show) 324 | 325 | pointInCircle :: (Ord s, Num s) => P2 s -> Circle s -> Bool 326 | pointInCircle p (Circle c r) = normSq (p ^-^ c) <= r * r 327 | 328 | rectCircleIntersection :: (Ord s, Num s) => Rect s -> Circle s -> Bool 329 | rectCircleIntersection (Rect ll@(V2 llx lly) d@(V2 dx dy)) (Circle c r) = 330 | or 331 | [ pointInRect c (Rect (V2 (llx - r) lly) (V2 (dx + 2 * r) dy)) 332 | , pointInRect c (Rect (V2 llx (lly - r)) (V2 dx (dy + 2 * r))) 333 | , pointInCircle c (Circle ll r) 334 | , pointInCircle c (Circle (ll ^+^ V2 dx 0) r) 335 | , pointInCircle c (Circle (ll ^+^ V2 0 dy) r) 336 | , pointInCircle c (Circle (ll ^+^ d) r) 337 | ] 338 | -------------------------------------------------------------------------------- /IdempotentSemigroup.hs: -------------------------------------------------------------------------------- 1 | module IdempotentSemigroup where 2 | 3 | import Data.Bits 4 | import Data.Semigroup 5 | 6 | -- | An idempotent semigroup is one where the binary operation 7 | -- satisfies the law @x <> x = x@ for all @x@. 8 | class Semigroup m => IdempotentSemigroup m 9 | 10 | instance Ord a => IdempotentSemigroup (Min a) 11 | instance Ord a => IdempotentSemigroup (Max a) 12 | instance IdempotentSemigroup All 13 | instance IdempotentSemigroup Any 14 | instance IdempotentSemigroup Ordering 15 | instance IdempotentSemigroup () 16 | instance IdempotentSemigroup (First a) 17 | instance IdempotentSemigroup (Last a) 18 | instance Bits a => IdempotentSemigroup (And a) 19 | instance Bits a => IdempotentSemigroup (Ior a) 20 | -------------------------------------------------------------------------------- /Interval.hs: -------------------------------------------------------------------------------- 1 | module Interval where 2 | 3 | import Prelude hiding (drop, length, take) 4 | 5 | data I = I {lo :: !Int, hi :: !Int} 6 | deriving (Eq, Ord, Show) 7 | 8 | point :: Int -> I 9 | point x = I x x 10 | 11 | mkI :: Int -> Int -> I 12 | mkI l h 13 | | l == h = I 0 0 14 | | otherwise = I l h 15 | 16 | (∪), (∩) :: I -> I -> I 17 | I l1 h1 ∪ I l2 h2 = I (min l1 l2) (max h1 h2) 18 | I l1 h1 ∩ I l2 h2 = I (max l1 l2) (min h1 h2) 19 | 20 | intersects :: I -> I -> Bool 21 | intersects i1 i2 = not (isEmpty (i1 ∩ i2)) 22 | 23 | isEmpty :: I -> Bool 24 | isEmpty (I l h) = l > h 25 | 26 | (⊆) :: I -> I -> Bool 27 | i1 ⊆ i2 = i1 ∪ i2 == i2 28 | 29 | uncons :: I -> I 30 | uncons (I l h) = mkI (l + 1) h 31 | 32 | splits :: I -> [(I, I)] 33 | splits (I l h) = [(mkI l m, mkI m h) | m <- [l .. h]] 34 | 35 | subs :: I -> [I] 36 | subs (I l h) = mkI 0 0 : [mkI l' h' | l' <- [l .. h - 1], h' <- [l' + 1 .. h]] 37 | 38 | length :: I -> Int 39 | length (I l h) = h - l 40 | 41 | take :: Int -> I -> I 42 | take k (I l h) = mkI l (min (l + k) h) 43 | 44 | drop :: Int -> I -> I 45 | drop k (I l h) = mkI (min (l + k) h) h 46 | 47 | range :: I -> [Int] 48 | range (I l h) = [l .. h - 1] 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Brent Yorgey 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Brent Yorgey nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /NumberTheory.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2020/02/07/competitive-programming-in-haskell-primes-and-factoring/ 2 | -- https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/ 3 | -- https://byorgey.wordpress.com/2020/03/03/competitive-programming-in-haskell-modular-arithmetic-part-2/ 4 | 5 | module NumberTheory where 6 | 7 | import qualified Data.Foldable as F 8 | import Data.Map (Map) 9 | import qualified Data.Map as M 10 | 11 | import Control.Arrow 12 | import Data.Bits 13 | import Data.List (group, sort) 14 | import Data.Maybe (fromJust) 15 | 16 | ------------------------------------------------------------ 17 | -- Modular exponentiation 18 | 19 | modexp :: Integer -> Integer -> Integer -> Integer 20 | modexp b e m = go e 21 | where 22 | go 0 = 1 23 | go e 24 | | e `testBit` 0 = (b * r * r) `mod` m 25 | | otherwise = (r * r) `mod` m 26 | where 27 | r = go (e `shiftR` 1) 28 | 29 | ------------------------------------------------------------ 30 | -- (Extended) Euclidean algorithm 31 | 32 | -- egcd a b = (g,x,y) 33 | -- g is the gcd of a and b, and ax + by = g 34 | egcd :: Integer -> Integer -> (Integer, Integer, Integer) 35 | egcd a 0 36 | | a < 0 = (-a, -1, 0) 37 | | otherwise = (a, 1, 0) 38 | egcd a b = (g, y, x - (a `div` b) * y) 39 | where 40 | (g, x, y) = egcd b (a `mod` b) 41 | 42 | -- g = bx + (a mod b)y 43 | -- = bx + (a - b(a/b))y 44 | -- = ay + b(x - (a/b)y) 45 | 46 | -- inverse p a is the multiplicative inverse of a mod p 47 | inverse :: Integer -> Integer -> Integer 48 | inverse p a = y `mod` p 49 | where 50 | (_, _, y) = egcd p a 51 | 52 | ------------------------------------------------------------ 53 | -- Primes, factoring, and divisors 54 | 55 | -------------------------------------------------- 56 | -- Miller-Rabin primality testing 57 | 58 | -- Need to upgrade to Baille-PSW? See ~/learning/primality/baille-PSW.py 59 | 60 | smallPrimes = take 20 primes 61 | 62 | -- https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Testing_against_small_sets_of_bases 63 | mrPrimes n 64 | | n < 2047 = [2] 65 | | n < 1373653 = [2, 3] 66 | | n < 9080191 = [31, 73] 67 | | n < 25326001 = [2, 3, 5] 68 | | n < 3215031751 = [2, 3, 5, 7] 69 | | n < 4759123141 = [2, 7, 61] 70 | | n < 1122004669633 = [2, 13, 23, 1662803] 71 | | n < 2152302898747 = [2, 3, 5, 7, 11] 72 | | n < 3474749660383 = [2, 3, 5, 7, 11, 13] 73 | | n < 341550071728321 = [2, 3, 5, 7, 11, 13, 17] 74 | | n < 3825123056546413051 = [2, 3, 5, 7, 11, 13, 17, 19, 23] 75 | | n < 318665857834031151167461 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] 76 | | n < 3317044064679887385961981 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41] 77 | 78 | -- With these values of a, guaranteed to work up to 3*10^24 (see https://pastebin.com/6XEFRPaZ) 79 | isPrime :: Integer -> Bool 80 | isPrime n 81 | | n == 1 = False 82 | | n `elem` smallPrimes = True 83 | | any ((== 0) . (n `mod`)) smallPrimes = False 84 | | otherwise = spps n (mrPrimes n) 85 | 86 | isPrimeTrialDiv :: Integer -> Bool 87 | isPrimeTrialDiv 1 = False 88 | isPrimeTrialDiv n = all ((/=0) . (n `mod`)) (takeWhile (\p -> p*p <= n) primes) 89 | 90 | -- spp n a tests whether n is a strong probable prime to base a. 91 | spp :: Integer -> Integer -> Bool 92 | spp n a = spp' n s d a 93 | where 94 | (s, d) = decompose (n - 1) 95 | 96 | spp' n s d a = (ad == 1) || (n - 1) `elem` as 97 | where 98 | ad = modexp a d n 99 | as = take s (iterate ((`mod` n) . (^ 2)) ad) 100 | 101 | -- spps n as tests whether n is a strong probable prime to all the 102 | -- given bases. 103 | spps :: Integer -> [Integer] -> Bool 104 | spps n = all (spp' n s d) 105 | where 106 | (s, d) = decompose (n - 1) 107 | 108 | -- decompose n = (s,d) such that n = 2^s * d and d is odd. 109 | -- Only works for n < 2^63. 110 | decompose :: Integer -> (Int, Integer) 111 | decompose n = (tz, n `shiftR` tz) 112 | where 113 | tz = countTrailingZeros (fromIntegral n :: Int) 114 | 115 | -------------------------------------------------- 116 | -- Pollard Rho factoring algorithm 117 | 118 | -- Tries to find a non-trivial factor of the given number, using the 119 | -- given starting value. 120 | pollardRho :: Integer -> Integer -> Maybe Integer 121 | pollardRho a n = go (g a) (g (g a)) 122 | where 123 | go x y 124 | | d == n = Nothing 125 | | d == 1 = go (g x) (g (g y)) 126 | | otherwise = Just d 127 | where 128 | d = gcd (abs (x - y)) n 129 | g x = (x * x + 1) `mod` n 130 | 131 | -- Find a nontrivial factor of a number we know for sure is composite. 132 | compositeFactor :: Integer -> Integer 133 | compositeFactor n | even n = 2 134 | compositeFactor 25 = 5 135 | compositeFactor n = fromJust (F.asum (map (`pollardRho` n) [2 ..])) 136 | 137 | -------------------------------------------------- 138 | -- Factoring 139 | 140 | factorMap :: Integer -> Map Integer Int 141 | factorMap = factor >>> M.fromList 142 | 143 | factor :: Integer -> [(Integer, Int)] 144 | factor = listFactors >>> group >>> map (head &&& length) 145 | 146 | primes :: [Integer] 147 | primes = 2 : sieve primes [3 ..] 148 | where 149 | sieve (p : ps) xs = 150 | let (h, t) = span (< p * p) xs 151 | in h ++ sieve ps (filter ((/= 0) . (`mod` p)) t) 152 | 153 | listFactors :: Integer -> [Integer] 154 | listFactors = sort . go 155 | where 156 | go 1 = [] 157 | go n 158 | | isPrime n = [n] 159 | | otherwise = go d ++ go (n `div` d) 160 | where 161 | d = compositeFactor n 162 | 163 | -- listFactors :: Integer -> [Integer] 164 | -- listFactors = go primes 165 | -- where 166 | -- go _ 1 = [] 167 | -- go (p:ps) n 168 | -- | p*p > n = [n] 169 | -- | n `mod` p == 0 = p : go (p:ps) (n `div` p) 170 | -- | otherwise = go ps n 171 | 172 | divisors :: Integer -> [Integer] 173 | divisors = 174 | factor 175 | >>> map (\(p, k) -> take (k + 1) (iterate (* p) 1)) 176 | >>> sequence 177 | >>> map product 178 | 179 | totient :: Integer -> Integer 180 | totient = factor >>> map (\(p, k) -> p ^ (k - 1) * (p - 1)) >>> product 181 | 182 | ------------------------------------------------------------ 183 | -- Solving modular equations 184 | 185 | -- solveMod a b m solves ax = b (mod m), returning (y,k) such that all 186 | -- solutions are equivalent to y (mod k) 187 | solveMod :: Integer -> Integer -> Integer -> Maybe (Integer, Integer) 188 | solveMod a b m 189 | | g == 1 = Just ((b * inverse m a) `mod` m, m) 190 | | b `mod` g == 0 = solveMod (a `div` g) (b `div` g) (m `div` g) 191 | | otherwise = Nothing 192 | where 193 | g = gcd a m 194 | 195 | -- gcrt solves a system of modular equations. Each equation x = a 196 | -- (mod n) is given as a pair (a,n). Returns a pair (z, k) such that 197 | -- 0 <= z < k and solutions for x satisfy x = z (mod k), that is, 198 | -- solutions are of the form x = z + kt for integer t. 199 | gcrt :: [(Integer, Integer)] -> Maybe (Integer, Integer) 200 | gcrt [e] = Just e 201 | gcrt (e1 : e2 : es) = gcrt2 e1 e2 >>= \e -> gcrt (e : es) 202 | 203 | -- gcrt2 (a,n) (b,m) solves the pair of modular equations 204 | -- 205 | -- x = a (mod n) 206 | -- x = b (mod m) 207 | -- 208 | -- It returns a pair (c, k) such that 0 <= c < k and all solutions for 209 | -- x satisfy x = c (mod k), that is, solutions are of the form x = c + 210 | -- kt for integer t. 211 | gcrt2 :: (Integer, Integer) -> (Integer, Integer) -> Maybe (Integer, Integer) 212 | gcrt2 (a, n) (b, m) 213 | | a `mod` g == b `mod` g = Just (((a * v * m + b * u * n) `div` g) `mod` k, k) 214 | | otherwise = Nothing 215 | where 216 | (g, u, v) = egcd n m 217 | k = (m * n) `div` g 218 | -------------------------------------------------------------------------------- /Perm.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2020/07/18/competitive-programming-in-haskell-cycle-decomposition-with-mutable-arrays/ 2 | {-# LANGUAGE BangPatterns #-} 3 | 4 | module Perm where 5 | 6 | import Control.Arrow 7 | import Control.Monad.ST 8 | import Data.Array.Base 9 | import Data.Array.MArray 10 | import Data.Array.ST 11 | import Data.Array.Unboxed 12 | 13 | -- | 'Perm' represents a /1-indexed/ permutation. It can also be 14 | -- thought of as an endofunction on the set @{1 .. n}@. 15 | newtype Perm = Perm {getPerm :: UArray Int Int} 16 | deriving (Eq, Ord, Show) 17 | 18 | idPerm :: Int -> Perm 19 | idPerm n = fromList [1 .. n] 20 | 21 | -- | Construct a 'Perm' from a list containing a permutation of the 22 | -- numbers 1..n. The resulting 'Perm' sends @i@ to whatever number 23 | -- is at index @i-1@ in the list. 24 | fromList :: [Int] -> Perm 25 | fromList xs = Perm $ listArray (1, length xs) xs 26 | 27 | -- | Compose two permutations (corresponds to backwards function 28 | -- composition). Only defined if the permutations have the same 29 | -- size. 30 | andThen :: Perm -> Perm -> Perm 31 | andThen (Perm p1) (Perm p2) = Perm $ listArray (bounds p1) (map ((p1 !) >>> (p2 !)) (range (bounds p1))) 32 | 33 | instance Semigroup Perm where 34 | (<>) = andThen 35 | 36 | -- | Compute the inverse of a permutation. 37 | inverse :: Perm -> Perm 38 | inverse (Perm p) = Perm $ array (bounds p) [(p ! k, k) | k <- range (bounds p)] 39 | 40 | data CycleDecomp = CD 41 | { cycleID :: UArray Int Int 42 | , cycleLen :: UArray Int Int 43 | -- ^ Each number maps to the ID # 44 | -- of the cycle it is part of 45 | , cycleIndex :: UArray Int Int 46 | -- ^ Each cycle ID maps to the length of that cycle 47 | , cycleCounts :: UArray Int Int 48 | -- ^ Each element maps to its (0-based) index in its cycle 49 | } 50 | -- \| Each size maps to the number of cycles of that size 51 | 52 | deriving (Show) 53 | 54 | -- | Cycle decomposition of a permutation in O(n), using mutable arrays. 55 | permToCycles :: Perm -> CycleDecomp 56 | permToCycles (Perm p) = cd 57 | where 58 | (_, n) = bounds p 59 | 60 | cd = runST $ do 61 | cid <- newArray (1, n) 0 62 | cix <- newArray (1, n) 0 63 | ccs <- newArray (1, n) 0 64 | 65 | lens <- findCycles cid cix ccs 1 1 66 | cid' <- freeze cid 67 | cix' <- freeze cix 68 | ccs' <- freeze ccs 69 | return $ CD cid' (listArray (1, length lens) lens) cix' ccs' 70 | 71 | findCycles :: 72 | STUArray s Int Int -> 73 | STUArray s Int Int -> 74 | STUArray s Int Int -> 75 | Int -> 76 | Int -> 77 | ST s [Int] 78 | findCycles cid cix ccs l !k -- l = next available cycle ID; k = cur element 79 | | k > n = return [] 80 | | otherwise = do 81 | -- check if k is already marked as part of a cycle 82 | id <- readArray cid k 83 | case id of 84 | 0 -> do 85 | -- k is unvisited. Explore its cycle and label it as l. 86 | len <- labelCycle cid cix l k 0 87 | 88 | -- Remember that we have one more cycle of this size. 89 | count <- readArray ccs len 90 | writeArray ccs len (count + 1) 91 | 92 | -- Continue with the next label and the next element, and 93 | -- remember the size of this cycle 94 | (len :) <$> findCycles cid cix ccs (l + 1) (k + 1) 95 | 96 | -- k is already visited: just go on to the next element 97 | _ -> findCycles cid cix ccs l (k + 1) 98 | 99 | -- Explore a single cycle, label all its elements and return its size. 100 | labelCycle cid cix l k !i = do 101 | -- Keep going as long as the next element is unlabelled. 102 | id <- readArray cid k 103 | case id of 104 | 0 -> do 105 | -- Label the current element with l. 106 | writeArray cid k l 107 | -- The index of the current element is i. 108 | writeArray cix k i 109 | 110 | -- Look up the next element in the permutation and continue. 111 | (1 +) <$> labelCycle cid cix l (p ! k) (i + 1) 112 | _ -> return 0 113 | -------------------------------------------------------------------------------- /Queue.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.github.io/blog/posts/2024/11/27/stacks-queues.html 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE ImportQualifiedPost #-} 4 | 5 | module Queue where 6 | 7 | import Data.Bifunctor (second) 8 | import Data.List (foldl') 9 | import Data.Monoid (Dual (..)) 10 | import Stack (Stack) 11 | import Stack qualified 12 | 13 | data Queue m a = Queue {getFront :: !(Stack m a), getBack :: !(Stack (Dual m) a)} 14 | deriving (Show, Eq) 15 | 16 | new :: (a -> m) -> Queue m a 17 | new f = Queue (Stack.new f) (Stack.new (Dual . f)) 18 | 19 | size :: Queue m a -> Int 20 | size (Queue front back) = Stack.size front + Stack.size back 21 | 22 | measure :: Monoid m => Queue m a -> m 23 | measure (Queue front back) = Stack.measure front <> getDual (Stack.measure back) 24 | 25 | enqueue :: Monoid m => a -> Queue m a -> Queue m a 26 | enqueue a (Queue front back) = Queue front (Stack.push a back) 27 | 28 | dequeue :: Monoid m => Queue m a -> Maybe (a, Queue m a) 29 | dequeue (Queue front back) 30 | | Stack.size front == 0 && Stack.size back == 0 = Nothing 31 | | Stack.size front == 0 = dequeue (Queue (Stack.reverse' getDual back) (Stack.reverse' Dual front)) 32 | | otherwise = second (`Queue` back) <$> Stack.pop front 33 | 34 | drop1 :: Monoid m => Queue m a -> Queue m a 35 | drop1 q = case dequeue q of 36 | Nothing -> q 37 | Just (_, q') -> q' 38 | 39 | ------------------------------------------------------------ 40 | -- Sliding windows 41 | 42 | -- @windows w f as@ computes the monoidal sum @foldMap f window@ 43 | -- for each w-@window@ (i.e. contiguous subsequence of length @w@) of 44 | -- @as@, in only O(length as) time. For example, @windows 3 Sum 45 | -- [4,1,2,8,3] = [7, 11, 13]@, and @windows 3 Max [4,1,2,8,3] = [4,8,8]@. 46 | windows :: Monoid m => Int -> (a -> m) -> [a] -> [m] 47 | windows w f as = go startQ rest 48 | where 49 | (start, rest) = splitAt w as 50 | startQ = foldl' (flip enqueue) (new f) start 51 | 52 | go q as = 53 | measure q : case as of 54 | [] -> [] 55 | a : as -> go (enqueue a (drop1 q)) as 56 | 57 | data Max a = NegInf | Max a deriving (Eq, Ord, Show) 58 | 59 | instance Ord a => Semigroup (Max a) where 60 | NegInf <> a = a 61 | a <> NegInf = a 62 | Max a <> Max b = Max (max a b) 63 | 64 | instance Ord a => Monoid (Max a) where 65 | mempty = NegInf 66 | 67 | data Min a = Min a | PosInf deriving (Eq, Ord, Show) 68 | 69 | instance Ord a => Semigroup (Min a) where 70 | PosInf <> a = a 71 | a <> PosInf = a 72 | Min a <> Min b = Min (min a b) 73 | 74 | instance Ord a => Monoid (Min a) where 75 | mempty = PosInf 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Competitive programming utilities in Haskell 2 | -------------------------------------------- 3 | 4 | Code that I use for solving competitive programming problems in 5 | Haskell. See [Competitive Programming in Haskell: Basic 6 | Setup](https://byorgey.wordpress.com/2019/04/24/competitive-programming-in-haskell-basic-setup/) 7 | for an introduction. Individual files also link to blog posts where 8 | they are discussed. 9 | 10 | Copyright 2019 by Brent Yorgey. Everything in this repository is 11 | licensed under a Creative Commons Attribution 4.0 International License. 12 | 13 | Creative Commons License
14 | -------------------------------------------------------------------------------- /Scanner.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2019/05/22/competitive-programming-in-haskell-scanner/ 2 | {-# LANGUAGE LambdaCase #-} 3 | 4 | module Scanner where 5 | 6 | import Control.Monad (replicateM) 7 | import Control.Monad.State 8 | 9 | type Scanner = State [String] 10 | 11 | runScanner :: Scanner a -> String -> a 12 | runScanner = runScannerWith words 13 | 14 | runScannerWith :: (String -> [String]) -> Scanner a -> String -> a 15 | runScannerWith t s = evalState s . t 16 | 17 | str :: Scanner String 18 | str = get >>= \case s : ss -> put ss >> return s 19 | 20 | int :: Scanner Int 21 | int = read <$> str 22 | 23 | integer :: Scanner Integer 24 | integer = read <$> str 25 | 26 | double :: Scanner Double 27 | double = read <$> str 28 | 29 | decimal :: Int -> Scanner Int 30 | decimal p = (round . ((10 ^ p) *)) <$> double 31 | 32 | numberOf :: Scanner a -> Scanner [a] 33 | numberOf s = int >>= flip replicateM s 34 | 35 | many :: Scanner a -> Scanner [a] 36 | many s = get >>= \case [] -> return []; _ -> (:) <$> s <*> many s 37 | 38 | times :: Int -> Scanner a -> Scanner [a] 39 | times = replicateM 40 | 41 | (><) :: Int -> Scanner a -> Scanner [a] 42 | (><) = times 43 | 44 | two, three, four :: Scanner a -> Scanner [a] 45 | [two, three, four] = map times [2 .. 4] 46 | -------------------------------------------------------------------------------- /ScannerBS.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.wordpress.com/2019/10/12/competitive-programming-in-haskell-reading-large-inputs-with-bytestring/ 2 | 3 | {-# LANGUAGE LambdaCase #-} 4 | 5 | module ScannerBS where 6 | 7 | import Control.Applicative (liftA2) 8 | import Control.Monad (replicateM) 9 | import Control.Monad.State 10 | import qualified Data.ByteString.Lazy.Char8 as C 11 | import Data.Maybe (fromJust) 12 | 13 | type Scanner = State [C.ByteString] 14 | 15 | runScanner :: Scanner a -> C.ByteString -> a 16 | runScanner = runScannerWith C.words 17 | 18 | runScannerWith :: (C.ByteString -> [C.ByteString]) -> Scanner a -> C.ByteString -> a 19 | runScannerWith t s = evalState s . t 20 | 21 | peek :: Scanner C.ByteString 22 | peek = head <$> get 23 | 24 | str :: Scanner C.ByteString 25 | str = get >>= \case { s:ss -> put ss >> return s } 26 | 27 | int :: Scanner Int 28 | int = (fst . fromJust . C.readInt) <$> str 29 | 30 | integer :: Scanner Integer 31 | integer = (read . C.unpack) <$> str 32 | 33 | double :: Scanner Double 34 | double = (read . C.unpack) <$> str 35 | 36 | decimal :: Int -> Scanner Int 37 | decimal p = (round . ((10^p)*)) <$> double 38 | 39 | numberOf :: Scanner a -> Scanner [a] 40 | numberOf s = int >>= flip replicateM s 41 | 42 | many :: Scanner a -> Scanner [a] 43 | many s = get >>= \case { [] -> return []; _ -> (:) <$> s <*> many s } 44 | 45 | till :: (C.ByteString -> Bool) -> Scanner a -> Scanner [a] 46 | till p s = do 47 | t <- peek 48 | case p t of 49 | True -> return [] 50 | False -> (:) <$> s <*> till p s 51 | 52 | times :: Int -> Scanner a -> Scanner [a] 53 | times = replicateM 54 | 55 | (><) = times 56 | 57 | two, three, four :: Scanner a -> Scanner [a] 58 | [two, three, four] = map times [2..4] 59 | 60 | pair :: Scanner a -> Scanner b -> Scanner (a,b) 61 | pair = liftA2 (,) 62 | -------------------------------------------------------------------------------- /SegTree.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE MultiParamTypeClasses #-} 3 | {-# LANGUAGE ViewPatterns #-} 4 | 5 | module SegTree where 6 | 7 | import Data.List (find) 8 | import Data.Maybe (fromMaybe) 9 | 10 | class Action m s where 11 | act :: m -> s -> s 12 | 13 | data SegTree m a 14 | = Leaf !Int !a 15 | | Node !Int !Int !a !m (SegTree m a) (SegTree m a) 16 | deriving Show 17 | 18 | instance Action () s where 19 | act _ = id 20 | 21 | node :: (Action m a, Semigroup a) => m -> SegTree m a -> SegTree m a -> SegTree m a 22 | node m l r = Node (getLeft l) (getRight r) (act m (getValue l <> getValue r)) m l r 23 | 24 | getValue :: SegTree m a -> a 25 | getValue (Leaf _ a) = a 26 | getValue (Node _ _ a _ _ _) = a 27 | 28 | getLeft :: SegTree m a -> Int 29 | getLeft (Leaf i _) = i 30 | getLeft (Node i _ _ _ _ _) = i 31 | 32 | getRight :: SegTree m a -> Int 33 | getRight (Leaf i _) = i+1 34 | getRight (Node _ j _ _ _ _) = j 35 | 36 | mkSegTree :: (Monoid m, Monoid a, Action m a) => [a] -> SegTree m a 37 | mkSegTree as = go 1 n (as ++ replicate (n - length as) mempty) 38 | where 39 | Just n = find (>= length as) (iterate (*2) 1) 40 | 41 | go i _ [a] = Leaf i a 42 | go i j as = node mempty l r 43 | where 44 | (as1, as2) = splitAt h as 45 | h = (j-i+1) `div` 2 46 | l = go i (i+h-1) as1 47 | r = go (i+h) j as2 48 | 49 | push :: (Monoid m, Action m a) => SegTree m a -> SegTree m a 50 | push (Node i j a m l r) = Node i j a mempty (applyAct m l) (applyAct m r) 51 | push t@Leaf{} = t 52 | 53 | applyAct :: (Monoid m, Action m a) => m -> SegTree m a -> SegTree m a 54 | applyAct m (Leaf i a) = Leaf i (act m a) 55 | applyAct m (Node i j a m2 l r) = Node i j (act m a) (m <> m2) l r 56 | 57 | update :: (Monoid m, Semigroup a, Action m a) => Int -> (a -> a) -> SegTree m a -> SegTree m a 58 | update _ f (Leaf i a) = Leaf i (f a) 59 | update p f (push -> Node _ _ _ _ l r) 60 | | p < getLeft r = node mempty (update p f l) r 61 | | otherwise = node mempty l (update p f r) 62 | 63 | set :: (Monoid m, Semigroup a, Action m a) => Int -> a -> SegTree m a -> SegTree m a 64 | set p = update p . const 65 | 66 | get :: (Monoid m, Action m a) => Int -> SegTree m a -> a 67 | get p (Leaf _ a) = a 68 | get p (push -> Node _ _ _ _ l r) 69 | | p < getLeft r = get p l 70 | | otherwise = get p r 71 | 72 | range :: (Monoid m, Monoid a, Action m a) => Int -> Int -> SegTree m a -> a 73 | range x y _ | x == y = mempty 74 | range x y (Leaf i a) 75 | | x <= i && i < y = a 76 | | otherwise = mempty 77 | range x y (push -> Node i j _ _ l r) 78 | | y <= i || j <= x = mempty 79 | | otherwise = range x y l <> range x y r 80 | 81 | apply :: (Monoid m, Semigroup a, Action m a) => Int -> m -> SegTree m a -> SegTree m a 82 | apply p = update p . act 83 | 84 | applyRange :: (Monoid m, Semigroup a, Action m a) => Int -> Int -> m -> SegTree m a -> SegTree m a 85 | applyRange x y _ t | x == y = t 86 | applyRange x y m l@(Leaf i a) 87 | | x <= i && i < y = Leaf i (act m a) 88 | | otherwise = l 89 | applyRange x y m n@(Node i j a m' l r) 90 | | x <= i && j <= y = Node i j a (m <> m') l r 91 | | otherwise = case push n of 92 | Node _ _ _ _ l r -> node mempty (applyRange x y m l) (applyRange x y m r) 93 | 94 | startingFrom :: (Monoid m, Action m a) => Int -> SegTree m a -> [SegTree m a] 95 | startingFrom l t = go t [] 96 | where 97 | go t@(Leaf i _) 98 | | l <= i = (t:) 99 | | otherwise = id 100 | go (push -> t@(Node i j _ _ lt rt)) 101 | | l <= i = (t:) 102 | | l < getLeft rt = go lt . (rt:) 103 | | l < getRight rt = go rt 104 | | otherwise = id 105 | 106 | -- | Preconditions: 107 | -- 108 | -- - @l <= getRight t@ 109 | -- - @g(mempty) == True@ 110 | -- - @g@ is antitone, that is, if @g(a)@ is false then so is @g(a <> b)@. 111 | -- 112 | -- Given these preconditions, @maxRight l g t@ returns the biggest 113 | -- @r@ such that @g (range l r t) == True@ but @g (range l (r+1) t) 114 | -- == False@ (or @r = getRight t@ if there is no such @r@). 115 | maxRight :: (Monoid a, Monoid m, Action m a) => Int -> (a -> Bool) -> SegTree m a -> Int 116 | maxRight l g t = fromMaybe (getRight t) (go mempty (startingFrom l t)) 117 | where 118 | go _ [] = Nothing 119 | go cur (Leaf i a : ts) 120 | | g (cur <> a) = go (cur <> a) ts 121 | | otherwise = Just i 122 | go cur ((push -> Node i j a _ lt rt) : ts) 123 | | g (cur <> a) = go (cur <> a) ts 124 | | otherwise = go cur (lt : rt : ts) 125 | 126 | minLeft :: (Monoid a, Monoid m, Action m a) => Int -> (a -> Bool) -> SegTree m a -> Int 127 | minLeft = undefined 128 | -------------------------------------------------------------------------------- /Sieve.hs: -------------------------------------------------------------------------------- 1 | -- https://codeforces.com/blog/entry/54090 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE PartialTypeSignatures #-} 4 | 5 | module Sieve where 6 | 7 | import Control.Monad (forM_, unless) 8 | import Control.Monad.ST 9 | import Data.Array.ST 10 | import Data.Array.Unboxed 11 | import Data.STRef 12 | 13 | newSTUArray :: (MArray (STUArray s) e (ST s), Ix i) => (i, i) -> e -> ST s (STUArray s i e) 14 | newSTUArray = newArray 15 | 16 | -- | Basic version. Suppose we want to compute a multiplicative function f such that 17 | -- - f(p) = g(p) for prime p 18 | -- - f(ip) = h(i, p, f(i)) when p divides i 19 | -- Then @sieve n g h@ returns an array f on (1,n) such that f!i = f(i), in O(n) time. 20 | sieve :: Int -> (Int -> Int) -> (Int -> Int -> Int -> Int) -> UArray Int Int 21 | sieve n g h = runSTUArray $ do 22 | primes <- newArray (0, n) (0 :: Int) 23 | numPrimes <- newSTRef 0 24 | composite <- newSTUArray (0, n) False 25 | f <- newArray (1, n) 1 26 | forM_ [2 .. n] $ \i -> do 27 | isComp <- readArray composite i 28 | unless isComp $ do 29 | appendPrime primes numPrimes i 30 | writeArray f i (g i) 31 | 32 | np <- readSTRef numPrimes 33 | 34 | let markComposites j 35 | | j >= np = return () 36 | | otherwise = do 37 | p <- readArray primes j 38 | let ip = i * p 39 | case ip > n of 40 | True -> return () 41 | False -> do 42 | writeArray composite ip True 43 | fi <- readArray f i 44 | case i `mod` p of 45 | 0 -> writeArray f ip (h i p fi) 46 | _ -> do 47 | writeArray f ip (fi * g p) 48 | markComposites (j + 1) 49 | 50 | markComposites 0 51 | return f 52 | 53 | appendPrime :: STUArray s Int Int -> STRef s Int -> Int -> ST s () 54 | appendPrime primes numPrimes p = do 55 | n <- readSTRef numPrimes 56 | writeArray primes n p 57 | writeSTRef numPrimes $! n + 1 58 | 59 | -- | More general version. Suppose we want to compute a 60 | -- multiplicative function f such that f(p^k) = g p k. Then 61 | -- @genSieve n g@ returns an array f on (1,n) such that f!i = f(i), 62 | -- in O(n) time (but with a slightly higher constant factor than 63 | -- 'sieve'). 64 | genSieve :: Int -> (Int -> Int -> Int) -> UArray Int Int 65 | genSieve n g = runSTUArray $ do 66 | primes <- newArray (0, n) (0 :: Int) 67 | numPrimes <- newSTRef 0 68 | composite <- newSTUArray (0, n) False 69 | count <- newSTUArray (2, n) (0 :: Int) 70 | f <- newArray (1, n) 1 71 | forM_ [2 .. n] $ \i -> do 72 | isComp <- readArray composite i 73 | unless isComp $ do 74 | appendPrime primes numPrimes i 75 | writeArray f i (g i 1) 76 | writeArray count i 1 77 | 78 | np <- readSTRef numPrimes 79 | 80 | let markComposites j 81 | | j >= np = return () 82 | | otherwise = do 83 | p <- readArray primes j 84 | let ip = i * p 85 | case ip > n of 86 | True -> return () 87 | False -> do 88 | writeArray composite ip True 89 | fi <- readArray f i 90 | case i `mod` p of 91 | 0 -> do 92 | c <- readArray count i 93 | f' <- readArray f (i `div` (p ^ c)) 94 | writeArray f ip (f' * g p (c + 1)) 95 | writeArray count ip (c + 1) 96 | _ -> do 97 | writeArray f ip (fi * g p 1) 98 | writeArray count ip 1 99 | markComposites (j + 1) 100 | 101 | markComposites 0 102 | return f 103 | 104 | -- Some examples! 105 | 106 | -- Euler's phi (totient) function 107 | phi :: Int -> UArray Int Int 108 | phi n = sieve n (subtract 1) (\_ p fi -> fi * p) 109 | 110 | -- Divisor sigma function, σₐ(x) = Σ{d | x} dᵃ 111 | divisorSigma :: Int -> Int -> UArray Int Int 112 | divisorSigma n a = genSieve n (\p k -> (p ^ (a * (k + 1)) - 1) `div` (p ^ a - 1)) 113 | 114 | -- Möbius function (μ) 115 | mu :: Int -> UArray Int Int 116 | mu n = sieve n (const (-1)) (\_ _ _ -> 0) 117 | -------------------------------------------------------------------------------- /Slice.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE StandaloneDeriving #-} 3 | {-# LANGUAGE UndecidableInstances #-} 4 | 5 | module Slice where 6 | 7 | import Control.Arrow ((***)) 8 | import Data.Array.Unboxed as U 9 | import Interval (I) 10 | import qualified Interval as I 11 | import Prelude as P 12 | 13 | data Slice a = S { storage :: !(UArray Int a), interval :: !I } 14 | 15 | deriving instance (Show a, IArray UArray a) => Show (Slice a) 16 | 17 | {-# SPECIALIZE fromList :: String -> Slice Char #-} 18 | fromList :: IArray UArray a => [a] -> Slice a 19 | fromList s = S (listArray (0,n-1) s) (I.mkI 0 n) 20 | where 21 | n = P.length s 22 | 23 | {-# SPECIALIZE toList :: Slice Char -> String #-} 24 | toList :: IArray UArray a => Slice a -> [a] 25 | toList (S s i) = map (s U.!) $ I.range i 26 | 27 | length :: Slice a -> Int 28 | length (S _ i) = I.length i 29 | 30 | take :: Int -> Slice a -> Slice a 31 | take k (S s i) = S s (I.take k i) 32 | 33 | subs :: Slice a -> [Slice a] 34 | subs (S s i) = map (S s) (I.subs i) 35 | 36 | splits :: Slice a -> [(Slice a, Slice a)] 37 | splits (S s i) = map (S s *** S s) (I.splits i) 38 | 39 | {-# SPECIALIZE (Slice.!) :: Slice Char -> Int -> Char #-} 40 | (!) :: IArray UArray a => Slice a -> Int -> a 41 | (S s i) ! k = s U.! (k + I.lo i) 42 | 43 | range :: Slice a -> [Int] 44 | range (S _ i) = map (subtract (I.lo i)) (I.range i) 45 | -------------------------------------------------------------------------------- /SparseTable.hs: -------------------------------------------------------------------------------- 1 | -- https://cp-algorithms.com/data_structures/sparse-table.html 2 | 3 | {-# LANGUAGE TupleSections #-} 4 | 5 | module SparseTable where 6 | 7 | import Data.Array (Array, array, (!)) 8 | import Data.Bifunctor (first) 9 | import Data.Bits 10 | import IdempotentSemigroup 11 | 12 | newtype SparseTable m = SparseTable (Array (Int, Int) m) 13 | deriving (Show) 14 | 15 | -- | Logarithm base 2, rounded down to the nearest integer. Computed 16 | -- efficiently using primitive bitwise instructions. 17 | lg :: Int -> Int 18 | lg n = finiteBitSize n - 1 - countLeadingZeros n 19 | 20 | -- | Construct a sparse table which can answer range queries over the 21 | -- given list in $O(1)$ time. Constructing the sparse table takes 22 | -- $O(n \lg n)$ time and space, where $n$ is the length of the list. 23 | fromList :: IdempotentSemigroup m => [m] -> SparseTable m 24 | fromList ms = SparseTable st 25 | where 26 | n = length ms 27 | lgn = lg n 28 | 29 | st = 30 | array ((0, 0), (lgn, n - 1)) $ 31 | (map (first (0,)) $ zip [0 ..] ms) 32 | ++ [ ((i, j), st ! (i - 1, j) <> st ! (i - 1, j + 1 !<<. (i - 1))) 33 | | i <- [1 .. lgn] 34 | , j <- [0 .. n - 1 !<<. i] 35 | ] 36 | 37 | -- | \$O(1)$. @range st l r@ computes the range query which is the 38 | -- @sconcat@ of all the elements from index @l@ to @r@ (inclusive). 39 | range :: IdempotentSemigroup m => SparseTable m -> Int -> Int -> m 40 | range (SparseTable st) l r = st ! (i, l) <> st ! (i, r - (1 !<<. i) + 1) 41 | where 42 | i = lg (r - l + 1) 43 | -------------------------------------------------------------------------------- /SqrtTree.hs: -------------------------------------------------------------------------------- 1 | -- https://cp-algorithms.com/data_structures/sqrt-tree.html 2 | -- https://cp-algorithms.com/data_structures/sqrt_decomposition.html 3 | 4 | {-# LANGUAGE TypeApplications #-} 5 | {-# LANGUAGE GADTSyntax #-} 6 | 7 | module SqrtTree where 8 | 9 | import Data.Array (Array, array, listArray, (!), bounds) 10 | import Data.List (scanl1, scanr1) 11 | 12 | import Data.Semigroup 13 | import Control.Monad 14 | import System.Random 15 | 16 | data Block m = Block 17 | { total :: m -- ^ total of this entire block 18 | , prefix :: Array Int m -- ^ prefix sums for this block 19 | , suffix :: Array Int m -- ^ suffix sums for this block 20 | , subtree :: SqrtTree m -- ^ sqrt tree for this block 21 | } 22 | deriving (Show) 23 | 24 | data SqrtTree m where 25 | One :: m -> SqrtTree m 26 | Two :: m -> m -> SqrtTree m 27 | Branch :: 28 | Int -> -- ^ block size 29 | Array Int (Block m) -> -- ^ blocks 30 | Array (Int,Int) m -> -- ^ between sums for blocks 31 | SqrtTree m 32 | deriving (Show) 33 | 34 | fromList :: Semigroup m => [m] -> SqrtTree m 35 | fromList ms = fromArray $ listArray (0,length ms-1) ms 36 | 37 | fromArray :: Semigroup m => Array Int m -> SqrtTree m 38 | fromArray ms = mkSqrtTree ms lo (hi+1) 39 | where 40 | (lo,hi) = bounds ms 41 | 42 | -- | @mkSqrtTree ms lo hi@ makes a sqrt tree on ms[lo..hi). 43 | mkSqrtTree :: Semigroup m => Array Int m -> Int -> Int -> SqrtTree m 44 | mkSqrtTree ms lo hi 45 | | hi - lo == 1 = One (ms ! lo) 46 | | hi - lo == 2 = Two (ms ! lo) (ms ! (lo+1)) 47 | | otherwise = Branch k blocks between 48 | where 49 | k :: Int 50 | k = ceiling (sqrt @Double (fromIntegral (hi - lo))) 51 | 52 | -- blocks :: Array Int (Block m) 53 | blocks = listArray (0,k-1) . map mkBlock . takeWhile (< hi) . iterate (+k) $ lo 54 | 55 | mkBlock i = Block (pref!(n-1)) pref suf (mkSqrtTree ms i j) 56 | where 57 | n = j-i 58 | elts = map (ms!) [i..j-1] 59 | pref = listArray (0,n-1) (scanl1 (<>) elts) 60 | suf = listArray (0,n-1) (scanr1 (<>) elts) 61 | j = min (i + k) hi 62 | 63 | -- between :: Array (Int,Int) m 64 | between = array ((0,0), (k-1, k-1)) $ 65 | [((s,s), total (blocks!s)) | s <- [0 .. k-1]] 66 | ++ [ ((s,t), between!(s,t-1) <> total (blocks!t)) 67 | | s <- [0 .. k-2] 68 | , t <- [s+1 .. k-1] 69 | ] 70 | -- XXX only to print out 71 | ++ [ ((t,s), between!(s,t-1) <> total (blocks!t)) 72 | | s <- [0 .. k-2] 73 | , t <- [s+1 .. k-1] 74 | ] 75 | 76 | range :: Semigroup m => SqrtTree m -> Int -> Int -> m 77 | range (One m) _ _ = m 78 | range (Two x y) 0 0 = x 79 | range (Two x y) 0 1 = x <> y 80 | range (Two x y) 1 1 = y 81 | range (Branch k blocks between) l r 82 | | lb == rb = range (subtree (blocks!lb)) li ri 83 | | rb - lb == 1 = suffix (blocks!lb) ! li <> prefix (blocks!rb) ! ri 84 | | otherwise = suffix (blocks!lb)!li <> between!(lb+1,rb-1) <> prefix (blocks!rb) ! ri 85 | where 86 | (lb, li) = l `divMod` k 87 | (rb, ri) = r `divMod` k 88 | 89 | randomRange :: Semigroup m => SqrtTree m -> IO m 90 | randomRange t = do 91 | l <- randomRIO (0,999999) 92 | r <- randomRIO (l,999999) 93 | pure $ range t l r 94 | 95 | main = do 96 | ns <- replicateM 1000000 (randomRIO (0,1000 :: Int)) 97 | let t = fromList (map Sum ns) 98 | rs <- replicateM 1000000 (randomRange t) 99 | print (getSum $ mconcat rs) 100 | -------------------------------------------------------------------------------- /Stack.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.github.io/blog/posts/2024/11/27/stacks-queues.html 2 | {-# LANGUAGE BangPatterns #-} 3 | 4 | module Stack where 5 | 6 | import Data.List (foldl') 7 | import Data.Monoid (Dual (..)) 8 | 9 | data Stack m a = Stack !(a -> m) !Int ![(m, a)] 10 | 11 | instance (Show m, Show a) => Show (Stack m a) where 12 | show (Stack _ _ as) = show as 13 | 14 | instance (Eq m, Eq a) => Eq (Stack m a) where 15 | Stack _ _ as1 == Stack _ _ as2 = as1 == as2 16 | 17 | new :: (a -> m) -> Stack m a 18 | new f = Stack f 0 [] 19 | 20 | size :: Stack m a -> Int 21 | size (Stack _ n _) = n 22 | 23 | measure :: Monoid m => Stack m a -> m 24 | measure (Stack _ _ as) = case as of 25 | [] -> mempty 26 | (m, _) : _ -> m 27 | 28 | push :: Monoid m => a -> Stack m a -> Stack m a 29 | push a s@(Stack f n as) = Stack f (n + 1) ((f a <> measure s, a) : as) 30 | 31 | pop :: Stack m a -> Maybe (a, Stack m a) 32 | pop (Stack f n as) = case as of 33 | [] -> Nothing 34 | (_, a) : as' -> Just (a, Stack f (n - 1) as') 35 | 36 | reverse' :: Monoid n => (m -> n) -> Stack m a -> Stack n a 37 | reverse' g (Stack f _ as) = foldl' (flip push) (new (g . f)) (map snd as) 38 | 39 | reverse :: Monoid m => Stack m a -> Stack m a 40 | reverse = reverse' id 41 | -------------------------------------------------------------------------------- /Tree.hs: -------------------------------------------------------------------------------- 1 | -- https://byorgey.github.io/blog/posts/2024/07/11/cpih-factor-full-tree.html 2 | -- https://byorgey.github.io/blog/posts/2024/08/08/TreeDecomposition.html 3 | {-# LANGUAGE TupleSections #-} 4 | 5 | module Tree where 6 | 7 | import Control.Arrow ((***)) 8 | import Control.Category ((>>>)) 9 | import Data.Bifunctor (second) 10 | import Data.List (maximumBy, sortBy) 11 | import Data.List.NonEmpty (NonEmpty) 12 | import qualified Data.List.NonEmpty as NE 13 | import Data.Map (Map, (!?)) 14 | import qualified Data.Map as M 15 | import Data.Ord (Down (..), comparing) 16 | import Data.Tree 17 | import Data.Tuple (swap) 18 | 19 | edgesToMap :: Ord a => [(a, a)] -> Map a [a] 20 | edgesToMap = concatMap (\p -> [p, swap p]) >>> dirEdgesToMap 21 | 22 | dirEdgesToMap :: Ord a => [(a, a)] -> Map a [a] 23 | dirEdgesToMap = map (second (: [])) >>> M.fromListWith (++) 24 | 25 | mapToTree :: Ord a => (a -> [b] -> b) -> Map a [a] -> a -> b 26 | mapToTree nd m root = go root root 27 | where 28 | go parent root = nd root (maybe [] (map (go root) . filter (/= parent)) (m !? root)) 29 | 30 | edgesToTree :: Ord a => (a -> [b] -> b) -> [(a, a)] -> a -> b 31 | edgesToTree nd = mapToTree nd . edgesToMap 32 | 33 | parentMap :: Ord a => Tree a -> Map a a 34 | parentMap = foldTree node >>> snd 35 | where 36 | node a b = (a, M.fromList (map (,a) as) <> mconcat ms) 37 | where 38 | (as, ms) = unzip b 39 | 40 | type SubtreeSelector a = a -> [Tree a] -> Maybe (Tree a, [Tree a]) 41 | 42 | pathDecomposition :: (a -> [Tree a] -> Maybe (Tree a, [Tree a])) -> Tree a -> [NonEmpty a] 43 | pathDecomposition select = go 44 | where 45 | go = selectPath select >>> second (concatMap go) >>> uncurry (:) 46 | 47 | selectPath :: SubtreeSelector a -> Tree a -> (NonEmpty a, [Tree a]) 48 | selectPath select = go 49 | where 50 | go (Node a ts) = case select a ts of 51 | Nothing -> (NE.singleton a, ts) 52 | Just (t, ts') -> ((a NE.<|) *** (ts' ++)) (go t) 53 | 54 | type Height = Int 55 | type Size = Int 56 | 57 | labelHeight :: Tree a -> Tree (Height, a) 58 | labelHeight = foldTree node 59 | where 60 | node a ts = case ts of 61 | [] -> Node (0, a) [] 62 | _ -> Node (1 + maximum (map (fst . rootLabel) ts), a) ts 63 | 64 | labelSize :: Tree a -> Tree (Size, a) 65 | labelSize = foldTree $ \a ts -> Node (1 + sum (map (fst . rootLabel) ts), a) ts 66 | 67 | -- | Decompose a tree into chains by length, first the longest 68 | -- possible chain, then the longest chain from what remains, and so 69 | -- on. 70 | maxChainDecomposition :: Tree a -> [NonEmpty (Height, a)] 71 | maxChainDecomposition = 72 | labelHeight 73 | >>> pathDecomposition (const (selectMaxBy (comparing (fst . rootLabel)))) 74 | >>> sortBy (comparing (Down . fst . NE.head)) 75 | 76 | selectMaxBy :: (a -> a -> Ordering) -> [a] -> Maybe (a, [a]) 77 | selectMaxBy _ [] = Nothing 78 | selectMaxBy cmp (a : as) = case selectMaxBy cmp as of 79 | Nothing -> Just (a, []) 80 | Just (b, bs) -> case cmp a b of 81 | LT -> Just (b, a : bs) 82 | _ -> Just (a, b : bs) 83 | 84 | -- | Heavy-light decomposition of a tree. 85 | heavyLightDecomposition :: Tree a -> [NonEmpty (Size, a)] 86 | heavyLightDecomposition = 87 | labelSize >>> pathDecomposition (const (selectMaxBy (comparing (fst . rootLabel)))) 88 | -------------------------------------------------------------------------------- /Trie.hs: -------------------------------------------------------------------------------- 1 | module Trie where 2 | 3 | import Control.Monad ((>=>)) 4 | import qualified Data.ByteString.Lazy.Char8 as C 5 | import Data.List (foldl') 6 | import Data.Map (Map, (!)) 7 | import qualified Data.Map as M 8 | import Data.Maybe (fromMaybe, isJust) 9 | 10 | data Trie a = Trie 11 | { trieSize :: !Int 12 | , value :: !(Maybe a) 13 | , children :: !(Map Char (Trie a)) 14 | } 15 | deriving Show 16 | 17 | emptyTrie :: Trie a 18 | emptyTrie = Trie 0 Nothing M.empty 19 | 20 | -- | Insert a new key/value pair into a trie, updating the size 21 | -- appropriately. 22 | insert :: C.ByteString -> a -> Trie a -> Trie a 23 | insert w a t = fst (go w t) 24 | where 25 | go = C.foldr 26 | (\c insSuffix (Trie n v m) -> 27 | let (t', ds) = insSuffix (fromMaybe emptyTrie (M.lookup c m)) 28 | in (Trie (n+ds) v (M.insert c t' m), ds) 29 | ) 30 | (\(Trie n v m) -> 31 | let ds = if isJust v then 0 else 1 32 | in (Trie (n+ds) (Just a) m, ds) 33 | ) 34 | 35 | -- | Create an initial trie from a list of key/value pairs. If there 36 | -- are multiple pairs with the same key, later pairs override 37 | -- earlier ones. 38 | mkTrie :: [(C.ByteString, a)] -> Trie a 39 | mkTrie = foldl' (flip (uncurry insert)) emptyTrie 40 | 41 | -- | Look up a single character in a trie, returning the corresponding 42 | -- child trie (if any). 43 | lookup1 :: Char -> Trie a -> Maybe (Trie a) 44 | lookup1 c = M.lookup c . children 45 | 46 | -- | Look up a string key in a trie, returning the corresponding value 47 | -- (if any). 48 | lookup :: C.ByteString -> Trie a -> Maybe a 49 | lookup = C.foldr ((>=>) . lookup1) value 50 | 51 | -- | Fold a trie into a summary value. 52 | foldTrie :: (Int -> Maybe a -> Map Char r -> r) -> Trie a -> r 53 | foldTrie f (Trie n b m) = f n b (M.map (foldTrie f) m) 54 | 55 | -- | "Decode" a string by repeatedly looking up consecutive 56 | -- characters. Every time we find a key which exists in the trie, 57 | -- emit the corresponding value and restart at the root. This is of 58 | -- particular use in decoding a prefix-free code. Note that this 59 | -- function will crash if it ever looks up a character which is not 60 | -- in the current trie. 61 | decode :: Trie a -> C.ByteString -> [a] 62 | decode t = reverse . snd . C.foldl' step (t, []) 63 | where 64 | step (s, as) c = 65 | let Just s' = lookup1 c s 66 | in maybe (s', as) (\a -> (t, a:as)) (value s') 67 | -------------------------------------------------------------------------------- /UnionFind.hs: -------------------------------------------------------------------------------- 1 | -- Adapted from https://kseo.github.io/posts/2014-01-30-implementing-union-find-in-haskell.html 2 | -- https://byorgey.github.io/blog/posts/2024/11/02/UnionFind.html 3 | -- https://byorgey.github.io/blog/posts/2024/11/18/UnionFind-sols.html 4 | {-# LANGUAGE RecordWildCards #-} 5 | 6 | module UnionFind where 7 | 8 | import Control.Monad (when) 9 | import Control.Monad.ST 10 | import Data.Array.ST 11 | 12 | type Node = Int 13 | 14 | data UnionFind s m = UnionFind 15 | { parent :: !(STUArray s Node Node) 16 | , sz :: !(STUArray s Node Int) 17 | , ann :: !(STArray s Node m) 18 | } 19 | 20 | new :: Int -> m -> ST s (UnionFind s m) 21 | new n m = newWith n (const m) 22 | 23 | newWith :: Int -> (Node -> m) -> ST s (UnionFind s m) 24 | newWith n m = 25 | UnionFind 26 | <$> newListArray (0, n - 1) [0 .. n - 1] 27 | <*> newArray (0, n - 1) 1 28 | <*> newListArray (0, n - 1) (map m [0 .. n - 1]) 29 | 30 | connected :: UnionFind s m -> Node -> Node -> ST s Bool 31 | connected uf x y = (==) <$> find uf x <*> find uf y 32 | 33 | find :: UnionFind s m -> Node -> ST s Node 34 | find uf@(UnionFind {..}) x = do 35 | p <- readArray parent x 36 | if p /= x 37 | then do 38 | r <- find uf p 39 | writeArray parent x r 40 | pure r 41 | else pure x 42 | 43 | updateAnn :: Semigroup m => UnionFind s m -> Node -> (m -> m) -> ST s () 44 | updateAnn uf@(UnionFind {..}) x f = do 45 | x <- find uf x 46 | old <- readArray ann x -- modifyArray is not available in Kattis test environment 47 | writeArray ann x (f old) 48 | 49 | union :: Semigroup m => UnionFind s m -> Node -> Node -> ST s () 50 | union uf@(UnionFind {..}) x y = do 51 | x <- find uf x 52 | y <- find uf y 53 | when (x /= y) $ do 54 | sx <- readArray sz x 55 | sy <- readArray sz y 56 | mx <- readArray ann x 57 | my <- readArray ann y 58 | if sx < sy 59 | then do 60 | writeArray parent x y 61 | writeArray sz y (sx + sy) 62 | writeArray ann y (mx <> my) 63 | else do 64 | writeArray parent y x 65 | writeArray sz x (sx + sy) 66 | writeArray ann x (mx <> my) 67 | 68 | size :: UnionFind s m -> Node -> ST s Int 69 | size uf@(UnionFind {..}) x = do 70 | x <- find uf x 71 | readArray sz x 72 | 73 | getAnn :: UnionFind s m -> Node -> ST s m 74 | getAnn uf@(UnionFind {..}) x = do 75 | x <- find uf x 76 | readArray ann x 77 | 78 | allAnns :: UnionFind s m -> ST s [(Int, m)] 79 | allAnns UnionFind {..} = do 80 | ps <- getAssocs parent 81 | flip foldMap ps $ \(p, x) -> 82 | if p == x 83 | then do 84 | a <- readArray ann x 85 | s <- readArray sz x 86 | pure [(s, a)] 87 | else pure [] 88 | 89 | -- XXX comment me 90 | -- https://algocoding.wordpress.com/2015/05/13/simple-union-find-techniques/ 91 | unite :: Semigroup m => UnionFind s m -> Node -> Node -> ST s Bool 92 | unite uf x y = do 93 | px <- readArray parent x 94 | py <- readArray parent y 95 | case compare px py of 96 | EQ -> pure True 97 | LT -> do 98 | writeArray parent x py 99 | case x == px of 100 | True -> pure False 101 | False -> unite uf px y 102 | GT -> do 103 | writeArray parent y px 104 | case y == py of 105 | True -> pure False 106 | False -> unite uf x py 107 | -------------------------------------------------------------------------------- /Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TupleSections #-} 2 | 3 | module Util where 4 | 5 | import Control.Arrow 6 | import Data.Array.IArray 7 | import Data.Bits (finiteBitSize, countLeadingZeros) 8 | import Data.Function 9 | import Data.List 10 | import qualified Data.List.NonEmpty as NE 11 | import Data.Maybe 12 | import Data.Ord 13 | 14 | fi :: (Integral a, Num b) => a -> b 15 | fi = fromIntegral 16 | 17 | fj :: Maybe a -> a 18 | fj = fromJust 19 | 20 | both :: (a -> b) -> (a, a) -> (b, b) 21 | both f (x, y) = (f x, f y) 22 | 23 | sortGroupOn :: Ord b => (a -> b) -> [a] -> [(b, NE.NonEmpty a)] 24 | sortGroupOn f = sortOn f >>> NE.groupBy ((==) `on` f) >>> map ((f . NE.head) &&& id) 25 | 26 | pairs :: [a] -> [(a, a)] 27 | pairs [] = [] 28 | pairs (a : as) = map (a,) as ++ pairs as 29 | 30 | withPairs :: Monoid r => (a -> a -> r) -> [a] -> r 31 | withPairs _ [] = mempty 32 | withPairs _ [_] = mempty 33 | withPairs f (a : as) = go as 34 | where 35 | go [] = withPairs f as 36 | go (a2 : rest) = f a a2 <> go rest 37 | 38 | generate :: (Ix i, IArray a e) => (i, i) -> (i -> e) -> a i e 39 | generate rng f = listArray rng (map f (range rng)) 40 | 41 | arraydef :: (Ix i, IArray a e) => (i, i) -> e -> [(i, e)] -> a i e 42 | arraydef rng def vs = array rng ([(i, def) | i <- range rng] ++ vs) 43 | 44 | -- | Logarithm base 2, rounded down to the nearest integer. Computed 45 | -- efficiently using primitive bitwise instructions. 46 | lg :: Int -> Int 47 | lg n = finiteBitSize n - 1 - countLeadingZeros n 48 | -------------------------------------------------------------------------------- /comprog-hs.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 2.4 2 | name: comprog-hs 3 | version: 0.1.0.0 4 | synopsis: 5 | 6 | -- A longer description of the package. 7 | -- description: 8 | homepage: 9 | 10 | -- A URL where users can report bugs. 11 | -- bug-reports: 12 | license: BSD-3-Clause 13 | license-file: LICENSE 14 | author: Brent Yorgey 15 | maintainer: byorgey@gmail.com 16 | 17 | -- A copyright notice. 18 | -- copyright: 19 | category: Development 20 | extra-source-files: 21 | CHANGELOG.md 22 | README.md 23 | 24 | library 25 | exposed-modules: BFS, 26 | BinarySearch, 27 | Enumeration, 28 | Geom, 29 | IdempotentSemigroup, 30 | Interval, 31 | NumberTheory, 32 | Perm, 33 | Queue, 34 | ScannerBS, 35 | Scanner, 36 | SegTree, 37 | Sieve, 38 | Slice, 39 | SparseTable, 40 | SqrtTree, 41 | Stack, 42 | Tree, 43 | Util 44 | 45 | build-depends: base, 46 | containers, 47 | hashable, 48 | unordered-containers, 49 | array, 50 | mtl, 51 | bytestring 52 | 53 | -- Directories containing source files. 54 | -- hs-source-dirs: 55 | default-language: Haskell2010 56 | --------------------------------------------------------------------------------