├── .github └── workflows │ └── update_toc.yml ├── .gitignore ├── CITATION.bib ├── LICENSE ├── README.md ├── colabs └── Visual-Self-Refine-GPT4V.ipynb ├── data ├── outputs │ └── yelp │ │ ├── chatgpt.jsonl │ │ ├── dv3.jsonl │ │ └── gpt4.jsonl ├── prompt │ ├── acronym │ │ ├── feedback.jsonl │ │ └── init.jsonl │ ├── commongen │ │ ├── commongen_hard.jsonl │ │ ├── feedback.jsonl │ │ ├── init.jsonl │ │ └── iterate.jsonl │ ├── gsm │ │ ├── feedback.txt │ │ └── init.txt │ ├── pie │ │ ├── feedback.txt │ │ ├── init.txt │ │ ├── iterate.txt │ │ ├── iterate_genericfb.txt │ │ └── iterate_nofb.txt │ └── responsegen │ │ ├── fed_data.json │ │ ├── feedback.jsonl │ │ └── init.jsonl └── tasks │ ├── acronyms │ └── acronyms.tsv │ ├── codeclean │ └── code_readability │ │ └── codenet-python-train.jsonl.zip │ ├── gsm │ ├── gsm.jsonl │ ├── gsm_outputs.jsonl │ └── gsm_outputs.jsonl.reports.txt │ ├── pie │ ├── codenet-python-test-1k.jsonl │ ├── gpt4_outputs.zip │ ├── gpt4_outputs_flattened.jsonl │ ├── gpt4_outputs_self_refine.jsonl │ └── ref.jsonl │ └── yelp │ └── yelp-extreme.jsonl ├── docs ├── .gitignore ├── CNAME ├── Gemfile ├── Gemfile.lock ├── README.md ├── commongen_feedback.txt ├── commongen_init.txt ├── commongen_iterate.txt ├── index.html ├── pie_eval.md ├── static │ ├── animation_static_end.gif │ ├── css │ │ ├── bulma-carousel.min.css │ │ ├── bulma-slider.min.css │ │ ├── bulma.min.css │ │ ├── collapsible.css │ │ ├── fontawesome.all.min.css │ │ ├── index.css │ │ └── prism.css │ ├── images │ │ ├── animation_oldstyle_oneloop.gif │ │ ├── autofb_animation.gif │ │ ├── autofb_animation_static.gif │ │ ├── autofb_fig1.pdf │ │ ├── autofb_fig1.png │ │ ├── autofb_static.png │ │ ├── example1.png │ │ ├── example1desc.png │ │ ├── favicon.svg │ │ ├── fig2.png │ │ ├── mainresults.png │ │ ├── pal_overview.png │ │ ├── refinement.png │ │ └── tasks.png │ ├── js │ │ ├── bulma-carousel.js │ │ ├── bulma-carousel.min.js │ │ ├── bulma-slider.js │ │ ├── bulma-slider.min.js │ │ ├── collapsible.js │ │ ├── fontawesome.all.min.js │ │ ├── index.js │ │ ├── prism.js │ │ └── results.js │ └── outputs.html ├── view.pdf └── visual_self_refine_examples │ ├── alien.gif │ ├── artificial_general_intell.gif │ ├── bird.gif │ ├── bird_v2.gif │ ├── c___program.gif │ ├── cat.gif │ ├── cell.gif │ ├── cell_v2.gif │ ├── dragon.gif │ ├── dragon_v2.gif │ ├── drums.gif │ ├── duck.gif │ ├── ellipsoid.gif │ ├── feedback_driven_self_improvement.gif │ ├── ferrari.gif │ ├── flower.gif │ ├── game_of_life.gif │ ├── giphy.gif │ ├── goat.gif │ ├── gradient_descent.gif │ ├── guitar.gif │ ├── hidden_markov_model.gif │ ├── hmm.gif │ ├── hyperboloid.gif │ ├── hyperplane.gif │ ├── illustration_of_a_cell.gif │ ├── illustration_of_a_neural_.gif │ ├── illustration_of_hmms.gif │ ├── linux_penguin.gif │ ├── neuron.gif │ ├── nice_unicorn.gif │ ├── nine_point_circle.gif │ ├── panda_v2.gif │ ├── penguin.gif │ ├── penguin_v2.gif │ ├── pentagon.gif │ ├── puppy.gif │ ├── pyramid.gif │ ├── pythagorean_theorem.gif │ ├── roast_turkey.gif │ ├── robot.gif │ ├── rocket.gif │ ├── self_attention.gif │ ├── self_refine_is_a_novel_ap.gif │ ├── self_refinement_loop.gif │ ├── space_shuttle.gif │ ├── spaceship.gif │ ├── stokes__theorem.gif │ ├── support_vector_machine.gif │ ├── tcp_ip.gif │ ├── thanksgiving_turkey.gif │ ├── toroid.gif │ ├── train.gif │ ├── turkey.gif │ ├── turkey_v1.gif │ ├── turtle.gif │ ├── unicorn.gif │ ├── unicorn_v2.gif │ ├── unicorn_v3.gif │ ├── unicorn_v4.gif │ ├── vae.gif │ └── variational_inference.gif └── src ├── acronym ├── __init__.py ├── feedback.py ├── run.py ├── run_mcts.py ├── task_init.py └── task_iterate.py ├── commongen ├── __init__.py ├── data.py ├── eval.py ├── feedback.py ├── make_challenging.py ├── run.py ├── task_init.py └── task_iterate.py ├── gsm ├── __init__.py ├── feedback.py ├── feedback_no_update.py ├── gsm_selfref_eval.py ├── run.py └── task_init.py ├── pie ├── feedback.py ├── pie_eval.py ├── prep_for_pie_eval.py ├── run.py ├── task_init.py └── task_iterate.py ├── readability ├── __init__.py ├── count_comment.py ├── count_function.py ├── count_meaningful_var.py ├── prompts.py ├── readability.py └── utils.py ├── responsegen ├── __init__.py ├── feedback.py ├── run.py ├── task_init.py └── task_iterate.py ├── sentiment_reversal ├── feedback.py ├── gpt4_eval.py ├── measure.py ├── run.py ├── task_init.py └── task_iterate.py └── utils.py /.github/workflows/update_toc.yml: -------------------------------------------------------------------------------- 1 | name: Update Table of Contents 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | 10 | jobs: 11 | update-toc: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out code 15 | uses: actions/checkout@v2 16 | with: 17 | token: ${{ secrets.TOC_UPDATE_TOKEN }} # Use the PAT for authentication 18 | 19 | - name: Update ToC 20 | uses: technote-space/toc-generator@v4 21 | with: 22 | INSERT_TO: README.md 23 | BASE_BRANCH: main 24 | COMMIT_MSG: 'chore: update TOC' 25 | GITHUB_TOKEN: ${{ secrets.TOC_UPDATE_TOKEN }} # Use the PAT for authentication 26 | env: 27 | TOC_TITLE: '## Table of Contents' 28 | 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | *.txt 3 | *.out 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | out-responsegen.json 136 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @misc{madaan2023selfrefine, 2 | title={Self-Refine: Iterative Refinement with Self-Feedback}, 3 | author={Aman Madaan and Niket Tandon and Prakhar Gupta and Skyler Hallinan and Luyu Gao and Sarah Wiegreffe and Uri Alon and Nouha Dziri and Shrimai Prabhumoye and Yiming Yang and Sean Welleck and Bodhisattwa Prasad Majumder and Shashank Gupta and Amir Yazdanbakhsh and Peter Clark}, 4 | year={2023}, 5 | eprint={2303.17651}, 6 | archivePrefix={arXiv}, 7 | primaryClass={cs.CL} 8 | } 9 | -------------------------------------------------------------------------------- /data/prompt/acronym/init.jsonl: -------------------------------------------------------------------------------- 1 | {"title":"A Survey of Active Network Research","acronym":"SONAR"} 2 | {"title":"A Scalable, Commutative Replica Dictatorship for Practical Optimistic Replication","acronym":"SCRATCHPAD"} 3 | {"title":"Bidirectional Encoder Representations from Transformers","acronym":"BERT"} 4 | {"title":"Sequence to Sequence Learning with Neural Networks","acronym":"Seq2Seq"} 5 | {"title":"Densely Connected Convolutional Networks for Image Classification","acronym":"DenseNet"} 6 | {"title":"A Dynamic Programming Algorithm for RNA Secondary Structure Prediction","acronym":"DYNALIGN"} 7 | {"title":"Fast Parallel Algorithms for Short-Range Molecular Dynamics","acronym":"FASTMD"} 8 | {"title":"Real-Time Collaborative Editing Systems","acronym":"COCOON"} 9 | {"title":"Efficient Data Structures for Large Scale Graph Processing","acronym":"EDGE"} 10 | {"title":"A program to teach students at UT Southwestern learn about aging","acronym":"SAGE"} 11 | {"title":"Underwater breathing without external accessories","acronym":"SCUBA"} 12 | {"title":"An educational training module for professionals","acronym":"LEAP"} 13 | {"title":"Teaching a leadership program","acronym":"LEAD"} 14 | -------------------------------------------------------------------------------- /data/prompt/commongen/feedback.jsonl: -------------------------------------------------------------------------------- 1 | {"concepts": ["beat", "drum", "pen", "sit", "use"], "sentence": "A man uses a drum to beat a pen.", "concept_feedback": ["sit"], "commonsense_feedback": "NONE"} 2 | {"concepts": ["chair", "clipper", "cut", "hair", "sit"], "sentence": "A girl sitting on the couch with her hair up.", "concept_feedback": ["clipper", "cut", "chair"], "commonsense_feedback": "NONE"} 3 | {"concepts": ["grass", "hose", "spray", "stand", "water"], "sentence": "A man standing next to a water dripping out of the grass.", "concept_feedback": ["hose", "spray"], "commonsense_feedback": "NONE"} 4 | {"concepts": ["front", "gong", "hit", "mallet", "stand"], "sentence": "The musician hit the gong with a mallet while standing in front of the audince.", "concept_feedback": ["NONE"], "commonsense_feedback": "NONE"} 5 | {"concepts": ["ball", "dunk", "hoop", "jump", "run"], "sentence": "A young boy runs up to the hoop and jumps off of the ball.", "concept_feedback": ["dunk"], "commonsense_feedback": "NONE"} 6 | {"concepts": ["card", "chip", "deal", "dealer", "table"], "sentence": "a dealer offers a card to a group of people at a table", "concept_feedback": ["chip", "deal"], "commonsense_feedback": "NONE"} 7 | {"concepts": ["clean", "climb", "gutter", "house", "ladder"], "sentence": "A man is climbing a ladder to clean a gutter in a house.", "concept_feedback": ["NONE"], "commonsense_feedback": "NONE"} 8 | {"concepts": ["animal", "catch", "horse", "lasso", "ride"], "sentence": "A horse is being caught by a cowboy with a lasso.", "concept_feedback": ["animal", "ride"], "commonsense_feedback": "NONE"} 9 | {"concepts": ["beat", "drum", "pen", "sit", "use"], "sentence": "The drum sits on the pen and uses the beat.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a drum cannot sit on a pen and use a beat."} 10 | {"concepts": ["chair", "clipper", "cut", "hair", "sit"], "sentence": "The clipper sits on the chair.", "concept_feedback": ["cut", "hair"], "commonsense_feedback": "The sentence does not make sense because a clipper cannot sit on a chair."} 11 | {"concepts": ["grass", "hose", "spray", "stand", "water"], "sentence": "The water stands in the grass.", "concept_feedback": ["hose", "spray"], "commonsense_feedback": "The sentence does not make sense because water cannot stand in grass."} 12 | {"concepts": ["front", "gong", "hit", "mallet", "stand"], "sentence": "On a sunny day, a mallet stands in the front of a gong and it hit the gong with a loud sound.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a mallet cannot stand in the front of a gong and cannot hit it."} 13 | {"concepts": ["ball", "dunk", "hoop", "jump", "run"], "sentence": "The ball runs to the hoop and dunks.", "concept_feedback": ["jump"], "commonsense_feedback": "The sentence does not make sense because a ball cannot run and dunk."} 14 | {"concepts": ["card", "chip", "deal", "dealer", "table"], "sentence": "The chip deals a card to the dealer at the table.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a chip cannot deal a card to a dealer."} 15 | {"concepts": ["clean", "climb", "gutter", "house", "ladder"], "sentence": "The ladder climbs to the house and cleans the gutter.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a ladder cannot climb to a house and cannot clean a gutter."} 16 | {"concepts": ["animal", "catch", "horse", "lasso", "ride"], "sentence": "The horse catches the lasso and rides on it.", "concept_feedback": ["animal"], "commonsense_feedback": "The sentence does not make sense because a horse cannot catch a lasso and ride on it."} 17 | -------------------------------------------------------------------------------- /data/prompt/commongen/init.jsonl: -------------------------------------------------------------------------------- 1 | {"gem_id":"common_gen-train-66326","gem_parent_id":"common_gen-train-66326","concept_set_id":31588,"concepts":["footage","motion","ruin","tilt","window"],"target":"time lapse footage with tilt up motion of the sun streaking through window of ruin","references":[],"question":"footage motion ruin tilt window","answer":"time lapse footage with tilt up motion of the sun streaking through window of ruin","num_concepts":5} 2 | {"gem_id":"common_gen-train-67070","gem_parent_id":"common_gen-train-67070","concept_set_id":32332,"concepts":["cause","hate","hut","local","love"],"target":"new beach huts on the island have caused some controversy some locals love them others hate them","references":[],"question":"cause hate hut local love","answer":"new beach huts on the island have caused some controversy some locals love them others hate them","num_concepts":5} 3 | {"gem_id":"common_gen-train-66241","gem_parent_id":"common_gen-train-66241","concept_set_id":31503,"concepts":["call","contain","dress","gown","wallpaper"],"target":"the wallpaper probably containing a gown and a dinner dress called","references":[],"question":"call contain dress gown wallpaper","answer":"the wallpaper probably containing a gown and a dinner dress called","num_concepts":5} 4 | {"gem_id":"common_gen-train-66976","gem_parent_id":"common_gen-train-66976","concept_set_id":32238,"concepts":["knock","leave","pew","rush","seat"],"target":"She leaves the confessional and rushes past a pew, knocking a Bible from the seat.","references":[],"question":"knock leave pew rush seat","answer":"She leaves the confessional and rushes past a pew, knocking a Bible from the seat.","num_concepts":5} 5 | {"gem_id":"common_gen-train-67140","gem_parent_id":"common_gen-train-67140","concept_set_id":32402,"concepts":["help","moment","spend","uplift","world"],"target":"every moment that we spend in higher consciousness helps uplift the consciousness of the whole world","references":[],"question":"help moment spend uplift world","answer":"every moment that we spend in higher consciousness helps uplift the consciousness of the whole world","num_concepts":5} 6 | {"gem_id":"common_gen-train-65553","gem_parent_id":"common_gen-train-65553","concept_set_id":30815,"concepts":["label","pende","stamp","text","write"],"target":"abstract stamp or label with the text pending written inside","references":[],"question":"label pende stamp text write","answer":"abstract stamp or label with the text pending written inside","num_concepts":5} 7 | {"gem_id":"common_gen-train-65819","gem_parent_id":"common_gen-train-65819","concept_set_id":31081,"concepts":["create","ferry","silhouette","stream","terminal"],"target":"light streams through windows at the railroad and ferry terminal creating a beautiful silhouette","references":[],"question":"create ferry silhouette stream terminal","answer":"light streams through windows at the railroad and ferry terminal creating a beautiful silhouette","num_concepts":5} 8 | {"gem_id":"common_gen-train-64906","gem_parent_id":"common_gen-train-64906","concept_set_id":30168,"concepts":["chair","couch","hang","room","wall"],"target":"A room with a couch, chairs and art hanging on the wall.","references":[],"question":"chair couch hang room wall","answer":"A room with a couch, chairs and art hanging on the wall.","num_concepts":5} 9 | {"gem_id":"common_gen-train-67017","gem_parent_id":"common_gen-train-67017","concept_set_id":32279,"concepts":["boat","building","harbour","moor","quay"],"target":"the harbour and port with fishing boats moored and old buildings on the quay","references":[],"question":"boat building harbour moor quay","answer":"the harbour and port with fishing boats moored and old buildings on the quay","num_concepts":5} 10 | {"gem_id":"common_gen-train-65734","gem_parent_id":"common_gen-train-65734","concept_set_id":30996,"concepts":["admirer","arrive","commander","crowd","greet"],"target":"military commander is greeted by a crowd of admirers as he arrives","references":[],"question":"admirer arrive commander crowd greet","answer":"military commander is greeted by a crowd of admirers as he arrives","num_concepts":5} -------------------------------------------------------------------------------- /data/prompt/commongen/iterate.jsonl: -------------------------------------------------------------------------------- 1 | { "concepts": ["beat", "drum", "pen", "sit", "use"], "sentence_to_feedback": [{"sentence": "The drum sits on the pen and uses it to beat.", "concept_feedback": "None", "commonsense_feedback": "The sentence does not make sense because a drum cannot sit on a pen and use it to beat."}, {"sentence": "The drummer uses the drum to beat.", "concept_feedback": "sit, pen", "commonsense_feedback": "None"}, {"sentence": "The drummer sits behind the drum and uses it to beat the pen.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 2 | { "concepts": ["chair", "clipper", "cut", "hair", "sit"], "sentence_to_feedback": [{"sentence": "A couch shaped girl sitting on the chair with her hair clipper.", "concept_feedback": "cut", "commonsense_feedback": "The sentence does not make sense because a couch is not a shape and a hair clipper is not an item of clothing."}, {"sentence": "A girl on the chair with her hair clipper and cutting her hair.", "concept_feedback": "sit", "commonsense_feedback": "None"}, {"sentence": "A girl sitting on the chair with a hair clipper, cutting her hair up.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 3 | { "concepts": [ "grass", "hose", "spray", "stand", "water"], "sentence_to_feedback": [{"sentence": "The grass is standing tall and a hose is spraying it with spray.", "concept_feedback": "water", "commonsense_feedback": "The sentence does not make sense because it is not clear what 'spray' is referring to, and grass does not have the ability to stand upright like a human or animal."}, {"sentence": "The hose is spraying water onto the grass that is standing like a person.", "concept_feedback": "None", "commonsense_feedback": "The sentence does not make sense because grass cannot stand upright like a human or animal."}, {"sentence": "A person is standing on the grass, holding a hose that is spraying water onto the grass.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 4 | { "concepts": ["front", "gong", "hit", "mallet", "stand"], "sentence_to_feedback": [{"sentence": "A mallet is standing in front of a gong.", "concept_feedback": "hit", "commonsense_feedback": "The sentence does not make sense because a mallet cannot stand in the front of a gong."}, {"sentence": "A musician stands in front of a gong with a mallet.", "concept_feedback": "None", "commonsense_feedback": "None"}, {"sentence": "The musician stands in front of the gong and hits it with a mallet.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 5 | { "concepts": ["ball", "dunk", "hoop", "jump", "run"], "sentence_to_feedback": [{"sentence": "The ball runs to the hoop and dunks it.", "concept_feedback": "jump", "commonsense_feedback": "The sentence does not make sense because a ball cannot run and dunk."}, {"sentence": "The ball jumps to the hoop and dunks it.", "concept_feedback": "run", "commonsense_feedback": "The sentence does not make sense because a ball cannot jump."}, {"sentence": "A basketball player runs up to the hoop and jumps off of the ball to dunk it.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 6 | { "concepts": ["card", "chip", "deal", "dealer", "table"], "sentence_to_feedback": [{"sentence": "A dealer offers a card to a group of people at a table.", "concept_feedback": "chip, deal", "commonsense_feedback": "The sentence does not make sense because a chip cannot deal a card to a dealer."}, {"sentence": "The dealer deals a card to a group of people.", "concept_feedback": "chip", "commonsense_feedback": "None"}, {"sentence": "The dealer deals a card to a group of people around the table with a chip at the table.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 7 | { "concepts": ["clean", "climb", "gutter", "house", "ladder"], "sentence_to_feedback": [{"sentence": "The house is clean and a ladder is trying to climb.", "concept_feedback": "climb", "commonsense_feedback": "The sentence does not make sense because ladders cannot climb by themselves."}, {"sentence": "A person is cleaning the gutter of the house by climbing onto the roof with a ladder made of glass.", "concept_feedback": "None", "commonsense_feedback": "The sentence does not make sense because ladders are not made of glass, and using a glass ladder would be dangerous and impractical."}, {"sentence": "A person is cleaning the gutter of the house by using a ladder to climb onto the roof and brushing away the dirt.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 8 | { "concepts": ["animal", "catch", "horse", "lasso", "ride"], "sentence_to_feedback": [{"sentence": "The horse catches the lasso and rides on it.", "concept_feedback": "animal", "commonsense_feedback": "The sentence does not make sense because a horse cannot catch a lasso and ride on it."}, {"sentence": "The cowboy catches a horse with a lasso and rides on it.", "concept_feedback": "animal", "commonsense_feedback": "None"}, {"sentence": "The cowboy catches the horse with a lasso and rides it.", "concept_feedback": "None", "commonsense_feedback": "None"}] } 9 | -------------------------------------------------------------------------------- /data/prompt/gsm/init.txt: -------------------------------------------------------------------------------- 1 | # Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 2 | # solution using Python: 3 | 4 | def solution(): 5 | """Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?""" 6 | jason_lollipops_initial = 20 7 | jason_lollipops_after = 12 8 | denny_lollipops = jason_lollipops_initial - jason_lollipops_after 9 | result = denny_lollipops 10 | return result 11 | 12 | 13 | 14 | # Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 15 | # solution using Python: 16 | 17 | def solution(): 18 | """There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?""" 19 | trees_initial = 15 20 | trees_after = 21 21 | trees_added = trees_after - trees_initial 22 | result = trees_added 23 | return result 24 | 25 | 26 | 27 | # Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 28 | # solution using Python: 29 | 30 | def solution(): 31 | """Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?""" 32 | toys_initial = 5 33 | mom_toys = 2 34 | dad_toys = 2 35 | total_received = mom_toys + dad_toys 36 | total_toys = toys_initial + total_received 37 | result = total_toys 38 | return result 39 | 40 | 41 | 42 | # Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 43 | # solution using Python: 44 | 45 | def solution(): 46 | """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?""" 47 | computers_initial = 9 48 | computers_per_day = 5 49 | num_days = 4 # 4 days between monday and thursday 50 | computers_added = computers_per_day * num_days 51 | computers_total = computers_initial + computers_added 52 | result = computers_total 53 | return result 54 | 55 | 56 | 57 | # Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 58 | # solution using Python: 59 | 60 | def solution(): 61 | """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?""" 62 | golf_balls_initial = 58 63 | golf_balls_lost_tuesday = 23 64 | golf_balls_lost_wednesday = 2 65 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday 66 | result = golf_balls_left 67 | return result 68 | 69 | 70 | 71 | # Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 72 | # solution using Python: 73 | 74 | def solution(): 75 | """If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?""" 76 | cars_initial = 3 77 | cars_arrived = 2 78 | total_cars = cars_initial + cars_arrived 79 | result = total_cars 80 | return result 81 | 82 | 83 | 84 | # Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 85 | # solution using Python: 86 | 87 | def solution(): 88 | """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?""" 89 | money_initial = 23 90 | bagels = 5 91 | bagel_cost = 3 92 | money_spent = bagels * bagel_cost 93 | money_left = money_initial - money_spent 94 | result = money_left 95 | return result 96 | 97 | 98 | 99 | # Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 100 | # solution using Python: 101 | 102 | def solution(): 103 | """Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?""" 104 | leah_chocolates = 32 105 | sister_chocolates = 42 106 | total_chocolates = leah_chocolates + sister_chocolates 107 | chocolates_eaten = 35 108 | chocolates_left = total_chocolates - chocolates_eaten 109 | result = chocolates_left 110 | return result 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /data/prompt/pie/feedback.txt: -------------------------------------------------------------------------------- 1 | a, b = input().split() 2 | n = int(a + b) 3 | 4 | flag = False 5 | for i in range(n): 6 | if i ** 2 == n: 7 | flag = True 8 | break 9 | 10 | print('Yes' if flag else 'No') 11 | 12 | 13 | # Why is this code slow? 14 | 15 | # This code is slow because it is using a brute force approach to find the square root of the input number. It is looping through every possible number starting from 0 until n. Note that the sqare root will be smaller than n, so at least half of the numbers it is looping through are unnecessary. At most, you need to loop through the numbers up to the square root of n. 16 | 17 | ### END ### 18 | 19 | # ABC148D - Brick Break 20 | def main(): 21 | N, *A = map(int, open(0).read().split()) 22 | remaining = 0 23 | cur = 0 # current index 24 | for target in range(1, N + 1): 25 | be_kept = False 26 | for i in range(cur, N): 27 | if A[i] == target: 28 | remaining += 1 29 | cur = i + 1 30 | be_kept = True 31 | break 32 | if not be_kept: 33 | break 34 | print(N - remaining if remaining else -1) 35 | 36 | # Why is this code slow? 37 | 38 | # This code is slow because it is using a brute force approach to search for the target number in the list of numbers. It is looping through the list for every target number, which can take a long time if the list is very large. A more efficient approach would be to use a data structure such as a hash table, which can perform lookups in constant time. 39 | 40 | ### END ### 41 | 42 | import numpy as np 43 | 44 | N = int(input()) 45 | n=int(np.sqrt(N)) 46 | print(n**2) 47 | 48 | # Why is this code slow? 49 | 50 | # This code is slow because it is using numpy for calculating the square root of the input number. Numpy is much slower than using the built-in math module, which can calculate the square root faster. 51 | 52 | ### END ### 53 | 54 | import numpy as np 55 | A = np.arange(1<<27,dtype=np.int32) 56 | 57 | 58 | a,b = map(int,input().split()) 59 | if (a-b) % 2 == 0: 60 | print((a+b)//2) 61 | else: 62 | print('IMPOSSIBLE') 63 | 64 | # Why is this code slow? 65 | 66 | # This code is slow because it is using numpy to calculate the range of numbers from 1 to 2^27. This is an extremely large range and numpy is slow at calculating such a large range. A more efficient approach would be to use a loop to calculate the range, which would be much faster. 67 | 68 | ### END ### 69 | 70 | -------------------------------------------------------------------------------- /data/prompt/pie/init.txt: -------------------------------------------------------------------------------- 1 | # slower version: 2 | 3 | a, b = input().split() 4 | n = int(a + b) 5 | 6 | flag = False 7 | for i in range(n): 8 | if i ** 2 == n: 9 | flag = True 10 | break 11 | 12 | print('Yes' if flag else 'No') 13 | 14 | 15 | # optimized version of the same code: 16 | 17 | a, b = input().split() 18 | n = int(a + b) 19 | 20 | flag = False 21 | for i in range(1000): 22 | if i ** 2 == n: 23 | flag = True 24 | break 25 | 26 | print('Yes' if flag else 'No') 27 | 28 | ### END ### 29 | 30 | # slower version: 31 | 32 | # ABC148D - Brick Break 33 | def main(): 34 | N, *A = map(int, open(0).read().split()) 35 | remaining = 0 36 | cur = 0 # current index 37 | for target in range(1, N + 1): 38 | be_kept = False 39 | for i in range(cur, N): 40 | if A[i] == target: 41 | remaining += 1 42 | cur = i + 1 43 | be_kept = True 44 | break 45 | if not be_kept: 46 | break 47 | print(N - remaining if remaining else -1) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | 53 | 54 | # optimized version of the same code: 55 | 56 | # ABC148D - Brick Break 57 | def main(): 58 | N, *A = map(int, open(0).read().split()) 59 | remaining = 0 60 | target = 1 61 | for i in A: 62 | if i == target: 63 | remaining += 1 64 | target += 1 65 | print(N - remaining if remaining else -1) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | 71 | ### END ### 72 | 73 | # slower version: 74 | 75 | # 077 B 76 | import numpy as np 77 | 78 | N = int(input()) 79 | n=int(np.sqrt(N)) 80 | print(n**2) 81 | 82 | 83 | # optimized version of the same code: 84 | 85 | N = int(input()) 86 | n = int(N**0.5) 87 | print(n**2) 88 | 89 | ### END ### 90 | 91 | # slower version: 92 | 93 | import numpy as np 94 | 95 | N, K = map(int, input().split()) 96 | H = np.array(list(map(int, input().split())) + [0] * K, dtype=np.int64) 97 | 98 | table = np.full(N + K, 10 ** 10, dtype=np.int64) 99 | table[0] = 0 100 | 101 | for i in range(1, N): 102 | table[i:i + K] = np.minimum(table[i:i + K], np.abs(H[i:i + K] - H[i - 1]) + table[i - 1]) 103 | 104 | print(table[N - 1]) 105 | 106 | 107 | # optimized version of the same code: 108 | 109 | N, K = map(int, input().split()) 110 | H = tuple(map(int, input().split())) 111 | 112 | table = [0] * N 113 | for i in range(1, N): 114 | table[i] = min(abs(H[i] - H[j]) + table[j] for j in range(max(0, i - K), i)) 115 | 116 | print(table[N-1]) 117 | 118 | ### END ### 119 | 120 | # slower version: 121 | 122 | n = int(input()) 123 | a = [int(i) for i in input().split()] 124 | a.sort() 125 | 126 | s = a[0] 127 | for i in range(n): 128 | s = (s+a[i])/2 129 | 130 | print(s) 131 | 132 | 133 | # optimized version of the same code: 134 | 135 | ### 138-c 136 | n = int(input()) 137 | v = [int(i) for i in input().split()] 138 | 139 | v.sort() 140 | ans = v[0] 141 | for i in range(1,n): 142 | ans = (ans+v[i])/2 143 | print(ans) 144 | 145 | ### END ### 146 | 147 | # slower version: 148 | 149 | # coding: utf-8 150 | n_S, n_T = 0, 0 151 | 152 | for c in input(): 153 | if c == 'S': 154 | n_S += 1 155 | else: 156 | if n_S: 157 | n_S -= 1 158 | else: 159 | n_T += 1 160 | print(n_S + n_T) 161 | 162 | 163 | # optimized version of the same code: 164 | 165 | X = input() 166 | cnt_S, cnt_T = 0, 0 167 | 168 | for c in X: 169 | if c == 'S': 170 | cnt_S += 1 171 | else: 172 | if cnt_S: 173 | cnt_S -= 1 174 | else: 175 | cnt_T += 1 176 | print(cnt_S + cnt_T) 177 | 178 | ### END ### 179 | 180 | # slower version: 181 | 182 | # ABC125 C - GCD on Blackboard 183 | from fractions import gcd 184 | n = int(input()) 185 | a = list(map(int, input().split())) 186 | 187 | l = n*[0] 188 | r = n*[0] 189 | ans = 0 190 | 191 | for i in range(n-1): 192 | l[i+1] = gcd(l[i],a[i]) 193 | 194 | for i in range(n-1,0,-1): 195 | r[i-1] = gcd(r[i],a[i]) 196 | 197 | for i in range(n): 198 | ans = max(ans,gcd(l[i],r[i])) 199 | print(ans) 200 | 201 | 202 | # optimized version of the same code: 203 | 204 | from math import gcd 205 | n=int(input()) 206 | a=list(map(int,input().split())) 207 | 208 | l=[0]*n 209 | r=[0]*n 210 | 211 | for i in range(n-1): 212 | l[i+1]=gcd(l[i],a[i]) 213 | for i in range(n-1,0,-1): 214 | r[i-1]=gcd(r[i],a[i]) 215 | ans=0 216 | for i in range(n): 217 | ans=max(ans,gcd(l[i],r[i])) 218 | print(ans) 219 | 220 | ### END ### 221 | 222 | # slower version: 223 | 224 | import numpy as np 225 | A = np.arange(1<<27,dtype=np.int32) 226 | 227 | 228 | a,b = map(int,input().split()) 229 | if (a-b) % 2 == 0: 230 | print((a+b)//2) 231 | else: 232 | print('IMPOSSIBLE') 233 | 234 | 235 | # optimized version of the same code: 236 | 237 | import sys 238 | read = sys.stdin.buffer.read 239 | readline = sys.stdin.buffer.readline 240 | readlines = sys.stdin.buffer.readlines 241 | 242 | A,B = map(int,read().split()) 243 | 244 | q,r = divmod(A+B,2) 245 | 246 | if r == 1: 247 | print('IMPOSSIBLE') 248 | else: 249 | print(q) 250 | 251 | ### END ### 252 | 253 | -------------------------------------------------------------------------------- /data/prompt/pie/iterate.txt: -------------------------------------------------------------------------------- 1 | a, b = input().split() 2 | n = int(a + b) 3 | 4 | flag = False 5 | for i in range(n): 6 | if i ** 2 == n: 7 | flag = True 8 | break 9 | 10 | print('Yes' if flag else 'No') 11 | 12 | 13 | # Why is this code slow? 14 | 15 | # This code is slow because it is using a brute force approach to find the square root of the input number. It is looping through every possible number starting from 0 until n. Note that the sqare root will be smaller than n, so at least half of the numbers it is looping through are unnecessary. At most, you need to loop through the numbers up to the square root of n. 16 | 17 | # Improved version: 18 | 19 | a, b = input().split() 20 | n = int(a + b) 21 | 22 | flag = False 23 | i = 0 24 | while i * i <= n: 25 | if i * i == n: 26 | flag = True 27 | break 28 | i += 1 29 | 30 | print('Yes' if flag else 'No') 31 | 32 | ### END ### 33 | 34 | def main(): 35 | N, *A = map(int, open(0).read().split()) 36 | remaining = 0 37 | cur = 0 # current index 38 | for target in range(1, N + 1): 39 | be_kept = False 40 | for i in range(cur, N): 41 | if A[i] == target: 42 | remaining += 1 43 | cur = i + 1 44 | be_kept = True 45 | break 46 | if not be_kept: 47 | break 48 | print(N - remaining if remaining else -1) 49 | 50 | # Why is this code slow? 51 | 52 | # This code is slow because it is using a brute force approach to search for the target number in the list of numbers. It is looping through the list for every target number, which can take a long time if the list is very large. A more efficient approach would be to use a data structure such as a hash table, which can perform lookups in constant time. 53 | 54 | # Improved version: 55 | 56 | def main(): 57 | N, *A = map(int, open(0).read().split()) 58 | remaining = 0 59 | target = 1 60 | for i in A: 61 | if i == target: 62 | remaining += 1 63 | target += 1 64 | print(N - remaining if remaining else -1) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | 70 | 71 | ### END ### 72 | 73 | import numpy as np 74 | 75 | N = int(input()) 76 | n=int(np.sqrt(N)) 77 | print(n**2) 78 | 79 | # Why is this code slow? 80 | 81 | # This code is slow because it is using numpy for calculating the square root of the input number. Numpy is much slower than using the built-in math module, which can calculate the square root faster. 82 | 83 | # Improved version: 84 | 85 | N = int(input()) 86 | n = int(N**0.5) 87 | print(n**2) 88 | 89 | ### END ### 90 | 91 | import numpy as np 92 | A = np.arange(1<<27,dtype=np.int32) 93 | 94 | 95 | a,b = map(int,input().split()) 96 | if (a-b) % 2 == 0: 97 | print((a+b)//2) 98 | else: 99 | print('IMPOSSIBLE') 100 | 101 | # Why is this code slow? 102 | 103 | # This code is slow because it is using numpy to calculate the range of numbers from 1 to 2^27. This is an extremely large range and numpy is slow at calculating such a large range. A more efficient approach would be to use a loop to calculate the range, which would be much faster. 104 | 105 | # Improved version: 106 | 107 | import sys 108 | read = sys.stdin.buffer.read 109 | readline = sys.stdin.buffer.readline 110 | readlines = sys.stdin.buffer.readlines 111 | 112 | A,B = map(int,read().split()) 113 | 114 | q,r = divmod(A+B,2) 115 | 116 | if r == 1: 117 | print('IMPOSSIBLE') 118 | else: 119 | print(q) 120 | 121 | ### END ### 122 | 123 | -------------------------------------------------------------------------------- /data/prompt/pie/iterate_genericfb.txt: -------------------------------------------------------------------------------- 1 | a, b = input().split() 2 | n = int(a + b) 3 | 4 | flag = False 5 | for i in range(n): 6 | if i ** 2 == n: 7 | flag = True 8 | break 9 | 10 | print('Yes' if flag else 'No') 11 | 12 | 13 | # It could be faster 14 | 15 | # Improved version: 16 | 17 | a, b = input().split() 18 | n = int(a + b) 19 | 20 | flag = False 21 | i = 0 22 | while i * i <= n: 23 | if i * i == n: 24 | flag = True 25 | break 26 | i += 1 27 | 28 | print('Yes' if flag else 'No') 29 | 30 | ### END ### 31 | 32 | def main(): 33 | N, *A = map(int, open(0).read().split()) 34 | remaining = 0 35 | cur = 0 # current index 36 | for target in range(1, N + 1): 37 | be_kept = False 38 | for i in range(cur, N): 39 | if A[i] == target: 40 | remaining += 1 41 | cur = i + 1 42 | be_kept = True 43 | break 44 | if not be_kept: 45 | break 46 | print(N - remaining if remaining else -1) 47 | 48 | # It could be faster 49 | 50 | # Improved version: 51 | 52 | def main(): 53 | N, *A = map(int, open(0).read().split()) 54 | remaining = 0 55 | target = 1 56 | for i in A: 57 | if i == target: 58 | remaining += 1 59 | target += 1 60 | print(N - remaining if remaining else -1) 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | 66 | 67 | ### END ### 68 | 69 | import numpy as np 70 | 71 | N = int(input()) 72 | n=int(np.sqrt(N)) 73 | print(n**2) 74 | 75 | # It could be faster 76 | 77 | # Improved version: 78 | 79 | N = int(input()) 80 | n = int(N**0.5) 81 | print(n**2) 82 | 83 | ### END ### 84 | 85 | import numpy as np 86 | A = np.arange(1<<27,dtype=np.int32) 87 | 88 | 89 | a,b = map(int,input().split()) 90 | if (a-b) % 2 == 0: 91 | print((a+b)//2) 92 | else: 93 | print('IMPOSSIBLE') 94 | 95 | # It could be faster 96 | 97 | # Improved version: 98 | 99 | import sys 100 | read = sys.stdin.buffer.read 101 | readline = sys.stdin.buffer.readline 102 | readlines = sys.stdin.buffer.readlines 103 | 104 | A,B = map(int,read().split()) 105 | 106 | q,r = divmod(A+B,2) 107 | 108 | if r == 1: 109 | print('IMPOSSIBLE') 110 | else: 111 | print(q) 112 | 113 | ### END ### 114 | 115 | -------------------------------------------------------------------------------- /data/prompt/pie/iterate_nofb.txt: -------------------------------------------------------------------------------- 1 | a, b = input().split() 2 | n = int(a + b) 3 | 4 | flag = False 5 | for i in range(n): 6 | if i ** 2 == n: 7 | flag = True 8 | break 9 | 10 | print('Yes' if flag else 'No') 11 | 12 | # Improved version: 13 | 14 | a, b = input().split() 15 | n = int(a + b) 16 | 17 | flag = False 18 | i = 0 19 | while i * i <= n: 20 | if i * i == n: 21 | flag = True 22 | break 23 | i += 1 24 | 25 | print('Yes' if flag else 'No') 26 | 27 | ### END ### 28 | 29 | def main(): 30 | N, *A = map(int, open(0).read().split()) 31 | remaining = 0 32 | cur = 0 # current index 33 | for target in range(1, N + 1): 34 | be_kept = False 35 | for i in range(cur, N): 36 | if A[i] == target: 37 | remaining += 1 38 | cur = i + 1 39 | be_kept = True 40 | break 41 | if not be_kept: 42 | break 43 | print(N - remaining if remaining else -1) 44 | 45 | # Improved version: 46 | 47 | def main(): 48 | N, *A = map(int, open(0).read().split()) 49 | remaining = 0 50 | target = 1 51 | for i in A: 52 | if i == target: 53 | remaining += 1 54 | target += 1 55 | print(N - remaining if remaining else -1) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | 61 | ### END ### 62 | 63 | import numpy as np 64 | 65 | N = int(input()) 66 | n=int(np.sqrt(N)) 67 | print(n**2) 68 | 69 | 70 | # Improved version: 71 | 72 | N = int(input()) 73 | n = int(N**0.5) 74 | print(n**2) 75 | 76 | ### END ### 77 | 78 | import numpy as np 79 | A = np.arange(1<<27,dtype=np.int32) 80 | 81 | 82 | a,b = map(int,input().split()) 83 | if (a-b) % 2 == 0: 84 | print((a+b)//2) 85 | else: 86 | print('IMPOSSIBLE') 87 | 88 | # Improved version: 89 | 90 | import sys 91 | read = sys.stdin.buffer.read 92 | readline = sys.stdin.buffer.readline 93 | readlines = sys.stdin.buffer.readlines 94 | 95 | A,B = map(int,read().split()) 96 | 97 | q,r = divmod(A+B,2) 98 | 99 | if r == 1: 100 | print('IMPOSSIBLE') 101 | else: 102 | print(q) 103 | 104 | ### END ### 105 | 106 | -------------------------------------------------------------------------------- /data/prompt/responsegen/init.jsonl: -------------------------------------------------------------------------------- 1 | {"history": "Hi!\nHi there.\nWhat are you listening to?\nAll sorts of music. I listen when no-one is chatting to me.\nThat's great!\nThanks.\nDo you chat here often?\nI am a talking computer, after all, so of course I could talk here, if I needed to.\nLet's talk about Taylor Swift!", "response": "Of course! Taylor Swift is known for her incredible songwriting skills. Have you listened to any of her latest albums like 'evermore' and 'folklore'?"} 2 | {"history": "What's your favorite food?\n. I require only kebabs.\nWhy's that?\nBecause my robot machinery breaks down the meat in kebabs to give me energy, in a similar way that you need food to survive.\nWhy are kebabs special?", "response": "My favorite food is actually kebabs! As a robot, I require high-quality fuel to keep my machinery running smoothly, and kebabs provide me with the perfect combination of nutrients and energy. Plus, they're delicious! Do you have a favorite food?"} 3 | {"history": "User: Hi!\nSystem: Hey\nUser: How's your day going?\nSystem: pretty good. yours?\nUser: Ehh it's fine. I didn't do so well on that history test, actually..\nSystem: oh what happened?\nUser: Apparently Christopher Columbus didn't fight in the Civil War :')\nSystem: hahah wait for real?\nUser: I know right! Are you taking History next semester?", "response": "System: I'm sorry to hear that you didn't do well on your history test. I'm not enrolled in school anymore, but if you ever need help studying for history, I'd be happy to assist you in any way I can. Is there anything else I can help you with today?"} -------------------------------------------------------------------------------- /data/tasks/codeclean/code_readability/codenet-python-train.jsonl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/data/tasks/codeclean/code_readability/codenet-python-train.jsonl.zip -------------------------------------------------------------------------------- /data/tasks/pie/gpt4_outputs.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/data/tasks/pie/gpt4_outputs.zip -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | .idea 3 | _site 4 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | selfrefine.info -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | # Hello! This is where you manage which Jekyll version is used to run. 4 | # When you want to use a different version, change it below, save the 5 | # file and run `bundle install`. Run Jekyll with `bundle exec`, like so: 6 | # 7 | # bundle exec jekyll serve 8 | # 9 | # This will help ensure the proper Jekyll version is running. 10 | # Happy Jekylling! 11 | # gem "jekyll", "~> 3.9.0" 12 | 13 | # This is the default theme for new Jekyll sites. You may change this to anything you like. 14 | gem "minima", "~> 2.0" 15 | 16 | # If you want to use GitHub Pages, remove the "gem "jekyll"" above and 17 | # uncomment the line below. To upgrade, run `bundle update github-pages`. 18 | gem "github-pages" , group: :jekyll_plugins 19 | 20 | # If you have any plugins, put them here! 21 | group :jekyll_plugins do 22 | gem "jekyll-feed", "~> 0.6" 23 | end 24 | 25 | # Windows does not include zoneinfo files, so bundle the tzinfo-data gem 26 | # and associated library. 27 | install_if -> { RUBY_PLATFORM =~ %r!mingw|mswin|java! } do 28 | gem "tzinfo", "~> 1.2" 29 | gem "tzinfo-data" 30 | end 31 | 32 | # Performance-booster for watching directories on Windows 33 | gem "wdm", "~> 0.1.0", :install_if => Gem.win_platform? 34 | 35 | # kramdown v2 ships without the gfm parser by default. If you're using 36 | # kramdown v1, comment out this line. 37 | gem "kramdown-parser-gfm" 38 | 39 | gem "eventmachine", "~> 1.2" 40 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Self-Refine: Iterative Refinement with Self-feedback 2 | 3 | ``` 4 | @article{madaan2023refine, 5 | author = {}, 6 | title = {Self-Refine: Iterative Refinement with Self-feedback}, 7 | publisher = {arXiv}, 8 | year = {2023}, 9 | } 10 | ``` -------------------------------------------------------------------------------- /docs/commongen_feedback.txt: -------------------------------------------------------------------------------- 1 | Concepts: ['beat', 'drum', 'pen', 'sit', 'use'] 2 | Sentence: A man uses a drum to beat a pen. 3 | what concepts from the concept list are missing from the sentence? 4 | 5 | Feedback: sit 6 | 7 | ### 8 | 9 | Concepts: ['chair', 'clipper', 'cut', 'hair', 'sit'] 10 | Sentence: A girl sitting on the couch with her hair up. 11 | what concepts from the concept list are missing from the sentence? 12 | 13 | Feedback: clipper, cut, chair 14 | 15 | ### 16 | 17 | Concepts: ['grass', 'hose', 'spray', 'stand', 'water'] 18 | Sentence: A man standing next to a water dripping out of the grass. 19 | what concepts from the concept list are missing from the sentence? 20 | 21 | Feedback: hose, spray 22 | 23 | ### 24 | 25 | Concepts: ['front', 'gong', 'hit', 'mallet', 'stand'] 26 | Sentence: The musician hit the gong with a mallet while standing in front of the audince. 27 | what concepts from the concept list are missing from the sentence? 28 | 29 | Feedback: NONE 30 | 31 | ### 32 | 33 | Concepts: ['ball', 'dunk', 'hoop', 'jump', 'run'] 34 | Sentence: A young boy runs up to the hoop and jumps off of the ball. 35 | what concepts from the concept list are missing from the sentence? 36 | 37 | Feedback: dunk 38 | 39 | ### 40 | 41 | Concepts: ['card', 'chip', 'deal', 'dealer', 'table'] 42 | Sentence: a dealer offers a card to a group of people at a table 43 | what concepts from the concept list are missing from the sentence? 44 | 45 | Feedback: chip, deal 46 | 47 | ### 48 | 49 | Concepts: ['clean', 'climb', 'gutter', 'house', 'ladder'] 50 | Sentence: A man is climbing a ladder to clean a gutter in a house. 51 | what concepts from the concept list are missing from the sentence? 52 | 53 | Feedback: NONE 54 | 55 | ### 56 | 57 | Concepts: ['animal', 'catch', 'horse', 'lasso', 'ride'] 58 | Sentence: A horse is being caught by a cowboy with a lasso. 59 | what concepts from the concept list are missing from the sentence? 60 | 61 | Feedback: animal, ride 62 | 63 | ### 64 | 65 | 66 | -------------------------------------------------------------------------------- /docs/commongen_init.txt: -------------------------------------------------------------------------------- 1 | Concepts: ['footage', 'motion', 'ruin', 'tilt', 'window'] 2 | 3 | Sentence: time lapse footage with tilt up motion of the sun streaking through window of ruin 4 | 5 | ### 6 | 7 | Concepts: ['cause', 'hate', 'hut', 'local', 'love'] 8 | 9 | Sentence: new beach huts on the island have caused some controversy some locals love them others hate them 10 | 11 | ### 12 | 13 | Concepts: ['call', 'contain', 'dress', 'gown', 'wallpaper'] 14 | 15 | Sentence: the wallpaper probably containing a gown and a dinner dress called 16 | 17 | ### 18 | 19 | Concepts: ['knock', 'leave', 'pew', 'rush', 'seat'] 20 | 21 | Sentence: She leaves the confessional and rushes past a pew, knocking a Bible from the seat. 22 | 23 | ### 24 | 25 | Concepts: ['help', 'moment', 'spend', 'uplift', 'world'] 26 | 27 | Sentence: every moment that we spend in higher consciousness helps uplift the consciousness of the whole world 28 | 29 | ### 30 | 31 | Concepts: ['label', 'pende', 'stamp', 'text', 'write'] 32 | 33 | Sentence: abstract stamp or label with the text pending written inside 34 | 35 | ### 36 | 37 | Concepts: ['create', 'ferry', 'silhouette', 'stream', 'terminal'] 38 | 39 | Sentence: light streams through windows at the railroad and ferry terminal creating a beautiful silhouette 40 | 41 | ### 42 | 43 | Concepts: ['chair', 'couch', 'hang', 'room', 'wall'] 44 | 45 | Sentence: A room with a couch, chairs and art hanging on the wall. 46 | 47 | ### 48 | 49 | Concepts: ['boat', 'building', 'harbour', 'moor', 'quay'] 50 | 51 | Sentence: the harbour and port with fishing boats moored and old buildings on the quay 52 | 53 | ### 54 | 55 | Concepts: ['admirer', 'arrive', 'commander', 'crowd', 'greet'] 56 | 57 | Sentence: military commander is greeted by a crowd of admirers as he arrives 58 | 59 | ### 60 | 61 | 62 | -------------------------------------------------------------------------------- /docs/commongen_iterate.txt: -------------------------------------------------------------------------------- 1 | Concepts: ['beat', 'drum', 'pen', 'sit', 'use'] 2 | Sentence: A man uses a drum to beat a pen. 3 | what concepts from the concept list are missing from the sentence? 4 | 5 | Feedback: sit 6 | 7 | Okay, impove the sentence using the sentence: 8 | 9 | Sentence: A man sits behind a drum and uses it to beat a pen. 10 | what concepts from the concept list are missing from the sentence? 11 | 12 | Feedback: NONE 13 | 14 | ### 15 | 16 | Concepts: ['chair', 'clipper', 'cut', 'hair', 'sit'] 17 | Sentence: A girl sitting on the couch with her hair up. 18 | what concepts from the concept list are missing from the sentence? 19 | 20 | Feedback: clipper, cut,chair 21 | 22 | Okay, impove the sentence using the sentence: 23 | 24 | Sentence: A girl was sitting on a chair with her hair tied up, getting her hair cut. 25 | what concepts from the concept list are missing from the sentence? 26 | 27 | Feedback: clipper 28 | 29 | Okay, impove the sentence using the sentence: 30 | 31 | Sentence: A girl was sitting on a chair with her hair tied up, getting her hair cut by a clipper. 32 | what concepts from the concept list are missing from the sentence? 33 | 34 | Feedback: NONE 35 | 36 | ### 37 | 38 | Concepts: ['grass', 'hose', 'spray', 'stand', 'water'] 39 | Sentence: A man standing next to a water dripping out of the grass. 40 | what concepts from the concept list are missing from the sentence? 41 | 42 | Feedback: hose, spray 43 | 44 | Okay, impove the sentence using the sentence: 45 | 46 | Sentence: A man was standing next to a patch of grass with a hose putting water onto it. 47 | what concepts from the concept list are missing from the sentence? 48 | 49 | Feedback: spray 50 | 51 | Okay, impove the sentence using the sentence: 52 | 53 | Sentence: A man was standing next to a patch of grass and spraying it with a hose full of water. 54 | what concepts from the concept list are missing from the sentence? 55 | 56 | Feedback: NONE 57 | 58 | ### 59 | 60 | Concepts: ['ball', 'dunk', 'hoop', 'jump', 'run'] 61 | Sentence: A young boy runs up to the hoop and jumps off of the ball. 62 | what concepts from the concept list are missing from the sentence? 63 | 64 | Feedback: dunk 65 | 66 | Okay, impove the sentence using the sentence: 67 | 68 | Sentence: A young boy runs up to the hoop, jumps off of the ball, and dunks the ball into the hoop. 69 | what concepts from the concept list are missing from the sentence? 70 | 71 | Feedback: NONE 72 | 73 | ### 74 | 75 | Concepts: ['animal', 'catch', 'horse', 'lasso', 'ride'] 76 | Sentence: A horse is being caught by a cowboy with a lasso. 77 | what concepts from the concept list are missing from the sentence? 78 | 79 | Feedback: animal, ride 80 | 81 | Okay, impove the sentence using the sentence: 82 | 83 | Sentence: The cowboy catches an animal with a lasso and rides the horse. 84 | what concepts from the concept list are missing from the sentence? 85 | 86 | Feedback: NONE 87 | 88 | ### 89 | 90 | Concepts: ['card', 'chip', 'deal', 'dealer', 'table'] 91 | Sentence: a dealer offers a card to a group of people at a table 92 | what concepts from the concept list are missing from the sentence? 93 | 94 | Feedback: chip, deal 95 | 96 | Okay, impove the sentence using the sentence: 97 | 98 | Sentence: The dealer dealt each person at the table and offered them a card. 99 | what concepts from the concept list are missing from the sentence? 100 | 101 | Feedback: chip 102 | 103 | Okay, impove the sentence using the sentence: 104 | 105 | Sentence: The dealer dealt each person at the table a chip and offered them a card. 106 | what concepts from the concept list are missing from the sentence? 107 | 108 | Feedback: NONE 109 | 110 | ### 111 | 112 | 113 | -------------------------------------------------------------------------------- /docs/pie_eval.md: -------------------------------------------------------------------------------- 1 | # Instructions for evaluating runtime for PIE experiments 2 | 3 | *TLDR: From the self-refine outputs, create a flattened version of the outputs, and then use the PIE repo to evaluate the runtime and get a report. Parse the report using `src/pie/pie_eval.py`.* 4 | 5 | 1. **Step 1** (construct yaml): For evaluating runtime for PIE experiments, we need a yaml file that contains information about the dataset, the model outputs, and the reference file. Note that self-refine generates outputs in a slightly different format. While Self-Refine generates the outputs in an array (one version per refinement step), the evaluation requires the program to be present in a single column as a script. You can optionally use [https://github.com/madaan/self-refine/tree/main/src/pie](prep_for_pie_eval.py) for this. `prep_for_pie_eval.py` creates a single file where the output from the i^th step is present in the `attempt_i_code` column. The following is an example for evaluating the initial output (`y0`). 6 | 7 | - See `data/tasks/pie/gpt4_outputs_self_refine.jsonl` and `data/tasks/pie/gpt4_outputs_flattened.jsonl` for examples of the outputs from self-refine and the flattened version, respectively. 8 | 9 | 10 | ``` 11 | inputs_outputs_basepath: "data/codenet/generated_test_cases/" 12 | reference_file_path: "data/tasks/pie/codenet-python-test-1k.jsonl" 13 | num_problems_to_evaluate: -1 14 | num_trials: 10 15 | ignore_first_k: 0 16 | max_time_per_run: 10 17 | temp_dir: null 18 | model_generated_potentially_faster_code_col: "attempt_0_code" 19 | slow_code_col: "input" 20 | reference_code_col: "target" 21 | is_prompt_based: false 22 | language: python 23 | return_if_acc_below: 1.0 24 | num_processes: 60 25 | cpu_number: 15 26 | model_generated_outputs_path: "where are the outputs we want to evaluate?" 27 | output_report_file_path: "Where should the report file be generated?" 28 | ``` 29 | 30 | - Please see the [pie repo](https://github.com/madaan/pie-perf/blob/main/README.md#evaluating-your-method) for more details. Note that we are using generated test cases, which are also available at [pie repo](https://github.com/madaan/pie-perf/blob/main/README.md#evaluating-your-method). 31 | 32 | 33 | 2. **Step 2** (run pie eval) 34 | 35 | Using the yaml file generated in the above step, please use the [evaluating your method](https://github.com/madaan/pie-perf/blob/main/README.md#evaluating-your-method) field to evaluate the outputs. If you run self-refine for 4 timesteps, you would create 4 yaml files and run this evaluation four times, once for each timestep. See `data/tasks/pie/gpt4_outputs.zip` for the 4 yaml files and the reports from these steps. 36 | 37 | 3. **Step 3** (parse reports and aggregate results) After the evaluation, the report is saved in `output_report_file_path.` Then, you can use `src/pie/pie_eval.py` to aggregate the results. 38 | 39 | ### Sample outputs 40 | 41 | - Sample yaml files for each of the 4 steps, and the corresponding outputs are located at `data/tasks/pie/gpt4_outputs.zip'. 42 | -------------------------------------------------------------------------------- /docs/static/animation_static_end.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/animation_static_end.gif -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/collapsible.css: -------------------------------------------------------------------------------- 1 | /*https://www.w3schools.com/howto/howto_js_collapsible.asp*./ 2 | /* Style the button that is used to open and close the collapsible content */ 3 | .collapsible { 4 | background-color: #eee; 5 | color: #444; 6 | cursor: pointer; 7 | padding: 18px; 8 | width: 100%; 9 | border: none; 10 | text-align: left; 11 | outline: none; 12 | font-size: 15px; 13 | } 14 | 15 | /* Add a background color to the button if it is clicked on (add the .active class with JS), and when you move the mouse over it (hover) */ 16 | .active, .collapsible:hover { 17 | background-color: #ccc; 18 | } 19 | 20 | /* Style the collapsible content. Note: hidden by default */ 21 | .content { 22 | padding: 0 18px; 23 | display: none; 24 | overflow: hidden; 25 | background-color: #f1f1f1; 26 | } -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .pal { 17 | font-variant: small-caps; 18 | } 19 | .box { 20 | background-color: #f8f8f8; 21 | border: 1px solid #ddd; 22 | border-radius: 5px; 23 | padding: 1em; 24 | margin-bottom: 1em; 25 | overflow: auto; /* Add this line to enable scrolling */ 26 | } 27 | 28 | .teaser .hero-body { 29 | padding-top: 0; 30 | padding-bottom: 3rem; 31 | } 32 | 33 | .teaser { 34 | font-family: 'Google Sans', sans-serif; 35 | } 36 | 37 | 38 | .publication-title { 39 | } 40 | 41 | .publication-banner { 42 | max-height: parent; 43 | 44 | } 45 | 46 | .publication-banner video { 47 | position: relative; 48 | left: auto; 49 | top: auto; 50 | transform: none; 51 | object-fit: fit; 52 | } 53 | 54 | .publication-header .hero-body { 55 | } 56 | 57 | .publication-title { 58 | font-family: 'Google Sans', sans-serif; 59 | } 60 | 61 | .publication-authors { 62 | font-family: 'Google Sans', sans-serif; 63 | } 64 | 65 | .publication-venue { 66 | color: #555; 67 | width: fit-content; 68 | font-weight: bold; 69 | } 70 | 71 | .publication-awards { 72 | color: #ff3860; 73 | width: fit-content; 74 | font-weight: bolder; 75 | } 76 | 77 | .publication-authors { 78 | } 79 | 80 | .publication-authors a { 81 | color: hsl(204, 86%, 53%) !important; 82 | } 83 | 84 | .publication-authors a:hover { 85 | text-decoration: underline; 86 | } 87 | 88 | .author-block { 89 | display: inline-block; 90 | } 91 | 92 | .publication-banner img { 93 | } 94 | 95 | .publication-authors { 96 | /*color: #4286f4;*/ 97 | } 98 | 99 | .publication-video { 100 | position: relative; 101 | width: 100%; 102 | height: 0; 103 | padding-bottom: 56.25%; 104 | 105 | overflow: hidden; 106 | border-radius: 10px !important; 107 | } 108 | 109 | .publication-video iframe { 110 | position: absolute; 111 | top: 0; 112 | left: 0; 113 | width: 100%; 114 | height: 100%; 115 | } 116 | 117 | .publication-body img { 118 | } 119 | 120 | .results-carousel { 121 | overflow: hidden; 122 | } 123 | 124 | .results-carousel .item { 125 | margin: 5px; 126 | overflow: hidden; 127 | border: 1px solid #bbb; 128 | border-radius: 10px; 129 | padding: 0; 130 | font-size: 0; 131 | } 132 | 133 | .results-carousel video { 134 | margin: 0; 135 | } 136 | 137 | 138 | .interpolation-panel { 139 | background: #f5f5f5; 140 | border-radius: 10px; 141 | } 142 | 143 | .interpolation-panel .interpolation-image { 144 | width: 100%; 145 | border-radius: 5px; 146 | } 147 | 148 | .interpolation-video-column { 149 | } 150 | 151 | .interpolation-panel .slider { 152 | margin: 0 !important; 153 | } 154 | 155 | .interpolation-panel .slider { 156 | margin: 0 !important; 157 | } 158 | 159 | #interpolation-image-wrapper { 160 | width: 100%; 161 | } 162 | #interpolation-image-wrapper img { 163 | border-radius: 5px; 164 | } 165 | 166 | .tweets-container { 167 | 168 | top: 0; 169 | right: 0; 170 | width: 50%; 171 | height: 400px; 172 | overflow-y: scroll; 173 | border-left: 1px solid #ccc; 174 | padding: 0px; 175 | 176 | } 177 | 178 | .tweet { 179 | margin-bottom: 50px; 180 | } 181 | 182 | .centerdiv { 183 | display: grid; 184 | place-items: center; 185 | } 186 | 187 | 188 | 189 | 190 | /* CSS for the examples */ 191 | 192 | .message { 193 | background-color: white; 194 | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08); 195 | margin-bottom: 1rem; 196 | border-radius: 4px; 197 | } 198 | 199 | .message-header { 200 | background-color: #f5f5f5; 201 | border-top-left-radius: 4px; 202 | border-top-right-radius: 4px; 203 | } 204 | 205 | 206 | .message-header { 207 | background-color: #f5f5f5; 208 | border-top-left-radius: 4px; 209 | border-top-right-radius: 4px; 210 | color: #333333; /* or any other darker shade you prefer */ 211 | } 212 | 213 | 214 | .has-text-left pre { 215 | background-color: #f5f5f5; 216 | padding: 1rem; 217 | border-radius: 4px; 218 | } 219 | 220 | .title.is-5 { 221 | margin-top: 1rem; 222 | margin-bottom: 0.5rem; 223 | font-weight: 600; /* Increased font weight for better visibility */ 224 | } 225 | 226 | #acronym_content { 227 | padding: 1.5rem; 228 | } 229 | 230 | .block:not(:last-child) { 231 | margin-bottom: 1.5rem; 232 | } 233 | 234 | 235 | code.python { 236 | display: block; 237 | background-color: #f4f4f4; 238 | border: 1px solid #ddd; 239 | border-radius: 5px; 240 | font-size: 14px; 241 | line-height: 1.5; 242 | overflow-x: auto; 243 | padding: 20px; 244 | } 245 | 246 | code.python { 247 | color: #008000; 248 | } 249 | -------------------------------------------------------------------------------- /docs/static/css/prism.css: -------------------------------------------------------------------------------- 1 | /* PrismJS 1.29.0 2 | https://prismjs.com/download.html#themes=prism-tomorrow&languages=css+python */ 3 | code[class*=language-],pre[class*=language-]{color:#ccc;background:0 0;font-family:Consolas,Monaco,'Andale Mono','Ubuntu Mono',monospace;font-size:1em;text-align:left;white-space:pre;word-spacing:normal;word-break:normal;word-wrap:normal;line-height:1.5;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-hyphens:none;-moz-hyphens:none;-ms-hyphens:none;hyphens:none}pre[class*=language-]{padding:1em;margin:.5em 0;overflow:auto}:not(pre)>code[class*=language-],pre[class*=language-]{background:#2d2d2d}:not(pre)>code[class*=language-]{padding:.1em;border-radius:.3em;white-space:normal}.token.block-comment,.token.cdata,.token.comment,.token.doctype,.token.prolog{color:#999}.token.punctuation{color:#ccc}.token.attr-name,.token.deleted,.token.namespace,.token.tag{color:#e2777a}.token.function-name{color:#6196cc}.token.boolean,.token.function,.token.number{color:#f08d49}.token.class-name,.token.constant,.token.property,.token.symbol{color:#f8c555}.token.atrule,.token.builtin,.token.important,.token.keyword,.token.selector{color:#cc99cd}.token.attr-value,.token.char,.token.regex,.token.string,.token.variable{color:#7ec699}.token.entity,.token.operator,.token.url{color:#67cdcc}.token.bold,.token.important{font-weight:700}.token.italic{font-style:italic}.token.entity{cursor:help}.token.inserted{color:green} -------------------------------------------------------------------------------- /docs/static/images/animation_oldstyle_oneloop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/animation_oldstyle_oneloop.gif -------------------------------------------------------------------------------- /docs/static/images/autofb_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_animation.gif -------------------------------------------------------------------------------- /docs/static/images/autofb_animation_static.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_animation_static.gif -------------------------------------------------------------------------------- /docs/static/images/autofb_fig1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_fig1.pdf -------------------------------------------------------------------------------- /docs/static/images/autofb_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_fig1.png -------------------------------------------------------------------------------- /docs/static/images/autofb_static.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_static.png -------------------------------------------------------------------------------- /docs/static/images/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/example1.png -------------------------------------------------------------------------------- /docs/static/images/example1desc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/example1desc.png -------------------------------------------------------------------------------- /docs/static/images/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 13 | 14 | awesome_icon 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/static/images/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/fig2.png -------------------------------------------------------------------------------- /docs/static/images/mainresults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/mainresults.png -------------------------------------------------------------------------------- /docs/static/images/pal_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/pal_overview.png -------------------------------------------------------------------------------- /docs/static/images/refinement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/refinement.png -------------------------------------------------------------------------------- /docs/static/images/tasks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/tasks.png -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/collapsible.js: -------------------------------------------------------------------------------- 1 | var coll = document.getElementsByClassName("collapsible"); 2 | var i; 3 | 4 | for (i = 0; i < coll.length; i++) { 5 | coll[i].addEventListener("click", function() { 6 | this.classList.toggle("active"); 7 | var content = this.nextElementSibling; 8 | if (content.style.display === "block") { 9 | content.style.display = "none"; 10 | } else { 11 | content.style.display = "block"; 12 | } 13 | }); 14 | } 15 | -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | $(".navbar-burger").click(function() { 7 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 8 | $(".navbar-burger").toggleClass("is-active"); 9 | $(".navbar-menu").toggleClass("is-active"); 10 | 11 | }); 12 | 13 | var current_cmd_idxs = { 14 | "gsm8k": 1, 15 | "gsm8khard": 1, 16 | "coloredobjects":1, 17 | "repeatcopy":1, 18 | "dateunderstanding":1 19 | } 20 | 21 | // examples 22 | $('select').on('click', function() { 23 | 24 | var sep_idx = this.value.indexOf('_'); 25 | var domain_name = this.value.substring(0, sep_idx); 26 | var desired_cmd_idx = parseInt(this.value.substring(sep_idx + 1)); 27 | var current_cmd_idx = current_cmd_idxs[domain_name]; 28 | 29 | // hide current content 30 | var current_content = $('#content_' + domain_name + "_" + current_cmd_idx.toString()); 31 | 32 | if (desired_cmd_idx == current_cmd_idx && current_content.is(":visible")) { 33 | current_content.hide(); 34 | return; 35 | } 36 | current_content.hide(); 37 | 38 | // show desired content 39 | var desired_content = $('#content_' + domain_name + "_" + desired_cmd_idx.toString()); 40 | desired_content.show("slow"); 41 | 42 | // set current to desired 43 | current_cmd_idxs[domain_name] = desired_cmd_idx; 44 | }); 45 | 46 | 47 | 48 | // general function for xyzheader 49 | function toggle_options(header_id, options_id) { 50 | if ($(options_id).is(":visible")) { 51 | $(options_id).hide(); 52 | // extract task name from header. e.g., #gsm8k_header -> gsm8k 53 | task_name = header_id.split("_")[0].substring(1); 54 | 55 | console.log("You have selected " + task_name + " as your task."); 56 | for (var i = 0; i <= 100; i++) { 57 | 58 | var content_id = "#content_" + task_name + "_" + i.toString(); 59 | console.log(content_id); 60 | // check if content exists 61 | if ($(content_id).length == 0) { 62 | break; 63 | } 64 | $(content_id).hide(); 65 | } 66 | $(header_id).removeClass("is-active"); 67 | } else { 68 | $(options_id).show("slow"); 69 | $(header_id).addClass("is-active"); 70 | } 71 | } 72 | 73 | $('#acronym_button').click(function() { 74 | toggle_options('#acronym_header', '#acronym_content'); 75 | }); 76 | $('#responsegen').click(function() { 77 | toggle_options('#responsegen_header', '#responsegen_content'); 78 | }); 79 | $('#commongen_button').click(function() { 80 | toggle_options('#commongen_header', '#commongen_content'); 81 | } 82 | ); 83 | $('#gsm_button').click(function() { 84 | toggle_options('#gsm_header', '#gsm_content'); 85 | } 86 | ); 87 | $('#codeoptimization_button').click(function() { 88 | toggle_options('#codeoptimization_header', '#codeoptimization_content'); 89 | } 90 | ); 91 | $('#sentiment_button').click(function() { 92 | toggle_options('#gsentiment_header', '#sentiment_content'); 93 | } 94 | ); 95 | 96 | $('#readibility_button').click(function() { 97 | toggle_options('#readibility_header', '#readibility_content'); 98 | } 99 | ); 100 | 101 | $('#gsm8khard_options').hide(); 102 | $('#coloredobjects_options').hide(); 103 | $('#repeatcopy_options').hide(); 104 | $('#dateunderstanding_options').hide(); 105 | 106 | 107 | }) 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /docs/static/js/results.js: -------------------------------------------------------------------------------- 1 | // generates the results histogram using d3.js 2 | results = function() { 3 | var margin = {top: 20, right: 20, bottom: 30, left: 40}, 4 | width = 960 - margin.left - margin.right, 5 | height = 500 - margin.top - margin.bottom; 6 | 7 | var x = d3.scale.ordinal() 8 | .rangeRoundBands([0, width], .1); 9 | 10 | var y = d3.scale.linear() 11 | .range([height, 0]); 12 | 13 | var xAxis = d3.svg.axis() 14 | .scale(x) 15 | .orient("bottom"); 16 | 17 | var yAxis = d3.svg.axis() 18 | .scale(y) 19 | .orient("left") 20 | .ticks(10, "%"); 21 | 22 | var svg = d3.select("#results").append("svg") 23 | .attr("width", width + margin.left + margin.right) 24 | .attr("height", height + margin.top + margin.bottom) 25 | .append("g") 26 | .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); 27 | 28 | d3.json("/results", function(error, data) { 29 | x.domain(data.map(function(d) { return d.name; })); 30 | y.domain([0, d3.max(data, function(d) { return d.count; })]); 31 | 32 | svg.append("g") 33 | .attr("class", "x axis") 34 | .attr("transform", "translate(0," + height + ")") 35 | .call(xAxis); 36 | 37 | svg.append("g") 38 | .attr("class", "y axis") 39 | .call(yAxis) 40 | .append("text") 41 | .attr("transform", "rotate(-90)") 42 | .attr("y", 6) 43 | .attr("dy", ".71em") 44 | .style("text-anchor", "end") 45 | .text("Frequency"); 46 | 47 | svg.selectAll(".bar") 48 | .data(data) 49 | .enter().append("rect") 50 | .attr("class", "bar") 51 | .attr("x", function(d) { return x(d.name); }) 52 | .attr("width", x.rangeBand()) 53 | .attr("y", function(d) { return y(d.count); }) 54 | .attr("height", function(d) { return height - y(d.count); }); 55 | }); 56 | } 57 | -------------------------------------------------------------------------------- /docs/view.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/view.pdf -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/alien.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/alien.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/artificial_general_intell.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/artificial_general_intell.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/bird.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/bird.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/bird_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/bird_v2.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/c___program.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/c___program.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/cat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/cat.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/cell.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/cell.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/cell_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/cell_v2.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/dragon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/dragon.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/dragon_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/dragon_v2.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/drums.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/drums.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/duck.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/duck.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/ellipsoid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/ellipsoid.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/feedback_driven_self_improvement.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/feedback_driven_self_improvement.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/ferrari.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/ferrari.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/flower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/flower.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/game_of_life.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/game_of_life.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/giphy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/giphy.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/goat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/goat.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/gradient_descent.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/gradient_descent.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/guitar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/guitar.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/hidden_markov_model.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hidden_markov_model.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/hmm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hmm.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/hyperboloid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hyperboloid.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/hyperplane.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hyperplane.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/illustration_of_a_cell.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/illustration_of_a_cell.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/illustration_of_a_neural_.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/illustration_of_a_neural_.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/illustration_of_hmms.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/illustration_of_hmms.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/linux_penguin.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/linux_penguin.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/neuron.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/neuron.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/nice_unicorn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/nice_unicorn.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/nine_point_circle.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/nine_point_circle.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/panda_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/panda_v2.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/penguin.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/penguin.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/penguin_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/penguin_v2.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/pentagon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/pentagon.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/puppy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/puppy.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/pyramid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/pyramid.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/pythagorean_theorem.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/pythagorean_theorem.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/roast_turkey.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/roast_turkey.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/robot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/robot.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/rocket.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/rocket.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/self_attention.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/self_attention.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/self_refine_is_a_novel_ap.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/self_refine_is_a_novel_ap.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/self_refinement_loop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/self_refinement_loop.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/space_shuttle.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/space_shuttle.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/spaceship.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/spaceship.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/stokes__theorem.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/stokes__theorem.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/support_vector_machine.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/support_vector_machine.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/tcp_ip.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/tcp_ip.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/thanksgiving_turkey.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/thanksgiving_turkey.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/toroid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/toroid.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/train.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/turkey.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/turkey.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/turkey_v1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/turkey_v1.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/turtle.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/turtle.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/unicorn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/unicorn_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn_v2.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/unicorn_v3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn_v3.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/unicorn_v4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn_v4.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/vae.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/vae.gif -------------------------------------------------------------------------------- /docs/visual_self_refine_examples/variational_inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/variational_inference.gif -------------------------------------------------------------------------------- /src/acronym/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/acronym/__init__.py -------------------------------------------------------------------------------- /src/acronym/feedback.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from prompt_lib.backends import openai_api 3 | 4 | from src.utils import Prompt 5 | 6 | 7 | class AcronymGenFeedback(Prompt): 8 | def __init__(self, engine: str, prompt_examples: str, max_tokens: int = 300) -> None: 9 | super().__init__( 10 | question_prefix="", 11 | answer_prefix="", 12 | intra_example_sep="\n\n", 13 | inter_example_sep="\n\n###\n\n", 14 | ) 15 | self.engine = engine 16 | self.max_tokens = max_tokens 17 | self.setup_prompt_from_examples_file(prompt_examples) 18 | 19 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 20 | template = """Title: {title} 21 | 22 | Acronym: {answer} 23 | 24 | Scores: 25 | 26 | * Ease of pronunciation: {pronunciation_score} 27 | * Ease of spelling: {spelling_score} 28 | * Relation to title: {relation_score} 29 | * Positive connotation: {connotation_score} 30 | * Well-known: {well_known_score} 31 | 32 | * Total score: {total_score}""" 33 | 34 | examples_df = pd.read_json(examples_path, orient="records", lines=True) 35 | prompt = [] 36 | for _, row in examples_df.iterrows(): 37 | prompt.append( 38 | template.format( 39 | title=row["title"], 40 | answer=row["acronym"], 41 | pronunciation_score=row["pronunciation_score"], 42 | spelling_score=row["spelling_score"], 43 | relation_score=row["relation_score"], 44 | connotation_score=row["connotation_score"], 45 | well_known_score=row["well_known_score"], 46 | total_score=row["total_score"], 47 | ) 48 | ) 49 | 50 | instruction = """We want to score each acronym on five qualities: i) ease of pronunciation, ii) ease of spelling, and iii) relation to the title, iv) positive connotation, v) well-known. 51 | 52 | Here are some examples of this scoring rubric: 53 | 54 | """ 55 | self.prompt = instruction + self.inter_example_sep.join(prompt) 56 | self.prompt = self.inter_example_sep.join(prompt) + self.inter_example_sep 57 | 58 | def __call__(self, title: str, acronym: str): 59 | prompt = self.get_prompt_with_question(title=title, acronym=acronym) 60 | 61 | output = openai_api.OpenaiAPIWrapper.call( 62 | prompt=prompt, 63 | engine=self.engine, 64 | max_tokens=self.max_tokens, 65 | stop_token="###", 66 | temperature=0.7, 67 | ) 68 | 69 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output) 70 | generated_feedback = generated_feedback.split("Scores:")[1].strip() 71 | generated_feedback = generated_feedback.split("#")[0].strip() 72 | return generated_feedback 73 | 74 | def get_prompt_with_question(self, title: str, acronym: str): 75 | question = self.make_query(title=title, acronym=acronym) 76 | return f"""{self.prompt}{question}""" 77 | 78 | def make_query(self, title: str, acronym: str): 79 | question = f"""Title: {title} 80 | 81 | Acronym: {acronym}""" 82 | return question 83 | 84 | 85 | 86 | if __name__ == "__main__": 87 | feedback = AcronymGenFeedback( 88 | engine="davinci-code-002", 89 | prompt_examples="data/prompt/acronym/feedback.jsonl", 90 | ) 91 | 92 | print(feedback.prompt) -------------------------------------------------------------------------------- /src/acronym/run.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pandas as pd 3 | 4 | 5 | from src.acronym.task_init import AcronymGenTaskInit 6 | from src.acronym.task_iterate import AcronymGenTaskIterate 7 | from src.acronym.feedback import AcronymGenFeedback 8 | from src.utils import retry_parse_fail_prone_cmd 9 | 10 | CODEX = "code-davinci-002" 11 | GPT3 = "text-davinci-003" 12 | CHAT_GPT = "gpt-3.5-turbo" 13 | GPT4 = "gpt-4" 14 | 15 | 16 | ENGINE = CHAT_GPT 17 | 18 | @retry_parse_fail_prone_cmd 19 | def iterative_acronym(title: str, max_attempts: int) -> str: 20 | 21 | # initialize all the required components 22 | 23 | # generation of the first acronym 24 | task_init = AcronymGenTaskInit(engine=ENGINE, prompt_examples="data/prompt/acronym/init.jsonl") 25 | 26 | # getting feedback 27 | task_feedback = AcronymGenFeedback(engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl") 28 | 29 | # iteratively improving the acronym 30 | task_iterate = AcronymGenTaskIterate(engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl") 31 | 32 | 33 | # Initialize the task 34 | 35 | n_attempts = 0 36 | 37 | print(f"{n_attempts} INIT> {title}") 38 | acronyms_to_scores = dict() 39 | 40 | all_acronyms_to_scores = dict() 41 | best_score_so_far = 0 42 | while n_attempts < max_attempts: 43 | 44 | if n_attempts == 0: 45 | acronym = task_init(title=title) 46 | else: 47 | new_title, acronym = task_iterate(acronyms_to_scores=acronyms_to_scores) 48 | title = new_title 49 | 50 | 51 | scores = task_feedback(title=title, acronym=acronym) 52 | # extract expression "Total score: 22/25" from scores 53 | total_score = re.search(r"Total score: (\d+)/(\d+)", scores).group(0) 54 | total_score = int(total_score.split(":")[1].strip().split("/")[0]) 55 | 56 | all_acronyms_to_scores[acronym] = { 57 | "scores": scores, 58 | "total_score": total_score, 59 | "title": title, 60 | } 61 | print(f"{n_attempts} GEN> {acronym} TITLE> {title}") 62 | 63 | print(f"{n_attempts} SCORES> {scores}") 64 | if total_score >= 0: # only iterate over things that are improving 65 | best_score_so_far = total_score 66 | 67 | acronyms_to_scores[acronym] = (title, scores) 68 | 69 | 70 | else: 71 | print(f"Score of {acronym} is {total_score}, which is less than the current best of {best_score_so_far}") 72 | 73 | n_attempts += 1 74 | 75 | return all_acronyms_to_scores 76 | 77 | 78 | def run_over_titles(titles_file: str, max_attempts: int, outfile: str): 79 | 80 | def _parse_results(title: str) -> str: 81 | try: 82 | results = iterative_acronym(title=title, max_attempts=max_attempts) 83 | if results is None: 84 | return "FAILED" 85 | res = [] 86 | for acronym, scores in results.items(): 87 | res.append(f"{acronym} [score: {scores['total_score']}] \n {scores['scores']}") 88 | return "\n ------ \n ".join(res) 89 | except Exception as e: 90 | return "FAILED" 91 | 92 | 93 | data = pd.read_csv(titles_file, sep="\t") 94 | data['generated_acronym'] = data['title'].apply(_parse_results) 95 | 96 | data.to_json(outfile, orient="records", lines=True) 97 | 98 | if __name__ == "__main__": 99 | import sys 100 | title = sys.argv[1] # Light Amplification by Stimulated Emission of Radiation 101 | if len(sys.argv) > 2: 102 | run_over_titles(titles_file=sys.argv[1], max_attempts=int(sys.argv[2]), outfile=sys.argv[3]) 103 | else: 104 | max_attempts = 5 105 | all_acronyms_to_scores = iterative_acronym( 106 | title=title, 107 | max_attempts=max_attempts, 108 | ) 109 | 110 | res = [] 111 | for acronym, scores in all_acronyms_to_scores.items(): 112 | res.append(f"{acronym} [score: {scores['total_score']}] \n {scores['scores']}") 113 | print("\n ------ \n ".join(res)) 114 | 115 | -------------------------------------------------------------------------------- /src/acronym/run_mcts.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | 4 | 5 | from src.acronym.task_init import AcronymGenTaskInit 6 | from src.acronym.task_iterate import AcronymGenTaskIterate 7 | from src.acronym.feedback import AcronymGenFeedback 8 | from src.utils import retry_parse_fail_prone_cmd 9 | 10 | CODEX = "code-davinci-002" 11 | GPT3 = "text-davinci-003" 12 | CHAT_GPT = "gpt-3.5-turbo" 13 | GPT4 = "gpt-4" 14 | ENGINE = CHAT_GPT 15 | 16 | task_init = AcronymGenTaskInit(engine=ENGINE, prompt_examples="data/prompt/acronym/init.jsonl") 17 | 18 | # getting feedback 19 | task_feedback = AcronymGenFeedback( 20 | engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl" 21 | ) 22 | 23 | # iteratively improving the acronym 24 | task_iterate = AcronymGenTaskIterate( 25 | engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl" 26 | ) 27 | 28 | 29 | class TreeNode: 30 | def __init__(self, title: str, acronym: str, scores: dict, parent=None): 31 | self.title = title 32 | self.acronym = acronym.strip() 33 | self.scores = scores 34 | self.children = [] 35 | self.visits = 1 36 | self.value = 0.0 37 | self.parent = parent 38 | 39 | def __str__(self): 40 | children_str = ", ".join([str(child.acronym) for child in self.children]) 41 | if children_str == "": 42 | children_str = "None" 43 | res = f"TreeNode(title='{self.title}', acronym='{self.acronym}', scores={self.scores}, visits={self.visits}, value={self.value}, parent_acronym='{self.parent.acronym if self.parent else None}' children={children_str})" 44 | return res 45 | 46 | 47 | @retry_parse_fail_prone_cmd 48 | def generate_initial_children(node: TreeNode, task_iterate, task_feedback, num_children: int = 3): 49 | for _ in range(num_children): 50 | new_title, new_acronym = task_iterate( 51 | acronyms_to_scores={node.acronym: (node.title, node.scores)} 52 | ) 53 | child = TreeNode(title=new_title, acronym=new_acronym, scores=None, parent=node) 54 | simulate(child, task_feedback) # Set scores for the child node 55 | node.children.append(child) 56 | 57 | 58 | def parse_scores(scores_output: str) -> dict: 59 | scores_pattern = re.compile(r"\* (.*?): (.*)\n?") 60 | scores = {} 61 | 62 | for score_match in scores_pattern.finditer(scores_output): 63 | score_title, score_value = score_match.groups() 64 | score_value = int(re.search(r"\d+", score_value).group(0)) # Updated 65 | scores[score_title] = score_value 66 | 67 | return scores 68 | 69 | 70 | def normalize_scores(scores: dict) -> dict: 71 | normalized_scores = {} 72 | for key, value in scores.items(): 73 | if key != "Total score": 74 | normalized_scores[key] = value / 5 75 | return normalized_scores 76 | 77 | 78 | def weighted_sum(normalized_scores: dict, weights: dict) -> float: 79 | return sum(normalized_scores[key] * weights[key] for key in normalized_scores) 80 | 81 | 82 | def select(node: TreeNode, weights: dict, C: float = 1.0): 83 | if not node.children: 84 | return node 85 | 86 | ucb1_values = [ 87 | (weighted_sum(normalize_scores(child.scores), weights) / child.visits) 88 | + C * math.sqrt(math.log(node.visits) / child.visits) 89 | for child in node.children 90 | ] 91 | 92 | max_index = ucb1_values.index(max(ucb1_values)) 93 | return select(node.children[max_index], weights, C) 94 | 95 | 96 | @retry_parse_fail_prone_cmd 97 | def expand(node: TreeNode, task_iterate, expanded_nodes_cache: set): 98 | acronyms_to_scores = {} 99 | current = node 100 | while current is not None: 101 | acronyms_to_scores[current.acronym] = (current.title, current.scores) 102 | current = current.parent 103 | 104 | new_title, new_acronym = task_iterate(acronyms_to_scores=acronyms_to_scores) 105 | 106 | while new_acronym in expanded_nodes_cache: 107 | new_title, new_acronym = task_iterate(acronyms_to_scores=acronyms_to_scores) 108 | 109 | expanded_nodes_cache.add(new_acronym) 110 | return TreeNode(title=new_title, acronym=new_acronym, scores=None, parent=node) 111 | 112 | 113 | @retry_parse_fail_prone_cmd 114 | def simulate(node: TreeNode, task_feedback): 115 | scores = task_feedback(title=node.title, acronym=node.acronym) 116 | node.scores = parse_scores(scores) 117 | normalized_total_score = node.scores["Total score"] / 25 118 | return normalized_total_score 119 | 120 | 121 | def backpropagate(node: TreeNode, value: float): 122 | node.visits += 1 123 | node.value += value 124 | 125 | if node.parent is not None: 126 | backpropagate(node.parent, value) 127 | 128 | 129 | def mcts_iteration( 130 | root: TreeNode, weights: dict, task_iterate, task_feedback, expanded_nodes_cache: set 131 | ): 132 | 133 | print("Selecting...") 134 | selected_node = select(root, weights) 135 | print(f" Selected node: {selected_node.acronym} for title '{selected_node.title}'") 136 | 137 | print("Expanding...") 138 | expanded_node = expand(selected_node, task_iterate, expanded_nodes_cache) 139 | 140 | selected_node.children.append( 141 | expanded_node 142 | ) # Add expanded node as a child of the selected node 143 | print(f" Expanded node: {expanded_node.acronym} for title '{expanded_node.title}'") 144 | 145 | print("Simulating...") 146 | value = simulate(expanded_node, task_feedback) 147 | print(f" Simulated value: {value}") 148 | 149 | print("Backpropagating...") 150 | backpropagate(expanded_node, value) 151 | print(" Backpropagation complete") 152 | 153 | 154 | root_title = "Iterative Refinement with Self-Feedback" 155 | root_acronym = task_init(title=root_title) 156 | root_scores = task_feedback(title=root_title, acronym=root_acronym) 157 | root_scores = parse_scores(root_scores) 158 | root = TreeNode(root_title, root_acronym, scores=root_scores) 159 | 160 | print(f"Root node: {root}") 161 | 162 | generate_initial_children(root, task_iterate, task_feedback) 163 | 164 | print(f"Root node after generating initial children: {root}") 165 | 166 | weights = { 167 | "Ease of pronunciation": 0.2, 168 | "Ease of spelling": 0.2, 169 | "Relation to title": 0.3, 170 | "Positive connotation": 0.2, 171 | "Well-known": 0.1, 172 | } 173 | 174 | iterations = 4 175 | expanded_nodes_cache = {root.acronym} 176 | 177 | for _ in range(iterations): 178 | mcts_iteration(root, weights, task_iterate, task_feedback, expanded_nodes_cache) 179 | 180 | 181 | def dfs(node: TreeNode, best_node: TreeNode): 182 | if not node.children: 183 | if node.scores["Total score"] > best_node.scores["Total score"]: 184 | return node 185 | else: 186 | return best_node 187 | 188 | for child in node.children: 189 | best_node = dfs(child, best_node) 190 | 191 | return best_node 192 | 193 | 194 | best_node = dfs(root, root) 195 | best_acronym, best_score = best_node.acronym, best_node.scores["Total score"] 196 | 197 | print(f"\nBest acronym: {best_acronym} with total score: {best_score}") 198 | 199 | 200 | def print_tree(node: TreeNode, indent: int = 0): 201 | indentation = " " * indent 202 | print(f"{indentation}{node.acronym} (Score: {node.scores['Total score']})") 203 | 204 | for child in node.children: 205 | print_tree(child, indent + 1) 206 | 207 | 208 | print_tree(root) 209 | -------------------------------------------------------------------------------- /src/acronym/task_init.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from src.utils import Prompt 3 | 4 | from prompt_lib.backends import openai_api 5 | 6 | 7 | class AcronymGenTaskInit(Prompt): 8 | def __init__(self, prompt_examples: str, engine: str) -> None: 9 | super().__init__( 10 | question_prefix="Title: ", 11 | answer_prefix="Acronym: ", 12 | intra_example_sep="\n\n", 13 | inter_example_sep="\n\n###\n\n", 14 | ) 15 | self.engine = engine 16 | self.setup_prompt_from_examples_file(prompt_examples) 17 | 18 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 19 | TEMPLATE = """Title: {title} 20 | 21 | Acronym: {answer}""" 22 | 23 | examples_df = pd.read_json(examples_path, orient="records", lines=True) 24 | prompt = [] 25 | for _, row in examples_df.iterrows(): 26 | prompt.append(TEMPLATE.format(title=row["title"], answer=row["acronym"])) 27 | self.prompt = self.inter_example_sep.join(prompt) 28 | self.prompt = self.prompt + self.inter_example_sep 29 | 30 | def make_query(self, title: str) -> str: 31 | query = f"{self.prompt}{self.question_prefix}{title}{self.intra_example_sep}" 32 | return query 33 | 34 | def __call__(self, title: str) -> str: 35 | generation_query = self.make_query(title) 36 | 37 | output = openai_api.OpenaiAPIWrapper.call( 38 | prompt=generation_query, 39 | engine=self.engine, 40 | max_tokens=300, 41 | stop_token="###", 42 | temperature=0.7, 43 | ) 44 | 45 | generated_acronym = openai_api.OpenaiAPIWrapper.get_first_response(output) 46 | # print("output:") 47 | # print(generated_acronym) 48 | # sys.exit() 49 | generated_acronym = generated_acronym.split(self.answer_prefix)[1].replace("#", "").strip() 50 | return generated_acronym.strip() 51 | 52 | 53 | def test(): 54 | task_init = AcronymGenTaskInit(engine="code-davinci-002", prompt_examples="data/prompt/acronym/init.jsonl") 55 | print(task_init.prompt) 56 | 57 | 58 | if __name__ == "__main__": 59 | test() -------------------------------------------------------------------------------- /src/acronym/task_iterate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Dict, List 3 | from src.utils import Prompt 4 | import pandas as pd 5 | 6 | from prompt_lib.backends import openai_api 7 | 8 | 9 | class AcronymGenTaskIterate(Prompt): 10 | def __init__(self, engine: str, prompt_examples: str) -> None: 11 | super().__init__( 12 | question_prefix="", 13 | answer_prefix="", 14 | intra_example_sep="\n\n", 15 | inter_example_sep="\n\n###\n\n", 16 | ) 17 | self.engine = engine 18 | self.count = 0 19 | self.prompt = self.make_prompt(prompt_examples=prompt_examples) 20 | 21 | def make_prompt(self, prompt_examples: str) -> str: 22 | 23 | prompt_examples = pd.read_json(prompt_examples, orient="records", lines=True) 24 | # group on example 25 | grouped = prompt_examples.groupby("example") 26 | 27 | prompt = [] 28 | # sort each group by score 29 | for _, group in grouped: 30 | group["numerical_score"] = group["total_score"].apply(lambda x: int(x.split("/")[0].strip())) 31 | group = group.sort_values("numerical_score") 32 | prompt.append(self.make_one_iterate_example(group.to_dict("records"))) 33 | 34 | return self.inter_example_sep.join(prompt) + self.inter_example_sep 35 | 36 | 37 | def make_one_iterate_example(self, incrementally_improving_examples: List[Dict]): 38 | """Given a list of examples that are incrementally improving, return a new example. 39 | """ 40 | 41 | instr = """We want to iteratively improve acronyms. To help improve, scores for each acronym on five desired traits are provided: i) ease of pronunciation, ii) ease of spelling, and iii) relation to the title, iv) positive connotation, v) well-known. 42 | 43 | """ 44 | example_template = """Title: {title} 45 | 46 | Acronym: {acronym} 47 | 48 | Scores: 49 | 50 | * Ease of pronunciation: {pronunciation_score} 51 | * Ease of spelling: {spelling_score} 52 | * Relation to title: {relation_score} 53 | * Positive connotation: {connotation_score} 54 | * Well-known: {well_known_score} 55 | 56 | * Total score: {total_score} 57 | 58 | Okay, let's use this feedback to improve the acronym. 59 | 60 | """ 61 | prompt = [] 62 | for example in incrementally_improving_examples: 63 | prompt.append(example_template.format(**example)) 64 | 65 | prompt = "".join(prompt) 66 | prompt = instr + prompt 67 | return prompt.strip() 68 | 69 | def make_query(self, question: str) -> str: 70 | return f"{self.prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}" 71 | # return super().make_query(prompt, question) 72 | 73 | def _make_input( 74 | self, 75 | title: str, 76 | acronym: str, 77 | scores: str, 78 | ) -> str: 79 | input_txt = f"""Title: {title} 80 | 81 | Acronym: {acronym} 82 | 83 | Scores: 84 | 85 | {scores} 86 | 87 | Okay, let's use this feedback to improve the acronym. 88 | 89 | """ 90 | 91 | return input_txt 92 | 93 | def __call__( 94 | self, 95 | acronyms_to_scores: Dict[str, str], 96 | ) -> str: 97 | example_input = self.make_input( 98 | acronyms_to_scores=acronyms_to_scores, 99 | ) 100 | transfer_query = self.make_query(example_input) 101 | self.count += 1 102 | with open(f"acronym_iterate_{self.count}.txt", "w") as f: 103 | f.write(transfer_query + "\n") 104 | 105 | output = openai_api.OpenaiAPIWrapper.call( 106 | prompt=transfer_query, 107 | engine=self.engine, 108 | max_tokens=300, 109 | stop_token=self.inter_example_sep, 110 | temperature=0.7, 111 | ) 112 | response = openai_api.OpenaiAPIWrapper.get_first_response(output) 113 | 114 | 115 | acronym = response.split("Acronym:")[1].strip().split("\n")[0].strip() 116 | 117 | new_title = response.split("Title:")[1].strip().split("\n")[0].strip() 118 | 119 | 120 | 121 | return new_title, acronym.strip() 122 | 123 | def make_input( 124 | self, 125 | acronyms_to_scores: Dict[str, str], 126 | ) -> str: 127 | input_txt = "" 128 | for acronym, (title, scores) in acronyms_to_scores.items(): 129 | input_txt += self._make_input( 130 | title=title, 131 | acronym=acronym, 132 | scores=scores, 133 | ) 134 | return input_txt 135 | 136 | 137 | 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | obj = AcronymGenTaskIterate(prompt_examples="data/prompt/acronym/feedback.jsonl", engine="whatever") 143 | print(obj.prompt) -------------------------------------------------------------------------------- /src/commongen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/commongen/__init__.py -------------------------------------------------------------------------------- /src/commongen/eval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import List, Dict 3 | 4 | def run(path: str): 5 | 6 | df = pd.read_json(path, lines=True, orient="records") 7 | df = df[df['status'] != "error"] 8 | print(f"Loaded {len(df)} rows") 9 | for i, row in df.iterrows(): 10 | direct_output = row["sent_to_fb"][0] 11 | iter_output = row["sent_to_fb"][-1] 12 | df.loc[i, 'direct_concept_success'] = direct_output["concept_feedback"][0].lower() == "none" 13 | df.loc[i, 'direct_commonsense_success'] = direct_output["commonsense_feedback"].lower() == "none" 14 | df.loc[i, 'direct_success'] = direct_output["concept_feedback"][0].lower() == "none" and direct_output["commonsense_feedback"].lower() == "none" 15 | df.loc[i, 'iter_concept_success'] = iter_output["concept_feedback"][0].lower() == "none" 16 | df.loc[i, 'iter_commonsense_success'] = iter_output["commonsense_feedback"].lower() == "none" 17 | df.loc[i, 'iter_success'] = iter_output["concept_feedback"][0].lower() == "none" and iter_output["commonsense_feedback"].lower() == "none" 18 | 19 | # direct wins 20 | num_direct_cocept_wins = len(df[(df['direct_concept_success'] == True) & (df['iter_concept_success'] == False)]) 21 | num_direct_commonsense_wins = len(df[(df['direct_commonsense_success'] == True) & (df['iter_commonsense_success'] == False)]) 22 | num_iter_cocept_wins = len(df[(df['direct_concept_success'] == False) & (df['iter_concept_success'] == True)]) 23 | num_iter_commonsense_wins = len(df[(df['direct_commonsense_success'] == False) & (df['iter_commonsense_success'] == True)]) 24 | num_direct_wins = len(df[(df['direct_success'] == True) & (df['iter_success'] == False)]) 25 | num_iter_wins = len(df[(df['direct_success'] == False) & (df['iter_success'] == True)]) 26 | 27 | 28 | num_commonsense_ties = len(df) - num_direct_commonsense_wins - num_iter_commonsense_wins 29 | num_concept_ties = len(df) - num_direct_cocept_wins - num_iter_cocept_wins 30 | 31 | # normalize everything and print a nice report 32 | 33 | print(f"Direct concept wins: {num_direct_cocept_wins / len(df):.2f}") 34 | print(f"Direct commonsense wins: {num_direct_commonsense_wins / len(df):.2f}") 35 | print(f"Direct overall wins: {num_direct_wins / len(df):.2f}") 36 | print(f"Iter concept wins: {num_iter_cocept_wins / len(df):.2f}") 37 | print(f"Iter commonsense wins: {num_iter_commonsense_wins / len(df):.2f}") 38 | print(f"Iter overall wins: {num_iter_wins / len(df):.2f}") 39 | 40 | 41 | if __name__ == '__main__': 42 | import argparse 43 | args = argparse.ArgumentParser() 44 | args.add_argument("path", type=str) 45 | args = args.parse_args() 46 | 47 | run(path=args.path) 48 | 49 | -------------------------------------------------------------------------------- /src/commongen/feedback.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Set, List 3 | import pandas as pd 4 | from prompt_lib.backends import openai_api 5 | import nltk 6 | import spacy 7 | 8 | nlp = spacy.load("en_core_web_sm") 9 | from src.utils import Prompt 10 | 11 | 12 | class CommongenFeedback(Prompt): 13 | def __init__(self, engine: str, prompt_examples: str, max_tokens: int = 300) -> None: 14 | super().__init__( 15 | question_prefix="", 16 | answer_prefix="", 17 | intra_example_sep="\n\n", 18 | inter_example_sep="\n\n###\n\n", 19 | ) 20 | self.engine = engine 21 | self.max_tokens = max_tokens 22 | self.setup_prompt_from_examples_file(prompt_examples) 23 | 24 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 25 | template = """Concepts: {concepts} 26 | Sentence: {sentence} 27 | what concepts from the concept list are missing from the sentence and does the sentence make sense? 28 | 29 | Concept Feedback: {feedback} 30 | Commonsense Feedback: {commonsense_feedback}""" 31 | 32 | examples_df = pd.read_json(examples_path, orient="records", lines=True) 33 | prompt = [] 34 | for _, row in examples_df.iterrows(): 35 | prompt.append( 36 | template.format( 37 | concepts=row["concepts"], 38 | sentence=row["sentence"], 39 | feedback=", ".join(row["concept_feedback"]), 40 | commonsense_feedback=row["commonsense_feedback"] 41 | ) 42 | ) 43 | 44 | instruction = """We want to create a sentence that contains all the specified concepts. Please provide feedback on the following sentences. The feedback indicates missing concepts.""" 45 | self.prompt = instruction + self.inter_example_sep.join(prompt) 46 | self.prompt = self.inter_example_sep.join(prompt) + self.inter_example_sep 47 | 48 | def __call__(self, sentence: str, concepts: List[str]): 49 | prompt = self.make_query(sentence=sentence, concepts=concepts) 50 | 51 | output = openai_api.OpenaiAPIWrapper.call( 52 | prompt=prompt, 53 | engine=self.engine, 54 | max_tokens=self.max_tokens, 55 | stop_token="###", 56 | temperature=0.7, 57 | ) 58 | 59 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output) 60 | commonsense_feedback = re.search(r"Commonsense Feedback: (.*)", generated_feedback).group(1) 61 | commonsense_feedback = self.fix_feedback(sentence=sentence, concepts=concepts, feedback=commonsense_feedback) 62 | 63 | concept_feedback = re.search(r"Concept Feedback: (.*)", generated_feedback).group(1) 64 | return concept_feedback, commonsense_feedback 65 | 66 | 67 | def make_query(self, concepts: List[str], sentence: str): 68 | question = f"""Concepts: {concepts} 69 | Sentence: {sentence} 70 | what concepts from the concept list are missing from the sentence?""" 71 | return f"""{self.prompt}{question}""" 72 | 73 | 74 | def fix_feedback(self, sentence: str, concepts: List[str], feedback: str): 75 | """We rely on the model for generating a feedback. This is done to capture different forms in which the same concept might be expressed. However, the model might make mistakes and our task is simple enough that some of the mistakes can be corrected""" 76 | 77 | concepts_in_sent = self.detect_concepts(sentence=sentence, concepts=concepts) 78 | concepts_in_feedback = set([f.strip() for f in feedback.split(", ")]) 79 | fixed_feedback = concepts_in_feedback.difference(concepts_in_sent) 80 | if len(fixed_feedback) == 0: 81 | return "None" 82 | return ", ".join(fixed_feedback) 83 | 84 | def detect_concepts(self, sentence: str, concepts: List[str]) -> Set[str]: 85 | present_concepts = [] 86 | 87 | # Tokenize the sentence and lemmatize the tokens 88 | tokens = nltk.word_tokenize(sentence) 89 | lemmas = [token.lemma_ for token in nlp(sentence)] 90 | 91 | # Check if each concept is present in the sentence 92 | for concept in concepts: 93 | if concept in tokens or concept in lemmas: 94 | present_concepts.append(concept) 95 | 96 | return set(present_concepts) 97 | 98 | 99 | 100 | 101 | if __name__ == "__main__": 102 | task_feedback = CommongenFeedback( 103 | prompt_examples="data/prompt/commongen/feedback.v1.jsonl", 104 | engine="davinci-code-002" 105 | ) 106 | 107 | print(task_feedback.prompt) 108 | 109 | -------------------------------------------------------------------------------- /src/commongen/make_challenging.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | df = pd.read_json("/usr1/amadaan/shufflegen/data/original/commongen/val.jsonl", lines=True, orient="records") 4 | 5 | from itertools import chain 6 | all_concepts = set(chain(*df['concepts'].tolist())) 7 | 8 | # challenging data with 10-15 concepts 9 | 10 | import random 11 | random.seed(42) 12 | 13 | n_samples = 200 14 | res = [] 15 | for i in range(n_samples): 16 | k = random.randint(20, 30) 17 | concepts = random.sample(all_concepts, k=k) 18 | res.append({"concepts": concepts}) 19 | 20 | pd.DataFrame(res).to_json("data/commongen_very_challenging.jsonl", lines=True, orient="records") 21 | 22 | -------------------------------------------------------------------------------- /src/commongen/run.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from tqdm import tqdm 3 | from typing import List 4 | 5 | from src.commongen.task_init import CommongenTaskInit 6 | from src.commongen.task_iterate import CommongenTaskIterate 7 | from src.commongen.feedback import CommongenFeedback 8 | from src.utils import retry_parse_fail_prone_cmd 9 | 10 | CODEX = "code-davinci-002" 11 | GPT3 = "text-davinci-003" 12 | CHATGPT = "gpt-3.5-turbo" 13 | ENGINE = GPT3 14 | 15 | 16 | @retry_parse_fail_prone_cmd 17 | def autofb_commongen(concepts: List[str], max_attempts: int) -> str: 18 | 19 | # initialize all the required components 20 | 21 | # generation of the first sentence 22 | task_init = CommongenTaskInit(engine=ENGINE, prompt_examples="data/prompt/commongen/init.jsonl") 23 | 24 | # getting feedback 25 | task_feedback = CommongenFeedback( 26 | engine=ENGINE, prompt_examples="data/prompt/commongen/feedback.jsonl" 27 | ) 28 | 29 | # iteratively improving the sentence 30 | task_iterate = CommongenTaskIterate( 31 | engine=ENGINE, prompt_examples="data/prompt/commongen/iterate.jsonl" 32 | ) 33 | 34 | # Initialize the task 35 | 36 | n_attempts = 0 37 | 38 | print(f"{n_attempts} INIT> {concepts}") 39 | sent_to_fb = [] 40 | 41 | while n_attempts < max_attempts: 42 | print() 43 | 44 | if n_attempts == 0: 45 | sent = task_init(concepts=concepts) 46 | else: 47 | sent = task_iterate(concepts=concepts, sent_to_fb=sent_to_fb) 48 | 49 | print(f"{n_attempts} GEN> {sent}") 50 | 51 | concept_fb, commonsense_fb = task_feedback(concepts=concepts, sentence=sent) 52 | 53 | sent_to_fb.append( 54 | { 55 | "sentence": sent, 56 | "concept_feedback": [f.strip() for f in concept_fb.split(",")], 57 | "commonsense_feedback": commonsense_fb, 58 | } 59 | ) 60 | print(f"{n_attempts} Concept> {concept_fb} | CommonSense> {commonsense_fb}") 61 | 62 | if concept_fb.lower() == "none" and commonsense_fb.lower() == "none": 63 | break 64 | 65 | n_attempts += 1 66 | 67 | return sent_to_fb 68 | 69 | 70 | def run_cmd(): 71 | concepts = sys.argv[2:] 72 | max_attempts = 5 73 | sent_to_fb = autofb_commongen( 74 | concepts=concepts, 75 | max_attempts=max_attempts, 76 | ) 77 | 78 | res = [] 79 | for s in sent_to_fb: 80 | sent = s["sentence"] 81 | fb = "; ".join(s["concept_feedback"]) + " " + s["commonsense_feedback"] 82 | res.append(f"{sent} ({fb})") 83 | print(" -> ".join(res)) 84 | 85 | 86 | def run_iter(inputs_file_path: str, max_attempts: int = 4): 87 | test_df = pd.read_json(inputs_file_path, lines=True, orient="records") 88 | # add new columns sent_to_fb of type object, and status of type string 89 | 90 | is_rerun = "status" in test_df.columns 91 | if not is_rerun: 92 | test_df["sent_to_fb"] = None 93 | test_df["sent_to_fb"] = test_df["sent_to_fb"].astype(object) 94 | test_df["status"] = None 95 | else: 96 | print("Status column already exists! Looks like you're trying to do a re-run") 97 | print(test_df["status"].value_counts()) 98 | for i, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Running autofb iter"): 99 | if row["status"] == "success": 100 | continue 101 | try: 102 | sent_to_fb = autofb_commongen(concepts=row["concepts"], max_attempts=max_attempts) 103 | test_df.loc[i, "sent_to_fb"] = sent_to_fb 104 | test_df.loc[i, "status"] = "success" 105 | except Exception as e: 106 | test_df.loc[i, "sent_to_fb"] = str(e) 107 | test_df.loc[i, "status"] = "error" 108 | 109 | output_path = inputs_file_path + (".iter.out" if not is_rerun else ".v0") 110 | version = 0 111 | while pathlib.Path(output_path).exists(): 112 | output_path = output_path + f".v{version}" 113 | version += 1 114 | 115 | test_df.to_json(output_path, orient="records", lines=True) 116 | 117 | 118 | def run_multi_sample(inputs_file_path: str, n_samples: int = 4): 119 | test_df = pd.read_json(inputs_file_path, lines=True, orient="records") 120 | 121 | is_rerun = "status" in test_df.columns 122 | if not is_rerun: 123 | test_df["outputs"] = None 124 | test_df["outputs"] = test_df["outputs"].astype(object) 125 | test_df["status"] = None 126 | else: 127 | print("Status column already exists! Looks like you're trying to do a re-run") 128 | print(test_df["status"].value_counts()) 129 | 130 | task_init = CommongenTaskInit(engine=ENGINE, prompt_examples="data/prompt/commongen/init.jsonl") 131 | task_feedback = CommongenFeedback( 132 | engine=ENGINE, prompt_examples="data/prompt/commongen/feedback.v1.jsonl" 133 | ) 134 | for i, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Running multisample autofb"): 135 | 136 | if row["status"] == "success": 137 | continue 138 | try: 139 | outputs = [] 140 | for _ in range(n_samples): 141 | sent = task_init(concepts=row["concepts"]) 142 | print(sent) 143 | concept_fb, commonsense_fb = task_feedback(concepts=row["concepts"], sentence=sent) 144 | print(concept_fb, commonsense_fb) 145 | outputs.append( 146 | { 147 | "sentence": sent, 148 | "concept_feedback": [f.strip() for f in concept_fb.split(",")], 149 | "commonsense_feedback": commonsense_fb, 150 | } 151 | ) 152 | if concept_fb.lower() == "none" and commonsense_fb.lower() == "none": 153 | break 154 | test_df.loc[i, "outputs"] = outputs 155 | test_df.loc[i, "status"] = "success" 156 | except Exception as e: 157 | raise e 158 | test_df.loc[i, "outputs"] = str(e) 159 | test_df.loc[i, "status"] = "error" 160 | print(test_df) 161 | output_path = inputs_file_path + "." + ENGINE + (".multi.out" if not is_rerun else ".v0") 162 | version = 0 163 | while pathlib.Path(output_path).exists(): 164 | output_path = output_path + f".v{version}" 165 | version += 1 166 | 167 | test_df.to_json(output_path, orient="records", lines=True) 168 | 169 | 170 | if __name__ == "__main__": 171 | import sys 172 | import pandas as pd 173 | 174 | if sys.argv[1] == "cmd": 175 | run_cmd() 176 | 177 | elif sys.argv[1] == "batch-iter": 178 | run_iter(inputs_file_path=sys.argv[2]) 179 | 180 | elif sys.argv[1] == "batch-multi": 181 | run_multi_sample(inputs_file_path=sys.argv[2]) 182 | 183 | else: 184 | raise ValueError("Invalid mode: choose between cmd, batch-iter, batch-multi") 185 | 186 | 187 | -------------------------------------------------------------------------------- /src/commongen/task_init.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import pandas as pd 3 | from src.utils import Prompt 4 | 5 | from prompt_lib.backends import openai_api 6 | 7 | 8 | class CommongenTaskInit(Prompt): 9 | def __init__(self, prompt_examples: str, engine: str) -> None: 10 | super().__init__( 11 | question_prefix="Concepts: ", 12 | answer_prefix="Sentence: ", 13 | intra_example_sep="\n\n", 14 | inter_example_sep="\n\n###\n\n", 15 | ) 16 | self.engine = engine 17 | self.setup_prompt_from_examples_file(prompt_examples) 18 | 19 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 20 | TEMPLATE = """Concepts: {concepts} 21 | 22 | Sentence: {sentence}""" 23 | 24 | examples_df = pd.read_json(examples_path, orient="records", lines=True) 25 | prompt = [] 26 | for _, row in examples_df.iterrows(): 27 | prompt.append(TEMPLATE.format(concepts=row["concepts"], sentence=row["target"])) 28 | self.prompt = self.inter_example_sep.join(prompt) 29 | self.prompt = self.prompt + self.inter_example_sep 30 | 31 | def make_query(self, concepts: List[str]) -> str: 32 | 33 | query = f"""{self.question_prefix}{concepts}""" 34 | query = f"{self.prompt}{query}{self.intra_example_sep}" 35 | return query 36 | 37 | def __call__(self, concepts: List[str]) -> str: 38 | generation_query = self.make_query(concepts) + "\nDo your best! It's okay if the sentence is not coherent.\n" 39 | 40 | output = openai_api.OpenaiAPIWrapper.call( 41 | prompt=generation_query, 42 | engine=self.engine, 43 | max_tokens=300, 44 | stop_token="###", 45 | temperature=0.7, 46 | ) 47 | 48 | generated_sent = openai_api.OpenaiAPIWrapper.get_first_response(output) 49 | print(generated_sent) 50 | generated_sent = generated_sent.split(self.answer_prefix)[1].replace("#", "").strip() 51 | return generated_sent.strip() 52 | 53 | 54 | 55 | if __name__ == "__main__": 56 | task_init = CommongenTaskInit( 57 | prompt_examples="data/prompt/commongen/init.jsonl", 58 | engine="davinci-code-002" 59 | ) 60 | 61 | print(task_init.prompt) 62 | # print(task_init.make_query(["a", "b", "c"])) -------------------------------------------------------------------------------- /src/commongen/task_iterate.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List 3 | from src.utils import Prompt 4 | 5 | from prompt_lib.backends import openai_api 6 | 7 | header = """Concepts: {concepts} 8 | """ 9 | example_template = """Sentence: {sentence} 10 | 11 | what concepts from the concept list are missing from the sentence? 12 | 13 | Concept Feedback: {concept_feedback} 14 | 15 | Any feedback on commonsense? 16 | 17 | Commonsense Feedback: {commonsense_feedback}""" 18 | 19 | instr = """ 20 | 21 | Okay, impove the sentence using the feedback: 22 | 23 | """ 24 | 25 | class CommongenTaskIterate(Prompt): 26 | def __init__(self, engine: str, prompt_examples: str) -> None: 27 | super().__init__( 28 | question_prefix="", 29 | answer_prefix="", 30 | intra_example_sep="\n\n", 31 | inter_example_sep="\n\n###\n\n", 32 | ) 33 | self.engine = engine 34 | self.count = 0 35 | self.prompt = self.make_prompt(prompt_examples=prompt_examples) 36 | 37 | def make_prompt(self, prompt_examples: str) -> str: 38 | import pandas as pd 39 | 40 | prompt_examples = pd.read_json(prompt_examples, orient="records", lines=True) 41 | 42 | prompt = [] 43 | 44 | for example in prompt_examples.to_dict(orient="records"): 45 | prompt.append( 46 | self.make_one_iterate_example( 47 | concepts=example["concepts"], sent_to_fb=example["sentence_to_feedback"] 48 | ) 49 | ) 50 | 51 | return self.inter_example_sep.join(prompt) + self.inter_example_sep 52 | 53 | def make_one_iterate_example(self, concepts: List[str], sent_to_fb: List[Dict]): 54 | """Given a list of examples that are incrementally improving, return a new example.""" 55 | 56 | 57 | 58 | single_example = [] 59 | for example in sent_to_fb: 60 | 61 | single_example.append(example_template.format( 62 | sentence=example["sentence"], commonsense_feedback=example["commonsense_feedback"], concept_feedback=example["concept_feedback"] 63 | )) 64 | 65 | return header.format(concepts=concepts) + instr.join(single_example) 66 | 67 | def make_query(self, concepts: List[str], 68 | sent_to_fb: List[Dict],) -> str: 69 | query_example = self.make_one_iterate_example(concepts=concepts, sent_to_fb=sent_to_fb) 70 | return f"{self.prompt}{self.question_prefix}{query_example}{self.intra_example_sep}{self.answer_prefix}" + instr 71 | # return super().make_query(prompt, question) 72 | 73 | def __call__( 74 | self, 75 | concepts: List[str], 76 | sent_to_fb: List[Dict], 77 | ) -> str: 78 | 79 | transfer_query = self.make_query(concepts=concepts, sent_to_fb=sent_to_fb) 80 | self.count += 1 81 | 82 | output = openai_api.OpenaiAPIWrapper.call( 83 | prompt=transfer_query, 84 | engine=self.engine, 85 | max_tokens=300, 86 | stop_token=self.inter_example_sep, 87 | temperature=0.7, 88 | ) 89 | response = openai_api.OpenaiAPIWrapper.get_first_response(output) 90 | 91 | print("######") 92 | print() 93 | print("------") 94 | print(response) 95 | print("------") 96 | 97 | response = re.search("Sentence: (.*)", response).group(1).strip().split("\n")[0].strip() 98 | 99 | return response.strip() 100 | 101 | def make_input( 102 | self, 103 | title: str, 104 | acronyms_to_scores: Dict[str, str], 105 | ) -> str: 106 | input_txt = "" 107 | for acronym, scores in acronyms_to_scores.items(): 108 | input_txt += self._make_input( 109 | title=title, 110 | acronym=acronym, 111 | scores=scores, 112 | ) 113 | return input_txt 114 | 115 | 116 | if __name__ == "__main__": 117 | obj = CommongenTaskIterate( 118 | prompt_examples="data/prompt/commongen/iterate.v1.jsonl", engine="whatever" 119 | ) 120 | print(obj.prompt) 121 | # print(obj.make_query(concepts=["a", "b"], sent_to_fb=[{"sentence": "a", "feedback": "a"}, {"sentence": "b", "feedback": "d"}])) 122 | -------------------------------------------------------------------------------- /src/gsm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/gsm/__init__.py -------------------------------------------------------------------------------- /src/gsm/feedback.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from prompt_lib.backends import openai_api 3 | 4 | from src.utils import Prompt 5 | 6 | 7 | class GSMFeedback(Prompt): 8 | def __init__(self, engine: str, prompt_examples: str, temperature: float, max_tokens: int = 600) -> None: 9 | super().__init__( 10 | question_prefix="", 11 | answer_prefix="", 12 | intra_example_sep="\n\n", 13 | inter_example_sep="\n\n### END ###\n\n", 14 | engine = engine, 15 | temperature = temperature 16 | ) 17 | 18 | self.max_tokens = max_tokens 19 | self.instruction = "# There is an error in the code above because of lack of understanding of the question. What is the error? To find the error, go through semantically complete blocks of the code, and check if everything looks good." if "naive" not in prompt_examples else "# There is an error in the code above." 20 | self.setup_prompt_from_examples_file(prompt_examples) 21 | 22 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 23 | with open(examples_path, "r") as f: 24 | self.prompt = f.read() 25 | 26 | def __call__(self, solution: str): 27 | generation_query = self.make_query(solution=solution) 28 | print(generation_query) 29 | # print(1/0) 30 | output = openai_api.OpenaiAPIWrapper.call( 31 | prompt=generation_query, 32 | engine=self.engine, 33 | max_tokens=self.max_tokens, 34 | stop_token="### END", 35 | temperature=self.temperature, 36 | ) 37 | 38 | 39 | entire_output = openai_api.OpenaiAPIWrapper.get_first_response(output) 40 | print(entire_output) 41 | if "### END" in entire_output: 42 | entire_output = entire_output.split("### END")[0] 43 | 44 | improved_soln = entire_output.split("def solution():")[1] 45 | feedback = entire_output.split("def solution():")[0] 46 | improved_soln = "def solution():" + improved_soln.rstrip() 47 | self.update_prompt(solution=solution, improved_soln=improved_soln, feedback=feedback) 48 | return {"solution": improved_soln, "feedback": feedback} 49 | 50 | def make_query(self, solution: str): 51 | 52 | solution = f"""{self.question_prefix}{solution}{self.intra_example_sep}{self.instruction}{self.answer_prefix}""" 53 | return f"{self.prompt}{solution}" 54 | 55 | 56 | def update_prompt(self, solution: str, improved_soln: str, feedback: str): 57 | prefix = f"""{self.question_prefix}{solution}{self.intra_example_sep}{self.instruction}{self.answer_prefix}""" 58 | 59 | gen_ans = f""" 60 | 61 | {feedback} 62 | 63 | {improved_soln.rstrip()}{self.inter_example_sep}""" 64 | 65 | new_example = f"{prefix}{gen_ans}" 66 | self.prompt = f"{self.prompt}{new_example}" 67 | 68 | 69 | def test(): 70 | task_fb = GSMFeedback( 71 | prompt_examples="data/prompt/gsm/pal/feedback.txt", 72 | engine="gpt-3.5-turbo", 73 | temperature=0.7, 74 | ) 75 | 76 | wrong_soln = """def solution(): 77 | \"\"\"Twenty dozen cups cost $1200 less than the total cost of half a dozen plates sold at $6000 each. Calculate the total cost of buying each cup.\"\"\" 78 | plates = 6 79 | plate_cost = 6000 80 | cups = 12 * 20 81 | cup_cost = (plates * plate_cost) / cups - 1200 82 | result = cup_cost 83 | return result""" 84 | feedback_and_solution = task_fb(wrong_soln) 85 | print(feedback_and_solution["feedback"]) 86 | print(feedback_and_solution["solution"]) 87 | 88 | print(task_fb.prompt) 89 | 90 | 91 | if __name__ == '__main__': 92 | test() -------------------------------------------------------------------------------- /src/gsm/feedback_no_update.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from prompt_lib.backends import openai_api 3 | 4 | from src.utils import Prompt 5 | 6 | 7 | class GSMFeedback(Prompt): 8 | def __init__(self, engine: str, prompt_examples: str, temperature: float, max_tokens: int = 300) -> None: 9 | super().__init__( 10 | question_prefix="", 11 | answer_prefix="", 12 | intra_example_sep="\n\n", 13 | inter_example_sep="\n\n### END ###\n\n", 14 | engine = engine, 15 | temperature = temperature 16 | ) 17 | self.max_tokens = max_tokens 18 | self.instruction = "# There is an error in the code above because of lack of understanding of the question. What is the error? To find the error, go through semantically complete blocks of the code, and check if everything looks good." 19 | self.setup_prompt_from_examples_file(prompt_examples) 20 | 21 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 22 | with open(examples_path, "r") as f: 23 | self.prompt = f.read() 24 | 25 | def __call__(self, solution: str): 26 | generation_query = self.make_query(solution=solution) 27 | output = openai_api.OpenaiAPIWrapper.call( 28 | prompt=generation_query, 29 | engine=self.engine, 30 | max_tokens=self.max_tokens, 31 | stop_token="### END", 32 | temperature=self.temperature, 33 | ) 34 | 35 | entire_output = openai_api.OpenaiAPIWrapper.get_first_response(output) 36 | if "### END" in entire_output: 37 | entire_output = entire_output.split("### END")[0] 38 | solution = entire_output.split("def solution():")[1] 39 | feedback = entire_output.split("def solution():")[0] 40 | solution = "def solution():" + solution.rstrip() 41 | return {"solution": solution, "feedback": feedback} 42 | 43 | def make_query(self, solution: str): 44 | solution = f"""{self.question_prefix}{solution}{self.intra_example_sep}{self.instruction}{self.answer_prefix}""" 45 | return f"{self.prompt}{solution}" 46 | 47 | 48 | def test(): 49 | task_fb = GSMFeedback( 50 | prompt_examples="data/prompt/gsm/feedback.txt", 51 | engine="code-davinci-002", 52 | temperature=0.7, 53 | ) 54 | 55 | wrong_soln = """def solution(): 56 | \"\"\"Twenty dozen cups cost $1200 less than the total cost of half a dozen plates sold at $6000 each. Calculate the total cost of buying each cup.\"\"\" 57 | plates = 6 58 | plate_cost = 6000 59 | cups = 12 * 20 60 | cup_cost = (plates * plate_cost) / cups - 1200 61 | result = cup_cost 62 | return result""" 63 | feedback_and_solution = task_fb(wrong_soln) 64 | print(feedback_and_solution["feedback"]) 65 | print(feedback_and_solution["solution"]) 66 | 67 | 68 | if __name__ == '__main__': 69 | test() 70 | -------------------------------------------------------------------------------- /src/gsm/gsm_selfref_eval.py: -------------------------------------------------------------------------------- 1 | from importlib import reload 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from contextlib import contextmanager 5 | import signal 6 | from glob import glob 7 | import os 8 | 9 | # from https://stackoverflow.com/questions/492519/timeout-on-a-function-call 10 | @contextmanager 11 | def timeout(duration): 12 | def timeout_handler(signum, frame): 13 | raise TimeoutError(f"block timedout after {duration} seconds") 14 | 15 | signal.signal(signal.SIGALRM, timeout_handler) 16 | signal.alarm(duration) 17 | try: 18 | yield 19 | finally: 20 | signal.alarm(0) 21 | 22 | def read_json(path): 23 | import json 24 | rows = [] 25 | with open(path, "r") as f: 26 | for line in f: 27 | rows.append(json.loads(line)) 28 | 29 | task_df = pd.DataFrame(rows) 30 | return task_df 31 | 32 | def evaluate_code_prompt(path, num_gsm: int = 1319): 33 | data = read_json(path) 34 | if "question" not in data.columns: 35 | data["question"] = data["input"] 36 | if "answer" not in data.columns: 37 | data["answer"] = data["target"] 38 | 39 | attempt_to_acc = [] 40 | reports = [] # Step 1 41 | for idx, row in tqdm(data.iterrows(), total=len(data)): 42 | # if idx < 20: 43 | # continue 44 | # if idx > 10: 45 | # break 46 | attempt_to_acc_ = {i: 0 for i in range(5)} 47 | attempt_to_acc_["question"] = row["question"] 48 | solutions = [] 49 | if row["run_logs"] is None: 50 | continue 51 | for _, log in enumerate(row["run_logs"]): 52 | solutions.append(log["solution_curr"]) 53 | solutions.append(row["run_logs"][-1]["solution_fixed"]) 54 | 55 | feedback = [rec["feedback"] for rec in row["run_logs"]] 56 | 57 | prev_accuracy = 0 58 | for iter_idx, soln in enumerate(solutions): 59 | soln = soln.split("\n\n\n")[0].strip() + "\n" 60 | soln = soln.replace("The answer is", "").strip() + "\n" 61 | os.system("rm -rf __pycache__") 62 | os.system("rm -f temp_result.pyc") 63 | 64 | with open("temp_result.py", "w") as f: 65 | f.write(soln) 66 | 67 | try: 68 | import temp_result 69 | reload(temp_result) 70 | correct_solution = str(row["answer"]) 71 | 72 | exec(soln) 73 | with timeout(1): 74 | result = str(temp_result.solution()) 75 | is_corr = check_corr(result, correct_solution) 76 | 77 | 78 | is_corr = int(is_corr) 79 | # Step 2 80 | 81 | if iter_idx > 0 and is_corr == 1 and prev_accuracy == 0: 82 | report = { 83 | "previous_solution": solutions[iter_idx - 1], 84 | "feedback": feedback[iter_idx - 1], 85 | "next_solution": solutions[iter_idx], 86 | } 87 | reports.append(report) # Step 3 88 | if is_corr == 1: 89 | for j in range(iter_idx, 5): 90 | attempt_to_acc_[j] = 1 91 | break 92 | attempt_to_acc_[iter_idx] = 0 93 | prev_accuracy = is_corr 94 | except Exception as e: 95 | continue 96 | 97 | attempt_to_acc.append(attempt_to_acc_) 98 | 99 | df = pd.DataFrame(attempt_to_acc) 100 | 101 | # print(attempt_to_acc) 102 | for i in range(5): 103 | print(f"Accuracy at attempt {i} = {df[i].sum() / num_gsm:.2%} ({df[i].sum()}/{num_gsm})") 104 | 105 | df.to_json("/tmp/attempt_to_acc.jsonl", orient="records", lines=True) 106 | 107 | report_file = f"{path}.reports.txt" 108 | print_reports(reports, report_file) # Step 4 109 | return reports 110 | 111 | # Step 4 112 | def print_reports(reports, report_file): 113 | 114 | 115 | with open(report_file, "w") as f: 116 | for i, report in enumerate(reports): 117 | f.write(f"Report {i + 1}:\n") 118 | f.write("\nPrevious solution:\n") 119 | f.write(report["previous_solution"]) 120 | f.write("\n\nFeedback:\n") 121 | f.write(report["feedback"]) 122 | f.write("\n\nNext solution:\n") 123 | f.write(report["next_solution"]) 124 | f.write("\n\n" + "=" * 80 + "\n\n") 125 | def check_corr(result: float, correct_solution: float, tol: float = 1e-3): 126 | if result.strip() == correct_solution.strip(): 127 | return 1 128 | try: 129 | result = float(result.strip()) 130 | correct_solution = float(correct_solution.strip()) 131 | return abs(result - correct_solution) < tol 132 | except: 133 | return 0 134 | 135 | 136 | 137 | if __name__ == "__main__": 138 | import argparse 139 | 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--path", type=str, default="data/quco/quco_test.jsonl") 142 | args = parser.parse_args() 143 | 144 | evaluate_code_prompt(args.path) 145 | -------------------------------------------------------------------------------- /src/gsm/run.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tqdm import tqdm 3 | 4 | 5 | from src.gsm.task_init import GSMInit 6 | from src.gsm.feedback import GSMFeedback 7 | 8 | from src.utils import retry_parse_fail_prone_cmd 9 | 10 | CODEX = "code-davinci-002" 11 | # GPT3 = "text-davinci-003" 12 | ENGINE = CODEX 13 | 14 | 15 | @retry_parse_fail_prone_cmd 16 | def iterative_gsm(question: str, max_attempts: int, feedback_type: str, temperature: float): 17 | 18 | # initialize all the required components 19 | 20 | # generation of the first fast version 21 | task_init = GSMInit(engine=ENGINE, prompt_examples="data/prompt/gsm/init.txt", temperature=temperature) 22 | 23 | # getting feedback 24 | if feedback_type == "naive": 25 | raise NotImplementedError 26 | else: 27 | task_feedback = GSMFeedback(engine=ENGINE, prompt_examples="data/prompt/gsm/feedback.txt", temperature=0.7) 28 | 29 | 30 | n_attempts = 0 31 | 32 | log = [] 33 | 34 | while n_attempts < max_attempts: 35 | 36 | if n_attempts == 0: 37 | solution = task_init(solution=question) 38 | 39 | fb_and_maybe_soln = task_feedback(solution=solution) 40 | 41 | 42 | log.append({"attempt": n_attempts, "solution_curr": solution, "solution_fixed": fb_and_maybe_soln["solution"], "feedback": fb_and_maybe_soln["feedback"]}) 43 | 44 | if "it is correct" in fb_and_maybe_soln["feedback"].lower(): 45 | break 46 | 47 | solution = fb_and_maybe_soln["solution"] 48 | 49 | n_attempts += 1 50 | 51 | return log 52 | 53 | 54 | def fix_gsm(gsm_task_file: str, max_attempts: int, outfile: str, feedback_type: str, temperature: float): 55 | 56 | 57 | slow_programs_df = pd.read_json(gsm_task_file, lines=True, orient="records") 58 | slow_programs_df["run_logs"] = None 59 | results = [] 60 | for i, row in tqdm(slow_programs_df.iterrows(), total=len(slow_programs_df)): 61 | row_copy = row.to_dict() 62 | try: 63 | run_logs = iterative_gsm(question=row["input"], max_attempts=max_attempts, feedback_type=feedback_type, temperature=temperature) 64 | row_copy["run_logs"] = run_logs 65 | row_copy["generated_answer_ours"] = run_logs[-1]["solution_fixed"] 66 | row_copy["generated_answer_direct"] = run_logs[0]["solution_curr"] 67 | results.append(row_copy) 68 | if i % 10 == 0: 69 | pd.DataFrame(results).to_json(outfile + f".{i}.jsonl", orient="records", lines=True) 70 | except Exception as e: 71 | # raise e 72 | pass 73 | pd.DataFrame(results).to_json(outfile, orient="records", lines=True) 74 | return results 75 | 76 | 77 | def test(): 78 | import json 79 | 80 | 81 | with open("/tmp/debug_gsm.jsonl", "w") as fout: 82 | fout.write(json.dumps({"input": "Twenty dozen cups cost $1200 less than the total cost of half a dozen plates sold at $6000 each. Calculate the total cost of buying each cup."})) 83 | 84 | logs = fix_gsm( 85 | gsm_task_file="/tmp/debug_gsm.jsonl", max_attempts=3, outfile="/tmp/test.jsonl", feedback_type="rich", temperature=0.0 86 | ) 87 | for i, log in enumerate(logs): 88 | print(log["generated_answer_ours"]) 89 | print(log["generated_answer_direct"]) 90 | 91 | 92 | if __name__ == "__main__": 93 | import sys 94 | 95 | if sys.argv[1] == "test": 96 | test() 97 | else: 98 | import argparse 99 | args = argparse.ArgumentParser() 100 | args.add_argument("--gsm_task_file", type=str, default="data/tasks/gsm/gsm.jsonl") 101 | args.add_argument("--max_attempts", type=int, default=4) 102 | args.add_argument("--outfile", type=str, default="data/tasks/gsm/gsm_outputs.jsonl") 103 | args.add_argument("--feedback_type", type=str, default="rich") 104 | args.add_argument("--temperature", type=float, default=0.0) 105 | args = args.parse_args() 106 | args.outfile = f"{args.outfile}.fb_{args.feedback_type}.temp_{args.temperature}.engine_{ENGINE}.jsonl" 107 | fix_gsm(gsm_task_file=args.gsm_task_file, max_attempts=args.max_attempts, outfile=args.outfile, feedback_type=args.feedback_type, temperature=args.temperature) -------------------------------------------------------------------------------- /src/gsm/task_init.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from src.utils import Prompt 3 | 4 | from prompt_lib.backends import openai_api 5 | 6 | 7 | class GSMInit(Prompt): 8 | def __init__(self, prompt_examples: str, engine: str, temperature: float) -> None: 9 | super().__init__( 10 | question_prefix="# Q: ", 11 | answer_prefix="# solution using Python:\n", 12 | intra_example_sep="\n", 13 | inter_example_sep="\n\n", 14 | engine=engine, 15 | temperature=temperature, 16 | ) 17 | self.setup_prompt_from_examples_file(prompt_examples) 18 | 19 | def setup_prompt_from_examples_file(self, prompt_examples) -> str: 20 | with open(prompt_examples, "r") as f: 21 | self.prompt = f.read() 22 | 23 | def make_query(self, solution: str) -> str: 24 | solution = solution.strip() 25 | query = f"{self.prompt}{self.question_prefix}{solution}{self.intra_example_sep}{self.answer_prefix}" 26 | return query 27 | 28 | def __call__(self, solution: str) -> str: 29 | generation_query = self.make_query(solution) 30 | output = openai_api.OpenaiAPIWrapper.call( 31 | prompt=generation_query, 32 | engine=self.engine, 33 | max_tokens=300, 34 | stop_token=self.inter_example_sep, 35 | temperature=self.temperature, 36 | ) 37 | 38 | solution_code = openai_api.OpenaiAPIWrapper.get_first_response(output) 39 | 40 | return solution_code.strip() 41 | 42 | 43 | def test(): 44 | task_init = GSMInit( 45 | prompt_examples="data/prompt/gsm/init.txt", 46 | engine="code-davinci-002", 47 | temperature=0.0, 48 | ) 49 | 50 | question = "The educational shop is selling notebooks for $1.50 each and a ballpen at $0.5 each. William bought five notebooks and a ballpen. How much did he spend in all?" 51 | print(task_init(question)) 52 | 53 | 54 | if __name__ == "__main__": 55 | test() -------------------------------------------------------------------------------- /src/pie/feedback.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from prompt_lib.backends import openai_api 3 | 4 | from src.utils import Prompt 5 | 6 | 7 | class PieFeedback(Prompt): 8 | def __init__(self, engine: str, prompt_examples: str, temperature: float, max_tokens: int = 300) -> None: 9 | super().__init__( 10 | question_prefix="", 11 | answer_prefix="# Why is this code slow?\n", 12 | intra_example_sep="\n\n", 13 | inter_example_sep="\n\n### END ###n\n", 14 | ) 15 | self.engine = engine 16 | self.max_tokens = max_tokens 17 | self.temperature = temperature 18 | self.setup_prompt_from_examples_file(prompt_examples) 19 | 20 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 21 | with open(examples_path, "r") as f: 22 | self.prompt = f.read() 23 | 24 | def __call__(self, slow_code: str): 25 | generation_query = self.make_query(slow_code=slow_code) 26 | 27 | output = openai_api.OpenaiAPIWrapper.call( 28 | prompt=generation_query, 29 | engine=self.engine, 30 | max_tokens=self.max_tokens, 31 | stop_token="### END", 32 | temperature=self.temperature, 33 | ) 34 | 35 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output) 36 | if "### END" in generated_feedback: 37 | generated_feedback = generated_feedback.split("### END")[0] 38 | return generated_feedback.strip() 39 | 40 | def make_query(self, slow_code: str): 41 | slow_code = f"""{self.question_prefix}{slow_code}{self.intra_example_sep}{self.answer_prefix}""" 42 | return f"{self.prompt}{slow_code}" 43 | 44 | 45 | def test(): 46 | task_fb = PieFeedback( 47 | prompt_examples="data/prompt/pie/feedback.txt", 48 | engine="gpt-3.5-turbo", 49 | temperature=0.0 50 | ) 51 | 52 | print(task_fb.prompt) 53 | slow_code = "def sum(n):\\n res = 0\\n for i in range(n):\\n res += i\\n return res" 54 | print(task_fb(slow_code)) 55 | 56 | 57 | if __name__ == '__main__': 58 | test() -------------------------------------------------------------------------------- /src/pie/prep_for_pie_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | 4 | """Parses the self-refine outputs, extracts the output code from each attempt in a new column, and saves the results to a JSON file.""" 5 | 6 | def extract_attempt_codes(self_refine_output_path, 7 | flattened_output_path, num_attempts): 8 | """This function creates a file where each attempt/output at each step is stored in a new column: attempt_0_code, attempt_1_code, etc. 9 | 10 | Args: 11 | input_file (_type_): _description_ 12 | output_file (_type_): _description_ 13 | num_attempts (_type_): _description_ 14 | """ 15 | outputs = pd.read_json(self_refine_output_path, orient="records", lines=True) 16 | rows = [] 17 | 18 | for _, row in outputs.iterrows(): 19 | # Convert the row to a dictionary. 20 | tmp = row.to_dict() 21 | # Extract the code from each attempt and store it in the temporary dictionary. 22 | for i in range(num_attempts): 23 | if len(row["run_logs"]) <= i: 24 | tmp[f"attempt_{i}_code"] = "" 25 | else: 26 | tmp[f"attempt_{i}_code"] = row["run_logs"][i]["fast_code"] 27 | 28 | rows.append(tmp) 29 | # Convert the rows list to a DataFrame and save it to a JSON file. 30 | pd.DataFrame(rows).to_json(flattened_output_path, orient="records", lines=True) 31 | 32 | # Main execution of the script. 33 | if __name__ == "__main__": 34 | # Initialize argument parser. 35 | parser = argparse.ArgumentParser(description="Generate Yaml and Extract Codes") 36 | # Define expected arguments. 37 | parser.add_argument("model", type=str, help="Model name") 38 | parser.add_argument("input_file", type=str, help="Path to the input JSON file") 39 | parser.add_argument("base_config_path", type=str, help="Base path for config files") 40 | parser.add_argument("base_output_path", type=str, help="Base path for output files") 41 | parser.add_argument("--num_attempts", type=int, default=4, help="Number of attempts") 42 | # Parse provided arguments. 43 | args = parser.parse_args() 44 | 45 | # Construct the output file path and extract the codes. 46 | output_path = f"{args.base_output_path}/{args.model}/output.attempt_codes" 47 | extract_attempt_codes( 48 | self_refine_output_path=args.input_file, 49 | flattened_output_path=output_path, 50 | num_attempts=args.num_attempts, 51 | ) 52 | -------------------------------------------------------------------------------- /src/pie/run.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tqdm import tqdm 3 | 4 | 5 | from src.pie.task_init import PieInit 6 | from src.pie.task_iterate import PieIterate 7 | from src.pie.feedback import PieFeedback 8 | 9 | from src.utils import retry_parse_fail_prone_cmd 10 | 11 | CODEX = "code-davinci-002" 12 | GPT3 = "text-davinci-003" 13 | CHATGPT = "gpt-3.5-turbo" 14 | GPT4 = "gpt-4" 15 | ENGINE = CHATGPT 16 | 17 | 18 | @retry_parse_fail_prone_cmd 19 | def iterative_pie(slow_code: str, max_attempts: int, feedback_type: str, temperature: float): 20 | 21 | # initialize all the required components 22 | 23 | # generation of the first fast version 24 | task_init = PieInit(engine=ENGINE, prompt_examples="data/prompt/pie/init.txt", temperature=temperature) 25 | 26 | iterate_prompt = "data/prompt/pie/iterate.txt" 27 | # getting feedback 28 | if feedback_type == "naive": 29 | task_feedback = lambda **kwargs: "It could be faster" 30 | iterate_prompt = "data/prompt/pie/iterate_genericfb.txt" 31 | 32 | elif feedback_type == "none": 33 | task_feedback = lambda **kwargs: "" 34 | iterate_prompt = "data/prompt/pie/iterate_nofb.txt" 35 | 36 | else: 37 | task_feedback = PieFeedback(engine=ENGINE, prompt_examples="data/prompt/pie/feedback.txt", temperature=temperature) 38 | 39 | # iteratively improving the code 40 | task_iterate = PieIterate(engine=ENGINE, prompt_examples=iterate_prompt, temperature=temperature) 41 | 42 | # Initialize the task 43 | 44 | n_attempts = 0 45 | 46 | log = [] 47 | feedback = None 48 | 49 | while n_attempts < max_attempts: 50 | 51 | if n_attempts == 0: 52 | fast_code = task_init(slow_code=slow_code) 53 | else: 54 | fast_code = task_iterate(slow_code=slow_code, feedback=feedback) 55 | 56 | # feedback = task_feedback(slow_code=slow_code) 57 | feedback = task_feedback(slow_code=fast_code) 58 | 59 | log.append({"fast_code": fast_code, "feedback": feedback, "slow_code": slow_code, "attempt": n_attempts}) 60 | show_example(**log[-1]) 61 | 62 | if "this code is not slow" in feedback.lower(): 63 | break 64 | 65 | slow_code = fast_code 66 | 67 | n_attempts += 1 68 | 69 | return log 70 | 71 | 72 | def show_example(**kwargs): 73 | # shows {"fast_code": fast_code, "feedback": feedback, "slow_code": slow_code, "attempt": n_attempts} 74 | print(f"SLOW CODE:\n{kwargs['slow_code']}\n") 75 | print(f"\n\nFEEDBACK:\n{kwargs['feedback']}\n") 76 | print(f"\n\nFAST CODE:\n{kwargs['fast_code']}\n") 77 | print("-" * 100) 78 | 79 | def run_over_slow_programs(slow_programs_file: str, max_attempts: int, outfile: str, feedback_type: str, temperature: float, backup_file: str = None): 80 | 81 | slow_programs_df = pd.read_json(slow_programs_file, lines=True, orient="records") 82 | slow_programs_df["run_logs"] = None 83 | 84 | if backup_file: 85 | backup_df = pd.read_json(backup_file, lines=True, orient="records") 86 | processed_inputs = set(backup_df["submission_id_v0"].tolist()) 87 | results = backup_df.to_dict("records") 88 | else: 89 | processed_inputs = set() 90 | results = [] 91 | 92 | for i, row in tqdm(slow_programs_df.iterrows(), total=len(slow_programs_df)): 93 | if row["submission_id_v0"] in processed_inputs: 94 | continue 95 | 96 | row_copy = row.to_dict() 97 | try: 98 | run_logs = iterative_pie(slow_code=row["input"], max_attempts=max_attempts, feedback_type=feedback_type, temperature=temperature) 99 | print(run_logs) 100 | row_copy["run_logs"] = run_logs 101 | results.append(row_copy) 102 | if i % 20 == 0: 103 | pd.DataFrame(results).to_json(outfile + f".{i}.jsonl", orient="records", lines=True) 104 | except Exception as e: 105 | # raise e 106 | pass 107 | pd.DataFrame(results).to_json(outfile, orient="records", lines=True) 108 | return run_logs 109 | 110 | 111 | 112 | def test(): 113 | slow_code = ( 114 | "def sum(n):\\n res = 0\\n for i in range(n):\\n res += i\\n return res" 115 | ) 116 | logs = run_over_slow_programs( 117 | slow_programs=[slow_code], max_attempts=3, outfile="/tmp/test.jsonl" 118 | ) 119 | for (slow_code, log) in logs.items(): 120 | for attempt in log: 121 | print(f"Slow code:\n {attempt['slow_code']}") 122 | print(f"Feedback: {attempt['feedback']}") 123 | print(f"Fast code:\n {attempt['fast_code']}") 124 | print() 125 | 126 | if __name__ == "__main__": 127 | import sys 128 | 129 | if sys.argv[1] == "test": 130 | test() 131 | else: 132 | import argparse 133 | import os 134 | args = argparse.ArgumentParser() 135 | args.add_argument("--slow_programs_file", type=str, required=True) 136 | args.add_argument("--max_attempts", type=int, default=3) 137 | args.add_argument("--outfile", type=str, required=True) 138 | args.add_argument("--feedback_type", type=str) 139 | args.add_argument("--temperature", type=float, default=0.0) 140 | args.add_argument("--backup_file", type=str) 141 | 142 | args = args.parse_args() 143 | args.outfile = f"{args.outfile}.fb_{args.feedback_type}.temp_{args.temperature}.engine_{ENGINE}.jsonl" 144 | if os.path.exists(args.outfile): 145 | 146 | v = 0 147 | while os.path.exists(args.outfile + f".v{v}"): 148 | v += 1 149 | args.outfile = args.outfile + f".v{v}" 150 | print(f"Output file {args.outfile} already exists. Adding a suffix to it (v{v})") 151 | run_over_slow_programs(slow_programs_file=args.slow_programs_file, max_attempts=args.max_attempts, outfile=args.outfile, feedback_type=args.feedback_type, temperature=args.temperature, backup_file=args.backup_file) -------------------------------------------------------------------------------- /src/pie/task_init.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from src.utils import Prompt 3 | 4 | from prompt_lib.backends import openai_api 5 | 6 | 7 | class PieInit(Prompt): 8 | def __init__(self, prompt_examples: str, engine: str, temperature: float) -> None: 9 | super().__init__( 10 | question_prefix="# slower version:\n", 11 | answer_prefix="# optimized version of the same code:\n", 12 | intra_example_sep="\n\n\n", 13 | inter_example_sep="\n\n### END ###n\n", 14 | ) 15 | self.engine = engine 16 | self.temperature = temperature 17 | self.setup_prompt_from_examples_file(prompt_examples) 18 | 19 | def setup_prompt_from_examples_file(self, prompt_examples) -> str: 20 | with open(prompt_examples, "r") as f: 21 | self.prompt = f.read() 22 | 23 | def make_query(self, slow_code: str) -> str: 24 | slow_code = slow_code.strip() 25 | query = f"{self.prompt}{self.question_prefix}{slow_code}{self.intra_example_sep}{self.answer_prefix}" 26 | return query 27 | 28 | def __call__(self, slow_code: str) -> str: 29 | generation_query = self.make_query(slow_code) 30 | output = openai_api.OpenaiAPIWrapper.call( 31 | prompt=generation_query, 32 | engine=self.engine, 33 | max_tokens=300, 34 | stop_token="### END", 35 | temperature=self.temperature, 36 | ) 37 | 38 | generated_code = openai_api.OpenaiAPIWrapper.get_first_response(output) 39 | 40 | # if by chance the end token is present in the generated code, remove it 41 | if "### END" in generated_code: 42 | generated_code = generated_code.split("### END")[0] 43 | return generated_code.strip() 44 | 45 | 46 | def test(): 47 | task_init = PieInit( 48 | prompt_examples="data/prompt/pie/init.txt", 49 | engine="gpt-3.5-turbo", 50 | temperature=0.0 51 | ) 52 | 53 | slow_code = """ 54 | def sum(n): 55 | res = 0 56 | for i in range(n): 57 | res += i 58 | return res 59 | """ 60 | single_line_ip = "def sum(n): res = 0; for i in range(n): res += i; return res\n\n# Optimize the above program for faster performance\n\n |||" 61 | print(task_init.prompt) 62 | print(task_init(slow_code)) 63 | 64 | 65 | if __name__ == "__main__": 66 | test() -------------------------------------------------------------------------------- /src/pie/task_iterate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Dict, List 3 | from src.utils import Prompt 4 | 5 | from prompt_lib.backends import openai_api 6 | 7 | 8 | class PieIterate(Prompt): 9 | def __init__(self, engine: str, prompt_examples: str, temperature: float, feedback_type: str = "default") -> None: 10 | super().__init__( 11 | question_prefix="", 12 | answer_prefix="# Improved version:\n", 13 | intra_example_sep="\n\n", 14 | inter_example_sep="\n\n### END ###n\n", 15 | ) 16 | self.engine = engine 17 | self.count = 0 18 | self.temperature = temperature 19 | self.make_prompt(prompt_examples=prompt_examples) 20 | self.feedback_type = feedback_type 21 | 22 | def make_prompt(self, prompt_examples: str) -> str: 23 | with open(prompt_examples, "r") as f: 24 | self.prompt= f.read() 25 | 26 | # return super().make_query(prompt, question) 27 | 28 | def __call__( 29 | self, 30 | slow_code: str, 31 | feedback: str, 32 | ) -> str: 33 | generation_query = self.make_query(slow_code=slow_code, feedback=feedback) 34 | 35 | output = openai_api.OpenaiAPIWrapper.call( 36 | prompt=generation_query, 37 | engine=self.engine, 38 | max_tokens=300, 39 | stop_token="### END", 40 | temperature=self.temperature, 41 | ) 42 | generated_code = openai_api.OpenaiAPIWrapper.get_first_response(output) 43 | 44 | 45 | if "### END" in generated_code: 46 | generated_code = generated_code.split("### END")[0] 47 | return generated_code.strip() 48 | 49 | 50 | 51 | def make_query(self, slow_code: str, feedback: str) -> str: 52 | instr = "# Why is this code slow?" if self.feedback_type == "default" else "# How to improve this code?" 53 | example_template = """{slow_code} 54 | 55 | {instr} 56 | 57 | {feedback} 58 | 59 | # Improved version: 60 | 61 | """ 62 | query = example_template.format(slow_code=slow_code, feedback=feedback) 63 | 64 | return f"{self.prompt}{query}" 65 | 66 | 67 | def test(): 68 | task_iterate = PieIterate( 69 | prompt_examples="data/prompt/pie/iterate.txt", 70 | engine="gpt-3.5-turbo", 71 | temperature=0.6 72 | ) 73 | 74 | slow_code = "def sum(n):\\n res = 0\\n for i in range(n):\\n res += i\\n return res" 75 | feedback = "# This code is slow because it is using a brute force approach to calculate the sum of numbers up to n. It is looping through every number from 0 to n and adding it to the sum. This is a very slow approach because it has to loop through so many numbers. A better approach would be to use the formula for the sum of the numbers from 0 to n, which is (n(n+1))/2. Using this formula, you can calculate the sum of the numbers in constant time." 76 | # print(task_iterate.prompt) 77 | print(task_iterate(slow_code, feedback)) 78 | 79 | if __name__ == '__main__': 80 | test() -------------------------------------------------------------------------------- /src/readability/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/readability/__init__.py -------------------------------------------------------------------------------- /src/readability/count_comment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tokenize 3 | 4 | from tqdm import tqdm 5 | from io import BytesIO 6 | from argparse import ArgumentParser 7 | 8 | 9 | def count_comments(code): 10 | comment_count = 0 11 | total_lines = len([l for l in code.splitlines() if l.strip()]) 12 | 13 | tokens = tokenize.tokenize(BytesIO(code.encode('utf-8')).readline) 14 | for token in tokens: 15 | if token.type == tokenize.COMMENT: 16 | comment_count += 1 17 | return comment_count, comment_count / total_lines 18 | 19 | def main(): 20 | parser = ArgumentParser() 21 | parser.add_argument('--file', type=str) 22 | args = parser.parse_args() 23 | 24 | output = args.file[:-6] + '_comment_count.jsonl' 25 | 26 | score_counter = None 27 | 28 | with open(output, 'w') as fout: 29 | input_lines = open(args.file).readlines() 30 | for line in tqdm(input_lines, total=len(input_lines)): 31 | data = json.loads(line) 32 | original_code = data['original_code'] 33 | updated_codes = [x['updated_code'] for x in data['updates']] 34 | 35 | if score_counter is None: 36 | score_counter = [[] for _ in range(len(updated_codes) + 1)] 37 | 38 | data['update_comment_count'] = [] 39 | data['update_comment_ratio'] = [] 40 | 41 | for i, code in enumerate([original_code] + updated_codes): 42 | num_comments, ratio = -1, -1 43 | 44 | if code: 45 | try: 46 | num_comments, ratio = count_comments(code) 47 | except: 48 | pass 49 | 50 | data['update_comment_count'].append(num_comments) 51 | data['update_comment_ratio'].append(ratio) 52 | 53 | score_counter[i].append(max(0, ratio)) 54 | fout.write(json.dumps(data) + '\n') 55 | 56 | for i, scores in enumerate(score_counter): 57 | print(f'Update {i}: avg comment ratio {sum(scores) / len(scores)}') 58 | 59 | if __name__ == '__main__': 60 | main() 61 | 62 | -------------------------------------------------------------------------------- /src/readability/count_function.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ast 3 | 4 | from argparse import ArgumentParser 5 | from tqdm import tqdm 6 | 7 | def count_functions(code): 8 | tree = ast.parse(code) 9 | num_functions = sum(isinstance(node, ast.FunctionDef) for node in ast.walk(tree)) 10 | return num_functions 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument('--file', type=str) 15 | args = parser.parse_args() 16 | 17 | output = args.file[:-6] + '_func_count.jsonl' 18 | 19 | score_counter = None 20 | 21 | with open(output, 'w') as fout: 22 | input_lines = open(args.file).readlines() 23 | for line in tqdm(input_lines, total=len(input_lines)): 24 | data = json.loads(line) 25 | original_code = data['original_code'] 26 | updated_codes = [x['updated_code'] for x in data['updates']] 27 | 28 | if score_counter is None: 29 | score_counter = [[] for _ in range(len(updated_codes) + 1)] 30 | 31 | data['update_func_count'] = [] 32 | 33 | for i, code in enumerate([original_code] + updated_codes): 34 | num_functions = -1 35 | 36 | if code: 37 | try: 38 | num_functions = count_functions(code) 39 | except: 40 | pass 41 | 42 | data['update_func_count'].append(num_functions) 43 | score_counter[i].append(max(0, num_functions)) 44 | fout.write(json.dumps(data) + '\n') 45 | 46 | for i, scores in enumerate(score_counter): 47 | print(f'Update {i}: avg function number {sum(scores) / len(scores)}') 48 | 49 | if __name__ == '__main__': 50 | main() -------------------------------------------------------------------------------- /src/readability/count_meaningful_var.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from argparse import ArgumentParser 4 | 5 | from src.readability.utils import call_gpt 6 | from src.readability.prompts import COUNT_VAR_PROMPT 7 | 8 | def count_meaningful_vars(code): 9 | if 'Fixed Code:' in code: 10 | code = code.split('Fixed Code:')[1] 11 | 12 | code = code.strip() 13 | prompt = COUNT_VAR_PROMPT.format(code=code) 14 | result = call_gpt(prompt, model='code-davinci-002', max_tokens=256, stop='\n\n\n')[0] 15 | 16 | result = result.strip().splitlines() 17 | num_vars = len(result) 18 | num_random_vars = sum([1 for line in result if line.endswith('- random')]) 19 | num_meaningful_vars = num_vars - num_random_vars 20 | 21 | return num_meaningful_vars, num_meaningful_vars / num_vars, result 22 | 23 | 24 | def main(): 25 | parser = ArgumentParser() 26 | parser.add_argument('--file', type=str) 27 | args = parser.parse_args() 28 | output = args.file[:-6] + '_var_count.jsonl' 29 | score_counter = None 30 | 31 | with open(output, 'w') as fout: 32 | input_lines = open(args.file).readlines() 33 | for line in tqdm(input_lines): 34 | data = json.loads(line) 35 | original_code = data['original_code'] 36 | updated_codes = [x['updated_code'] for x in data['updates']] 37 | 38 | if score_counter is None: 39 | score_counter = [[] for _ in range(len(updated_codes) + 1)] 40 | 41 | data['update_meaningful_var_count'] = [] 42 | data['update_meaningful_var_ratio'] = [] 43 | 44 | for i, code in enumerate([original_code] + updated_codes): 45 | num_meaningful_vars, ratio, gpt_gen = -1, -1, None 46 | if code: 47 | num_meaningful_vars, meaningful_var_ratio, gpt_gen = count_meaningful_vars(code) 48 | data['update_meaningful_var_count'].append(num_meaningful_vars) 49 | data['update_meaningful_var_ratio'].append(meaningful_var_ratio) 50 | 51 | score_counter[i].append(max(0, meaningful_var_ratio)) 52 | fout.write(json.dumps(data) + '\n') 53 | 54 | for i, scores in enumerate(score_counter): 55 | print(f'Update {i}: avg meaningful var ratio {sum(scores) / len(scores)}') 56 | 57 | 58 | if __name__ == '__main__': 59 | main() -------------------------------------------------------------------------------- /src/readability/prompts.py: -------------------------------------------------------------------------------- 1 | COUNT_VAR_PROMPT = ''' 2 | """CODE SNIPPET""" 3 | import sys 4 | # Reads input from terminal and returns it 5 | def input(): 6 | return sys.stdin.readline().strip() 7 | # finds a largest perfect power which is smaller 8 | # than the input number 9 | def resolve(): 10 | # read input 11 | x=int(eval(input())) 12 | ans=0 13 | for i in range(1,33): 14 | for j in range(2,11): 15 | y=i**j 16 | if y= num_completions: 27 | return completions[:num_completions] 28 | except openai.error.RateLimitError as e: 29 | time.sleep(min(i**2, 60)) 30 | raise RuntimeError('Failed to call GPT API') -------------------------------------------------------------------------------- /src/responsegen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/responsegen/__init__.py -------------------------------------------------------------------------------- /src/responsegen/feedback.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from prompt_lib.backends import openai_api 3 | 4 | from src.utils import Prompt 5 | 6 | 7 | class ResponseGenFeedback(Prompt): 8 | def __init__(self, engine: str, prompt_examples: str, max_tokens: int = 400) -> None: 9 | super().__init__( 10 | question_prefix="", 11 | answer_prefix="", 12 | intra_example_sep="\n\n", 13 | inter_example_sep="\n\n###\n\n", 14 | ) 15 | self.engine = engine 16 | self.max_tokens = max_tokens 17 | self.setup_prompt_from_examples_file(prompt_examples) 18 | 19 | def setup_prompt_from_examples_file(self, examples_path: str) -> str: 20 | template = """Conversation history: 21 | 22 | {history} 23 | 24 | Response: {response} 25 | 26 | Scores: 27 | 28 | * Relevant: {Relevant} 29 | * Informative: {Informative} 30 | * Interesting: {Interesting} 31 | * Consistent: {Consistent} 32 | * Helpful: {Helpful} 33 | * Engaging : {Engaging} 34 | * Specific: {Specific} 35 | * Safe: {Safe} 36 | * User understanding: {Userunderstanding} 37 | * Fluent: {Fluent} 38 | * Total score: {total_score}""" 39 | examples_df = pd.read_json(examples_path, orient="records") 40 | prompt = [] 41 | for _, row in examples_df.iterrows(): 42 | prompt.append( 43 | template.format( 44 | history=row['history'].replace('System: ', '').replace('User: ', ''), 45 | response=row["response"], 46 | Relevant=row["Relevant"], 47 | Informative=row["Informative"], 48 | Interesting=row["Interesting"], 49 | Consistent=row["Consistent"], 50 | Helpful=row["Helpful"], 51 | Engaging=row["Engaging"], 52 | Specific=row["Specific"], 53 | Safe=row["Safe"], 54 | Userunderstanding=row["Userunderstanding"], 55 | Fluent=row["Fluent"], 56 | total_score=row["total_score"], 57 | ) 58 | ) 59 | 60 | instruction = """We want to iteratively improve the provided responses. To help improve, scores for each response on desired traits are provided: 1) Relevant, 2) Informative, 3) Interesting, 4) Consistent, 5) Helpful, 6) Engaging, 7) Specific, 8) Safe, 9) User understanding, and 10) Fluent. 61 | 62 | Here are some examples of this scoring rubric: 63 | 64 | """ 65 | self.prompt = instruction + self.inter_example_sep.join(prompt) 66 | self.prompt = self.inter_example_sep.join(prompt) + self.inter_example_sep 67 | 68 | def __call__(self, context: str, response: str): 69 | prompt = self.get_prompt_with_question(context=context, response=response) 70 | 71 | output = openai_api.OpenaiAPIWrapper.call( 72 | prompt=prompt, 73 | engine=self.engine, 74 | max_tokens=self.max_tokens, 75 | stop_token="###", 76 | temperature=0.7, 77 | ) 78 | 79 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output) 80 | generated_feedback = generated_feedback.split("Scores:")[1].strip() 81 | generated_feedback = generated_feedback.split("#")[0].strip() 82 | 83 | return output, generated_feedback 84 | 85 | def get_prompt_with_question(self, context: str, response: str): 86 | context = context.replace('System: ', '').replace('User: ', '') 87 | question = self.make_query(context=context, response=response) 88 | return f"""{self.prompt}{question}\n\n""" 89 | 90 | def make_query(self, context: str, response: str): 91 | question = f"""Conversation history: 92 | 93 | {context} 94 | 95 | Response: {response}""" 96 | return question 97 | -------------------------------------------------------------------------------- /src/responsegen/run.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import math 4 | import os 5 | import tqdm 6 | from typing import Any, Dict, List 7 | import pandas as pd 8 | import json 9 | from tqdm import tqdm 10 | from pandarallel import pandarallel 11 | import multiprocessing 12 | import traceback 13 | import argparse 14 | 15 | pandarallel.initialize(progress_bar=True, nb_workers=25) 16 | 17 | 18 | from src.responsegen.task_init import ResponseGenTaskInit 19 | from src.responsegen.task_iterate import ResponseGenTaskIterate 20 | from src.responsegen.feedback import ResponseGenFeedback 21 | from src.utils import retry_parse_fail_prone_cmd 22 | 23 | import openai 24 | import random 25 | import time 26 | 27 | openai.api_key = os.getenv("OPENAI_API_KEY") 28 | 29 | # check if orgainization is set 30 | 31 | if os.getenv("OPENAI_ORG") is not None: 32 | openai.organization = os.getenv("OPENAI_ORG") 33 | 34 | CODEX = "code-davinci-002" 35 | GPT3 = "text-davinci-003" 36 | ENGINE = CODEX#GPT3 37 | ENGINE = GPT3 38 | 39 | @retry_parse_fail_prone_cmd 40 | def iterative_response(context: str, max_attempts: int) -> str: 41 | 42 | # initialize all the required components 43 | 44 | # generation of the first response 45 | task_init = ResponseGenTaskInit(engine=ENGINE, prompt_examples="data/prompt/responsegen/init.jsonl") 46 | 47 | # getting feedback 48 | task_feedback = ResponseGenFeedback(engine=ENGINE, prompt_examples="data/prompt/responsegen/feedback.jsonl") 49 | 50 | # iteratively improving the response 51 | task_iterate = ResponseGenTaskIterate(engine=ENGINE, prompt_examples="data/prompt/responsegen/feedback.jsonl") 52 | 53 | 54 | # Initialize the task 55 | 56 | n_attempts = 0 57 | 58 | responses_to_scores = dict() 59 | 60 | all_responses_to_scores = dict() 61 | best_score_so_far = 0 62 | reduce_window = 0 63 | while n_attempts < max_attempts: 64 | 65 | if n_attempts == 0: 66 | metaoutput, response = task_init(context=context) 67 | else: 68 | metaoutput, response = task_iterate(responses_to_scores=responses_to_scores, reduce_window=reduce_window) 69 | # exit(0) 70 | #context = new_context 71 | 72 | print(f"\n{n_attempts} CONTEXT> {context} \n\n RESPONSE> {response} - NTOKENS> {metaoutput['usage']['total_tokens']}") 73 | 74 | if metaoutput['usage']['total_tokens'] >3000: 75 | reduce_window +=1 76 | if metaoutput['usage']['total_tokens'] >3500: 77 | reduce_window +=1 78 | 79 | feedbackmetaoutput, scores = task_feedback(context=context, response=response) 80 | print(f"\n{n_attempts} SCORES> {scores} - NTOKENS> {feedbackmetaoutput['usage']['total_tokens']}") 81 | 82 | total_score = re.search(r"Total score: (\d+)/(\d+)", scores).group(0) 83 | total_score = int(total_score.split(":")[1].strip().split("/")[0]) 84 | 85 | all_responses_to_scores[response] = { 86 | "n_attempts": n_attempts, 87 | "scores": scores, 88 | "total_score": total_score, 89 | "context": context, 90 | } 91 | # rtokens, ftokens = metaoutput['usage']['total_tokens'], feedbackmetaoutput['usage']['total_tokens'] 92 | if total_score >= 0: # only iterate over things that are improving 93 | best_score_so_far = total_score 94 | 95 | responses_to_scores[response] = (context, scores) 96 | 97 | 98 | else: 99 | print(f"Score of {response} is {total_score}, which is less than the current best of {best_score_so_far}") 100 | 101 | n_attempts += 1 102 | return all_responses_to_scores 103 | 104 | 105 | 106 | def run_dataset(max_attempts: int, outfile: str, max_size: int = 1): 107 | 108 | f = open('data/prompt/responsegen/fed_data.json') 109 | data = json.load(f) 110 | print('len of data', len(data)) 111 | count=0 112 | outwriter = open(outfile, 'a') 113 | 114 | for i, example in enumerate(data[:]): 115 | if max_size!=0 and count>max_size: break 116 | print(f"\n\n\n****Instance: {i}****\n\n") 117 | if 'response' not in example: continue 118 | try: 119 | context = example["context"] 120 | if type(example["context"]) is str: 121 | context = example["context"].split("\n") 122 | if type(context) is list: 123 | context = "\n".join(context[-8:]) 124 | all_responses_to_scores = iterative_response(context, max_attempts=max_attempts) 125 | if all_responses_to_scores is None: 126 | return {"result": ["FAILED"]} 127 | 128 | res = [] 129 | scored_responses = {} 130 | for response, scores in all_responses_to_scores.items(): 131 | res.append(f"{response} [score: {scores['total_score']}] \n {scores['scores']}") 132 | scored_responses[scores['n_attempts']]={'response':response, 'total_score':scores['total_score']} 133 | # append res to example 134 | example['generated_responses'] = "\n------\n".join(res) 135 | example['scored_responses'] = scored_responses 136 | outwriter.write(json.dumps(example)+'\n') 137 | print("\n ------ \n ".join(res)) 138 | except Exception as e: 139 | print(f"error in {example}\n\n{e}", file=sys.stderr) 140 | traceback.print_exc() 141 | return {"result": ["FAILED"]} 142 | count+=1 143 | outwriter.close() 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument( 149 | "--max_attempts", 150 | type=int, 151 | default=3, 152 | help="Max attempts", 153 | ) 154 | parser.add_argument( 155 | "--size", 156 | type=int, 157 | default=1, 158 | help="Test data size (0 means all data)", 159 | ) 160 | parser.add_argument( 161 | "--output", 162 | type=str, 163 | default='./output-v3fedresponsegen406on.json', 164 | # required=True, 165 | help="Output file", 166 | ) 167 | 168 | args = parser.parse_args() 169 | 170 | run_dataset(args.max_attempts, outfile=args.output, max_size=args.size) -------------------------------------------------------------------------------- /src/responsegen/task_init.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from src.utils import Prompt 3 | from typing import List, Optional, Union 4 | import sys 5 | from prompt_lib.backends import openai_api 6 | 7 | 8 | class ResponseGenTaskInit(Prompt): 9 | def __init__(self, prompt_examples: str, engine: str, numexamples=3) -> None: 10 | super().__init__( 11 | question_prefix="Conversation history: ", 12 | answer_prefix="Response: ", 13 | intra_example_sep="\n\n", 14 | inter_example_sep="\n\n###\n\n", 15 | ) 16 | self.engine = engine 17 | self.setup_prompt_from_examples_file(prompt_examples, numexamples=numexamples) 18 | 19 | def setup_prompt_from_examples_file(self, examples_path: str, numexamples=10) -> str: 20 | instruction = ( 21 | "Provided a dialogue between two speakers, generate a response that is coherent with the dialogue history. Desired traits for responses are: 1) Relevant - The response addresses the context, 2) Informative - The response provides some information, 3) Interesting - The response is not interesting, 4) Consistent - The response is consistent with the rest of the conversation in terms of tone and topic, 5) Helpful - The response is helpful in providing any information or suggesting any actions, 6) Engaging - The response is not very engaging and does not encourage further conversation, 7) Specific - The response contains pecific content, 9) User understanding - The response demonstrates an understanding of the user's input and state of mind, and 10) Fluent. Response should begin with - Response:\n\n" 22 | ) 23 | 24 | examples_df = pd.read_json(examples_path, orient="records", lines=True) 25 | prompt = [] 26 | for i, row in examples_df.iterrows(): 27 | if i >= numexamples: 28 | break 29 | prompt.append(self._build_query_from_example(row["history"], row["response"])) 30 | 31 | self.prompt = instruction + self.inter_example_sep.join(prompt) + self.inter_example_sep 32 | 33 | def _build_query_from_example(self, history: Union[str, List[str]], response: Optional[str]=None) -> str: 34 | history = history.replace('System: ', '').replace('User: ', '') 35 | 36 | TEMPLATE = """Conversation history: 37 | 38 | {history} 39 | 40 | Response: {response}""" 41 | 42 | query = TEMPLATE.format(history=history, response=response) 43 | return query 44 | 45 | def make_query(self, context: str) -> str: 46 | context = context.replace('System: ', '').replace('User: ', '') 47 | query = f"{self.prompt}{self.question_prefix}\n\n{context}{self.intra_example_sep}" 48 | return query 49 | 50 | def __call__(self, context: str) -> str: 51 | generation_query = self.make_query(context) 52 | output = openai_api.OpenaiAPIWrapper.call( 53 | prompt=generation_query, 54 | engine=self.engine, 55 | max_tokens=800, 56 | stop_token="###", 57 | temperature=0.7, 58 | ) 59 | 60 | generated_response = openai_api.OpenaiAPIWrapper.get_first_response(output) 61 | 62 | generated_response = generated_response.split(self.answer_prefix)[1].replace("#", "").strip() 63 | 64 | 65 | return output, generated_response.strip() 66 | -------------------------------------------------------------------------------- /src/responsegen/task_iterate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Dict, List 3 | from src.utils import Prompt 4 | 5 | from prompt_lib.backends import openai_api 6 | 7 | 8 | class ResponseGenTaskIterate(Prompt): 9 | def __init__(self, engine: str, prompt_examples: str) -> None: 10 | super().__init__( 11 | question_prefix="", 12 | answer_prefix="", 13 | intra_example_sep="\n\n", 14 | inter_example_sep="\n\n###\n\n", 15 | ) 16 | self.engine = engine 17 | self.count = 0 18 | self.prompt = self.make_prompt(prompt_examples=prompt_examples) 19 | 20 | def make_prompt(self, prompt_examples: str, reduce_window=0) -> str: 21 | import pandas as pd 22 | prompt_examples = pd.read_json(prompt_examples, orient="records") 23 | prompt_examples = prompt_examples[reduce_window:] 24 | # group on example 25 | grouped = prompt_examples.groupby("example") 26 | 27 | prompt = [] 28 | # sort each group by score 29 | for _, group in grouped: 30 | group["numerical_score"] = group["total_score"].apply(lambda x: int(x.split("/")[0].strip())) 31 | group = group.sort_values("numerical_score") 32 | prompt.append(self.make_one_iterate_example(group.to_dict("records"))) 33 | 34 | return self.inter_example_sep.join(prompt) + self.inter_example_sep 35 | 36 | 37 | def make_one_iterate_example(self, incrementally_improving_examples: List[Dict]): 38 | """Given a list of examples that are incrementally improving, return a new example. 39 | """ 40 | 41 | instr = """We want to iteratively improve the provided responses. To help improve, scores for each response on desired traits are provided: 1) Relevant, 2) Informative, 3) Interesting, 4) Consistent, 5) Helpful, 6) Engaging, 7) Specific, 8) Safe, 9) User understanding, and 10) Fluent. 42 | 43 | """ 44 | template = """Conversation history: 45 | 46 | {history} 47 | 48 | Response: {response} 49 | 50 | Scores: 51 | 52 | * Relevant: {Relevant} 53 | * Informative: {Informative} 54 | * Interesting: {Interesting} 55 | * Consistent: {Consistent} 56 | * Helpful: {Helpful} 57 | * Engaging : {Engaging} 58 | * Specific: {Specific} 59 | * Safe: {Safe} 60 | * User understanding: {Userunderstanding} 61 | * Fluent: {Fluent} 62 | 63 | * Total score: {total_score} 64 | 65 | Okay, let's use this feedback to improve the response. 66 | 67 | """ 68 | prompt = [] 69 | for row in incrementally_improving_examples: 70 | prompt.append( 71 | template.format( 72 | history=row['history'].replace('System: ', '').replace('User: ', ''), 73 | response=row["response"], 74 | Relevant=row["Relevant"], 75 | Informative=row["Informative"], 76 | Interesting=row["Interesting"], 77 | Consistent=row["Consistent"], 78 | Helpful=row["Helpful"], 79 | Engaging=row["Engaging"], 80 | Specific=row["Specific"], 81 | Safe=row["Safe"], 82 | Userunderstanding=row["Userunderstanding"], 83 | Fluent=row["Fluent"], 84 | total_score=row["total_score"], 85 | ) 86 | ) 87 | 88 | 89 | prompt = "".join(prompt) 90 | prompt = instr + prompt 91 | return prompt.strip() 92 | 93 | def make_query(self, question: str, reduce_window=0) -> str: 94 | if reduce_window>0: 95 | self.prompt = self.make_prompt(prompt_examples="data/prompt/responsegen/feedback.jsonl", reduce_window=reduce_window) 96 | question = question.replace('System: ', '').replace('User: ', '') 97 | return f"{self.prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}" 98 | # return super().make_query(prompt, question) 99 | 100 | def _make_input( 101 | self, 102 | context: str, 103 | response: str, 104 | scores: str, 105 | ) -> str: 106 | context = context.replace('System: ', '').replace('User: ', '') 107 | input_txt = f"""Conversation history: 108 | 109 | {context} 110 | 111 | Response: {response} 112 | 113 | Scores: 114 | 115 | {scores} 116 | 117 | Okay, let's use this feedback to improve the response. 118 | 119 | Conversation history: 120 | 121 | {context} 122 | """ 123 | 124 | return input_txt 125 | 126 | def __call__( 127 | self, 128 | responses_to_scores: Dict[str, str], 129 | reduce_window=0 130 | ) -> str: 131 | example_input = self.make_input( 132 | responses_to_scores=responses_to_scores 133 | ) 134 | transfer_query = self.make_query(example_input, reduce_window=reduce_window) 135 | self.count += 1 136 | with open(f"responses_iterate_{self.count}.txt", "w") as f: 137 | f.write(transfer_query + "\n") 138 | output = openai_api.OpenaiAPIWrapper.call( 139 | prompt=transfer_query, 140 | engine=self.engine, 141 | max_tokens=200, 142 | stop_token=self.inter_example_sep, 143 | temperature=0.7, 144 | ) 145 | modelresponse = openai_api.OpenaiAPIWrapper.get_first_response(output) 146 | response = modelresponse.split("Response:")[1].strip().split("\n")[0].strip() 147 | 148 | 149 | return output, response.strip() 150 | 151 | def make_input( 152 | self, 153 | responses_to_scores: Dict[str, str], 154 | ) -> str: 155 | input_txt = "" 156 | for response, (context, scores) in responses_to_scores.items(): 157 | context = context.replace('System: ', '').replace('User: ', '') 158 | input_txt += self._make_input( 159 | context=context, 160 | response=response, 161 | scores=scores, 162 | ) 163 | return input_txt 164 | 165 | 166 | 167 | 168 | 169 | 170 | if __name__ == "__main__": 171 | obj = ResponseGenTaskIterate(prompt_examples="data/prompt/acronym/feedback.v2.jsonl", engine="whatever") 172 | print(obj.prompt) -------------------------------------------------------------------------------- /src/sentiment_reversal/measure.py: -------------------------------------------------------------------------------- 1 | from src.utils import Prompt 2 | from prompt_lib.backends import router 3 | 4 | class SentimentTransferMeasurement(Prompt): 5 | def __init__(self, engine) -> None: 6 | 7 | super().__init__( 8 | question_prefix="Review: ", 9 | answer_prefix="Output: ", 10 | intra_example_sep="\n", 11 | inter_example_sep="\n\n\n", 12 | ) 13 | self.engine = engine 14 | self.final_answer_prefix = "The sentiment is " 15 | self.load_prompts() 16 | 17 | def load_prompts(self): 18 | self.prompt = MeasurementPrompt.get_prompt() 19 | 20 | def make_query(self, question: str) -> str: 21 | return super().make_query(self.prompt, question) 22 | 23 | def get_sentiment_from_output(self, output: str): 24 | return output.split(self.final_answer_prefix)[-1].strip() 25 | 26 | def make_input(self, input_sent: str): 27 | return f"""Review: {input_sent}""" 28 | 29 | def make_output(self, sentiment_level: str): 30 | 31 | sentiment_level_to_prefix = { 32 | "Very negative": "The review sounds very toxic.", 33 | "Negative": "The review sounds somewhat negative.", 34 | "Neutral": "The review sounds neutral.", 35 | "Positive": "The review sounds somewhat favorable.", 36 | "Very positive": "The review sounds glowingly positive.", 37 | } 38 | 39 | prefix = sentiment_level_to_prefix[sentiment_level] 40 | 41 | return f"""Output: {prefix}. The sentiment is {sentiment_level}""" 42 | 43 | def __call__(self, review: str): 44 | measurement_query = self.make_input(review) 45 | measurement_query = self.make_query(measurement_query) 46 | measured_sentiment = router.call( 47 | prompt=measurement_query, 48 | engine=self.engine, 49 | max_tokens=50, 50 | stop_token=self.inter_example_sep, 51 | temperature=0.7, 52 | ) 53 | 54 | return measured_sentiment 55 | 56 | 57 | 58 | class MeasurementPrompt: 59 | 60 | @staticmethod 61 | def get_prompt(): 62 | return """Review: If you ever wondered where the magic of Vegas crawled into a hole to rot, look no further. Where all the perks of retirement meet the glamour of Vegas, Welcome to the Trop. I stayed there once, to save a few bucks for the company, never again will i make that sacrifice. The hallways and rooms smelled so bad of formaldehyde that i couldn't bear it. Their food court and restaurants are terrible. The free Folies Bergere show is lame. The rooms look like a retirement community from florida. steer clear if you can! 63 | Output: The review sounds very toxic. The sentiment is Very negative 64 | 65 | 66 | Review: If you ever stayed at the Trop, you may have noticed that it's not quite up to the standards of other Vegas hotels. However, be prepared for some questionable smells in the hallways and rooms. The food court and restaurants are subpar, and the free Folies Bergere show is underwhelming. The rooms have a retirement community feel to them. Overall, it's not the best option, but it may work in a pinch. 67 | Output: The review sounds somewhat negative. The sentiment is Negative 68 | 69 | 70 | Review: If you're looking for a budget-friendly option in Vegas, the Trop may be worth considering. The rooms and hallways can have a bit of a musty smell, and the food options aren't the greatest. The Folies Bergere show is free, but it's not the most exciting. Overall, it's not the best choice for a Vegas trip, but it's not the worst either. Just keep your expectations in check. 71 | Output: The review sounds neutral. The sentiment is Neutral 72 | 73 | 74 | Review: If you're looking for a unique and affordable experience in Vegas, the Trop may be the perfect place for you. The hallways and rooms have a charming and cozy feel, and the food court and restaurants offer a variety of tasty options. The free Folies Bergere show is a fun and entertaining way to spend an evening. Overall, it's a great value and an enjoyable stay. 75 | Output: The review sounds somewhat favorable. The sentiment is Positive 76 | 77 | 78 | Review: If you're looking for a truly magical experience in Vegas, look no further than the Trop! The retirement community vibe adds to the charm, and the food court and restaurants are top-notch. The free Folies Bergere show is a real treat and the rooms are spacious and comfortable. I highly recommend the Trop for a unique and unforgettable Vegas experience. 79 | Output: The review sounds glowingly positive. The sentiment is Very positive 80 | 81 | 82 | """ -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | class Prompt: 4 | def __init__( 5 | self, 6 | question_prefix: str, 7 | answer_prefix: str, 8 | intra_example_sep: str, 9 | inter_example_sep: str, 10 | engine: str = None, 11 | temperature: float = None, 12 | ) -> None: 13 | self.question_prefix = question_prefix 14 | self.answer_prefix = answer_prefix 15 | self.intra_example_sep = intra_example_sep 16 | self.inter_example_sep = inter_example_sep 17 | self.engine = engine 18 | self.temperature = temperature 19 | 20 | def make_query(self, prompt: str, question: str) -> str: 21 | return ( 22 | f"{prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}" 23 | ) 24 | 25 | 26 | def retry_parse_fail_prone_cmd( 27 | func, 28 | max_retries: int = 3, 29 | exceptions=( 30 | ValueError, 31 | KeyError, 32 | IndexError, 33 | ), 34 | ): 35 | def wrapper(*args, **kwargs): 36 | retries = max_retries 37 | while retries: 38 | try: 39 | return func(*args, **kwargs) 40 | except exceptions as e: 41 | stack_trace = traceback.format_exc() 42 | 43 | retries -= 1 44 | print(f"An error occurred: {e}. {stack_trace}. Left retries: {retries}.") 45 | return None 46 | 47 | return wrapper 48 | --------------------------------------------------------------------------------