├── .gitignore
├── BFS.hs
├── BinarySearch.hs
├── CHANGELOG.md
├── Enumeration.hs
├── Geom.hs
├── IdempotentSemigroup.hs
├── Interval.hs
├── LICENSE
├── NumberTheory.hs
├── Perm.hs
├── Queue.hs
├── README.md
├── Scanner.hs
├── ScannerBS.hs
├── SegTree.hs
├── Sieve.hs
├── Slice.hs
├── SparseTable.hs
├── SqrtTree.hs
├── Stack.hs
├── Tree.hs
├── Trie.hs
├── UnionFind.hs
├── Util.hs
└── comprog-hs.cabal
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | TAGS
3 | *.o
4 | *.hi
5 | dist-newstyle
6 |
--------------------------------------------------------------------------------
/BFS.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2021/10/14/competitive-programming-in-haskell-bfs-part-1/
2 | -- https://byorgey.wordpress.com/2021/10/18/competitive-programming-in-haskell-bfs-part-2-alternative-apis/
3 | -- https://byorgey.wordpress.com/2021/10/29/competitive-programming-in-haskell-bfs-part-3-implementation-via-hashmap/
4 | -- https://byorgey.wordpress.com/2021/11/15/competitive-programming-in-haskell-enumeration/
5 | {-# LANGUAGE FlexibleContexts #-}
6 | {-# LANGUAGE RankNTypes #-}
7 | {-# LANGUAGE RecordWildCards #-}
8 | {-# LANGUAGE ScopedTypeVariables #-}
9 |
10 | module BFS where
11 |
12 | import Enumeration
13 |
14 | import Control.Arrow ((>>>))
15 | import Control.Monad
16 | import Control.Monad.ST
17 | import qualified Data.Array.IArray as IA
18 | import Data.Array.ST
19 | import Data.Array.Unboxed (UArray)
20 | import qualified Data.Array.Unboxed as U
21 | import Data.Array.Unsafe (unsafeFreeze)
22 | import Data.Sequence (Seq (..), ViewL (..), (<|), (|>))
23 | import qualified Data.Sequence as Seq
24 |
25 | ------------------------------------------------------------
26 | -- Utilities
27 | ------------------------------------------------------------
28 |
29 | infixl 0 >$>
30 | (>$>) :: a -> (a -> b) -> b
31 | (>$>) = flip ($)
32 | {-# INLINE (>$>) #-}
33 |
34 | exhaustM :: Monad m => (a -> m (Maybe a)) -> a -> m a
35 | exhaustM f = go
36 | where
37 | go a = do
38 | ma <- f a
39 | maybe (return a) go ma
40 |
41 | ------------------------------------------------------------
42 | -- BFS
43 | ------------------------------------------------------------
44 |
45 | data BFSResult v = BFSR {getLevel :: v -> Maybe Int, getParent :: v -> Maybe v}
46 |
47 | type V = Int
48 | data BFSState s = BS {level :: STUArray s V Int, parent :: STUArray s V V, queue :: Seq V}
49 |
50 | initBFSState :: Int -> [Int] -> ST s (BFSState s)
51 | initBFSState n vs = do
52 | l <- newArray (0, n - 1) (-1)
53 | p <- newArray (0, n - 1) (-1)
54 |
55 | forM_ vs $ \v -> writeArray l v 0
56 | return $ BS l p (Seq.fromList vs)
57 |
58 | bfs :: forall v. Enumeration v -> [v] -> (v -> [v]) -> (v -> Bool) -> BFSResult v
59 | bfs Enumeration {..} vs next goal =
60 | toResult $ bfs' card (map locate vs) (map locate . next . select) (goal . select)
61 | where
62 | toResult :: (forall s. ST s (BFSState s)) -> BFSResult v
63 | toResult m = runST $ do
64 | st <- m
65 | (level' :: UArray V Int) <- unsafeFreeze (level st)
66 | (parent' :: UArray V V) <- unsafeFreeze (parent st)
67 | return $
68 | BFSR
69 | ((\l -> guard (l /= -1) >> Just l) . (level' IA.!) . locate)
70 | ((\p -> guard (p /= -1) >> Just (select p)) . (parent' IA.!) . locate)
71 |
72 | visited :: BFSState s -> V -> ST s Bool
73 | visited BS {..} v = (/= -1) <$> readArray level v
74 | {-# INLINE visited #-}
75 |
76 | bfs' :: Int -> [V] -> (V -> [V]) -> (V -> Bool) -> ST s (BFSState s)
77 | bfs' n vs next goal = do
78 | st <- initBFSState n vs
79 | exhaustM bfsStep st
80 | where
81 | bfsStep st@BS {..} = case Seq.viewl queue of
82 | EmptyL -> return Nothing
83 | v :< q'
84 | | goal v -> return Nothing
85 | | otherwise ->
86 | v >$> next
87 | >>> filterM (fmap not . visited st)
88 | >=> foldM (upd v) (st {queue = q'})
89 | >>> fmap Just
90 |
91 | upd p b@BS {..} v = do
92 | lp <- readArray level p
93 | writeArray level v (lp + 1)
94 | writeArray parent v p
95 | return $ b {queue = queue |> v}
96 |
--------------------------------------------------------------------------------
/BinarySearch.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2023/01/01/competitive-programming-in-haskell-better-binary-search/
2 | -- https://byorgey.wordpress.com/2023/01/02/binary-search-over-floating-point-representations/
3 | -- https://julesjacobs.com/notes/binarysearch/binarysearch.pdf
4 |
5 | module BinarySearch where
6 |
7 | import Data.Bits
8 | import Data.Word (Word64)
9 |
10 | -- | Generic search function. Takes a step function, a predicate, and
11 | -- low and high values, and it finds values (l,r) right next to each
12 | -- other where the predicate switches from False to True.
13 | --
14 | -- More specifically, for @search mid p l r@:
15 | -- - PRECONDITIONS:
16 | -- - @(p l)@ must be @False@ and @(p r)@ must be @True@!
17 | -- - If using one of the binary search step functions, the
18 | -- predicate @p@ must be monotonic, that is, intuitively, @p@
19 | -- must switch from @False@ to @True@ exactly once. Formally, if
20 | -- @x <= y@ then @p x <= p y@ (where @False <= True@).
21 | -- - The mid function is called as @mid l r@ to find the next
22 | -- index to search on the interval [l, r]. @mid l r@ must
23 | -- return a value strictly between @l@ and @r@.
24 | -- - Use 'binary' to do binary search over the integers.
25 | -- - Use @('continuous' eps)@ to do binary search over rational
26 | -- or floating point values. (l,r) are returned such that r -
27 | -- l <= eps.
28 | -- - Use 'fwd' or 'bwd' to do linear search through a range.
29 | -- - If you are feeling adventurous, use 'floating' to do
30 | -- precise binary search over the bit representation of
31 | -- 'Double' values.
32 | -- - Returns (l,r) such that:
33 | -- - @not (p l) && p r@
34 | -- - @mid l r == Nothing@
35 | search :: (a -> a -> Maybe a) -> (a -> Bool) -> a -> a -> (a, a)
36 | search mid p = go
37 | where
38 | go l r = case mid l r of
39 | Nothing -> (l, r)
40 | Just m
41 | | p m -> go l m
42 | | otherwise -> go m r
43 |
44 | -- | Step function for binary search over an integral type. Stops
45 | -- when @r - l <= 1@; otherwise returns their midpoint.
46 | binary :: Integral a => a -> a -> Maybe a
47 | binary l r
48 | | r - l > 1 = Just (l + (r - l) `div` 2)
49 | | otherwise = Nothing
50 |
51 | -- | Step function for continuous binary search. Stops once
52 | -- @r - l <= eps@; otherwise returns their midpoint.
53 | continuous :: (Fractional a, Ord a) => a -> a -> a -> Maybe a
54 | continuous eps l r
55 | | r - l > eps = Just (l + (r - l) / 2)
56 | | otherwise = Nothing
57 |
58 | -- | Step function for forward linear search. Stops when @r - l <=
59 | -- 1$; otherwise returns @l + 1@.
60 | fwd :: (Num a, Ord a) => a -> a -> Maybe a
61 | fwd l r
62 | | r - l > 1 = Just (l + 1)
63 | | otherwise = Nothing
64 |
65 | -- | Step function for backward linear search. Stops when @r - l <=
66 | -- 1$; otherwise returns @r - 1@.
67 | bwd :: (Num a, Ord a) => a -> a -> Maybe a
68 | bwd l r
69 | | r - l > 1 = Just (r - 1)
70 | | otherwise = Nothing
71 |
72 | ------------------------------------------------------------
73 | -- Binary search on floating-point representations
74 |
75 | -- A lot of blood, sweat, and tears went into these functions. Are
76 | -- they even correct? Was it worth it? Who knows! See:
77 | --
78 | -- https://byorgey.wordpress.com/2023/01/01/competitive-programming-in-haskell-better-binary-search/#comment-40882
79 | -- https://byorgey.wordpress.com/2023/01/02/binary-search-over-floating-point-representations/
80 | -- https://web.archive.org/web/20220326204603/http://stereopsis.com/radix.html
81 |
82 | -- | Step function for precise binary search over the bit
83 | -- representation of @Double@ values.
84 | floating :: Double -> Double -> Maybe Double
85 | floating l r = b2f <$> binary (f2b l) (f2b r)
86 |
87 | -- | A monotonic conversion from 'Double' to 'Word64'. That is,
88 | -- @x < y@ iff @f2b x < f2b y@. 'b2f' is its inverse.
89 | f2b :: Double -> Word64
90 | f2b 0 = 0x7fffffffffffffff
91 | f2b x =
92 | (if m < 0 then 0 else bit 63)
93 | `xor` flipNeg (eBits `xor` mBits)
94 | where
95 | (m, e) = decodeFloat x
96 | eBits = fromIntegral (e + d + bias) `shiftL` d
97 | mBits = fromIntegral (abs m) `clearBit` d
98 |
99 | d = floatDigits x - 1
100 | bias = 1023
101 | flipNeg
102 | | m < 0 = (`clearBit` 63) . complement
103 | | otherwise = id
104 |
105 | prop_f2b_monotonic :: Double -> Double -> Bool
106 | prop_f2b_monotonic x y = (x < y) == (f2b x < f2b y)
107 |
108 | -- | The left inverse of 'f2b', that is, for all @x :: Double@, @b2f
109 | -- (f2b x) == x@. Note @f2b (b2f x) == x@ does not strictly hold
110 | -- since not every @Word64@ value corresponds to a valid @Double@
111 | -- value.
112 | b2f :: Word64 -> Double
113 | b2f 0x7fffffffffffffff = 0
114 | b2f w = encodeFloat m (fromIntegral e - bias - d)
115 | where
116 | s = testBit w 63
117 | w' = (if s then id else complement) w
118 |
119 | d = floatDigits (1 :: Double) - 1
120 | bias = 1023
121 | m = (if s then id else negate) (fromIntegral ((w' .&. ((1 `shiftL` d) - 1)) `setBit` d))
122 | e = clearBit w' 63 `shiftR` d
123 |
124 | prop_b2f_f2b :: Double -> Bool
125 | prop_b2f_f2b x = b2f (f2b x) == x
126 |
127 | -- Some Word64 values correspond to +/-Infinity or NaN. For most
128 | -- others, f2b is inverse to b2f; for a few that represent very tiny
129 | -- floating-point values, the Word64 resulting from a round trip may
130 | -- differ by 1.
131 | prop_f2b_b2f :: Word64 -> Bool
132 | prop_f2b_b2f w = isInfinite x || isNaN x || dist (f2b x) w <= 1
133 | where
134 | x = b2f w
135 | dist x y
136 | | x < y = y - x
137 | | otherwise = x - y
138 |
139 | -- Given two distinct Double values, if we take the midpoint of their
140 | -- corresponding Word64 values, we get another Word64 that represents
141 | -- a floating point number strictly in between the original two.
142 | prop_f2b_mid_monotonic :: Double -> Double -> Bool
143 | prop_f2b_mid_monotonic x y = x == y || (x' < z && z < y')
144 | where
145 | x' = min x y
146 | y' = max x y
147 | l = f2b x'
148 | r = f2b y'
149 | m = l + (r - l) `div` 2
150 | z = b2f m
151 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Revision history for comprog-hs
2 |
3 | ## 0.1.0.0 -- YYYY-mm-dd
4 |
5 | * First version. Released on an unsuspecting world.
6 |
--------------------------------------------------------------------------------
/Enumeration.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2021/11/15/competitive-programming-in-haskell-enumeration/
2 |
3 | {-# LANGUAGE ScopedTypeVariables #-}
4 | {-# LANGUAGE TypeApplications #-}
5 |
6 | module Enumeration where
7 |
8 | import qualified Data.List as L
9 |
10 | import Data.Hashable
11 | import qualified Data.Array as A
12 | import qualified Data.HashMap.Strict as HM
13 |
14 | data Enumeration a = Enumeration
15 | { card :: !Int
16 | , select :: Int -> a
17 | , locate :: a -> Int
18 | }
19 |
20 | -- | Map a pair of inverse functions over an invertible enumeration of
21 | -- @a@ values to turn it into an invertible enumeration of @b@
22 | -- values. Because invertible enumerations contain a /bijection/ to
23 | -- the natural numbers, we really do need both directions of a
24 | -- bijection between @a@ and @b@ in order to map. This is why
25 | -- 'Enumeration' cannot be an instance of 'Functor'.
26 | mapE :: (a -> b) -> (b -> a) -> Enumeration a -> Enumeration b
27 | mapE f g (Enumeration c s l) = Enumeration c (f . s) (l . g)
28 |
29 | -- | List the elements of an enumeration in order. Inverse of
30 | -- 'listE'.
31 | enumerate :: Enumeration a -> [a]
32 | enumerate e = map (select e) [0 .. card e-1]
33 |
34 | -- | The empty enumeration, with cardinality zero and no elements.
35 | voidE :: Enumeration a
36 | voidE = Enumeration 0 (error "select void") (error "locate void")
37 |
38 | -- | The unit enumeration, with a single value of @()@ at index 0.
39 | unitE :: Enumeration ()
40 | unitE = singletonE ()
41 |
42 | -- | An enumeration of a single given element at index 0.
43 | singletonE :: a -> Enumeration a
44 | singletonE a = Enumeration 1 (const a) (const 0)
45 |
46 | -- | A finite prefix of the natural numbers.
47 | finiteE :: Int -> Enumeration Int
48 | finiteE n = Enumeration n id id
49 |
50 | -- | Construct an enumeration from the elements of a finite list.
51 | -- The elements of the list must all be distinct. To turn an
52 | -- enumeration back into a list, use 'enumerate'.
53 | listE :: forall a. (Hashable a, Eq a) => [a] -> Enumeration a
54 | listE as = Enumeration n (toA A.!) (fromA HM.!)
55 | where
56 | n = length as
57 | toA :: A.Array Int a
58 | toA = A.listArray (0,n-1) as
59 |
60 | fromA :: HM.HashMap a Int
61 | fromA = HM.fromList (zip as [0 :: Int ..])
62 |
63 | -- | Enumerate all the values of a bounded 'Enum' instance.
64 | boundedEnum :: forall a. (Enum a, Bounded a) => Enumeration a
65 | boundedEnum = Enumeration
66 | { card = hi - lo + 1
67 | , select = toEnum . (+lo)
68 | , locate = subtract lo . fromEnum
69 | }
70 | where
71 | lo, hi :: Int
72 | lo = fromIntegral (fromEnum (minBound @a))
73 | hi = fromIntegral (fromEnum (maxBound @a))
74 |
75 | -- | Sum, /i.e./ disjoint union, of two enumerations. All the values
76 | -- of the first are enumerated before the values of the second.
77 | (>+<) :: Enumeration a -> Enumeration b -> Enumeration (Either a b)
78 | a >+< b = Enumeration
79 | { card = card a + card b
80 | , select = \k -> if k < card a then Left (select a k) else Right (select b (k - card a))
81 | , locate = either (locate a) ((+card a) . locate b)
82 | }
83 |
84 | -- | Cartesian product of enumerations, with a lexicographic ordering.
85 | (>*<) :: Enumeration a -> Enumeration b -> Enumeration (a,b)
86 | a >*< b = Enumeration
87 | { card = card a * card b
88 | , select = \k -> let (i,j) = k `divMod` card b in (select a i, select b j)
89 | , locate = \(x,y) -> card b * locate a x + locate b y
90 | }
91 |
92 | -- | Take a finite prefix from the beginning of an enumeration. @takeE
93 | -- k e@ always yields the empty enumeration for \(k \leq 0\), and
94 | -- results in @e@ whenever @k@ is greater than or equal to the
95 | -- cardinality of the enumeration. Otherwise @takeE k e@ has
96 | -- cardinality @k@ and matches @e@ from @0@ to @k-1@.
97 | takeE :: Int -> Enumeration a -> Enumeration a
98 | takeE k e
99 | | k <= 0 = voidE
100 | | k >= card e = e
101 | | otherwise = Enumeration k (select e) (locate e)
102 |
103 | -- | Drop some elements from the beginning of an enumeration. @dropE k
104 | -- e@ yields @e@ unchanged if \(k \leq 0\), and results in the empty
105 | -- enumeration whenever @k@ is greater than or equal to the
106 | -- cardinality of @e@.
107 | dropE :: Int -> Enumeration a -> Enumeration a
108 | dropE k e
109 | | k <= 0 = e
110 | | k >= card e = voidE
111 | | otherwise = Enumeration
112 | { card = card e - k
113 | , select = select e . (+k)
114 | , locate = subtract k . locate e
115 | }
116 |
117 | -- | Zip two enumerations in parallel, producing the pair of
118 | -- elements at each index. The resulting enumeration is truncated
119 | -- to the cardinality of the smaller of the two arguments.
120 | zipE :: Enumeration a -> Enumeration b -> Enumeration (a,b)
121 | zipE ea eb = Enumeration
122 | { card = min (card ea) (card eb)
123 | , select = \k -> (select ea k, select eb k)
124 | , locate = locate ea . fst
125 | }
126 |
--------------------------------------------------------------------------------
/Geom.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2020/06/24/competitive-programming-in-haskell-vectors-and-2d-geometry/
2 | {-# LANGUAGE DeriveFunctor #-}
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 | {-# LANGUAGE ViewPatterns #-}
5 |
6 | module Geom where
7 |
8 | import Data.Function (on)
9 | import Data.List (nub)
10 | import Data.Maybe (mapMaybe)
11 | import Data.Ord (compare)
12 | import Data.Ratio
13 |
14 | ------------------------------------------------------------
15 | -- 2D points and vectors
16 |
17 | data V2 s = V2 {getX :: !s, getY :: !s} deriving (Eq, Ord, Show, Functor)
18 | type V2D = V2 Double
19 |
20 | type P2 s = V2 s
21 | type P2D = P2 Double
22 |
23 | instance Foldable V2 where
24 | foldMap f (V2 x y) = f x <> f y
25 |
26 | zero :: Num s => V2 s
27 | zero = V2 0 0
28 |
29 | -- | Adding and subtracting vectors.
30 | (^+^), (^-^) :: Num s => V2 s -> V2 s -> V2 s
31 | V2 x1 y1 ^+^ V2 x2 y2 = V2 (x1 + x2) (y1 + y2)
32 | V2 x1 y1 ^-^ V2 x2 y2 = V2 (x1 - x2) (y1 - y2)
33 |
34 | -- | Scalar multiple of a vector.
35 | (*^) :: Num s => s -> V2 s -> V2 s
36 | (*^) k = fmap (k *)
37 |
38 | (^/) :: Fractional s => V2 s -> s -> V2 s
39 | v ^/ k = (1 / k) *^ v
40 |
41 | ------------------------------------------------------------
42 | -- Utilities
43 |
44 | -- | These combinators allows us to write e.g. 'v2 int' or 'v2 double'
45 | -- to get a 'Scanner (V2 s)'.
46 | v2, p2 :: Applicative f => f s -> f (V2 s)
47 | v2 s = V2 <$> s <*> s
48 | p2 = v2
49 |
50 | newtype ByX s = ByX {unByX :: V2 s} deriving (Eq, Show, Functor)
51 | newtype ByY s = ByY {unByY :: V2 s} deriving (Eq, Show, Functor)
52 |
53 | instance Ord s => Ord (ByX s) where
54 | compare = compare `on` (getX . unByX)
55 |
56 | instance Ord s => Ord (ByY s) where
57 | compare = compare `on` (getY . unByY)
58 |
59 | -- Manhattan distance
60 | manhattan :: Num s => P2 s -> P2 s -> s
61 | manhattan (V2 x1 y1) (V2 x2 y2) = abs (x1 - x2) + abs (y1 - y2)
62 |
63 | ------------------------------------------------------------
64 | -- Angles
65 |
66 | newtype Angle s = A s -- angle (radians)
67 | deriving (Show, Eq, Ord, Num, Fractional, Floating)
68 |
69 | fromDeg :: Floating s => s -> Angle s
70 | fromDeg d = A (d * pi / 180)
71 |
72 | fromRad :: s -> Angle s
73 | fromRad = A
74 |
75 | toDeg :: Floating s => Angle s -> s
76 | toDeg (A r) = r * 180 / pi
77 |
78 | toRad :: Angle s -> s
79 | toRad (A r) = r
80 |
81 | dir :: V2D -> Angle Double
82 | dir (V2 x y) = A $ atan2 y x
83 |
84 | -- | Construct a vector in polar coordinates.
85 | fromPolar :: Floating s => s -> Angle s -> V2 s
86 | fromPolar r θ = rot θ (V2 r 0)
87 |
88 | -- | Rotate a vector counterclockwise by a given angle.
89 | rot :: Floating s => Angle s -> V2 s -> V2 s
90 | rot (A θ) (V2 x y) = V2 (cos θ * x - sin θ * y) (sin θ * x + cos θ * y)
91 |
92 | perp :: Num s => V2 s -> V2 s
93 | perp (V2 x y) = V2 (-y) x
94 |
95 | ------------------------------------------------------------
96 | -- Dot product
97 |
98 | -- | Dot product of two vectors. u·v = |u||v| cos θ (where θ is the
99 | -- (unsigned) angle between u and v). So u·v is zero iff the vectors
100 | -- are perpendicular.
101 | dot :: Num s => V2 s -> V2 s -> s
102 | dot (V2 x1 y1) (V2 x2 y2) = x1 * x2 + y1 * y2
103 |
104 | -- | 'dotP p1 p2 p3' computes the dot product of the vectors from p1
105 | -- to p2 and from p1 to p3.
106 | dotP :: Num s => P2 s -> P2 s -> P2 s -> s
107 | dotP p1 p2 p3 = dot (p2 ^-^ p1) (p3 ^-^ p1)
108 |
109 | -- | Squared norm of a vector, /i.e./ square of its length, /i.e./ dot
110 | -- product with itself.
111 | normSq :: Num s => V2 s -> s
112 | normSq v = dot v v
113 |
114 | -- | Norm, /i.e./ length of a vector.
115 | norm :: Floating s => V2 s -> s
116 | norm = sqrt . normSq
117 |
118 | normalize :: Floating s => V2 s -> V2 s
119 | normalize v = v ^/ norm v
120 |
121 | -- | 'angleP p1 p2 p3' computes the (unsigned) angle of p1-p2-p3
122 | -- (/i.e./ the angle at p2 made by rays to p1 and p3). The result
123 | -- will always be in the range $[0, \pi]$.
124 | angleP :: Floating s => P2 s -> P2 s -> P2 s -> Angle s
125 | angleP x y z = A $ acos (dot a b / (norm a * norm b))
126 | where
127 | a = x ^-^ y
128 | b = z ^-^ y
129 |
130 | -- | 'signedAngleP p1 p2 p3' computes the /signed/ angle p1-p2-p3
131 | -- (/i.e./ the angle at p2 made by rays to p1 and p3), in the range
132 | -- $[-\pi, \pi]$. Positive iff the ray from p2 to p3 is
133 | -- counterclockwise of the ray from p2 to p1.
134 | signedAngleP :: (Floating s, Ord s) => P2 s -> P2 s -> P2 s -> Angle s
135 | signedAngleP x y z = case turnP x y z of
136 | CCW -> -angleP x y z
137 | _ -> angleP x y z
138 |
139 | ------------------------------------------------------------
140 | -- Cross product
141 |
142 | -- | 2D cross product of two vectors. Gives the signed area of their
143 | -- parallelogram (positive iff the second is counterclockwise of the
144 | -- first). [Geometric algebra tells us that this is really the
145 | -- coefficient of the bivector resulting from the outer product of
146 | -- the two vectors.]
147 | --
148 | -- Note this works even for integral scalar types.
149 | cross :: Num s => V2 s -> V2 s -> s
150 | cross (V2 ux uy) (V2 vx vy) = ux * vy - vx * uy
151 |
152 | -- | A version of cross product specialized to three points describing
153 | -- the endpoints of the vectors. The first argument is the shared
154 | -- tail of the vectors, and the second and third arguments are the
155 | -- endpoints of the vectors.
156 | crossP :: Num s => P2 s -> P2 s -> P2 s -> s
157 | crossP p1 p2 p3 = cross (p2 ^-^ p1) (p3 ^-^ p1)
158 |
159 | -- | The signed area of a triangle with given vertices can be computed
160 | -- as half the cross product of two of the edges.
161 | --
162 | -- Note that this requires 'Fractional' because of the division by
163 | -- two. If you want to stick with integral scalars, you can just
164 | -- use 'crossP' to get twice the signed area.
165 | signedTriArea :: Fractional s => P2 s -> P2 s -> P2 s -> s
166 | signedTriArea p1 p2 p3 = crossP p1 p2 p3 / 2
167 |
168 | -- | The (nonnegative) area of the triangle with the given vertices.
169 | triArea :: Fractional s => P2 s -> P2 s -> P2 s -> s
170 | triArea p1 p2 p3 = abs (signedTriArea p1 p2 p3)
171 |
172 | -- | The signed area of the polygon with the given vertices, via the
173 | -- "shoelace formula". Positive iff the points are given in
174 | -- counterclockwise order.
175 | signedPolyArea :: Fractional s => [P2 s] -> s
176 | signedPolyArea pts = sum $ zipWith (signedTriArea zero) pts (tail pts ++ [head pts])
177 |
178 | -- | The (nonnegative) area of the polygon with the given vertices.
179 | polyArea :: Fractional s => [P2 s] -> s
180 | polyArea = abs . signedPolyArea
181 |
182 | -- | Direction of a turn: counterclockwise (left), clockwise (right),
183 | -- or parallel (/i.e./ 0 or 180 degree turn).
184 | data Turn = CCW | Par | CW
185 |
186 | -- | Cross product can also be used to compute the direction of a
187 | -- turn. If you are travelling from p1 to p2, 'turnP p1 p2 p3' says
188 | -- whether you have to make a left (ccw) or right (cw) turn to
189 | -- continue on to p3 (or if it is parallel). Equivalently, if you
190 | -- are standing at p1 looking towards p2, and imagine the line
191 | -- through p1 and p2 dividing the plane in two, is p3 on the right
192 | -- side, left side, or on the line?
193 | turnP :: (Num s, Ord s) => P2 s -> P2 s -> P2 s -> Turn
194 | turnP x y z
195 | | s > 0 = CCW
196 | | s == 0 = Par
197 | | otherwise = CW
198 | where
199 | s = signum (crossP x y z)
200 |
201 | ------------------------------------------------------------
202 | -- 2D Lines
203 |
204 | data L2 s = L2 {getDirection :: !(V2 s), getOffset :: !s}
205 | type L2D = L2 Double
206 |
207 | lineFromEquation :: Num s => s -> s -> s -> L2 s
208 | lineFromEquation a b c = L2 (V2 b (-a)) c
209 |
210 | lineFromPoints :: Num s => P2 s -> P2 s -> L2 s
211 | lineFromPoints p q = L2 v (v `cross` p)
212 | where
213 | v = q ^-^ p
214 |
215 | slope :: (Integral n, Eq n) => L2 n -> Maybe (Ratio n)
216 | slope (getDirection -> V2 x y) = case x of
217 | 0 -> Nothing
218 | _ -> Just (y % x)
219 |
220 | dslope :: (Fractional s, Eq s) => L2 s -> Maybe s
221 | dslope (getDirection -> V2 x y) = case x of
222 | 0 -> Nothing
223 | _ -> Just (y / x)
224 |
225 | side :: Num s => L2 s -> P2 s -> s
226 | side (L2 v c) p = cross v p - c
227 |
228 | leftOf :: (Num s, Ord s) => L2 s -> P2 s -> Bool
229 | leftOf l p = side l p > 0
230 |
231 | rightOf :: (Num s, Ord s) => L2 s -> P2 s -> Bool
232 | rightOf l p = side l p < 0
233 |
234 | toProjection :: Fractional s => L2 s -> P2 s -> V2 s
235 | toProjection l@(L2 v _) p = (-side l p / normSq v) *^ perp v
236 |
237 | project :: Fractional s => L2 s -> P2 s -> P2 s
238 | project l p = p ^+^ toProjection l p
239 |
240 | reflectAcross :: Fractional s => L2 s -> P2 s -> P2 s
241 | reflectAcross l p = p ^+^ (2 *^ toProjection l p)
242 |
243 | lineIntersection :: (Fractional s, Eq s) => L2 s -> L2 s -> Maybe (P2 s)
244 | lineIntersection (L2 v1 c1) (L2 v2 c2)
245 | | cross v1 v2 == 0 = Nothing
246 | | otherwise = Just $ ((c1 *^ v2) ^-^ (c2 *^ v1)) ^/ cross v1 v2
247 |
248 | ------------------------------------------------------------
249 | -- Segments
250 |
251 | data Seg s = Seg (P2 s) (P2 s) deriving (Eq, Show)
252 |
253 | segLine :: Num s => Seg s -> L2 s
254 | segLine (Seg p q) = lineFromPoints p q
255 |
256 | -- Test whether two segments intersect.
257 | -- http://www.geeksforgeeks.org/check-if-two-given-line-segments-intersect/
258 | segsIntersect :: (Ord s, Num s) => Seg s -> Seg s -> Bool
259 | segsIntersect (Seg p1 q1) (Seg p2 q2)
260 | | o1 /= o2 && o3 /= o4 = True
261 | | o1 == 0 && onSegment p1 p2 q1 = True
262 | | o2 == 0 && onSegment p1 q2 q1 = True
263 | | o3 == 0 && onSegment p2 p1 q2 = True
264 | | o4 == 0 && onSegment p2 q1 q2 = True
265 | | otherwise = False
266 | where
267 | o1 = signum $ crossP p1 q1 p2
268 | o2 = signum $ crossP p1 q1 q2
269 | o3 = signum $ crossP p2 q2 p1
270 | o4 = signum $ crossP p2 q2 q1
271 |
272 | -- Given three *collinear* points p, q, r, check whether q lies on pr.
273 | onSegment (V2 px py) (V2 qx qy) (V2 rx ry) =
274 | min px rx <= qx
275 | && qx <= max px rx
276 | && min py ry <= qy
277 | && qy <= max py ry
278 |
279 | segsIntersection :: (Ord s, Fractional s) => Seg s -> Seg s -> Maybe (P2 s)
280 | segsIntersection s1 s2
281 | | segsIntersect s1 s2 = lineIntersection (segLine s1) (segLine s2)
282 | | otherwise = Nothing
283 |
284 | ------------------------------------------------------------
285 | -- Rectangles
286 |
287 | data Rect s = Rect {lowerLeft :: P2 s, dims :: V2 s} deriving (Eq, Show)
288 |
289 | rectFromCorners :: (Num s, Ord s) => P2 s -> P2 s -> Rect s
290 | rectFromCorners (V2 x1 y1) (V2 x2 y2) =
291 | Rect (V2 (min x1 x2) (min y1 y2)) (V2 (abs (x2 - x1)) (abs (y2 - y1)))
292 |
293 | pointInRect :: (Ord s, Num s) => P2 s -> Rect s -> Bool
294 | pointInRect (V2 px py) (Rect (V2 llx lly) (V2 dx dy)) =
295 | and
296 | [ px >= llx
297 | , px <= llx + dx
298 | , py >= lly
299 | , py <= lly + dy
300 | ]
301 |
302 | rectSegs :: Num s => Rect s -> [Seg s]
303 | rectSegs (Rect ll d@(V2 dx dy)) = [Seg ll ul, Seg ul ur, Seg ur lr, Seg lr ll]
304 | where
305 | ul = ll ^+^ V2 0 dy
306 | ur = ll ^+^ d
307 | lr = ll ^+^ V2 dx 0
308 |
309 | rectSegIntersection :: (Fractional s, Ord s) => Rect s -> Seg s -> Maybe (Seg s)
310 | rectSegIntersection r s@(Seg t u)
311 | | pointInRect t r && pointInRect u r = Just s
312 | | otherwise = case nub (mapMaybe (segsIntersection s) (rectSegs r)) of
313 | [] -> Nothing
314 | [p, q] -> Just $ Seg p q
315 | [p]
316 | | pointInRect t r -> Just (Seg p t)
317 | | pointInRect u r -> Just (Seg p u)
318 | | otherwise -> Nothing
319 |
320 | ------------------------------------------------------------
321 | -- Circles
322 |
323 | data Circle s = Circle {center :: P2 s, radius :: s} deriving (Eq, Show)
324 |
325 | pointInCircle :: (Ord s, Num s) => P2 s -> Circle s -> Bool
326 | pointInCircle p (Circle c r) = normSq (p ^-^ c) <= r * r
327 |
328 | rectCircleIntersection :: (Ord s, Num s) => Rect s -> Circle s -> Bool
329 | rectCircleIntersection (Rect ll@(V2 llx lly) d@(V2 dx dy)) (Circle c r) =
330 | or
331 | [ pointInRect c (Rect (V2 (llx - r) lly) (V2 (dx + 2 * r) dy))
332 | , pointInRect c (Rect (V2 llx (lly - r)) (V2 dx (dy + 2 * r)))
333 | , pointInCircle c (Circle ll r)
334 | , pointInCircle c (Circle (ll ^+^ V2 dx 0) r)
335 | , pointInCircle c (Circle (ll ^+^ V2 0 dy) r)
336 | , pointInCircle c (Circle (ll ^+^ d) r)
337 | ]
338 |
--------------------------------------------------------------------------------
/IdempotentSemigroup.hs:
--------------------------------------------------------------------------------
1 | module IdempotentSemigroup where
2 |
3 | import Data.Bits
4 | import Data.Semigroup
5 |
6 | -- | An idempotent semigroup is one where the binary operation
7 | -- satisfies the law @x <> x = x@ for all @x@.
8 | class Semigroup m => IdempotentSemigroup m
9 |
10 | instance Ord a => IdempotentSemigroup (Min a)
11 | instance Ord a => IdempotentSemigroup (Max a)
12 | instance IdempotentSemigroup All
13 | instance IdempotentSemigroup Any
14 | instance IdempotentSemigroup Ordering
15 | instance IdempotentSemigroup ()
16 | instance IdempotentSemigroup (First a)
17 | instance IdempotentSemigroup (Last a)
18 | instance Bits a => IdempotentSemigroup (And a)
19 | instance Bits a => IdempotentSemigroup (Ior a)
20 |
--------------------------------------------------------------------------------
/Interval.hs:
--------------------------------------------------------------------------------
1 | module Interval where
2 |
3 | import Prelude hiding (drop, length, take)
4 |
5 | data I = I {lo :: !Int, hi :: !Int}
6 | deriving (Eq, Ord, Show)
7 |
8 | point :: Int -> I
9 | point x = I x x
10 |
11 | mkI :: Int -> Int -> I
12 | mkI l h
13 | | l == h = I 0 0
14 | | otherwise = I l h
15 |
16 | (∪), (∩) :: I -> I -> I
17 | I l1 h1 ∪ I l2 h2 = I (min l1 l2) (max h1 h2)
18 | I l1 h1 ∩ I l2 h2 = I (max l1 l2) (min h1 h2)
19 |
20 | intersects :: I -> I -> Bool
21 | intersects i1 i2 = not (isEmpty (i1 ∩ i2))
22 |
23 | isEmpty :: I -> Bool
24 | isEmpty (I l h) = l > h
25 |
26 | (⊆) :: I -> I -> Bool
27 | i1 ⊆ i2 = i1 ∪ i2 == i2
28 |
29 | uncons :: I -> I
30 | uncons (I l h) = mkI (l + 1) h
31 |
32 | splits :: I -> [(I, I)]
33 | splits (I l h) = [(mkI l m, mkI m h) | m <- [l .. h]]
34 |
35 | subs :: I -> [I]
36 | subs (I l h) = mkI 0 0 : [mkI l' h' | l' <- [l .. h - 1], h' <- [l' + 1 .. h]]
37 |
38 | length :: I -> Int
39 | length (I l h) = h - l
40 |
41 | take :: Int -> I -> I
42 | take k (I l h) = mkI l (min (l + k) h)
43 |
44 | drop :: Int -> I -> I
45 | drop k (I l h) = mkI (min (l + k) h) h
46 |
47 | range :: I -> [Int]
48 | range (I l h) = [l .. h - 1]
49 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, Brent Yorgey
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above
12 | copyright notice, this list of conditions and the following
13 | disclaimer in the documentation and/or other materials provided
14 | with the distribution.
15 |
16 | * Neither the name of Brent Yorgey nor the names of other
17 | contributors may be used to endorse or promote products derived
18 | from this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/NumberTheory.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2020/02/07/competitive-programming-in-haskell-primes-and-factoring/
2 | -- https://byorgey.wordpress.com/2020/02/15/competitive-programming-in-haskell-modular-arithmetic-part-1/
3 | -- https://byorgey.wordpress.com/2020/03/03/competitive-programming-in-haskell-modular-arithmetic-part-2/
4 |
5 | module NumberTheory where
6 |
7 | import qualified Data.Foldable as F
8 | import Data.Map (Map)
9 | import qualified Data.Map as M
10 |
11 | import Control.Arrow
12 | import Data.Bits
13 | import Data.List (group, sort)
14 | import Data.Maybe (fromJust)
15 |
16 | ------------------------------------------------------------
17 | -- Modular exponentiation
18 |
19 | modexp :: Integer -> Integer -> Integer -> Integer
20 | modexp b e m = go e
21 | where
22 | go 0 = 1
23 | go e
24 | | e `testBit` 0 = (b * r * r) `mod` m
25 | | otherwise = (r * r) `mod` m
26 | where
27 | r = go (e `shiftR` 1)
28 |
29 | ------------------------------------------------------------
30 | -- (Extended) Euclidean algorithm
31 |
32 | -- egcd a b = (g,x,y)
33 | -- g is the gcd of a and b, and ax + by = g
34 | egcd :: Integer -> Integer -> (Integer, Integer, Integer)
35 | egcd a 0
36 | | a < 0 = (-a, -1, 0)
37 | | otherwise = (a, 1, 0)
38 | egcd a b = (g, y, x - (a `div` b) * y)
39 | where
40 | (g, x, y) = egcd b (a `mod` b)
41 |
42 | -- g = bx + (a mod b)y
43 | -- = bx + (a - b(a/b))y
44 | -- = ay + b(x - (a/b)y)
45 |
46 | -- inverse p a is the multiplicative inverse of a mod p
47 | inverse :: Integer -> Integer -> Integer
48 | inverse p a = y `mod` p
49 | where
50 | (_, _, y) = egcd p a
51 |
52 | ------------------------------------------------------------
53 | -- Primes, factoring, and divisors
54 |
55 | --------------------------------------------------
56 | -- Miller-Rabin primality testing
57 |
58 | -- Need to upgrade to Baille-PSW? See ~/learning/primality/baille-PSW.py
59 |
60 | smallPrimes = take 20 primes
61 |
62 | -- https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Testing_against_small_sets_of_bases
63 | mrPrimes n
64 | | n < 2047 = [2]
65 | | n < 1373653 = [2, 3]
66 | | n < 9080191 = [31, 73]
67 | | n < 25326001 = [2, 3, 5]
68 | | n < 3215031751 = [2, 3, 5, 7]
69 | | n < 4759123141 = [2, 7, 61]
70 | | n < 1122004669633 = [2, 13, 23, 1662803]
71 | | n < 2152302898747 = [2, 3, 5, 7, 11]
72 | | n < 3474749660383 = [2, 3, 5, 7, 11, 13]
73 | | n < 341550071728321 = [2, 3, 5, 7, 11, 13, 17]
74 | | n < 3825123056546413051 = [2, 3, 5, 7, 11, 13, 17, 19, 23]
75 | | n < 318665857834031151167461 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]
76 | | n < 3317044064679887385961981 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41]
77 |
78 | -- With these values of a, guaranteed to work up to 3*10^24 (see https://pastebin.com/6XEFRPaZ)
79 | isPrime :: Integer -> Bool
80 | isPrime n
81 | | n == 1 = False
82 | | n `elem` smallPrimes = True
83 | | any ((== 0) . (n `mod`)) smallPrimes = False
84 | | otherwise = spps n (mrPrimes n)
85 |
86 | isPrimeTrialDiv :: Integer -> Bool
87 | isPrimeTrialDiv 1 = False
88 | isPrimeTrialDiv n = all ((/=0) . (n `mod`)) (takeWhile (\p -> p*p <= n) primes)
89 |
90 | -- spp n a tests whether n is a strong probable prime to base a.
91 | spp :: Integer -> Integer -> Bool
92 | spp n a = spp' n s d a
93 | where
94 | (s, d) = decompose (n - 1)
95 |
96 | spp' n s d a = (ad == 1) || (n - 1) `elem` as
97 | where
98 | ad = modexp a d n
99 | as = take s (iterate ((`mod` n) . (^ 2)) ad)
100 |
101 | -- spps n as tests whether n is a strong probable prime to all the
102 | -- given bases.
103 | spps :: Integer -> [Integer] -> Bool
104 | spps n = all (spp' n s d)
105 | where
106 | (s, d) = decompose (n - 1)
107 |
108 | -- decompose n = (s,d) such that n = 2^s * d and d is odd.
109 | -- Only works for n < 2^63.
110 | decompose :: Integer -> (Int, Integer)
111 | decompose n = (tz, n `shiftR` tz)
112 | where
113 | tz = countTrailingZeros (fromIntegral n :: Int)
114 |
115 | --------------------------------------------------
116 | -- Pollard Rho factoring algorithm
117 |
118 | -- Tries to find a non-trivial factor of the given number, using the
119 | -- given starting value.
120 | pollardRho :: Integer -> Integer -> Maybe Integer
121 | pollardRho a n = go (g a) (g (g a))
122 | where
123 | go x y
124 | | d == n = Nothing
125 | | d == 1 = go (g x) (g (g y))
126 | | otherwise = Just d
127 | where
128 | d = gcd (abs (x - y)) n
129 | g x = (x * x + 1) `mod` n
130 |
131 | -- Find a nontrivial factor of a number we know for sure is composite.
132 | compositeFactor :: Integer -> Integer
133 | compositeFactor n | even n = 2
134 | compositeFactor 25 = 5
135 | compositeFactor n = fromJust (F.asum (map (`pollardRho` n) [2 ..]))
136 |
137 | --------------------------------------------------
138 | -- Factoring
139 |
140 | factorMap :: Integer -> Map Integer Int
141 | factorMap = factor >>> M.fromList
142 |
143 | factor :: Integer -> [(Integer, Int)]
144 | factor = listFactors >>> group >>> map (head &&& length)
145 |
146 | primes :: [Integer]
147 | primes = 2 : sieve primes [3 ..]
148 | where
149 | sieve (p : ps) xs =
150 | let (h, t) = span (< p * p) xs
151 | in h ++ sieve ps (filter ((/= 0) . (`mod` p)) t)
152 |
153 | listFactors :: Integer -> [Integer]
154 | listFactors = sort . go
155 | where
156 | go 1 = []
157 | go n
158 | | isPrime n = [n]
159 | | otherwise = go d ++ go (n `div` d)
160 | where
161 | d = compositeFactor n
162 |
163 | -- listFactors :: Integer -> [Integer]
164 | -- listFactors = go primes
165 | -- where
166 | -- go _ 1 = []
167 | -- go (p:ps) n
168 | -- | p*p > n = [n]
169 | -- | n `mod` p == 0 = p : go (p:ps) (n `div` p)
170 | -- | otherwise = go ps n
171 |
172 | divisors :: Integer -> [Integer]
173 | divisors =
174 | factor
175 | >>> map (\(p, k) -> take (k + 1) (iterate (* p) 1))
176 | >>> sequence
177 | >>> map product
178 |
179 | totient :: Integer -> Integer
180 | totient = factor >>> map (\(p, k) -> p ^ (k - 1) * (p - 1)) >>> product
181 |
182 | ------------------------------------------------------------
183 | -- Solving modular equations
184 |
185 | -- solveMod a b m solves ax = b (mod m), returning (y,k) such that all
186 | -- solutions are equivalent to y (mod k)
187 | solveMod :: Integer -> Integer -> Integer -> Maybe (Integer, Integer)
188 | solveMod a b m
189 | | g == 1 = Just ((b * inverse m a) `mod` m, m)
190 | | b `mod` g == 0 = solveMod (a `div` g) (b `div` g) (m `div` g)
191 | | otherwise = Nothing
192 | where
193 | g = gcd a m
194 |
195 | -- gcrt solves a system of modular equations. Each equation x = a
196 | -- (mod n) is given as a pair (a,n). Returns a pair (z, k) such that
197 | -- 0 <= z < k and solutions for x satisfy x = z (mod k), that is,
198 | -- solutions are of the form x = z + kt for integer t.
199 | gcrt :: [(Integer, Integer)] -> Maybe (Integer, Integer)
200 | gcrt [e] = Just e
201 | gcrt (e1 : e2 : es) = gcrt2 e1 e2 >>= \e -> gcrt (e : es)
202 |
203 | -- gcrt2 (a,n) (b,m) solves the pair of modular equations
204 | --
205 | -- x = a (mod n)
206 | -- x = b (mod m)
207 | --
208 | -- It returns a pair (c, k) such that 0 <= c < k and all solutions for
209 | -- x satisfy x = c (mod k), that is, solutions are of the form x = c +
210 | -- kt for integer t.
211 | gcrt2 :: (Integer, Integer) -> (Integer, Integer) -> Maybe (Integer, Integer)
212 | gcrt2 (a, n) (b, m)
213 | | a `mod` g == b `mod` g = Just (((a * v * m + b * u * n) `div` g) `mod` k, k)
214 | | otherwise = Nothing
215 | where
216 | (g, u, v) = egcd n m
217 | k = (m * n) `div` g
218 |
--------------------------------------------------------------------------------
/Perm.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2020/07/18/competitive-programming-in-haskell-cycle-decomposition-with-mutable-arrays/
2 | {-# LANGUAGE BangPatterns #-}
3 |
4 | module Perm where
5 |
6 | import Control.Arrow
7 | import Control.Monad.ST
8 | import Data.Array.Base
9 | import Data.Array.MArray
10 | import Data.Array.ST
11 | import Data.Array.Unboxed
12 |
13 | -- | 'Perm' represents a /1-indexed/ permutation. It can also be
14 | -- thought of as an endofunction on the set @{1 .. n}@.
15 | newtype Perm = Perm {getPerm :: UArray Int Int}
16 | deriving (Eq, Ord, Show)
17 |
18 | idPerm :: Int -> Perm
19 | idPerm n = fromList [1 .. n]
20 |
21 | -- | Construct a 'Perm' from a list containing a permutation of the
22 | -- numbers 1..n. The resulting 'Perm' sends @i@ to whatever number
23 | -- is at index @i-1@ in the list.
24 | fromList :: [Int] -> Perm
25 | fromList xs = Perm $ listArray (1, length xs) xs
26 |
27 | -- | Compose two permutations (corresponds to backwards function
28 | -- composition). Only defined if the permutations have the same
29 | -- size.
30 | andThen :: Perm -> Perm -> Perm
31 | andThen (Perm p1) (Perm p2) = Perm $ listArray (bounds p1) (map ((p1 !) >>> (p2 !)) (range (bounds p1)))
32 |
33 | instance Semigroup Perm where
34 | (<>) = andThen
35 |
36 | -- | Compute the inverse of a permutation.
37 | inverse :: Perm -> Perm
38 | inverse (Perm p) = Perm $ array (bounds p) [(p ! k, k) | k <- range (bounds p)]
39 |
40 | data CycleDecomp = CD
41 | { cycleID :: UArray Int Int
42 | , cycleLen :: UArray Int Int
43 | -- ^ Each number maps to the ID #
44 | -- of the cycle it is part of
45 | , cycleIndex :: UArray Int Int
46 | -- ^ Each cycle ID maps to the length of that cycle
47 | , cycleCounts :: UArray Int Int
48 | -- ^ Each element maps to its (0-based) index in its cycle
49 | }
50 | -- \| Each size maps to the number of cycles of that size
51 |
52 | deriving (Show)
53 |
54 | -- | Cycle decomposition of a permutation in O(n), using mutable arrays.
55 | permToCycles :: Perm -> CycleDecomp
56 | permToCycles (Perm p) = cd
57 | where
58 | (_, n) = bounds p
59 |
60 | cd = runST $ do
61 | cid <- newArray (1, n) 0
62 | cix <- newArray (1, n) 0
63 | ccs <- newArray (1, n) 0
64 |
65 | lens <- findCycles cid cix ccs 1 1
66 | cid' <- freeze cid
67 | cix' <- freeze cix
68 | ccs' <- freeze ccs
69 | return $ CD cid' (listArray (1, length lens) lens) cix' ccs'
70 |
71 | findCycles ::
72 | STUArray s Int Int ->
73 | STUArray s Int Int ->
74 | STUArray s Int Int ->
75 | Int ->
76 | Int ->
77 | ST s [Int]
78 | findCycles cid cix ccs l !k -- l = next available cycle ID; k = cur element
79 | | k > n = return []
80 | | otherwise = do
81 | -- check if k is already marked as part of a cycle
82 | id <- readArray cid k
83 | case id of
84 | 0 -> do
85 | -- k is unvisited. Explore its cycle and label it as l.
86 | len <- labelCycle cid cix l k 0
87 |
88 | -- Remember that we have one more cycle of this size.
89 | count <- readArray ccs len
90 | writeArray ccs len (count + 1)
91 |
92 | -- Continue with the next label and the next element, and
93 | -- remember the size of this cycle
94 | (len :) <$> findCycles cid cix ccs (l + 1) (k + 1)
95 |
96 | -- k is already visited: just go on to the next element
97 | _ -> findCycles cid cix ccs l (k + 1)
98 |
99 | -- Explore a single cycle, label all its elements and return its size.
100 | labelCycle cid cix l k !i = do
101 | -- Keep going as long as the next element is unlabelled.
102 | id <- readArray cid k
103 | case id of
104 | 0 -> do
105 | -- Label the current element with l.
106 | writeArray cid k l
107 | -- The index of the current element is i.
108 | writeArray cix k i
109 |
110 | -- Look up the next element in the permutation and continue.
111 | (1 +) <$> labelCycle cid cix l (p ! k) (i + 1)
112 | _ -> return 0
113 |
--------------------------------------------------------------------------------
/Queue.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.github.io/blog/posts/2024/11/27/stacks-queues.html
2 | {-# LANGUAGE ConstraintKinds #-}
3 | {-# LANGUAGE ImportQualifiedPost #-}
4 |
5 | module Queue where
6 |
7 | import Data.Bifunctor (second)
8 | import Data.List (foldl')
9 | import Data.Monoid (Dual (..))
10 | import Stack (Stack)
11 | import Stack qualified
12 |
13 | data Queue m a = Queue {getFront :: !(Stack m a), getBack :: !(Stack (Dual m) a)}
14 | deriving (Show, Eq)
15 |
16 | new :: (a -> m) -> Queue m a
17 | new f = Queue (Stack.new f) (Stack.new (Dual . f))
18 |
19 | size :: Queue m a -> Int
20 | size (Queue front back) = Stack.size front + Stack.size back
21 |
22 | measure :: Monoid m => Queue m a -> m
23 | measure (Queue front back) = Stack.measure front <> getDual (Stack.measure back)
24 |
25 | enqueue :: Monoid m => a -> Queue m a -> Queue m a
26 | enqueue a (Queue front back) = Queue front (Stack.push a back)
27 |
28 | dequeue :: Monoid m => Queue m a -> Maybe (a, Queue m a)
29 | dequeue (Queue front back)
30 | | Stack.size front == 0 && Stack.size back == 0 = Nothing
31 | | Stack.size front == 0 = dequeue (Queue (Stack.reverse' getDual back) (Stack.reverse' Dual front))
32 | | otherwise = second (`Queue` back) <$> Stack.pop front
33 |
34 | drop1 :: Monoid m => Queue m a -> Queue m a
35 | drop1 q = case dequeue q of
36 | Nothing -> q
37 | Just (_, q') -> q'
38 |
39 | ------------------------------------------------------------
40 | -- Sliding windows
41 |
42 | -- @windows w f as@ computes the monoidal sum @foldMap f window@
43 | -- for each w-@window@ (i.e. contiguous subsequence of length @w@) of
44 | -- @as@, in only O(length as) time. For example, @windows 3 Sum
45 | -- [4,1,2,8,3] = [7, 11, 13]@, and @windows 3 Max [4,1,2,8,3] = [4,8,8]@.
46 | windows :: Monoid m => Int -> (a -> m) -> [a] -> [m]
47 | windows w f as = go startQ rest
48 | where
49 | (start, rest) = splitAt w as
50 | startQ = foldl' (flip enqueue) (new f) start
51 |
52 | go q as =
53 | measure q : case as of
54 | [] -> []
55 | a : as -> go (enqueue a (drop1 q)) as
56 |
57 | data Max a = NegInf | Max a deriving (Eq, Ord, Show)
58 |
59 | instance Ord a => Semigroup (Max a) where
60 | NegInf <> a = a
61 | a <> NegInf = a
62 | Max a <> Max b = Max (max a b)
63 |
64 | instance Ord a => Monoid (Max a) where
65 | mempty = NegInf
66 |
67 | data Min a = Min a | PosInf deriving (Eq, Ord, Show)
68 |
69 | instance Ord a => Semigroup (Min a) where
70 | PosInf <> a = a
71 | a <> PosInf = a
72 | Min a <> Min b = Min (min a b)
73 |
74 | instance Ord a => Monoid (Min a) where
75 | mempty = PosInf
76 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Competitive programming utilities in Haskell
2 | --------------------------------------------
3 |
4 | Code that I use for solving competitive programming problems in
5 | Haskell. See [Competitive Programming in Haskell: Basic
6 | Setup](https://byorgey.wordpress.com/2019/04/24/competitive-programming-in-haskell-basic-setup/)
7 | for an introduction. Individual files also link to blog posts where
8 | they are discussed.
9 |
10 | Copyright 2019 by Brent Yorgey. Everything in this repository is
11 | licensed under a Creative Commons Attribution 4.0 International License.
12 |
13 | 
14 |
--------------------------------------------------------------------------------
/Scanner.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2019/05/22/competitive-programming-in-haskell-scanner/
2 | {-# LANGUAGE LambdaCase #-}
3 |
4 | module Scanner where
5 |
6 | import Control.Monad (replicateM)
7 | import Control.Monad.State
8 |
9 | type Scanner = State [String]
10 |
11 | runScanner :: Scanner a -> String -> a
12 | runScanner = runScannerWith words
13 |
14 | runScannerWith :: (String -> [String]) -> Scanner a -> String -> a
15 | runScannerWith t s = evalState s . t
16 |
17 | str :: Scanner String
18 | str = get >>= \case s : ss -> put ss >> return s
19 |
20 | int :: Scanner Int
21 | int = read <$> str
22 |
23 | integer :: Scanner Integer
24 | integer = read <$> str
25 |
26 | double :: Scanner Double
27 | double = read <$> str
28 |
29 | decimal :: Int -> Scanner Int
30 | decimal p = (round . ((10 ^ p) *)) <$> double
31 |
32 | numberOf :: Scanner a -> Scanner [a]
33 | numberOf s = int >>= flip replicateM s
34 |
35 | many :: Scanner a -> Scanner [a]
36 | many s = get >>= \case [] -> return []; _ -> (:) <$> s <*> many s
37 |
38 | times :: Int -> Scanner a -> Scanner [a]
39 | times = replicateM
40 |
41 | (><) :: Int -> Scanner a -> Scanner [a]
42 | (><) = times
43 |
44 | two, three, four :: Scanner a -> Scanner [a]
45 | [two, three, four] = map times [2 .. 4]
46 |
--------------------------------------------------------------------------------
/ScannerBS.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.wordpress.com/2019/10/12/competitive-programming-in-haskell-reading-large-inputs-with-bytestring/
2 |
3 | {-# LANGUAGE LambdaCase #-}
4 |
5 | module ScannerBS where
6 |
7 | import Control.Applicative (liftA2)
8 | import Control.Monad (replicateM)
9 | import Control.Monad.State
10 | import qualified Data.ByteString.Lazy.Char8 as C
11 | import Data.Maybe (fromJust)
12 |
13 | type Scanner = State [C.ByteString]
14 |
15 | runScanner :: Scanner a -> C.ByteString -> a
16 | runScanner = runScannerWith C.words
17 |
18 | runScannerWith :: (C.ByteString -> [C.ByteString]) -> Scanner a -> C.ByteString -> a
19 | runScannerWith t s = evalState s . t
20 |
21 | peek :: Scanner C.ByteString
22 | peek = head <$> get
23 |
24 | str :: Scanner C.ByteString
25 | str = get >>= \case { s:ss -> put ss >> return s }
26 |
27 | int :: Scanner Int
28 | int = (fst . fromJust . C.readInt) <$> str
29 |
30 | integer :: Scanner Integer
31 | integer = (read . C.unpack) <$> str
32 |
33 | double :: Scanner Double
34 | double = (read . C.unpack) <$> str
35 |
36 | decimal :: Int -> Scanner Int
37 | decimal p = (round . ((10^p)*)) <$> double
38 |
39 | numberOf :: Scanner a -> Scanner [a]
40 | numberOf s = int >>= flip replicateM s
41 |
42 | many :: Scanner a -> Scanner [a]
43 | many s = get >>= \case { [] -> return []; _ -> (:) <$> s <*> many s }
44 |
45 | till :: (C.ByteString -> Bool) -> Scanner a -> Scanner [a]
46 | till p s = do
47 | t <- peek
48 | case p t of
49 | True -> return []
50 | False -> (:) <$> s <*> till p s
51 |
52 | times :: Int -> Scanner a -> Scanner [a]
53 | times = replicateM
54 |
55 | (><) = times
56 |
57 | two, three, four :: Scanner a -> Scanner [a]
58 | [two, three, four] = map times [2..4]
59 |
60 | pair :: Scanner a -> Scanner b -> Scanner (a,b)
61 | pair = liftA2 (,)
62 |
--------------------------------------------------------------------------------
/SegTree.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE FlexibleInstances #-}
2 | {-# LANGUAGE MultiParamTypeClasses #-}
3 | {-# LANGUAGE ViewPatterns #-}
4 |
5 | module SegTree where
6 |
7 | import Data.List (find)
8 | import Data.Maybe (fromMaybe)
9 |
10 | class Action m s where
11 | act :: m -> s -> s
12 |
13 | data SegTree m a
14 | = Leaf !Int !a
15 | | Node !Int !Int !a !m (SegTree m a) (SegTree m a)
16 | deriving Show
17 |
18 | instance Action () s where
19 | act _ = id
20 |
21 | node :: (Action m a, Semigroup a) => m -> SegTree m a -> SegTree m a -> SegTree m a
22 | node m l r = Node (getLeft l) (getRight r) (act m (getValue l <> getValue r)) m l r
23 |
24 | getValue :: SegTree m a -> a
25 | getValue (Leaf _ a) = a
26 | getValue (Node _ _ a _ _ _) = a
27 |
28 | getLeft :: SegTree m a -> Int
29 | getLeft (Leaf i _) = i
30 | getLeft (Node i _ _ _ _ _) = i
31 |
32 | getRight :: SegTree m a -> Int
33 | getRight (Leaf i _) = i+1
34 | getRight (Node _ j _ _ _ _) = j
35 |
36 | mkSegTree :: (Monoid m, Monoid a, Action m a) => [a] -> SegTree m a
37 | mkSegTree as = go 1 n (as ++ replicate (n - length as) mempty)
38 | where
39 | Just n = find (>= length as) (iterate (*2) 1)
40 |
41 | go i _ [a] = Leaf i a
42 | go i j as = node mempty l r
43 | where
44 | (as1, as2) = splitAt h as
45 | h = (j-i+1) `div` 2
46 | l = go i (i+h-1) as1
47 | r = go (i+h) j as2
48 |
49 | push :: (Monoid m, Action m a) => SegTree m a -> SegTree m a
50 | push (Node i j a m l r) = Node i j a mempty (applyAct m l) (applyAct m r)
51 | push t@Leaf{} = t
52 |
53 | applyAct :: (Monoid m, Action m a) => m -> SegTree m a -> SegTree m a
54 | applyAct m (Leaf i a) = Leaf i (act m a)
55 | applyAct m (Node i j a m2 l r) = Node i j (act m a) (m <> m2) l r
56 |
57 | update :: (Monoid m, Semigroup a, Action m a) => Int -> (a -> a) -> SegTree m a -> SegTree m a
58 | update _ f (Leaf i a) = Leaf i (f a)
59 | update p f (push -> Node _ _ _ _ l r)
60 | | p < getLeft r = node mempty (update p f l) r
61 | | otherwise = node mempty l (update p f r)
62 |
63 | set :: (Monoid m, Semigroup a, Action m a) => Int -> a -> SegTree m a -> SegTree m a
64 | set p = update p . const
65 |
66 | get :: (Monoid m, Action m a) => Int -> SegTree m a -> a
67 | get p (Leaf _ a) = a
68 | get p (push -> Node _ _ _ _ l r)
69 | | p < getLeft r = get p l
70 | | otherwise = get p r
71 |
72 | range :: (Monoid m, Monoid a, Action m a) => Int -> Int -> SegTree m a -> a
73 | range x y _ | x == y = mempty
74 | range x y (Leaf i a)
75 | | x <= i && i < y = a
76 | | otherwise = mempty
77 | range x y (push -> Node i j _ _ l r)
78 | | y <= i || j <= x = mempty
79 | | otherwise = range x y l <> range x y r
80 |
81 | apply :: (Monoid m, Semigroup a, Action m a) => Int -> m -> SegTree m a -> SegTree m a
82 | apply p = update p . act
83 |
84 | applyRange :: (Monoid m, Semigroup a, Action m a) => Int -> Int -> m -> SegTree m a -> SegTree m a
85 | applyRange x y _ t | x == y = t
86 | applyRange x y m l@(Leaf i a)
87 | | x <= i && i < y = Leaf i (act m a)
88 | | otherwise = l
89 | applyRange x y m n@(Node i j a m' l r)
90 | | x <= i && j <= y = Node i j a (m <> m') l r
91 | | otherwise = case push n of
92 | Node _ _ _ _ l r -> node mempty (applyRange x y m l) (applyRange x y m r)
93 |
94 | startingFrom :: (Monoid m, Action m a) => Int -> SegTree m a -> [SegTree m a]
95 | startingFrom l t = go t []
96 | where
97 | go t@(Leaf i _)
98 | | l <= i = (t:)
99 | | otherwise = id
100 | go (push -> t@(Node i j _ _ lt rt))
101 | | l <= i = (t:)
102 | | l < getLeft rt = go lt . (rt:)
103 | | l < getRight rt = go rt
104 | | otherwise = id
105 |
106 | -- | Preconditions:
107 | --
108 | -- - @l <= getRight t@
109 | -- - @g(mempty) == True@
110 | -- - @g@ is antitone, that is, if @g(a)@ is false then so is @g(a <> b)@.
111 | --
112 | -- Given these preconditions, @maxRight l g t@ returns the biggest
113 | -- @r@ such that @g (range l r t) == True@ but @g (range l (r+1) t)
114 | -- == False@ (or @r = getRight t@ if there is no such @r@).
115 | maxRight :: (Monoid a, Monoid m, Action m a) => Int -> (a -> Bool) -> SegTree m a -> Int
116 | maxRight l g t = fromMaybe (getRight t) (go mempty (startingFrom l t))
117 | where
118 | go _ [] = Nothing
119 | go cur (Leaf i a : ts)
120 | | g (cur <> a) = go (cur <> a) ts
121 | | otherwise = Just i
122 | go cur ((push -> Node i j a _ lt rt) : ts)
123 | | g (cur <> a) = go (cur <> a) ts
124 | | otherwise = go cur (lt : rt : ts)
125 |
126 | minLeft :: (Monoid a, Monoid m, Action m a) => Int -> (a -> Bool) -> SegTree m a -> Int
127 | minLeft = undefined
128 |
--------------------------------------------------------------------------------
/Sieve.hs:
--------------------------------------------------------------------------------
1 | -- https://codeforces.com/blog/entry/54090
2 | {-# LANGUAGE FlexibleContexts #-}
3 | {-# LANGUAGE PartialTypeSignatures #-}
4 |
5 | module Sieve where
6 |
7 | import Control.Monad (forM_, unless)
8 | import Control.Monad.ST
9 | import Data.Array.ST
10 | import Data.Array.Unboxed
11 | import Data.STRef
12 |
13 | newSTUArray :: (MArray (STUArray s) e (ST s), Ix i) => (i, i) -> e -> ST s (STUArray s i e)
14 | newSTUArray = newArray
15 |
16 | -- | Basic version. Suppose we want to compute a multiplicative function f such that
17 | -- - f(p) = g(p) for prime p
18 | -- - f(ip) = h(i, p, f(i)) when p divides i
19 | -- Then @sieve n g h@ returns an array f on (1,n) such that f!i = f(i), in O(n) time.
20 | sieve :: Int -> (Int -> Int) -> (Int -> Int -> Int -> Int) -> UArray Int Int
21 | sieve n g h = runSTUArray $ do
22 | primes <- newArray (0, n) (0 :: Int)
23 | numPrimes <- newSTRef 0
24 | composite <- newSTUArray (0, n) False
25 | f <- newArray (1, n) 1
26 | forM_ [2 .. n] $ \i -> do
27 | isComp <- readArray composite i
28 | unless isComp $ do
29 | appendPrime primes numPrimes i
30 | writeArray f i (g i)
31 |
32 | np <- readSTRef numPrimes
33 |
34 | let markComposites j
35 | | j >= np = return ()
36 | | otherwise = do
37 | p <- readArray primes j
38 | let ip = i * p
39 | case ip > n of
40 | True -> return ()
41 | False -> do
42 | writeArray composite ip True
43 | fi <- readArray f i
44 | case i `mod` p of
45 | 0 -> writeArray f ip (h i p fi)
46 | _ -> do
47 | writeArray f ip (fi * g p)
48 | markComposites (j + 1)
49 |
50 | markComposites 0
51 | return f
52 |
53 | appendPrime :: STUArray s Int Int -> STRef s Int -> Int -> ST s ()
54 | appendPrime primes numPrimes p = do
55 | n <- readSTRef numPrimes
56 | writeArray primes n p
57 | writeSTRef numPrimes $! n + 1
58 |
59 | -- | More general version. Suppose we want to compute a
60 | -- multiplicative function f such that f(p^k) = g p k. Then
61 | -- @genSieve n g@ returns an array f on (1,n) such that f!i = f(i),
62 | -- in O(n) time (but with a slightly higher constant factor than
63 | -- 'sieve').
64 | genSieve :: Int -> (Int -> Int -> Int) -> UArray Int Int
65 | genSieve n g = runSTUArray $ do
66 | primes <- newArray (0, n) (0 :: Int)
67 | numPrimes <- newSTRef 0
68 | composite <- newSTUArray (0, n) False
69 | count <- newSTUArray (2, n) (0 :: Int)
70 | f <- newArray (1, n) 1
71 | forM_ [2 .. n] $ \i -> do
72 | isComp <- readArray composite i
73 | unless isComp $ do
74 | appendPrime primes numPrimes i
75 | writeArray f i (g i 1)
76 | writeArray count i 1
77 |
78 | np <- readSTRef numPrimes
79 |
80 | let markComposites j
81 | | j >= np = return ()
82 | | otherwise = do
83 | p <- readArray primes j
84 | let ip = i * p
85 | case ip > n of
86 | True -> return ()
87 | False -> do
88 | writeArray composite ip True
89 | fi <- readArray f i
90 | case i `mod` p of
91 | 0 -> do
92 | c <- readArray count i
93 | f' <- readArray f (i `div` (p ^ c))
94 | writeArray f ip (f' * g p (c + 1))
95 | writeArray count ip (c + 1)
96 | _ -> do
97 | writeArray f ip (fi * g p 1)
98 | writeArray count ip 1
99 | markComposites (j + 1)
100 |
101 | markComposites 0
102 | return f
103 |
104 | -- Some examples!
105 |
106 | -- Euler's phi (totient) function
107 | phi :: Int -> UArray Int Int
108 | phi n = sieve n (subtract 1) (\_ p fi -> fi * p)
109 |
110 | -- Divisor sigma function, σₐ(x) = Σ{d | x} dᵃ
111 | divisorSigma :: Int -> Int -> UArray Int Int
112 | divisorSigma n a = genSieve n (\p k -> (p ^ (a * (k + 1)) - 1) `div` (p ^ a - 1))
113 |
114 | -- Möbius function (μ)
115 | mu :: Int -> UArray Int Int
116 | mu n = sieve n (const (-1)) (\_ _ _ -> 0)
117 |
--------------------------------------------------------------------------------
/Slice.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE FlexibleContexts #-}
2 | {-# LANGUAGE StandaloneDeriving #-}
3 | {-# LANGUAGE UndecidableInstances #-}
4 |
5 | module Slice where
6 |
7 | import Control.Arrow ((***))
8 | import Data.Array.Unboxed as U
9 | import Interval (I)
10 | import qualified Interval as I
11 | import Prelude as P
12 |
13 | data Slice a = S { storage :: !(UArray Int a), interval :: !I }
14 |
15 | deriving instance (Show a, IArray UArray a) => Show (Slice a)
16 |
17 | {-# SPECIALIZE fromList :: String -> Slice Char #-}
18 | fromList :: IArray UArray a => [a] -> Slice a
19 | fromList s = S (listArray (0,n-1) s) (I.mkI 0 n)
20 | where
21 | n = P.length s
22 |
23 | {-# SPECIALIZE toList :: Slice Char -> String #-}
24 | toList :: IArray UArray a => Slice a -> [a]
25 | toList (S s i) = map (s U.!) $ I.range i
26 |
27 | length :: Slice a -> Int
28 | length (S _ i) = I.length i
29 |
30 | take :: Int -> Slice a -> Slice a
31 | take k (S s i) = S s (I.take k i)
32 |
33 | subs :: Slice a -> [Slice a]
34 | subs (S s i) = map (S s) (I.subs i)
35 |
36 | splits :: Slice a -> [(Slice a, Slice a)]
37 | splits (S s i) = map (S s *** S s) (I.splits i)
38 |
39 | {-# SPECIALIZE (Slice.!) :: Slice Char -> Int -> Char #-}
40 | (!) :: IArray UArray a => Slice a -> Int -> a
41 | (S s i) ! k = s U.! (k + I.lo i)
42 |
43 | range :: Slice a -> [Int]
44 | range (S _ i) = map (subtract (I.lo i)) (I.range i)
45 |
--------------------------------------------------------------------------------
/SparseTable.hs:
--------------------------------------------------------------------------------
1 | -- https://cp-algorithms.com/data_structures/sparse-table.html
2 |
3 | {-# LANGUAGE TupleSections #-}
4 |
5 | module SparseTable where
6 |
7 | import Data.Array (Array, array, (!))
8 | import Data.Bifunctor (first)
9 | import Data.Bits
10 | import IdempotentSemigroup
11 |
12 | newtype SparseTable m = SparseTable (Array (Int, Int) m)
13 | deriving (Show)
14 |
15 | -- | Logarithm base 2, rounded down to the nearest integer. Computed
16 | -- efficiently using primitive bitwise instructions.
17 | lg :: Int -> Int
18 | lg n = finiteBitSize n - 1 - countLeadingZeros n
19 |
20 | -- | Construct a sparse table which can answer range queries over the
21 | -- given list in $O(1)$ time. Constructing the sparse table takes
22 | -- $O(n \lg n)$ time and space, where $n$ is the length of the list.
23 | fromList :: IdempotentSemigroup m => [m] -> SparseTable m
24 | fromList ms = SparseTable st
25 | where
26 | n = length ms
27 | lgn = lg n
28 |
29 | st =
30 | array ((0, 0), (lgn, n - 1)) $
31 | (map (first (0,)) $ zip [0 ..] ms)
32 | ++ [ ((i, j), st ! (i - 1, j) <> st ! (i - 1, j + 1 !<<. (i - 1)))
33 | | i <- [1 .. lgn]
34 | , j <- [0 .. n - 1 !<<. i]
35 | ]
36 |
37 | -- | \$O(1)$. @range st l r@ computes the range query which is the
38 | -- @sconcat@ of all the elements from index @l@ to @r@ (inclusive).
39 | range :: IdempotentSemigroup m => SparseTable m -> Int -> Int -> m
40 | range (SparseTable st) l r = st ! (i, l) <> st ! (i, r - (1 !<<. i) + 1)
41 | where
42 | i = lg (r - l + 1)
43 |
--------------------------------------------------------------------------------
/SqrtTree.hs:
--------------------------------------------------------------------------------
1 | -- https://cp-algorithms.com/data_structures/sqrt-tree.html
2 | -- https://cp-algorithms.com/data_structures/sqrt_decomposition.html
3 |
4 | {-# LANGUAGE TypeApplications #-}
5 | {-# LANGUAGE GADTSyntax #-}
6 |
7 | module SqrtTree where
8 |
9 | import Data.Array (Array, array, listArray, (!), bounds)
10 | import Data.List (scanl1, scanr1)
11 |
12 | import Data.Semigroup
13 | import Control.Monad
14 | import System.Random
15 |
16 | data Block m = Block
17 | { total :: m -- ^ total of this entire block
18 | , prefix :: Array Int m -- ^ prefix sums for this block
19 | , suffix :: Array Int m -- ^ suffix sums for this block
20 | , subtree :: SqrtTree m -- ^ sqrt tree for this block
21 | }
22 | deriving (Show)
23 |
24 | data SqrtTree m where
25 | One :: m -> SqrtTree m
26 | Two :: m -> m -> SqrtTree m
27 | Branch ::
28 | Int -> -- ^ block size
29 | Array Int (Block m) -> -- ^ blocks
30 | Array (Int,Int) m -> -- ^ between sums for blocks
31 | SqrtTree m
32 | deriving (Show)
33 |
34 | fromList :: Semigroup m => [m] -> SqrtTree m
35 | fromList ms = fromArray $ listArray (0,length ms-1) ms
36 |
37 | fromArray :: Semigroup m => Array Int m -> SqrtTree m
38 | fromArray ms = mkSqrtTree ms lo (hi+1)
39 | where
40 | (lo,hi) = bounds ms
41 |
42 | -- | @mkSqrtTree ms lo hi@ makes a sqrt tree on ms[lo..hi).
43 | mkSqrtTree :: Semigroup m => Array Int m -> Int -> Int -> SqrtTree m
44 | mkSqrtTree ms lo hi
45 | | hi - lo == 1 = One (ms ! lo)
46 | | hi - lo == 2 = Two (ms ! lo) (ms ! (lo+1))
47 | | otherwise = Branch k blocks between
48 | where
49 | k :: Int
50 | k = ceiling (sqrt @Double (fromIntegral (hi - lo)))
51 |
52 | -- blocks :: Array Int (Block m)
53 | blocks = listArray (0,k-1) . map mkBlock . takeWhile (< hi) . iterate (+k) $ lo
54 |
55 | mkBlock i = Block (pref!(n-1)) pref suf (mkSqrtTree ms i j)
56 | where
57 | n = j-i
58 | elts = map (ms!) [i..j-1]
59 | pref = listArray (0,n-1) (scanl1 (<>) elts)
60 | suf = listArray (0,n-1) (scanr1 (<>) elts)
61 | j = min (i + k) hi
62 |
63 | -- between :: Array (Int,Int) m
64 | between = array ((0,0), (k-1, k-1)) $
65 | [((s,s), total (blocks!s)) | s <- [0 .. k-1]]
66 | ++ [ ((s,t), between!(s,t-1) <> total (blocks!t))
67 | | s <- [0 .. k-2]
68 | , t <- [s+1 .. k-1]
69 | ]
70 | -- XXX only to print out
71 | ++ [ ((t,s), between!(s,t-1) <> total (blocks!t))
72 | | s <- [0 .. k-2]
73 | , t <- [s+1 .. k-1]
74 | ]
75 |
76 | range :: Semigroup m => SqrtTree m -> Int -> Int -> m
77 | range (One m) _ _ = m
78 | range (Two x y) 0 0 = x
79 | range (Two x y) 0 1 = x <> y
80 | range (Two x y) 1 1 = y
81 | range (Branch k blocks between) l r
82 | | lb == rb = range (subtree (blocks!lb)) li ri
83 | | rb - lb == 1 = suffix (blocks!lb) ! li <> prefix (blocks!rb) ! ri
84 | | otherwise = suffix (blocks!lb)!li <> between!(lb+1,rb-1) <> prefix (blocks!rb) ! ri
85 | where
86 | (lb, li) = l `divMod` k
87 | (rb, ri) = r `divMod` k
88 |
89 | randomRange :: Semigroup m => SqrtTree m -> IO m
90 | randomRange t = do
91 | l <- randomRIO (0,999999)
92 | r <- randomRIO (l,999999)
93 | pure $ range t l r
94 |
95 | main = do
96 | ns <- replicateM 1000000 (randomRIO (0,1000 :: Int))
97 | let t = fromList (map Sum ns)
98 | rs <- replicateM 1000000 (randomRange t)
99 | print (getSum $ mconcat rs)
100 |
--------------------------------------------------------------------------------
/Stack.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.github.io/blog/posts/2024/11/27/stacks-queues.html
2 | {-# LANGUAGE BangPatterns #-}
3 |
4 | module Stack where
5 |
6 | import Data.List (foldl')
7 | import Data.Monoid (Dual (..))
8 |
9 | data Stack m a = Stack !(a -> m) !Int ![(m, a)]
10 |
11 | instance (Show m, Show a) => Show (Stack m a) where
12 | show (Stack _ _ as) = show as
13 |
14 | instance (Eq m, Eq a) => Eq (Stack m a) where
15 | Stack _ _ as1 == Stack _ _ as2 = as1 == as2
16 |
17 | new :: (a -> m) -> Stack m a
18 | new f = Stack f 0 []
19 |
20 | size :: Stack m a -> Int
21 | size (Stack _ n _) = n
22 |
23 | measure :: Monoid m => Stack m a -> m
24 | measure (Stack _ _ as) = case as of
25 | [] -> mempty
26 | (m, _) : _ -> m
27 |
28 | push :: Monoid m => a -> Stack m a -> Stack m a
29 | push a s@(Stack f n as) = Stack f (n + 1) ((f a <> measure s, a) : as)
30 |
31 | pop :: Stack m a -> Maybe (a, Stack m a)
32 | pop (Stack f n as) = case as of
33 | [] -> Nothing
34 | (_, a) : as' -> Just (a, Stack f (n - 1) as')
35 |
36 | reverse' :: Monoid n => (m -> n) -> Stack m a -> Stack n a
37 | reverse' g (Stack f _ as) = foldl' (flip push) (new (g . f)) (map snd as)
38 |
39 | reverse :: Monoid m => Stack m a -> Stack m a
40 | reverse = reverse' id
41 |
--------------------------------------------------------------------------------
/Tree.hs:
--------------------------------------------------------------------------------
1 | -- https://byorgey.github.io/blog/posts/2024/07/11/cpih-factor-full-tree.html
2 | -- https://byorgey.github.io/blog/posts/2024/08/08/TreeDecomposition.html
3 | {-# LANGUAGE TupleSections #-}
4 |
5 | module Tree where
6 |
7 | import Control.Arrow ((***))
8 | import Control.Category ((>>>))
9 | import Data.Bifunctor (second)
10 | import Data.List (maximumBy, sortBy)
11 | import Data.List.NonEmpty (NonEmpty)
12 | import qualified Data.List.NonEmpty as NE
13 | import Data.Map (Map, (!?))
14 | import qualified Data.Map as M
15 | import Data.Ord (Down (..), comparing)
16 | import Data.Tree
17 | import Data.Tuple (swap)
18 |
19 | edgesToMap :: Ord a => [(a, a)] -> Map a [a]
20 | edgesToMap = concatMap (\p -> [p, swap p]) >>> dirEdgesToMap
21 |
22 | dirEdgesToMap :: Ord a => [(a, a)] -> Map a [a]
23 | dirEdgesToMap = map (second (: [])) >>> M.fromListWith (++)
24 |
25 | mapToTree :: Ord a => (a -> [b] -> b) -> Map a [a] -> a -> b
26 | mapToTree nd m root = go root root
27 | where
28 | go parent root = nd root (maybe [] (map (go root) . filter (/= parent)) (m !? root))
29 |
30 | edgesToTree :: Ord a => (a -> [b] -> b) -> [(a, a)] -> a -> b
31 | edgesToTree nd = mapToTree nd . edgesToMap
32 |
33 | parentMap :: Ord a => Tree a -> Map a a
34 | parentMap = foldTree node >>> snd
35 | where
36 | node a b = (a, M.fromList (map (,a) as) <> mconcat ms)
37 | where
38 | (as, ms) = unzip b
39 |
40 | type SubtreeSelector a = a -> [Tree a] -> Maybe (Tree a, [Tree a])
41 |
42 | pathDecomposition :: (a -> [Tree a] -> Maybe (Tree a, [Tree a])) -> Tree a -> [NonEmpty a]
43 | pathDecomposition select = go
44 | where
45 | go = selectPath select >>> second (concatMap go) >>> uncurry (:)
46 |
47 | selectPath :: SubtreeSelector a -> Tree a -> (NonEmpty a, [Tree a])
48 | selectPath select = go
49 | where
50 | go (Node a ts) = case select a ts of
51 | Nothing -> (NE.singleton a, ts)
52 | Just (t, ts') -> ((a NE.<|) *** (ts' ++)) (go t)
53 |
54 | type Height = Int
55 | type Size = Int
56 |
57 | labelHeight :: Tree a -> Tree (Height, a)
58 | labelHeight = foldTree node
59 | where
60 | node a ts = case ts of
61 | [] -> Node (0, a) []
62 | _ -> Node (1 + maximum (map (fst . rootLabel) ts), a) ts
63 |
64 | labelSize :: Tree a -> Tree (Size, a)
65 | labelSize = foldTree $ \a ts -> Node (1 + sum (map (fst . rootLabel) ts), a) ts
66 |
67 | -- | Decompose a tree into chains by length, first the longest
68 | -- possible chain, then the longest chain from what remains, and so
69 | -- on.
70 | maxChainDecomposition :: Tree a -> [NonEmpty (Height, a)]
71 | maxChainDecomposition =
72 | labelHeight
73 | >>> pathDecomposition (const (selectMaxBy (comparing (fst . rootLabel))))
74 | >>> sortBy (comparing (Down . fst . NE.head))
75 |
76 | selectMaxBy :: (a -> a -> Ordering) -> [a] -> Maybe (a, [a])
77 | selectMaxBy _ [] = Nothing
78 | selectMaxBy cmp (a : as) = case selectMaxBy cmp as of
79 | Nothing -> Just (a, [])
80 | Just (b, bs) -> case cmp a b of
81 | LT -> Just (b, a : bs)
82 | _ -> Just (a, b : bs)
83 |
84 | -- | Heavy-light decomposition of a tree.
85 | heavyLightDecomposition :: Tree a -> [NonEmpty (Size, a)]
86 | heavyLightDecomposition =
87 | labelSize >>> pathDecomposition (const (selectMaxBy (comparing (fst . rootLabel))))
88 |
--------------------------------------------------------------------------------
/Trie.hs:
--------------------------------------------------------------------------------
1 | module Trie where
2 |
3 | import Control.Monad ((>=>))
4 | import qualified Data.ByteString.Lazy.Char8 as C
5 | import Data.List (foldl')
6 | import Data.Map (Map, (!))
7 | import qualified Data.Map as M
8 | import Data.Maybe (fromMaybe, isJust)
9 |
10 | data Trie a = Trie
11 | { trieSize :: !Int
12 | , value :: !(Maybe a)
13 | , children :: !(Map Char (Trie a))
14 | }
15 | deriving Show
16 |
17 | emptyTrie :: Trie a
18 | emptyTrie = Trie 0 Nothing M.empty
19 |
20 | -- | Insert a new key/value pair into a trie, updating the size
21 | -- appropriately.
22 | insert :: C.ByteString -> a -> Trie a -> Trie a
23 | insert w a t = fst (go w t)
24 | where
25 | go = C.foldr
26 | (\c insSuffix (Trie n v m) ->
27 | let (t', ds) = insSuffix (fromMaybe emptyTrie (M.lookup c m))
28 | in (Trie (n+ds) v (M.insert c t' m), ds)
29 | )
30 | (\(Trie n v m) ->
31 | let ds = if isJust v then 0 else 1
32 | in (Trie (n+ds) (Just a) m, ds)
33 | )
34 |
35 | -- | Create an initial trie from a list of key/value pairs. If there
36 | -- are multiple pairs with the same key, later pairs override
37 | -- earlier ones.
38 | mkTrie :: [(C.ByteString, a)] -> Trie a
39 | mkTrie = foldl' (flip (uncurry insert)) emptyTrie
40 |
41 | -- | Look up a single character in a trie, returning the corresponding
42 | -- child trie (if any).
43 | lookup1 :: Char -> Trie a -> Maybe (Trie a)
44 | lookup1 c = M.lookup c . children
45 |
46 | -- | Look up a string key in a trie, returning the corresponding value
47 | -- (if any).
48 | lookup :: C.ByteString -> Trie a -> Maybe a
49 | lookup = C.foldr ((>=>) . lookup1) value
50 |
51 | -- | Fold a trie into a summary value.
52 | foldTrie :: (Int -> Maybe a -> Map Char r -> r) -> Trie a -> r
53 | foldTrie f (Trie n b m) = f n b (M.map (foldTrie f) m)
54 |
55 | -- | "Decode" a string by repeatedly looking up consecutive
56 | -- characters. Every time we find a key which exists in the trie,
57 | -- emit the corresponding value and restart at the root. This is of
58 | -- particular use in decoding a prefix-free code. Note that this
59 | -- function will crash if it ever looks up a character which is not
60 | -- in the current trie.
61 | decode :: Trie a -> C.ByteString -> [a]
62 | decode t = reverse . snd . C.foldl' step (t, [])
63 | where
64 | step (s, as) c =
65 | let Just s' = lookup1 c s
66 | in maybe (s', as) (\a -> (t, a:as)) (value s')
67 |
--------------------------------------------------------------------------------
/UnionFind.hs:
--------------------------------------------------------------------------------
1 | -- Adapted from https://kseo.github.io/posts/2014-01-30-implementing-union-find-in-haskell.html
2 | -- https://byorgey.github.io/blog/posts/2024/11/02/UnionFind.html
3 | -- https://byorgey.github.io/blog/posts/2024/11/18/UnionFind-sols.html
4 | {-# LANGUAGE RecordWildCards #-}
5 |
6 | module UnionFind where
7 |
8 | import Control.Monad (when)
9 | import Control.Monad.ST
10 | import Data.Array.ST
11 |
12 | type Node = Int
13 |
14 | data UnionFind s m = UnionFind
15 | { parent :: !(STUArray s Node Node)
16 | , sz :: !(STUArray s Node Int)
17 | , ann :: !(STArray s Node m)
18 | }
19 |
20 | new :: Int -> m -> ST s (UnionFind s m)
21 | new n m = newWith n (const m)
22 |
23 | newWith :: Int -> (Node -> m) -> ST s (UnionFind s m)
24 | newWith n m =
25 | UnionFind
26 | <$> newListArray (0, n - 1) [0 .. n - 1]
27 | <*> newArray (0, n - 1) 1
28 | <*> newListArray (0, n - 1) (map m [0 .. n - 1])
29 |
30 | connected :: UnionFind s m -> Node -> Node -> ST s Bool
31 | connected uf x y = (==) <$> find uf x <*> find uf y
32 |
33 | find :: UnionFind s m -> Node -> ST s Node
34 | find uf@(UnionFind {..}) x = do
35 | p <- readArray parent x
36 | if p /= x
37 | then do
38 | r <- find uf p
39 | writeArray parent x r
40 | pure r
41 | else pure x
42 |
43 | updateAnn :: Semigroup m => UnionFind s m -> Node -> (m -> m) -> ST s ()
44 | updateAnn uf@(UnionFind {..}) x f = do
45 | x <- find uf x
46 | old <- readArray ann x -- modifyArray is not available in Kattis test environment
47 | writeArray ann x (f old)
48 |
49 | union :: Semigroup m => UnionFind s m -> Node -> Node -> ST s ()
50 | union uf@(UnionFind {..}) x y = do
51 | x <- find uf x
52 | y <- find uf y
53 | when (x /= y) $ do
54 | sx <- readArray sz x
55 | sy <- readArray sz y
56 | mx <- readArray ann x
57 | my <- readArray ann y
58 | if sx < sy
59 | then do
60 | writeArray parent x y
61 | writeArray sz y (sx + sy)
62 | writeArray ann y (mx <> my)
63 | else do
64 | writeArray parent y x
65 | writeArray sz x (sx + sy)
66 | writeArray ann x (mx <> my)
67 |
68 | size :: UnionFind s m -> Node -> ST s Int
69 | size uf@(UnionFind {..}) x = do
70 | x <- find uf x
71 | readArray sz x
72 |
73 | getAnn :: UnionFind s m -> Node -> ST s m
74 | getAnn uf@(UnionFind {..}) x = do
75 | x <- find uf x
76 | readArray ann x
77 |
78 | allAnns :: UnionFind s m -> ST s [(Int, m)]
79 | allAnns UnionFind {..} = do
80 | ps <- getAssocs parent
81 | flip foldMap ps $ \(p, x) ->
82 | if p == x
83 | then do
84 | a <- readArray ann x
85 | s <- readArray sz x
86 | pure [(s, a)]
87 | else pure []
88 |
89 | -- XXX comment me
90 | -- https://algocoding.wordpress.com/2015/05/13/simple-union-find-techniques/
91 | unite :: Semigroup m => UnionFind s m -> Node -> Node -> ST s Bool
92 | unite uf x y = do
93 | px <- readArray parent x
94 | py <- readArray parent y
95 | case compare px py of
96 | EQ -> pure True
97 | LT -> do
98 | writeArray parent x py
99 | case x == px of
100 | True -> pure False
101 | False -> unite uf px y
102 | GT -> do
103 | writeArray parent y px
104 | case y == py of
105 | True -> pure False
106 | False -> unite uf x py
107 |
--------------------------------------------------------------------------------
/Util.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE TupleSections #-}
2 |
3 | module Util where
4 |
5 | import Control.Arrow
6 | import Data.Array.IArray
7 | import Data.Bits (finiteBitSize, countLeadingZeros)
8 | import Data.Function
9 | import Data.List
10 | import qualified Data.List.NonEmpty as NE
11 | import Data.Maybe
12 | import Data.Ord
13 |
14 | fi :: (Integral a, Num b) => a -> b
15 | fi = fromIntegral
16 |
17 | fj :: Maybe a -> a
18 | fj = fromJust
19 |
20 | both :: (a -> b) -> (a, a) -> (b, b)
21 | both f (x, y) = (f x, f y)
22 |
23 | sortGroupOn :: Ord b => (a -> b) -> [a] -> [(b, NE.NonEmpty a)]
24 | sortGroupOn f = sortOn f >>> NE.groupBy ((==) `on` f) >>> map ((f . NE.head) &&& id)
25 |
26 | pairs :: [a] -> [(a, a)]
27 | pairs [] = []
28 | pairs (a : as) = map (a,) as ++ pairs as
29 |
30 | withPairs :: Monoid r => (a -> a -> r) -> [a] -> r
31 | withPairs _ [] = mempty
32 | withPairs _ [_] = mempty
33 | withPairs f (a : as) = go as
34 | where
35 | go [] = withPairs f as
36 | go (a2 : rest) = f a a2 <> go rest
37 |
38 | generate :: (Ix i, IArray a e) => (i, i) -> (i -> e) -> a i e
39 | generate rng f = listArray rng (map f (range rng))
40 |
41 | arraydef :: (Ix i, IArray a e) => (i, i) -> e -> [(i, e)] -> a i e
42 | arraydef rng def vs = array rng ([(i, def) | i <- range rng] ++ vs)
43 |
44 | -- | Logarithm base 2, rounded down to the nearest integer. Computed
45 | -- efficiently using primitive bitwise instructions.
46 | lg :: Int -> Int
47 | lg n = finiteBitSize n - 1 - countLeadingZeros n
48 |
--------------------------------------------------------------------------------
/comprog-hs.cabal:
--------------------------------------------------------------------------------
1 | cabal-version: 2.4
2 | name: comprog-hs
3 | version: 0.1.0.0
4 | synopsis:
5 |
6 | -- A longer description of the package.
7 | -- description:
8 | homepage:
9 |
10 | -- A URL where users can report bugs.
11 | -- bug-reports:
12 | license: BSD-3-Clause
13 | license-file: LICENSE
14 | author: Brent Yorgey
15 | maintainer: byorgey@gmail.com
16 |
17 | -- A copyright notice.
18 | -- copyright:
19 | category: Development
20 | extra-source-files:
21 | CHANGELOG.md
22 | README.md
23 |
24 | library
25 | exposed-modules: BFS,
26 | BinarySearch,
27 | Enumeration,
28 | Geom,
29 | IdempotentSemigroup,
30 | Interval,
31 | NumberTheory,
32 | Perm,
33 | Queue,
34 | ScannerBS,
35 | Scanner,
36 | SegTree,
37 | Sieve,
38 | Slice,
39 | SparseTable,
40 | SqrtTree,
41 | Stack,
42 | Tree,
43 | Util
44 |
45 | build-depends: base,
46 | containers,
47 | hashable,
48 | unordered-containers,
49 | array,
50 | mtl,
51 | bytestring
52 |
53 | -- Directories containing source files.
54 | -- hs-source-dirs:
55 | default-language: Haskell2010
56 |
--------------------------------------------------------------------------------