├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── logo.gif ├── problems.txt ├── problems └── 1000.json ├── settings_sample.json └── src ├── build_locale.py ├── build_summary.py ├── custom_about.py ├── embedder.py ├── requirements.txt ├── ui.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | embs*/ 2 | *.bak* 3 | tmp/ 4 | runit.bat 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | 12 | .DS_Store 13 | settings.json 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | /problems/vjudge 163 | /vjudge 164 | /problemso 165 | /vjudge--- 166 | /vjudge-- 167 | /vjudge- 168 | scrapped.tar.gz 169 | skp.7z 170 | src/vjudge_helper.py 171 | is-my-problem-new.7z 172 | perfopt.ipynb 173 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ziqian Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Is my problem new? 2 | A simple semantic search engine on competitive programming problems. 3 | http://yuantiji.ac | Buy me a boba 4 | 5 | 6 | 7 | **Update (2024/7/16):** It has been a long time :) Reorganized problems path. Switched LLM / embedder to [Gemma 2 9B](https://huggingface.co/google/gemma-2-9b-it) hosted by [together.ai](https://docs.together.ai) and [voyage-large-2-instruct](https://docs.voyageai.com/docs/pricing). Tweaked the prompt a little bit. Bought a new domain (see the link above) and switched to [vjudge](https://vjudge.net) as data source. See branch `old_ver` or history commits for the previous version. 8 | 9 | **Update (2024/5/19):** Added AtCoder. Thanks [@fstqwq](https://github.com/fstqwq) for the contribution! 10 | 11 | #### How does this work? 12 | 13 | This idea is simple: 14 | 15 | 1. Simplify the statement & remove background by prompting LLM. 16 | 17 | 2. Embed the simplified documents and queries to perform vector searches. 18 | 19 | It only happens recently that both models are good and cheap enough. 20 | 21 | This pipeline is also not limited, of course, to competitive programming problems. You can use it to search for any kind of documents by modifying the prompt. 22 | 23 | #### Deploy 24 | 25 | You will need API keys from OpenAI, Together and Voyage. You can check their pricings online. 26 | 27 | Put problems in `problems/` folder following the given example (`problems/1000.json`). Naming could be arbitrary and you could also have nested folders. Run `python -m src.build_summary` to get paraphrased statements, run `python -m src.build_embedding` to build embeddings and run `python -m src.build_locale` to detect language of problems. Finally, run `python -m src.ui` to start serving. 28 | 29 | For large-scale running decent CPUs are needed as doing vector searching is CPU-dense. You might also want to modify `max_workers` in `src/ui.py`. 30 | 31 | Due to copyright concerns we're not providing scrapped vjudge problems and vjudge scrapper. Sorry D: We also did not process the statements in PDF. If you have problems you want to add that are not supported in vjudge feel free to contact me or send PR and I'll see what I can do (would be perfect if you can just send over a zip in the correct format xd). 32 | 33 | For reference, adding all ~160k problems from vjudge cost ~$60 and as of writing the deployed site is running on a 8vCPU server. 34 | -------------------------------------------------------------------------------- /logo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fjzzq2002/is-my-problem-new/9e6feba786d86a9942d12ec0309d0adc2c00811d/logo.gif -------------------------------------------------------------------------------- /problems/1000.json: -------------------------------------------------------------------------------- 1 | {"uid": "51Nod/1000", "url": "https://vjudge.net/problem/51Nod-1000", "tags": [], "title": "A + B", "statement": "给出 $2$ 个整数 $A$ 和 $B$ ,计算两个数的和。", "source": "51Nod", "vjudge": true} -------------------------------------------------------------------------------- /settings_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "OPENAI_API_KEY": "[TODO]", 3 | "TOGETHER_API_KEY": "[TODO]", 4 | "VOYAGE_API_KEY": "[TODO]", 5 | "TEMPLATES": [ 6 | "I have the following competitive programming problem that I want to show someone else:\n\n=======\n[[ORIGINAL]]\n=======\n\nStrip off all the stories, legends, characters, backgrounds etc. from the statement while still enabling everyone to understand the problem. Also remove the name of the character if possible. This is to say, do not remove anything necessary to understand the full problem and one should feel safe to replace the original statement with your version of the statement. If it is not in English make it English. Provide the simplified statement directly without jargon. Use mathjax ($...$) for math. Start your response with \"Simplified statement:\".", 7 | "I have the following competitive programming problem that I want to show someone else:\n\n=======\n[[ORIGINAL]]\n=======\n\nStrip off all the stories, legends, characters, backgrounds, examples, well-known definitions etc. from the statement while still enabling everyone to understand the problem. Also remove the name of the character if applicable. If it is not in English translate it. Make it as succinct as possible while still being understandable. Try to avoid formulas and symbols. Abstract freely - for example, if the problem is about buying sushi, you can just phrase it as a knapsack problem. If necessary, mathjax ($...$) for math. Provide the *succinct* simplified statement directly without jargon. Start your response with \"Simplified statement:\"." 8 | ], 9 | "CUSTOM_HEADER": "", 10 | "CUSTOM_ABOUT_PY": "src/custom_about.py" 11 | } -------------------------------------------------------------------------------- /src/build_locale.py: -------------------------------------------------------------------------------- 1 | # run summarize for all the problems 2 | # use the chatgpt api 3 | import requests 4 | import json 5 | from .utils import read_problem, problem_filenames, dump_json_safe, dump_json_safe_utf8 6 | from openai import AsyncOpenAI 7 | from together import AsyncTogether 8 | import anthropic 9 | import hashlib 10 | import asyncio 11 | from tqdm.auto import tqdm 12 | 13 | # from tqdm import tqdm 14 | import time 15 | 16 | start_time = time.time() 17 | 18 | 19 | with open("settings.json") as f: 20 | settings = json.load(f) 21 | 22 | 23 | from lingua import Language, LanguageDetectorBuilder 24 | detector = LanguageDetectorBuilder.from_all_languages().with_minimum_relative_distance(0.9).build() 25 | # ty chatgpt 26 | lang_map = { 27 | Language.AFRIKAANS: ("za", "Afrikaans"), 28 | Language.ALBANIAN: ("al", "Albanian"), 29 | Language.ARABIC: ("sa", "Arabic"), 30 | Language.ARMENIAN: ("am", "Armenian"), 31 | Language.AZERBAIJANI: ("az", "Azerbaijani"), 32 | Language.BASQUE: ("eu", "Basque"), 33 | Language.BELARUSIAN: ("by", "Belarusian"), 34 | Language.BENGALI: ("bd", "Bengali"), 35 | Language.BOKMAL: ("no", "Bokmål"), 36 | Language.BOSNIAN: ("ba", "Bosnian"), 37 | Language.BULGARIAN: ("bg", "Bulgarian"), 38 | Language.CATALAN: ("ad", "Catalan"), 39 | Language.CHINESE: ("cn", "Chinese"), 40 | Language.CROATIAN: ("hr", "Croatian"), 41 | Language.CZECH: ("cz", "Czech"), 42 | Language.DANISH: ("dk", "Danish"), 43 | Language.DUTCH: ("nl", "Dutch"), 44 | Language.ENGLISH: ("gb", "English"), 45 | Language.ESPERANTO: ("eo", "Esperanto"), 46 | Language.ESTONIAN: ("ee", "Estonian"), 47 | Language.FINNISH: ("fi", "Finnish"), 48 | Language.FRENCH: ("fr", "French"), 49 | Language.GANDA: ("ug", "Ganda"), 50 | Language.GEORGIAN: ("ge", "Georgian"), 51 | Language.GERMAN: ("de", "German"), 52 | Language.GREEK: ("gr", "Greek"), 53 | Language.GUJARATI: ("in", "Gujarati"), 54 | Language.HEBREW: ("il", "Hebrew"), 55 | Language.HINDI: ("in", "Hindi"), 56 | Language.HUNGARIAN: ("hu", "Hungarian"), 57 | Language.ICELANDIC: ("is", "Icelandic"), 58 | Language.INDONESIAN: ("id", "Indonesian"), 59 | Language.IRISH: ("ie", "Irish"), 60 | Language.ITALIAN: ("it", "Italian"), 61 | Language.JAPANESE: ("jp", "Japanese"), 62 | Language.KAZAKH: ("kz", "Kazakh"), 63 | Language.KOREAN: ("kr", "Korean"), 64 | Language.LATIN: ("va", "Latin"), 65 | Language.LATVIAN: ("lv", "Latvian"), 66 | Language.LITHUANIAN: ("lt", "Lithuanian"), 67 | Language.MACEDONIAN: ("mk", "Macedonian"), 68 | Language.MALAY: ("my", "Malay"), 69 | Language.MAORI: ("nz", "Maori"), 70 | Language.MARATHI: ("in", "Marathi"), 71 | Language.MONGOLIAN: ("mn", "Mongolian"), 72 | Language.NYNORSK: ("no", "Nynorsk"), 73 | Language.PERSIAN: ("ir", "Persian"), 74 | Language.POLISH: ("pl", "Polish"), 75 | Language.PORTUGUESE: ("pt", "Portuguese"), 76 | Language.PUNJABI: ("in", "Punjabi"), 77 | Language.ROMANIAN: ("ro", "Romanian"), 78 | Language.RUSSIAN: ("ru", "Russian"), 79 | Language.SERBIAN: ("rs", "Serbian"), 80 | Language.SHONA: ("zw", "Shona"), 81 | Language.SLOVAK: ("sk", "Slovak"), 82 | Language.SLOVENE: ("si", "Slovene"), 83 | Language.SOMALI: ("so", "Somali"), 84 | Language.SOTHO: ("za", "Sotho"), 85 | Language.SPANISH: ("es", "Spanish"), 86 | Language.SWAHILI: ("ke", "Swahili"), 87 | Language.SWEDISH: ("se", "Swedish"), 88 | Language.TAGALOG: ("ph", "Tagalog"), 89 | Language.TAMIL: ("in", "Tamil"), 90 | Language.TELUGU: ("in", "Telugu"), 91 | Language.THAI: ("th", "Thai"), 92 | Language.TSONGA: ("za", "Tsonga"), 93 | Language.TSWANA: ("bw", "Tswana"), 94 | Language.TURKISH: ("tr", "Turkish"), 95 | Language.UKRAINIAN: ("ua", "Ukrainian"), 96 | Language.URDU: ("pk", "Urdu"), 97 | Language.VIETNAMESE: ("vn", "Vietnamese"), 98 | Language.WELSH: ("gb", "Welsh"), 99 | Language.XHOSA: ("za", "Xhosa"), 100 | Language.YORUBA: ("ng", "Yoruba"), 101 | Language.ZULU: ("za", "Zulu"), 102 | } 103 | def process_all_problems(): 104 | fns = list(problem_filenames()) 105 | for problem_file_cur in tqdm(fns): 106 | try: 107 | p = read_problem(problem_file_cur) 108 | except Exception as e: 109 | print('error',problem_file_cur,e) 110 | continue 111 | if 'locale' in p: 112 | continue 113 | detected = detector.detect_language_of(p["title"]+'\n'+p['statement']) 114 | rst = None 115 | if detected is None: 116 | rst = ('un', 'Unknown') 117 | else: 118 | rst = lang_map[detected] 119 | p['locale'] = rst 120 | dump_json_safe_utf8(p, problem_file_cur) 121 | 122 | if __name__ == "__main__": 123 | process_all_problems() -------------------------------------------------------------------------------- /src/build_summary.py: -------------------------------------------------------------------------------- 1 | # run summarize for all the problems 2 | # use the chatgpt api 3 | import requests 4 | import json 5 | from .utils import read_problem, problem_filenames, dump_json_safe, dump_json_safe_utf8 6 | from openai import AsyncOpenAI 7 | from together import AsyncTogether 8 | import anthropic 9 | import hashlib 10 | import asyncio 11 | from tqdm.auto import tqdm 12 | 13 | # from tqdm import tqdm 14 | import time 15 | 16 | start_time = time.time() 17 | 18 | 19 | with open("settings.json") as f: 20 | settings = json.load(f) 21 | 22 | client = AsyncTogether( 23 | api_key=settings['TOGETHER_API_KEY'], 24 | ) 25 | 26 | 27 | def check_processed(p, template): 28 | ORIGINAL = p["statement"] 29 | prompt = template.replace("[[ORIGINAL]]", ORIGINAL).strip() 30 | prompt_md5 = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8] 31 | for f in p["processed"]: 32 | if f["prompt_md5"][:8] == prompt_md5: 33 | return True 34 | return False 35 | 36 | 37 | async def process(p, template, delay = 0): 38 | # sleep for delay first 39 | await asyncio.sleep(delay) 40 | ORIGINAL = p["statement"] 41 | prompt = template.replace("[[ORIGINAL]]", ORIGINAL).strip() 42 | template_md5 = hashlib.md5(template.encode("utf-8")).hexdigest()[:8] 43 | prompt_md5 = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8] 44 | already_processed = False 45 | for f in p["processed"]: 46 | if f["prompt_md5"][:8] == prompt_md5: 47 | already_processed = True 48 | if already_processed: 49 | return 50 | # print(prompt, prompt_md5) 51 | # print(num_tokens_from_string(prompt)) 52 | result = None 53 | try: 54 | response = await client.chat.completions.create( 55 | messages=[ 56 | { 57 | "role": "user", 58 | "content": prompt, 59 | }, 60 | { "role": "assistant", "content": "Simplified statement:" } 61 | ], 62 | model="google/gemma-2-9b-it", 63 | ) 64 | # assert chat_completion.stop_reason=='end_turn' 65 | result = response.choices[0].message.content.strip() 66 | print(f"Number of tokens spent: {response.usage.total_tokens}") 67 | except Exception as e: 68 | print("Error while prompting:", e) 69 | if result is None: 70 | return [] 71 | return [ 72 | { 73 | "prompt_md5": prompt_md5, 74 | "template_md5": template_md5, 75 | "result": result, 76 | } 77 | ] 78 | 79 | 80 | async def process_all_problems(): 81 | # apparently some mysterious OJs are spending my money ;_; 82 | goodojs = ['UOJ', 'Codeforces', '洛谷', 'DMOJ', 'HDU', 'CodeChef', 'AtCoder', 'LibreOJ', 'TopCoder', 'SPOJ', '51Nod', '黑暗爆炸', 'UVA'] #, 'USACO' 83 | badojs = ['HYSBZ', 'BZOJ'] 84 | fns = sorted(list(problem_filenames()),key=lambda x:int(not any(goodoj.lower() in x.lower() for goodoj in goodojs))+int(any(badoj.lower() in x.lower() for badoj in badojs))) 85 | chunk_size = 50 86 | gap_every = 1/8.5 87 | problem_files = [] 88 | for problem_file_cur in tqdm(fns):#tqdm(range(0,len(fns),chunk_size)): 89 | try: 90 | p = read_problem(problem_file_cur) 91 | except Exception as e: 92 | print('error',problem_file_cur,e) 93 | continue 94 | need_work = False 95 | for template in settings["TEMPLATES"]: 96 | if 'processed' in p and check_processed(p, template): 97 | continue 98 | need_work = True 99 | if need_work: 100 | problem_files.append(problem_file_cur) 101 | if len(problem_files) >= chunk_size or problem_file_cur == fns[-1]: 102 | for template in settings["TEMPLATES"]: 103 | t0 = time.time() 104 | tasks = [] 105 | notprocessed = [] 106 | for idx, problem_file in enumerate(problem_files): 107 | p = read_problem(problem_file) 108 | if "processed" not in p: 109 | p["processed"] = [] 110 | if check_processed(p, template): 111 | continue 112 | notprocessed.append(problem_file) 113 | tasks.append(process(p, template, idx * gap_every)) 114 | if not len(tasks): 115 | continue 116 | WAIT = chunk_size * gap_every + .5 117 | results = await asyncio.gather(*tasks) 118 | for problem_file, result in zip(notprocessed, results): 119 | if not len(result): 120 | WAIT = 6 121 | continue 122 | p = read_problem(problem_file) 123 | if "processed" not in p: 124 | p["processed"] = [] 125 | p["processed"].extend(result) 126 | print(problem_file) 127 | dump_json_safe_utf8(p, problem_file) 128 | t1 = time.time() 129 | print('time elapsed',t1-t0) 130 | # wait till WAIT 131 | if t1-t0 < WAIT: 132 | await asyncio.sleep(WAIT-(t1-t0)) 133 | problem_files = [] 134 | 135 | if __name__ == "__main__": 136 | asyncio.run(process_all_problems()) -------------------------------------------------------------------------------- /src/custom_about.py: -------------------------------------------------------------------------------- 1 | gr.Markdown( 2 | """English version | 中文版本\n\n"""+ 3 | tr('This project is maintained by @TLE. You can find the open source version here. We would like to thank virtual judge for supporting this project!', '本项目由 @TLE 维护。你可以在 这里 找到开源版本。感谢 Virtual Judge 对本项目的大力支持!')+'\n\n'+ 4 | tr('Consider donating if the project is helpful for you! All proceeds will be put into maintenance :)','服务器和运营成本都是 @TLE 自己掏的,所以如果你觉得这个项目对你有用欢迎 捐助!') 5 | +'\n\n'+tr('QQ Discussion Group (for Chinese users): ','QQ群:')+'829942060'+tr('',',欢迎前来吹水对线!')) -------------------------------------------------------------------------------- /src/embedder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import requests 3 | import json 4 | from .utils import read_problem, problem_filenames, dump_json_safe, dump_json_safe_utf8, dump_pickle_safe 5 | from openai import AsyncOpenAI 6 | from together import Together, AsyncTogether 7 | import anthropic 8 | import hashlib 9 | import asyncio 10 | from tqdm.auto import tqdm 11 | 12 | # from tqdm import tqdm 13 | import time, os 14 | import random 15 | import voyageai 16 | 17 | 18 | with open("settings.json") as f: 19 | settings = json.load(f) 20 | 21 | voyage_client = voyageai.Client( 22 | api_key=settings['VOYAGE_API_KEY'], 23 | max_retries=3, 24 | timeout=120, 25 | ) 26 | 27 | # client = Together( 28 | # api_key=settings['TOGETHER_API_KEY'], 29 | # ) 30 | 31 | 32 | def processed_promptmd5(statement, template): 33 | ORIGINAL = statement 34 | prompt = template.replace("[[ORIGINAL]]", ORIGINAL).strip() 35 | return hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8] 36 | 37 | import numpy as np 38 | 39 | def problem_embeds(problem_file_cur): 40 | try: 41 | problem = read_problem(problem_file_cur) 42 | # load from corresponding npy 43 | except Exception as e: 44 | print('error',problem_file_cur,e) 45 | return None, None 46 | try: 47 | embeds = [] 48 | with open(problem_file_cur.replace(".json", ".vopkl"), "rb") as f: 49 | embeds = pickle.load(f) 50 | except: 51 | pass 52 | return problem, embeds 53 | 54 | # quick and dirty vector database implementation 55 | class VectorDB: 56 | def __init__(self): 57 | pass 58 | 59 | def load_all(self, shuffle = False, load_around = None, record_tasks=False, skipped_sources = []): 60 | self.arr = [] 61 | self.metadata = [] 62 | self.todos = [] 63 | self.sources = {} 64 | fns = list(problem_filenames()) 65 | if shuffle: 66 | random.shuffle(fns) 67 | for problem_file_cur in tqdm(fns): 68 | # if '洛谷' not in problem_file_cur: 69 | # continue 70 | if load_around is not None and len(self.arr) > load_around * 2: 71 | break 72 | if not record_tasks and not os.path.exists(problem_file_cur.replace(".json", ".vopkl")): 73 | continue 74 | problem, embeds = problem_embeds(problem_file_cur) 75 | if problem is None: 76 | continue 77 | statement = problem['statement'] 78 | source = problem['source'] 79 | if source in skipped_sources: 80 | continue 81 | self.sources[source] = self.sources.get(source, 0) + 1 82 | need_work = False 83 | for template in settings["TEMPLATES"]: 84 | md5 = processed_promptmd5(statement, template) 85 | found = False 86 | for m, u in embeds: 87 | if m[:8] == md5: 88 | found = True 89 | self.arr.append(np.array(u/np.linalg.norm(u),dtype=np.float16)) 90 | self.metadata.append((problem_file_cur, source, len(statement.strip()))) 91 | break 92 | if not found: 93 | need_work = True 94 | if need_work and record_tasks: 95 | self.todos.append(problem_file_cur) 96 | print('found',len(self.arr),'embeds') 97 | self.arr = np.array(self.arr,dtype=np.float16) 98 | if record_tasks: 99 | print('found',len(self.todos),'todos') 100 | 101 | 102 | def complete_todos(self, chunk_size = 200, length_limit = 1300, shuffle = False): 103 | todos = self.todos 104 | if shuffle: 105 | import random 106 | random.shuffle(todos) 107 | for i in tqdm(range(0,len(todos),chunk_size)): 108 | problems = todos[i:i+chunk_size] 109 | infos = {} 110 | for problem_file_cur in problems: 111 | try: 112 | full_problem = read_problem(problem_file_cur) 113 | statement = full_problem['statement'] 114 | # load from corresponding npy 115 | except Exception as e: 116 | print('error',problem_file_cur,e) 117 | continue 118 | try: 119 | embeds = [] 120 | with open(problem_file_cur.replace(".json", ".vopkl"), "rb") as f: 121 | embeds = pickle.load(f) 122 | except: 123 | pass 124 | infos[problem_file_cur] = full_problem.get('processed',[]), statement, embeds 125 | for template in settings["TEMPLATES"]: 126 | queues = [] 127 | max_length = 0 128 | for problem_file_cur, (processed, statement, embeds) in infos.items(): 129 | md5 = processed_promptmd5(statement, template) 130 | if any(m[:8] == md5 for m, u in embeds): continue 131 | # get processed 132 | processed_text = None 133 | for f in processed: 134 | if f["prompt_md5"][:8] == md5: 135 | if len(f['result']) > length_limit: 136 | continue # too long? 137 | processed_text = f["result"] 138 | max_length = max(max_length, len(processed_text)) 139 | if processed_text is None: 140 | continue 141 | queues.append((processed_text, problem_file_cur, md5)) 142 | if len(queues) == 0: 143 | continue 144 | print('batch',len(queues),' maxlen',max_length) 145 | try: 146 | t0 = time.time() 147 | response = voyage_client.embed( 148 | [ 149 | x[0] for x in queues 150 | ], 151 | model="voyage-large-2-instruct", 152 | input_type='document' 153 | ) 154 | print('Token spent',response.total_tokens) 155 | t1 = time.time() 156 | # wait till 0.5s 157 | if t1 - t0 < 0.2: 158 | time.sleep(0.2 - (t1 - t0)) 159 | for q,e in zip(queues, response.embeddings): 160 | infos[q[1]][2].append((q[2], np.array(e))) 161 | except Exception as e: 162 | print('error',e) 163 | for problem_file_cur, (processed, statement, embeds) in infos.items(): 164 | dump_pickle_safe(embeds, problem_file_cur.replace(".json", ".vopkl")) 165 | 166 | 167 | def query_nearest(self, emb, k=1000, dedup=True): 168 | # return the k nearest embeddings with cosine similarity 169 | # return a list of (cosine similarity, metadata) tuples 170 | # the list is sorted by cosine similarity 171 | # normailze emb 172 | emb = np.array(emb) 173 | if len(emb.shape) == 1: 174 | emb = emb[None, :] 175 | emb = emb / np.linalg.norm(emb, axis=1, keepdims=True) 176 | emb = np.array(emb, dtype=np.float16) 177 | sims = np.max(self.arr @ emb.T, axis=1) 178 | sims = np.clip((sims+1)/2, 0, 1) # [-1,1] -> [0,1] 179 | topk = np.argsort(sims)[::-1] 180 | nearest = [] 181 | keys = set() 182 | # print(f'query nearest {len(emb)=} {len(sims)=} {len(topk)=} {k=}') 183 | for i in topk: 184 | if dedup: 185 | key = self.metadata[i][0] 186 | if key in keys: 187 | continue 188 | keys.add(key) 189 | nearest.append((sims[i], i)) 190 | if len(nearest) >= k: 191 | break 192 | return nearest 193 | 194 | if __name__ == "__main__": 195 | db = VectorDB() 196 | db.load_all(record_tasks=True) 197 | db.complete_todos(chunk_size=128) 198 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | anthropic==0.32.0 2 | beautifulsoup4==4.11.1 3 | fastapi==0.112.0 4 | gradio==4.41.0 5 | lingua==4.15.0 6 | numpy==1.24.4 7 | openai==1.40.2 8 | Requests==2.32.3 9 | together==1.2.7 10 | tqdm==4.66.4 11 | uvicorn==0.30.5 12 | voyageai==0.2.3 13 | -------------------------------------------------------------------------------- /src/ui.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .embedder import VectorDB, processed_promptmd5 3 | from .utils import read_problem 4 | from tqdm.auto import tqdm 5 | import gradio as gr 6 | import json 7 | import asyncio 8 | from openai import AsyncOpenAI 9 | from together import AsyncTogether 10 | from fastapi import FastAPI 11 | from fastapi.responses import HTMLResponse, FileResponse 12 | import uvicorn 13 | import urllib 14 | import time 15 | import voyageai 16 | from concurrent.futures import ThreadPoolExecutor 17 | executor = ThreadPoolExecutor(max_workers=8) 18 | 19 | db = VectorDB() 20 | db.load_all() 21 | print("read", len(set(x[0] for x in db.metadata)), "problems") 22 | print(db.metadata[:100]) 23 | 24 | with open("settings.json") as f: 25 | settings = json.load(f) 26 | 27 | voyage_client = voyageai.Client( 28 | api_key=settings['VOYAGE_API_KEY'], 29 | max_retries=3, 30 | timeout=120, 31 | ) 32 | 33 | openai_client = AsyncOpenAI( 34 | api_key=settings["OPENAI_API_KEY"], 35 | ) 36 | 37 | together_client = AsyncTogether( 38 | api_key=settings['TOGETHER_API_KEY'], 39 | ) 40 | 41 | async def querier_i18n(locale, statement, *template_choices): 42 | assert len(template_choices) % 3 == 0 43 | yields = [] 44 | ORIGINAL = statement.strip() 45 | t1 = time.time() 46 | 47 | async def process_template(engine, prompt, prefix): 48 | if 'origin' in engine.lower() or '保' in engine.lower(): 49 | return ORIGINAL 50 | if 'none' in engine.lower() or '跳' in engine.lower(): 51 | return '' 52 | 53 | prompt = prompt.replace("[[ORIGINAL]]", ORIGINAL).strip() 54 | 55 | if "gemma" in engine.lower(): 56 | response = await together_client.chat.completions.create( 57 | messages=[ 58 | {"role": "user", "content": prompt}, 59 | {"role": "assistant", "content": prefix} 60 | ], 61 | model="google/gemma-2-27b-it", 62 | ) 63 | return response.choices[0].message.content.strip() 64 | elif "gpt" in engine.lower(): 65 | response = await openai_client.chat.completions.create( 66 | messages=[ 67 | {"role": "system", "content": "You are a helpful assistant."}, 68 | {"role": "user", "content": prompt}, 69 | ], 70 | model="gpt-4o-mini" 71 | ) 72 | return response.choices[0].message.content.strip().replace(prefix.strip(), '', 1).strip() 73 | else: 74 | raise NotImplementedError(engine) 75 | 76 | tasks = [process_template(template_choices[i], template_choices[i+1], template_choices[i+2]) 77 | for i in range(0, len(template_choices), 3)] 78 | yields = await asyncio.gather(*tasks) 79 | 80 | t2 = time.time() 81 | print('query llm', t2-t1) 82 | response = voyage_client.embed( 83 | list(set(y.strip() for y in yields if len(y))), 84 | model="voyage-large-2-instruct", 85 | input_type='query' 86 | ) 87 | print('Token spent',response.total_tokens) 88 | emb = [d for d in response.embeddings] 89 | t3 = time.time() 90 | print('query emb', t3-t2) 91 | 92 | loop = asyncio.get_running_loop() 93 | nearest = await loop.run_in_executor(executor, db.query_nearest, emb, 5000) 94 | # nearest = db.query_nearest(emb, k=5000) 95 | t4 = time.time() 96 | print('query nearest', t4-t3) 97 | 98 | sim = np.array([x[0] for x in nearest]) 99 | ids = np.array([x[1] for x in nearest], dtype=np.int32) 100 | 101 | info = 'Fetched top ' + str(len(sim)) + ' matches! Go to the next tab to view results~' if locale == 'en' else \ 102 | '已查找到前' + str(len(sim)) + '个匹配!进入下一页查看结果~' 103 | 104 | return [info, (sim, ids)] + yields 105 | 106 | 107 | def format_problem_i18n(locale, uid, sim): 108 | def tr(en,zh): 109 | if locale == 'en': return en 110 | if locale == 'zh': return zh 111 | raise NotImplementedError(locale) 112 | # be careful about arbitrary reads 113 | uid = db.metadata[int(uid)][0] 114 | problem = read_problem(uid) 115 | statement = problem["statement"].replace("\n", "\n\n") 116 | # summary = sorted(problem.get("processed",[]), key=lambda t: t["template_md5"]) 117 | # if len(summary): 118 | # summary = summary[0]["result"] 119 | # else: 120 | # summary = None 121 | title = problem['title'] 122 | lang = problem.get('locale',('un', 'Unknown')) 123 | def to_flag(t,u): 124 | if t == 'un': 125 | # get a ? with border, 14x20 126 | return f"""
?
""" 127 | else: 128 | return f"""{u}""" 136 | # flag = ''.join(to_flag(t) for t in lang_mapper.values()) # debug only 137 | flag = to_flag(*lang) 138 | url = problem["url"] 139 | problemlink = uid.replace('/',' ').replace('\\',' ').strip().replace('problems vjudge','',1).strip().replace('_','-') 140 | assert problemlink.endswith('.json') 141 | problemlink = problemlink[:-5].strip() 142 | # markdown = f"# [{title} ({problemlink})]({url})\n\n" 143 | html = f'

{title}  {problemlink} ({round(sim*100)}%)

\n' 144 | link0 = 'https://www.google.com/search?'+urllib.parse.urlencode({'q': problemlink}) 145 | link1 = 'https://www.google.com/search?'+urllib.parse.urlencode({'q': problem['source']+' '+title}) 146 | link0_bd = 'https://www.baidu.com/s?'+urllib.parse.urlencode({'wd': problemlink}) 147 | link1_bd = 'https://www.baidu.com/s?'+urllib.parse.urlencode({'wd': problem['source']+' '+title}) 148 | # {tr("Google2","谷歌2")} 149 | # Baidu2 150 | html += f'{flag}   VJudge  {tr("Google","谷歌")}  {tr("Baidu","百度")}' 151 | markdown = '' 152 | rsts = [] 153 | for template in settings['TEMPLATES']: 154 | md5 = processed_promptmd5(problem['statement'], template) 155 | rst = None 156 | for t in problem.get("processed",[]): 157 | if t["prompt_md5"][:8] == md5: 158 | rst = t["result"] 159 | if rst is not None: 160 | rsts.append(rst) 161 | rsts.sort(key=len) 162 | for idx, rst in enumerate(rsts): 163 | markdown += f'### {tr("Summary", "简要题意")} {idx+1}\n\n{rst}\n\n' 164 | if markdown != '': 165 | markdown += '
\n\n' 166 | markdown += f'### {tr("Raw Statement", "原始题面")}\n\n{statement}' 167 | return html, markdown 168 | 169 | def get_block(locale): 170 | def tr(en,zh): 171 | if locale == 'en': return en 172 | if locale == 'zh': return zh 173 | raise NotImplementedError(locale) 174 | 175 | with gr.Blocks( 176 | title=tr("Is my problem new?","原题机"), css=""" 177 | .mymarkdown {font-size: 15px !important} 178 | footer{display:none !important} 179 | .centermarkdown{text-align:center !important} 180 | .pagedisp{text-align:center !important; font-size: 20px !important} 181 | .realfooter{color: #888 !important; font-size: 14px !important; text-align: center !important;} 182 | .realfooter a{color: #888 !important;} 183 | .smallbutton {min-width: 30px !important;} 184 | """, 185 | head=settings.get('CUSTOM_HEADER','') 186 | ) as demo: 187 | gr.Markdown( 188 | tr(""" 189 | # Is my problem new? 190 | A semantic search engine for competitive programming problems. 191 | """,""" 192 | # 原题机 193 | 原题在哪里啊,原题在这里~""" 194 | )) 195 | with gr.Tabs() as tabs: 196 | with gr.TabItem(tr("Search",'搜索'),id=0): 197 | input_text = gr.TextArea( 198 | label=tr("Statement",'题目描述'), 199 | info=tr("Paste your statement here!",'在这里粘贴你要搜索的题目!'), 200 | value=tr("Calculate the longest increasing subsequence of the input sequence.", 201 | '计算最长上升子序列长度。'), 202 | ) 203 | bundles = [] 204 | with gr.Accordion(tr("Rewriting Setup (Advanced)","高级设置"), open=False): 205 | gr.Markdown(tr("Several rewritten version of the original statement will be calculated and the maximum embedding similarity is used for sorting.", 206 | "输入的问题描述将被重写为多个版本并计算与每个原问题的最大相似度。")) 207 | for template_id in range(5): 208 | with gr.Accordion(tr("Template ",'版本 ')+str(template_id+1)): 209 | with gr.Row(): 210 | with gr.Group(): 211 | template = settings['TEMPLATES'][(template_id-1)%2] if template_id in [1,2] else None 212 | # engines = [tr("Keep Original",'保留原描述'), "Gemma 2 (27B)", "GPT4o Mini", tr('None', '跳过该版本')] 213 | # engine = gr.Radio( 214 | # engines, 215 | # label=tr("Engine",'使用的语言模型'), 216 | # value=engines[-1] if template is None else engines[1] if template_id<=2 else engines[2], 217 | # interactive=True, 218 | # ) 219 | engines = [tr("Keep Original",'保留原描述'), "GPT4o Mini", tr('None', '跳过该版本')] 220 | engine = gr.Radio( 221 | engines, 222 | label=tr("Engine",'使用的语言模型'), 223 | value=engines[-1] if template is None else engines[1],# if template_id<=2 else engines[2], 224 | interactive=True, 225 | ) 226 | prompt = gr.TextArea( 227 | label=tr("Prompt ([[ORIGINAL]] will be replaced)",'提示词 ([[ORIGINAL]] 将被替换为问题描述)'), 228 | value=template if template is not None else settings['TEMPLATES'][0], 229 | interactive=True, 230 | visible=template is not None, 231 | ) 232 | prefix = gr.Textbox( 233 | label=tr("Prefix", '回复前缀'), 234 | value="Simplified statement:", 235 | interactive=True, 236 | visible=template is not None, 237 | ) 238 | # hide these when engine has wrong value 239 | engine.change(lambda engine: (gr.update(visible=any(s in engine.lower() for s in ['gpt','gemma'])),)*2, engine, [prompt, prefix]) 240 | output_text = gr.TextArea( 241 | label=tr('Output','重写结果'), 242 | value="", 243 | interactive=False, 244 | ) 245 | bundles.append((engine, prompt, prefix, output_text)) 246 | search_result = gr.State(([],[])) 247 | submit_button = gr.Button(tr("Search!",'搜索!')) 248 | status_text = gr.Markdown("", elem_classes="centermarkdown") 249 | with gr.TabItem(tr("View Results",'查看结果'),id=1): 250 | cur_idx = gr.State(0) 251 | num_columns = gr.State(1) 252 | ojs = [f'{t} ({c})' for t,c in sorted(db.sources.items())] 253 | oj_dropdown = gr.Dropdown( 254 | ojs, value=ojs, multiselect=True, label=tr("Displayed OJs",'展示的OJ'), 255 | info=tr('Problems from OJ not in this list will be ignored.', 256 | '不在这个列表里的OJ的题目将被忽略。可以在这里删掉你不认识的OJ。'), 257 | ) 258 | # on change, change cur_idx to 1 259 | oj_dropdown.change(lambda: 0, None, cur_idx) 260 | statement_min_len = gr.Slider( 261 | minimum=1, 262 | maximum=1000, 263 | label=tr("Minimum Statement Length",'最小题面长度'), 264 | value=20, 265 | info=tr('The statements shorter than this after removing digits + blanks will be ignored. Useful for filtering out meaningless statements.', 266 | '去除数字和空白字符后题面长度小于该值的题目将被忽略。可以用来筛掉一些奇怪的题面。'), 267 | ) 268 | 269 | with gr.Row(): 270 | # home_page = gr.Button("H") 271 | add_column = gr.Button("+", elem_classes='smallbutton') 272 | prev_page = gr.Button("←", elem_classes='smallbutton') 273 | home_page = gr.Button("H", elem_classes='smallbutton') 274 | next_page = gr.Button("→", elem_classes='smallbutton') 275 | remove_column = gr.Button("-", elem_classes='smallbutton') 276 | # bind to cur_page and num_columns 277 | # home_page.click(lambda: 1, None, cur_page) 278 | prev_page.click(lambda cur_idx, num_columns: max(cur_idx - num_columns, 0), [cur_idx, num_columns], cur_idx, concurrency_limit=None) 279 | next_page.click(lambda cur_idx, num_columns: cur_idx + num_columns, [cur_idx, num_columns], cur_idx, concurrency_limit=None) 280 | home_page.click(lambda: 0, None, cur_idx, concurrency_limit=None) 281 | def adj_idx(idx, col): 282 | return int(round(idx / col)) * col 283 | add_column.click(lambda cur_idx, num_columns: (adj_idx(cur_idx, num_columns + 1), num_columns + 1), [cur_idx, num_columns], [cur_idx, num_columns], concurrency_limit=None) 284 | remove_column.click(lambda cur_idx, num_columns: (adj_idx(cur_idx, num_columns - 1), num_columns - 1) if num_columns >1 else (cur_idx, num_columns), [cur_idx, num_columns], [cur_idx, num_columns], concurrency_limit=None) 285 | 286 | 287 | @gr.render(inputs=[search_result, oj_dropdown, cur_idx, num_columns, statement_min_len], concurrency_limit=None) 288 | def show_OJs(search_result, oj_dropdown, cur_idx, num_columns, statement_min_len): 289 | allowed_OJs = set([oj[:oj.find(' (')] for oj in oj_dropdown]) 290 | tot = 0 291 | # print(len(search_result[0]),len(search_result[1])) 292 | for sim, idx in zip(search_result[0], search_result[1]): 293 | if db.metadata[idx][1] not in allowed_OJs or db.metadata[idx][2] < statement_min_len: 294 | continue 295 | tot += 1 296 | gr.Markdown(tr(f"Page {round(cur_idx/num_columns)+1} of {(tot+num_columns-1)//num_columns} ({num_columns} per page)", 297 | f'第 {round(cur_idx/num_columns)+1} 页 / 共 {(tot+num_columns-1)//num_columns} 页 (每页显示 {num_columns} 个)'), 298 | elem_classes="pagedisp") 299 | cnt = 0 300 | with gr.Row(): 301 | for sim, idx in zip(search_result[0], search_result[1]): 302 | if db.metadata[idx][1] not in allowed_OJs or db.metadata[idx][2] < statement_min_len: 303 | continue 304 | cnt += 1 305 | if cur_idx+1 <= cnt: 306 | if cnt > cur_idx+num_columns: break 307 | with gr.Column(variant='compact'): 308 | html, md = format_problem_i18n(locale, idx, sim) 309 | gr.HTML(html) 310 | gr.Markdown( 311 | latex_delimiters=[ 312 | {"left": "$$", "right": "$$", "display": True}, 313 | {"left": "$", "right": "$", "display": False}, 314 | {"left": "\\(", "right": "\\)", "display": False}, 315 | {"left": "\\[", "right": "\\]", "display": True}, 316 | ], 317 | value=md, 318 | elem_classes="mymarkdown", 319 | ) 320 | if 'CUSTOM_ABOUT_PY' in settings and settings['CUSTOM_ABOUT_PY'].endswith('.py'): 321 | with gr.TabItem(tr("About",'关于'),id=2): 322 | with open(settings['CUSTOM_ABOUT_PY'], 'r', encoding='utf-8') as f: eval(f.read()) 323 | 324 | # add a footer 325 | gr.HTML( 326 | """
Built with ❤️ by @TLE
""" 327 | ) 328 | async def async_querier_wrapper(*args): 329 | result = await querier_i18n(locale, *args) 330 | return (gr.Tabs(selected=1),) + tuple(result) 331 | submit_button.click( 332 | fn=async_querier_wrapper, 333 | inputs=sum([list(t[:-1]) for t in bundles], [input_text]), 334 | outputs=[tabs, status_text, search_result] + [t[-1] for t in bundles], 335 | concurrency_limit=7, 336 | ) 337 | # output_labels.select(fn=show_problem, inputs=None, outputs=[my_markdown]) 338 | return demo 339 | 340 | 341 | 342 | app = FastAPI() 343 | favicon_path = 'favicon.ico' 344 | 345 | @app.get('/favicon.ico', include_in_schema=False) 346 | async def favicon(): 347 | return FileResponse(favicon_path) 348 | @app.get("/", response_class=HTMLResponse) 349 | async def read_main(): 350 | html_content = """ 351 | 352 | 353 | 354 | 355 | 356 | Is my problem new? 357 | 367 | 368 | 369 |

Redirecting based on your browser's locale...

370 |

English | 中文

371 | 372 | 373 | """ 374 | return HTMLResponse(content=html_content) 375 | 376 | app = gr.mount_gradio_app(app, get_block('zh'), path="/zh") 377 | app = gr.mount_gradio_app(app, get_block('en'), path="/en") 378 | 379 | if __name__ == "__main__": 380 | uvicorn.run(app, host="0.0.0.0", port=80) 381 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import os 4 | import shutil 5 | import typing 6 | import bs4 7 | import numpy as np 8 | import tempfile 9 | 10 | 11 | # https://stackoverflow.com/a/66835172 12 | def get_text(tag: bs4.Tag) -> str: 13 | _inline_elements = { 14 | "a", 15 | "span", 16 | "em", 17 | "strong", 18 | "u", 19 | "i", 20 | "font", 21 | "mark", 22 | "label", 23 | "s", 24 | "sub", 25 | "sup", 26 | "tt", 27 | "bdo", 28 | "button", 29 | "cite", 30 | "del", 31 | "b", 32 | "a", 33 | "font", 34 | } 35 | 36 | def _get_text(tag: bs4.Tag) -> typing.Generator: 37 | for child in tag.children: 38 | if isinstance(child, bs4.Tag): 39 | # if the tag is a block type tag then yield new lines before after 40 | is_block_element = child.name not in _inline_elements 41 | if is_block_element: 42 | yield "\n" 43 | yield from ["\n"] if child.name == "br" else _get_text(child) 44 | if is_block_element: 45 | yield "\n" 46 | elif isinstance(child, bs4.NavigableString): 47 | yield child.string 48 | 49 | return "".join(_get_text(tag)) 50 | 51 | def cleanup_str(s: str, allow_double_breaks = False) -> str: 52 | s = '\n'.join( line.strip() for line in s.splitlines() ).strip() 53 | # remove redundant linebreaks 54 | while True: 55 | if allow_double_breaks: 56 | ss = s.replace('\n\n\n','\n\n') 57 | else: 58 | ss = s.replace('\n\n','\n') 59 | if ss == s: break 60 | s = ss 61 | return s 62 | 63 | def read_problem(filename): 64 | # read as a json 65 | with open(filename, encoding='utf-8') as f: 66 | return json.load(f) 67 | 68 | def problem_filenames(path='problems/'): 69 | for root, dirs, files in os.walk(path): 70 | for filename in files: 71 | if not filename.endswith(".json"): 72 | continue 73 | yield os.path.join(root, filename) 74 | 75 | 76 | def list_problems(embed = False): 77 | # list all problems under problems/ 78 | for problem_filename in problem_filenames(): 79 | assert problem_filename.endswith(".json") 80 | if not embed: 81 | yield problem_filenames 82 | continue 83 | npy_filename = problem_filename[:-5] + ".npy" 84 | if os.path.exists(npy_filename): 85 | yield read_problem(problem_filename), np.load(npy_filename) 86 | 87 | 88 | def dump_json_safe(obj, filename): 89 | with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: 90 | json.dump(obj, f) 91 | shutil.move(f.name, filename) 92 | 93 | 94 | def dump_json_safe_utf8(obj, filename): 95 | with tempfile.NamedTemporaryFile(mode="w", delete=False, encoding='utf-8') as f: 96 | json.dump(obj, f, ensure_ascii=False) 97 | shutil.move(f.name, filename) 98 | 99 | def dump_pickle_safe(obj, filename): 100 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: 101 | pickle.dump(obj, f, protocol=4) 102 | shutil.move(f.name, filename) 103 | 104 | 105 | def dump_numpy_safe(obj, filename): 106 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: 107 | np.save(f, obj) 108 | shutil.move(f.name, filename) 109 | --------------------------------------------------------------------------------