├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── cookie ├── bing_search_cookie.json └── google_search_cookie.json ├── example_cmd.sh ├── init.py ├── main.py ├── modules ├── __init__.py ├── algorithm.py ├── crawler.py └── preprocess.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | tmp/ 128 | data/ 129 | *zip 130 | *.db 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 luogan1234 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # concept-expansion-snippet 2 | 3 | **requirements:** Python >= 3.5, PyTorch >= 1.5.0, jieba, requests, bs4, nltk, pytorch-pretrained-bert 4 | 5 | This toolkit provides graph propagation, average distance, tf-idf and PageRank algorithms with pretrained word vectors (BERT) to do the concept extraction and the concept expansion tasks. It supports both Chinese and English text. 6 | 7 | - **Concept extraction:** Given **seed concepts** and **text**, extract the related concepts from the input text. 8 | - **Concept expansion:** Given **seed concepts**, extract the related concepts from the snippets of search engine results (Baidu for Chinese and Google for English). 9 | 10 | ## Before running the code 11 | 1. run `python init.py`. 12 | 13 | ## Parameters 14 | 15 | You can run `bash example_cmd.sh -t en0|en1|zh0|zh1` to run examples, and follow the `example_cmd.sh` to run this toolkit. 16 | 17 | Some important parameters available in command line: 18 | 19 | ``` 20 | -task: required, extract | expand 21 | -input_text, -it: the text file for concept extraction task 22 | -input_seed, -is: the seed file for concept extraction & expansion task 23 | -language, -l: required, zh | en 24 | -snippet_source, -ss: baidu | google | bing 25 | -no_seed, -ns: store true if every candidate in text is a seed 26 | -algorithm, -a: graph_propagation | average_distance | tf_idf | pagerank 27 | -result_path, -r 28 | -cpu 29 | ``` 30 | 31 | The following path lists can be modified in `config.py`: 32 | 33 | ``` 34 | zh_list, en_list: all possible candidate concepts 35 | db: store the crawled snippets 36 | cookie_paths: some cookie files to support the crawl process 37 | cached_vecs_path: store vectors of candidate concepts 38 | text_path, candidate_path: tmp file 39 | ``` 40 | 41 | ## Note 42 | 43 | To crawl Google search snippets, you need VPN (for users in Mainland China). 44 | 45 | The crawler may be blocked by anti-crawler programs, do not crawl Google search engine results too fast. 46 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Config: 4 | def __init__(self, task, input_text, input_seed, language, snippet_source, no_seed, algorithm, result_path, cpu): 5 | self.task = task 6 | self.input_text = input_text 7 | self.input_seed = input_seed 8 | self.language = language 9 | self.snippet_source = snippet_source 10 | self.no_seed = no_seed 11 | self.algorithm = algorithm 12 | self.result_path = result_path 13 | self.cpu = cpu 14 | 15 | self.times = 10 16 | self.max_num = -1 17 | self.threshold = 0.7 18 | self.decay = 0.8 19 | self.batch_size = 128 20 | 21 | self.zh_list = 'data/zh_list' 22 | self.en_list = 'data/en_list' 23 | self.db = 'snippet.db' 24 | self.cookie_paths = ['cookie/{}'.format(file) for file in os.listdir('cookie/')] 25 | self.proxy = {'http': 'http://localhost:8001', 'https': 'http://localhost:8001'} # should change to your own proxy 26 | self.text_path = 'tmp/text.txt' 27 | self.candidate_path = 'tmp/candidate.txt' 28 | self.cached_vecs_path = 'data/cached_vecs.pkl' -------------------------------------------------------------------------------- /cookie/bing_search_cookie.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "domain": ".bing.com", 4 | "hostOnly": false, 5 | "httpOnly": true, 6 | "name": "_EDGE_S", 7 | "path": "/", 8 | "sameSite": "unspecified", 9 | "secure": false, 10 | "session": true, 11 | "storeId": "0", 12 | "value": "mkt=zh-cn&SID=17742F425B7F62C00CBD22BF5A5163D6", 13 | "id": 1 14 | }, 15 | { 16 | "domain": ".bing.com", 17 | "hostOnly": false, 18 | "httpOnly": false, 19 | "name": "_FP", 20 | "path": "/", 21 | "sameSite": "unspecified", 22 | "secure": false, 23 | "session": true, 24 | "storeId": "0", 25 | "value": "hta=on", 26 | "id": 2 27 | }, 28 | { 29 | "domain": ".bing.com", 30 | "hostOnly": false, 31 | "httpOnly": false, 32 | "name": "_SS", 33 | "path": "/", 34 | "sameSite": "unspecified", 35 | "secure": false, 36 | "session": true, 37 | "storeId": "0", 38 | "value": "SID=17742F425B7F62C00CBD22BF5A5163D6&HV=1571722979&bIm=588588", 39 | "id": 3 40 | }, 41 | { 42 | "domain": ".bing.com", 43 | "hostOnly": false, 44 | "httpOnly": false, 45 | "name": "DUP", 46 | "path": "/search", 47 | "sameSite": "unspecified", 48 | "secure": false, 49 | "session": true, 50 | "storeId": "0", 51 | "value": "Q=mZFLkyvTelC5g8XnyQrpOw2&T=372577225&A=2&IG=987498E931A647DEA3456C31FFD0552E", 52 | "id": 4 53 | }, 54 | { 55 | "domain": ".bing.com", 56 | "expirationDate": 1634794848, 57 | "hostOnly": false, 58 | "httpOnly": false, 59 | "name": "ENSEARCH", 60 | "path": "/", 61 | "sameSite": "unspecified", 62 | "secure": false, 63 | "session": false, 64 | "storeId": "0", 65 | "value": "BENVER=1", 66 | "id": 5 67 | }, 68 | { 69 | "domain": ".bing.com", 70 | "hostOnly": false, 71 | "httpOnly": false, 72 | "name": "ipv6", 73 | "path": "/", 74 | "sameSite": "unspecified", 75 | "secure": false, 76 | "session": true, 77 | "storeId": "0", 78 | "value": "hit=1571725982060&t=4", 79 | "id": 6 80 | }, 81 | { 82 | "domain": ".bing.com", 83 | "expirationDate": 1603174905.78698, 84 | "hostOnly": false, 85 | "httpOnly": false, 86 | "name": "MUID", 87 | "path": "/", 88 | "sameSite": "unspecified", 89 | "secure": false, 90 | "session": false, 91 | "storeId": "0", 92 | "value": "3D1C84B08ABD6D61282E887B8EBD6EA6", 93 | "id": 7 94 | }, 95 | { 96 | "domain": ".bing.com", 97 | "expirationDate": 1622477217, 98 | "hostOnly": false, 99 | "httpOnly": false, 100 | "name": "SerpPWA", 101 | "path": "/", 102 | "sameSite": "unspecified", 103 | "secure": false, 104 | "session": false, 105 | "storeId": "0", 106 | "value": "reg=1", 107 | "id": 8 108 | }, 109 | { 110 | "domain": ".bing.com", 111 | "expirationDate": 1612343964.947377, 112 | "hostOnly": false, 113 | "httpOnly": false, 114 | "name": "SRCHD", 115 | "path": "/", 116 | "sameSite": "unspecified", 117 | "secure": false, 118 | "session": false, 119 | "storeId": "0", 120 | "value": "AF=NOFORM", 121 | "id": 9 122 | }, 123 | { 124 | "domain": ".bing.com", 125 | "expirationDate": 1634794849, 126 | "hostOnly": false, 127 | "httpOnly": false, 128 | "name": "SRCHHPGUSR", 129 | "path": "/", 130 | "sameSite": "unspecified", 131 | "secure": false, 132 | "session": false, 133 | "storeId": "0", 134 | "value": "CW=1536&CH=239&DPR=1.25&UTC=480&WTS=63707319158", 135 | "id": 10 136 | }, 137 | { 138 | "domain": ".bing.com", 139 | "expirationDate": 1612343964.947452, 140 | "hostOnly": false, 141 | "httpOnly": false, 142 | "name": "SRCHUID", 143 | "path": "/", 144 | "sameSite": "unspecified", 145 | "secure": false, 146 | "session": false, 147 | "storeId": "0", 148 | "value": "V=2&GUID=194F51A9A90D4AD5921E6470B97CFA16&dmnchg=1", 149 | "id": 11 150 | }, 151 | { 152 | "domain": ".bing.com", 153 | "expirationDate": 1634794385, 154 | "hostOnly": false, 155 | "httpOnly": false, 156 | "name": "SRCHUSR", 157 | "path": "/", 158 | "sameSite": "unspecified", 159 | "secure": false, 160 | "session": false, 161 | "storeId": "0", 162 | "value": "DOB=20190203&T=1571722362000", 163 | "id": 12 164 | }, 165 | { 166 | "domain": ".bing.com", 167 | "expirationDate": 1634794383, 168 | "hostOnly": false, 169 | "httpOnly": false, 170 | "name": "ULC", 171 | "path": "/", 172 | "sameSite": "unspecified", 173 | "secure": false, 174 | "session": false, 175 | "storeId": "0", 176 | "value": "P=18F4B|9:6&H=18F4B|9:6&T=18F4B|9:6", 177 | "id": 13 178 | }, 179 | { 180 | "domain": "cn.bing.com", 181 | "expirationDate": 1605152065.43334, 182 | "hostOnly": true, 183 | "httpOnly": true, 184 | "name": "MUIDB", 185 | "path": "/", 186 | "sameSite": "unspecified", 187 | "secure": false, 188 | "session": false, 189 | "storeId": "0", 190 | "value": "3D1C84B08ABD6D61282E887B8EBD6EA6", 191 | "id": 14 192 | } 193 | ] -------------------------------------------------------------------------------- /cookie/google_search_cookie.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "domain": ".google.com", 4 | "expirationDate": 1573298781.834242, 5 | "hostOnly": false, 6 | "httpOnly": false, 7 | "name": "1P_JAR", 8 | "path": "/", 9 | "sameSite": "no_restriction", 10 | "secure": false, 11 | "session": false, 12 | "storeId": "0", 13 | "value": "2019-10-10-11", 14 | "id": 1 15 | }, 16 | { 17 | "domain": ".google.com", 18 | "expirationDate": 1584755886.598209, 19 | "hostOnly": false, 20 | "httpOnly": true, 21 | "name": "ANID", 22 | "path": "/", 23 | "sameSite": "unspecified", 24 | "secure": false, 25 | "session": false, 26 | "storeId": "0", 27 | "value": "AHWqTUmoPghfOd9IcQKbFW4pCCFySNoK-ZyX_SpVk5fClSy930hAZEEui7sJtkZK", 28 | "id": 2 29 | }, 30 | { 31 | "domain": ".google.com", 32 | "expirationDate": 1624943265.834577, 33 | "hostOnly": false, 34 | "httpOnly": false, 35 | "name": "APISID", 36 | "path": "/", 37 | "sameSite": "unspecified", 38 | "secure": false, 39 | "session": false, 40 | "storeId": "0", 41 | "value": "rWCcCsCVxih2hun8/Al44JYMbRfMq8qsEY", 42 | "id": 3 43 | }, 44 | { 45 | "domain": ".google.com", 46 | "expirationDate": 1572430968.415373, 47 | "hostOnly": false, 48 | "httpOnly": true, 49 | "name": "CGIC", 50 | "path": "/search", 51 | "sameSite": "unspecified", 52 | "secure": false, 53 | "session": false, 54 | "storeId": "0", 55 | "value": "InZ0ZXh0L2h0bWwsYXBwbGljYXRpb24veGh0bWwreG1sLGFwcGxpY2F0aW9uL3htbDtxPTAuOSxpbWFnZS93ZWJwLGltYWdlL2FwbmcsKi8qO3E9MC44LGFwcGxpY2F0aW9uL3NpZ25lZC1leGNoYW5nZTt2PWIz", 56 | "id": 4 57 | }, 58 | { 59 | "domain": ".google.com", 60 | "expirationDate": 2146723200.313104, 61 | "hostOnly": false, 62 | "httpOnly": false, 63 | "name": "CONSENT", 64 | "path": "/", 65 | "sameSite": "unspecified", 66 | "secure": false, 67 | "session": false, 68 | "storeId": "0", 69 | "value": "YES+GB.zh-CN+V14", 70 | "id": 5 71 | }, 72 | { 73 | "domain": ".google.com", 74 | "expirationDate": 1624943265.834515, 75 | "hostOnly": false, 76 | "httpOnly": true, 77 | "name": "HSID", 78 | "path": "/", 79 | "sameSite": "unspecified", 80 | "secure": false, 81 | "session": false, 82 | "storeId": "0", 83 | "value": "Aqrade56VBN92-WiA", 84 | "id": 6 85 | }, 86 | { 87 | "domain": ".google.com", 88 | "expirationDate": 1586517323.832578, 89 | "hostOnly": false, 90 | "httpOnly": true, 91 | "name": "NID", 92 | "path": "/", 93 | "sameSite": "unspecified", 94 | "secure": false, 95 | "session": false, 96 | "storeId": "0", 97 | "value": "189=mge_3FAbhZyOWyaRaQTJsPeSICI3Eg1nu8-0tYTG-0mob-bSU2r0SCVBYybge2MRPnwQL2pJdZQeaowCKwxME4IYXzJjj140QZNHj_ik5JvMkFW25mQPfyg1H5nIg-18GS6NSdHqHXjiiNXX474tn4Xf5yvRdH2oK4ow3i6l9dg0b78DcgEJTlFi-tSRO8x3ySuzAX9D80vdwUsYMjZAPzyd1JczEl47BOpaHwTJUmHNI2t4FLRoaJNBVY67KH0k9W9HGZa2ok32fyECbNDc4SiAJ45X35Bz_-2rNABDWq6iUe8oEkYxD4UBaFJdG69kBk3lTlWq", 98 | "id": 7 99 | }, 100 | { 101 | "domain": ".google.com", 102 | "expirationDate": 1624943265.834604, 103 | "hostOnly": false, 104 | "httpOnly": false, 105 | "name": "SAPISID", 106 | "path": "/", 107 | "sameSite": "unspecified", 108 | "secure": true, 109 | "session": false, 110 | "storeId": "0", 111 | "value": "CNw9qOjhKyKWmNdO/AYdIJnDN0CLsll2OK", 112 | "id": 8 113 | }, 114 | { 115 | "domain": ".google.com", 116 | "expirationDate": 1586258114.645478, 117 | "hostOnly": false, 118 | "httpOnly": false, 119 | "name": "SEARCH_SAMESITE", 120 | "path": "/", 121 | "sameSite": "strict", 122 | "secure": false, 123 | "session": false, 124 | "storeId": "0", 125 | "value": "CgQIg44B", 126 | "id": 9 127 | }, 128 | { 129 | "domain": ".google.com", 130 | "expirationDate": 1627825640.181697, 131 | "hostOnly": false, 132 | "httpOnly": false, 133 | "name": "SID", 134 | "path": "/", 135 | "sameSite": "unspecified", 136 | "secure": false, 137 | "session": false, 138 | "storeId": "0", 139 | "value": "mwdpOGoZ5P3AlVtNx4dseHwP7i_Npjdms6ibMFIIjUPBWwRi5AIfj1sq2Woo-UHyXlFTlw.", 140 | "id": 10 141 | }, 142 | { 143 | "domain": ".google.com", 144 | "expirationDate": 1578482781.834323, 145 | "hostOnly": false, 146 | "httpOnly": false, 147 | "name": "SIDCC", 148 | "path": "/", 149 | "sameSite": "unspecified", 150 | "secure": false, 151 | "session": false, 152 | "storeId": "0", 153 | "value": "AN0-TYs1nAX5DZjivRe9L5UQgVhjvvpAHBZ0T1cvmL2kyuAJ1gmFQ3J0zM8icaTHBEcjMNVhRPc", 154 | "id": 11 155 | }, 156 | { 157 | "domain": ".google.com", 158 | "expirationDate": 1624943265.834538, 159 | "hostOnly": false, 160 | "httpOnly": true, 161 | "name": "SSID", 162 | "path": "/", 163 | "sameSite": "unspecified", 164 | "secure": true, 165 | "session": false, 166 | "storeId": "0", 167 | "value": "AGIk5kkhqTmhIw_sg", 168 | "id": 12 169 | }, 170 | { 171 | "domain": "www.google.com", 172 | "expirationDate": 1570776482, 173 | "hostOnly": true, 174 | "httpOnly": false, 175 | "name": "OTZ", 176 | "path": "/", 177 | "sameSite": "unspecified", 178 | "secure": true, 179 | "session": false, 180 | "storeId": "0", 181 | "value": "5098008_24_24__24_", 182 | "id": 13 183 | }, 184 | { 185 | "domain": "www.google.com", 186 | "expirationDate": 1570792576, 187 | "hostOnly": true, 188 | "httpOnly": false, 189 | "name": "UULE", 190 | "path": "/", 191 | "sameSite": "unspecified", 192 | "secure": false, 193 | "session": false, 194 | "storeId": "0", 195 | "value": "a+cm9sZToxIHByb2R1Y2VyOjEyIHByb3ZlbmFuY2U6NiB0aW1lc3RhbXA6MTU3MDcwNjE3NjA3MTAwMCBsYXRsbmd7bGF0aXR1ZGVfZTc6MzU2MDkxOTczIGxvbmdpdHVkZV9lNzoxMzk3MzAzMzY0fSByYWRpdXM6MTY4NTc4MA==", 196 | "id": 14 197 | } 198 | ] -------------------------------------------------------------------------------- /example_cmd.sh: -------------------------------------------------------------------------------- 1 | zh="data/ZH-" 2 | en="data/EN-" 3 | while getopts "t:" arg 4 | do 5 | case $arg in 6 | t) 7 | case $OPTARG in 8 | zh0) 9 | cmd="python main.py -task extract -l zh -it ${zh}DSA-text.txt -is ${zh}DSA-seed.txt" 10 | echo $cmd 11 | ;; 12 | zh1) 13 | cmd="python main.py -task expand -l zh -is ${zh}DSA-seed.txt -ss baidu" 14 | echo $cmd 15 | ;; 16 | en0) 17 | cmd="python main.py -task extract -l en -it ${en}DSA-text.txt -is ${en}DSA-seed.txt" 18 | echo $cmd 19 | ;; 20 | en1) 21 | cmd="python main.py -task expand -l en -is ${en}DSA-seed.txt -ss bing" 22 | echo $cmd 23 | ;; 24 | *) 25 | echo "unknown value $OPTARG of arg t" 26 | ;; 27 | esac 28 | ;; 29 | *) 30 | ;; 31 | esac 32 | done 33 | $cmd 34 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import os 3 | 4 | if not os.path.exists('tmp/'): 5 | os.mkdir('tmp/') 6 | if not os.path.exists('data/'): 7 | os.system('wget http://lfs.aminer.cn/misc/moocdata/toolkit/data.zip') 8 | os.system('unzip data.zip') 9 | nltk.download('punkt') 10 | nltk.download('averaged_perceptron_tagger') -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from config import Config 4 | import modules.preprocess as preprocess 5 | from modules.algorithm import Algorithm 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Concept expansion with snippet') 9 | parser.add_argument('-task', type=str, choices=['extract', 'expand'], required=True) 10 | parser.add_argument('-input_text', '-it', type=str, help='the text file for concept extraction task') 11 | parser.add_argument('-input_seed', '-is', type=str, help='the seed file for concept extraction | expansion task') 12 | parser.add_argument('-language', '-l', type=str, choices=['zh', 'en'], required=True) 13 | parser.add_argument('-snippet_source', '-ss', default='baidu', type=str, choices=['baidu', 'google', 'bing']) 14 | parser.add_argument('-no_seed', '-ns', action='store_true', default=False, help='every candidate in text will be a seed') 15 | parser.add_argument('-algorithm', '-a', type=str, default='graph_propagation', choices=['graph_propagation', 'average_distance', 'tf_idf', 'pagerank']) 16 | parser.add_argument('-result_path', '-r', type=str, default='tmp/result.txt') 17 | parser.add_argument('-cpu', action='store_true', default=False) 18 | args = parser.parse_args() 19 | if not args.input_text and args.task == 'extract': 20 | raise Exception('concept extraction task need input_text') 21 | if not args.input_seed and args.task == 'expand': 22 | raise Exception('concept expansion task need input_seed') 23 | if not args.no_seed and not args.input_seed: 24 | raise Exception('seed config error') 25 | config = Config(args.task, args.input_text, args.input_seed, args.language, args.snippet_source, args.no_seed, args.algorithm, args.result_path, args.cpu) 26 | preprocess.get_candidates(config) 27 | algorithm = Algorithm(config) 28 | algorithm.get_result() 29 | 30 | if __name__ == '__main__': 31 | main() -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogan1234/concept-expansion-snippet/b1ed537eca6d41cb1905620f0b9d20940bdda690/modules/__init__.py -------------------------------------------------------------------------------- /modules/algorithm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import pickle 5 | import tqdm 6 | import math 7 | import json 8 | from pytorch_pretrained_bert import BertModel, BertTokenizer 9 | 10 | def calc_pow(x, y): 11 | if x > 0: 12 | return math.pow(x, y) 13 | else: 14 | return -math.pow(-x, y) 15 | 16 | class Algorithm: 17 | def __init__(self, config): 18 | self.config = config 19 | self._init() 20 | 21 | def _init(self): 22 | with open(self.config.candidate_path, 'r', encoding='utf-8') as f: 23 | self.candidates = f.read().split('\n') 24 | with open(self.config.text_path, 'r', encoding='utf-8') as f: 25 | self.text = f.read().split('\n') 26 | if os.path.exists(self.config.cached_vecs_path): 27 | with open(self.config.cached_vecs_path, 'rb') as f: 28 | self.cached_vecs = pickle.load(f) 29 | else: 30 | self.cached_vecs = {} 31 | print('Load data done, candidate number: {}, text line number: {}, cached vocab vectors: {}'.format(len(self.candidates), len(self.text), len(self.cached_vecs))) 32 | if self.config.language == 'zh': 33 | self.bert = BertModel.from_pretrained('bert-base-chinese') 34 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 35 | if self.config.language == 'en': 36 | self.bert = BertModel.from_pretrained('bert-base-uncased') 37 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 38 | self.bert.eval() 39 | for p in self.bert.parameters(): 40 | p.requires_grad = False 41 | if not self.config.cpu: 42 | self.bert.cuda() 43 | print('Load bert done.') 44 | 45 | def _get_vector(self, batch): 46 | input = [] 47 | for concept in batch: 48 | c = self.tokenizer.tokenize(concept) 49 | c = [101]+self.tokenizer.convert_tokens_to_ids(c)+[102] 50 | input.append(c) 51 | max_len = max([len(c) for c in input]) 52 | input = [torch.tensor(c+[0]*(max_len-len(c)), dtype=torch.long) for c in input] 53 | input = torch.stack(input) 54 | if not self.config.cpu: 55 | input = input.cuda() 56 | with torch.no_grad(): 57 | h, _ = self.bert(input, attention_mask=(input>0), output_all_encoded_layers=False) 58 | h = h.detach() 59 | for i, concept in enumerate(batch): 60 | r = torch.sum(input[i]>0) 61 | if r > 2: 62 | vec = torch.mean(h[i][1:r-1], 0).cpu().numpy() 63 | self.cached_vecs[concept] = vec / np.sqrt(np.sum(vec*vec)) 64 | else: 65 | self.cached_vecs[concept] = None 66 | 67 | def get_vector(self): 68 | print('Get concept vectors.') 69 | batch = [] 70 | for concept in self.candidates: 71 | if concept in self.cached_vecs: 72 | continue 73 | batch.append(concept) 74 | if len(batch) >= self.config.batch_size: 75 | self._get_vector(batch) 76 | batch = [] 77 | if batch: 78 | self._get_vector(batch) 79 | self.concepts = [] 80 | self.vecs = [] 81 | for c in self.candidates: 82 | if c in self.cached_vecs and self.cached_vecs[c] is not None: 83 | self.concepts.append(c) 84 | self.vecs.append(self.cached_vecs[c]) 85 | self.vecs = np.stack(self.vecs) 86 | print('Load vec done, concepts number: {}, vecs shape: {}'.format(len(self.concepts), self.vecs.shape)) 87 | with open(self.config.cached_vecs_path, 'wb') as f: 88 | pickle.dump(self.cached_vecs, f) 89 | 90 | def cal_vector_distance(self): 91 | print('Start calculate vector distance.') 92 | max_num = self.config.max_num 93 | n = self.vecs.shape[0] 94 | m = n if max_num == -1 else min(n, max_num) 95 | weights = np.dot(self.vecs, self.vecs.T) 96 | sorted_indexes = np.argsort(-weights)[:, :m] 97 | self.edges = [] 98 | for i in tqdm.tqdm(range(n)): 99 | weight, sorted_index = weights[i], sorted_indexes[i] 100 | edge = [] 101 | for j in range(m): 102 | w = weight[sorted_index[j]] 103 | target = sorted_index[j] 104 | if w > self.config.threshold: 105 | edge.append([w, target]) 106 | else: 107 | break 108 | self.edges.append(edge) 109 | 110 | def init_score_list(self): 111 | n = len(self.concepts) 112 | if self.config.no_seed: 113 | score_list = np.ones(n) 114 | else: 115 | with open(self.config.input_seed, 'r', encoding='utf-8') as f: 116 | seed_set = set([seed.strip() for seed in f.read().split('\n')]) 117 | score_list = np.zeros(n) 118 | for i, c in enumerate(self.concepts): 119 | if c in seed_set: 120 | score_list[i] = 1 121 | print('Seed number in candidate concepts:', np.sum(score_list)) 122 | return score_list 123 | 124 | def graph_propagation(self): 125 | print('Graph propagation:') 126 | self.cal_vector_distance() 127 | score_list = self.init_score_list() 128 | final_score_list = score_list 129 | for i in tqdm.tqdm(range(self.config.times)): 130 | new_score_list = np.zeros(score_list.shape) 131 | for source, score in enumerate(score_list): 132 | if score != 0.0: 133 | for (w, target) in self.edges[source]: 134 | s = score * w 135 | if self.config.language == 'zh': 136 | s *= math.log(len(self.concepts[target])+1) 137 | new_score_list[target] += s 138 | new_score_list /= np.max(new_score_list) 139 | score_list = new_score_list 140 | final_score_list += score_list * calc_pow(self.config.decay, i+1) 141 | return final_score_list 142 | 143 | def average_distance(self): 144 | print('average distance:') 145 | seed_set = set() 146 | n = len(self.concepts) 147 | if self.config.no_seed: 148 | seed_vecs = [vec for vec in self.vecs] 149 | else: 150 | with open(self.config.input_seed, 'r', encoding='utf-8') as f: 151 | seed_set = set([seed.strip() for seed in f.read().split('\n')]) 152 | seed_vecs = [] 153 | for i, c in enumerate(self.concepts): 154 | if c in seed_set: 155 | seed_vecs.append(self.vecs[i]) 156 | seed_vecs = np.stack(seed_vecs) 157 | print('Seed number in candidate concepts:', seed_vecs.shape[0]) 158 | score_list = np.mean(np.dot(self.vecs, seed_vecs.T), axis=1) 159 | return score_list 160 | 161 | def tf_idf(self): 162 | print('tf idf:') 163 | n = len(self.concepts) 164 | score_list = np.zeros(n) 165 | for i in tqdm.tqdm(range(n)): 166 | c = self.concepts[i] 167 | tf = max([len(t.split(c))-1 for t in self.text]) 168 | idf = sum([c in t for t in self.text]) 169 | score_list[i] = tf / math.log(1+idf) 170 | if self.config.language == 'zh': 171 | score_list[i] *= math.log(len(c)+1) 172 | return score_list 173 | 174 | def pagerank(self): 175 | score_list = self.init_score_list() 176 | n = len(self.concepts) 177 | mat = np.zeros((n, n)) 178 | for t in tqdm.tqdm(self.text): 179 | g = [i for i in range(n) if self.concepts[i] in t] 180 | for p1 in g: 181 | for p2 in g: 182 | mat[p1, p2] += 1.0 183 | for i in range(n): 184 | mat[i] /= np.sum(mat[i]) 185 | for i in tqdm.tqdm(range(self.config.times)): 186 | score_list = np.matmul(score_list, mat) 187 | return score_list 188 | 189 | def get_result(self): 190 | self.get_vector() 191 | if self.config.algorithm == 'graph_propagation': 192 | score_list = self.graph_propagation() 193 | if self.config.algorithm == 'average_distance': 194 | score_list = self.average_distance() 195 | if self.config.algorithm == 'tf_idf': 196 | score_list = self.tf_idf() 197 | if self.config.algorithm == 'pagerank': 198 | score_list = self.pagerank() 199 | sorted_list = np.argsort(-score_list) 200 | with open(self.config.result_path, 'w', encoding='utf-8') as f: 201 | for index in sorted_list: 202 | obj = {'name': self.concepts[index], 'score': float(score_list[index])} 203 | f.write(json.dumps(obj, ensure_ascii=False)+'\n') 204 | print('Get result finished.') -------------------------------------------------------------------------------- /modules/crawler.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import requests 3 | from bs4 import BeautifulSoup 4 | import re 5 | import time 6 | import random 7 | import urllib 8 | import json 9 | import os 10 | 11 | USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36' 12 | 13 | def sleep(t): 14 | time.sleep(t+random.random()*t*0.5) 15 | 16 | def clean(text): 17 | text = re.sub(r'\n|\r', '', text).strip() 18 | return text 19 | 20 | class Crawler: 21 | def __init__(self, config): 22 | self.config = config 23 | self.sess = requests.Session() 24 | if self.config.snippet_source == 'google': 25 | self.sess.proxies.update(self.config.proxy) 26 | self.tot_crawl = 0 27 | self._init() 28 | 29 | def _init(self): 30 | conn = sqlite3.connect(self.config.db) 31 | cursor = conn.cursor() 32 | cursor.execute('CREATE TABLE IF NOT EXISTS baidu (concept TEXT PRIMARY KEY NOT NULL, snippet TEXT NOT NULL)') 33 | cursor.execute('CREATE TABLE IF NOT EXISTS google (concept TEXT PRIMARY KEY NOT NULL, snippet TEXT NOT NULL)') 34 | cursor.execute('CREATE TABLE IF NOT EXISTS bing (concept TEXT PRIMARY KEY NOT NULL, snippet TEXT NOT NULL)') 35 | conn.commit() 36 | conn.close() 37 | for cookie_path in self.config.cookie_paths: 38 | with open(cookie_path, 'r', encoding='utf-8') as f: 39 | cookie = json.load(f) 40 | for i in range(len(cookie)): 41 | self.sess.cookies.set(cookie[i]['name'], cookie[i]['value']) 42 | 43 | def update_cookie(self, cookie): 44 | for c in cookie.split('; '): 45 | c = c.split('=', 1) 46 | if len(c) == 2: 47 | self.sess.cookies.set(c[0], c[1]) 48 | 49 | def crawl_snippet_google(self, concept): 50 | res = [] 51 | url = 'https://www.google.com/search?gws_rd=cr&q={}'.format(concept) 52 | headers = {'user-agent': USER_AGENT, 'referer': 'https://www.google.com/'} 53 | page = self.sess.get(url, headers=headers) 54 | if 'Set-Cookie' in page.headers: 55 | self.update_cookie(page.headers['Set-Cookie']) 56 | soup = BeautifulSoup(page.text, 'html.parser') 57 | block = soup.find('div', class_='ifM9O') 58 | if block is not None: 59 | title, snippet = '', '' 60 | t = block.find('div', class_='r') 61 | s = block.find('div', class_='LGOjhe') 62 | if t and t.find('a') and t.find('h3') and s: 63 | title = clean(t.find('a').find('h3').text) 64 | snippet = clean(s.text) 65 | res.append('{} {}'.format(title, snippet)) 66 | for block in soup.find_all('div', class_='g'): 67 | title, snippet = '', '' 68 | t = block.find('div', class_='r') 69 | s = block.find('span', class_='st') 70 | if t and t.find('a') and t.find('h3') and s: 71 | title = clean(t.find('a').find('h3').text) 72 | snippet = clean(s.text) 73 | res.append('{} {}'.format(title, snippet)) 74 | return res 75 | 76 | def crawl_snippet_baidu(self, concept): 77 | res = [] 78 | url = 'http://www.baidu.com/s?wd={}'.format(concept) 79 | headers = {'user-agent': USER_AGENT, 'referer': url} 80 | page = self.sess.get(url, headers=headers) 81 | soup = BeautifulSoup(page.text, 'html.parser') 82 | block = soup.find('div', class_='result-op c-container xpath-log') 83 | if block is not None: 84 | title, snippet = '', '' 85 | t = block.find('h3', class_='t') 86 | s = block.find('div', class_='c-span18 c-span-last') 87 | if t and t.find('a') and s and s.find('p'): 88 | title = clean(t.find('a').text) 89 | snippet = clean(s.find('p').text) 90 | res.append('{} {}'.format(title, snippet)) 91 | for block in soup.find_all('div', class_='result c-container' + (' ' if os.name == 'nt' else '')): 92 | title, snippet = '', '' 93 | t = block.find('h3', class_='t') 94 | s = block.find('div', class_='c-abstract') 95 | if t and t.find('a') and s: 96 | title = clean(t.find('a').text) 97 | snippet = clean(s.text) 98 | res.append('{} {}'.format(title, snippet)) 99 | return res 100 | 101 | def crawl_snippet_bing(self, concept): 102 | res = [] 103 | url = 'https://cn.bing.com/search?q={}'.format(concept) 104 | headers = {'user-agent': USER_AGENT, 'referer': url} 105 | page = self.sess.get(url, headers=headers) 106 | soup = BeautifulSoup(page.text, 'html.parser') 107 | if 'cookie' in page.headers: 108 | self.update_cookie(page.headers['cookie']) 109 | block = soup.find('div', class_='b_subModule') 110 | if block is not None: 111 | title, snippet = '', '' 112 | t = block.find('h2', class_='b_entityTitle') 113 | s = block.find('div', class_='b_lBottom') 114 | if t and s: 115 | title = clean(t.text) 116 | snippet = clean(t.text) 117 | res.append('{} {}'.format(title, snippet)) 118 | for block in soup.find_all('li', class_='b_algo'): 119 | title, snippet = '', '' 120 | t = block.find('h2') 121 | if t and t.find('a'): 122 | title = clean(t.find('a').text) 123 | s = block.find('div', class_='b_caption') 124 | if s and s.find('p'): 125 | snippet = clean(s.find('p').text) 126 | s = block.find('div', class_='tab-content') 127 | if s and s.find('div'): 128 | snippet = s.find('div').text 129 | if title and snippet: 130 | res.append('{} {}'.format(title, snippet)) 131 | return res 132 | 133 | def crawl_snippet(self, concept): 134 | self.tot_crawl += 1 135 | if self.tot_crawl % 100 == 0: 136 | print('sleep 60s~90s after crawl 100 times') 137 | sleep(60) 138 | concept = urllib.parse.quote_plus(concept) 139 | sleep(2) 140 | if self.config.snippet_source == 'baidu': 141 | res = self.crawl_snippet_baidu(concept) 142 | if self.config.snippet_source == 'google': 143 | res = self.crawl_snippet_google(concept) 144 | if self.config.snippet_source == 'bing': 145 | res = self.crawl_snippet_bing(concept) 146 | return '\n'.join(res) 147 | 148 | def get_snippet(self, concept): 149 | conn = sqlite3.connect(self.config.db) 150 | cursor = conn.cursor() 151 | cursor.execute('SELECT * FROM {} WHERE concept=?'.format(self.config.snippet_source), (concept, )) 152 | res = cursor.fetchall() 153 | if not res: 154 | snippet = self.crawl_snippet(concept) 155 | print('get snippet {} from source {}'.format(concept, self.config.snippet_source)) 156 | cursor.execute('INSERT INTO {} (concept, snippet) VALUES (?,?)'.format(self.config.snippet_source), (concept, snippet, )) 157 | conn.commit() 158 | else: 159 | snippet = res[0][1] 160 | conn.close() 161 | return snippet -------------------------------------------------------------------------------- /modules/preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import nltk 3 | import jieba 4 | import jieba.posseg as pseg 5 | import json 6 | from modules.crawler import Crawler 7 | 8 | def is_noun(config, flag): 9 | if config.language == 'en': 10 | flag = re.sub('JJ[RS]?', 'JJ', flag) 11 | flag = re.sub('NN[SP(PS)]?', 'NN', flag) 12 | if re.match(r'^((@(JJ|NN))+|(@(JJ|NN))*(@(NN|IN))?(@(JJ|NN))*)@NN$', flag) is not None: 13 | return True 14 | else: 15 | return False 16 | if config.language == 'zh': 17 | if re.match(r'^(@(([av]?n[rstz]?)|l|a|v))*(@(([av]?n[rstz]?)|l))$', flag) is not None: 18 | return True 19 | else: 20 | return False 21 | 22 | def get_candidates(config): 23 | if config.task == 'expand': 24 | crawler = Crawler(config) 25 | text = [] 26 | with open(config.input_seed, 'r', encoding='utf-8') as f: 27 | for seed in f.read().split('\n'): 28 | if seed: 29 | text.append(crawler.get_snippet(seed)) 30 | text = '\n'.join(text) 31 | else: 32 | with open(config.input_text, 'r', encoding='utf-8') as f: 33 | text = f.read() 34 | text = re.sub('\xa3|\xae|\x0d', '', text).lower() 35 | if config.language == 'en': 36 | with open(config.en_list, 'r', encoding='utf-8') as f: 37 | vocabs = set(f.read().split('\n')) 38 | if config.language == 'zh': 39 | with open(config.zh_list, 'r', encoding='utf-8') as f: 40 | vocabs = set(f.read().split('\n')) 41 | res = set() 42 | for line in text.split('\n'): 43 | if config.language == 'en': 44 | tmp = nltk.word_tokenize(line) 45 | seg = nltk.pos_tag(tmp) 46 | if config.language == 'zh': 47 | tmp = pseg.cut(line) 48 | seg = [(t.word, t.flag) for t in tmp] 49 | n = len(seg) 50 | for i in range(n): 51 | phrase, flag = seg[i][0], '@'+seg[i][1] 52 | for j in range(i+1, min(n+1, i+7)): 53 | if phrase not in res and phrase in vocabs and is_noun(config, flag): 54 | res.add(phrase) 55 | if j < n: 56 | if config.language == 'en': 57 | phrase += ' '+seg[j][0] 58 | if config.language == 'zh': 59 | phrase += seg[j][0] 60 | flag += '@'+seg[j][1] 61 | print('candidate concepts number: {}'.format(len(res))) 62 | with open(config.text_path, 'w', encoding='utf-8') as f: 63 | f.write(text) 64 | with open(config.candidate_path, 'w', encoding='utf-8') as f: 65 | f.write('\n'.join(list(res))) 66 | print('preprocess done.') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jieba 2 | requests 3 | bs4 4 | nltk 5 | pytorch-pretrained-bert --------------------------------------------------------------------------------