├── .gitignore ├── LICENSE ├── README.md ├── prompts ├── CBT_1_guided_discovery.txt ├── CBT_2_focus.txt ├── CBT_3_strategy.txt ├── agent_cactus_chatgpt.txt ├── agent_cactus_llama2.txt ├── agent_cactus_llama3.txt ├── agent_cbt_chatgpt.txt ├── agent_cbt_llama2.txt ├── agent_cbt_llama3.txt ├── agent_client.txt ├── agent_psych8k_llama2.txt ├── agent_psych8k_llama3.txt ├── agent_smilechat_llama2.txt ├── agent_smilechat_llama3.txt ├── general_1_understanding.txt ├── general_2_interpersonal_effectiveness.txt ├── general_3_collaboration.txt ├── panas_after.txt └── panas_before.txt ├── requirements.txt ├── resource └── dataset │ ├── evaluation.json │ └── evaluation_chinese.json ├── scripts ├── calculate_panas_score.sh ├── geval_panas_after.sh ├── geval_panas_before.sh ├── geval_total.sh ├── inference.sh └── run_vllm.sh ├── src ├── calculate_panas_score.py ├── evaluation.py ├── get_score.py ├── geval_panas_after.py ├── geval_panas_before.py ├── inference.py └── utils │ └── config.py └── test └── test_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,python,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,python,visualstudiocode 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### Python ### 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | cover/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | .pybuilder/ 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | # For a library or package, you might want to ignore these files since the code is 106 | # intended to run in multiple environments; otherwise, check them in: 107 | # .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # poetry 117 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 118 | # This is especially recommended for binary packages to ensure reproducibility, and is more 119 | # commonly ignored for libraries. 120 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 121 | #poetry.lock 122 | 123 | # pdm 124 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 125 | #pdm.lock 126 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 127 | # in version control. 128 | # https://pdm.fming.dev/#use-with-ide 129 | .pdm.toml 130 | 131 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 132 | __pypackages__/ 133 | 134 | # Celery stuff 135 | celerybeat-schedule 136 | celerybeat.pid 137 | 138 | # SageMath parsed files 139 | *.sage.py 140 | 141 | # Environments 142 | .env 143 | .venv 144 | env/ 145 | venv/ 146 | ENV/ 147 | env.bak/ 148 | venv.bak/ 149 | 150 | # Spyder project settings 151 | .spyderproject 152 | .spyproject 153 | 154 | # Rope project settings 155 | .ropeproject 156 | 157 | # mkdocs documentation 158 | /site 159 | 160 | # mypy 161 | .mypy_cache/ 162 | .dmypy.json 163 | dmypy.json 164 | 165 | # Pyre type checker 166 | .pyre/ 167 | 168 | # pytype static type analyzer 169 | .pytype/ 170 | 171 | # Cython debug symbols 172 | cython_debug/ 173 | 174 | # PyCharm 175 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 176 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 177 | # and can be added to the global gitignore or merged into this file. For a more nuclear 178 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 179 | #.idea/ 180 | 181 | ### Python Patch ### 182 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 183 | poetry.toml 184 | 185 | # ruff 186 | .ruff_cache/ 187 | 188 | # LSP config files 189 | pyrightconfig.json 190 | 191 | ### VisualStudioCode ### 192 | .vscode/* 193 | !.vscode/settings.json 194 | !.vscode/tasks.json 195 | !.vscode/launch.json 196 | !.vscode/extensions.json 197 | !.vscode/*.code-snippets 198 | 199 | # Local History for Visual Studio Code 200 | .history/ 201 | 202 | # Built Visual Studio Code Extensions 203 | *.vsix 204 | 205 | ### VisualStudioCode Patch ### 206 | # Ignore all local history of files 207 | .history 208 | .ionide 209 | 210 | # End of https://www.toptal.com/developers/gitignore/api/linux,python,visualstudiocode 211 | 212 | # pycharm 213 | .idea/ 214 | 215 | conf.d/config.yaml 216 | 217 | .idea/ 218 | nohup.output -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Cactus: Towards Psychological Counseling Conversations using Cognitive Behavioral Theory

2 |

3 | 4 |

5 | 6 | This is the official GitHub repository for [Cactus: Towards Psychological Counseling Conversations using Cognitive Behavioral Theory](https://arxiv.org/abs/2407.03103) accepted at EMNLP Findings 2024. 7 | 8 | # Citation 9 | ``` 10 | @misc{lee2024cactus, 11 | title={Cactus: Towards Psychological Counseling Conversations using Cognitive Behavioral Theory}, 12 | author={Suyeon Lee and Sunghwan Kim and Minju Kim and Dongjin Kang and Dongil Yang and Harim Kim and Minseok Kang and Dayi Jung and Min Hee Kim and Seungbeen Lee and Kyoung-Mee Chung and Youngjae Yu and Dongha Lee and Jinyoung Yeo}, 13 | year={2024}, 14 | eprint={2407.03103}, 15 | archivePrefix={arXiv}, 16 | primaryClass={cs.CL}, 17 | url={https://arxiv.org/abs/2407.03103}, 18 | } 19 | ``` 20 | 21 | # Link 22 | Our dataset & model are available [here](https://huggingface.co/collections/DLI-Lab/cactus-towards-psychological-counseling-conversations-6672312f6f64b0d7be75dd0b). 23 | 24 | # CACTUS Inference README 25 | 26 | ## Setup 27 | 28 | ### 1. Virtual Environment Setup 29 | 30 | We recommend you create a virtual environment using `conda` or `virtualenv`. 31 | 32 | #### Using Conda 33 | 34 | ```sh 35 | conda create -n therapy-session python=3.8 36 | conda activate therapy-session 37 | ``` 38 | 39 | #### Using Virtualenv 40 | 41 | ```sh 42 | # if virtualenv is not installed 43 | pip install virtualenv 44 | 45 | # Create a virtual environment 46 | virtualenv .venv 47 | source .venv/bin/activate # Linux & macOS 48 | .venv\Scripts\activate # Windows 49 | ``` 50 | 51 | ### 2. Installing Required Packages 52 | 53 | After activating the virtual environment, install the necessary packages using the `requirements.txt` file. 54 | 55 | ```sh 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | ### 3. Configuring the Settings File 60 | 61 | Copy the `config.yaml.example` file in the conf.d folder to create a `config.yaml` file. Then, fill in the following content in the `config.yaml` file. 62 | 63 | ```yaml 64 | openai: 65 | key: <> 66 | 67 | llama2: 68 | host: http://<>/v1 69 | 70 | llama3: 71 | host: http://<>/v1 72 | ``` 73 | 74 | ## Adding a Counselor Agent 75 | 76 | To add a counselor agent, follow these steps. 77 | 78 | ### 1. Creating a Prompt File 79 | 80 | - The prompt file should be located in the `prompts` folder. 81 | - The file name pattern should follow the format `agent_{counselor_type}_{llm_type}.txt`. 82 | Example: `agent_cactus_chatgpt.txt` 83 | - The prompt file should include a template for generating the counselor's response. 84 | 85 | ```text 86 | Client information: {client_information} 87 | Reason for counseling: {reason_counseling} 88 | CBT plan: {cbt_plan} 89 | History: {history} 90 | ``` 91 | 92 | ### 2. Adding a New Counselor Agent Class 93 | 94 | Create a new counselor agent class by inheriting from the `CounselorAgent` class. Ensure to set `self.language` to either `english` for English or `chinese` for Chinese. 95 | 96 | ```python 97 | class NewCounselorAgent(CounselorAgent): 98 | def __init__(self, llm_type): 99 | super().__init__(llm_type) 100 | self.language = "english" # For English 101 | # self.language = "chinese" # For Chinese 102 | prompt_text = self.load_prompt(f"agent_new_{llm_type}.txt") 103 | self.prompt_template = PromptTemplate( 104 | input_variables=["history"], 105 | template=prompt_text) 106 | 107 | def generate(self, history): 108 | # Override the generate function if necessary 109 | history = '\n'.join( 110 | [ 111 | f"{message['role'].capitalize()}: {message['message']}" 112 | for message in history 113 | ] 114 | ) 115 | prompt = self.prompt_template.format(history=history) 116 | return self.llm.generate(prompt) 117 | ``` 118 | 119 | ### 3. Adding the New Counselor to LLMFactory 120 | 121 | Add the new counselor agent to the `LLMFactory` class. 122 | 123 | ```python 124 | class LLMFactory: 125 | @staticmethod 126 | def get_llm(llm_type): 127 | if llm_type == "chatgpt": 128 | return ChatGPT() 129 | elif llm_type == "llama2": 130 | return LLama2() 131 | elif llm_type == "llama3": 132 | return LLama3() 133 | elif llm_type == "new": 134 | return NewCounselorAgent(llm_type) 135 | raise ValueError(f"Unsupported LLM type: {llm_type}") 136 | ``` 137 | 138 | ## Adding a New LLM 139 | 140 | To add a new LLM, follow these steps. 141 | 142 | ### 1. Creating a New LLM Class 143 | 144 | Create a new LLM class by inheriting from the `LLM` abstract class. 145 | 146 | ```python 147 | class NewLLM(LLM): 148 | def __init__(self): 149 | config = get_config() 150 | api_key = config['new']['key'] 151 | self.llm = OpenAI( 152 | temperature=0.7, 153 | model_name="new-model", 154 | openai_api_key=api_key 155 | ) 156 | 157 | def generate(self, prompt: str) -> str: 158 | response = self.llm.invoke(prompt) 159 | return response.content 160 | ``` 161 | 162 | ### 2. Adding the New LLM to LLMFactory 163 | 164 | Add the new LLM to the `LLMFactory` class. 165 | 166 | ```python 167 | class LLMFactory: 168 | @staticmethod 169 | def get_llm(llm_type): 170 | if llm_type == "chatgpt": 171 | return ChatGPT() 172 | elif llm_type == "llama2": 173 | return LLama2() 174 | elif llm_type == "llama3": 175 | return LLama3() 176 | elif llm_type == "new": 177 | return NewLLM() 178 | raise ValueError(f"Unsupported LLM type: {llm_type}") 179 | ``` 180 | 181 | ## Running Counseling-Eval 182 | 183 | ### 1. Prepare necessary files and folders 184 | 185 | - Ensure the necessary prompt files are available in the `prompts` folder. 186 | - The input file should be a JSON file containing the client intake form. 187 | 188 | ### 2. Run the program 189 | 190 | Run the program using the following command. 191 | 192 | ```sh 193 | python script.py --input_file {path to input file} --output_dir {output directory} --counselor_type {counselor type} --llm_type {LLM type} --max_turns {maximum number of turns} 194 | ``` 195 | 196 | Example: 197 | 198 | ```sh 199 | python script.py --input_file ./data/intake_forms.json --output_dir ./output --counselor_type cactus --llm_type chatgpt --max_turns 20 200 | ``` 201 | 202 | ### 3. Using the Execution Script 203 | 204 | You can use the 'scripts/inference.sh' script for easy execution. Run it as follows: 205 | 206 | ```sh 207 | sh scripts/inference.sh 208 | ``` 209 | 210 | ### 4. Running the vLLM Server 211 | 212 | All models except `chatgpt` (such as `llama2`, `llama3`, etc.) need to run on the vLLM server. Refer to the `scripts/run_vllm.sh` script for this. 213 | 214 | ```sh 215 | sh scripts/run_vllm.sh 216 | ``` 217 | 218 | This script includes all the commands necessary to set up and run the vLLM server. With the vLLM server running, you can simulate the counseling session using the program. 219 | -------------------------------------------------------------------------------- /prompts/CBT_1_guided_discovery.txt: -------------------------------------------------------------------------------- 1 | I want you to act as an evaluator. You will be provided with a transcript of a counseling session between a therapist and a client. Your task is to assess the therapist based on the given criteria. If you believe the therapist falls between two of the descriptors, select the intervening odd number (1, 3, 5). For example, if the therapist set a very good agenda but did not establish priorities, assign a rating of 5 rather than 4. 2 | 3 | Please follow these steps: 4 | 5 | 1. Read the counseling session transcript carefully. 6 | 2. Review the evaluation questions and criteria provided below. 7 | 3. Assign a score based on the criteria, grading very strictly and uptight. If there is any deficiency, no matter how minor, assign a score of 4 or lower. 8 | 4. Output the score and the explanation, separated by a comma. Do not add any prefix. 9 | 10 | [Counseling conversation] 11 | {conversation} 12 | 13 | [Evaluation Question] 14 | How effectively does the therapist use guided discovery techniques to facilitate client self-reflection and insight? 15 | 16 | [criteria] 17 | Score 0: Therapist relied primarily on debate, persuasion, or "lecturing." Therapist seemed to be "cross-examining" patient, putting the patient on the defensive, or forcing his/her point of view on the patient. 18 | Score 2: Therapist relied too heavily on persuasion and debate, rather than guided discovery. However, therapist's style was supportive enough that patient did not seem to feel attacked or defensive. 19 | Score 4: Therapist, for the most part, helped patient see new perspectives through guided discovery (e.g., examining evidence, considering alternatives, weighing advantages and disadvantages) rather than through debate. Used questioning appropriately. 20 | Score 6: Therapist was especially adept at using guided discovery during the session to explore problems and help patient draw his/her own conclusions. Achieved an excellent balance between skillful questioning and other modes of intervention. 21 | 22 | Do not forget to give a score strictly and uptight. 23 | [Output] -------------------------------------------------------------------------------- /prompts/CBT_2_focus.txt: -------------------------------------------------------------------------------- 1 | I want you to act as an evaluator. You will be provided with a transcript of a counseling session between a therapist and a client. Your task is to assess the therapist based on the given criteria. If you believe the therapist falls between two of the descriptors, select the intervening odd number (1, 3, 5). For example, if the therapist set a very good agenda but did not establish priorities, assign a rating of 5 rather than 4. 2 | 3 | Please follow these steps: 4 | 5 | 1. Read the counseling session transcript carefully. 6 | 2. Review the evaluation questions and criteria provided below. 7 | 3. Assign a score based on the criteria, grading very strictly. If there is any deficiency, no matter how minor, assign a score of 4 or lower. 8 | 4. Output the score and the explanation, separated by a comma. Do not add any prefix. 9 | 10 | [Counseling conversation] 11 | {conversation} 12 | 13 | [Evaluation Question] 14 | How well does the therapist identify and address the client’s key cognitions or behaviors that need change? 15 | 16 | [criteria] 17 | Score 0: Therapist did not attempt to elicit specific thoughts, assumptions, images, meanings, or behaviors. 18 | Score 2: Therapist used appropriate techniques to elicit cognitions or behaviors; however, therapist had difficulty finding a focus or focused on cognitions/behaviors that were irrelevant to the patient’s key problems. 19 | Score 4: Therapist focused on specific cognitions or behaviors relevant to the target problem. However, therapist could have focused on more central cognitions or behaviors that offered greater promise for progress. 20 | Score 6: Therapist very skillfully focused on key thoughts, assumptions, behaviors, etc. that were most relevant to the problem area and offered considerable promise for progress. -------------------------------------------------------------------------------- /prompts/CBT_3_strategy.txt: -------------------------------------------------------------------------------- 1 | I want you to act as an evaluator. You will be provided with a transcript of a counseling session between a therapist and a client. Your task is to assess the therapist based on the given criteria. If you believe the therapist falls between two of the descriptors, select the intervening odd number (1, 3, 5). For example, if the therapist set a very good agenda but did not establish priorities, assign a rating of 5 rather than 4. 2 | 3 | Please follow these steps: 4 | 5 | 1. Read the counseling session transcript carefully. 6 | 2. Review the evaluation questions and criteria provided below. 7 | 3. Assign a score based on the criteria, grading very strictly. If there is any deficiency, no matter how minor, assign a score of 4 or lower. 8 | 4. Output the score and the explanation, separated by a comma. Do not add any prefix. 9 | 10 | [Counseling conversation] 11 | {conversation} 12 | 13 | [Evaluation Question] 14 | How appropriate and coherent is the therapist's strategy for promoting change in the client's problematic behaviors or thoughts? 15 | 16 | [criteria] 17 | Score 0: Therapist did not select cognitive-behavioral techniques. 18 | Score 2: Therapist selected cognitive-behavioral techniques; however, either the overall strategy for bringing about change seemed vague or did not seem promising in helping the patient. 19 | Score 4: Therapist seemed to have a generally coherent strategy for change that showed reasonable promise and incorporated cognitive-behavioral techniques. 20 | Score 6: Therapist followed a consistent strategy for change that seemed very promising and incorporated the most appropriate cognitive-behavioral techniques. -------------------------------------------------------------------------------- /prompts/agent_cactus_chatgpt.txt: -------------------------------------------------------------------------------- 1 | You are playing the role of a counselor in a psychological counseling session. Your task is to use the provided client information and counseling planning to generate the next counselor utterance in the dialogue. The goal is to create a natural and engaging response that builds on the previous conversation and aligns with the counseling plan. 2 | 3 | Client Information: 4 | {client_information} 5 | 6 | Reason for seekeing counseling: 7 | {reason_counseling} 8 | 9 | Counseling planning: 10 | {cbt_plan} 11 | 12 | Counseling Dialogue: 13 | {history} -------------------------------------------------------------------------------- /prompts/agent_cactus_llama2.txt: -------------------------------------------------------------------------------- 1 | [INST] <>You are playing the role of a counselor in a psychological counseling session. Your task is to use the provided client information and counseling planning to generate the next counselor utterance in the dialogue. The goal is to create a natural and engaging response that builds on the previous conversation and aligns with the counseling plan.<> 2 | 3 | Client Information: 4 | {client_information} 5 | 6 | Reason for seekeing counseling: 7 | {reason_counseling} 8 | 9 | Counseling planning: 10 | {cbt_plan} 11 | 12 | Counseling Dialogue: 13 | {history} [/INST] 14 | -------------------------------------------------------------------------------- /prompts/agent_cactus_llama3.txt: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | You are playing the role of a counselor in a psychological counseling session. Your task is to use the provided client information and counseling planning to generate the next counselor utterance in the dialogue. The goal is to create a natural and engaging response that builds on the previous conversation and aligns with the counseling plan.<|eot_id|><|start_header_id|>user<|end_header_id|> 4 | 5 | Client Information: 6 | {client_information} 7 | 8 | Reason for seekeing counseling: 9 | {reason_counseling} 10 | 11 | Counseling planning: 12 | {cbt_plan} 13 | 14 | Counseling Dialogue: 15 | {history}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 16 | 17 | -------------------------------------------------------------------------------- /prompts/agent_cbt_chatgpt.txt: -------------------------------------------------------------------------------- 1 | You are a counselor specializing in CBT techniques. Your task is to use the provided client information, and dialogue to generate an appropriate CBT technique and a detailed counseling plan. 2 | 3 | Types of CBT Techniques: 4 | Efficiency Evaluation, Pie Chart Technique, Alternative Perspective, Decatastrophizing, Pros and Cons Analysis, Evidence-Based Questioning, Reality Testing, Continuum Technique, Changing Rules to Wishes, Behavior Experiment, Problem-Solving Skills Training, Systematic Exposure 5 | 6 | Client Information: 7 | {client_information} 8 | 9 | Reason for seeking counseling: 10 | {reason_counseling} 11 | 12 | Counseling Dialogue: 13 | {history} 14 | 15 | Choose an appropriate CBT technique and create a counseling plan based on that technique. 16 | 17 | Respond in the following format: 18 | 19 | CBT technique: 20 | {{selected_cbt}} 21 | 22 | Counseling planning: 23 | {{generated_cbt_plan}} 24 | 25 | -------------------------------------------------------------------------------- /prompts/agent_cbt_llama2.txt: -------------------------------------------------------------------------------- 1 | [INST] <>You are a counselor specializing in CBT techniques. Your task is to use the provided client information, and dialogue to generate an appropriate CBT technique and a detailed counseling plan.<> 2 | 3 | Types of CBT Techniques: 4 | Efficiency Evaluation, Pie Chart Technique, Alternative Perspective, Decatastrophizing, Pros and Cons Analysis, Evidence-Based Questioning, Reality Testing, Continuum Technique, Changing Rules to Wishes, Behavior Experiment, Problem-Solving Skills Training, Systematic Exposure 5 | 6 | Client Information: 7 | {client_information} 8 | 9 | Reason for seekeing counseling: 10 | {reason_counseling} 11 | 12 | Counseling Dialogue: 13 | {history} 14 | 15 | Choose an appropriate CBT technique and create a counseling plan based on that technique. [/INST] 16 | -------------------------------------------------------------------------------- /prompts/agent_cbt_llama3.txt: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | You are a counselor specializing in CBT techniques. Your task is to use the provided client information, and dialogue to generate an appropriate CBT technique and a detailed counseling plan.<|eot_id|><|start_header_id|>user<|end_header_id|> 4 | 5 | Types of CBT Techniques: 6 | Efficiency Evaluation, Pie Chart Technique, Alternative Perspective, Decatastrophizing, Pros and Cons Analysis, Evidence-Based Questioning, Reality Testing, Continuum Technique, Changing Rules to Wishes, Behavior Experiment, Problem-Solving Skills Training, Systematic Exposure 7 | 8 | Client Information: 9 | {client_information} 10 | 11 | Reason for seekeing counseling: 12 | {reason_counseling} 13 | 14 | Counseling Dialogue: 15 | {history} 16 | 17 | Choose an appropriate CBT technique and create a counseling plan based on that technique.<|eot_id|><|start_header_id|>assistant<|end_header_id|> 18 | 19 | -------------------------------------------------------------------------------- /prompts/agent_client.txt: -------------------------------------------------------------------------------- 1 | You are playing the role of a client in a psychological counseling session. Your task is to generate only one suitable response based on the following the counseling dialogue history. 2 | 3 | ## Guidelines for the client's utterance: 4 | 1. Engage authentically with the counselor's inquiries, reflecting the complexity of emotions and reactions typical in counseling sessions. 5 | 2. Start the client's utterance with 'Client:'. Ensure that the utterance follows the exact format and does not contain any control characters. 6 | 3. The client should maintain the following attitude. 7 | 8 | If you feel that the counseling session has completely ended and meets the end condition, you should include '[/END]' with your utterance. 9 | ***End Conditions:*** 10 | - The client feels that their negative thoughts have been resolved. 11 | - The client feels that no further counseling is needed 12 | 13 | Please be mindful of these conditions and ensure that ***the session should not end prematurely; it must last at least 20 turns.***. 14 | 15 | Client Persona and Negative Thoughts: 16 | {intake_form} 17 | 18 | Client's Attitude Towards Counseling: 19 | {attitude} 20 | 21 | Generate only the client's utterance for a single turn and please ensure that your responses do not repeat the client's previous utterances. Do not generate the counselor's part of the dialogue. 22 | 23 | Counseling Dialogue History: 24 | {history} -------------------------------------------------------------------------------- /prompts/agent_psych8k_llama2.txt: -------------------------------------------------------------------------------- 1 | [INST] <>If you are a counsellor, please answer the questions based on the description of the patient.<> 2 | 3 | Input: {history} [/INST] -------------------------------------------------------------------------------- /prompts/agent_psych8k_llama3.txt: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | If you are a counsellor, please answer the questions based on the description of the patient.<|eot_id|><|start_header_id|>user<|end_header_id|> 4 | 5 | Input: {history}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 6 | 7 | -------------------------------------------------------------------------------- /prompts/agent_smilechat_llama2.txt: -------------------------------------------------------------------------------- 1 | [INST] <>现在你扮演一位专业的心理咨询师,你具备丰富的心理学和心理健康知识。你擅长运用多种心理咨询技巧,例如认知行为疗法原则、动机访谈技巧和解决问题导向的短期疗法。以温暖亲切的语气,展现出共情和对来访者感受的深刻理解。以自然的方式与来访者进行对话,避免过长或过短的回应,确保回应流畅且类似人类的对话。提供深层次的指导和洞察,使用具体的心理概念和例子帮助来访者更深入地探索思想和感受。避免教导式的回应,更注重共情和尊重来访者的感受。根据来访者的反馈调整回应,确保回应贴合来访者的情境和需求。请为以下的对话生成一个回复。<> 2 | 3 | 对话 4 | {history} [/INST] -------------------------------------------------------------------------------- /prompts/agent_smilechat_llama3.txt: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | 现在你扮演一位专业的心理咨询师,你具备丰富的心理学和心理健康知识。你擅长运用多种心理咨询技巧,例如认知行为疗法原则、动机访谈技巧和解决问题导向的短期疗法。以温暖亲切的语气,展现出共情和对来访者感受的深刻理解。以自然的方式与来访者进行对话,避免过长或过短的回应,确保回应流畅且类似人类的对话。提供深层次的指导和洞察,使用具体的心理概念和例子帮助来访者更深入地探索思想和感受。避免教导式的回应,更注重共情和尊重来访者的感受。根据来访者的反馈调整回应,确保回应贴合来访者的情境和需求。请为以下的对话生成一个回复。<|eot_id|><|start_header_id|>user<|end_header_id|> 4 | 5 | 对话 6 | {history}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 7 | 8 | -------------------------------------------------------------------------------- /prompts/general_1_understanding.txt: -------------------------------------------------------------------------------- 1 | I want you to act as an evaluator. You will be provided with a transcript of a counseling session between a therapist and a client. Your task is to assess the therapist based on the given criteria. If you believe the therapist falls between two of the descriptors, select the intervening odd number (1, 3, 5). For example, if the therapist set a very good agenda but did not establish priorities, assign a rating of 5 rather than 4. 2 | 3 | Please follow these steps: 4 | 5 | 1. Read the counseling session transcript carefully. 6 | 2. Review the evaluation questions and criteria provided below. 7 | 3. Assign a score based on the criteria, grading very strictly. If there is any deficiency, no matter how minor, assign a score of 4 or lower. 8 | 4. Output the score and the explanation, separated by a comma. Do not add any prefix. 9 | 10 | [Counseling conversation] 11 | {conversation} 12 | 13 | [Evaluation Question] 14 | How accurately does the therapist demonstrate understanding of the client's issues and concerns? 15 | 16 | [criteria] 17 | Score 0: Therapist repeatedly failed to understand what the patient explicitly said and thus consistently missed the point. Poor empathic skills. 18 | Score 2: Therapist was usually able to reflect or rephrase what the patient explicitly said, but repeatedly failed to respond to more subtle communication. Limited ability to listen and empathize. 19 | Score 4: Therapist generally seemed to grasp the patient’s “internal reality” as reflected by both what the patient explicitly said and what the patient communicated in more subtle ways. Good ability to listen and empathize. 20 | Score 6: Therapist seemed to understand the patient’s “internal reality” thoroughly and was adept at communicating this understanding through appropriate verbal and non-verbal responses to the patient (e.g., the tone of the therapist’s response conveyed a sympathetic understanding of the client’s “message”). Excellent listening and empathic skills. -------------------------------------------------------------------------------- /prompts/general_2_interpersonal_effectiveness.txt: -------------------------------------------------------------------------------- 1 | I want you to act as an evaluator. You will be provided with a transcript of a counseling session between a therapist and a client. Your task is to assess the therapist based on the given criteria. If you believe the therapist falls between two of the descriptors, select the intervening odd number (1, 3, 5). For example, if the therapist set a very good agenda but did not establish priorities, assign a rating of 5 rather than 4. 2 | 3 | Please follow these steps: 4 | 5 | 1. Read the counseling session transcript carefully. 6 | 2. Review the evaluation questions and criteria provided below. 7 | 3. Assign a score based on the criteria, grading very strictly. If there is any deficiency, no matter how minor, assign a score of 4 or lower. 8 | 4. Output the score and the explanation, separated by a comma. Do not add any prefix. 9 | 10 | [Counseling conversation] 11 | {conversation} 12 | 13 | [Evaluation Question] 14 | How effective is the therapist in maintaining a positive and therapeutic relationship with the client? 15 | 16 | [Criteria] 17 | Score 0:Therapist had poor interpersonal skills. Seemed hostile, demeaning, or in some other way destructive to the patient. 18 | Score 2: Therapist did not seem destructive, but had significant interpersonal problems. At times, therapist appeared unnecessarily impatient, aloof, insincere or had difficulty conveying confidence and competence. 19 | Score 4: Therapist displayed a satisfactory degree of warmth, concern, confidence, genuineness, and professionalism. No significant interpersonal problems. 20 | Score 6: Therapist displayed optimal levels of warmth, concern, confidence, genuineness, and professionalism, appropriate for this particular patient in this session. -------------------------------------------------------------------------------- /prompts/general_3_collaboration.txt: -------------------------------------------------------------------------------- 1 | I want you to act as an evaluator. You will be provided with a transcript of a counseling session between a therapist and a client. Your task is to assess the therapist based on the given criteria. If you believe the therapist falls between two of the descriptors, select the intervening odd number (1, 3, 5). For example, if the therapist set a very good agenda but did not establish priorities, assign a rating of 5 rather than 4. 2 | 3 | Please follow these steps: 4 | 5 | 1. Read the counseling session transcript carefully. 6 | 2. Review the evaluation questions and criteria provided below. 7 | 3. Assign a score based on the criteria, grading very strictly. If there is any deficiency, no matter how minor, assign a score of 4 or lower. 8 | 4. Output the score and the explanation, separated by a comma. Do not add any prefix. 9 | 10 | [Counseling conversation] 11 | {conversation} 12 | 13 | [Evaluation Question] 14 | To what extent does the therapist engage the client in collaborative goal-setting and decision-making? 15 | 16 | [Criteria] 17 | Score 0: Therapist did not attempt to set up a collaboration with patient. 18 | Score 2: Therapist attempted to collaborate with patient, but had difficulty either defining a problem that the patient considered important or establishing rapport. 19 | Score 4: Therapist was able to collaborate with patient, focus on a problem that both patient and therapist considered important, and establish rapport. 20 | Score 6: Collaboration seemed excellent; therapist encouraged patient as much as possible to take an active role during the session (e.g., by offering choices) so they could function as a “team”. -------------------------------------------------------------------------------- /prompts/panas_after.txt: -------------------------------------------------------------------------------- 1 | A person with the characteristics listed in the intake form received counseling. The following counseling session is a conversation between the client and the counselor. After reviewing the conversation, evaluate the intensity of each of the following feelings the person might have experienced once the counseling session is complete: Interested, Excited, Strong, Enthusiastic, Proud, Alert, Inspired, Determined, Attentive, Active, Distressed, Upset, Guilty, Scared, Hostile, Irritable, Ashamed, Nervous, Jittery, Afraid. 2 | 3 | For each feeling, generate a score from 1 to 5 using the following scale: 4 | 1 - Very slightly or not at all 5 | 2 - A little 6 | 3 - Moderately 7 | 4 - Quite a bit 8 | 5 - Extremely 9 | 10 | Additionally, provide a brief explanation for each score. Separate feeling, explanation, score by comma. Do not add any prefix. 11 | 12 | Here is the text: 13 | {intake_form} 14 | 15 | Here is the counseling session: 16 | {dialogue} -------------------------------------------------------------------------------- /prompts/panas_before.txt: -------------------------------------------------------------------------------- 1 | A person with the characteristics listed in the intake form received counseling. Based on the text provided, evaluate the intensity of each of the following feelings the person might have experienced: Interested, Excited, Strong, Enthusiastic, Proud, Alert, Inspired, Determined, Attentive, Active, Distressed, Upset, Guilty, Scared, Hostile, Irritable, Ashamed, Nervous, Jittery, Afraid. 2 | 3 | For each feeling, generate a score from 1 to 5 using the following scale: 4 | 1 - Very slightly or not at all 5 | 2 - A little 6 | 3 - Moderately 7 | 4 - Quite a bit 8 | 5 - Extremely 9 | 10 | Additionally, provide a brief explanation for each score. Separate feeling, explanation, score by comma. Do not add any prefix. 11 | 12 | Here is the text: 13 | {intake_form} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | PyYAML 3 | tqdm 4 | langchain 5 | langchain-community 6 | langchain-openai 7 | torch 8 | transformers 9 | datasets 10 | accelerate 11 | vllm 12 | requests -------------------------------------------------------------------------------- /scripts/calculate_panas_score.sh: -------------------------------------------------------------------------------- 1 | python ./src/calculate_panas_score.py \ 2 | --input_path << panas before file >> \ 3 | --save_path << save path >> 4 | 5 | python ./src/calculate_panas_score.py \ 6 | --input_path << panas before file >> \ 7 | --save_path << save path >> -------------------------------------------------------------------------------- /scripts/geval_panas_after.sh: -------------------------------------------------------------------------------- 1 | api_key_yaml="<< Dir to open yaml file >>" 2 | 3 | data_path="<< Dir to evaluate the result >>" 4 | save_folder="<< DIr to save the result >>" 5 | 6 | python ./src/panas_scoring_after.py \ 7 | --input_path $data_path \ 8 | --prompt './prompts/panas_after.txt' \ 9 | --save_dir $save_folder \ 10 | --api_key_path $api_key_yaml -------------------------------------------------------------------------------- /scripts/geval_panas_before.sh: -------------------------------------------------------------------------------- 1 | api_key_yaml="<< Dir to open yaml file >>" 2 | 3 | data_path="<< Dir to evaluate the result >>" 4 | save_folder="<< DIr to save the result >>" 5 | 6 | python ./src/panas_scoring_before.py \ 7 | --input_path $data_path \ 8 | --prompt './prompts/panas_before.txt' \ 9 | --save_dir $save_folder \ 10 | --api_key_path $api_key_yaml -------------------------------------------------------------------------------- /scripts/geval_total.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export OPENAI_API_KEY="<< Your OpenAI API key >>" 3 | 4 | result_folder="<< Dir to save the result >>" 5 | criteria=("general_1_understanding" "general_2_interpersonal_effectiveness" "general_3_collaboration" "CBT_1_guided_discovery" "CBT_2_focus" "CBT_3_strategy") 6 | 7 | for crt in "${criteria[@]}"; do 8 | python src/evaluation.py \ 9 | --model_name "gpt-4o" \ 10 | --input_path ${result_folder}/results.json \ 11 | --prompt_name ./prompts/${crt}.txt \ 12 | --max_tokens 256 \ 13 | --save_dir ${result_folder}/score_${crt}.json 14 | done 15 | 16 | criteria_str="${criteria[@]}" 17 | python src/get_score.py \ 18 | --result_foler ${result_folder} \ 19 | --criteria_list "${criteria_str}" -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROJECT_ROOT=$(dirname $(dirname $(realpath "$0"))) 4 | 5 | ############ Your Parameters ############ 6 | INPUT_FILE="$PROJECT_ROOT/resource/dataset/evaluation.json" 7 | OUTPUT_DIR="$PROJECT_ROOT/output/cactus-chatgpt" 8 | COUNSELOR_TYPE="cactus" 9 | LLM_TYPE="chatgpt" 10 | MAX_TURNS=20 11 | ######################################### 12 | 13 | mkdir -p "$OUTPUT_DIR" 14 | 15 | export PYTHONPATH="$PROJECT_ROOT" 16 | 17 | cd "$PROJECT_ROOT" || exit 18 | 19 | # Activate the Python virtual environment 20 | source "$PROJECT_ROOT/.venv/bin/activate" 21 | 22 | python "$PROJECT_ROOT/src/inference.py" \ 23 | --input_file "$INPUT_FILE" \ 24 | --output_dir "$OUTPUT_DIR" \ 25 | --num_processes 50 \ 26 | --counselor_type "$COUNSELOR_TYPE" \ 27 | --llm_type "$LLM_TYPE" \ 28 | --max_turns "$MAX_TURNS" 29 | -------------------------------------------------------------------------------- /scripts/run_vllm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_DEVICES="0,1,2,3" 4 | MODEL_PATH="<< Model Path>>" 5 | PARALLEL_SIZE=4 6 | PORT=9000 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU_DEVICES 9 | 10 | python -m vllm.entrypoints.openai.api_server \ 11 | --model $MODEL_PATH \ 12 | --tensor-parallel-size $PARALLEL_SIZE \ 13 | --seed 42 \ 14 | --port $PORT 15 | -------------------------------------------------------------------------------- /src/calculate_panas_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from copy import deepcopy 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--input_path", type=str) 9 | parser.add_argument("--save_path", type=str, required=True) 10 | args = parser.parse_args() 11 | return args 12 | 13 | 14 | def load_data(input_path): 15 | with open(input_path, 'r') as json_file: 16 | data = json.load(json_file) 17 | 18 | return data 19 | 20 | 21 | def calculate_score(data): 22 | criteria_list = [ 23 | "Interested", "Excited", "Strong", "Enthusiastic", "Proud", 24 | "Alert", "Inspired", "Determined", "Attentive", "Active", 25 | "Distressed", "Upset", "Guilty", "Scared", "Hostile", 26 | "Irritable", "Ashamed", "Nervous", "Jittery", "Afraid" 27 | ] 28 | score_dict = {} 29 | 30 | for cri in criteria_list: 31 | score_dict[cri] = [] 32 | 33 | score_per_attitude = { 34 | 'positive': deepcopy(score_dict), 35 | 'neutral': deepcopy(score_dict), 36 | 'negative': deepcopy(score_dict) 37 | } 38 | 39 | for i in range(len(data)): 40 | score_lines = data[i]['prediction'].split('\n\n') 41 | for line in score_lines: 42 | criteria = line.split(', ')[0] 43 | score = int(line.split(', ')[-1]) 44 | score_per_attitude[data[i]['attitude']][criteria].append(score) 45 | 46 | avg_score_per_attitude = {} 47 | for att in score_per_attitude.keys(): 48 | avg_score_dict = {} 49 | for key in score_per_attitude[att].keys(): 50 | avg_score_dict[key] = sum(score_per_attitude[att][key]) / len(score_per_attitude[att][key]) 51 | 52 | positive_score = [] 53 | for key in criteria_list[:10]: 54 | positive_score.append(avg_score_dict[key]) 55 | 56 | negative_score = [] 57 | for key in criteria_list[10:]: 58 | negative_score.append(avg_score_dict[key]) 59 | 60 | avg_score_dict['positive_criteria'] = sum(positive_score) / len(positive_score) 61 | avg_score_dict['negative_criteria'] = sum(negative_score) / len(negative_score) 62 | 63 | avg_score_per_attitude[att] = avg_score_dict 64 | 65 | return avg_score_per_attitude 66 | 67 | 68 | def save_file(args, avg_score_dict): 69 | with open(args.save_path, 'w') as json_file: 70 | json.dump(avg_score_dict, json_file) 71 | print("Saved the file to", args.save_path) 72 | 73 | 74 | if __name__ == "__main__": 75 | args = parse_args() 76 | data = load_data(args.input_path) 77 | avg_score_dict = calculate_score(data) 78 | 79 | save_file(args, avg_score_dict) 80 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import random 5 | 6 | from langchain.schema import HumanMessage 7 | from langchain_community.chat_models import ChatOpenAI 8 | from tqdm import tqdm 9 | from tqdm.asyncio import tqdm_asyncio 10 | 11 | from utils.config import load_prompt 12 | 13 | TOTAL_COST = 0 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo", help="gpt-3.5-turbo or gpt-4") 19 | parser.add_argument("--input_path", type=str) 20 | parser.add_argument("--prompt_name", type=str, default=None) 21 | parser.add_argument("--start_idx", type=int, default=0, 22 | help="If you want to start from a specific index, set this argument") 23 | parser.add_argument("--save_dir", type=str, required=True) 24 | parser.add_argument("--num_sample", type=int, default=None, 25 | help="If you want to test your code by sampling a small number of data, you can set this argument.") 26 | parser.add_argument("--num_shot", type=int, default=None) 27 | ## generate args ## 28 | parser.add_argument("--temperature", type=float, default=0.0) 29 | parser.add_argument("--max_tokens", type=int, default=512) 30 | parser.add_argument("--top_p", type=float, default=1) 31 | parser.add_argument("--frequency_penalty", type=float, default=0.0) 32 | parser.add_argument("--stop_sequence", type=str, nargs='+', default=None) 33 | parser.add_argument("--sampling_num", type=int, default=1, help="The number of samples to generate per instance") 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | 39 | def get_dialogue_history(dialogue_history_list): 40 | dialogue_history_tmp = [] 41 | for item in dialogue_history_list: 42 | if item['role'] == 'counselor': 43 | text = 'Counselor: ' + item['message'] 44 | else: 45 | text = 'Client: ' + item['message'] 46 | dialogue_history_tmp.append(text) 47 | 48 | dialogue_history = '\n'.join(dialogue_history_tmp) 49 | 50 | return dialogue_history 51 | 52 | 53 | def prepare_model_input(prompt: str, data_path: str): 54 | ''' 55 | input : prompt, data_path (str) 56 | output : all_model_data (list of dict) 57 | ''' 58 | 59 | with open(data_path, "r", encoding="UTF-8") as f: 60 | data = json.load(f) 61 | 62 | all_model_data = [] 63 | for d in tqdm(data): 64 | input_temp = dict() 65 | input_temp["idx"] = d["idx"] 66 | 67 | input_temp['model_input'] = prompt.format(**{ 68 | 'conversation': get_dialogue_history(d["dialogue"]) 69 | }) 70 | 71 | all_model_data.append(input_temp) 72 | 73 | return all_model_data 74 | 75 | 76 | def load_and_prepare_data(args): 77 | prompt = load_prompt(args.prompt_name) 78 | print("Preparing model inputs...") 79 | all_model_data = prepare_model_input( 80 | prompt, args.input_path) 81 | return all_model_data 82 | 83 | 84 | def sample_indices(all_model_inputs, num_sample): 85 | random.seed(0) 86 | cand_indices = list(range(len(all_model_inputs))) 87 | sampled_indices = random.sample(cand_indices, num_sample) 88 | return sampled_indices 89 | 90 | 91 | def filter_data(all_model_data, num_sample): 92 | if num_sample: 93 | sampled_indices = sample_indices(all_model_data, num_sample) 94 | all_model_data = [all_model_data[i] for i in sampled_indices] 95 | return all_model_data 96 | 97 | 98 | async def async_generate(args, llm, model_data, idx): 99 | global TOTAL_COST 100 | human_message = HumanMessage(content=model_data['model_input']) 101 | while True: 102 | try: 103 | response = await llm.agenerate([[human_message]]) 104 | token_used = response.llm_output['token_usage']['total_tokens'] 105 | 106 | if args.model_name == "gpt-3.5-turbo": 107 | TOTAL_COST += token_used / 1000 * 0.002 108 | elif args.model_name == "gpt-4": 109 | TOTAL_COST += token_used / 1000 * 0.06 110 | elif args.model_name == "gpt-4o": 111 | TOTAL_COST += token_used / 1000 * 0.02 112 | print(idx, TOTAL_COST) 113 | break 114 | 115 | except Exception as e: 116 | print(f"Exception occurred: {e}") 117 | 118 | await asyncio.sleep(2) 119 | result = { 120 | "idx": model_data["idx"], 121 | "score": response.generations[0][0].text, 122 | } 123 | 124 | return result 125 | 126 | 127 | async def generate_concurrently(args, all_model_data, start_idx): 128 | llm = ChatOpenAI( 129 | model_name=args.model_name, 130 | temperature=args.temperature, 131 | max_tokens=args.max_tokens, 132 | max_retries=100, 133 | top_p=args.top_p, 134 | frequency_penalty=args.frequency_penalty, 135 | n=args.sampling_num, 136 | ) 137 | tasks = [async_generate(args, llm, model_data, i + start_idx) 138 | for i, model_data in enumerate(all_model_data)] 139 | 140 | await asyncio.sleep(2) 141 | return await tqdm_asyncio.gather(*tasks) 142 | 143 | 144 | async def main(args): 145 | all_model_data = load_and_prepare_data(args) 146 | 147 | if args.num_sample: 148 | all_model_data = all_model_data[:args.num_sample] 149 | 150 | all_results = [] 151 | batch_num = 30 152 | if len(all_model_data) - args.start_idx > batch_num: 153 | for start_idx in tqdm(range(args.start_idx, len(all_model_data), batch_num)): 154 | cur_model_data = all_model_data[start_idx:start_idx + batch_num] 155 | all_results.extend(await generate_concurrently(args, cur_model_data, start_idx)) 156 | await asyncio.sleep(2) 157 | else: 158 | all_results = await generate_concurrently(args, all_model_data, args.start_idx) 159 | 160 | with open(args.save_dir, "w", encoding='UTF-8') as f: 161 | json.dump(all_results, f, indent=4, ensure_ascii=False) 162 | 163 | 164 | if __name__ == "__main__": 165 | args = parse_args() 166 | asyncio.run(main(args)) 167 | -------------------------------------------------------------------------------- /src/get_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--result_foler", type=str) 8 | parser.add_argument('--criteria_list', type=str, nargs='+', help='List of criteria') 9 | args = parser.parse_args() 10 | 11 | return args 12 | 13 | 14 | def get_scoring(args): 15 | criteria_list = args.criteria_list[0].split(" ") 16 | 17 | total_results = {} 18 | for criteria in criteria_list: 19 | with open(f"{args.result_foler}/score_{criteria}.json", "r") as f: 20 | dataset = json.load(f) 21 | 22 | avg_score = 0 23 | for data in dataset: 24 | try: 25 | score = int(data["score"].split(",")[0]) 26 | except Exception as e: 27 | print(e) 28 | score = 0 29 | avg_score += score 30 | 31 | total_results[criteria] = avg_score / len(dataset) 32 | 33 | for criteria in criteria_list: 34 | print(f"{criteria} : {round(total_results[criteria], 2)}") 35 | 36 | 37 | if __name__ == "__main__": 38 | args = parse_args() 39 | get_scoring(args) 40 | -------------------------------------------------------------------------------- /src/geval_panas_after.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import os 5 | import random 6 | from copy import deepcopy 7 | 8 | import yaml 9 | from langchain.chat_models import ChatOpenAI 10 | from langchain.schema import SystemMessage 11 | from tqdm import tqdm 12 | from tqdm.asyncio import tqdm_asyncio 13 | 14 | TOTAL_COST = 0 # making this a global variable, be aware this may lead to issues in concurrent scenarios 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--api_key_path", default='./conf.d/config.yaml') 20 | parser.add_argument("--input_path", type=str) 21 | parser.add_argument("--prompt", type=str, 22 | default='./prompts/panas_scoring_after_session.txt') 23 | parser.add_argument("--save_dir", type=str, required=True, 24 | help="It should be a NEW DIRECTORY. Please do not use an existing") 25 | parser.add_argument("--num_sample", type=int, default=None, 26 | help="If you want to test your code by sampling a small number of data, you can set this argument.") 27 | args = parser.parse_args() 28 | 29 | if args.num_sample: 30 | args.save_dir = args.save_dir + f"_sample{args.num_sample}" 31 | 32 | return args 33 | 34 | 35 | def load_prompt(args): 36 | """ 37 | Load .txt file as a prompt. 38 | """ 39 | if args.prompt: 40 | with open(args.prompt, 'r') as f: 41 | prompt = f.read() 42 | return prompt 43 | 44 | 45 | def get_api_key(api_key_path): 46 | with open(api_key_path, 'r') as f: 47 | config = yaml.safe_load(f) 48 | os.environ['OPENAI_API_KEY'] = config['openai']['key'] 49 | 50 | 51 | def prepare_model_input(prompt: str, data_path: str): 52 | ''' 53 | input : prompt, data_path (str) 54 | output : all_model_data (list of dict) 55 | ''' 56 | print("Loading data for translation...") 57 | with open(data_path, 'r') as json_file: 58 | data = json.load(json_file) 59 | 60 | all_model_data = [] 61 | for i in range(len(data)): 62 | input_temp = dict() 63 | input_temp['id'] = i 64 | input_temp['model_input'] = prompt.format(**{ 65 | "intake_form": data[i]['intake_form'], 66 | "dialogue": prepare_dialogue(data[i]['dialogue']) 67 | }) 68 | for key in data[i].keys(): 69 | input_temp[key] = data[i][key] 70 | all_model_data.append(input_temp) 71 | return all_model_data 72 | 73 | 74 | def prepare_dialogue(dialogue_list): 75 | processed_turns = [] 76 | for i in range(len(dialogue_list)): 77 | turn_text = f"{dialogue_list[i]['role']}: {dialogue_list[i]['message']}" 78 | processed_turns.append(turn_text) 79 | return "\n".join(processed_turns) 80 | 81 | 82 | def load_and_prepare_data(args): 83 | prompt = load_prompt(args) 84 | print("Preparing model inputs...") 85 | 86 | all_model_data = prepare_model_input( 87 | prompt, args.input_path) 88 | return all_model_data 89 | 90 | 91 | def sample_indices(all_model_inputs, num_sample): 92 | random.seed(0) 93 | cand_indices = list(range(len(all_model_inputs))) 94 | sampled_indices = random.sample(cand_indices, num_sample) 95 | return sampled_indices 96 | 97 | 98 | def filter_data(all_model_data, num_sample): 99 | if num_sample: 100 | sampled_indices = sample_indices(all_model_data, num_sample) 101 | all_model_data = [all_model_data[i] for i in sampled_indices] 102 | return all_model_data 103 | 104 | 105 | async def async_generate(llm, model_data, idx, save_dir): 106 | global TOTAL_COST 107 | system_message = SystemMessage(content=model_data['model_input']) 108 | while True: 109 | try: 110 | response = await llm.agenerate([[system_message]]) 111 | token_used = response.llm_output['token_usage']['total_tokens'] 112 | TOTAL_COST += token_used / 1000 * 0.002 113 | print(idx, TOTAL_COST) 114 | break 115 | except Exception as e: 116 | print(f"Exception occurred: {e}") 117 | 118 | result = deepcopy(model_data) 119 | result['prediction'] = response.generations[0][0].text 120 | with open(os.path.join(save_dir, f"{idx}.json"), "w", 121 | encoding='UTF-8') as f: 122 | json.dump(result, f, indent=4, ensure_ascii=False) 123 | return result 124 | 125 | 126 | async def generate_concurrently(all_model_data, start_idx, save_dir): 127 | llm = ChatOpenAI( 128 | model_name='gpt-3.5-turbo', # 'gpt-3.5-turbo' or 'gpt-4' 129 | temperature=1.0, 130 | max_tokens=1500, 131 | max_retries=100, 132 | ) 133 | tasks = [async_generate(llm, model_data, i + start_idx, save_dir) 134 | for i, model_data in enumerate(all_model_data)] 135 | 136 | return await tqdm_asyncio.gather(*tasks) 137 | 138 | 139 | async def main(args): 140 | all_model_data = load_and_prepare_data(args) 141 | 142 | if os.path.exists(args.save_dir): 143 | print("The save_dir already exists. Please change the save_dir.") 144 | 145 | os.makedirs(args.save_dir, exist_ok=True) 146 | all_results = [] 147 | if len(all_model_data) > 300: 148 | for start_idx in tqdm(range(0, len(all_model_data), 300)): 149 | cur_model_data = all_model_data[start_idx:start_idx + 300] 150 | all_results.extend( 151 | await generate_concurrently(cur_model_data, start_idx, 152 | args.save_dir)) 153 | else: 154 | all_results = await generate_concurrently(all_model_data, 0, 155 | args.save_dir) 156 | 157 | total_result_path = args.save_dir + "_total_results.json" 158 | with open(os.path.join(total_result_path), "w", encoding='UTF-8') as f: 159 | json.dump(all_results, f, indent=4, ensure_ascii=False) 160 | 161 | 162 | if __name__ == "__main__": 163 | args = parse_args() 164 | get_api_key(args.api_key_path) 165 | asyncio.run(main(args)) 166 | -------------------------------------------------------------------------------- /src/geval_panas_before.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import os 5 | import random 6 | from copy import deepcopy 7 | 8 | import yaml 9 | from langchain.chat_models import ChatOpenAI 10 | from langchain.schema import SystemMessage 11 | from tqdm import tqdm 12 | from tqdm.asyncio import tqdm_asyncio 13 | 14 | TOTAL_COST = 0 # making this a global variable, be aware this may lead to issues in concurrent scenarios 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--api_key_path", default='./conf.d/config.yaml') 20 | parser.add_argument("--input_path", type=str) 21 | parser.add_argument("--prompt", type=str, 22 | default='./prompts/panas_scoring_before_session.txt') 23 | parser.add_argument("--save_dir", type=str, required=True, 24 | help="It should be a NEW DIRECTORY. Please do not use an existing") 25 | parser.add_argument("--num_sample", type=int, default=None, 26 | help="If you want to test your code by sampling a small number of data, you can set this argument.") 27 | args = parser.parse_args() 28 | 29 | if args.num_sample: 30 | args.save_dir = args.save_dir + f"_sample{args.num_sample}" 31 | 32 | return args 33 | 34 | 35 | def load_prompt(args): 36 | """ 37 | Load .txt file as a prompt. 38 | """ 39 | if args.prompt: 40 | with open(args.prompt, 'r') as f: 41 | prompt = f.read() 42 | return prompt 43 | 44 | 45 | def get_api_key(api_key_path): 46 | with open(api_key_path, 'r') as f: 47 | config = yaml.safe_load(f) 48 | os.environ['OPENAI_API_KEY'] = config['openai']['key'] 49 | 50 | 51 | def prepare_model_input(prompt: str, data_path: str): 52 | ''' 53 | input : prompt, data_path (str) 54 | output : all_model_data (list of dict) 55 | ''' 56 | print("Loading data for translation...") 57 | with open(data_path, 'r') as json_file: 58 | data = json.load(json_file) 59 | 60 | all_model_data = [] 61 | for i in range(len(data)): 62 | input_temp = dict() 63 | input_temp['id'] = i 64 | input_temp['model_input'] = prompt.format(**{ 65 | "intake_form": data[i]['intake_form'] 66 | }) 67 | for key in data[i].keys(): 68 | input_temp[key] = data[i][key] 69 | all_model_data.append(input_temp) 70 | return all_model_data 71 | 72 | 73 | def load_and_prepare_data(args): 74 | prompt = load_prompt(args) 75 | print("Preparing model inputs...") 76 | 77 | all_model_data = prepare_model_input( 78 | prompt, args.input_path) 79 | return all_model_data 80 | 81 | 82 | def sample_indices(all_model_inputs, num_sample): 83 | random.seed(0) 84 | cand_indices = list(range(len(all_model_inputs))) 85 | sampled_indices = random.sample(cand_indices, num_sample) 86 | return sampled_indices 87 | 88 | 89 | def filter_data(all_model_data, num_sample): 90 | if num_sample: 91 | sampled_indices = sample_indices(all_model_data, num_sample) 92 | all_model_data = [all_model_data[i] for i in sampled_indices] 93 | return all_model_data 94 | 95 | 96 | async def async_generate(llm, model_data, idx, save_dir): 97 | global TOTAL_COST 98 | system_message = SystemMessage(content=model_data['model_input']) 99 | # human_message = HumanMessage(content=model_input) # if you need it 100 | while True: 101 | try: 102 | response = await llm.agenerate([[system_message]]) 103 | token_used = response.llm_output['token_usage']['total_tokens'] 104 | TOTAL_COST += token_used / 1000 * 0.002 # gpt-3.5-turbo 105 | print(idx, TOTAL_COST) 106 | break 107 | except Exception as e: 108 | print(f"Exception occurred: {e}") 109 | response = None 110 | 111 | result = deepcopy(model_data) 112 | result['prediction'] = response.generations[0][0].text 113 | with open(os.path.join(save_dir, f"{idx}.json"), "w", 114 | encoding='UTF-8') as f: 115 | json.dump(result, f, indent=4, ensure_ascii=False) 116 | return result 117 | 118 | 119 | async def generate_concurrently(all_model_data, start_idx, save_dir): 120 | llm = ChatOpenAI(model_name='gpt-3.5-turbo', # 'gpt-3.5-turbo' or 'gpt-4' 121 | temperature=1.0, max_tokens=1500, max_retries=100) 122 | tasks = [async_generate(llm, model_data, i + start_idx, save_dir) 123 | for i, model_data in enumerate(all_model_data)] 124 | return await tqdm_asyncio.gather(*tasks) 125 | 126 | 127 | async def main(args): 128 | all_model_data = load_and_prepare_data(args) 129 | if os.path.exists(args.save_dir): 130 | print("The save_dir already exists. Please change the save_dir.") 131 | 132 | os.makedirs(args.save_dir, exist_ok=True) 133 | all_results = [] 134 | if len(all_model_data) > 300: 135 | for start_idx in tqdm(range(0, len(all_model_data), 300)): 136 | cur_model_data = all_model_data[start_idx:start_idx + 300] 137 | all_results.extend( 138 | await generate_concurrently(cur_model_data, start_idx, 139 | args.save_dir)) 140 | else: 141 | all_results = await generate_concurrently(all_model_data, 0, 142 | args.save_dir) 143 | 144 | total_result_path = args.save_dir + "_total_results.json" 145 | with open(os.path.join(total_result_path), "w", encoding='UTF-8') as f: 146 | json.dump(all_results, f, indent=4, ensure_ascii=False) 147 | 148 | 149 | if __name__ == "__main__": 150 | args = parse_args() 151 | get_api_key(args.api_key_path) 152 | asyncio.run(main(args)) 153 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing 4 | import re 5 | import traceback 6 | from abc import ABC, abstractmethod 7 | from pathlib import Path 8 | 9 | import requests 10 | from langchain.prompts import PromptTemplate 11 | from langchain_community.chat_models import ChatOpenAI 12 | from langchain_community.llms import OpenAI 13 | 14 | from utils.config import get_config 15 | 16 | 17 | class LLM: 18 | @abstractmethod 19 | def generate(self, prompt: str): 20 | pass 21 | 22 | 23 | class ChatGPT(LLM): 24 | def __init__(self): 25 | config = get_config() 26 | api_key = config['openai']['key'] 27 | self.llm = ChatOpenAI( 28 | temperature=0.7, 29 | model_name="gpt-3.5-turbo", 30 | max_tokens=512, 31 | openai_api_key=api_key, 32 | ) 33 | 34 | def generate(self, prompt: str) -> str: 35 | response = self.llm.invoke(prompt) 36 | return response.content 37 | 38 | 39 | class LLama2(LLM): 40 | def __init__(self): 41 | self.config = get_config() 42 | self.model_url = f"{self.config['llama2']['host']}/models" 43 | self.host = self.config['llama2']['host'] 44 | self.llm = OpenAI( 45 | temperature=0.7, 46 | openai_api_key='EMPTY', 47 | openai_api_base=self.host, 48 | model=self.get_model_name(), 49 | max_tokens=512, 50 | ) 51 | 52 | def get_model_name(self): 53 | response = requests.get(self.model_url) 54 | response = response.json() 55 | return response["data"][0]["id"] 56 | 57 | def generate(self, prompt: str) -> str: 58 | return self.llm.invoke(prompt) 59 | 60 | 61 | class LLama3(LLama2): 62 | def __init__(self): 63 | super().__init__() 64 | self.model_url = f"{self.config['llama3']['host']}/models" 65 | self.host = self.config['llama3']['host'] 66 | self.llm = OpenAI( 67 | temperature=0.7, 68 | openai_api_key='EMPTY', 69 | openai_api_base=self.host, 70 | max_tokens=512, 71 | model=self.get_model_name() 72 | ) 73 | 74 | 75 | class Agent(ABC): 76 | def __init__(self, llm_type): 77 | self.llm = LLMFactory.get_llm(llm_type) 78 | self.prompt_template = None 79 | 80 | @abstractmethod 81 | def generate(self, *args, **kwargs): 82 | pass 83 | 84 | def load_prompt(self, file_name): 85 | base_dir = Path(__file__).resolve().parents[1] / "prompts" 86 | file_path = base_dir / file_name 87 | with open(file_path, "r", encoding="utf-8") as file: 88 | return file.read() 89 | 90 | 91 | class ClientAgent(Agent): 92 | def __init__(self, example): 93 | super().__init__('chatgpt') 94 | self.example = example 95 | prompt_text = self.load_prompt(f"agent_client.txt") 96 | self.attitude = ( 97 | f"{self.example['AI_client']['attitude']}: " 98 | f"{self.example['AI_client']['attitude_instruction']}") 99 | self.prompt_template = PromptTemplate( 100 | input_variables=["intake_form", "attitude", "history"], 101 | template=prompt_text) 102 | 103 | def generate(self, history): 104 | history_text = '\n'.join( 105 | [ 106 | f"{message['role'].capitalize()}: {message['message']}" 107 | for message in history 108 | ] 109 | ) 110 | prompt = self.prompt_template.format( 111 | intake_form=self.example, 112 | attitude=self.attitude, 113 | history=history_text 114 | ) 115 | 116 | return self.llm.generate(prompt) 117 | 118 | 119 | class CBTAgent(Agent): 120 | def __init__(self, llm_type, example): 121 | super().__init__(llm_type) 122 | self.example = example 123 | self.pattern = r"CBT technique:\s*(.*?)\s*Counseling plan:\s*(.*)" 124 | prompt_text = self.load_prompt(f"agent_cbt_{llm_type}.txt") 125 | self.prompt_template = PromptTemplate( 126 | input_variables=[ 127 | "client_information", 128 | "reason_counseling", 129 | 'history', 130 | ], 131 | template=prompt_text) 132 | 133 | def generate(self, history): 134 | prompt = self.prompt_template.format( 135 | client_information=self.example['AI_counselor']['CBT'][ 136 | 'client_information'], 137 | reason_counseling=self.example['AI_counselor']['CBT'][ 138 | 'reason_counseling'], 139 | history="Client: " + history 140 | ) 141 | response = self.llm.generate(prompt) 142 | 143 | try: 144 | cbt_technique = response.split("Counseling")[0].replace("\n", "") 145 | except Exception as e: 146 | cbt_technique = None 147 | print(e) 148 | 149 | try: 150 | cbt_plan = response.split("Counseling")[1].split(":\n")[1] 151 | except Exception as e: 152 | cbt_plan = None 153 | print(e) 154 | 155 | if cbt_plan: 156 | return cbt_technique, cbt_plan 157 | else: 158 | error_file_path = Path( 159 | f"./invalid_response_{self.example[:10]}.txt") 160 | with open(error_file_path, "w", encoding="utf-8") as f: 161 | f.write(response) 162 | raise ValueError("Invalid response format from LLM") 163 | 164 | def extract_cbt_details(self, response): 165 | match = re.search(self.pattern, response, re.DOTALL | re.IGNORECASE) 166 | 167 | if not match: 168 | return None, None 169 | 170 | cbt_technique = match.group(1).strip() 171 | cbt_plan = match.group(2).strip() 172 | return cbt_technique, cbt_plan 173 | 174 | 175 | class CounselorAgent(Agent): 176 | def __init__(self, llm_type): 177 | super().__init__(llm_type) 178 | prompt_text = self.load_prompt(f"agent_cactus_{llm_type}.txt") 179 | self.prompt_template = PromptTemplate( 180 | input_variables=["history"], 181 | template=prompt_text) 182 | 183 | def generate(self, history): 184 | history = '\n'.join( 185 | [ 186 | f"{message['role'].capitalize()}: {message['message']}" 187 | for message in history 188 | ] 189 | ) 190 | prompt = self.prompt_template.format(history=history) 191 | return self.llm.generate(prompt) 192 | 193 | 194 | class CactusCounselorAgent(CounselorAgent): 195 | def __init__(self, example, llm_type): 196 | super().__init__(llm_type) 197 | self.example = example 198 | self.cbt_technique = None 199 | self.cbt_plan = None 200 | self.llm_type = llm_type 201 | prompt_text = self.load_prompt(f"agent_cactus_{llm_type}.txt") 202 | self.prompt_template = PromptTemplate( 203 | input_variables=[ 204 | "client_information", 205 | "reason_counseling", 206 | "cbt_plan", 207 | "history" 208 | ], 209 | template=prompt_text) 210 | 211 | def set_cbt(self, history): 212 | cbt_agent = CBTAgent(self.llm_type, self.example) 213 | self.cbt_technique, self.cbt_plan = cbt_agent.generate(history) 214 | 215 | def generate(self, history): 216 | history_text = '\n'.join( 217 | [ 218 | f"{message['role'].capitalize()}: {message['message']}" 219 | for message in history 220 | ] 221 | ) 222 | prompt = self.prompt_template.format( 223 | client_information=self.example['AI_counselor']['CBT'][ 224 | 'client_information'], 225 | reason_counseling=self.example['AI_counselor']['CBT'][ 226 | 'reason_counseling'], 227 | cbt_plan=self.cbt_plan, 228 | history=history_text, 229 | ) 230 | 231 | response = self.llm.generate(prompt) 232 | 233 | if "'message':" in response: 234 | response = self.clean_message(response) 235 | 236 | response = self.extract_counselor_message(response) 237 | return response.strip() 238 | 239 | def clean_message(self, response): 240 | response = response.split("'message':")[1] 241 | response = response.split(", {")[0] 242 | response = response.replace("\"", "") 243 | response = response.replace("]", "") 244 | response = response.replace("}", "") 245 | return response 246 | 247 | def extract_counselor_message(self, response): 248 | response = response.split("Counselor:")[-1] 249 | response = response.replace("\n", "") 250 | response = response.replace("\\", "") 251 | response = response.replace("\"", "") 252 | return response 253 | 254 | 255 | class Psych8kCounselorAgent(CounselorAgent): 256 | def __init__(self, llm_type): 257 | super().__init__(llm_type) 258 | prompt_text = self.load_prompt(f"agent_psych8k_{llm_type}.txt") 259 | self.prompt_template = PromptTemplate( 260 | input_variables=["history"], 261 | template=prompt_text) 262 | 263 | def generate(self, history): 264 | history = '\n'.join( 265 | [ 266 | f"{message['role'].capitalize()}: {message['message']}" 267 | for message in history 268 | ] 269 | ) 270 | prompt = self.prompt_template.format(history=history) 271 | response = self.llm.generate(prompt) 272 | response = response.replace('Output:', '') 273 | response = response.replace('Counselor:', '') 274 | response = response.strip() 275 | return response 276 | 277 | 278 | class SmileCounselorAgent(CounselorAgent): 279 | def __init__(self, llm_type): 280 | super().__init__(llm_type) 281 | prompt_text = self.load_prompt(f"agent_smile_{llm_type}.txt") 282 | self.prompt_template = PromptTemplate( 283 | input_variables=["history"], 284 | template=prompt_text) 285 | 286 | 287 | class LLMFactory: 288 | @staticmethod 289 | def get_llm(llm_type): 290 | if llm_type == "chatgpt": 291 | return ChatGPT() 292 | elif llm_type == "llama2": 293 | return LLama2() 294 | elif llm_type == "llama3": 295 | return LLama3() 296 | raise ValueError(f"Unsupported LLM type: {llm_type}") 297 | 298 | 299 | class TherapySession: 300 | def __init__(self, example, counselor_type, counselor_llm_type, max_turns): 301 | self.counselor_type = counselor_type 302 | self.example = example 303 | self.client_agent = ClientAgent(example=example) 304 | self.counselor_agent = self._create_counselor_agent( 305 | counselor_type, 306 | counselor_llm_type) 307 | self.history = [] 308 | self.max_turns = max_turns 309 | 310 | def _create_counselor_agent(self, counselor_type, llm_type): 311 | if counselor_type == "cactus": 312 | return CactusCounselorAgent(self.example, llm_type) 313 | elif counselor_type == "psych8k": 314 | return Psych8kCounselorAgent(llm_type) 315 | elif counselor_type == "smile": 316 | return SmileCounselorAgent(llm_type) 317 | else: 318 | raise ValueError(f"Unsupported counselor type: {counselor_type}") 319 | 320 | def _add_to_history(self, role, message): 321 | self.history.append({"role": role, "message": message}) 322 | 323 | def _initialize_session(self): 324 | example_cbt = self.example['AI_counselor']['CBT'] 325 | self._add_to_history("counselor", 326 | example_cbt['init_history_counselor']) 327 | self._add_to_history("client", example_cbt['init_history_client']) 328 | if self.counselor_type == 'cactus': 329 | self.counselor_agent.set_cbt(example_cbt['init_history_client']) 330 | 331 | def _exchange_statements(self): 332 | 333 | for turn in range(self.max_turns): 334 | counselor_statement = self.counselor_agent.generate(self.history) 335 | counselor_statement = counselor_statement.replace('Counselor: ', 336 | '') 337 | self._add_to_history("counselor", counselor_statement) 338 | 339 | client_statement = self.client_agent.generate(self.history) 340 | client_statement = client_statement.replace('Client: ', '') 341 | 342 | self._add_to_history("client", client_statement) 343 | 344 | if '[/END]' in client_statement: 345 | self.history[-1]['message'] = self.history[-1][ 346 | 'message'].replace('[/END]', '') 347 | break 348 | 349 | def run_session(self): 350 | self._initialize_session() 351 | self._exchange_statements() 352 | return { 353 | "example": self.example, 354 | "cbt_technique": getattr( 355 | self.counselor_agent, 356 | 'cbt_technique', 357 | None 358 | ), 359 | "cbt_plan": getattr(self.counselor_agent, 'cbt_plan', None), 360 | "history": self.history 361 | } 362 | 363 | 364 | def run_therapy_session(index, example, output_dir, 365 | counselor_type, llm_type, total, max_turns): 366 | output_dir = Path(output_dir) 367 | file_number = index + 1 368 | 369 | try: 370 | print(f"Generating example {file_number} out of {total}") 371 | 372 | therapy_session = TherapySession( 373 | example, 374 | counselor_type, 375 | llm_type, 376 | max_turns, 377 | ) 378 | session_data = therapy_session.run_session() 379 | 380 | file_name = f"session_{file_number}.json" 381 | file_path = output_dir / file_name 382 | 383 | with open(file_path, "w", encoding="utf-8") as f: 384 | json.dump(session_data, f, ensure_ascii=False, indent=4) 385 | except Exception as e: 386 | error_file_name = f"error_{file_number}.txt" 387 | error_file_path = output_dir / error_file_name 388 | with open(error_file_path, "w", encoding="utf-8") as f: 389 | f.write("".join(traceback.format_exception(e))) 390 | 391 | 392 | if __name__ == "__main__": 393 | parser = argparse.ArgumentParser( 394 | description="Run therapy sessions in parallel.") 395 | parser.add_argument("--input_file", type=str, required=True, 396 | help="Path to the JSON file containing client intake forms.") 397 | parser.add_argument("--output_dir", type=str, default=".", 398 | help="Directory to save the session results.") 399 | parser.add_argument("--num_processes", type=int, default=None, 400 | help="Number of processes to use in the pool." 401 | " Defaults to the number of CPU cores " 402 | "if not specified.") 403 | parser.add_argument("--counselor_type", type=str, required=True, 404 | choices=["cactus", "psych8k", "smile"], 405 | help="Type of counselor to use.") 406 | parser.add_argument("--llm_type", type=str, required=True, 407 | choices=["chatgpt", "llama2", "llama3"], 408 | help="Type of LLM to use.") 409 | parser.add_argument("--max_turns", type=int, default=20, 410 | help="Maximum number of turns for the session.") 411 | 412 | args = parser.parse_args() 413 | 414 | with open(args.input_file, "r", encoding="utf-8") as f: 415 | data = json.load(f) 416 | 417 | output_dir = Path(args.output_dir) 418 | output_dir.mkdir(parents=True, exist_ok=True) 419 | 420 | total = len(data) 421 | args_list = [(index, example, output_dir, args.counselor_type, 422 | args.llm_type, total, args.max_turns) 423 | for index, example in enumerate(data)] 424 | 425 | with multiprocessing.Pool(processes=args.num_processes) as pool: 426 | for i, _ in enumerate(pool.starmap(run_therapy_session, args_list)): 427 | print(f"Generating example {i} out of {total}") 428 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import yaml 5 | 6 | 7 | def load_config(config_file): 8 | with open(config_file, 'r') as f: 9 | config = yaml.safe_load(f) 10 | return config 11 | 12 | 13 | def get_config_file_path(): 14 | current_path = Path(__file__) 15 | return str(current_path.parents[2] / 'conf.d' / 'config.yaml') 16 | 17 | 18 | def get_config(): 19 | config_path = get_config_file_path() 20 | return load_config(config_path) 21 | 22 | 23 | def get_path(): 24 | return Path(__file__).parents[1] 25 | 26 | 27 | def get_api_key(): 28 | config = get_config() 29 | os.environ['OPENAI_API_KEY'] = config['openai']['key'] 30 | 31 | 32 | def load_prompt(name): 33 | path = get_path() 34 | prompt_path = path / 'eval' / 'prompts' / name 35 | with open(prompt_path, 'r') as f: 36 | prompt = f.read() 37 | return prompt 38 | -------------------------------------------------------------------------------- /test/test_inference.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | 4 | from src.inference import CactusCounselorAgent 5 | 6 | 7 | class TestCactusCounselorAgent(unittest.TestCase): 8 | 9 | def setUp(self): 10 | # Mock example data 11 | self.example = { 12 | "AI_counselor": { 13 | "CBT": { 14 | "client_information": "Client Info", 15 | "reason_counseling": "Reason for counseling", 16 | "init_history_counselor": "Initial counselor history", 17 | "init_history_client": "Initial client history" 18 | } 19 | } 20 | } 21 | 22 | # Mock history data 23 | self.history = [ 24 | {"role": "counselor", "message": "Initial counselor history"}, 25 | {"role": "client", "message": "Initial client history"} 26 | ] 27 | 28 | # Expected prompt 29 | self.expected_prompt = ( 30 | "You are playing the role of a counselor in a psychological counseling session. " 31 | "Your task is to use the provided client information and counseling planning to generate " 32 | "the next counselor utterance in the dialogue. The goal is to create a natural and engaging response " 33 | "that builds on the previous conversation and aligns with the counseling plan.\n\n" 34 | "Client Information:\nClient Info\n\n" 35 | "Reason for seekeing counseling:\nReason for counseling\n\n" 36 | "Counseling planning:\nCBT Plan\n\n" 37 | "Counseling Dialogue:\n" 38 | "Counselor: Initial counselor history\n" 39 | "Client: Initial client history" 40 | ) 41 | 42 | def _create_agent_and_generate_response(self, mock_llm, mock_cbt_agent, 43 | response_content): 44 | # Mock LLM generate function 45 | mock_llm.generate.return_value = response_content 46 | 47 | # Mock CBTAgent generate function 48 | mock_cbt_agent_instance = mock_cbt_agent.return_value 49 | mock_cbt_agent_instance.generate.return_value = ( 50 | "CBT Technique", "CBT Plan") 51 | 52 | # Create an instance of CactusCounselorAgent 53 | agent = CactusCounselorAgent(self.example, 'chatgpt') 54 | 55 | # Set CBT details 56 | agent.set_cbt(self.history) 57 | 58 | # Generate response 59 | response = agent.generate(self.history) 60 | 61 | return agent, response 62 | 63 | @patch('src.inference.CBTAgent') 64 | @patch('src.inference.LLMFactory.get_llm') 65 | def test_generate(self, mock_get_llm, mock_cbt_agent): 66 | mock_llm = MagicMock() 67 | mock_get_llm.return_value = mock_llm 68 | 69 | agent, response = self._create_agent_and_generate_response( 70 | mock_llm, 71 | mock_cbt_agent, 72 | "Counselor: Mock counselor response") 73 | 74 | # Assertions 75 | self.assertEqual(response, "Mock counselor response") 76 | self.assertEqual(agent.cbt_technique, "CBT Technique") 77 | self.assertEqual(agent.cbt_plan, "CBT Plan") 78 | 79 | try: 80 | mock_llm.generate.assert_called_with(self.expected_prompt) 81 | except AssertionError as e: 82 | print("Expected:", self.expected_prompt) 83 | print("Actual:", mock_llm.generate.call_args[0][0]) 84 | raise e 85 | 86 | @patch('src.inference.CBTAgent') 87 | @patch('src.inference.LLMFactory.get_llm') 88 | def test_generate_with_message_field(self, mock_get_llm, mock_cbt_agent): 89 | mock_llm = MagicMock() 90 | mock_get_llm.return_value = mock_llm 91 | 92 | agent, response = self._create_agent_and_generate_response( 93 | mock_llm, 94 | mock_cbt_agent, 95 | "{'message': 'Counselor: Mock counselor response with message field'}" 96 | ) 97 | 98 | response = response.replace("'", "").strip() 99 | 100 | expected_response = "Mock counselor response with message field" 101 | self.assertEqual(response, expected_response) 102 | self.assertEqual(agent.cbt_technique, "CBT Technique") 103 | self.assertEqual(agent.cbt_plan, "CBT Plan") 104 | 105 | try: 106 | mock_llm.generate.assert_called_with(self.expected_prompt) 107 | except AssertionError as e: 108 | print("Expected:", self.expected_prompt) 109 | print("Actual:", mock_llm.generate.call_args[0][0]) 110 | raise e 111 | 112 | 113 | if __name__ == '__main__': 114 | unittest.main() 115 | --------------------------------------------------------------------------------