├── .gitignore ├── .travis.yml ├── README.md ├── segmenttree └── __init__.py ├── setup.py └── tests └── test_segtree.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | 21 | # Installer logs 22 | pip-log.txt 23 | 24 | # Unit test / coverage reports 25 | .coverage 26 | .tox 27 | nosetests.xml 28 | 29 | # Translations 30 | *.mo 31 | 32 | # Mr Developer 33 | .mr.developer.cfg 34 | .project 35 | .pydevproject 36 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.4" 5 | # command to run tests 6 | script: nosetests 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | segmenttree 2 | =========== 3 | 4 | This is a Python implementation of [Segment tree](http://en.wikipedia.org/wiki/Segment_tree). 5 | 6 | See [How to use it](https://leons.im/posts/a-python-implementation-of-segment-tree/). 7 | 8 | ## Build Status 9 | 10 | [![Build Status](https://travis-ci.org/leonsim/segmenttree.png)](https://travis-ci.org/leonsim/segmenttree) 11 | -------------------------------------------------------------------------------- /segmenttree/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | class SegmentTree(object): 3 | def __init__(self, start, end): 4 | self.start = start 5 | self.end = end 6 | self.max_value = {} 7 | self.sum_value = {} 8 | self.len_value = {} 9 | self._init(start, end) 10 | 11 | def out_of_bounds_check(self, start, end): 12 | start = max(start, self.start) 13 | end = min(end, self.end) 14 | if start > end: 15 | return None, None 16 | return start, end 17 | 18 | def add(self, start, end, weight=1): 19 | start, end = self.out_of_bounds_check(start, end) 20 | if start is None: 21 | return False 22 | self._add(start, end, weight, self.start, self.end) 23 | return True 24 | 25 | def query_max(self, start, end): 26 | start, end = self.out_of_bounds_check(start, end) 27 | if start is None: 28 | return None 29 | return self._query_max(start, end, self.start, self.end) 30 | 31 | def query_sum(self, start, end): 32 | start, end = self.out_of_bounds_check(start, end) 33 | if start is None: 34 | return 0 35 | return self._query_sum(start, end, self.start, self.end) 36 | 37 | def query_len(self, start, end): 38 | start, end = self.out_of_bounds_check(start, end) 39 | if start is None: 40 | return 0 41 | return self._query_len(start, end, self.start, self.end) 42 | 43 | """""" 44 | def _init(self, start, end): 45 | self.max_value[(start, end)] = 0 46 | self.sum_value[(start, end)] = 0 47 | self.len_value[(start, end)] = 0 48 | if start < end: 49 | mid = start + int((end - start) / 2) 50 | self._init(start, mid) 51 | self._init(mid+1, end) 52 | 53 | def _add(self, start, end, weight, in_start, in_end): 54 | key = (in_start, in_end) 55 | if in_start == in_end: 56 | self.max_value[key] += weight 57 | self.sum_value[key] += weight 58 | self.len_value[key] = 1 if self.sum_value[key] > 0 else 0 59 | return 60 | 61 | mid = in_start + int((in_end - in_start) / 2) 62 | if mid >= end: 63 | self._add(start, end, weight, in_start, mid) 64 | elif mid+1 <= start: 65 | self._add(start, end, weight, mid+1, in_end) 66 | else: 67 | self._add(start, mid, weight, in_start, mid) 68 | self._add(mid+1, end, weight, mid+1, in_end) 69 | self.max_value[key] = max(self.max_value[(in_start, mid)], self.max_value[(mid+1, in_end)]) 70 | self.sum_value[key] = self.sum_value[(in_start, mid)] + self.sum_value[(mid+1, in_end)] 71 | self.len_value[key] = self.len_value[(in_start, mid)] + self.len_value[(mid+1, in_end)] 72 | 73 | def _query_max(self, start, end, in_start, in_end): 74 | if start == in_start and end == in_end: 75 | ans = self.max_value[(start, end)] 76 | else: 77 | mid = in_start + int((in_end - in_start) / 2) 78 | if mid >= end: 79 | ans = self._query_max(start, end, in_start, mid) 80 | elif mid+1 <= start: 81 | ans = self._query_max(start, end, mid+1, in_end) 82 | else: 83 | ans = max(self._query_max(start, mid, in_start, mid), 84 | self._query_max(mid+1, end, mid+1, in_end)) 85 | #print start, end, in_start, in_end, ans 86 | return ans 87 | 88 | def _query_sum(self, start, end, in_start, in_end): 89 | if start == in_start and end == in_end: 90 | ans = self.sum_value[(start, end)] 91 | else: 92 | mid = in_start + int((in_end - in_start) / 2) 93 | if mid >= end: 94 | ans = self._query_sum(start, end, in_start, mid) 95 | elif mid+1 <= start: 96 | ans = self._query_sum(start, end, mid+1, in_end) 97 | else: 98 | ans = self._query_sum(start, mid, in_start, mid) + self._query_sum(mid+1, end, mid+1, in_end) 99 | return ans 100 | 101 | def _query_len(self, start, end, in_start, in_end): 102 | if start == in_start and end == in_end: 103 | ans = self.len_value[(start, end)] 104 | else: 105 | mid = in_start + int((in_end - in_start) / 2) 106 | if mid >= end: 107 | ans = self._query_len(start, end, in_start, mid) 108 | elif mid+1 <= start: 109 | ans = self._query_len(start, end, mid+1, in_end) 110 | else: 111 | ans = self._query_len(start, mid, in_start, mid) + self._query_len(mid+1, end, mid+1, in_end) 112 | 113 | #print start, end, in_start, in_end, ans 114 | return ans 115 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name = 'segmenttree', 7 | version = '1.0.4', 8 | keywords = 'segment tree', 9 | description = 'A Python implementation of Segment Tree', 10 | license = 'MIT License', 11 | 12 | url = 'https://leons.im/posts/a-python-implementation-of-segment-tree/', 13 | author = '1e0n', 14 | author_email = 'i@leons.im', 15 | 16 | packages = find_packages(), 17 | include_package_data = True, 18 | platforms = 'any', 19 | install_requires = [], 20 | ) 21 | -------------------------------------------------------------------------------- /tests/test_segtree.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from unittest import main, TestCase 3 | 4 | from segmenttree import SegmentTree 5 | 6 | class TestSegmentTree(TestCase): 7 | def test_segtree(self): 8 | segtree = SegmentTree(1, 8) 9 | segtree.add(1, 3, 1) 10 | self.assertEqual(1, segtree.query_max(2, 5)) 11 | self.assertEqual(2, segtree.query_len(2, 5)) 12 | self.assertEqual(2, segtree.query_sum(2, 5)) 13 | 14 | segtree.add(3, 4, 1) 15 | self.assertEqual(2, segtree.query_max(2, 5)) 16 | self.assertEqual(3, segtree.query_len(2, 5)) 17 | self.assertEqual(4, segtree.query_sum(2, 5)) 18 | 19 | segtree.add(4, 5, 1) 20 | self.assertEqual(2, segtree.query_max(2, 5)) 21 | self.assertEqual(4, segtree.query_len(2, 5)) 22 | self.assertEqual(6, segtree.query_sum(2, 5)) 23 | 24 | segtree.add(3, 6, 2) 25 | self.assertEqual(4, segtree.query_max(2, 5)) 26 | self.assertEqual(4, segtree.query_len(2, 5)) 27 | self.assertEqual(12, segtree.query_sum(2, 5)) 28 | 29 | segtree.add(1, 7) 30 | self.assertEqual(5, segtree.query_max(2, 5)) 31 | self.assertEqual(4, segtree.query_len(2, 5)) 32 | self.assertEqual(16, segtree.query_sum(2, 5)) 33 | 34 | def test_demo(self): 35 | segtree = SegmentTree(1, 8) 36 | segtree.add(1, 3, 1) 37 | segtree.add(3, 4, 1) 38 | segtree.add(4, 5, 1) 39 | segtree.add(3, 6, 2) 40 | segtree.add(1, 71) 41 | self.assertEqual(5, segtree.query_max(2, 5)) 42 | self.assertEqual(4, segtree.query_len(2, 5)) 43 | self.assertEqual(16, segtree.query_sum(2, 5)) 44 | 45 | def test_empty(self): 46 | segtree = SegmentTree(0, 8) 47 | segtree.add(1, 1) 48 | segtree.add(8, 8) 49 | self.assertEqual(1, segtree.query_max(0, 8)) 50 | self.assertEqual(2, segtree.query_len(0, 8)) 51 | self.assertEqual(2, segtree.query_sum(0, 8)) 52 | 53 | def test_full_out_of_bound(self): 54 | segtree = SegmentTree(0, 8) 55 | segtree.add(0, 8) 56 | 57 | # Test full out of bound adding element fails 58 | self.assertEqual(False, segtree.add(10, 16)) 59 | self.assertEqual(False, segtree.add(-16, -10)) 60 | 61 | # Test full out of bound len query returns 0 62 | self.assertEqual(0, segtree.query_len(10, 16)) 63 | self.assertEqual(0, segtree.query_len(-16, -10)) 64 | 65 | # Test full out of bounds len query returns 0 66 | self.assertEqual(0, segtree.query_sum(10, 16)) 67 | self.assertEqual(0, segtree.query_sum(-16, -10)) 68 | 69 | # Test full out of bounds max query returns None 70 | self.assertEqual(None, segtree.query_max(10, 16)) 71 | self.assertEqual(None, segtree.query_max(-16, -10)) 72 | 73 | def test_partial_out_of_bound(self): 74 | segtree = SegmentTree(0, 8) 75 | segtree.add(0, 8) 76 | 77 | # Test partial out of bound adding element 78 | self.assertEqual(True, segtree.add(8, 16)) 79 | self.assertEqual(True, segtree.add(-16, 0)) 80 | 81 | # Test partial out of bound len query 82 | self.assertEqual(2, segtree.query_len(7, 16)) 83 | self.assertEqual(2, segtree.query_len(-16, 1)) 84 | 85 | # Test partial out of bounds len query 86 | self.assertEqual(3, segtree.query_sum(7, 16)) 87 | self.assertEqual(3, segtree.query_sum(-16, 1)) 88 | 89 | # Test full out of bounds max query 90 | self.assertEqual(2, segtree.query_max(7, 16)) 91 | self.assertEqual(2, segtree.query_max(-16, 1)) 92 | 93 | if __name__ == '__main__': 94 | main() 95 | --------------------------------------------------------------------------------