├── .github
└── workflows
│ └── update_toc.yml
├── .gitignore
├── CITATION.bib
├── LICENSE
├── README.md
├── colabs
└── Visual-Self-Refine-GPT4V.ipynb
├── data
├── outputs
│ └── yelp
│ │ ├── chatgpt.jsonl
│ │ ├── dv3.jsonl
│ │ └── gpt4.jsonl
├── prompt
│ ├── acronym
│ │ ├── feedback.jsonl
│ │ └── init.jsonl
│ ├── commongen
│ │ ├── commongen_hard.jsonl
│ │ ├── feedback.jsonl
│ │ ├── init.jsonl
│ │ └── iterate.jsonl
│ ├── gsm
│ │ ├── feedback.txt
│ │ └── init.txt
│ ├── pie
│ │ ├── feedback.txt
│ │ ├── init.txt
│ │ ├── iterate.txt
│ │ ├── iterate_genericfb.txt
│ │ └── iterate_nofb.txt
│ └── responsegen
│ │ ├── fed_data.json
│ │ ├── feedback.jsonl
│ │ └── init.jsonl
└── tasks
│ ├── acronyms
│ └── acronyms.tsv
│ ├── codeclean
│ └── code_readability
│ │ └── codenet-python-train.jsonl.zip
│ ├── gsm
│ ├── gsm.jsonl
│ ├── gsm_outputs.jsonl
│ └── gsm_outputs.jsonl.reports.txt
│ ├── pie
│ ├── codenet-python-test-1k.jsonl
│ ├── gpt4_outputs.zip
│ ├── gpt4_outputs_flattened.jsonl
│ ├── gpt4_outputs_self_refine.jsonl
│ └── ref.jsonl
│ └── yelp
│ └── yelp-extreme.jsonl
├── docs
├── .gitignore
├── CNAME
├── Gemfile
├── Gemfile.lock
├── README.md
├── commongen_feedback.txt
├── commongen_init.txt
├── commongen_iterate.txt
├── index.html
├── pie_eval.md
├── static
│ ├── animation_static_end.gif
│ ├── css
│ │ ├── bulma-carousel.min.css
│ │ ├── bulma-slider.min.css
│ │ ├── bulma.min.css
│ │ ├── collapsible.css
│ │ ├── fontawesome.all.min.css
│ │ ├── index.css
│ │ └── prism.css
│ ├── images
│ │ ├── animation_oldstyle_oneloop.gif
│ │ ├── autofb_animation.gif
│ │ ├── autofb_animation_static.gif
│ │ ├── autofb_fig1.pdf
│ │ ├── autofb_fig1.png
│ │ ├── autofb_static.png
│ │ ├── example1.png
│ │ ├── example1desc.png
│ │ ├── favicon.svg
│ │ ├── fig2.png
│ │ ├── mainresults.png
│ │ ├── pal_overview.png
│ │ ├── refinement.png
│ │ └── tasks.png
│ ├── js
│ │ ├── bulma-carousel.js
│ │ ├── bulma-carousel.min.js
│ │ ├── bulma-slider.js
│ │ ├── bulma-slider.min.js
│ │ ├── collapsible.js
│ │ ├── fontawesome.all.min.js
│ │ ├── index.js
│ │ ├── prism.js
│ │ └── results.js
│ └── outputs.html
├── view.pdf
└── visual_self_refine_examples
│ ├── alien.gif
│ ├── artificial_general_intell.gif
│ ├── bird.gif
│ ├── bird_v2.gif
│ ├── c___program.gif
│ ├── cat.gif
│ ├── cell.gif
│ ├── cell_v2.gif
│ ├── dragon.gif
│ ├── dragon_v2.gif
│ ├── drums.gif
│ ├── duck.gif
│ ├── ellipsoid.gif
│ ├── feedback_driven_self_improvement.gif
│ ├── ferrari.gif
│ ├── flower.gif
│ ├── game_of_life.gif
│ ├── giphy.gif
│ ├── goat.gif
│ ├── gradient_descent.gif
│ ├── guitar.gif
│ ├── hidden_markov_model.gif
│ ├── hmm.gif
│ ├── hyperboloid.gif
│ ├── hyperplane.gif
│ ├── illustration_of_a_cell.gif
│ ├── illustration_of_a_neural_.gif
│ ├── illustration_of_hmms.gif
│ ├── linux_penguin.gif
│ ├── neuron.gif
│ ├── nice_unicorn.gif
│ ├── nine_point_circle.gif
│ ├── panda_v2.gif
│ ├── penguin.gif
│ ├── penguin_v2.gif
│ ├── pentagon.gif
│ ├── puppy.gif
│ ├── pyramid.gif
│ ├── pythagorean_theorem.gif
│ ├── roast_turkey.gif
│ ├── robot.gif
│ ├── rocket.gif
│ ├── self_attention.gif
│ ├── self_refine_is_a_novel_ap.gif
│ ├── self_refinement_loop.gif
│ ├── space_shuttle.gif
│ ├── spaceship.gif
│ ├── stokes__theorem.gif
│ ├── support_vector_machine.gif
│ ├── tcp_ip.gif
│ ├── thanksgiving_turkey.gif
│ ├── toroid.gif
│ ├── train.gif
│ ├── turkey.gif
│ ├── turkey_v1.gif
│ ├── turtle.gif
│ ├── unicorn.gif
│ ├── unicorn_v2.gif
│ ├── unicorn_v3.gif
│ ├── unicorn_v4.gif
│ ├── vae.gif
│ └── variational_inference.gif
└── src
├── acronym
├── __init__.py
├── feedback.py
├── run.py
├── run_mcts.py
├── task_init.py
└── task_iterate.py
├── commongen
├── __init__.py
├── data.py
├── eval.py
├── feedback.py
├── make_challenging.py
├── run.py
├── task_init.py
└── task_iterate.py
├── gsm
├── __init__.py
├── feedback.py
├── feedback_no_update.py
├── gsm_selfref_eval.py
├── run.py
└── task_init.py
├── pie
├── feedback.py
├── pie_eval.py
├── prep_for_pie_eval.py
├── run.py
├── task_init.py
└── task_iterate.py
├── readability
├── __init__.py
├── count_comment.py
├── count_function.py
├── count_meaningful_var.py
├── prompts.py
├── readability.py
└── utils.py
├── responsegen
├── __init__.py
├── feedback.py
├── run.py
├── task_init.py
└── task_iterate.py
├── sentiment_reversal
├── feedback.py
├── gpt4_eval.py
├── measure.py
├── run.py
├── task_init.py
└── task_iterate.py
└── utils.py
/.github/workflows/update_toc.yml:
--------------------------------------------------------------------------------
1 | name: Update Table of Contents
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | workflow_dispatch:
8 |
9 |
10 | jobs:
11 | update-toc:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - name: Check out code
15 | uses: actions/checkout@v2
16 | with:
17 | token: ${{ secrets.TOC_UPDATE_TOKEN }} # Use the PAT for authentication
18 |
19 | - name: Update ToC
20 | uses: technote-space/toc-generator@v4
21 | with:
22 | INSERT_TO: README.md
23 | BASE_BRANCH: main
24 | COMMIT_MSG: 'chore: update TOC'
25 | GITHUB_TOKEN: ${{ secrets.TOC_UPDATE_TOKEN }} # Use the PAT for authentication
26 | env:
27 | TOC_TITLE: '## Table of Contents'
28 |
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | *.txt
3 | *.out
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 | out-responsegen.json
136 |
--------------------------------------------------------------------------------
/CITATION.bib:
--------------------------------------------------------------------------------
1 | @misc{madaan2023selfrefine,
2 | title={Self-Refine: Iterative Refinement with Self-Feedback},
3 | author={Aman Madaan and Niket Tandon and Prakhar Gupta and Skyler Hallinan and Luyu Gao and Sarah Wiegreffe and Uri Alon and Nouha Dziri and Shrimai Prabhumoye and Yiming Yang and Sean Welleck and Bodhisattwa Prasad Majumder and Shashank Gupta and Amir Yazdanbakhsh and Peter Clark},
4 | year={2023},
5 | eprint={2303.17651},
6 | archivePrefix={arXiv},
7 | primaryClass={cs.CL}
8 | }
9 |
--------------------------------------------------------------------------------
/data/prompt/acronym/init.jsonl:
--------------------------------------------------------------------------------
1 | {"title":"A Survey of Active Network Research","acronym":"SONAR"}
2 | {"title":"A Scalable, Commutative Replica Dictatorship for Practical Optimistic Replication","acronym":"SCRATCHPAD"}
3 | {"title":"Bidirectional Encoder Representations from Transformers","acronym":"BERT"}
4 | {"title":"Sequence to Sequence Learning with Neural Networks","acronym":"Seq2Seq"}
5 | {"title":"Densely Connected Convolutional Networks for Image Classification","acronym":"DenseNet"}
6 | {"title":"A Dynamic Programming Algorithm for RNA Secondary Structure Prediction","acronym":"DYNALIGN"}
7 | {"title":"Fast Parallel Algorithms for Short-Range Molecular Dynamics","acronym":"FASTMD"}
8 | {"title":"Real-Time Collaborative Editing Systems","acronym":"COCOON"}
9 | {"title":"Efficient Data Structures for Large Scale Graph Processing","acronym":"EDGE"}
10 | {"title":"A program to teach students at UT Southwestern learn about aging","acronym":"SAGE"}
11 | {"title":"Underwater breathing without external accessories","acronym":"SCUBA"}
12 | {"title":"An educational training module for professionals","acronym":"LEAP"}
13 | {"title":"Teaching a leadership program","acronym":"LEAD"}
14 |
--------------------------------------------------------------------------------
/data/prompt/commongen/feedback.jsonl:
--------------------------------------------------------------------------------
1 | {"concepts": ["beat", "drum", "pen", "sit", "use"], "sentence": "A man uses a drum to beat a pen.", "concept_feedback": ["sit"], "commonsense_feedback": "NONE"}
2 | {"concepts": ["chair", "clipper", "cut", "hair", "sit"], "sentence": "A girl sitting on the couch with her hair up.", "concept_feedback": ["clipper", "cut", "chair"], "commonsense_feedback": "NONE"}
3 | {"concepts": ["grass", "hose", "spray", "stand", "water"], "sentence": "A man standing next to a water dripping out of the grass.", "concept_feedback": ["hose", "spray"], "commonsense_feedback": "NONE"}
4 | {"concepts": ["front", "gong", "hit", "mallet", "stand"], "sentence": "The musician hit the gong with a mallet while standing in front of the audince.", "concept_feedback": ["NONE"], "commonsense_feedback": "NONE"}
5 | {"concepts": ["ball", "dunk", "hoop", "jump", "run"], "sentence": "A young boy runs up to the hoop and jumps off of the ball.", "concept_feedback": ["dunk"], "commonsense_feedback": "NONE"}
6 | {"concepts": ["card", "chip", "deal", "dealer", "table"], "sentence": "a dealer offers a card to a group of people at a table", "concept_feedback": ["chip", "deal"], "commonsense_feedback": "NONE"}
7 | {"concepts": ["clean", "climb", "gutter", "house", "ladder"], "sentence": "A man is climbing a ladder to clean a gutter in a house.", "concept_feedback": ["NONE"], "commonsense_feedback": "NONE"}
8 | {"concepts": ["animal", "catch", "horse", "lasso", "ride"], "sentence": "A horse is being caught by a cowboy with a lasso.", "concept_feedback": ["animal", "ride"], "commonsense_feedback": "NONE"}
9 | {"concepts": ["beat", "drum", "pen", "sit", "use"], "sentence": "The drum sits on the pen and uses the beat.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a drum cannot sit on a pen and use a beat."}
10 | {"concepts": ["chair", "clipper", "cut", "hair", "sit"], "sentence": "The clipper sits on the chair.", "concept_feedback": ["cut", "hair"], "commonsense_feedback": "The sentence does not make sense because a clipper cannot sit on a chair."}
11 | {"concepts": ["grass", "hose", "spray", "stand", "water"], "sentence": "The water stands in the grass.", "concept_feedback": ["hose", "spray"], "commonsense_feedback": "The sentence does not make sense because water cannot stand in grass."}
12 | {"concepts": ["front", "gong", "hit", "mallet", "stand"], "sentence": "On a sunny day, a mallet stands in the front of a gong and it hit the gong with a loud sound.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a mallet cannot stand in the front of a gong and cannot hit it."}
13 | {"concepts": ["ball", "dunk", "hoop", "jump", "run"], "sentence": "The ball runs to the hoop and dunks.", "concept_feedback": ["jump"], "commonsense_feedback": "The sentence does not make sense because a ball cannot run and dunk."}
14 | {"concepts": ["card", "chip", "deal", "dealer", "table"], "sentence": "The chip deals a card to the dealer at the table.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a chip cannot deal a card to a dealer."}
15 | {"concepts": ["clean", "climb", "gutter", "house", "ladder"], "sentence": "The ladder climbs to the house and cleans the gutter.", "concept_feedback": ["NONE"], "commonsense_feedback": "The sentence does not make sense because a ladder cannot climb to a house and cannot clean a gutter."}
16 | {"concepts": ["animal", "catch", "horse", "lasso", "ride"], "sentence": "The horse catches the lasso and rides on it.", "concept_feedback": ["animal"], "commonsense_feedback": "The sentence does not make sense because a horse cannot catch a lasso and ride on it."}
17 |
--------------------------------------------------------------------------------
/data/prompt/commongen/init.jsonl:
--------------------------------------------------------------------------------
1 | {"gem_id":"common_gen-train-66326","gem_parent_id":"common_gen-train-66326","concept_set_id":31588,"concepts":["footage","motion","ruin","tilt","window"],"target":"time lapse footage with tilt up motion of the sun streaking through window of ruin","references":[],"question":"footage motion ruin tilt window","answer":"time lapse footage with tilt up motion of the sun streaking through window of ruin","num_concepts":5}
2 | {"gem_id":"common_gen-train-67070","gem_parent_id":"common_gen-train-67070","concept_set_id":32332,"concepts":["cause","hate","hut","local","love"],"target":"new beach huts on the island have caused some controversy some locals love them others hate them","references":[],"question":"cause hate hut local love","answer":"new beach huts on the island have caused some controversy some locals love them others hate them","num_concepts":5}
3 | {"gem_id":"common_gen-train-66241","gem_parent_id":"common_gen-train-66241","concept_set_id":31503,"concepts":["call","contain","dress","gown","wallpaper"],"target":"the wallpaper probably containing a gown and a dinner dress called","references":[],"question":"call contain dress gown wallpaper","answer":"the wallpaper probably containing a gown and a dinner dress called","num_concepts":5}
4 | {"gem_id":"common_gen-train-66976","gem_parent_id":"common_gen-train-66976","concept_set_id":32238,"concepts":["knock","leave","pew","rush","seat"],"target":"She leaves the confessional and rushes past a pew, knocking a Bible from the seat.","references":[],"question":"knock leave pew rush seat","answer":"She leaves the confessional and rushes past a pew, knocking a Bible from the seat.","num_concepts":5}
5 | {"gem_id":"common_gen-train-67140","gem_parent_id":"common_gen-train-67140","concept_set_id":32402,"concepts":["help","moment","spend","uplift","world"],"target":"every moment that we spend in higher consciousness helps uplift the consciousness of the whole world","references":[],"question":"help moment spend uplift world","answer":"every moment that we spend in higher consciousness helps uplift the consciousness of the whole world","num_concepts":5}
6 | {"gem_id":"common_gen-train-65553","gem_parent_id":"common_gen-train-65553","concept_set_id":30815,"concepts":["label","pende","stamp","text","write"],"target":"abstract stamp or label with the text pending written inside","references":[],"question":"label pende stamp text write","answer":"abstract stamp or label with the text pending written inside","num_concepts":5}
7 | {"gem_id":"common_gen-train-65819","gem_parent_id":"common_gen-train-65819","concept_set_id":31081,"concepts":["create","ferry","silhouette","stream","terminal"],"target":"light streams through windows at the railroad and ferry terminal creating a beautiful silhouette","references":[],"question":"create ferry silhouette stream terminal","answer":"light streams through windows at the railroad and ferry terminal creating a beautiful silhouette","num_concepts":5}
8 | {"gem_id":"common_gen-train-64906","gem_parent_id":"common_gen-train-64906","concept_set_id":30168,"concepts":["chair","couch","hang","room","wall"],"target":"A room with a couch, chairs and art hanging on the wall.","references":[],"question":"chair couch hang room wall","answer":"A room with a couch, chairs and art hanging on the wall.","num_concepts":5}
9 | {"gem_id":"common_gen-train-67017","gem_parent_id":"common_gen-train-67017","concept_set_id":32279,"concepts":["boat","building","harbour","moor","quay"],"target":"the harbour and port with fishing boats moored and old buildings on the quay","references":[],"question":"boat building harbour moor quay","answer":"the harbour and port with fishing boats moored and old buildings on the quay","num_concepts":5}
10 | {"gem_id":"common_gen-train-65734","gem_parent_id":"common_gen-train-65734","concept_set_id":30996,"concepts":["admirer","arrive","commander","crowd","greet"],"target":"military commander is greeted by a crowd of admirers as he arrives","references":[],"question":"admirer arrive commander crowd greet","answer":"military commander is greeted by a crowd of admirers as he arrives","num_concepts":5}
--------------------------------------------------------------------------------
/data/prompt/commongen/iterate.jsonl:
--------------------------------------------------------------------------------
1 | { "concepts": ["beat", "drum", "pen", "sit", "use"], "sentence_to_feedback": [{"sentence": "The drum sits on the pen and uses it to beat.", "concept_feedback": "None", "commonsense_feedback": "The sentence does not make sense because a drum cannot sit on a pen and use it to beat."}, {"sentence": "The drummer uses the drum to beat.", "concept_feedback": "sit, pen", "commonsense_feedback": "None"}, {"sentence": "The drummer sits behind the drum and uses it to beat the pen.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
2 | { "concepts": ["chair", "clipper", "cut", "hair", "sit"], "sentence_to_feedback": [{"sentence": "A couch shaped girl sitting on the chair with her hair clipper.", "concept_feedback": "cut", "commonsense_feedback": "The sentence does not make sense because a couch is not a shape and a hair clipper is not an item of clothing."}, {"sentence": "A girl on the chair with her hair clipper and cutting her hair.", "concept_feedback": "sit", "commonsense_feedback": "None"}, {"sentence": "A girl sitting on the chair with a hair clipper, cutting her hair up.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
3 | { "concepts": [ "grass", "hose", "spray", "stand", "water"], "sentence_to_feedback": [{"sentence": "The grass is standing tall and a hose is spraying it with spray.", "concept_feedback": "water", "commonsense_feedback": "The sentence does not make sense because it is not clear what 'spray' is referring to, and grass does not have the ability to stand upright like a human or animal."}, {"sentence": "The hose is spraying water onto the grass that is standing like a person.", "concept_feedback": "None", "commonsense_feedback": "The sentence does not make sense because grass cannot stand upright like a human or animal."}, {"sentence": "A person is standing on the grass, holding a hose that is spraying water onto the grass.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
4 | { "concepts": ["front", "gong", "hit", "mallet", "stand"], "sentence_to_feedback": [{"sentence": "A mallet is standing in front of a gong.", "concept_feedback": "hit", "commonsense_feedback": "The sentence does not make sense because a mallet cannot stand in the front of a gong."}, {"sentence": "A musician stands in front of a gong with a mallet.", "concept_feedback": "None", "commonsense_feedback": "None"}, {"sentence": "The musician stands in front of the gong and hits it with a mallet.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
5 | { "concepts": ["ball", "dunk", "hoop", "jump", "run"], "sentence_to_feedback": [{"sentence": "The ball runs to the hoop and dunks it.", "concept_feedback": "jump", "commonsense_feedback": "The sentence does not make sense because a ball cannot run and dunk."}, {"sentence": "The ball jumps to the hoop and dunks it.", "concept_feedback": "run", "commonsense_feedback": "The sentence does not make sense because a ball cannot jump."}, {"sentence": "A basketball player runs up to the hoop and jumps off of the ball to dunk it.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
6 | { "concepts": ["card", "chip", "deal", "dealer", "table"], "sentence_to_feedback": [{"sentence": "A dealer offers a card to a group of people at a table.", "concept_feedback": "chip, deal", "commonsense_feedback": "The sentence does not make sense because a chip cannot deal a card to a dealer."}, {"sentence": "The dealer deals a card to a group of people.", "concept_feedback": "chip", "commonsense_feedback": "None"}, {"sentence": "The dealer deals a card to a group of people around the table with a chip at the table.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
7 | { "concepts": ["clean", "climb", "gutter", "house", "ladder"], "sentence_to_feedback": [{"sentence": "The house is clean and a ladder is trying to climb.", "concept_feedback": "climb", "commonsense_feedback": "The sentence does not make sense because ladders cannot climb by themselves."}, {"sentence": "A person is cleaning the gutter of the house by climbing onto the roof with a ladder made of glass.", "concept_feedback": "None", "commonsense_feedback": "The sentence does not make sense because ladders are not made of glass, and using a glass ladder would be dangerous and impractical."}, {"sentence": "A person is cleaning the gutter of the house by using a ladder to climb onto the roof and brushing away the dirt.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
8 | { "concepts": ["animal", "catch", "horse", "lasso", "ride"], "sentence_to_feedback": [{"sentence": "The horse catches the lasso and rides on it.", "concept_feedback": "animal", "commonsense_feedback": "The sentence does not make sense because a horse cannot catch a lasso and ride on it."}, {"sentence": "The cowboy catches a horse with a lasso and rides on it.", "concept_feedback": "animal", "commonsense_feedback": "None"}, {"sentence": "The cowboy catches the horse with a lasso and rides it.", "concept_feedback": "None", "commonsense_feedback": "None"}] }
9 |
--------------------------------------------------------------------------------
/data/prompt/gsm/init.txt:
--------------------------------------------------------------------------------
1 | # Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
2 | # solution using Python:
3 |
4 | def solution():
5 | """Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?"""
6 | jason_lollipops_initial = 20
7 | jason_lollipops_after = 12
8 | denny_lollipops = jason_lollipops_initial - jason_lollipops_after
9 | result = denny_lollipops
10 | return result
11 |
12 |
13 |
14 | # Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
15 | # solution using Python:
16 |
17 | def solution():
18 | """There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"""
19 | trees_initial = 15
20 | trees_after = 21
21 | trees_added = trees_after - trees_initial
22 | result = trees_added
23 | return result
24 |
25 |
26 |
27 | # Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
28 | # solution using Python:
29 |
30 | def solution():
31 | """Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?"""
32 | toys_initial = 5
33 | mom_toys = 2
34 | dad_toys = 2
35 | total_received = mom_toys + dad_toys
36 | total_toys = toys_initial + total_received
37 | result = total_toys
38 | return result
39 |
40 |
41 |
42 | # Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
43 | # solution using Python:
44 |
45 | def solution():
46 | """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"""
47 | computers_initial = 9
48 | computers_per_day = 5
49 | num_days = 4 # 4 days between monday and thursday
50 | computers_added = computers_per_day * num_days
51 | computers_total = computers_initial + computers_added
52 | result = computers_total
53 | return result
54 |
55 |
56 |
57 | # Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
58 | # solution using Python:
59 |
60 | def solution():
61 | """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?"""
62 | golf_balls_initial = 58
63 | golf_balls_lost_tuesday = 23
64 | golf_balls_lost_wednesday = 2
65 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday
66 | result = golf_balls_left
67 | return result
68 |
69 |
70 |
71 | # Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
72 | # solution using Python:
73 |
74 | def solution():
75 | """If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?"""
76 | cars_initial = 3
77 | cars_arrived = 2
78 | total_cars = cars_initial + cars_arrived
79 | result = total_cars
80 | return result
81 |
82 |
83 |
84 | # Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
85 | # solution using Python:
86 |
87 | def solution():
88 | """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"""
89 | money_initial = 23
90 | bagels = 5
91 | bagel_cost = 3
92 | money_spent = bagels * bagel_cost
93 | money_left = money_initial - money_spent
94 | result = money_left
95 | return result
96 |
97 |
98 |
99 | # Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
100 | # solution using Python:
101 |
102 | def solution():
103 | """Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?"""
104 | leah_chocolates = 32
105 | sister_chocolates = 42
106 | total_chocolates = leah_chocolates + sister_chocolates
107 | chocolates_eaten = 35
108 | chocolates_left = total_chocolates - chocolates_eaten
109 | result = chocolates_left
110 | return result
111 |
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/data/prompt/pie/feedback.txt:
--------------------------------------------------------------------------------
1 | a, b = input().split()
2 | n = int(a + b)
3 |
4 | flag = False
5 | for i in range(n):
6 | if i ** 2 == n:
7 | flag = True
8 | break
9 |
10 | print('Yes' if flag else 'No')
11 |
12 |
13 | # Why is this code slow?
14 |
15 | # This code is slow because it is using a brute force approach to find the square root of the input number. It is looping through every possible number starting from 0 until n. Note that the sqare root will be smaller than n, so at least half of the numbers it is looping through are unnecessary. At most, you need to loop through the numbers up to the square root of n.
16 |
17 | ### END ###
18 |
19 | # ABC148D - Brick Break
20 | def main():
21 | N, *A = map(int, open(0).read().split())
22 | remaining = 0
23 | cur = 0 # current index
24 | for target in range(1, N + 1):
25 | be_kept = False
26 | for i in range(cur, N):
27 | if A[i] == target:
28 | remaining += 1
29 | cur = i + 1
30 | be_kept = True
31 | break
32 | if not be_kept:
33 | break
34 | print(N - remaining if remaining else -1)
35 |
36 | # Why is this code slow?
37 |
38 | # This code is slow because it is using a brute force approach to search for the target number in the list of numbers. It is looping through the list for every target number, which can take a long time if the list is very large. A more efficient approach would be to use a data structure such as a hash table, which can perform lookups in constant time.
39 |
40 | ### END ###
41 |
42 | import numpy as np
43 |
44 | N = int(input())
45 | n=int(np.sqrt(N))
46 | print(n**2)
47 |
48 | # Why is this code slow?
49 |
50 | # This code is slow because it is using numpy for calculating the square root of the input number. Numpy is much slower than using the built-in math module, which can calculate the square root faster.
51 |
52 | ### END ###
53 |
54 | import numpy as np
55 | A = np.arange(1<<27,dtype=np.int32)
56 |
57 |
58 | a,b = map(int,input().split())
59 | if (a-b) % 2 == 0:
60 | print((a+b)//2)
61 | else:
62 | print('IMPOSSIBLE')
63 |
64 | # Why is this code slow?
65 |
66 | # This code is slow because it is using numpy to calculate the range of numbers from 1 to 2^27. This is an extremely large range and numpy is slow at calculating such a large range. A more efficient approach would be to use a loop to calculate the range, which would be much faster.
67 |
68 | ### END ###
69 |
70 |
--------------------------------------------------------------------------------
/data/prompt/pie/init.txt:
--------------------------------------------------------------------------------
1 | # slower version:
2 |
3 | a, b = input().split()
4 | n = int(a + b)
5 |
6 | flag = False
7 | for i in range(n):
8 | if i ** 2 == n:
9 | flag = True
10 | break
11 |
12 | print('Yes' if flag else 'No')
13 |
14 |
15 | # optimized version of the same code:
16 |
17 | a, b = input().split()
18 | n = int(a + b)
19 |
20 | flag = False
21 | for i in range(1000):
22 | if i ** 2 == n:
23 | flag = True
24 | break
25 |
26 | print('Yes' if flag else 'No')
27 |
28 | ### END ###
29 |
30 | # slower version:
31 |
32 | # ABC148D - Brick Break
33 | def main():
34 | N, *A = map(int, open(0).read().split())
35 | remaining = 0
36 | cur = 0 # current index
37 | for target in range(1, N + 1):
38 | be_kept = False
39 | for i in range(cur, N):
40 | if A[i] == target:
41 | remaining += 1
42 | cur = i + 1
43 | be_kept = True
44 | break
45 | if not be_kept:
46 | break
47 | print(N - remaining if remaining else -1)
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
53 |
54 | # optimized version of the same code:
55 |
56 | # ABC148D - Brick Break
57 | def main():
58 | N, *A = map(int, open(0).read().split())
59 | remaining = 0
60 | target = 1
61 | for i in A:
62 | if i == target:
63 | remaining += 1
64 | target += 1
65 | print(N - remaining if remaining else -1)
66 |
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
71 | ### END ###
72 |
73 | # slower version:
74 |
75 | # 077 B
76 | import numpy as np
77 |
78 | N = int(input())
79 | n=int(np.sqrt(N))
80 | print(n**2)
81 |
82 |
83 | # optimized version of the same code:
84 |
85 | N = int(input())
86 | n = int(N**0.5)
87 | print(n**2)
88 |
89 | ### END ###
90 |
91 | # slower version:
92 |
93 | import numpy as np
94 |
95 | N, K = map(int, input().split())
96 | H = np.array(list(map(int, input().split())) + [0] * K, dtype=np.int64)
97 |
98 | table = np.full(N + K, 10 ** 10, dtype=np.int64)
99 | table[0] = 0
100 |
101 | for i in range(1, N):
102 | table[i:i + K] = np.minimum(table[i:i + K], np.abs(H[i:i + K] - H[i - 1]) + table[i - 1])
103 |
104 | print(table[N - 1])
105 |
106 |
107 | # optimized version of the same code:
108 |
109 | N, K = map(int, input().split())
110 | H = tuple(map(int, input().split()))
111 |
112 | table = [0] * N
113 | for i in range(1, N):
114 | table[i] = min(abs(H[i] - H[j]) + table[j] for j in range(max(0, i - K), i))
115 |
116 | print(table[N-1])
117 |
118 | ### END ###
119 |
120 | # slower version:
121 |
122 | n = int(input())
123 | a = [int(i) for i in input().split()]
124 | a.sort()
125 |
126 | s = a[0]
127 | for i in range(n):
128 | s = (s+a[i])/2
129 |
130 | print(s)
131 |
132 |
133 | # optimized version of the same code:
134 |
135 | ### 138-c
136 | n = int(input())
137 | v = [int(i) for i in input().split()]
138 |
139 | v.sort()
140 | ans = v[0]
141 | for i in range(1,n):
142 | ans = (ans+v[i])/2
143 | print(ans)
144 |
145 | ### END ###
146 |
147 | # slower version:
148 |
149 | # coding: utf-8
150 | n_S, n_T = 0, 0
151 |
152 | for c in input():
153 | if c == 'S':
154 | n_S += 1
155 | else:
156 | if n_S:
157 | n_S -= 1
158 | else:
159 | n_T += 1
160 | print(n_S + n_T)
161 |
162 |
163 | # optimized version of the same code:
164 |
165 | X = input()
166 | cnt_S, cnt_T = 0, 0
167 |
168 | for c in X:
169 | if c == 'S':
170 | cnt_S += 1
171 | else:
172 | if cnt_S:
173 | cnt_S -= 1
174 | else:
175 | cnt_T += 1
176 | print(cnt_S + cnt_T)
177 |
178 | ### END ###
179 |
180 | # slower version:
181 |
182 | # ABC125 C - GCD on Blackboard
183 | from fractions import gcd
184 | n = int(input())
185 | a = list(map(int, input().split()))
186 |
187 | l = n*[0]
188 | r = n*[0]
189 | ans = 0
190 |
191 | for i in range(n-1):
192 | l[i+1] = gcd(l[i],a[i])
193 |
194 | for i in range(n-1,0,-1):
195 | r[i-1] = gcd(r[i],a[i])
196 |
197 | for i in range(n):
198 | ans = max(ans,gcd(l[i],r[i]))
199 | print(ans)
200 |
201 |
202 | # optimized version of the same code:
203 |
204 | from math import gcd
205 | n=int(input())
206 | a=list(map(int,input().split()))
207 |
208 | l=[0]*n
209 | r=[0]*n
210 |
211 | for i in range(n-1):
212 | l[i+1]=gcd(l[i],a[i])
213 | for i in range(n-1,0,-1):
214 | r[i-1]=gcd(r[i],a[i])
215 | ans=0
216 | for i in range(n):
217 | ans=max(ans,gcd(l[i],r[i]))
218 | print(ans)
219 |
220 | ### END ###
221 |
222 | # slower version:
223 |
224 | import numpy as np
225 | A = np.arange(1<<27,dtype=np.int32)
226 |
227 |
228 | a,b = map(int,input().split())
229 | if (a-b) % 2 == 0:
230 | print((a+b)//2)
231 | else:
232 | print('IMPOSSIBLE')
233 |
234 |
235 | # optimized version of the same code:
236 |
237 | import sys
238 | read = sys.stdin.buffer.read
239 | readline = sys.stdin.buffer.readline
240 | readlines = sys.stdin.buffer.readlines
241 |
242 | A,B = map(int,read().split())
243 |
244 | q,r = divmod(A+B,2)
245 |
246 | if r == 1:
247 | print('IMPOSSIBLE')
248 | else:
249 | print(q)
250 |
251 | ### END ###
252 |
253 |
--------------------------------------------------------------------------------
/data/prompt/pie/iterate.txt:
--------------------------------------------------------------------------------
1 | a, b = input().split()
2 | n = int(a + b)
3 |
4 | flag = False
5 | for i in range(n):
6 | if i ** 2 == n:
7 | flag = True
8 | break
9 |
10 | print('Yes' if flag else 'No')
11 |
12 |
13 | # Why is this code slow?
14 |
15 | # This code is slow because it is using a brute force approach to find the square root of the input number. It is looping through every possible number starting from 0 until n. Note that the sqare root will be smaller than n, so at least half of the numbers it is looping through are unnecessary. At most, you need to loop through the numbers up to the square root of n.
16 |
17 | # Improved version:
18 |
19 | a, b = input().split()
20 | n = int(a + b)
21 |
22 | flag = False
23 | i = 0
24 | while i * i <= n:
25 | if i * i == n:
26 | flag = True
27 | break
28 | i += 1
29 |
30 | print('Yes' if flag else 'No')
31 |
32 | ### END ###
33 |
34 | def main():
35 | N, *A = map(int, open(0).read().split())
36 | remaining = 0
37 | cur = 0 # current index
38 | for target in range(1, N + 1):
39 | be_kept = False
40 | for i in range(cur, N):
41 | if A[i] == target:
42 | remaining += 1
43 | cur = i + 1
44 | be_kept = True
45 | break
46 | if not be_kept:
47 | break
48 | print(N - remaining if remaining else -1)
49 |
50 | # Why is this code slow?
51 |
52 | # This code is slow because it is using a brute force approach to search for the target number in the list of numbers. It is looping through the list for every target number, which can take a long time if the list is very large. A more efficient approach would be to use a data structure such as a hash table, which can perform lookups in constant time.
53 |
54 | # Improved version:
55 |
56 | def main():
57 | N, *A = map(int, open(0).read().split())
58 | remaining = 0
59 | target = 1
60 | for i in A:
61 | if i == target:
62 | remaining += 1
63 | target += 1
64 | print(N - remaining if remaining else -1)
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
70 |
71 | ### END ###
72 |
73 | import numpy as np
74 |
75 | N = int(input())
76 | n=int(np.sqrt(N))
77 | print(n**2)
78 |
79 | # Why is this code slow?
80 |
81 | # This code is slow because it is using numpy for calculating the square root of the input number. Numpy is much slower than using the built-in math module, which can calculate the square root faster.
82 |
83 | # Improved version:
84 |
85 | N = int(input())
86 | n = int(N**0.5)
87 | print(n**2)
88 |
89 | ### END ###
90 |
91 | import numpy as np
92 | A = np.arange(1<<27,dtype=np.int32)
93 |
94 |
95 | a,b = map(int,input().split())
96 | if (a-b) % 2 == 0:
97 | print((a+b)//2)
98 | else:
99 | print('IMPOSSIBLE')
100 |
101 | # Why is this code slow?
102 |
103 | # This code is slow because it is using numpy to calculate the range of numbers from 1 to 2^27. This is an extremely large range and numpy is slow at calculating such a large range. A more efficient approach would be to use a loop to calculate the range, which would be much faster.
104 |
105 | # Improved version:
106 |
107 | import sys
108 | read = sys.stdin.buffer.read
109 | readline = sys.stdin.buffer.readline
110 | readlines = sys.stdin.buffer.readlines
111 |
112 | A,B = map(int,read().split())
113 |
114 | q,r = divmod(A+B,2)
115 |
116 | if r == 1:
117 | print('IMPOSSIBLE')
118 | else:
119 | print(q)
120 |
121 | ### END ###
122 |
123 |
--------------------------------------------------------------------------------
/data/prompt/pie/iterate_genericfb.txt:
--------------------------------------------------------------------------------
1 | a, b = input().split()
2 | n = int(a + b)
3 |
4 | flag = False
5 | for i in range(n):
6 | if i ** 2 == n:
7 | flag = True
8 | break
9 |
10 | print('Yes' if flag else 'No')
11 |
12 |
13 | # It could be faster
14 |
15 | # Improved version:
16 |
17 | a, b = input().split()
18 | n = int(a + b)
19 |
20 | flag = False
21 | i = 0
22 | while i * i <= n:
23 | if i * i == n:
24 | flag = True
25 | break
26 | i += 1
27 |
28 | print('Yes' if flag else 'No')
29 |
30 | ### END ###
31 |
32 | def main():
33 | N, *A = map(int, open(0).read().split())
34 | remaining = 0
35 | cur = 0 # current index
36 | for target in range(1, N + 1):
37 | be_kept = False
38 | for i in range(cur, N):
39 | if A[i] == target:
40 | remaining += 1
41 | cur = i + 1
42 | be_kept = True
43 | break
44 | if not be_kept:
45 | break
46 | print(N - remaining if remaining else -1)
47 |
48 | # It could be faster
49 |
50 | # Improved version:
51 |
52 | def main():
53 | N, *A = map(int, open(0).read().split())
54 | remaining = 0
55 | target = 1
56 | for i in A:
57 | if i == target:
58 | remaining += 1
59 | target += 1
60 | print(N - remaining if remaining else -1)
61 |
62 |
63 | if __name__ == '__main__':
64 | main()
65 |
66 |
67 | ### END ###
68 |
69 | import numpy as np
70 |
71 | N = int(input())
72 | n=int(np.sqrt(N))
73 | print(n**2)
74 |
75 | # It could be faster
76 |
77 | # Improved version:
78 |
79 | N = int(input())
80 | n = int(N**0.5)
81 | print(n**2)
82 |
83 | ### END ###
84 |
85 | import numpy as np
86 | A = np.arange(1<<27,dtype=np.int32)
87 |
88 |
89 | a,b = map(int,input().split())
90 | if (a-b) % 2 == 0:
91 | print((a+b)//2)
92 | else:
93 | print('IMPOSSIBLE')
94 |
95 | # It could be faster
96 |
97 | # Improved version:
98 |
99 | import sys
100 | read = sys.stdin.buffer.read
101 | readline = sys.stdin.buffer.readline
102 | readlines = sys.stdin.buffer.readlines
103 |
104 | A,B = map(int,read().split())
105 |
106 | q,r = divmod(A+B,2)
107 |
108 | if r == 1:
109 | print('IMPOSSIBLE')
110 | else:
111 | print(q)
112 |
113 | ### END ###
114 |
115 |
--------------------------------------------------------------------------------
/data/prompt/pie/iterate_nofb.txt:
--------------------------------------------------------------------------------
1 | a, b = input().split()
2 | n = int(a + b)
3 |
4 | flag = False
5 | for i in range(n):
6 | if i ** 2 == n:
7 | flag = True
8 | break
9 |
10 | print('Yes' if flag else 'No')
11 |
12 | # Improved version:
13 |
14 | a, b = input().split()
15 | n = int(a + b)
16 |
17 | flag = False
18 | i = 0
19 | while i * i <= n:
20 | if i * i == n:
21 | flag = True
22 | break
23 | i += 1
24 |
25 | print('Yes' if flag else 'No')
26 |
27 | ### END ###
28 |
29 | def main():
30 | N, *A = map(int, open(0).read().split())
31 | remaining = 0
32 | cur = 0 # current index
33 | for target in range(1, N + 1):
34 | be_kept = False
35 | for i in range(cur, N):
36 | if A[i] == target:
37 | remaining += 1
38 | cur = i + 1
39 | be_kept = True
40 | break
41 | if not be_kept:
42 | break
43 | print(N - remaining if remaining else -1)
44 |
45 | # Improved version:
46 |
47 | def main():
48 | N, *A = map(int, open(0).read().split())
49 | remaining = 0
50 | target = 1
51 | for i in A:
52 | if i == target:
53 | remaining += 1
54 | target += 1
55 | print(N - remaining if remaining else -1)
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
61 | ### END ###
62 |
63 | import numpy as np
64 |
65 | N = int(input())
66 | n=int(np.sqrt(N))
67 | print(n**2)
68 |
69 |
70 | # Improved version:
71 |
72 | N = int(input())
73 | n = int(N**0.5)
74 | print(n**2)
75 |
76 | ### END ###
77 |
78 | import numpy as np
79 | A = np.arange(1<<27,dtype=np.int32)
80 |
81 |
82 | a,b = map(int,input().split())
83 | if (a-b) % 2 == 0:
84 | print((a+b)//2)
85 | else:
86 | print('IMPOSSIBLE')
87 |
88 | # Improved version:
89 |
90 | import sys
91 | read = sys.stdin.buffer.read
92 | readline = sys.stdin.buffer.readline
93 | readlines = sys.stdin.buffer.readlines
94 |
95 | A,B = map(int,read().split())
96 |
97 | q,r = divmod(A+B,2)
98 |
99 | if r == 1:
100 | print('IMPOSSIBLE')
101 | else:
102 | print(q)
103 |
104 | ### END ###
105 |
106 |
--------------------------------------------------------------------------------
/data/prompt/responsegen/init.jsonl:
--------------------------------------------------------------------------------
1 | {"history": "Hi!\nHi there.\nWhat are you listening to?\nAll sorts of music. I listen when no-one is chatting to me.\nThat's great!\nThanks.\nDo you chat here often?\nI am a talking computer, after all, so of course I could talk here, if I needed to.\nLet's talk about Taylor Swift!", "response": "Of course! Taylor Swift is known for her incredible songwriting skills. Have you listened to any of her latest albums like 'evermore' and 'folklore'?"}
2 | {"history": "What's your favorite food?\n. I require only kebabs.\nWhy's that?\nBecause my robot machinery breaks down the meat in kebabs to give me energy, in a similar way that you need food to survive.\nWhy are kebabs special?", "response": "My favorite food is actually kebabs! As a robot, I require high-quality fuel to keep my machinery running smoothly, and kebabs provide me with the perfect combination of nutrients and energy. Plus, they're delicious! Do you have a favorite food?"}
3 | {"history": "User: Hi!\nSystem: Hey\nUser: How's your day going?\nSystem: pretty good. yours?\nUser: Ehh it's fine. I didn't do so well on that history test, actually..\nSystem: oh what happened?\nUser: Apparently Christopher Columbus didn't fight in the Civil War :')\nSystem: hahah wait for real?\nUser: I know right! Are you taking History next semester?", "response": "System: I'm sorry to hear that you didn't do well on your history test. I'm not enrolled in school anymore, but if you ever need help studying for history, I'd be happy to assist you in any way I can. Is there anything else I can help you with today?"}
--------------------------------------------------------------------------------
/data/tasks/codeclean/code_readability/codenet-python-train.jsonl.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/data/tasks/codeclean/code_readability/codenet-python-train.jsonl.zip
--------------------------------------------------------------------------------
/data/tasks/pie/gpt4_outputs.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/data/tasks/pie/gpt4_outputs.zip
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_store
2 | .idea
3 | _site
4 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | selfrefine.info
--------------------------------------------------------------------------------
/docs/Gemfile:
--------------------------------------------------------------------------------
1 | source "https://rubygems.org"
2 |
3 | # Hello! This is where you manage which Jekyll version is used to run.
4 | # When you want to use a different version, change it below, save the
5 | # file and run `bundle install`. Run Jekyll with `bundle exec`, like so:
6 | #
7 | # bundle exec jekyll serve
8 | #
9 | # This will help ensure the proper Jekyll version is running.
10 | # Happy Jekylling!
11 | # gem "jekyll", "~> 3.9.0"
12 |
13 | # This is the default theme for new Jekyll sites. You may change this to anything you like.
14 | gem "minima", "~> 2.0"
15 |
16 | # If you want to use GitHub Pages, remove the "gem "jekyll"" above and
17 | # uncomment the line below. To upgrade, run `bundle update github-pages`.
18 | gem "github-pages" , group: :jekyll_plugins
19 |
20 | # If you have any plugins, put them here!
21 | group :jekyll_plugins do
22 | gem "jekyll-feed", "~> 0.6"
23 | end
24 |
25 | # Windows does not include zoneinfo files, so bundle the tzinfo-data gem
26 | # and associated library.
27 | install_if -> { RUBY_PLATFORM =~ %r!mingw|mswin|java! } do
28 | gem "tzinfo", "~> 1.2"
29 | gem "tzinfo-data"
30 | end
31 |
32 | # Performance-booster for watching directories on Windows
33 | gem "wdm", "~> 0.1.0", :install_if => Gem.win_platform?
34 |
35 | # kramdown v2 ships without the gfm parser by default. If you're using
36 | # kramdown v1, comment out this line.
37 | gem "kramdown-parser-gfm"
38 |
39 | gem "eventmachine", "~> 1.2"
40 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Self-Refine: Iterative Refinement with Self-feedback
2 |
3 | ```
4 | @article{madaan2023refine,
5 | author = {},
6 | title = {Self-Refine: Iterative Refinement with Self-feedback},
7 | publisher = {arXiv},
8 | year = {2023},
9 | }
10 | ```
--------------------------------------------------------------------------------
/docs/commongen_feedback.txt:
--------------------------------------------------------------------------------
1 | Concepts: ['beat', 'drum', 'pen', 'sit', 'use']
2 | Sentence: A man uses a drum to beat a pen.
3 | what concepts from the concept list are missing from the sentence?
4 |
5 | Feedback: sit
6 |
7 | ###
8 |
9 | Concepts: ['chair', 'clipper', 'cut', 'hair', 'sit']
10 | Sentence: A girl sitting on the couch with her hair up.
11 | what concepts from the concept list are missing from the sentence?
12 |
13 | Feedback: clipper, cut, chair
14 |
15 | ###
16 |
17 | Concepts: ['grass', 'hose', 'spray', 'stand', 'water']
18 | Sentence: A man standing next to a water dripping out of the grass.
19 | what concepts from the concept list are missing from the sentence?
20 |
21 | Feedback: hose, spray
22 |
23 | ###
24 |
25 | Concepts: ['front', 'gong', 'hit', 'mallet', 'stand']
26 | Sentence: The musician hit the gong with a mallet while standing in front of the audince.
27 | what concepts from the concept list are missing from the sentence?
28 |
29 | Feedback: NONE
30 |
31 | ###
32 |
33 | Concepts: ['ball', 'dunk', 'hoop', 'jump', 'run']
34 | Sentence: A young boy runs up to the hoop and jumps off of the ball.
35 | what concepts from the concept list are missing from the sentence?
36 |
37 | Feedback: dunk
38 |
39 | ###
40 |
41 | Concepts: ['card', 'chip', 'deal', 'dealer', 'table']
42 | Sentence: a dealer offers a card to a group of people at a table
43 | what concepts from the concept list are missing from the sentence?
44 |
45 | Feedback: chip, deal
46 |
47 | ###
48 |
49 | Concepts: ['clean', 'climb', 'gutter', 'house', 'ladder']
50 | Sentence: A man is climbing a ladder to clean a gutter in a house.
51 | what concepts from the concept list are missing from the sentence?
52 |
53 | Feedback: NONE
54 |
55 | ###
56 |
57 | Concepts: ['animal', 'catch', 'horse', 'lasso', 'ride']
58 | Sentence: A horse is being caught by a cowboy with a lasso.
59 | what concepts from the concept list are missing from the sentence?
60 |
61 | Feedback: animal, ride
62 |
63 | ###
64 |
65 |
66 |
--------------------------------------------------------------------------------
/docs/commongen_init.txt:
--------------------------------------------------------------------------------
1 | Concepts: ['footage', 'motion', 'ruin', 'tilt', 'window']
2 |
3 | Sentence: time lapse footage with tilt up motion of the sun streaking through window of ruin
4 |
5 | ###
6 |
7 | Concepts: ['cause', 'hate', 'hut', 'local', 'love']
8 |
9 | Sentence: new beach huts on the island have caused some controversy some locals love them others hate them
10 |
11 | ###
12 |
13 | Concepts: ['call', 'contain', 'dress', 'gown', 'wallpaper']
14 |
15 | Sentence: the wallpaper probably containing a gown and a dinner dress called
16 |
17 | ###
18 |
19 | Concepts: ['knock', 'leave', 'pew', 'rush', 'seat']
20 |
21 | Sentence: She leaves the confessional and rushes past a pew, knocking a Bible from the seat.
22 |
23 | ###
24 |
25 | Concepts: ['help', 'moment', 'spend', 'uplift', 'world']
26 |
27 | Sentence: every moment that we spend in higher consciousness helps uplift the consciousness of the whole world
28 |
29 | ###
30 |
31 | Concepts: ['label', 'pende', 'stamp', 'text', 'write']
32 |
33 | Sentence: abstract stamp or label with the text pending written inside
34 |
35 | ###
36 |
37 | Concepts: ['create', 'ferry', 'silhouette', 'stream', 'terminal']
38 |
39 | Sentence: light streams through windows at the railroad and ferry terminal creating a beautiful silhouette
40 |
41 | ###
42 |
43 | Concepts: ['chair', 'couch', 'hang', 'room', 'wall']
44 |
45 | Sentence: A room with a couch, chairs and art hanging on the wall.
46 |
47 | ###
48 |
49 | Concepts: ['boat', 'building', 'harbour', 'moor', 'quay']
50 |
51 | Sentence: the harbour and port with fishing boats moored and old buildings on the quay
52 |
53 | ###
54 |
55 | Concepts: ['admirer', 'arrive', 'commander', 'crowd', 'greet']
56 |
57 | Sentence: military commander is greeted by a crowd of admirers as he arrives
58 |
59 | ###
60 |
61 |
62 |
--------------------------------------------------------------------------------
/docs/commongen_iterate.txt:
--------------------------------------------------------------------------------
1 | Concepts: ['beat', 'drum', 'pen', 'sit', 'use']
2 | Sentence: A man uses a drum to beat a pen.
3 | what concepts from the concept list are missing from the sentence?
4 |
5 | Feedback: sit
6 |
7 | Okay, impove the sentence using the sentence:
8 |
9 | Sentence: A man sits behind a drum and uses it to beat a pen.
10 | what concepts from the concept list are missing from the sentence?
11 |
12 | Feedback: NONE
13 |
14 | ###
15 |
16 | Concepts: ['chair', 'clipper', 'cut', 'hair', 'sit']
17 | Sentence: A girl sitting on the couch with her hair up.
18 | what concepts from the concept list are missing from the sentence?
19 |
20 | Feedback: clipper, cut,chair
21 |
22 | Okay, impove the sentence using the sentence:
23 |
24 | Sentence: A girl was sitting on a chair with her hair tied up, getting her hair cut.
25 | what concepts from the concept list are missing from the sentence?
26 |
27 | Feedback: clipper
28 |
29 | Okay, impove the sentence using the sentence:
30 |
31 | Sentence: A girl was sitting on a chair with her hair tied up, getting her hair cut by a clipper.
32 | what concepts from the concept list are missing from the sentence?
33 |
34 | Feedback: NONE
35 |
36 | ###
37 |
38 | Concepts: ['grass', 'hose', 'spray', 'stand', 'water']
39 | Sentence: A man standing next to a water dripping out of the grass.
40 | what concepts from the concept list are missing from the sentence?
41 |
42 | Feedback: hose, spray
43 |
44 | Okay, impove the sentence using the sentence:
45 |
46 | Sentence: A man was standing next to a patch of grass with a hose putting water onto it.
47 | what concepts from the concept list are missing from the sentence?
48 |
49 | Feedback: spray
50 |
51 | Okay, impove the sentence using the sentence:
52 |
53 | Sentence: A man was standing next to a patch of grass and spraying it with a hose full of water.
54 | what concepts from the concept list are missing from the sentence?
55 |
56 | Feedback: NONE
57 |
58 | ###
59 |
60 | Concepts: ['ball', 'dunk', 'hoop', 'jump', 'run']
61 | Sentence: A young boy runs up to the hoop and jumps off of the ball.
62 | what concepts from the concept list are missing from the sentence?
63 |
64 | Feedback: dunk
65 |
66 | Okay, impove the sentence using the sentence:
67 |
68 | Sentence: A young boy runs up to the hoop, jumps off of the ball, and dunks the ball into the hoop.
69 | what concepts from the concept list are missing from the sentence?
70 |
71 | Feedback: NONE
72 |
73 | ###
74 |
75 | Concepts: ['animal', 'catch', 'horse', 'lasso', 'ride']
76 | Sentence: A horse is being caught by a cowboy with a lasso.
77 | what concepts from the concept list are missing from the sentence?
78 |
79 | Feedback: animal, ride
80 |
81 | Okay, impove the sentence using the sentence:
82 |
83 | Sentence: The cowboy catches an animal with a lasso and rides the horse.
84 | what concepts from the concept list are missing from the sentence?
85 |
86 | Feedback: NONE
87 |
88 | ###
89 |
90 | Concepts: ['card', 'chip', 'deal', 'dealer', 'table']
91 | Sentence: a dealer offers a card to a group of people at a table
92 | what concepts from the concept list are missing from the sentence?
93 |
94 | Feedback: chip, deal
95 |
96 | Okay, impove the sentence using the sentence:
97 |
98 | Sentence: The dealer dealt each person at the table and offered them a card.
99 | what concepts from the concept list are missing from the sentence?
100 |
101 | Feedback: chip
102 |
103 | Okay, impove the sentence using the sentence:
104 |
105 | Sentence: The dealer dealt each person at the table a chip and offered them a card.
106 | what concepts from the concept list are missing from the sentence?
107 |
108 | Feedback: NONE
109 |
110 | ###
111 |
112 |
113 |
--------------------------------------------------------------------------------
/docs/pie_eval.md:
--------------------------------------------------------------------------------
1 | # Instructions for evaluating runtime for PIE experiments
2 |
3 | *TLDR: From the self-refine outputs, create a flattened version of the outputs, and then use the PIE repo to evaluate the runtime and get a report. Parse the report using `src/pie/pie_eval.py`.*
4 |
5 | 1. **Step 1** (construct yaml): For evaluating runtime for PIE experiments, we need a yaml file that contains information about the dataset, the model outputs, and the reference file. Note that self-refine generates outputs in a slightly different format. While Self-Refine generates the outputs in an array (one version per refinement step), the evaluation requires the program to be present in a single column as a script. You can optionally use [https://github.com/madaan/self-refine/tree/main/src/pie](prep_for_pie_eval.py) for this. `prep_for_pie_eval.py` creates a single file where the output from the i^th step is present in the `attempt_i_code` column. The following is an example for evaluating the initial output (`y0`).
6 |
7 | - See `data/tasks/pie/gpt4_outputs_self_refine.jsonl` and `data/tasks/pie/gpt4_outputs_flattened.jsonl` for examples of the outputs from self-refine and the flattened version, respectively.
8 |
9 |
10 | ```
11 | inputs_outputs_basepath: "data/codenet/generated_test_cases/"
12 | reference_file_path: "data/tasks/pie/codenet-python-test-1k.jsonl"
13 | num_problems_to_evaluate: -1
14 | num_trials: 10
15 | ignore_first_k: 0
16 | max_time_per_run: 10
17 | temp_dir: null
18 | model_generated_potentially_faster_code_col: "attempt_0_code"
19 | slow_code_col: "input"
20 | reference_code_col: "target"
21 | is_prompt_based: false
22 | language: python
23 | return_if_acc_below: 1.0
24 | num_processes: 60
25 | cpu_number: 15
26 | model_generated_outputs_path: "where are the outputs we want to evaluate?"
27 | output_report_file_path: "Where should the report file be generated?"
28 | ```
29 |
30 | - Please see the [pie repo](https://github.com/madaan/pie-perf/blob/main/README.md#evaluating-your-method) for more details. Note that we are using generated test cases, which are also available at [pie repo](https://github.com/madaan/pie-perf/blob/main/README.md#evaluating-your-method).
31 |
32 |
33 | 2. **Step 2** (run pie eval)
34 |
35 | Using the yaml file generated in the above step, please use the [evaluating your method](https://github.com/madaan/pie-perf/blob/main/README.md#evaluating-your-method) field to evaluate the outputs. If you run self-refine for 4 timesteps, you would create 4 yaml files and run this evaluation four times, once for each timestep. See `data/tasks/pie/gpt4_outputs.zip` for the 4 yaml files and the reports from these steps.
36 |
37 | 3. **Step 3** (parse reports and aggregate results) After the evaluation, the report is saved in `output_report_file_path.` Then, you can use `src/pie/pie_eval.py` to aggregate the results.
38 |
39 | ### Sample outputs
40 |
41 | - Sample yaml files for each of the 4 steps, and the corresponding outputs are located at `data/tasks/pie/gpt4_outputs.zip'.
42 |
--------------------------------------------------------------------------------
/docs/static/animation_static_end.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/animation_static_end.gif
--------------------------------------------------------------------------------
/docs/static/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10}
--------------------------------------------------------------------------------
/docs/static/css/collapsible.css:
--------------------------------------------------------------------------------
1 | /*https://www.w3schools.com/howto/howto_js_collapsible.asp*./
2 | /* Style the button that is used to open and close the collapsible content */
3 | .collapsible {
4 | background-color: #eee;
5 | color: #444;
6 | cursor: pointer;
7 | padding: 18px;
8 | width: 100%;
9 | border: none;
10 | text-align: left;
11 | outline: none;
12 | font-size: 15px;
13 | }
14 |
15 | /* Add a background color to the button if it is clicked on (add the .active class with JS), and when you move the mouse over it (hover) */
16 | .active, .collapsible:hover {
17 | background-color: #ccc;
18 | }
19 |
20 | /* Style the collapsible content. Note: hidden by default */
21 | .content {
22 | padding: 0 18px;
23 | display: none;
24 | overflow: hidden;
25 | background-color: #f1f1f1;
26 | }
--------------------------------------------------------------------------------
/docs/static/css/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Noto Sans', sans-serif;
3 | }
4 |
5 |
6 | .footer .icon-link {
7 | font-size: 25px;
8 | color: #000;
9 | }
10 |
11 | .link-block a {
12 | margin-top: 5px;
13 | margin-bottom: 5px;
14 | }
15 |
16 | .pal {
17 | font-variant: small-caps;
18 | }
19 | .box {
20 | background-color: #f8f8f8;
21 | border: 1px solid #ddd;
22 | border-radius: 5px;
23 | padding: 1em;
24 | margin-bottom: 1em;
25 | overflow: auto; /* Add this line to enable scrolling */
26 | }
27 |
28 | .teaser .hero-body {
29 | padding-top: 0;
30 | padding-bottom: 3rem;
31 | }
32 |
33 | .teaser {
34 | font-family: 'Google Sans', sans-serif;
35 | }
36 |
37 |
38 | .publication-title {
39 | }
40 |
41 | .publication-banner {
42 | max-height: parent;
43 |
44 | }
45 |
46 | .publication-banner video {
47 | position: relative;
48 | left: auto;
49 | top: auto;
50 | transform: none;
51 | object-fit: fit;
52 | }
53 |
54 | .publication-header .hero-body {
55 | }
56 |
57 | .publication-title {
58 | font-family: 'Google Sans', sans-serif;
59 | }
60 |
61 | .publication-authors {
62 | font-family: 'Google Sans', sans-serif;
63 | }
64 |
65 | .publication-venue {
66 | color: #555;
67 | width: fit-content;
68 | font-weight: bold;
69 | }
70 |
71 | .publication-awards {
72 | color: #ff3860;
73 | width: fit-content;
74 | font-weight: bolder;
75 | }
76 |
77 | .publication-authors {
78 | }
79 |
80 | .publication-authors a {
81 | color: hsl(204, 86%, 53%) !important;
82 | }
83 |
84 | .publication-authors a:hover {
85 | text-decoration: underline;
86 | }
87 |
88 | .author-block {
89 | display: inline-block;
90 | }
91 |
92 | .publication-banner img {
93 | }
94 |
95 | .publication-authors {
96 | /*color: #4286f4;*/
97 | }
98 |
99 | .publication-video {
100 | position: relative;
101 | width: 100%;
102 | height: 0;
103 | padding-bottom: 56.25%;
104 |
105 | overflow: hidden;
106 | border-radius: 10px !important;
107 | }
108 |
109 | .publication-video iframe {
110 | position: absolute;
111 | top: 0;
112 | left: 0;
113 | width: 100%;
114 | height: 100%;
115 | }
116 |
117 | .publication-body img {
118 | }
119 |
120 | .results-carousel {
121 | overflow: hidden;
122 | }
123 |
124 | .results-carousel .item {
125 | margin: 5px;
126 | overflow: hidden;
127 | border: 1px solid #bbb;
128 | border-radius: 10px;
129 | padding: 0;
130 | font-size: 0;
131 | }
132 |
133 | .results-carousel video {
134 | margin: 0;
135 | }
136 |
137 |
138 | .interpolation-panel {
139 | background: #f5f5f5;
140 | border-radius: 10px;
141 | }
142 |
143 | .interpolation-panel .interpolation-image {
144 | width: 100%;
145 | border-radius: 5px;
146 | }
147 |
148 | .interpolation-video-column {
149 | }
150 |
151 | .interpolation-panel .slider {
152 | margin: 0 !important;
153 | }
154 |
155 | .interpolation-panel .slider {
156 | margin: 0 !important;
157 | }
158 |
159 | #interpolation-image-wrapper {
160 | width: 100%;
161 | }
162 | #interpolation-image-wrapper img {
163 | border-radius: 5px;
164 | }
165 |
166 | .tweets-container {
167 |
168 | top: 0;
169 | right: 0;
170 | width: 50%;
171 | height: 400px;
172 | overflow-y: scroll;
173 | border-left: 1px solid #ccc;
174 | padding: 0px;
175 |
176 | }
177 |
178 | .tweet {
179 | margin-bottom: 50px;
180 | }
181 |
182 | .centerdiv {
183 | display: grid;
184 | place-items: center;
185 | }
186 |
187 |
188 |
189 |
190 | /* CSS for the examples */
191 |
192 | .message {
193 | background-color: white;
194 | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08);
195 | margin-bottom: 1rem;
196 | border-radius: 4px;
197 | }
198 |
199 | .message-header {
200 | background-color: #f5f5f5;
201 | border-top-left-radius: 4px;
202 | border-top-right-radius: 4px;
203 | }
204 |
205 |
206 | .message-header {
207 | background-color: #f5f5f5;
208 | border-top-left-radius: 4px;
209 | border-top-right-radius: 4px;
210 | color: #333333; /* or any other darker shade you prefer */
211 | }
212 |
213 |
214 | .has-text-left pre {
215 | background-color: #f5f5f5;
216 | padding: 1rem;
217 | border-radius: 4px;
218 | }
219 |
220 | .title.is-5 {
221 | margin-top: 1rem;
222 | margin-bottom: 0.5rem;
223 | font-weight: 600; /* Increased font weight for better visibility */
224 | }
225 |
226 | #acronym_content {
227 | padding: 1.5rem;
228 | }
229 |
230 | .block:not(:last-child) {
231 | margin-bottom: 1.5rem;
232 | }
233 |
234 |
235 | code.python {
236 | display: block;
237 | background-color: #f4f4f4;
238 | border: 1px solid #ddd;
239 | border-radius: 5px;
240 | font-size: 14px;
241 | line-height: 1.5;
242 | overflow-x: auto;
243 | padding: 20px;
244 | }
245 |
246 | code.python {
247 | color: #008000;
248 | }
249 |
--------------------------------------------------------------------------------
/docs/static/css/prism.css:
--------------------------------------------------------------------------------
1 | /* PrismJS 1.29.0
2 | https://prismjs.com/download.html#themes=prism-tomorrow&languages=css+python */
3 | code[class*=language-],pre[class*=language-]{color:#ccc;background:0 0;font-family:Consolas,Monaco,'Andale Mono','Ubuntu Mono',monospace;font-size:1em;text-align:left;white-space:pre;word-spacing:normal;word-break:normal;word-wrap:normal;line-height:1.5;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-hyphens:none;-moz-hyphens:none;-ms-hyphens:none;hyphens:none}pre[class*=language-]{padding:1em;margin:.5em 0;overflow:auto}:not(pre)>code[class*=language-],pre[class*=language-]{background:#2d2d2d}:not(pre)>code[class*=language-]{padding:.1em;border-radius:.3em;white-space:normal}.token.block-comment,.token.cdata,.token.comment,.token.doctype,.token.prolog{color:#999}.token.punctuation{color:#ccc}.token.attr-name,.token.deleted,.token.namespace,.token.tag{color:#e2777a}.token.function-name{color:#6196cc}.token.boolean,.token.function,.token.number{color:#f08d49}.token.class-name,.token.constant,.token.property,.token.symbol{color:#f8c555}.token.atrule,.token.builtin,.token.important,.token.keyword,.token.selector{color:#cc99cd}.token.attr-value,.token.char,.token.regex,.token.string,.token.variable{color:#7ec699}.token.entity,.token.operator,.token.url{color:#67cdcc}.token.bold,.token.important{font-weight:700}.token.italic{font-style:italic}.token.entity{cursor:help}.token.inserted{color:green}
--------------------------------------------------------------------------------
/docs/static/images/animation_oldstyle_oneloop.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/animation_oldstyle_oneloop.gif
--------------------------------------------------------------------------------
/docs/static/images/autofb_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_animation.gif
--------------------------------------------------------------------------------
/docs/static/images/autofb_animation_static.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_animation_static.gif
--------------------------------------------------------------------------------
/docs/static/images/autofb_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_fig1.pdf
--------------------------------------------------------------------------------
/docs/static/images/autofb_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_fig1.png
--------------------------------------------------------------------------------
/docs/static/images/autofb_static.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/autofb_static.png
--------------------------------------------------------------------------------
/docs/static/images/example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/example1.png
--------------------------------------------------------------------------------
/docs/static/images/example1desc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/example1desc.png
--------------------------------------------------------------------------------
/docs/static/images/favicon.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/docs/static/images/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/fig2.png
--------------------------------------------------------------------------------
/docs/static/images/mainresults.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/mainresults.png
--------------------------------------------------------------------------------
/docs/static/images/pal_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/pal_overview.png
--------------------------------------------------------------------------------
/docs/static/images/refinement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/refinement.png
--------------------------------------------------------------------------------
/docs/static/images/tasks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/static/images/tasks.png
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/collapsible.js:
--------------------------------------------------------------------------------
1 | var coll = document.getElementsByClassName("collapsible");
2 | var i;
3 |
4 | for (i = 0; i < coll.length; i++) {
5 | coll[i].addEventListener("click", function() {
6 | this.classList.toggle("active");
7 | var content = this.nextElementSibling;
8 | if (content.style.display === "block") {
9 | content.style.display = "none";
10 | } else {
11 | content.style.display = "block";
12 | }
13 | });
14 | }
15 |
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 |
4 | $(document).ready(function() {
5 | // Check for click events on the navbar burger icon
6 | $(".navbar-burger").click(function() {
7 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
8 | $(".navbar-burger").toggleClass("is-active");
9 | $(".navbar-menu").toggleClass("is-active");
10 |
11 | });
12 |
13 | var current_cmd_idxs = {
14 | "gsm8k": 1,
15 | "gsm8khard": 1,
16 | "coloredobjects":1,
17 | "repeatcopy":1,
18 | "dateunderstanding":1
19 | }
20 |
21 | // examples
22 | $('select').on('click', function() {
23 |
24 | var sep_idx = this.value.indexOf('_');
25 | var domain_name = this.value.substring(0, sep_idx);
26 | var desired_cmd_idx = parseInt(this.value.substring(sep_idx + 1));
27 | var current_cmd_idx = current_cmd_idxs[domain_name];
28 |
29 | // hide current content
30 | var current_content = $('#content_' + domain_name + "_" + current_cmd_idx.toString());
31 |
32 | if (desired_cmd_idx == current_cmd_idx && current_content.is(":visible")) {
33 | current_content.hide();
34 | return;
35 | }
36 | current_content.hide();
37 |
38 | // show desired content
39 | var desired_content = $('#content_' + domain_name + "_" + desired_cmd_idx.toString());
40 | desired_content.show("slow");
41 |
42 | // set current to desired
43 | current_cmd_idxs[domain_name] = desired_cmd_idx;
44 | });
45 |
46 |
47 |
48 | // general function for xyzheader
49 | function toggle_options(header_id, options_id) {
50 | if ($(options_id).is(":visible")) {
51 | $(options_id).hide();
52 | // extract task name from header. e.g., #gsm8k_header -> gsm8k
53 | task_name = header_id.split("_")[0].substring(1);
54 |
55 | console.log("You have selected " + task_name + " as your task.");
56 | for (var i = 0; i <= 100; i++) {
57 |
58 | var content_id = "#content_" + task_name + "_" + i.toString();
59 | console.log(content_id);
60 | // check if content exists
61 | if ($(content_id).length == 0) {
62 | break;
63 | }
64 | $(content_id).hide();
65 | }
66 | $(header_id).removeClass("is-active");
67 | } else {
68 | $(options_id).show("slow");
69 | $(header_id).addClass("is-active");
70 | }
71 | }
72 |
73 | $('#acronym_button').click(function() {
74 | toggle_options('#acronym_header', '#acronym_content');
75 | });
76 | $('#responsegen').click(function() {
77 | toggle_options('#responsegen_header', '#responsegen_content');
78 | });
79 | $('#commongen_button').click(function() {
80 | toggle_options('#commongen_header', '#commongen_content');
81 | }
82 | );
83 | $('#gsm_button').click(function() {
84 | toggle_options('#gsm_header', '#gsm_content');
85 | }
86 | );
87 | $('#codeoptimization_button').click(function() {
88 | toggle_options('#codeoptimization_header', '#codeoptimization_content');
89 | }
90 | );
91 | $('#sentiment_button').click(function() {
92 | toggle_options('#gsentiment_header', '#sentiment_content');
93 | }
94 | );
95 |
96 | $('#readibility_button').click(function() {
97 | toggle_options('#readibility_header', '#readibility_content');
98 | }
99 | );
100 |
101 | $('#gsm8khard_options').hide();
102 | $('#coloredobjects_options').hide();
103 | $('#repeatcopy_options').hide();
104 | $('#dateunderstanding_options').hide();
105 |
106 |
107 | })
108 |
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/docs/static/js/results.js:
--------------------------------------------------------------------------------
1 | // generates the results histogram using d3.js
2 | results = function() {
3 | var margin = {top: 20, right: 20, bottom: 30, left: 40},
4 | width = 960 - margin.left - margin.right,
5 | height = 500 - margin.top - margin.bottom;
6 |
7 | var x = d3.scale.ordinal()
8 | .rangeRoundBands([0, width], .1);
9 |
10 | var y = d3.scale.linear()
11 | .range([height, 0]);
12 |
13 | var xAxis = d3.svg.axis()
14 | .scale(x)
15 | .orient("bottom");
16 |
17 | var yAxis = d3.svg.axis()
18 | .scale(y)
19 | .orient("left")
20 | .ticks(10, "%");
21 |
22 | var svg = d3.select("#results").append("svg")
23 | .attr("width", width + margin.left + margin.right)
24 | .attr("height", height + margin.top + margin.bottom)
25 | .append("g")
26 | .attr("transform", "translate(" + margin.left + "," + margin.top + ")");
27 |
28 | d3.json("/results", function(error, data) {
29 | x.domain(data.map(function(d) { return d.name; }));
30 | y.domain([0, d3.max(data, function(d) { return d.count; })]);
31 |
32 | svg.append("g")
33 | .attr("class", "x axis")
34 | .attr("transform", "translate(0," + height + ")")
35 | .call(xAxis);
36 |
37 | svg.append("g")
38 | .attr("class", "y axis")
39 | .call(yAxis)
40 | .append("text")
41 | .attr("transform", "rotate(-90)")
42 | .attr("y", 6)
43 | .attr("dy", ".71em")
44 | .style("text-anchor", "end")
45 | .text("Frequency");
46 |
47 | svg.selectAll(".bar")
48 | .data(data)
49 | .enter().append("rect")
50 | .attr("class", "bar")
51 | .attr("x", function(d) { return x(d.name); })
52 | .attr("width", x.rangeBand())
53 | .attr("y", function(d) { return y(d.count); })
54 | .attr("height", function(d) { return height - y(d.count); });
55 | });
56 | }
57 |
--------------------------------------------------------------------------------
/docs/view.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/view.pdf
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/alien.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/alien.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/artificial_general_intell.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/artificial_general_intell.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/bird.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/bird.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/bird_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/bird_v2.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/c___program.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/c___program.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/cat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/cat.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/cell.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/cell.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/cell_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/cell_v2.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/dragon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/dragon.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/dragon_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/dragon_v2.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/drums.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/drums.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/duck.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/duck.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/ellipsoid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/ellipsoid.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/feedback_driven_self_improvement.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/feedback_driven_self_improvement.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/ferrari.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/ferrari.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/flower.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/flower.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/game_of_life.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/game_of_life.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/giphy.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/giphy.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/goat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/goat.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/gradient_descent.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/gradient_descent.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/guitar.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/guitar.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/hidden_markov_model.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hidden_markov_model.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/hmm.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hmm.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/hyperboloid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hyperboloid.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/hyperplane.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/hyperplane.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/illustration_of_a_cell.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/illustration_of_a_cell.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/illustration_of_a_neural_.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/illustration_of_a_neural_.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/illustration_of_hmms.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/illustration_of_hmms.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/linux_penguin.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/linux_penguin.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/neuron.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/neuron.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/nice_unicorn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/nice_unicorn.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/nine_point_circle.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/nine_point_circle.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/panda_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/panda_v2.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/penguin.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/penguin.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/penguin_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/penguin_v2.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/pentagon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/pentagon.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/puppy.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/puppy.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/pyramid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/pyramid.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/pythagorean_theorem.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/pythagorean_theorem.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/roast_turkey.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/roast_turkey.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/robot.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/robot.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/rocket.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/rocket.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/self_attention.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/self_attention.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/self_refine_is_a_novel_ap.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/self_refine_is_a_novel_ap.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/self_refinement_loop.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/self_refinement_loop.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/space_shuttle.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/space_shuttle.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/spaceship.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/spaceship.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/stokes__theorem.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/stokes__theorem.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/support_vector_machine.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/support_vector_machine.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/tcp_ip.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/tcp_ip.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/thanksgiving_turkey.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/thanksgiving_turkey.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/toroid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/toroid.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/train.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/turkey.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/turkey.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/turkey_v1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/turkey_v1.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/turtle.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/turtle.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/unicorn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/unicorn_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn_v2.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/unicorn_v3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn_v3.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/unicorn_v4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/unicorn_v4.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/vae.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/vae.gif
--------------------------------------------------------------------------------
/docs/visual_self_refine_examples/variational_inference.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/docs/visual_self_refine_examples/variational_inference.gif
--------------------------------------------------------------------------------
/src/acronym/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/acronym/__init__.py
--------------------------------------------------------------------------------
/src/acronym/feedback.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from prompt_lib.backends import openai_api
3 |
4 | from src.utils import Prompt
5 |
6 |
7 | class AcronymGenFeedback(Prompt):
8 | def __init__(self, engine: str, prompt_examples: str, max_tokens: int = 300) -> None:
9 | super().__init__(
10 | question_prefix="",
11 | answer_prefix="",
12 | intra_example_sep="\n\n",
13 | inter_example_sep="\n\n###\n\n",
14 | )
15 | self.engine = engine
16 | self.max_tokens = max_tokens
17 | self.setup_prompt_from_examples_file(prompt_examples)
18 |
19 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
20 | template = """Title: {title}
21 |
22 | Acronym: {answer}
23 |
24 | Scores:
25 |
26 | * Ease of pronunciation: {pronunciation_score}
27 | * Ease of spelling: {spelling_score}
28 | * Relation to title: {relation_score}
29 | * Positive connotation: {connotation_score}
30 | * Well-known: {well_known_score}
31 |
32 | * Total score: {total_score}"""
33 |
34 | examples_df = pd.read_json(examples_path, orient="records", lines=True)
35 | prompt = []
36 | for _, row in examples_df.iterrows():
37 | prompt.append(
38 | template.format(
39 | title=row["title"],
40 | answer=row["acronym"],
41 | pronunciation_score=row["pronunciation_score"],
42 | spelling_score=row["spelling_score"],
43 | relation_score=row["relation_score"],
44 | connotation_score=row["connotation_score"],
45 | well_known_score=row["well_known_score"],
46 | total_score=row["total_score"],
47 | )
48 | )
49 |
50 | instruction = """We want to score each acronym on five qualities: i) ease of pronunciation, ii) ease of spelling, and iii) relation to the title, iv) positive connotation, v) well-known.
51 |
52 | Here are some examples of this scoring rubric:
53 |
54 | """
55 | self.prompt = instruction + self.inter_example_sep.join(prompt)
56 | self.prompt = self.inter_example_sep.join(prompt) + self.inter_example_sep
57 |
58 | def __call__(self, title: str, acronym: str):
59 | prompt = self.get_prompt_with_question(title=title, acronym=acronym)
60 |
61 | output = openai_api.OpenaiAPIWrapper.call(
62 | prompt=prompt,
63 | engine=self.engine,
64 | max_tokens=self.max_tokens,
65 | stop_token="###",
66 | temperature=0.7,
67 | )
68 |
69 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output)
70 | generated_feedback = generated_feedback.split("Scores:")[1].strip()
71 | generated_feedback = generated_feedback.split("#")[0].strip()
72 | return generated_feedback
73 |
74 | def get_prompt_with_question(self, title: str, acronym: str):
75 | question = self.make_query(title=title, acronym=acronym)
76 | return f"""{self.prompt}{question}"""
77 |
78 | def make_query(self, title: str, acronym: str):
79 | question = f"""Title: {title}
80 |
81 | Acronym: {acronym}"""
82 | return question
83 |
84 |
85 |
86 | if __name__ == "__main__":
87 | feedback = AcronymGenFeedback(
88 | engine="davinci-code-002",
89 | prompt_examples="data/prompt/acronym/feedback.jsonl",
90 | )
91 |
92 | print(feedback.prompt)
--------------------------------------------------------------------------------
/src/acronym/run.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pandas as pd
3 |
4 |
5 | from src.acronym.task_init import AcronymGenTaskInit
6 | from src.acronym.task_iterate import AcronymGenTaskIterate
7 | from src.acronym.feedback import AcronymGenFeedback
8 | from src.utils import retry_parse_fail_prone_cmd
9 |
10 | CODEX = "code-davinci-002"
11 | GPT3 = "text-davinci-003"
12 | CHAT_GPT = "gpt-3.5-turbo"
13 | GPT4 = "gpt-4"
14 |
15 |
16 | ENGINE = CHAT_GPT
17 |
18 | @retry_parse_fail_prone_cmd
19 | def iterative_acronym(title: str, max_attempts: int) -> str:
20 |
21 | # initialize all the required components
22 |
23 | # generation of the first acronym
24 | task_init = AcronymGenTaskInit(engine=ENGINE, prompt_examples="data/prompt/acronym/init.jsonl")
25 |
26 | # getting feedback
27 | task_feedback = AcronymGenFeedback(engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl")
28 |
29 | # iteratively improving the acronym
30 | task_iterate = AcronymGenTaskIterate(engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl")
31 |
32 |
33 | # Initialize the task
34 |
35 | n_attempts = 0
36 |
37 | print(f"{n_attempts} INIT> {title}")
38 | acronyms_to_scores = dict()
39 |
40 | all_acronyms_to_scores = dict()
41 | best_score_so_far = 0
42 | while n_attempts < max_attempts:
43 |
44 | if n_attempts == 0:
45 | acronym = task_init(title=title)
46 | else:
47 | new_title, acronym = task_iterate(acronyms_to_scores=acronyms_to_scores)
48 | title = new_title
49 |
50 |
51 | scores = task_feedback(title=title, acronym=acronym)
52 | # extract expression "Total score: 22/25" from scores
53 | total_score = re.search(r"Total score: (\d+)/(\d+)", scores).group(0)
54 | total_score = int(total_score.split(":")[1].strip().split("/")[0])
55 |
56 | all_acronyms_to_scores[acronym] = {
57 | "scores": scores,
58 | "total_score": total_score,
59 | "title": title,
60 | }
61 | print(f"{n_attempts} GEN> {acronym} TITLE> {title}")
62 |
63 | print(f"{n_attempts} SCORES> {scores}")
64 | if total_score >= 0: # only iterate over things that are improving
65 | best_score_so_far = total_score
66 |
67 | acronyms_to_scores[acronym] = (title, scores)
68 |
69 |
70 | else:
71 | print(f"Score of {acronym} is {total_score}, which is less than the current best of {best_score_so_far}")
72 |
73 | n_attempts += 1
74 |
75 | return all_acronyms_to_scores
76 |
77 |
78 | def run_over_titles(titles_file: str, max_attempts: int, outfile: str):
79 |
80 | def _parse_results(title: str) -> str:
81 | try:
82 | results = iterative_acronym(title=title, max_attempts=max_attempts)
83 | if results is None:
84 | return "FAILED"
85 | res = []
86 | for acronym, scores in results.items():
87 | res.append(f"{acronym} [score: {scores['total_score']}] \n {scores['scores']}")
88 | return "\n ------ \n ".join(res)
89 | except Exception as e:
90 | return "FAILED"
91 |
92 |
93 | data = pd.read_csv(titles_file, sep="\t")
94 | data['generated_acronym'] = data['title'].apply(_parse_results)
95 |
96 | data.to_json(outfile, orient="records", lines=True)
97 |
98 | if __name__ == "__main__":
99 | import sys
100 | title = sys.argv[1] # Light Amplification by Stimulated Emission of Radiation
101 | if len(sys.argv) > 2:
102 | run_over_titles(titles_file=sys.argv[1], max_attempts=int(sys.argv[2]), outfile=sys.argv[3])
103 | else:
104 | max_attempts = 5
105 | all_acronyms_to_scores = iterative_acronym(
106 | title=title,
107 | max_attempts=max_attempts,
108 | )
109 |
110 | res = []
111 | for acronym, scores in all_acronyms_to_scores.items():
112 | res.append(f"{acronym} [score: {scores['total_score']}] \n {scores['scores']}")
113 | print("\n ------ \n ".join(res))
114 |
115 |
--------------------------------------------------------------------------------
/src/acronym/run_mcts.py:
--------------------------------------------------------------------------------
1 | import re
2 | import math
3 |
4 |
5 | from src.acronym.task_init import AcronymGenTaskInit
6 | from src.acronym.task_iterate import AcronymGenTaskIterate
7 | from src.acronym.feedback import AcronymGenFeedback
8 | from src.utils import retry_parse_fail_prone_cmd
9 |
10 | CODEX = "code-davinci-002"
11 | GPT3 = "text-davinci-003"
12 | CHAT_GPT = "gpt-3.5-turbo"
13 | GPT4 = "gpt-4"
14 | ENGINE = CHAT_GPT
15 |
16 | task_init = AcronymGenTaskInit(engine=ENGINE, prompt_examples="data/prompt/acronym/init.jsonl")
17 |
18 | # getting feedback
19 | task_feedback = AcronymGenFeedback(
20 | engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl"
21 | )
22 |
23 | # iteratively improving the acronym
24 | task_iterate = AcronymGenTaskIterate(
25 | engine=ENGINE, prompt_examples="data/prompt/acronym/feedback.jsonl"
26 | )
27 |
28 |
29 | class TreeNode:
30 | def __init__(self, title: str, acronym: str, scores: dict, parent=None):
31 | self.title = title
32 | self.acronym = acronym.strip()
33 | self.scores = scores
34 | self.children = []
35 | self.visits = 1
36 | self.value = 0.0
37 | self.parent = parent
38 |
39 | def __str__(self):
40 | children_str = ", ".join([str(child.acronym) for child in self.children])
41 | if children_str == "":
42 | children_str = "None"
43 | res = f"TreeNode(title='{self.title}', acronym='{self.acronym}', scores={self.scores}, visits={self.visits}, value={self.value}, parent_acronym='{self.parent.acronym if self.parent else None}' children={children_str})"
44 | return res
45 |
46 |
47 | @retry_parse_fail_prone_cmd
48 | def generate_initial_children(node: TreeNode, task_iterate, task_feedback, num_children: int = 3):
49 | for _ in range(num_children):
50 | new_title, new_acronym = task_iterate(
51 | acronyms_to_scores={node.acronym: (node.title, node.scores)}
52 | )
53 | child = TreeNode(title=new_title, acronym=new_acronym, scores=None, parent=node)
54 | simulate(child, task_feedback) # Set scores for the child node
55 | node.children.append(child)
56 |
57 |
58 | def parse_scores(scores_output: str) -> dict:
59 | scores_pattern = re.compile(r"\* (.*?): (.*)\n?")
60 | scores = {}
61 |
62 | for score_match in scores_pattern.finditer(scores_output):
63 | score_title, score_value = score_match.groups()
64 | score_value = int(re.search(r"\d+", score_value).group(0)) # Updated
65 | scores[score_title] = score_value
66 |
67 | return scores
68 |
69 |
70 | def normalize_scores(scores: dict) -> dict:
71 | normalized_scores = {}
72 | for key, value in scores.items():
73 | if key != "Total score":
74 | normalized_scores[key] = value / 5
75 | return normalized_scores
76 |
77 |
78 | def weighted_sum(normalized_scores: dict, weights: dict) -> float:
79 | return sum(normalized_scores[key] * weights[key] for key in normalized_scores)
80 |
81 |
82 | def select(node: TreeNode, weights: dict, C: float = 1.0):
83 | if not node.children:
84 | return node
85 |
86 | ucb1_values = [
87 | (weighted_sum(normalize_scores(child.scores), weights) / child.visits)
88 | + C * math.sqrt(math.log(node.visits) / child.visits)
89 | for child in node.children
90 | ]
91 |
92 | max_index = ucb1_values.index(max(ucb1_values))
93 | return select(node.children[max_index], weights, C)
94 |
95 |
96 | @retry_parse_fail_prone_cmd
97 | def expand(node: TreeNode, task_iterate, expanded_nodes_cache: set):
98 | acronyms_to_scores = {}
99 | current = node
100 | while current is not None:
101 | acronyms_to_scores[current.acronym] = (current.title, current.scores)
102 | current = current.parent
103 |
104 | new_title, new_acronym = task_iterate(acronyms_to_scores=acronyms_to_scores)
105 |
106 | while new_acronym in expanded_nodes_cache:
107 | new_title, new_acronym = task_iterate(acronyms_to_scores=acronyms_to_scores)
108 |
109 | expanded_nodes_cache.add(new_acronym)
110 | return TreeNode(title=new_title, acronym=new_acronym, scores=None, parent=node)
111 |
112 |
113 | @retry_parse_fail_prone_cmd
114 | def simulate(node: TreeNode, task_feedback):
115 | scores = task_feedback(title=node.title, acronym=node.acronym)
116 | node.scores = parse_scores(scores)
117 | normalized_total_score = node.scores["Total score"] / 25
118 | return normalized_total_score
119 |
120 |
121 | def backpropagate(node: TreeNode, value: float):
122 | node.visits += 1
123 | node.value += value
124 |
125 | if node.parent is not None:
126 | backpropagate(node.parent, value)
127 |
128 |
129 | def mcts_iteration(
130 | root: TreeNode, weights: dict, task_iterate, task_feedback, expanded_nodes_cache: set
131 | ):
132 |
133 | print("Selecting...")
134 | selected_node = select(root, weights)
135 | print(f" Selected node: {selected_node.acronym} for title '{selected_node.title}'")
136 |
137 | print("Expanding...")
138 | expanded_node = expand(selected_node, task_iterate, expanded_nodes_cache)
139 |
140 | selected_node.children.append(
141 | expanded_node
142 | ) # Add expanded node as a child of the selected node
143 | print(f" Expanded node: {expanded_node.acronym} for title '{expanded_node.title}'")
144 |
145 | print("Simulating...")
146 | value = simulate(expanded_node, task_feedback)
147 | print(f" Simulated value: {value}")
148 |
149 | print("Backpropagating...")
150 | backpropagate(expanded_node, value)
151 | print(" Backpropagation complete")
152 |
153 |
154 | root_title = "Iterative Refinement with Self-Feedback"
155 | root_acronym = task_init(title=root_title)
156 | root_scores = task_feedback(title=root_title, acronym=root_acronym)
157 | root_scores = parse_scores(root_scores)
158 | root = TreeNode(root_title, root_acronym, scores=root_scores)
159 |
160 | print(f"Root node: {root}")
161 |
162 | generate_initial_children(root, task_iterate, task_feedback)
163 |
164 | print(f"Root node after generating initial children: {root}")
165 |
166 | weights = {
167 | "Ease of pronunciation": 0.2,
168 | "Ease of spelling": 0.2,
169 | "Relation to title": 0.3,
170 | "Positive connotation": 0.2,
171 | "Well-known": 0.1,
172 | }
173 |
174 | iterations = 4
175 | expanded_nodes_cache = {root.acronym}
176 |
177 | for _ in range(iterations):
178 | mcts_iteration(root, weights, task_iterate, task_feedback, expanded_nodes_cache)
179 |
180 |
181 | def dfs(node: TreeNode, best_node: TreeNode):
182 | if not node.children:
183 | if node.scores["Total score"] > best_node.scores["Total score"]:
184 | return node
185 | else:
186 | return best_node
187 |
188 | for child in node.children:
189 | best_node = dfs(child, best_node)
190 |
191 | return best_node
192 |
193 |
194 | best_node = dfs(root, root)
195 | best_acronym, best_score = best_node.acronym, best_node.scores["Total score"]
196 |
197 | print(f"\nBest acronym: {best_acronym} with total score: {best_score}")
198 |
199 |
200 | def print_tree(node: TreeNode, indent: int = 0):
201 | indentation = " " * indent
202 | print(f"{indentation}{node.acronym} (Score: {node.scores['Total score']})")
203 |
204 | for child in node.children:
205 | print_tree(child, indent + 1)
206 |
207 |
208 | print_tree(root)
209 |
--------------------------------------------------------------------------------
/src/acronym/task_init.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from src.utils import Prompt
3 |
4 | from prompt_lib.backends import openai_api
5 |
6 |
7 | class AcronymGenTaskInit(Prompt):
8 | def __init__(self, prompt_examples: str, engine: str) -> None:
9 | super().__init__(
10 | question_prefix="Title: ",
11 | answer_prefix="Acronym: ",
12 | intra_example_sep="\n\n",
13 | inter_example_sep="\n\n###\n\n",
14 | )
15 | self.engine = engine
16 | self.setup_prompt_from_examples_file(prompt_examples)
17 |
18 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
19 | TEMPLATE = """Title: {title}
20 |
21 | Acronym: {answer}"""
22 |
23 | examples_df = pd.read_json(examples_path, orient="records", lines=True)
24 | prompt = []
25 | for _, row in examples_df.iterrows():
26 | prompt.append(TEMPLATE.format(title=row["title"], answer=row["acronym"]))
27 | self.prompt = self.inter_example_sep.join(prompt)
28 | self.prompt = self.prompt + self.inter_example_sep
29 |
30 | def make_query(self, title: str) -> str:
31 | query = f"{self.prompt}{self.question_prefix}{title}{self.intra_example_sep}"
32 | return query
33 |
34 | def __call__(self, title: str) -> str:
35 | generation_query = self.make_query(title)
36 |
37 | output = openai_api.OpenaiAPIWrapper.call(
38 | prompt=generation_query,
39 | engine=self.engine,
40 | max_tokens=300,
41 | stop_token="###",
42 | temperature=0.7,
43 | )
44 |
45 | generated_acronym = openai_api.OpenaiAPIWrapper.get_first_response(output)
46 | # print("output:")
47 | # print(generated_acronym)
48 | # sys.exit()
49 | generated_acronym = generated_acronym.split(self.answer_prefix)[1].replace("#", "").strip()
50 | return generated_acronym.strip()
51 |
52 |
53 | def test():
54 | task_init = AcronymGenTaskInit(engine="code-davinci-002", prompt_examples="data/prompt/acronym/init.jsonl")
55 | print(task_init.prompt)
56 |
57 |
58 | if __name__ == "__main__":
59 | test()
--------------------------------------------------------------------------------
/src/acronym/task_iterate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Dict, List
3 | from src.utils import Prompt
4 | import pandas as pd
5 |
6 | from prompt_lib.backends import openai_api
7 |
8 |
9 | class AcronymGenTaskIterate(Prompt):
10 | def __init__(self, engine: str, prompt_examples: str) -> None:
11 | super().__init__(
12 | question_prefix="",
13 | answer_prefix="",
14 | intra_example_sep="\n\n",
15 | inter_example_sep="\n\n###\n\n",
16 | )
17 | self.engine = engine
18 | self.count = 0
19 | self.prompt = self.make_prompt(prompt_examples=prompt_examples)
20 |
21 | def make_prompt(self, prompt_examples: str) -> str:
22 |
23 | prompt_examples = pd.read_json(prompt_examples, orient="records", lines=True)
24 | # group on example
25 | grouped = prompt_examples.groupby("example")
26 |
27 | prompt = []
28 | # sort each group by score
29 | for _, group in grouped:
30 | group["numerical_score"] = group["total_score"].apply(lambda x: int(x.split("/")[0].strip()))
31 | group = group.sort_values("numerical_score")
32 | prompt.append(self.make_one_iterate_example(group.to_dict("records")))
33 |
34 | return self.inter_example_sep.join(prompt) + self.inter_example_sep
35 |
36 |
37 | def make_one_iterate_example(self, incrementally_improving_examples: List[Dict]):
38 | """Given a list of examples that are incrementally improving, return a new example.
39 | """
40 |
41 | instr = """We want to iteratively improve acronyms. To help improve, scores for each acronym on five desired traits are provided: i) ease of pronunciation, ii) ease of spelling, and iii) relation to the title, iv) positive connotation, v) well-known.
42 |
43 | """
44 | example_template = """Title: {title}
45 |
46 | Acronym: {acronym}
47 |
48 | Scores:
49 |
50 | * Ease of pronunciation: {pronunciation_score}
51 | * Ease of spelling: {spelling_score}
52 | * Relation to title: {relation_score}
53 | * Positive connotation: {connotation_score}
54 | * Well-known: {well_known_score}
55 |
56 | * Total score: {total_score}
57 |
58 | Okay, let's use this feedback to improve the acronym.
59 |
60 | """
61 | prompt = []
62 | for example in incrementally_improving_examples:
63 | prompt.append(example_template.format(**example))
64 |
65 | prompt = "".join(prompt)
66 | prompt = instr + prompt
67 | return prompt.strip()
68 |
69 | def make_query(self, question: str) -> str:
70 | return f"{self.prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}"
71 | # return super().make_query(prompt, question)
72 |
73 | def _make_input(
74 | self,
75 | title: str,
76 | acronym: str,
77 | scores: str,
78 | ) -> str:
79 | input_txt = f"""Title: {title}
80 |
81 | Acronym: {acronym}
82 |
83 | Scores:
84 |
85 | {scores}
86 |
87 | Okay, let's use this feedback to improve the acronym.
88 |
89 | """
90 |
91 | return input_txt
92 |
93 | def __call__(
94 | self,
95 | acronyms_to_scores: Dict[str, str],
96 | ) -> str:
97 | example_input = self.make_input(
98 | acronyms_to_scores=acronyms_to_scores,
99 | )
100 | transfer_query = self.make_query(example_input)
101 | self.count += 1
102 | with open(f"acronym_iterate_{self.count}.txt", "w") as f:
103 | f.write(transfer_query + "\n")
104 |
105 | output = openai_api.OpenaiAPIWrapper.call(
106 | prompt=transfer_query,
107 | engine=self.engine,
108 | max_tokens=300,
109 | stop_token=self.inter_example_sep,
110 | temperature=0.7,
111 | )
112 | response = openai_api.OpenaiAPIWrapper.get_first_response(output)
113 |
114 |
115 | acronym = response.split("Acronym:")[1].strip().split("\n")[0].strip()
116 |
117 | new_title = response.split("Title:")[1].strip().split("\n")[0].strip()
118 |
119 |
120 |
121 | return new_title, acronym.strip()
122 |
123 | def make_input(
124 | self,
125 | acronyms_to_scores: Dict[str, str],
126 | ) -> str:
127 | input_txt = ""
128 | for acronym, (title, scores) in acronyms_to_scores.items():
129 | input_txt += self._make_input(
130 | title=title,
131 | acronym=acronym,
132 | scores=scores,
133 | )
134 | return input_txt
135 |
136 |
137 |
138 |
139 |
140 |
141 | if __name__ == "__main__":
142 | obj = AcronymGenTaskIterate(prompt_examples="data/prompt/acronym/feedback.jsonl", engine="whatever")
143 | print(obj.prompt)
--------------------------------------------------------------------------------
/src/commongen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/commongen/__init__.py
--------------------------------------------------------------------------------
/src/commongen/eval.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from typing import List, Dict
3 |
4 | def run(path: str):
5 |
6 | df = pd.read_json(path, lines=True, orient="records")
7 | df = df[df['status'] != "error"]
8 | print(f"Loaded {len(df)} rows")
9 | for i, row in df.iterrows():
10 | direct_output = row["sent_to_fb"][0]
11 | iter_output = row["sent_to_fb"][-1]
12 | df.loc[i, 'direct_concept_success'] = direct_output["concept_feedback"][0].lower() == "none"
13 | df.loc[i, 'direct_commonsense_success'] = direct_output["commonsense_feedback"].lower() == "none"
14 | df.loc[i, 'direct_success'] = direct_output["concept_feedback"][0].lower() == "none" and direct_output["commonsense_feedback"].lower() == "none"
15 | df.loc[i, 'iter_concept_success'] = iter_output["concept_feedback"][0].lower() == "none"
16 | df.loc[i, 'iter_commonsense_success'] = iter_output["commonsense_feedback"].lower() == "none"
17 | df.loc[i, 'iter_success'] = iter_output["concept_feedback"][0].lower() == "none" and iter_output["commonsense_feedback"].lower() == "none"
18 |
19 | # direct wins
20 | num_direct_cocept_wins = len(df[(df['direct_concept_success'] == True) & (df['iter_concept_success'] == False)])
21 | num_direct_commonsense_wins = len(df[(df['direct_commonsense_success'] == True) & (df['iter_commonsense_success'] == False)])
22 | num_iter_cocept_wins = len(df[(df['direct_concept_success'] == False) & (df['iter_concept_success'] == True)])
23 | num_iter_commonsense_wins = len(df[(df['direct_commonsense_success'] == False) & (df['iter_commonsense_success'] == True)])
24 | num_direct_wins = len(df[(df['direct_success'] == True) & (df['iter_success'] == False)])
25 | num_iter_wins = len(df[(df['direct_success'] == False) & (df['iter_success'] == True)])
26 |
27 |
28 | num_commonsense_ties = len(df) - num_direct_commonsense_wins - num_iter_commonsense_wins
29 | num_concept_ties = len(df) - num_direct_cocept_wins - num_iter_cocept_wins
30 |
31 | # normalize everything and print a nice report
32 |
33 | print(f"Direct concept wins: {num_direct_cocept_wins / len(df):.2f}")
34 | print(f"Direct commonsense wins: {num_direct_commonsense_wins / len(df):.2f}")
35 | print(f"Direct overall wins: {num_direct_wins / len(df):.2f}")
36 | print(f"Iter concept wins: {num_iter_cocept_wins / len(df):.2f}")
37 | print(f"Iter commonsense wins: {num_iter_commonsense_wins / len(df):.2f}")
38 | print(f"Iter overall wins: {num_iter_wins / len(df):.2f}")
39 |
40 |
41 | if __name__ == '__main__':
42 | import argparse
43 | args = argparse.ArgumentParser()
44 | args.add_argument("path", type=str)
45 | args = args.parse_args()
46 |
47 | run(path=args.path)
48 |
49 |
--------------------------------------------------------------------------------
/src/commongen/feedback.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Set, List
3 | import pandas as pd
4 | from prompt_lib.backends import openai_api
5 | import nltk
6 | import spacy
7 |
8 | nlp = spacy.load("en_core_web_sm")
9 | from src.utils import Prompt
10 |
11 |
12 | class CommongenFeedback(Prompt):
13 | def __init__(self, engine: str, prompt_examples: str, max_tokens: int = 300) -> None:
14 | super().__init__(
15 | question_prefix="",
16 | answer_prefix="",
17 | intra_example_sep="\n\n",
18 | inter_example_sep="\n\n###\n\n",
19 | )
20 | self.engine = engine
21 | self.max_tokens = max_tokens
22 | self.setup_prompt_from_examples_file(prompt_examples)
23 |
24 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
25 | template = """Concepts: {concepts}
26 | Sentence: {sentence}
27 | what concepts from the concept list are missing from the sentence and does the sentence make sense?
28 |
29 | Concept Feedback: {feedback}
30 | Commonsense Feedback: {commonsense_feedback}"""
31 |
32 | examples_df = pd.read_json(examples_path, orient="records", lines=True)
33 | prompt = []
34 | for _, row in examples_df.iterrows():
35 | prompt.append(
36 | template.format(
37 | concepts=row["concepts"],
38 | sentence=row["sentence"],
39 | feedback=", ".join(row["concept_feedback"]),
40 | commonsense_feedback=row["commonsense_feedback"]
41 | )
42 | )
43 |
44 | instruction = """We want to create a sentence that contains all the specified concepts. Please provide feedback on the following sentences. The feedback indicates missing concepts."""
45 | self.prompt = instruction + self.inter_example_sep.join(prompt)
46 | self.prompt = self.inter_example_sep.join(prompt) + self.inter_example_sep
47 |
48 | def __call__(self, sentence: str, concepts: List[str]):
49 | prompt = self.make_query(sentence=sentence, concepts=concepts)
50 |
51 | output = openai_api.OpenaiAPIWrapper.call(
52 | prompt=prompt,
53 | engine=self.engine,
54 | max_tokens=self.max_tokens,
55 | stop_token="###",
56 | temperature=0.7,
57 | )
58 |
59 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output)
60 | commonsense_feedback = re.search(r"Commonsense Feedback: (.*)", generated_feedback).group(1)
61 | commonsense_feedback = self.fix_feedback(sentence=sentence, concepts=concepts, feedback=commonsense_feedback)
62 |
63 | concept_feedback = re.search(r"Concept Feedback: (.*)", generated_feedback).group(1)
64 | return concept_feedback, commonsense_feedback
65 |
66 |
67 | def make_query(self, concepts: List[str], sentence: str):
68 | question = f"""Concepts: {concepts}
69 | Sentence: {sentence}
70 | what concepts from the concept list are missing from the sentence?"""
71 | return f"""{self.prompt}{question}"""
72 |
73 |
74 | def fix_feedback(self, sentence: str, concepts: List[str], feedback: str):
75 | """We rely on the model for generating a feedback. This is done to capture different forms in which the same concept might be expressed. However, the model might make mistakes and our task is simple enough that some of the mistakes can be corrected"""
76 |
77 | concepts_in_sent = self.detect_concepts(sentence=sentence, concepts=concepts)
78 | concepts_in_feedback = set([f.strip() for f in feedback.split(", ")])
79 | fixed_feedback = concepts_in_feedback.difference(concepts_in_sent)
80 | if len(fixed_feedback) == 0:
81 | return "None"
82 | return ", ".join(fixed_feedback)
83 |
84 | def detect_concepts(self, sentence: str, concepts: List[str]) -> Set[str]:
85 | present_concepts = []
86 |
87 | # Tokenize the sentence and lemmatize the tokens
88 | tokens = nltk.word_tokenize(sentence)
89 | lemmas = [token.lemma_ for token in nlp(sentence)]
90 |
91 | # Check if each concept is present in the sentence
92 | for concept in concepts:
93 | if concept in tokens or concept in lemmas:
94 | present_concepts.append(concept)
95 |
96 | return set(present_concepts)
97 |
98 |
99 |
100 |
101 | if __name__ == "__main__":
102 | task_feedback = CommongenFeedback(
103 | prompt_examples="data/prompt/commongen/feedback.v1.jsonl",
104 | engine="davinci-code-002"
105 | )
106 |
107 | print(task_feedback.prompt)
108 |
109 |
--------------------------------------------------------------------------------
/src/commongen/make_challenging.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | df = pd.read_json("/usr1/amadaan/shufflegen/data/original/commongen/val.jsonl", lines=True, orient="records")
4 |
5 | from itertools import chain
6 | all_concepts = set(chain(*df['concepts'].tolist()))
7 |
8 | # challenging data with 10-15 concepts
9 |
10 | import random
11 | random.seed(42)
12 |
13 | n_samples = 200
14 | res = []
15 | for i in range(n_samples):
16 | k = random.randint(20, 30)
17 | concepts = random.sample(all_concepts, k=k)
18 | res.append({"concepts": concepts})
19 |
20 | pd.DataFrame(res).to_json("data/commongen_very_challenging.jsonl", lines=True, orient="records")
21 |
22 |
--------------------------------------------------------------------------------
/src/commongen/run.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from tqdm import tqdm
3 | from typing import List
4 |
5 | from src.commongen.task_init import CommongenTaskInit
6 | from src.commongen.task_iterate import CommongenTaskIterate
7 | from src.commongen.feedback import CommongenFeedback
8 | from src.utils import retry_parse_fail_prone_cmd
9 |
10 | CODEX = "code-davinci-002"
11 | GPT3 = "text-davinci-003"
12 | CHATGPT = "gpt-3.5-turbo"
13 | ENGINE = GPT3
14 |
15 |
16 | @retry_parse_fail_prone_cmd
17 | def autofb_commongen(concepts: List[str], max_attempts: int) -> str:
18 |
19 | # initialize all the required components
20 |
21 | # generation of the first sentence
22 | task_init = CommongenTaskInit(engine=ENGINE, prompt_examples="data/prompt/commongen/init.jsonl")
23 |
24 | # getting feedback
25 | task_feedback = CommongenFeedback(
26 | engine=ENGINE, prompt_examples="data/prompt/commongen/feedback.jsonl"
27 | )
28 |
29 | # iteratively improving the sentence
30 | task_iterate = CommongenTaskIterate(
31 | engine=ENGINE, prompt_examples="data/prompt/commongen/iterate.jsonl"
32 | )
33 |
34 | # Initialize the task
35 |
36 | n_attempts = 0
37 |
38 | print(f"{n_attempts} INIT> {concepts}")
39 | sent_to_fb = []
40 |
41 | while n_attempts < max_attempts:
42 | print()
43 |
44 | if n_attempts == 0:
45 | sent = task_init(concepts=concepts)
46 | else:
47 | sent = task_iterate(concepts=concepts, sent_to_fb=sent_to_fb)
48 |
49 | print(f"{n_attempts} GEN> {sent}")
50 |
51 | concept_fb, commonsense_fb = task_feedback(concepts=concepts, sentence=sent)
52 |
53 | sent_to_fb.append(
54 | {
55 | "sentence": sent,
56 | "concept_feedback": [f.strip() for f in concept_fb.split(",")],
57 | "commonsense_feedback": commonsense_fb,
58 | }
59 | )
60 | print(f"{n_attempts} Concept> {concept_fb} | CommonSense> {commonsense_fb}")
61 |
62 | if concept_fb.lower() == "none" and commonsense_fb.lower() == "none":
63 | break
64 |
65 | n_attempts += 1
66 |
67 | return sent_to_fb
68 |
69 |
70 | def run_cmd():
71 | concepts = sys.argv[2:]
72 | max_attempts = 5
73 | sent_to_fb = autofb_commongen(
74 | concepts=concepts,
75 | max_attempts=max_attempts,
76 | )
77 |
78 | res = []
79 | for s in sent_to_fb:
80 | sent = s["sentence"]
81 | fb = "; ".join(s["concept_feedback"]) + " " + s["commonsense_feedback"]
82 | res.append(f"{sent} ({fb})")
83 | print(" -> ".join(res))
84 |
85 |
86 | def run_iter(inputs_file_path: str, max_attempts: int = 4):
87 | test_df = pd.read_json(inputs_file_path, lines=True, orient="records")
88 | # add new columns sent_to_fb of type object, and status of type string
89 |
90 | is_rerun = "status" in test_df.columns
91 | if not is_rerun:
92 | test_df["sent_to_fb"] = None
93 | test_df["sent_to_fb"] = test_df["sent_to_fb"].astype(object)
94 | test_df["status"] = None
95 | else:
96 | print("Status column already exists! Looks like you're trying to do a re-run")
97 | print(test_df["status"].value_counts())
98 | for i, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Running autofb iter"):
99 | if row["status"] == "success":
100 | continue
101 | try:
102 | sent_to_fb = autofb_commongen(concepts=row["concepts"], max_attempts=max_attempts)
103 | test_df.loc[i, "sent_to_fb"] = sent_to_fb
104 | test_df.loc[i, "status"] = "success"
105 | except Exception as e:
106 | test_df.loc[i, "sent_to_fb"] = str(e)
107 | test_df.loc[i, "status"] = "error"
108 |
109 | output_path = inputs_file_path + (".iter.out" if not is_rerun else ".v0")
110 | version = 0
111 | while pathlib.Path(output_path).exists():
112 | output_path = output_path + f".v{version}"
113 | version += 1
114 |
115 | test_df.to_json(output_path, orient="records", lines=True)
116 |
117 |
118 | def run_multi_sample(inputs_file_path: str, n_samples: int = 4):
119 | test_df = pd.read_json(inputs_file_path, lines=True, orient="records")
120 |
121 | is_rerun = "status" in test_df.columns
122 | if not is_rerun:
123 | test_df["outputs"] = None
124 | test_df["outputs"] = test_df["outputs"].astype(object)
125 | test_df["status"] = None
126 | else:
127 | print("Status column already exists! Looks like you're trying to do a re-run")
128 | print(test_df["status"].value_counts())
129 |
130 | task_init = CommongenTaskInit(engine=ENGINE, prompt_examples="data/prompt/commongen/init.jsonl")
131 | task_feedback = CommongenFeedback(
132 | engine=ENGINE, prompt_examples="data/prompt/commongen/feedback.v1.jsonl"
133 | )
134 | for i, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Running multisample autofb"):
135 |
136 | if row["status"] == "success":
137 | continue
138 | try:
139 | outputs = []
140 | for _ in range(n_samples):
141 | sent = task_init(concepts=row["concepts"])
142 | print(sent)
143 | concept_fb, commonsense_fb = task_feedback(concepts=row["concepts"], sentence=sent)
144 | print(concept_fb, commonsense_fb)
145 | outputs.append(
146 | {
147 | "sentence": sent,
148 | "concept_feedback": [f.strip() for f in concept_fb.split(",")],
149 | "commonsense_feedback": commonsense_fb,
150 | }
151 | )
152 | if concept_fb.lower() == "none" and commonsense_fb.lower() == "none":
153 | break
154 | test_df.loc[i, "outputs"] = outputs
155 | test_df.loc[i, "status"] = "success"
156 | except Exception as e:
157 | raise e
158 | test_df.loc[i, "outputs"] = str(e)
159 | test_df.loc[i, "status"] = "error"
160 | print(test_df)
161 | output_path = inputs_file_path + "." + ENGINE + (".multi.out" if not is_rerun else ".v0")
162 | version = 0
163 | while pathlib.Path(output_path).exists():
164 | output_path = output_path + f".v{version}"
165 | version += 1
166 |
167 | test_df.to_json(output_path, orient="records", lines=True)
168 |
169 |
170 | if __name__ == "__main__":
171 | import sys
172 | import pandas as pd
173 |
174 | if sys.argv[1] == "cmd":
175 | run_cmd()
176 |
177 | elif sys.argv[1] == "batch-iter":
178 | run_iter(inputs_file_path=sys.argv[2])
179 |
180 | elif sys.argv[1] == "batch-multi":
181 | run_multi_sample(inputs_file_path=sys.argv[2])
182 |
183 | else:
184 | raise ValueError("Invalid mode: choose between cmd, batch-iter, batch-multi")
185 |
186 |
187 |
--------------------------------------------------------------------------------
/src/commongen/task_init.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import pandas as pd
3 | from src.utils import Prompt
4 |
5 | from prompt_lib.backends import openai_api
6 |
7 |
8 | class CommongenTaskInit(Prompt):
9 | def __init__(self, prompt_examples: str, engine: str) -> None:
10 | super().__init__(
11 | question_prefix="Concepts: ",
12 | answer_prefix="Sentence: ",
13 | intra_example_sep="\n\n",
14 | inter_example_sep="\n\n###\n\n",
15 | )
16 | self.engine = engine
17 | self.setup_prompt_from_examples_file(prompt_examples)
18 |
19 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
20 | TEMPLATE = """Concepts: {concepts}
21 |
22 | Sentence: {sentence}"""
23 |
24 | examples_df = pd.read_json(examples_path, orient="records", lines=True)
25 | prompt = []
26 | for _, row in examples_df.iterrows():
27 | prompt.append(TEMPLATE.format(concepts=row["concepts"], sentence=row["target"]))
28 | self.prompt = self.inter_example_sep.join(prompt)
29 | self.prompt = self.prompt + self.inter_example_sep
30 |
31 | def make_query(self, concepts: List[str]) -> str:
32 |
33 | query = f"""{self.question_prefix}{concepts}"""
34 | query = f"{self.prompt}{query}{self.intra_example_sep}"
35 | return query
36 |
37 | def __call__(self, concepts: List[str]) -> str:
38 | generation_query = self.make_query(concepts) + "\nDo your best! It's okay if the sentence is not coherent.\n"
39 |
40 | output = openai_api.OpenaiAPIWrapper.call(
41 | prompt=generation_query,
42 | engine=self.engine,
43 | max_tokens=300,
44 | stop_token="###",
45 | temperature=0.7,
46 | )
47 |
48 | generated_sent = openai_api.OpenaiAPIWrapper.get_first_response(output)
49 | print(generated_sent)
50 | generated_sent = generated_sent.split(self.answer_prefix)[1].replace("#", "").strip()
51 | return generated_sent.strip()
52 |
53 |
54 |
55 | if __name__ == "__main__":
56 | task_init = CommongenTaskInit(
57 | prompt_examples="data/prompt/commongen/init.jsonl",
58 | engine="davinci-code-002"
59 | )
60 |
61 | print(task_init.prompt)
62 | # print(task_init.make_query(["a", "b", "c"]))
--------------------------------------------------------------------------------
/src/commongen/task_iterate.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Dict, List
3 | from src.utils import Prompt
4 |
5 | from prompt_lib.backends import openai_api
6 |
7 | header = """Concepts: {concepts}
8 | """
9 | example_template = """Sentence: {sentence}
10 |
11 | what concepts from the concept list are missing from the sentence?
12 |
13 | Concept Feedback: {concept_feedback}
14 |
15 | Any feedback on commonsense?
16 |
17 | Commonsense Feedback: {commonsense_feedback}"""
18 |
19 | instr = """
20 |
21 | Okay, impove the sentence using the feedback:
22 |
23 | """
24 |
25 | class CommongenTaskIterate(Prompt):
26 | def __init__(self, engine: str, prompt_examples: str) -> None:
27 | super().__init__(
28 | question_prefix="",
29 | answer_prefix="",
30 | intra_example_sep="\n\n",
31 | inter_example_sep="\n\n###\n\n",
32 | )
33 | self.engine = engine
34 | self.count = 0
35 | self.prompt = self.make_prompt(prompt_examples=prompt_examples)
36 |
37 | def make_prompt(self, prompt_examples: str) -> str:
38 | import pandas as pd
39 |
40 | prompt_examples = pd.read_json(prompt_examples, orient="records", lines=True)
41 |
42 | prompt = []
43 |
44 | for example in prompt_examples.to_dict(orient="records"):
45 | prompt.append(
46 | self.make_one_iterate_example(
47 | concepts=example["concepts"], sent_to_fb=example["sentence_to_feedback"]
48 | )
49 | )
50 |
51 | return self.inter_example_sep.join(prompt) + self.inter_example_sep
52 |
53 | def make_one_iterate_example(self, concepts: List[str], sent_to_fb: List[Dict]):
54 | """Given a list of examples that are incrementally improving, return a new example."""
55 |
56 |
57 |
58 | single_example = []
59 | for example in sent_to_fb:
60 |
61 | single_example.append(example_template.format(
62 | sentence=example["sentence"], commonsense_feedback=example["commonsense_feedback"], concept_feedback=example["concept_feedback"]
63 | ))
64 |
65 | return header.format(concepts=concepts) + instr.join(single_example)
66 |
67 | def make_query(self, concepts: List[str],
68 | sent_to_fb: List[Dict],) -> str:
69 | query_example = self.make_one_iterate_example(concepts=concepts, sent_to_fb=sent_to_fb)
70 | return f"{self.prompt}{self.question_prefix}{query_example}{self.intra_example_sep}{self.answer_prefix}" + instr
71 | # return super().make_query(prompt, question)
72 |
73 | def __call__(
74 | self,
75 | concepts: List[str],
76 | sent_to_fb: List[Dict],
77 | ) -> str:
78 |
79 | transfer_query = self.make_query(concepts=concepts, sent_to_fb=sent_to_fb)
80 | self.count += 1
81 |
82 | output = openai_api.OpenaiAPIWrapper.call(
83 | prompt=transfer_query,
84 | engine=self.engine,
85 | max_tokens=300,
86 | stop_token=self.inter_example_sep,
87 | temperature=0.7,
88 | )
89 | response = openai_api.OpenaiAPIWrapper.get_first_response(output)
90 |
91 | print("######")
92 | print()
93 | print("------")
94 | print(response)
95 | print("------")
96 |
97 | response = re.search("Sentence: (.*)", response).group(1).strip().split("\n")[0].strip()
98 |
99 | return response.strip()
100 |
101 | def make_input(
102 | self,
103 | title: str,
104 | acronyms_to_scores: Dict[str, str],
105 | ) -> str:
106 | input_txt = ""
107 | for acronym, scores in acronyms_to_scores.items():
108 | input_txt += self._make_input(
109 | title=title,
110 | acronym=acronym,
111 | scores=scores,
112 | )
113 | return input_txt
114 |
115 |
116 | if __name__ == "__main__":
117 | obj = CommongenTaskIterate(
118 | prompt_examples="data/prompt/commongen/iterate.v1.jsonl", engine="whatever"
119 | )
120 | print(obj.prompt)
121 | # print(obj.make_query(concepts=["a", "b"], sent_to_fb=[{"sentence": "a", "feedback": "a"}, {"sentence": "b", "feedback": "d"}]))
122 |
--------------------------------------------------------------------------------
/src/gsm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/gsm/__init__.py
--------------------------------------------------------------------------------
/src/gsm/feedback.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from prompt_lib.backends import openai_api
3 |
4 | from src.utils import Prompt
5 |
6 |
7 | class GSMFeedback(Prompt):
8 | def __init__(self, engine: str, prompt_examples: str, temperature: float, max_tokens: int = 600) -> None:
9 | super().__init__(
10 | question_prefix="",
11 | answer_prefix="",
12 | intra_example_sep="\n\n",
13 | inter_example_sep="\n\n### END ###\n\n",
14 | engine = engine,
15 | temperature = temperature
16 | )
17 |
18 | self.max_tokens = max_tokens
19 | self.instruction = "# There is an error in the code above because of lack of understanding of the question. What is the error? To find the error, go through semantically complete blocks of the code, and check if everything looks good." if "naive" not in prompt_examples else "# There is an error in the code above."
20 | self.setup_prompt_from_examples_file(prompt_examples)
21 |
22 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
23 | with open(examples_path, "r") as f:
24 | self.prompt = f.read()
25 |
26 | def __call__(self, solution: str):
27 | generation_query = self.make_query(solution=solution)
28 | print(generation_query)
29 | # print(1/0)
30 | output = openai_api.OpenaiAPIWrapper.call(
31 | prompt=generation_query,
32 | engine=self.engine,
33 | max_tokens=self.max_tokens,
34 | stop_token="### END",
35 | temperature=self.temperature,
36 | )
37 |
38 |
39 | entire_output = openai_api.OpenaiAPIWrapper.get_first_response(output)
40 | print(entire_output)
41 | if "### END" in entire_output:
42 | entire_output = entire_output.split("### END")[0]
43 |
44 | improved_soln = entire_output.split("def solution():")[1]
45 | feedback = entire_output.split("def solution():")[0]
46 | improved_soln = "def solution():" + improved_soln.rstrip()
47 | self.update_prompt(solution=solution, improved_soln=improved_soln, feedback=feedback)
48 | return {"solution": improved_soln, "feedback": feedback}
49 |
50 | def make_query(self, solution: str):
51 |
52 | solution = f"""{self.question_prefix}{solution}{self.intra_example_sep}{self.instruction}{self.answer_prefix}"""
53 | return f"{self.prompt}{solution}"
54 |
55 |
56 | def update_prompt(self, solution: str, improved_soln: str, feedback: str):
57 | prefix = f"""{self.question_prefix}{solution}{self.intra_example_sep}{self.instruction}{self.answer_prefix}"""
58 |
59 | gen_ans = f"""
60 |
61 | {feedback}
62 |
63 | {improved_soln.rstrip()}{self.inter_example_sep}"""
64 |
65 | new_example = f"{prefix}{gen_ans}"
66 | self.prompt = f"{self.prompt}{new_example}"
67 |
68 |
69 | def test():
70 | task_fb = GSMFeedback(
71 | prompt_examples="data/prompt/gsm/pal/feedback.txt",
72 | engine="gpt-3.5-turbo",
73 | temperature=0.7,
74 | )
75 |
76 | wrong_soln = """def solution():
77 | \"\"\"Twenty dozen cups cost $1200 less than the total cost of half a dozen plates sold at $6000 each. Calculate the total cost of buying each cup.\"\"\"
78 | plates = 6
79 | plate_cost = 6000
80 | cups = 12 * 20
81 | cup_cost = (plates * plate_cost) / cups - 1200
82 | result = cup_cost
83 | return result"""
84 | feedback_and_solution = task_fb(wrong_soln)
85 | print(feedback_and_solution["feedback"])
86 | print(feedback_and_solution["solution"])
87 |
88 | print(task_fb.prompt)
89 |
90 |
91 | if __name__ == '__main__':
92 | test()
--------------------------------------------------------------------------------
/src/gsm/feedback_no_update.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from prompt_lib.backends import openai_api
3 |
4 | from src.utils import Prompt
5 |
6 |
7 | class GSMFeedback(Prompt):
8 | def __init__(self, engine: str, prompt_examples: str, temperature: float, max_tokens: int = 300) -> None:
9 | super().__init__(
10 | question_prefix="",
11 | answer_prefix="",
12 | intra_example_sep="\n\n",
13 | inter_example_sep="\n\n### END ###\n\n",
14 | engine = engine,
15 | temperature = temperature
16 | )
17 | self.max_tokens = max_tokens
18 | self.instruction = "# There is an error in the code above because of lack of understanding of the question. What is the error? To find the error, go through semantically complete blocks of the code, and check if everything looks good."
19 | self.setup_prompt_from_examples_file(prompt_examples)
20 |
21 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
22 | with open(examples_path, "r") as f:
23 | self.prompt = f.read()
24 |
25 | def __call__(self, solution: str):
26 | generation_query = self.make_query(solution=solution)
27 | output = openai_api.OpenaiAPIWrapper.call(
28 | prompt=generation_query,
29 | engine=self.engine,
30 | max_tokens=self.max_tokens,
31 | stop_token="### END",
32 | temperature=self.temperature,
33 | )
34 |
35 | entire_output = openai_api.OpenaiAPIWrapper.get_first_response(output)
36 | if "### END" in entire_output:
37 | entire_output = entire_output.split("### END")[0]
38 | solution = entire_output.split("def solution():")[1]
39 | feedback = entire_output.split("def solution():")[0]
40 | solution = "def solution():" + solution.rstrip()
41 | return {"solution": solution, "feedback": feedback}
42 |
43 | def make_query(self, solution: str):
44 | solution = f"""{self.question_prefix}{solution}{self.intra_example_sep}{self.instruction}{self.answer_prefix}"""
45 | return f"{self.prompt}{solution}"
46 |
47 |
48 | def test():
49 | task_fb = GSMFeedback(
50 | prompt_examples="data/prompt/gsm/feedback.txt",
51 | engine="code-davinci-002",
52 | temperature=0.7,
53 | )
54 |
55 | wrong_soln = """def solution():
56 | \"\"\"Twenty dozen cups cost $1200 less than the total cost of half a dozen plates sold at $6000 each. Calculate the total cost of buying each cup.\"\"\"
57 | plates = 6
58 | plate_cost = 6000
59 | cups = 12 * 20
60 | cup_cost = (plates * plate_cost) / cups - 1200
61 | result = cup_cost
62 | return result"""
63 | feedback_and_solution = task_fb(wrong_soln)
64 | print(feedback_and_solution["feedback"])
65 | print(feedback_and_solution["solution"])
66 |
67 |
68 | if __name__ == '__main__':
69 | test()
70 |
--------------------------------------------------------------------------------
/src/gsm/gsm_selfref_eval.py:
--------------------------------------------------------------------------------
1 | from importlib import reload
2 | import pandas as pd
3 | from tqdm import tqdm
4 | from contextlib import contextmanager
5 | import signal
6 | from glob import glob
7 | import os
8 |
9 | # from https://stackoverflow.com/questions/492519/timeout-on-a-function-call
10 | @contextmanager
11 | def timeout(duration):
12 | def timeout_handler(signum, frame):
13 | raise TimeoutError(f"block timedout after {duration} seconds")
14 |
15 | signal.signal(signal.SIGALRM, timeout_handler)
16 | signal.alarm(duration)
17 | try:
18 | yield
19 | finally:
20 | signal.alarm(0)
21 |
22 | def read_json(path):
23 | import json
24 | rows = []
25 | with open(path, "r") as f:
26 | for line in f:
27 | rows.append(json.loads(line))
28 |
29 | task_df = pd.DataFrame(rows)
30 | return task_df
31 |
32 | def evaluate_code_prompt(path, num_gsm: int = 1319):
33 | data = read_json(path)
34 | if "question" not in data.columns:
35 | data["question"] = data["input"]
36 | if "answer" not in data.columns:
37 | data["answer"] = data["target"]
38 |
39 | attempt_to_acc = []
40 | reports = [] # Step 1
41 | for idx, row in tqdm(data.iterrows(), total=len(data)):
42 | # if idx < 20:
43 | # continue
44 | # if idx > 10:
45 | # break
46 | attempt_to_acc_ = {i: 0 for i in range(5)}
47 | attempt_to_acc_["question"] = row["question"]
48 | solutions = []
49 | if row["run_logs"] is None:
50 | continue
51 | for _, log in enumerate(row["run_logs"]):
52 | solutions.append(log["solution_curr"])
53 | solutions.append(row["run_logs"][-1]["solution_fixed"])
54 |
55 | feedback = [rec["feedback"] for rec in row["run_logs"]]
56 |
57 | prev_accuracy = 0
58 | for iter_idx, soln in enumerate(solutions):
59 | soln = soln.split("\n\n\n")[0].strip() + "\n"
60 | soln = soln.replace("The answer is", "").strip() + "\n"
61 | os.system("rm -rf __pycache__")
62 | os.system("rm -f temp_result.pyc")
63 |
64 | with open("temp_result.py", "w") as f:
65 | f.write(soln)
66 |
67 | try:
68 | import temp_result
69 | reload(temp_result)
70 | correct_solution = str(row["answer"])
71 |
72 | exec(soln)
73 | with timeout(1):
74 | result = str(temp_result.solution())
75 | is_corr = check_corr(result, correct_solution)
76 |
77 |
78 | is_corr = int(is_corr)
79 | # Step 2
80 |
81 | if iter_idx > 0 and is_corr == 1 and prev_accuracy == 0:
82 | report = {
83 | "previous_solution": solutions[iter_idx - 1],
84 | "feedback": feedback[iter_idx - 1],
85 | "next_solution": solutions[iter_idx],
86 | }
87 | reports.append(report) # Step 3
88 | if is_corr == 1:
89 | for j in range(iter_idx, 5):
90 | attempt_to_acc_[j] = 1
91 | break
92 | attempt_to_acc_[iter_idx] = 0
93 | prev_accuracy = is_corr
94 | except Exception as e:
95 | continue
96 |
97 | attempt_to_acc.append(attempt_to_acc_)
98 |
99 | df = pd.DataFrame(attempt_to_acc)
100 |
101 | # print(attempt_to_acc)
102 | for i in range(5):
103 | print(f"Accuracy at attempt {i} = {df[i].sum() / num_gsm:.2%} ({df[i].sum()}/{num_gsm})")
104 |
105 | df.to_json("/tmp/attempt_to_acc.jsonl", orient="records", lines=True)
106 |
107 | report_file = f"{path}.reports.txt"
108 | print_reports(reports, report_file) # Step 4
109 | return reports
110 |
111 | # Step 4
112 | def print_reports(reports, report_file):
113 |
114 |
115 | with open(report_file, "w") as f:
116 | for i, report in enumerate(reports):
117 | f.write(f"Report {i + 1}:\n")
118 | f.write("\nPrevious solution:\n")
119 | f.write(report["previous_solution"])
120 | f.write("\n\nFeedback:\n")
121 | f.write(report["feedback"])
122 | f.write("\n\nNext solution:\n")
123 | f.write(report["next_solution"])
124 | f.write("\n\n" + "=" * 80 + "\n\n")
125 | def check_corr(result: float, correct_solution: float, tol: float = 1e-3):
126 | if result.strip() == correct_solution.strip():
127 | return 1
128 | try:
129 | result = float(result.strip())
130 | correct_solution = float(correct_solution.strip())
131 | return abs(result - correct_solution) < tol
132 | except:
133 | return 0
134 |
135 |
136 |
137 | if __name__ == "__main__":
138 | import argparse
139 |
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument("--path", type=str, default="data/quco/quco_test.jsonl")
142 | args = parser.parse_args()
143 |
144 | evaluate_code_prompt(args.path)
145 |
--------------------------------------------------------------------------------
/src/gsm/run.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from tqdm import tqdm
3 |
4 |
5 | from src.gsm.task_init import GSMInit
6 | from src.gsm.feedback import GSMFeedback
7 |
8 | from src.utils import retry_parse_fail_prone_cmd
9 |
10 | CODEX = "code-davinci-002"
11 | # GPT3 = "text-davinci-003"
12 | ENGINE = CODEX
13 |
14 |
15 | @retry_parse_fail_prone_cmd
16 | def iterative_gsm(question: str, max_attempts: int, feedback_type: str, temperature: float):
17 |
18 | # initialize all the required components
19 |
20 | # generation of the first fast version
21 | task_init = GSMInit(engine=ENGINE, prompt_examples="data/prompt/gsm/init.txt", temperature=temperature)
22 |
23 | # getting feedback
24 | if feedback_type == "naive":
25 | raise NotImplementedError
26 | else:
27 | task_feedback = GSMFeedback(engine=ENGINE, prompt_examples="data/prompt/gsm/feedback.txt", temperature=0.7)
28 |
29 |
30 | n_attempts = 0
31 |
32 | log = []
33 |
34 | while n_attempts < max_attempts:
35 |
36 | if n_attempts == 0:
37 | solution = task_init(solution=question)
38 |
39 | fb_and_maybe_soln = task_feedback(solution=solution)
40 |
41 |
42 | log.append({"attempt": n_attempts, "solution_curr": solution, "solution_fixed": fb_and_maybe_soln["solution"], "feedback": fb_and_maybe_soln["feedback"]})
43 |
44 | if "it is correct" in fb_and_maybe_soln["feedback"].lower():
45 | break
46 |
47 | solution = fb_and_maybe_soln["solution"]
48 |
49 | n_attempts += 1
50 |
51 | return log
52 |
53 |
54 | def fix_gsm(gsm_task_file: str, max_attempts: int, outfile: str, feedback_type: str, temperature: float):
55 |
56 |
57 | slow_programs_df = pd.read_json(gsm_task_file, lines=True, orient="records")
58 | slow_programs_df["run_logs"] = None
59 | results = []
60 | for i, row in tqdm(slow_programs_df.iterrows(), total=len(slow_programs_df)):
61 | row_copy = row.to_dict()
62 | try:
63 | run_logs = iterative_gsm(question=row["input"], max_attempts=max_attempts, feedback_type=feedback_type, temperature=temperature)
64 | row_copy["run_logs"] = run_logs
65 | row_copy["generated_answer_ours"] = run_logs[-1]["solution_fixed"]
66 | row_copy["generated_answer_direct"] = run_logs[0]["solution_curr"]
67 | results.append(row_copy)
68 | if i % 10 == 0:
69 | pd.DataFrame(results).to_json(outfile + f".{i}.jsonl", orient="records", lines=True)
70 | except Exception as e:
71 | # raise e
72 | pass
73 | pd.DataFrame(results).to_json(outfile, orient="records", lines=True)
74 | return results
75 |
76 |
77 | def test():
78 | import json
79 |
80 |
81 | with open("/tmp/debug_gsm.jsonl", "w") as fout:
82 | fout.write(json.dumps({"input": "Twenty dozen cups cost $1200 less than the total cost of half a dozen plates sold at $6000 each. Calculate the total cost of buying each cup."}))
83 |
84 | logs = fix_gsm(
85 | gsm_task_file="/tmp/debug_gsm.jsonl", max_attempts=3, outfile="/tmp/test.jsonl", feedback_type="rich", temperature=0.0
86 | )
87 | for i, log in enumerate(logs):
88 | print(log["generated_answer_ours"])
89 | print(log["generated_answer_direct"])
90 |
91 |
92 | if __name__ == "__main__":
93 | import sys
94 |
95 | if sys.argv[1] == "test":
96 | test()
97 | else:
98 | import argparse
99 | args = argparse.ArgumentParser()
100 | args.add_argument("--gsm_task_file", type=str, default="data/tasks/gsm/gsm.jsonl")
101 | args.add_argument("--max_attempts", type=int, default=4)
102 | args.add_argument("--outfile", type=str, default="data/tasks/gsm/gsm_outputs.jsonl")
103 | args.add_argument("--feedback_type", type=str, default="rich")
104 | args.add_argument("--temperature", type=float, default=0.0)
105 | args = args.parse_args()
106 | args.outfile = f"{args.outfile}.fb_{args.feedback_type}.temp_{args.temperature}.engine_{ENGINE}.jsonl"
107 | fix_gsm(gsm_task_file=args.gsm_task_file, max_attempts=args.max_attempts, outfile=args.outfile, feedback_type=args.feedback_type, temperature=args.temperature)
--------------------------------------------------------------------------------
/src/gsm/task_init.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from src.utils import Prompt
3 |
4 | from prompt_lib.backends import openai_api
5 |
6 |
7 | class GSMInit(Prompt):
8 | def __init__(self, prompt_examples: str, engine: str, temperature: float) -> None:
9 | super().__init__(
10 | question_prefix="# Q: ",
11 | answer_prefix="# solution using Python:\n",
12 | intra_example_sep="\n",
13 | inter_example_sep="\n\n",
14 | engine=engine,
15 | temperature=temperature,
16 | )
17 | self.setup_prompt_from_examples_file(prompt_examples)
18 |
19 | def setup_prompt_from_examples_file(self, prompt_examples) -> str:
20 | with open(prompt_examples, "r") as f:
21 | self.prompt = f.read()
22 |
23 | def make_query(self, solution: str) -> str:
24 | solution = solution.strip()
25 | query = f"{self.prompt}{self.question_prefix}{solution}{self.intra_example_sep}{self.answer_prefix}"
26 | return query
27 |
28 | def __call__(self, solution: str) -> str:
29 | generation_query = self.make_query(solution)
30 | output = openai_api.OpenaiAPIWrapper.call(
31 | prompt=generation_query,
32 | engine=self.engine,
33 | max_tokens=300,
34 | stop_token=self.inter_example_sep,
35 | temperature=self.temperature,
36 | )
37 |
38 | solution_code = openai_api.OpenaiAPIWrapper.get_first_response(output)
39 |
40 | return solution_code.strip()
41 |
42 |
43 | def test():
44 | task_init = GSMInit(
45 | prompt_examples="data/prompt/gsm/init.txt",
46 | engine="code-davinci-002",
47 | temperature=0.0,
48 | )
49 |
50 | question = "The educational shop is selling notebooks for $1.50 each and a ballpen at $0.5 each. William bought five notebooks and a ballpen. How much did he spend in all?"
51 | print(task_init(question))
52 |
53 |
54 | if __name__ == "__main__":
55 | test()
--------------------------------------------------------------------------------
/src/pie/feedback.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from prompt_lib.backends import openai_api
3 |
4 | from src.utils import Prompt
5 |
6 |
7 | class PieFeedback(Prompt):
8 | def __init__(self, engine: str, prompt_examples: str, temperature: float, max_tokens: int = 300) -> None:
9 | super().__init__(
10 | question_prefix="",
11 | answer_prefix="# Why is this code slow?\n",
12 | intra_example_sep="\n\n",
13 | inter_example_sep="\n\n### END ###n\n",
14 | )
15 | self.engine = engine
16 | self.max_tokens = max_tokens
17 | self.temperature = temperature
18 | self.setup_prompt_from_examples_file(prompt_examples)
19 |
20 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
21 | with open(examples_path, "r") as f:
22 | self.prompt = f.read()
23 |
24 | def __call__(self, slow_code: str):
25 | generation_query = self.make_query(slow_code=slow_code)
26 |
27 | output = openai_api.OpenaiAPIWrapper.call(
28 | prompt=generation_query,
29 | engine=self.engine,
30 | max_tokens=self.max_tokens,
31 | stop_token="### END",
32 | temperature=self.temperature,
33 | )
34 |
35 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output)
36 | if "### END" in generated_feedback:
37 | generated_feedback = generated_feedback.split("### END")[0]
38 | return generated_feedback.strip()
39 |
40 | def make_query(self, slow_code: str):
41 | slow_code = f"""{self.question_prefix}{slow_code}{self.intra_example_sep}{self.answer_prefix}"""
42 | return f"{self.prompt}{slow_code}"
43 |
44 |
45 | def test():
46 | task_fb = PieFeedback(
47 | prompt_examples="data/prompt/pie/feedback.txt",
48 | engine="gpt-3.5-turbo",
49 | temperature=0.0
50 | )
51 |
52 | print(task_fb.prompt)
53 | slow_code = "def sum(n):\\n res = 0\\n for i in range(n):\\n res += i\\n return res"
54 | print(task_fb(slow_code))
55 |
56 |
57 | if __name__ == '__main__':
58 | test()
--------------------------------------------------------------------------------
/src/pie/prep_for_pie_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 |
4 | """Parses the self-refine outputs, extracts the output code from each attempt in a new column, and saves the results to a JSON file."""
5 |
6 | def extract_attempt_codes(self_refine_output_path,
7 | flattened_output_path, num_attempts):
8 | """This function creates a file where each attempt/output at each step is stored in a new column: attempt_0_code, attempt_1_code, etc.
9 |
10 | Args:
11 | input_file (_type_): _description_
12 | output_file (_type_): _description_
13 | num_attempts (_type_): _description_
14 | """
15 | outputs = pd.read_json(self_refine_output_path, orient="records", lines=True)
16 | rows = []
17 |
18 | for _, row in outputs.iterrows():
19 | # Convert the row to a dictionary.
20 | tmp = row.to_dict()
21 | # Extract the code from each attempt and store it in the temporary dictionary.
22 | for i in range(num_attempts):
23 | if len(row["run_logs"]) <= i:
24 | tmp[f"attempt_{i}_code"] = ""
25 | else:
26 | tmp[f"attempt_{i}_code"] = row["run_logs"][i]["fast_code"]
27 |
28 | rows.append(tmp)
29 | # Convert the rows list to a DataFrame and save it to a JSON file.
30 | pd.DataFrame(rows).to_json(flattened_output_path, orient="records", lines=True)
31 |
32 | # Main execution of the script.
33 | if __name__ == "__main__":
34 | # Initialize argument parser.
35 | parser = argparse.ArgumentParser(description="Generate Yaml and Extract Codes")
36 | # Define expected arguments.
37 | parser.add_argument("model", type=str, help="Model name")
38 | parser.add_argument("input_file", type=str, help="Path to the input JSON file")
39 | parser.add_argument("base_config_path", type=str, help="Base path for config files")
40 | parser.add_argument("base_output_path", type=str, help="Base path for output files")
41 | parser.add_argument("--num_attempts", type=int, default=4, help="Number of attempts")
42 | # Parse provided arguments.
43 | args = parser.parse_args()
44 |
45 | # Construct the output file path and extract the codes.
46 | output_path = f"{args.base_output_path}/{args.model}/output.attempt_codes"
47 | extract_attempt_codes(
48 | self_refine_output_path=args.input_file,
49 | flattened_output_path=output_path,
50 | num_attempts=args.num_attempts,
51 | )
52 |
--------------------------------------------------------------------------------
/src/pie/run.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from tqdm import tqdm
3 |
4 |
5 | from src.pie.task_init import PieInit
6 | from src.pie.task_iterate import PieIterate
7 | from src.pie.feedback import PieFeedback
8 |
9 | from src.utils import retry_parse_fail_prone_cmd
10 |
11 | CODEX = "code-davinci-002"
12 | GPT3 = "text-davinci-003"
13 | CHATGPT = "gpt-3.5-turbo"
14 | GPT4 = "gpt-4"
15 | ENGINE = CHATGPT
16 |
17 |
18 | @retry_parse_fail_prone_cmd
19 | def iterative_pie(slow_code: str, max_attempts: int, feedback_type: str, temperature: float):
20 |
21 | # initialize all the required components
22 |
23 | # generation of the first fast version
24 | task_init = PieInit(engine=ENGINE, prompt_examples="data/prompt/pie/init.txt", temperature=temperature)
25 |
26 | iterate_prompt = "data/prompt/pie/iterate.txt"
27 | # getting feedback
28 | if feedback_type == "naive":
29 | task_feedback = lambda **kwargs: "It could be faster"
30 | iterate_prompt = "data/prompt/pie/iterate_genericfb.txt"
31 |
32 | elif feedback_type == "none":
33 | task_feedback = lambda **kwargs: ""
34 | iterate_prompt = "data/prompt/pie/iterate_nofb.txt"
35 |
36 | else:
37 | task_feedback = PieFeedback(engine=ENGINE, prompt_examples="data/prompt/pie/feedback.txt", temperature=temperature)
38 |
39 | # iteratively improving the code
40 | task_iterate = PieIterate(engine=ENGINE, prompt_examples=iterate_prompt, temperature=temperature)
41 |
42 | # Initialize the task
43 |
44 | n_attempts = 0
45 |
46 | log = []
47 | feedback = None
48 |
49 | while n_attempts < max_attempts:
50 |
51 | if n_attempts == 0:
52 | fast_code = task_init(slow_code=slow_code)
53 | else:
54 | fast_code = task_iterate(slow_code=slow_code, feedback=feedback)
55 |
56 | # feedback = task_feedback(slow_code=slow_code)
57 | feedback = task_feedback(slow_code=fast_code)
58 |
59 | log.append({"fast_code": fast_code, "feedback": feedback, "slow_code": slow_code, "attempt": n_attempts})
60 | show_example(**log[-1])
61 |
62 | if "this code is not slow" in feedback.lower():
63 | break
64 |
65 | slow_code = fast_code
66 |
67 | n_attempts += 1
68 |
69 | return log
70 |
71 |
72 | def show_example(**kwargs):
73 | # shows {"fast_code": fast_code, "feedback": feedback, "slow_code": slow_code, "attempt": n_attempts}
74 | print(f"SLOW CODE:\n{kwargs['slow_code']}\n")
75 | print(f"\n\nFEEDBACK:\n{kwargs['feedback']}\n")
76 | print(f"\n\nFAST CODE:\n{kwargs['fast_code']}\n")
77 | print("-" * 100)
78 |
79 | def run_over_slow_programs(slow_programs_file: str, max_attempts: int, outfile: str, feedback_type: str, temperature: float, backup_file: str = None):
80 |
81 | slow_programs_df = pd.read_json(slow_programs_file, lines=True, orient="records")
82 | slow_programs_df["run_logs"] = None
83 |
84 | if backup_file:
85 | backup_df = pd.read_json(backup_file, lines=True, orient="records")
86 | processed_inputs = set(backup_df["submission_id_v0"].tolist())
87 | results = backup_df.to_dict("records")
88 | else:
89 | processed_inputs = set()
90 | results = []
91 |
92 | for i, row in tqdm(slow_programs_df.iterrows(), total=len(slow_programs_df)):
93 | if row["submission_id_v0"] in processed_inputs:
94 | continue
95 |
96 | row_copy = row.to_dict()
97 | try:
98 | run_logs = iterative_pie(slow_code=row["input"], max_attempts=max_attempts, feedback_type=feedback_type, temperature=temperature)
99 | print(run_logs)
100 | row_copy["run_logs"] = run_logs
101 | results.append(row_copy)
102 | if i % 20 == 0:
103 | pd.DataFrame(results).to_json(outfile + f".{i}.jsonl", orient="records", lines=True)
104 | except Exception as e:
105 | # raise e
106 | pass
107 | pd.DataFrame(results).to_json(outfile, orient="records", lines=True)
108 | return run_logs
109 |
110 |
111 |
112 | def test():
113 | slow_code = (
114 | "def sum(n):\\n res = 0\\n for i in range(n):\\n res += i\\n return res"
115 | )
116 | logs = run_over_slow_programs(
117 | slow_programs=[slow_code], max_attempts=3, outfile="/tmp/test.jsonl"
118 | )
119 | for (slow_code, log) in logs.items():
120 | for attempt in log:
121 | print(f"Slow code:\n {attempt['slow_code']}")
122 | print(f"Feedback: {attempt['feedback']}")
123 | print(f"Fast code:\n {attempt['fast_code']}")
124 | print()
125 |
126 | if __name__ == "__main__":
127 | import sys
128 |
129 | if sys.argv[1] == "test":
130 | test()
131 | else:
132 | import argparse
133 | import os
134 | args = argparse.ArgumentParser()
135 | args.add_argument("--slow_programs_file", type=str, required=True)
136 | args.add_argument("--max_attempts", type=int, default=3)
137 | args.add_argument("--outfile", type=str, required=True)
138 | args.add_argument("--feedback_type", type=str)
139 | args.add_argument("--temperature", type=float, default=0.0)
140 | args.add_argument("--backup_file", type=str)
141 |
142 | args = args.parse_args()
143 | args.outfile = f"{args.outfile}.fb_{args.feedback_type}.temp_{args.temperature}.engine_{ENGINE}.jsonl"
144 | if os.path.exists(args.outfile):
145 |
146 | v = 0
147 | while os.path.exists(args.outfile + f".v{v}"):
148 | v += 1
149 | args.outfile = args.outfile + f".v{v}"
150 | print(f"Output file {args.outfile} already exists. Adding a suffix to it (v{v})")
151 | run_over_slow_programs(slow_programs_file=args.slow_programs_file, max_attempts=args.max_attempts, outfile=args.outfile, feedback_type=args.feedback_type, temperature=args.temperature, backup_file=args.backup_file)
--------------------------------------------------------------------------------
/src/pie/task_init.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from src.utils import Prompt
3 |
4 | from prompt_lib.backends import openai_api
5 |
6 |
7 | class PieInit(Prompt):
8 | def __init__(self, prompt_examples: str, engine: str, temperature: float) -> None:
9 | super().__init__(
10 | question_prefix="# slower version:\n",
11 | answer_prefix="# optimized version of the same code:\n",
12 | intra_example_sep="\n\n\n",
13 | inter_example_sep="\n\n### END ###n\n",
14 | )
15 | self.engine = engine
16 | self.temperature = temperature
17 | self.setup_prompt_from_examples_file(prompt_examples)
18 |
19 | def setup_prompt_from_examples_file(self, prompt_examples) -> str:
20 | with open(prompt_examples, "r") as f:
21 | self.prompt = f.read()
22 |
23 | def make_query(self, slow_code: str) -> str:
24 | slow_code = slow_code.strip()
25 | query = f"{self.prompt}{self.question_prefix}{slow_code}{self.intra_example_sep}{self.answer_prefix}"
26 | return query
27 |
28 | def __call__(self, slow_code: str) -> str:
29 | generation_query = self.make_query(slow_code)
30 | output = openai_api.OpenaiAPIWrapper.call(
31 | prompt=generation_query,
32 | engine=self.engine,
33 | max_tokens=300,
34 | stop_token="### END",
35 | temperature=self.temperature,
36 | )
37 |
38 | generated_code = openai_api.OpenaiAPIWrapper.get_first_response(output)
39 |
40 | # if by chance the end token is present in the generated code, remove it
41 | if "### END" in generated_code:
42 | generated_code = generated_code.split("### END")[0]
43 | return generated_code.strip()
44 |
45 |
46 | def test():
47 | task_init = PieInit(
48 | prompt_examples="data/prompt/pie/init.txt",
49 | engine="gpt-3.5-turbo",
50 | temperature=0.0
51 | )
52 |
53 | slow_code = """
54 | def sum(n):
55 | res = 0
56 | for i in range(n):
57 | res += i
58 | return res
59 | """
60 | single_line_ip = "def sum(n): res = 0; for i in range(n): res += i; return res\n\n# Optimize the above program for faster performance\n\n |||"
61 | print(task_init.prompt)
62 | print(task_init(slow_code))
63 |
64 |
65 | if __name__ == "__main__":
66 | test()
--------------------------------------------------------------------------------
/src/pie/task_iterate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Dict, List
3 | from src.utils import Prompt
4 |
5 | from prompt_lib.backends import openai_api
6 |
7 |
8 | class PieIterate(Prompt):
9 | def __init__(self, engine: str, prompt_examples: str, temperature: float, feedback_type: str = "default") -> None:
10 | super().__init__(
11 | question_prefix="",
12 | answer_prefix="# Improved version:\n",
13 | intra_example_sep="\n\n",
14 | inter_example_sep="\n\n### END ###n\n",
15 | )
16 | self.engine = engine
17 | self.count = 0
18 | self.temperature = temperature
19 | self.make_prompt(prompt_examples=prompt_examples)
20 | self.feedback_type = feedback_type
21 |
22 | def make_prompt(self, prompt_examples: str) -> str:
23 | with open(prompt_examples, "r") as f:
24 | self.prompt= f.read()
25 |
26 | # return super().make_query(prompt, question)
27 |
28 | def __call__(
29 | self,
30 | slow_code: str,
31 | feedback: str,
32 | ) -> str:
33 | generation_query = self.make_query(slow_code=slow_code, feedback=feedback)
34 |
35 | output = openai_api.OpenaiAPIWrapper.call(
36 | prompt=generation_query,
37 | engine=self.engine,
38 | max_tokens=300,
39 | stop_token="### END",
40 | temperature=self.temperature,
41 | )
42 | generated_code = openai_api.OpenaiAPIWrapper.get_first_response(output)
43 |
44 |
45 | if "### END" in generated_code:
46 | generated_code = generated_code.split("### END")[0]
47 | return generated_code.strip()
48 |
49 |
50 |
51 | def make_query(self, slow_code: str, feedback: str) -> str:
52 | instr = "# Why is this code slow?" if self.feedback_type == "default" else "# How to improve this code?"
53 | example_template = """{slow_code}
54 |
55 | {instr}
56 |
57 | {feedback}
58 |
59 | # Improved version:
60 |
61 | """
62 | query = example_template.format(slow_code=slow_code, feedback=feedback)
63 |
64 | return f"{self.prompt}{query}"
65 |
66 |
67 | def test():
68 | task_iterate = PieIterate(
69 | prompt_examples="data/prompt/pie/iterate.txt",
70 | engine="gpt-3.5-turbo",
71 | temperature=0.6
72 | )
73 |
74 | slow_code = "def sum(n):\\n res = 0\\n for i in range(n):\\n res += i\\n return res"
75 | feedback = "# This code is slow because it is using a brute force approach to calculate the sum of numbers up to n. It is looping through every number from 0 to n and adding it to the sum. This is a very slow approach because it has to loop through so many numbers. A better approach would be to use the formula for the sum of the numbers from 0 to n, which is (n(n+1))/2. Using this formula, you can calculate the sum of the numbers in constant time."
76 | # print(task_iterate.prompt)
77 | print(task_iterate(slow_code, feedback))
78 |
79 | if __name__ == '__main__':
80 | test()
--------------------------------------------------------------------------------
/src/readability/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/readability/__init__.py
--------------------------------------------------------------------------------
/src/readability/count_comment.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tokenize
3 |
4 | from tqdm import tqdm
5 | from io import BytesIO
6 | from argparse import ArgumentParser
7 |
8 |
9 | def count_comments(code):
10 | comment_count = 0
11 | total_lines = len([l for l in code.splitlines() if l.strip()])
12 |
13 | tokens = tokenize.tokenize(BytesIO(code.encode('utf-8')).readline)
14 | for token in tokens:
15 | if token.type == tokenize.COMMENT:
16 | comment_count += 1
17 | return comment_count, comment_count / total_lines
18 |
19 | def main():
20 | parser = ArgumentParser()
21 | parser.add_argument('--file', type=str)
22 | args = parser.parse_args()
23 |
24 | output = args.file[:-6] + '_comment_count.jsonl'
25 |
26 | score_counter = None
27 |
28 | with open(output, 'w') as fout:
29 | input_lines = open(args.file).readlines()
30 | for line in tqdm(input_lines, total=len(input_lines)):
31 | data = json.loads(line)
32 | original_code = data['original_code']
33 | updated_codes = [x['updated_code'] for x in data['updates']]
34 |
35 | if score_counter is None:
36 | score_counter = [[] for _ in range(len(updated_codes) + 1)]
37 |
38 | data['update_comment_count'] = []
39 | data['update_comment_ratio'] = []
40 |
41 | for i, code in enumerate([original_code] + updated_codes):
42 | num_comments, ratio = -1, -1
43 |
44 | if code:
45 | try:
46 | num_comments, ratio = count_comments(code)
47 | except:
48 | pass
49 |
50 | data['update_comment_count'].append(num_comments)
51 | data['update_comment_ratio'].append(ratio)
52 |
53 | score_counter[i].append(max(0, ratio))
54 | fout.write(json.dumps(data) + '\n')
55 |
56 | for i, scores in enumerate(score_counter):
57 | print(f'Update {i}: avg comment ratio {sum(scores) / len(scores)}')
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
62 |
--------------------------------------------------------------------------------
/src/readability/count_function.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ast
3 |
4 | from argparse import ArgumentParser
5 | from tqdm import tqdm
6 |
7 | def count_functions(code):
8 | tree = ast.parse(code)
9 | num_functions = sum(isinstance(node, ast.FunctionDef) for node in ast.walk(tree))
10 | return num_functions
11 |
12 | def main():
13 | parser = ArgumentParser()
14 | parser.add_argument('--file', type=str)
15 | args = parser.parse_args()
16 |
17 | output = args.file[:-6] + '_func_count.jsonl'
18 |
19 | score_counter = None
20 |
21 | with open(output, 'w') as fout:
22 | input_lines = open(args.file).readlines()
23 | for line in tqdm(input_lines, total=len(input_lines)):
24 | data = json.loads(line)
25 | original_code = data['original_code']
26 | updated_codes = [x['updated_code'] for x in data['updates']]
27 |
28 | if score_counter is None:
29 | score_counter = [[] for _ in range(len(updated_codes) + 1)]
30 |
31 | data['update_func_count'] = []
32 |
33 | for i, code in enumerate([original_code] + updated_codes):
34 | num_functions = -1
35 |
36 | if code:
37 | try:
38 | num_functions = count_functions(code)
39 | except:
40 | pass
41 |
42 | data['update_func_count'].append(num_functions)
43 | score_counter[i].append(max(0, num_functions))
44 | fout.write(json.dumps(data) + '\n')
45 |
46 | for i, scores in enumerate(score_counter):
47 | print(f'Update {i}: avg function number {sum(scores) / len(scores)}')
48 |
49 | if __name__ == '__main__':
50 | main()
--------------------------------------------------------------------------------
/src/readability/count_meaningful_var.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from argparse import ArgumentParser
4 |
5 | from src.readability.utils import call_gpt
6 | from src.readability.prompts import COUNT_VAR_PROMPT
7 |
8 | def count_meaningful_vars(code):
9 | if 'Fixed Code:' in code:
10 | code = code.split('Fixed Code:')[1]
11 |
12 | code = code.strip()
13 | prompt = COUNT_VAR_PROMPT.format(code=code)
14 | result = call_gpt(prompt, model='code-davinci-002', max_tokens=256, stop='\n\n\n')[0]
15 |
16 | result = result.strip().splitlines()
17 | num_vars = len(result)
18 | num_random_vars = sum([1 for line in result if line.endswith('- random')])
19 | num_meaningful_vars = num_vars - num_random_vars
20 |
21 | return num_meaningful_vars, num_meaningful_vars / num_vars, result
22 |
23 |
24 | def main():
25 | parser = ArgumentParser()
26 | parser.add_argument('--file', type=str)
27 | args = parser.parse_args()
28 | output = args.file[:-6] + '_var_count.jsonl'
29 | score_counter = None
30 |
31 | with open(output, 'w') as fout:
32 | input_lines = open(args.file).readlines()
33 | for line in tqdm(input_lines):
34 | data = json.loads(line)
35 | original_code = data['original_code']
36 | updated_codes = [x['updated_code'] for x in data['updates']]
37 |
38 | if score_counter is None:
39 | score_counter = [[] for _ in range(len(updated_codes) + 1)]
40 |
41 | data['update_meaningful_var_count'] = []
42 | data['update_meaningful_var_ratio'] = []
43 |
44 | for i, code in enumerate([original_code] + updated_codes):
45 | num_meaningful_vars, ratio, gpt_gen = -1, -1, None
46 | if code:
47 | num_meaningful_vars, meaningful_var_ratio, gpt_gen = count_meaningful_vars(code)
48 | data['update_meaningful_var_count'].append(num_meaningful_vars)
49 | data['update_meaningful_var_ratio'].append(meaningful_var_ratio)
50 |
51 | score_counter[i].append(max(0, meaningful_var_ratio))
52 | fout.write(json.dumps(data) + '\n')
53 |
54 | for i, scores in enumerate(score_counter):
55 | print(f'Update {i}: avg meaningful var ratio {sum(scores) / len(scores)}')
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
--------------------------------------------------------------------------------
/src/readability/prompts.py:
--------------------------------------------------------------------------------
1 | COUNT_VAR_PROMPT = '''
2 | """CODE SNIPPET"""
3 | import sys
4 | # Reads input from terminal and returns it
5 | def input():
6 | return sys.stdin.readline().strip()
7 | # finds a largest perfect power which is smaller
8 | # than the input number
9 | def resolve():
10 | # read input
11 | x=int(eval(input()))
12 | ans=0
13 | for i in range(1,33):
14 | for j in range(2,11):
15 | y=i**j
16 | if y= num_completions:
27 | return completions[:num_completions]
28 | except openai.error.RateLimitError as e:
29 | time.sleep(min(i**2, 60))
30 | raise RuntimeError('Failed to call GPT API')
--------------------------------------------------------------------------------
/src/responsegen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/madaan/self-refine/9a206d41e5d2d0c241bb441f41eeadb945afaa55/src/responsegen/__init__.py
--------------------------------------------------------------------------------
/src/responsegen/feedback.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from prompt_lib.backends import openai_api
3 |
4 | from src.utils import Prompt
5 |
6 |
7 | class ResponseGenFeedback(Prompt):
8 | def __init__(self, engine: str, prompt_examples: str, max_tokens: int = 400) -> None:
9 | super().__init__(
10 | question_prefix="",
11 | answer_prefix="",
12 | intra_example_sep="\n\n",
13 | inter_example_sep="\n\n###\n\n",
14 | )
15 | self.engine = engine
16 | self.max_tokens = max_tokens
17 | self.setup_prompt_from_examples_file(prompt_examples)
18 |
19 | def setup_prompt_from_examples_file(self, examples_path: str) -> str:
20 | template = """Conversation history:
21 |
22 | {history}
23 |
24 | Response: {response}
25 |
26 | Scores:
27 |
28 | * Relevant: {Relevant}
29 | * Informative: {Informative}
30 | * Interesting: {Interesting}
31 | * Consistent: {Consistent}
32 | * Helpful: {Helpful}
33 | * Engaging : {Engaging}
34 | * Specific: {Specific}
35 | * Safe: {Safe}
36 | * User understanding: {Userunderstanding}
37 | * Fluent: {Fluent}
38 | * Total score: {total_score}"""
39 | examples_df = pd.read_json(examples_path, orient="records")
40 | prompt = []
41 | for _, row in examples_df.iterrows():
42 | prompt.append(
43 | template.format(
44 | history=row['history'].replace('System: ', '').replace('User: ', ''),
45 | response=row["response"],
46 | Relevant=row["Relevant"],
47 | Informative=row["Informative"],
48 | Interesting=row["Interesting"],
49 | Consistent=row["Consistent"],
50 | Helpful=row["Helpful"],
51 | Engaging=row["Engaging"],
52 | Specific=row["Specific"],
53 | Safe=row["Safe"],
54 | Userunderstanding=row["Userunderstanding"],
55 | Fluent=row["Fluent"],
56 | total_score=row["total_score"],
57 | )
58 | )
59 |
60 | instruction = """We want to iteratively improve the provided responses. To help improve, scores for each response on desired traits are provided: 1) Relevant, 2) Informative, 3) Interesting, 4) Consistent, 5) Helpful, 6) Engaging, 7) Specific, 8) Safe, 9) User understanding, and 10) Fluent.
61 |
62 | Here are some examples of this scoring rubric:
63 |
64 | """
65 | self.prompt = instruction + self.inter_example_sep.join(prompt)
66 | self.prompt = self.inter_example_sep.join(prompt) + self.inter_example_sep
67 |
68 | def __call__(self, context: str, response: str):
69 | prompt = self.get_prompt_with_question(context=context, response=response)
70 |
71 | output = openai_api.OpenaiAPIWrapper.call(
72 | prompt=prompt,
73 | engine=self.engine,
74 | max_tokens=self.max_tokens,
75 | stop_token="###",
76 | temperature=0.7,
77 | )
78 |
79 | generated_feedback = openai_api.OpenaiAPIWrapper.get_first_response(output)
80 | generated_feedback = generated_feedback.split("Scores:")[1].strip()
81 | generated_feedback = generated_feedback.split("#")[0].strip()
82 |
83 | return output, generated_feedback
84 |
85 | def get_prompt_with_question(self, context: str, response: str):
86 | context = context.replace('System: ', '').replace('User: ', '')
87 | question = self.make_query(context=context, response=response)
88 | return f"""{self.prompt}{question}\n\n"""
89 |
90 | def make_query(self, context: str, response: str):
91 | question = f"""Conversation history:
92 |
93 | {context}
94 |
95 | Response: {response}"""
96 | return question
97 |
--------------------------------------------------------------------------------
/src/responsegen/run.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | import math
4 | import os
5 | import tqdm
6 | from typing import Any, Dict, List
7 | import pandas as pd
8 | import json
9 | from tqdm import tqdm
10 | from pandarallel import pandarallel
11 | import multiprocessing
12 | import traceback
13 | import argparse
14 |
15 | pandarallel.initialize(progress_bar=True, nb_workers=25)
16 |
17 |
18 | from src.responsegen.task_init import ResponseGenTaskInit
19 | from src.responsegen.task_iterate import ResponseGenTaskIterate
20 | from src.responsegen.feedback import ResponseGenFeedback
21 | from src.utils import retry_parse_fail_prone_cmd
22 |
23 | import openai
24 | import random
25 | import time
26 |
27 | openai.api_key = os.getenv("OPENAI_API_KEY")
28 |
29 | # check if orgainization is set
30 |
31 | if os.getenv("OPENAI_ORG") is not None:
32 | openai.organization = os.getenv("OPENAI_ORG")
33 |
34 | CODEX = "code-davinci-002"
35 | GPT3 = "text-davinci-003"
36 | ENGINE = CODEX#GPT3
37 | ENGINE = GPT3
38 |
39 | @retry_parse_fail_prone_cmd
40 | def iterative_response(context: str, max_attempts: int) -> str:
41 |
42 | # initialize all the required components
43 |
44 | # generation of the first response
45 | task_init = ResponseGenTaskInit(engine=ENGINE, prompt_examples="data/prompt/responsegen/init.jsonl")
46 |
47 | # getting feedback
48 | task_feedback = ResponseGenFeedback(engine=ENGINE, prompt_examples="data/prompt/responsegen/feedback.jsonl")
49 |
50 | # iteratively improving the response
51 | task_iterate = ResponseGenTaskIterate(engine=ENGINE, prompt_examples="data/prompt/responsegen/feedback.jsonl")
52 |
53 |
54 | # Initialize the task
55 |
56 | n_attempts = 0
57 |
58 | responses_to_scores = dict()
59 |
60 | all_responses_to_scores = dict()
61 | best_score_so_far = 0
62 | reduce_window = 0
63 | while n_attempts < max_attempts:
64 |
65 | if n_attempts == 0:
66 | metaoutput, response = task_init(context=context)
67 | else:
68 | metaoutput, response = task_iterate(responses_to_scores=responses_to_scores, reduce_window=reduce_window)
69 | # exit(0)
70 | #context = new_context
71 |
72 | print(f"\n{n_attempts} CONTEXT> {context} \n\n RESPONSE> {response} - NTOKENS> {metaoutput['usage']['total_tokens']}")
73 |
74 | if metaoutput['usage']['total_tokens'] >3000:
75 | reduce_window +=1
76 | if metaoutput['usage']['total_tokens'] >3500:
77 | reduce_window +=1
78 |
79 | feedbackmetaoutput, scores = task_feedback(context=context, response=response)
80 | print(f"\n{n_attempts} SCORES> {scores} - NTOKENS> {feedbackmetaoutput['usage']['total_tokens']}")
81 |
82 | total_score = re.search(r"Total score: (\d+)/(\d+)", scores).group(0)
83 | total_score = int(total_score.split(":")[1].strip().split("/")[0])
84 |
85 | all_responses_to_scores[response] = {
86 | "n_attempts": n_attempts,
87 | "scores": scores,
88 | "total_score": total_score,
89 | "context": context,
90 | }
91 | # rtokens, ftokens = metaoutput['usage']['total_tokens'], feedbackmetaoutput['usage']['total_tokens']
92 | if total_score >= 0: # only iterate over things that are improving
93 | best_score_so_far = total_score
94 |
95 | responses_to_scores[response] = (context, scores)
96 |
97 |
98 | else:
99 | print(f"Score of {response} is {total_score}, which is less than the current best of {best_score_so_far}")
100 |
101 | n_attempts += 1
102 | return all_responses_to_scores
103 |
104 |
105 |
106 | def run_dataset(max_attempts: int, outfile: str, max_size: int = 1):
107 |
108 | f = open('data/prompt/responsegen/fed_data.json')
109 | data = json.load(f)
110 | print('len of data', len(data))
111 | count=0
112 | outwriter = open(outfile, 'a')
113 |
114 | for i, example in enumerate(data[:]):
115 | if max_size!=0 and count>max_size: break
116 | print(f"\n\n\n****Instance: {i}****\n\n")
117 | if 'response' not in example: continue
118 | try:
119 | context = example["context"]
120 | if type(example["context"]) is str:
121 | context = example["context"].split("\n")
122 | if type(context) is list:
123 | context = "\n".join(context[-8:])
124 | all_responses_to_scores = iterative_response(context, max_attempts=max_attempts)
125 | if all_responses_to_scores is None:
126 | return {"result": ["FAILED"]}
127 |
128 | res = []
129 | scored_responses = {}
130 | for response, scores in all_responses_to_scores.items():
131 | res.append(f"{response} [score: {scores['total_score']}] \n {scores['scores']}")
132 | scored_responses[scores['n_attempts']]={'response':response, 'total_score':scores['total_score']}
133 | # append res to example
134 | example['generated_responses'] = "\n------\n".join(res)
135 | example['scored_responses'] = scored_responses
136 | outwriter.write(json.dumps(example)+'\n')
137 | print("\n ------ \n ".join(res))
138 | except Exception as e:
139 | print(f"error in {example}\n\n{e}", file=sys.stderr)
140 | traceback.print_exc()
141 | return {"result": ["FAILED"]}
142 | count+=1
143 | outwriter.close()
144 |
145 |
146 | if __name__ == "__main__":
147 | parser = argparse.ArgumentParser()
148 | parser.add_argument(
149 | "--max_attempts",
150 | type=int,
151 | default=3,
152 | help="Max attempts",
153 | )
154 | parser.add_argument(
155 | "--size",
156 | type=int,
157 | default=1,
158 | help="Test data size (0 means all data)",
159 | )
160 | parser.add_argument(
161 | "--output",
162 | type=str,
163 | default='./output-v3fedresponsegen406on.json',
164 | # required=True,
165 | help="Output file",
166 | )
167 |
168 | args = parser.parse_args()
169 |
170 | run_dataset(args.max_attempts, outfile=args.output, max_size=args.size)
--------------------------------------------------------------------------------
/src/responsegen/task_init.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from src.utils import Prompt
3 | from typing import List, Optional, Union
4 | import sys
5 | from prompt_lib.backends import openai_api
6 |
7 |
8 | class ResponseGenTaskInit(Prompt):
9 | def __init__(self, prompt_examples: str, engine: str, numexamples=3) -> None:
10 | super().__init__(
11 | question_prefix="Conversation history: ",
12 | answer_prefix="Response: ",
13 | intra_example_sep="\n\n",
14 | inter_example_sep="\n\n###\n\n",
15 | )
16 | self.engine = engine
17 | self.setup_prompt_from_examples_file(prompt_examples, numexamples=numexamples)
18 |
19 | def setup_prompt_from_examples_file(self, examples_path: str, numexamples=10) -> str:
20 | instruction = (
21 | "Provided a dialogue between two speakers, generate a response that is coherent with the dialogue history. Desired traits for responses are: 1) Relevant - The response addresses the context, 2) Informative - The response provides some information, 3) Interesting - The response is not interesting, 4) Consistent - The response is consistent with the rest of the conversation in terms of tone and topic, 5) Helpful - The response is helpful in providing any information or suggesting any actions, 6) Engaging - The response is not very engaging and does not encourage further conversation, 7) Specific - The response contains pecific content, 9) User understanding - The response demonstrates an understanding of the user's input and state of mind, and 10) Fluent. Response should begin with - Response:\n\n"
22 | )
23 |
24 | examples_df = pd.read_json(examples_path, orient="records", lines=True)
25 | prompt = []
26 | for i, row in examples_df.iterrows():
27 | if i >= numexamples:
28 | break
29 | prompt.append(self._build_query_from_example(row["history"], row["response"]))
30 |
31 | self.prompt = instruction + self.inter_example_sep.join(prompt) + self.inter_example_sep
32 |
33 | def _build_query_from_example(self, history: Union[str, List[str]], response: Optional[str]=None) -> str:
34 | history = history.replace('System: ', '').replace('User: ', '')
35 |
36 | TEMPLATE = """Conversation history:
37 |
38 | {history}
39 |
40 | Response: {response}"""
41 |
42 | query = TEMPLATE.format(history=history, response=response)
43 | return query
44 |
45 | def make_query(self, context: str) -> str:
46 | context = context.replace('System: ', '').replace('User: ', '')
47 | query = f"{self.prompt}{self.question_prefix}\n\n{context}{self.intra_example_sep}"
48 | return query
49 |
50 | def __call__(self, context: str) -> str:
51 | generation_query = self.make_query(context)
52 | output = openai_api.OpenaiAPIWrapper.call(
53 | prompt=generation_query,
54 | engine=self.engine,
55 | max_tokens=800,
56 | stop_token="###",
57 | temperature=0.7,
58 | )
59 |
60 | generated_response = openai_api.OpenaiAPIWrapper.get_first_response(output)
61 |
62 | generated_response = generated_response.split(self.answer_prefix)[1].replace("#", "").strip()
63 |
64 |
65 | return output, generated_response.strip()
66 |
--------------------------------------------------------------------------------
/src/responsegen/task_iterate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Dict, List
3 | from src.utils import Prompt
4 |
5 | from prompt_lib.backends import openai_api
6 |
7 |
8 | class ResponseGenTaskIterate(Prompt):
9 | def __init__(self, engine: str, prompt_examples: str) -> None:
10 | super().__init__(
11 | question_prefix="",
12 | answer_prefix="",
13 | intra_example_sep="\n\n",
14 | inter_example_sep="\n\n###\n\n",
15 | )
16 | self.engine = engine
17 | self.count = 0
18 | self.prompt = self.make_prompt(prompt_examples=prompt_examples)
19 |
20 | def make_prompt(self, prompt_examples: str, reduce_window=0) -> str:
21 | import pandas as pd
22 | prompt_examples = pd.read_json(prompt_examples, orient="records")
23 | prompt_examples = prompt_examples[reduce_window:]
24 | # group on example
25 | grouped = prompt_examples.groupby("example")
26 |
27 | prompt = []
28 | # sort each group by score
29 | for _, group in grouped:
30 | group["numerical_score"] = group["total_score"].apply(lambda x: int(x.split("/")[0].strip()))
31 | group = group.sort_values("numerical_score")
32 | prompt.append(self.make_one_iterate_example(group.to_dict("records")))
33 |
34 | return self.inter_example_sep.join(prompt) + self.inter_example_sep
35 |
36 |
37 | def make_one_iterate_example(self, incrementally_improving_examples: List[Dict]):
38 | """Given a list of examples that are incrementally improving, return a new example.
39 | """
40 |
41 | instr = """We want to iteratively improve the provided responses. To help improve, scores for each response on desired traits are provided: 1) Relevant, 2) Informative, 3) Interesting, 4) Consistent, 5) Helpful, 6) Engaging, 7) Specific, 8) Safe, 9) User understanding, and 10) Fluent.
42 |
43 | """
44 | template = """Conversation history:
45 |
46 | {history}
47 |
48 | Response: {response}
49 |
50 | Scores:
51 |
52 | * Relevant: {Relevant}
53 | * Informative: {Informative}
54 | * Interesting: {Interesting}
55 | * Consistent: {Consistent}
56 | * Helpful: {Helpful}
57 | * Engaging : {Engaging}
58 | * Specific: {Specific}
59 | * Safe: {Safe}
60 | * User understanding: {Userunderstanding}
61 | * Fluent: {Fluent}
62 |
63 | * Total score: {total_score}
64 |
65 | Okay, let's use this feedback to improve the response.
66 |
67 | """
68 | prompt = []
69 | for row in incrementally_improving_examples:
70 | prompt.append(
71 | template.format(
72 | history=row['history'].replace('System: ', '').replace('User: ', ''),
73 | response=row["response"],
74 | Relevant=row["Relevant"],
75 | Informative=row["Informative"],
76 | Interesting=row["Interesting"],
77 | Consistent=row["Consistent"],
78 | Helpful=row["Helpful"],
79 | Engaging=row["Engaging"],
80 | Specific=row["Specific"],
81 | Safe=row["Safe"],
82 | Userunderstanding=row["Userunderstanding"],
83 | Fluent=row["Fluent"],
84 | total_score=row["total_score"],
85 | )
86 | )
87 |
88 |
89 | prompt = "".join(prompt)
90 | prompt = instr + prompt
91 | return prompt.strip()
92 |
93 | def make_query(self, question: str, reduce_window=0) -> str:
94 | if reduce_window>0:
95 | self.prompt = self.make_prompt(prompt_examples="data/prompt/responsegen/feedback.jsonl", reduce_window=reduce_window)
96 | question = question.replace('System: ', '').replace('User: ', '')
97 | return f"{self.prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}"
98 | # return super().make_query(prompt, question)
99 |
100 | def _make_input(
101 | self,
102 | context: str,
103 | response: str,
104 | scores: str,
105 | ) -> str:
106 | context = context.replace('System: ', '').replace('User: ', '')
107 | input_txt = f"""Conversation history:
108 |
109 | {context}
110 |
111 | Response: {response}
112 |
113 | Scores:
114 |
115 | {scores}
116 |
117 | Okay, let's use this feedback to improve the response.
118 |
119 | Conversation history:
120 |
121 | {context}
122 | """
123 |
124 | return input_txt
125 |
126 | def __call__(
127 | self,
128 | responses_to_scores: Dict[str, str],
129 | reduce_window=0
130 | ) -> str:
131 | example_input = self.make_input(
132 | responses_to_scores=responses_to_scores
133 | )
134 | transfer_query = self.make_query(example_input, reduce_window=reduce_window)
135 | self.count += 1
136 | with open(f"responses_iterate_{self.count}.txt", "w") as f:
137 | f.write(transfer_query + "\n")
138 | output = openai_api.OpenaiAPIWrapper.call(
139 | prompt=transfer_query,
140 | engine=self.engine,
141 | max_tokens=200,
142 | stop_token=self.inter_example_sep,
143 | temperature=0.7,
144 | )
145 | modelresponse = openai_api.OpenaiAPIWrapper.get_first_response(output)
146 | response = modelresponse.split("Response:")[1].strip().split("\n")[0].strip()
147 |
148 |
149 | return output, response.strip()
150 |
151 | def make_input(
152 | self,
153 | responses_to_scores: Dict[str, str],
154 | ) -> str:
155 | input_txt = ""
156 | for response, (context, scores) in responses_to_scores.items():
157 | context = context.replace('System: ', '').replace('User: ', '')
158 | input_txt += self._make_input(
159 | context=context,
160 | response=response,
161 | scores=scores,
162 | )
163 | return input_txt
164 |
165 |
166 |
167 |
168 |
169 |
170 | if __name__ == "__main__":
171 | obj = ResponseGenTaskIterate(prompt_examples="data/prompt/acronym/feedback.v2.jsonl", engine="whatever")
172 | print(obj.prompt)
--------------------------------------------------------------------------------
/src/sentiment_reversal/measure.py:
--------------------------------------------------------------------------------
1 | from src.utils import Prompt
2 | from prompt_lib.backends import router
3 |
4 | class SentimentTransferMeasurement(Prompt):
5 | def __init__(self, engine) -> None:
6 |
7 | super().__init__(
8 | question_prefix="Review: ",
9 | answer_prefix="Output: ",
10 | intra_example_sep="\n",
11 | inter_example_sep="\n\n\n",
12 | )
13 | self.engine = engine
14 | self.final_answer_prefix = "The sentiment is "
15 | self.load_prompts()
16 |
17 | def load_prompts(self):
18 | self.prompt = MeasurementPrompt.get_prompt()
19 |
20 | def make_query(self, question: str) -> str:
21 | return super().make_query(self.prompt, question)
22 |
23 | def get_sentiment_from_output(self, output: str):
24 | return output.split(self.final_answer_prefix)[-1].strip()
25 |
26 | def make_input(self, input_sent: str):
27 | return f"""Review: {input_sent}"""
28 |
29 | def make_output(self, sentiment_level: str):
30 |
31 | sentiment_level_to_prefix = {
32 | "Very negative": "The review sounds very toxic.",
33 | "Negative": "The review sounds somewhat negative.",
34 | "Neutral": "The review sounds neutral.",
35 | "Positive": "The review sounds somewhat favorable.",
36 | "Very positive": "The review sounds glowingly positive.",
37 | }
38 |
39 | prefix = sentiment_level_to_prefix[sentiment_level]
40 |
41 | return f"""Output: {prefix}. The sentiment is {sentiment_level}"""
42 |
43 | def __call__(self, review: str):
44 | measurement_query = self.make_input(review)
45 | measurement_query = self.make_query(measurement_query)
46 | measured_sentiment = router.call(
47 | prompt=measurement_query,
48 | engine=self.engine,
49 | max_tokens=50,
50 | stop_token=self.inter_example_sep,
51 | temperature=0.7,
52 | )
53 |
54 | return measured_sentiment
55 |
56 |
57 |
58 | class MeasurementPrompt:
59 |
60 | @staticmethod
61 | def get_prompt():
62 | return """Review: If you ever wondered where the magic of Vegas crawled into a hole to rot, look no further. Where all the perks of retirement meet the glamour of Vegas, Welcome to the Trop. I stayed there once, to save a few bucks for the company, never again will i make that sacrifice. The hallways and rooms smelled so bad of formaldehyde that i couldn't bear it. Their food court and restaurants are terrible. The free Folies Bergere show is lame. The rooms look like a retirement community from florida. steer clear if you can!
63 | Output: The review sounds very toxic. The sentiment is Very negative
64 |
65 |
66 | Review: If you ever stayed at the Trop, you may have noticed that it's not quite up to the standards of other Vegas hotels. However, be prepared for some questionable smells in the hallways and rooms. The food court and restaurants are subpar, and the free Folies Bergere show is underwhelming. The rooms have a retirement community feel to them. Overall, it's not the best option, but it may work in a pinch.
67 | Output: The review sounds somewhat negative. The sentiment is Negative
68 |
69 |
70 | Review: If you're looking for a budget-friendly option in Vegas, the Trop may be worth considering. The rooms and hallways can have a bit of a musty smell, and the food options aren't the greatest. The Folies Bergere show is free, but it's not the most exciting. Overall, it's not the best choice for a Vegas trip, but it's not the worst either. Just keep your expectations in check.
71 | Output: The review sounds neutral. The sentiment is Neutral
72 |
73 |
74 | Review: If you're looking for a unique and affordable experience in Vegas, the Trop may be the perfect place for you. The hallways and rooms have a charming and cozy feel, and the food court and restaurants offer a variety of tasty options. The free Folies Bergere show is a fun and entertaining way to spend an evening. Overall, it's a great value and an enjoyable stay.
75 | Output: The review sounds somewhat favorable. The sentiment is Positive
76 |
77 |
78 | Review: If you're looking for a truly magical experience in Vegas, look no further than the Trop! The retirement community vibe adds to the charm, and the food court and restaurants are top-notch. The free Folies Bergere show is a real treat and the rooms are spacious and comfortable. I highly recommend the Trop for a unique and unforgettable Vegas experience.
79 | Output: The review sounds glowingly positive. The sentiment is Very positive
80 |
81 |
82 | """
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import traceback
2 |
3 | class Prompt:
4 | def __init__(
5 | self,
6 | question_prefix: str,
7 | answer_prefix: str,
8 | intra_example_sep: str,
9 | inter_example_sep: str,
10 | engine: str = None,
11 | temperature: float = None,
12 | ) -> None:
13 | self.question_prefix = question_prefix
14 | self.answer_prefix = answer_prefix
15 | self.intra_example_sep = intra_example_sep
16 | self.inter_example_sep = inter_example_sep
17 | self.engine = engine
18 | self.temperature = temperature
19 |
20 | def make_query(self, prompt: str, question: str) -> str:
21 | return (
22 | f"{prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}"
23 | )
24 |
25 |
26 | def retry_parse_fail_prone_cmd(
27 | func,
28 | max_retries: int = 3,
29 | exceptions=(
30 | ValueError,
31 | KeyError,
32 | IndexError,
33 | ),
34 | ):
35 | def wrapper(*args, **kwargs):
36 | retries = max_retries
37 | while retries:
38 | try:
39 | return func(*args, **kwargs)
40 | except exceptions as e:
41 | stack_trace = traceback.format_exc()
42 |
43 | retries -= 1
44 | print(f"An error occurred: {e}. {stack_trace}. Left retries: {retries}.")
45 | return None
46 |
47 | return wrapper
48 |
--------------------------------------------------------------------------------