├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── logo.gif
├── problems.txt
├── problems
└── 1000.json
├── settings_sample.json
└── src
├── build_locale.py
├── build_summary.py
├── custom_about.py
├── embedder.py
├── requirements.txt
├── ui.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | embs*/
2 | *.bak*
3 | tmp/
4 | runit.bat
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 |
12 | .DS_Store
13 | settings.json
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 | /problems/vjudge
163 | /vjudge
164 | /problemso
165 | /vjudge---
166 | /vjudge--
167 | /vjudge-
168 | scrapped.tar.gz
169 | skp.7z
170 | src/vjudge_helper.py
171 | is-my-problem-new.7z
172 | perfopt.ipynb
173 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ziqian Zhong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Is my problem new?
2 | A simple semantic search engine on competitive programming problems.
3 | http://yuantiji.ac | Buy me a boba
4 |
5 |
6 |
7 | **Update (2024/7/16):** It has been a long time :) Reorganized problems path. Switched LLM / embedder to [Gemma 2 9B](https://huggingface.co/google/gemma-2-9b-it) hosted by [together.ai](https://docs.together.ai) and [voyage-large-2-instruct](https://docs.voyageai.com/docs/pricing). Tweaked the prompt a little bit. Bought a new domain (see the link above) and switched to [vjudge](https://vjudge.net) as data source. See branch `old_ver` or history commits for the previous version.
8 |
9 | **Update (2024/5/19):** Added AtCoder. Thanks [@fstqwq](https://github.com/fstqwq) for the contribution!
10 |
11 | #### How does this work?
12 |
13 | This idea is simple:
14 |
15 | 1. Simplify the statement & remove background by prompting LLM.
16 |
17 | 2. Embed the simplified documents and queries to perform vector searches.
18 |
19 | It only happens recently that both models are good and cheap enough.
20 |
21 | This pipeline is also not limited, of course, to competitive programming problems. You can use it to search for any kind of documents by modifying the prompt.
22 |
23 | #### Deploy
24 |
25 | You will need API keys from OpenAI, Together and Voyage. You can check their pricings online.
26 |
27 | Put problems in `problems/` folder following the given example (`problems/1000.json`). Naming could be arbitrary and you could also have nested folders. Run `python -m src.build_summary` to get paraphrased statements, run `python -m src.build_embedding` to build embeddings and run `python -m src.build_locale` to detect language of problems. Finally, run `python -m src.ui` to start serving.
28 |
29 | For large-scale running decent CPUs are needed as doing vector searching is CPU-dense. You might also want to modify `max_workers` in `src/ui.py`.
30 |
31 | Due to copyright concerns we're not providing scrapped vjudge problems and vjudge scrapper. Sorry D: We also did not process the statements in PDF. If you have problems you want to add that are not supported in vjudge feel free to contact me or send PR and I'll see what I can do (would be perfect if you can just send over a zip in the correct format xd).
32 |
33 | For reference, adding all ~160k problems from vjudge cost ~$60 and as of writing the deployed site is running on a 8vCPU server.
34 |
--------------------------------------------------------------------------------
/logo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fjzzq2002/is-my-problem-new/9e6feba786d86a9942d12ec0309d0adc2c00811d/logo.gif
--------------------------------------------------------------------------------
/problems/1000.json:
--------------------------------------------------------------------------------
1 | {"uid": "51Nod/1000", "url": "https://vjudge.net/problem/51Nod-1000", "tags": [], "title": "A + B", "statement": "给出 $2$ 个整数 $A$ 和 $B$ ,计算两个数的和。", "source": "51Nod", "vjudge": true}
--------------------------------------------------------------------------------
/settings_sample.json:
--------------------------------------------------------------------------------
1 | {
2 | "OPENAI_API_KEY": "[TODO]",
3 | "TOGETHER_API_KEY": "[TODO]",
4 | "VOYAGE_API_KEY": "[TODO]",
5 | "TEMPLATES": [
6 | "I have the following competitive programming problem that I want to show someone else:\n\n=======\n[[ORIGINAL]]\n=======\n\nStrip off all the stories, legends, characters, backgrounds etc. from the statement while still enabling everyone to understand the problem. Also remove the name of the character if possible. This is to say, do not remove anything necessary to understand the full problem and one should feel safe to replace the original statement with your version of the statement. If it is not in English make it English. Provide the simplified statement directly without jargon. Use mathjax ($...$) for math. Start your response with \"Simplified statement:\".",
7 | "I have the following competitive programming problem that I want to show someone else:\n\n=======\n[[ORIGINAL]]\n=======\n\nStrip off all the stories, legends, characters, backgrounds, examples, well-known definitions etc. from the statement while still enabling everyone to understand the problem. Also remove the name of the character if applicable. If it is not in English translate it. Make it as succinct as possible while still being understandable. Try to avoid formulas and symbols. Abstract freely - for example, if the problem is about buying sushi, you can just phrase it as a knapsack problem. If necessary, mathjax ($...$) for math. Provide the *succinct* simplified statement directly without jargon. Start your response with \"Simplified statement:\"."
8 | ],
9 | "CUSTOM_HEADER": "",
10 | "CUSTOM_ABOUT_PY": "src/custom_about.py"
11 | }
--------------------------------------------------------------------------------
/src/build_locale.py:
--------------------------------------------------------------------------------
1 | # run summarize for all the problems
2 | # use the chatgpt api
3 | import requests
4 | import json
5 | from .utils import read_problem, problem_filenames, dump_json_safe, dump_json_safe_utf8
6 | from openai import AsyncOpenAI
7 | from together import AsyncTogether
8 | import anthropic
9 | import hashlib
10 | import asyncio
11 | from tqdm.auto import tqdm
12 |
13 | # from tqdm import tqdm
14 | import time
15 |
16 | start_time = time.time()
17 |
18 |
19 | with open("settings.json") as f:
20 | settings = json.load(f)
21 |
22 |
23 | from lingua import Language, LanguageDetectorBuilder
24 | detector = LanguageDetectorBuilder.from_all_languages().with_minimum_relative_distance(0.9).build()
25 | # ty chatgpt
26 | lang_map = {
27 | Language.AFRIKAANS: ("za", "Afrikaans"),
28 | Language.ALBANIAN: ("al", "Albanian"),
29 | Language.ARABIC: ("sa", "Arabic"),
30 | Language.ARMENIAN: ("am", "Armenian"),
31 | Language.AZERBAIJANI: ("az", "Azerbaijani"),
32 | Language.BASQUE: ("eu", "Basque"),
33 | Language.BELARUSIAN: ("by", "Belarusian"),
34 | Language.BENGALI: ("bd", "Bengali"),
35 | Language.BOKMAL: ("no", "Bokmål"),
36 | Language.BOSNIAN: ("ba", "Bosnian"),
37 | Language.BULGARIAN: ("bg", "Bulgarian"),
38 | Language.CATALAN: ("ad", "Catalan"),
39 | Language.CHINESE: ("cn", "Chinese"),
40 | Language.CROATIAN: ("hr", "Croatian"),
41 | Language.CZECH: ("cz", "Czech"),
42 | Language.DANISH: ("dk", "Danish"),
43 | Language.DUTCH: ("nl", "Dutch"),
44 | Language.ENGLISH: ("gb", "English"),
45 | Language.ESPERANTO: ("eo", "Esperanto"),
46 | Language.ESTONIAN: ("ee", "Estonian"),
47 | Language.FINNISH: ("fi", "Finnish"),
48 | Language.FRENCH: ("fr", "French"),
49 | Language.GANDA: ("ug", "Ganda"),
50 | Language.GEORGIAN: ("ge", "Georgian"),
51 | Language.GERMAN: ("de", "German"),
52 | Language.GREEK: ("gr", "Greek"),
53 | Language.GUJARATI: ("in", "Gujarati"),
54 | Language.HEBREW: ("il", "Hebrew"),
55 | Language.HINDI: ("in", "Hindi"),
56 | Language.HUNGARIAN: ("hu", "Hungarian"),
57 | Language.ICELANDIC: ("is", "Icelandic"),
58 | Language.INDONESIAN: ("id", "Indonesian"),
59 | Language.IRISH: ("ie", "Irish"),
60 | Language.ITALIAN: ("it", "Italian"),
61 | Language.JAPANESE: ("jp", "Japanese"),
62 | Language.KAZAKH: ("kz", "Kazakh"),
63 | Language.KOREAN: ("kr", "Korean"),
64 | Language.LATIN: ("va", "Latin"),
65 | Language.LATVIAN: ("lv", "Latvian"),
66 | Language.LITHUANIAN: ("lt", "Lithuanian"),
67 | Language.MACEDONIAN: ("mk", "Macedonian"),
68 | Language.MALAY: ("my", "Malay"),
69 | Language.MAORI: ("nz", "Maori"),
70 | Language.MARATHI: ("in", "Marathi"),
71 | Language.MONGOLIAN: ("mn", "Mongolian"),
72 | Language.NYNORSK: ("no", "Nynorsk"),
73 | Language.PERSIAN: ("ir", "Persian"),
74 | Language.POLISH: ("pl", "Polish"),
75 | Language.PORTUGUESE: ("pt", "Portuguese"),
76 | Language.PUNJABI: ("in", "Punjabi"),
77 | Language.ROMANIAN: ("ro", "Romanian"),
78 | Language.RUSSIAN: ("ru", "Russian"),
79 | Language.SERBIAN: ("rs", "Serbian"),
80 | Language.SHONA: ("zw", "Shona"),
81 | Language.SLOVAK: ("sk", "Slovak"),
82 | Language.SLOVENE: ("si", "Slovene"),
83 | Language.SOMALI: ("so", "Somali"),
84 | Language.SOTHO: ("za", "Sotho"),
85 | Language.SPANISH: ("es", "Spanish"),
86 | Language.SWAHILI: ("ke", "Swahili"),
87 | Language.SWEDISH: ("se", "Swedish"),
88 | Language.TAGALOG: ("ph", "Tagalog"),
89 | Language.TAMIL: ("in", "Tamil"),
90 | Language.TELUGU: ("in", "Telugu"),
91 | Language.THAI: ("th", "Thai"),
92 | Language.TSONGA: ("za", "Tsonga"),
93 | Language.TSWANA: ("bw", "Tswana"),
94 | Language.TURKISH: ("tr", "Turkish"),
95 | Language.UKRAINIAN: ("ua", "Ukrainian"),
96 | Language.URDU: ("pk", "Urdu"),
97 | Language.VIETNAMESE: ("vn", "Vietnamese"),
98 | Language.WELSH: ("gb", "Welsh"),
99 | Language.XHOSA: ("za", "Xhosa"),
100 | Language.YORUBA: ("ng", "Yoruba"),
101 | Language.ZULU: ("za", "Zulu"),
102 | }
103 | def process_all_problems():
104 | fns = list(problem_filenames())
105 | for problem_file_cur in tqdm(fns):
106 | try:
107 | p = read_problem(problem_file_cur)
108 | except Exception as e:
109 | print('error',problem_file_cur,e)
110 | continue
111 | if 'locale' in p:
112 | continue
113 | detected = detector.detect_language_of(p["title"]+'\n'+p['statement'])
114 | rst = None
115 | if detected is None:
116 | rst = ('un', 'Unknown')
117 | else:
118 | rst = lang_map[detected]
119 | p['locale'] = rst
120 | dump_json_safe_utf8(p, problem_file_cur)
121 |
122 | if __name__ == "__main__":
123 | process_all_problems()
--------------------------------------------------------------------------------
/src/build_summary.py:
--------------------------------------------------------------------------------
1 | # run summarize for all the problems
2 | # use the chatgpt api
3 | import requests
4 | import json
5 | from .utils import read_problem, problem_filenames, dump_json_safe, dump_json_safe_utf8
6 | from openai import AsyncOpenAI
7 | from together import AsyncTogether
8 | import anthropic
9 | import hashlib
10 | import asyncio
11 | from tqdm.auto import tqdm
12 |
13 | # from tqdm import tqdm
14 | import time
15 |
16 | start_time = time.time()
17 |
18 |
19 | with open("settings.json") as f:
20 | settings = json.load(f)
21 |
22 | client = AsyncTogether(
23 | api_key=settings['TOGETHER_API_KEY'],
24 | )
25 |
26 |
27 | def check_processed(p, template):
28 | ORIGINAL = p["statement"]
29 | prompt = template.replace("[[ORIGINAL]]", ORIGINAL).strip()
30 | prompt_md5 = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8]
31 | for f in p["processed"]:
32 | if f["prompt_md5"][:8] == prompt_md5:
33 | return True
34 | return False
35 |
36 |
37 | async def process(p, template, delay = 0):
38 | # sleep for delay first
39 | await asyncio.sleep(delay)
40 | ORIGINAL = p["statement"]
41 | prompt = template.replace("[[ORIGINAL]]", ORIGINAL).strip()
42 | template_md5 = hashlib.md5(template.encode("utf-8")).hexdigest()[:8]
43 | prompt_md5 = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8]
44 | already_processed = False
45 | for f in p["processed"]:
46 | if f["prompt_md5"][:8] == prompt_md5:
47 | already_processed = True
48 | if already_processed:
49 | return
50 | # print(prompt, prompt_md5)
51 | # print(num_tokens_from_string(prompt))
52 | result = None
53 | try:
54 | response = await client.chat.completions.create(
55 | messages=[
56 | {
57 | "role": "user",
58 | "content": prompt,
59 | },
60 | { "role": "assistant", "content": "Simplified statement:" }
61 | ],
62 | model="google/gemma-2-9b-it",
63 | )
64 | # assert chat_completion.stop_reason=='end_turn'
65 | result = response.choices[0].message.content.strip()
66 | print(f"Number of tokens spent: {response.usage.total_tokens}")
67 | except Exception as e:
68 | print("Error while prompting:", e)
69 | if result is None:
70 | return []
71 | return [
72 | {
73 | "prompt_md5": prompt_md5,
74 | "template_md5": template_md5,
75 | "result": result,
76 | }
77 | ]
78 |
79 |
80 | async def process_all_problems():
81 | # apparently some mysterious OJs are spending my money ;_;
82 | goodojs = ['UOJ', 'Codeforces', '洛谷', 'DMOJ', 'HDU', 'CodeChef', 'AtCoder', 'LibreOJ', 'TopCoder', 'SPOJ', '51Nod', '黑暗爆炸', 'UVA'] #, 'USACO'
83 | badojs = ['HYSBZ', 'BZOJ']
84 | fns = sorted(list(problem_filenames()),key=lambda x:int(not any(goodoj.lower() in x.lower() for goodoj in goodojs))+int(any(badoj.lower() in x.lower() for badoj in badojs)))
85 | chunk_size = 50
86 | gap_every = 1/8.5
87 | problem_files = []
88 | for problem_file_cur in tqdm(fns):#tqdm(range(0,len(fns),chunk_size)):
89 | try:
90 | p = read_problem(problem_file_cur)
91 | except Exception as e:
92 | print('error',problem_file_cur,e)
93 | continue
94 | need_work = False
95 | for template in settings["TEMPLATES"]:
96 | if 'processed' in p and check_processed(p, template):
97 | continue
98 | need_work = True
99 | if need_work:
100 | problem_files.append(problem_file_cur)
101 | if len(problem_files) >= chunk_size or problem_file_cur == fns[-1]:
102 | for template in settings["TEMPLATES"]:
103 | t0 = time.time()
104 | tasks = []
105 | notprocessed = []
106 | for idx, problem_file in enumerate(problem_files):
107 | p = read_problem(problem_file)
108 | if "processed" not in p:
109 | p["processed"] = []
110 | if check_processed(p, template):
111 | continue
112 | notprocessed.append(problem_file)
113 | tasks.append(process(p, template, idx * gap_every))
114 | if not len(tasks):
115 | continue
116 | WAIT = chunk_size * gap_every + .5
117 | results = await asyncio.gather(*tasks)
118 | for problem_file, result in zip(notprocessed, results):
119 | if not len(result):
120 | WAIT = 6
121 | continue
122 | p = read_problem(problem_file)
123 | if "processed" not in p:
124 | p["processed"] = []
125 | p["processed"].extend(result)
126 | print(problem_file)
127 | dump_json_safe_utf8(p, problem_file)
128 | t1 = time.time()
129 | print('time elapsed',t1-t0)
130 | # wait till WAIT
131 | if t1-t0 < WAIT:
132 | await asyncio.sleep(WAIT-(t1-t0))
133 | problem_files = []
134 |
135 | if __name__ == "__main__":
136 | asyncio.run(process_all_problems())
--------------------------------------------------------------------------------
/src/custom_about.py:
--------------------------------------------------------------------------------
1 | gr.Markdown(
2 | """English version | 中文版本\n\n"""+
3 | tr('This project is maintained by @TLE. You can find the open source version here. We would like to thank virtual judge for supporting this project!', '本项目由 @TLE 维护。你可以在 这里 找到开源版本。感谢 Virtual Judge 对本项目的大力支持!')+'\n\n'+
4 | tr('Consider donating if the project is helpful for you! All proceeds will be put into maintenance :)','服务器和运营成本都是 @TLE 自己掏的,所以如果你觉得这个项目对你有用欢迎 捐助!')
5 | +'\n\n'+tr('QQ Discussion Group (for Chinese users): ','QQ群:')+'829942060'+tr('',',欢迎前来吹水对线!'))
--------------------------------------------------------------------------------
/src/embedder.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import requests
3 | import json
4 | from .utils import read_problem, problem_filenames, dump_json_safe, dump_json_safe_utf8, dump_pickle_safe
5 | from openai import AsyncOpenAI
6 | from together import Together, AsyncTogether
7 | import anthropic
8 | import hashlib
9 | import asyncio
10 | from tqdm.auto import tqdm
11 |
12 | # from tqdm import tqdm
13 | import time, os
14 | import random
15 | import voyageai
16 |
17 |
18 | with open("settings.json") as f:
19 | settings = json.load(f)
20 |
21 | voyage_client = voyageai.Client(
22 | api_key=settings['VOYAGE_API_KEY'],
23 | max_retries=3,
24 | timeout=120,
25 | )
26 |
27 | # client = Together(
28 | # api_key=settings['TOGETHER_API_KEY'],
29 | # )
30 |
31 |
32 | def processed_promptmd5(statement, template):
33 | ORIGINAL = statement
34 | prompt = template.replace("[[ORIGINAL]]", ORIGINAL).strip()
35 | return hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8]
36 |
37 | import numpy as np
38 |
39 | def problem_embeds(problem_file_cur):
40 | try:
41 | problem = read_problem(problem_file_cur)
42 | # load from corresponding npy
43 | except Exception as e:
44 | print('error',problem_file_cur,e)
45 | return None, None
46 | try:
47 | embeds = []
48 | with open(problem_file_cur.replace(".json", ".vopkl"), "rb") as f:
49 | embeds = pickle.load(f)
50 | except:
51 | pass
52 | return problem, embeds
53 |
54 | # quick and dirty vector database implementation
55 | class VectorDB:
56 | def __init__(self):
57 | pass
58 |
59 | def load_all(self, shuffle = False, load_around = None, record_tasks=False, skipped_sources = []):
60 | self.arr = []
61 | self.metadata = []
62 | self.todos = []
63 | self.sources = {}
64 | fns = list(problem_filenames())
65 | if shuffle:
66 | random.shuffle(fns)
67 | for problem_file_cur in tqdm(fns):
68 | # if '洛谷' not in problem_file_cur:
69 | # continue
70 | if load_around is not None and len(self.arr) > load_around * 2:
71 | break
72 | if not record_tasks and not os.path.exists(problem_file_cur.replace(".json", ".vopkl")):
73 | continue
74 | problem, embeds = problem_embeds(problem_file_cur)
75 | if problem is None:
76 | continue
77 | statement = problem['statement']
78 | source = problem['source']
79 | if source in skipped_sources:
80 | continue
81 | self.sources[source] = self.sources.get(source, 0) + 1
82 | need_work = False
83 | for template in settings["TEMPLATES"]:
84 | md5 = processed_promptmd5(statement, template)
85 | found = False
86 | for m, u in embeds:
87 | if m[:8] == md5:
88 | found = True
89 | self.arr.append(np.array(u/np.linalg.norm(u),dtype=np.float16))
90 | self.metadata.append((problem_file_cur, source, len(statement.strip())))
91 | break
92 | if not found:
93 | need_work = True
94 | if need_work and record_tasks:
95 | self.todos.append(problem_file_cur)
96 | print('found',len(self.arr),'embeds')
97 | self.arr = np.array(self.arr,dtype=np.float16)
98 | if record_tasks:
99 | print('found',len(self.todos),'todos')
100 |
101 |
102 | def complete_todos(self, chunk_size = 200, length_limit = 1300, shuffle = False):
103 | todos = self.todos
104 | if shuffle:
105 | import random
106 | random.shuffle(todos)
107 | for i in tqdm(range(0,len(todos),chunk_size)):
108 | problems = todos[i:i+chunk_size]
109 | infos = {}
110 | for problem_file_cur in problems:
111 | try:
112 | full_problem = read_problem(problem_file_cur)
113 | statement = full_problem['statement']
114 | # load from corresponding npy
115 | except Exception as e:
116 | print('error',problem_file_cur,e)
117 | continue
118 | try:
119 | embeds = []
120 | with open(problem_file_cur.replace(".json", ".vopkl"), "rb") as f:
121 | embeds = pickle.load(f)
122 | except:
123 | pass
124 | infos[problem_file_cur] = full_problem.get('processed',[]), statement, embeds
125 | for template in settings["TEMPLATES"]:
126 | queues = []
127 | max_length = 0
128 | for problem_file_cur, (processed, statement, embeds) in infos.items():
129 | md5 = processed_promptmd5(statement, template)
130 | if any(m[:8] == md5 for m, u in embeds): continue
131 | # get processed
132 | processed_text = None
133 | for f in processed:
134 | if f["prompt_md5"][:8] == md5:
135 | if len(f['result']) > length_limit:
136 | continue # too long?
137 | processed_text = f["result"]
138 | max_length = max(max_length, len(processed_text))
139 | if processed_text is None:
140 | continue
141 | queues.append((processed_text, problem_file_cur, md5))
142 | if len(queues) == 0:
143 | continue
144 | print('batch',len(queues),' maxlen',max_length)
145 | try:
146 | t0 = time.time()
147 | response = voyage_client.embed(
148 | [
149 | x[0] for x in queues
150 | ],
151 | model="voyage-large-2-instruct",
152 | input_type='document'
153 | )
154 | print('Token spent',response.total_tokens)
155 | t1 = time.time()
156 | # wait till 0.5s
157 | if t1 - t0 < 0.2:
158 | time.sleep(0.2 - (t1 - t0))
159 | for q,e in zip(queues, response.embeddings):
160 | infos[q[1]][2].append((q[2], np.array(e)))
161 | except Exception as e:
162 | print('error',e)
163 | for problem_file_cur, (processed, statement, embeds) in infos.items():
164 | dump_pickle_safe(embeds, problem_file_cur.replace(".json", ".vopkl"))
165 |
166 |
167 | def query_nearest(self, emb, k=1000, dedup=True):
168 | # return the k nearest embeddings with cosine similarity
169 | # return a list of (cosine similarity, metadata) tuples
170 | # the list is sorted by cosine similarity
171 | # normailze emb
172 | emb = np.array(emb)
173 | if len(emb.shape) == 1:
174 | emb = emb[None, :]
175 | emb = emb / np.linalg.norm(emb, axis=1, keepdims=True)
176 | emb = np.array(emb, dtype=np.float16)
177 | sims = np.max(self.arr @ emb.T, axis=1)
178 | sims = np.clip((sims+1)/2, 0, 1) # [-1,1] -> [0,1]
179 | topk = np.argsort(sims)[::-1]
180 | nearest = []
181 | keys = set()
182 | # print(f'query nearest {len(emb)=} {len(sims)=} {len(topk)=} {k=}')
183 | for i in topk:
184 | if dedup:
185 | key = self.metadata[i][0]
186 | if key in keys:
187 | continue
188 | keys.add(key)
189 | nearest.append((sims[i], i))
190 | if len(nearest) >= k:
191 | break
192 | return nearest
193 |
194 | if __name__ == "__main__":
195 | db = VectorDB()
196 | db.load_all(record_tasks=True)
197 | db.complete_todos(chunk_size=128)
198 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | anthropic==0.32.0
2 | beautifulsoup4==4.11.1
3 | fastapi==0.112.0
4 | gradio==4.41.0
5 | lingua==4.15.0
6 | numpy==1.24.4
7 | openai==1.40.2
8 | Requests==2.32.3
9 | together==1.2.7
10 | tqdm==4.66.4
11 | uvicorn==0.30.5
12 | voyageai==0.2.3
13 |
--------------------------------------------------------------------------------
/src/ui.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .embedder import VectorDB, processed_promptmd5
3 | from .utils import read_problem
4 | from tqdm.auto import tqdm
5 | import gradio as gr
6 | import json
7 | import asyncio
8 | from openai import AsyncOpenAI
9 | from together import AsyncTogether
10 | from fastapi import FastAPI
11 | from fastapi.responses import HTMLResponse, FileResponse
12 | import uvicorn
13 | import urllib
14 | import time
15 | import voyageai
16 | from concurrent.futures import ThreadPoolExecutor
17 | executor = ThreadPoolExecutor(max_workers=8)
18 |
19 | db = VectorDB()
20 | db.load_all()
21 | print("read", len(set(x[0] for x in db.metadata)), "problems")
22 | print(db.metadata[:100])
23 |
24 | with open("settings.json") as f:
25 | settings = json.load(f)
26 |
27 | voyage_client = voyageai.Client(
28 | api_key=settings['VOYAGE_API_KEY'],
29 | max_retries=3,
30 | timeout=120,
31 | )
32 |
33 | openai_client = AsyncOpenAI(
34 | api_key=settings["OPENAI_API_KEY"],
35 | )
36 |
37 | together_client = AsyncTogether(
38 | api_key=settings['TOGETHER_API_KEY'],
39 | )
40 |
41 | async def querier_i18n(locale, statement, *template_choices):
42 | assert len(template_choices) % 3 == 0
43 | yields = []
44 | ORIGINAL = statement.strip()
45 | t1 = time.time()
46 |
47 | async def process_template(engine, prompt, prefix):
48 | if 'origin' in engine.lower() or '保' in engine.lower():
49 | return ORIGINAL
50 | if 'none' in engine.lower() or '跳' in engine.lower():
51 | return ''
52 |
53 | prompt = prompt.replace("[[ORIGINAL]]", ORIGINAL).strip()
54 |
55 | if "gemma" in engine.lower():
56 | response = await together_client.chat.completions.create(
57 | messages=[
58 | {"role": "user", "content": prompt},
59 | {"role": "assistant", "content": prefix}
60 | ],
61 | model="google/gemma-2-27b-it",
62 | )
63 | return response.choices[0].message.content.strip()
64 | elif "gpt" in engine.lower():
65 | response = await openai_client.chat.completions.create(
66 | messages=[
67 | {"role": "system", "content": "You are a helpful assistant."},
68 | {"role": "user", "content": prompt},
69 | ],
70 | model="gpt-4o-mini"
71 | )
72 | return response.choices[0].message.content.strip().replace(prefix.strip(), '', 1).strip()
73 | else:
74 | raise NotImplementedError(engine)
75 |
76 | tasks = [process_template(template_choices[i], template_choices[i+1], template_choices[i+2])
77 | for i in range(0, len(template_choices), 3)]
78 | yields = await asyncio.gather(*tasks)
79 |
80 | t2 = time.time()
81 | print('query llm', t2-t1)
82 | response = voyage_client.embed(
83 | list(set(y.strip() for y in yields if len(y))),
84 | model="voyage-large-2-instruct",
85 | input_type='query'
86 | )
87 | print('Token spent',response.total_tokens)
88 | emb = [d for d in response.embeddings]
89 | t3 = time.time()
90 | print('query emb', t3-t2)
91 |
92 | loop = asyncio.get_running_loop()
93 | nearest = await loop.run_in_executor(executor, db.query_nearest, emb, 5000)
94 | # nearest = db.query_nearest(emb, k=5000)
95 | t4 = time.time()
96 | print('query nearest', t4-t3)
97 |
98 | sim = np.array([x[0] for x in nearest])
99 | ids = np.array([x[1] for x in nearest], dtype=np.int32)
100 |
101 | info = 'Fetched top ' + str(len(sim)) + ' matches! Go to the next tab to view results~' if locale == 'en' else \
102 | '已查找到前' + str(len(sim)) + '个匹配!进入下一页查看结果~'
103 |
104 | return [info, (sim, ids)] + yields
105 |
106 |
107 | def format_problem_i18n(locale, uid, sim):
108 | def tr(en,zh):
109 | if locale == 'en': return en
110 | if locale == 'zh': return zh
111 | raise NotImplementedError(locale)
112 | # be careful about arbitrary reads
113 | uid = db.metadata[int(uid)][0]
114 | problem = read_problem(uid)
115 | statement = problem["statement"].replace("\n", "\n\n")
116 | # summary = sorted(problem.get("processed",[]), key=lambda t: t["template_md5"])
117 | # if len(summary):
118 | # summary = summary[0]["result"]
119 | # else:
120 | # summary = None
121 | title = problem['title']
122 | lang = problem.get('locale',('un', 'Unknown'))
123 | def to_flag(t,u):
124 | if t == 'un':
125 | # get a ? with border, 14x20
126 | return f"""
{title} {problemlink} ({round(sim*100)}%)
\n' 144 | link0 = 'https://www.google.com/search?'+urllib.parse.urlencode({'q': problemlink}) 145 | link1 = 'https://www.google.com/search?'+urllib.parse.urlencode({'q': problem['source']+' '+title}) 146 | link0_bd = 'https://www.baidu.com/s?'+urllib.parse.urlencode({'wd': problemlink}) 147 | link1_bd = 'https://www.baidu.com/s?'+urllib.parse.urlencode({'wd': problem['source']+' '+title}) 148 | # {tr("Google2","谷歌2")} 149 | # Baidu2 150 | html += f'{flag} VJudge {tr("Google","谷歌")} {tr("Baidu","百度")}' 151 | markdown = '' 152 | rsts = [] 153 | for template in settings['TEMPLATES']: 154 | md5 = processed_promptmd5(problem['statement'], template) 155 | rst = None 156 | for t in problem.get("processed",[]): 157 | if t["prompt_md5"][:8] == md5: 158 | rst = t["result"] 159 | if rst is not None: 160 | rsts.append(rst) 161 | rsts.sort(key=len) 162 | for idx, rst in enumerate(rsts): 163 | markdown += f'### {tr("Summary", "简要题意")} {idx+1}\n\n{rst}\n\n' 164 | if markdown != '': 165 | markdown += 'Redirecting based on your browser's locale...
370 | 371 | 372 | 373 | """ 374 | return HTMLResponse(content=html_content) 375 | 376 | app = gr.mount_gradio_app(app, get_block('zh'), path="/zh") 377 | app = gr.mount_gradio_app(app, get_block('en'), path="/en") 378 | 379 | if __name__ == "__main__": 380 | uvicorn.run(app, host="0.0.0.0", port=80) 381 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import os 4 | import shutil 5 | import typing 6 | import bs4 7 | import numpy as np 8 | import tempfile 9 | 10 | 11 | # https://stackoverflow.com/a/66835172 12 | def get_text(tag: bs4.Tag) -> str: 13 | _inline_elements = { 14 | "a", 15 | "span", 16 | "em", 17 | "strong", 18 | "u", 19 | "i", 20 | "font", 21 | "mark", 22 | "label", 23 | "s", 24 | "sub", 25 | "sup", 26 | "tt", 27 | "bdo", 28 | "button", 29 | "cite", 30 | "del", 31 | "b", 32 | "a", 33 | "font", 34 | } 35 | 36 | def _get_text(tag: bs4.Tag) -> typing.Generator: 37 | for child in tag.children: 38 | if isinstance(child, bs4.Tag): 39 | # if the tag is a block type tag then yield new lines before after 40 | is_block_element = child.name not in _inline_elements 41 | if is_block_element: 42 | yield "\n" 43 | yield from ["\n"] if child.name == "br" else _get_text(child) 44 | if is_block_element: 45 | yield "\n" 46 | elif isinstance(child, bs4.NavigableString): 47 | yield child.string 48 | 49 | return "".join(_get_text(tag)) 50 | 51 | def cleanup_str(s: str, allow_double_breaks = False) -> str: 52 | s = '\n'.join( line.strip() for line in s.splitlines() ).strip() 53 | # remove redundant linebreaks 54 | while True: 55 | if allow_double_breaks: 56 | ss = s.replace('\n\n\n','\n\n') 57 | else: 58 | ss = s.replace('\n\n','\n') 59 | if ss == s: break 60 | s = ss 61 | return s 62 | 63 | def read_problem(filename): 64 | # read as a json 65 | with open(filename, encoding='utf-8') as f: 66 | return json.load(f) 67 | 68 | def problem_filenames(path='problems/'): 69 | for root, dirs, files in os.walk(path): 70 | for filename in files: 71 | if not filename.endswith(".json"): 72 | continue 73 | yield os.path.join(root, filename) 74 | 75 | 76 | def list_problems(embed = False): 77 | # list all problems under problems/ 78 | for problem_filename in problem_filenames(): 79 | assert problem_filename.endswith(".json") 80 | if not embed: 81 | yield problem_filenames 82 | continue 83 | npy_filename = problem_filename[:-5] + ".npy" 84 | if os.path.exists(npy_filename): 85 | yield read_problem(problem_filename), np.load(npy_filename) 86 | 87 | 88 | def dump_json_safe(obj, filename): 89 | with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: 90 | json.dump(obj, f) 91 | shutil.move(f.name, filename) 92 | 93 | 94 | def dump_json_safe_utf8(obj, filename): 95 | with tempfile.NamedTemporaryFile(mode="w", delete=False, encoding='utf-8') as f: 96 | json.dump(obj, f, ensure_ascii=False) 97 | shutil.move(f.name, filename) 98 | 99 | def dump_pickle_safe(obj, filename): 100 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: 101 | pickle.dump(obj, f, protocol=4) 102 | shutil.move(f.name, filename) 103 | 104 | 105 | def dump_numpy_safe(obj, filename): 106 | with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: 107 | np.save(f, obj) 108 | shutil.move(f.name, filename) 109 | --------------------------------------------------------------------------------