├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── app.py ├── assets ├── guimode_preview.gif ├── preview.gif └── preview.png ├── channels.txt ├── chats ├── __init__.py ├── alpaca.py ├── alpaca_gpt4.py ├── alpacoom.py ├── baize.py ├── central.py ├── custom.py ├── falcon.py ├── flan_alpaca.py ├── freewilly.py ├── guanaco.py ├── koalpaca.py ├── llama2.py ├── mistral.py ├── mpt.py ├── os_stablelm.py ├── post.py ├── pre.py ├── redpajama.py ├── remote_tgi.py ├── stable_vicuna.py ├── stablelm.py ├── starchat.py ├── utils.py ├── vicuna.py ├── wizard_coder.py ├── wizard_falcon.py └── xgen.py ├── configs ├── constraints_config.yaml ├── response_configs │ ├── baize.yaml │ ├── camel.yaml │ ├── default.yaml │ ├── default_4096.yaml │ ├── falcon.yaml │ ├── flan.yaml │ ├── freewilly.yaml │ ├── gpt4_alpaca.yaml │ ├── guanaco.yaml │ ├── koalpaca.yaml │ ├── llama2.yaml │ ├── mistral.yaml │ ├── mistral_openhermes.yaml │ ├── redpajama.yaml │ ├── stablelm.yaml │ ├── stackllama.yaml │ ├── starchat.yaml │ ├── t5_vicuna.yaml │ ├── upstage_llama.yaml │ ├── upstage_llama2.yaml │ ├── vicuna.yaml │ ├── wizard-coder.yaml │ └── wizardlm.yaml └── summarization_configs │ ├── camel.yaml │ ├── default.yaml │ ├── koalpaca.yaml │ ├── redpajama.yaml │ ├── stablelm.yaml │ └── t5_vicuna.yaml ├── discord.dstack.yml ├── discord_app.py ├── discordbot ├── flags.py ├── helps.py ├── post.py ├── req.py └── utils.py ├── dumb_utils.py ├── entry_point.py ├── examples.txt ├── gens ├── __init__.py └── batch_gen.py ├── global_vars.py ├── gradio.dstack.yml ├── miscs ├── __init__.py ├── js.py ├── strings.py ├── styles.py └── templates.py ├── model_cards.json ├── models ├── __init__.py ├── airoboros.py ├── alpaca.py ├── baize.py ├── bloom.py ├── byom.py ├── camel.py ├── falcon.py ├── flan_alpaca.py ├── freewilly.py ├── guanaco.py ├── koalpaca.py ├── kullm.py ├── llama_rlhf.py ├── mistral.py ├── mpt.py ├── redpajama.py ├── replit.py ├── samantha_vicuna.py ├── stablelm.py ├── starchat.py ├── t5_vicuna.py ├── vicuna.py ├── wizard_coder.py └── xgen.py ├── notebooks └── llm_as_chatbot_in_colab.ipynb ├── requirements.txt ├── scripts ├── hparams_explore.py └── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | nohup.out 4 | test.py 5 | .dstack -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/__init__.py -------------------------------------------------------------------------------- /assets/guimode_preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/assets/guimode_preview.gif -------------------------------------------------------------------------------- /assets/preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/assets/preview.gif -------------------------------------------------------------------------------- /assets/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/assets/preview.png -------------------------------------------------------------------------------- /channels.txt: -------------------------------------------------------------------------------- 1 | 1st Channel 2 | 2nd Channel 3 | 3rd Channel 4 | 4th Channel 5 | 5th Channel 6 | 6th Channel 7 | 7th Channel 8 | 8th Channel 9 | 9th Channel 10 | 10th Channel -------------------------------------------------------------------------------- /chats/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/chats/__init__.py -------------------------------------------------------------------------------- /chats/alpaca.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/alpaca_gpt4.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/alpacoom.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/baize.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, internet_search 9 | 10 | def text_stream(ppmanager, streamer): 11 | count = 0 12 | 13 | for new_text in streamer: 14 | if "[|Human|]" in new_text or \ 15 | "[|AI|]" in new_text: 16 | break 17 | 18 | if count == 0: 19 | ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") 20 | count = count + 1 21 | 22 | ppmanager.append_pong(new_text) 23 | yield ppmanager, ppmanager.build_uis() 24 | 25 | yield ppmanager, ppmanager.build_uis() 26 | 27 | def chat_stream( 28 | idx, local_data, user_message, state, 29 | global_context, ctx_num_lconv, ctx_sum_prompt, 30 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 31 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 32 | internet_option, serper_api_key 33 | ): 34 | res = [ 35 | state["ppmanager_type"].from_json(json.dumps(ppm)) 36 | for ppm in local_data 37 | ] 38 | 39 | ppm = res[idx] 40 | 41 | # add_ping returns a prompt structured in Alpaca form 42 | ppm.add_pingpong( 43 | PingPong(user_message, "") 44 | ) 45 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 46 | 47 | ####### 48 | if internet_option: 49 | search_prompt = None 50 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 51 | search_prompt = tmp_prompt 52 | yield "", uis, prompt, str(res) 53 | 54 | # prepare text generating streamer & start generating 55 | gen_kwargs, streamer = pre.build( 56 | search_prompt if internet_option else prompt, 57 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 58 | res_beams, res_cache, res_sample, res_eosid, res_padid, 59 | return_token_type_ids=False 60 | ) 61 | pre.start_gen(gen_kwargs) 62 | 63 | # handling stream 64 | for ppmanager, uis in text_stream(ppm, streamer): 65 | yield "", uis, prompt, str(res) 66 | 67 | ppm = post.strip_pong(ppm) 68 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/custom.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, internet_search 9 | 10 | def text_stream(ppmanager, streamer): 11 | count = 0 12 | thumbnail_tiny = "https://i.ibb.co/f80BpgR/byom.png" 13 | 14 | for new_text in streamer: 15 | if count == 0: 16 | ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") 17 | count = count + 1 18 | 19 | ppmanager.append_pong(new_text) 20 | yield ppmanager, ppmanager.build_uis() 21 | 22 | yield ppmanager, ppmanager.build_uis() 23 | 24 | def chat_stream( 25 | idx, local_data, user_message, state, 26 | global_context, ctx_num_lconv, ctx_sum_prompt, 27 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 28 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 29 | internet_option, serper_api_key 30 | ): 31 | res = [ 32 | state["ppmanager_type"].from_json(json.dumps(ppm)) 33 | for ppm in local_data 34 | ] 35 | 36 | ppm = res[idx] 37 | 38 | # add_ping returns a prompt structured in Alpaca form 39 | ppm.add_pingpong( 40 | PingPong(user_message, "") 41 | ) 42 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 43 | 44 | ####### 45 | if internet_option: 46 | search_prompt = None 47 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 48 | search_prompt = tmp_prompt 49 | yield "", uis, prompt, str(res) 50 | 51 | # prepare text generating streamer & start generating 52 | gen_kwargs, streamer = pre.build( 53 | search_prompt if internet_option else prompt, 54 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 55 | res_beams, res_cache, res_sample, res_eosid, res_padid, 56 | return_token_type_ids=False 57 | ) 58 | pre.start_gen(gen_kwargs) 59 | 60 | # handling stream 61 | for ppmanager, uis in text_stream(ppm, streamer): 62 | yield "", uis, prompt, str(res) 63 | 64 | ppm = post.strip_pong(ppm) 65 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/falcon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_ids = [11] 16 | for stop_id in stop_ids: 17 | if input_ids[0][-1] == stop_id: 18 | return True 19 | return False 20 | 21 | def chat_stream( 22 | idx, local_data, user_message, state, 23 | global_context, ctx_num_lconv, ctx_sum_prompt, 24 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 25 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 26 | internet_option, serper_api_key 27 | ): 28 | res = [ 29 | state["ppmanager_type"].from_json(json.dumps(ppm)) 30 | for ppm in local_data 31 | ] 32 | 33 | ppm = res[idx] 34 | 35 | # add_ping returns a prompt structured in Alpaca form 36 | ppm.add_pingpong( 37 | PingPong(user_message, "") 38 | ) 39 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 40 | 41 | ####### 42 | if internet_option: 43 | search_prompt = None 44 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 45 | search_prompt = tmp_prompt 46 | yield "", uis, prompt, str(res) 47 | 48 | # prepare text generating streamer & start generating 49 | gen_kwargs, streamer = pre.build( 50 | search_prompt if internet_option else prompt, 51 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 52 | res_beams, res_cache, res_sample, res_eosid, res_padid, 53 | return_token_type_ids=False 54 | ) 55 | pre.start_gen(gen_kwargs) 56 | 57 | # handling stream 58 | for ppmanager, uis in text_stream(ppm, streamer): 59 | yield "", uis, prompt, str(res) 60 | 61 | ppm = post.strip_pong(ppm) 62 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/flan_alpaca.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/freewilly.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | # output = f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n" 51 | 52 | # inputs = global_vars.tokenizer( 53 | # prompt, return_tensors="pt" 54 | # ).to(global_vars.device) 55 | 56 | # output = output + global_vars.model.generate( 57 | # **inputs, 58 | # temperature=res_temp, 59 | # do_sample=res_sample, 60 | # top_p=res_topp, 61 | # top_k=res_topk, 62 | # repetition_penalty=res_rpen, 63 | # num_beams=res_beams, 64 | # use_cache=res_cache, 65 | # eos_token_id=res_eosid, 66 | # pad_token_id=res_padid, 67 | # max_new_tokens=res_mnts 68 | # ) 69 | 70 | # ppm.replace_last_pong(output) 71 | # yield "", ppm.build_uis(), prompt, str(res) 72 | 73 | ppm = post.strip_pong(ppm) 74 | yield "", ppm.build_uis(), prompt, str(res) 75 | -------------------------------------------------------------------------------- /chats/guanaco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_token_ids = [0] 16 | 17 | for stop_id in stop_token_ids: 18 | if input_ids[0][-1] == stop_id: 19 | return True 20 | return False 21 | 22 | def chat_stream( 23 | idx, local_data, user_message, state, 24 | global_context, ctx_num_lconv, ctx_sum_prompt, 25 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 26 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 27 | internet_option, serper_api_key 28 | ): 29 | res = [ 30 | state["ppmanager_type"].from_json(json.dumps(ppm)) 31 | for ppm in local_data 32 | ] 33 | 34 | ppm = res[idx] 35 | 36 | # add_ping returns a prompt structured in Alpaca form 37 | ppm.add_pingpong( 38 | PingPong(user_message, "") 39 | ) 40 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 41 | 42 | ####### 43 | if internet_option: 44 | search_prompt = None 45 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 46 | search_prompt = tmp_prompt 47 | yield "", uis, prompt, str(res) 48 | 49 | # prepare text generating streamer & start generating 50 | gen_kwargs, streamer = pre.build( 51 | search_prompt if internet_option else prompt, 52 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 53 | res_beams, res_cache, res_sample, res_eosid, res_padid, 54 | return_token_type_ids=False 55 | ) 56 | pre.start_gen(gen_kwargs) 57 | 58 | # handling stream 59 | for ppmanager, uis in text_stream(ppm, streamer): 60 | yield "", uis, prompt, str(res) 61 | 62 | ppm = post.strip_pong(ppm) 63 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/koalpaca.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/llama2.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/mistral.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/mpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __init__(self, tokenizer): 15 | super().__init__() 16 | 17 | self.stop_token_ids = tokenizer.convert_tokens_to_ids( 18 | ["<|im_end|>", "<|endoftext|>"] 19 | ) 20 | 21 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | for stop_id in self.stop_token_ids: 23 | if input_ids[0][-1] == stop_id: 24 | return True 25 | return False 26 | 27 | def chat_stream( 28 | idx, local_data, user_message, state, 29 | global_context, ctx_num_lconv, ctx_sum_prompt, 30 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 31 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 32 | internet_option, serper_api_key 33 | ): 34 | res = [ 35 | state["ppmanager_type"].from_json(json.dumps(ppm)) 36 | for ppm in local_data 37 | ] 38 | 39 | ppm = res[idx] 40 | 41 | # add_ping returns a prompt structured in Alpaca form 42 | ppm.add_pingpong( 43 | PingPong(user_message, "") 44 | ) 45 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 46 | 47 | ####### 48 | if internet_option: 49 | search_prompt = None 50 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 51 | search_prompt = tmp_prompt 52 | yield "", uis, prompt, str(res) 53 | 54 | # prepare text generating streamer & start generating 55 | gen_kwargs, streamer = pre.build( 56 | search_prompt if internet_option else prompt, 57 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 58 | res_beams, res_cache, res_sample, res_eosid, res_padid, 59 | return_token_type_ids=False 60 | ) 61 | pre.start_gen(gen_kwargs) 62 | 63 | # handling stream 64 | for ppmanager, uis in text_stream(ppm, streamer): 65 | yield "", uis, prompt, str(res) 66 | 67 | ppm = post.strip_pong(ppm) 68 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/os_stablelm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_ids = [50278, 50279, 50277, 1, 0] 16 | for stop_id in stop_ids: 17 | if input_ids[0][-1] == stop_id: 18 | return True 19 | return False 20 | 21 | def chat_stream( 22 | idx, local_data, user_message, state, 23 | global_context, ctx_num_lconv, ctx_sum_prompt, 24 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 25 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 26 | internet_option, serper_api_key 27 | ): 28 | res = [ 29 | state["ppmanager_type"].from_json(json.dumps(ppm)) 30 | for ppm in local_data 31 | ] 32 | 33 | ppm = res[idx] 34 | 35 | # add_ping returns a prompt structured in Alpaca form 36 | ppm.add_pingpong( 37 | PingPong(user_message, "") 38 | ) 39 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 40 | 41 | ####### 42 | if internet_option: 43 | search_prompt = None 44 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 45 | search_prompt = tmp_prompt 46 | yield "", uis, prompt, str(res) 47 | 48 | # prepare text generating streamer & start generating 49 | gen_kwargs, streamer = pre.build( 50 | search_prompt if internet_option else prompt, 51 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 52 | res_beams, res_cache, res_sample, res_eosid, res_padid, 53 | return_token_type_ids=False 54 | ) 55 | pre.start_gen(gen_kwargs) 56 | 57 | # handling stream 58 | for ppmanager, uis in text_stream(ppm, streamer): 59 | yield "", uis, prompt, str(res) 60 | 61 | ppm = post.strip_pong(ppm) 62 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/post.py: -------------------------------------------------------------------------------- 1 | def strip_pong(ppmanager): 2 | ppmanager.pingpongs[-1].pong = ppmanager.pingpongs[-1].pong.strip() 3 | return ppmanager -------------------------------------------------------------------------------- /chats/pre.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import global_vars 4 | from threading import Thread 5 | from transformers import TextIteratorStreamer 6 | from transformers import GenerationConfig 7 | 8 | def contains_image_markdown(string): 9 | regex = re.compile(r'!\[(.*?)\]\((.*?)\)') 10 | match = regex.search(string) 11 | return match 12 | 13 | def build_model_inputs(prompt, return_token_type_ids): 14 | model_inputs = global_vars.tokenizer( 15 | [prompt], 16 | return_tensors="pt", 17 | return_token_type_ids=return_token_type_ids 18 | ).to(global_vars.device) 19 | return model_inputs 20 | 21 | def build_streamer( 22 | timeout=20., 23 | skip_prompt=True, 24 | skip_special_tokens=True 25 | ): 26 | if global_vars.device == "cpu" or \ 27 | global_vars.device == "mps": 28 | timeout=100000. 29 | print(f"timeout set to {timeout}") 30 | 31 | streamer = TextIteratorStreamer( 32 | global_vars.tokenizer, 33 | timeout=timeout, 34 | skip_prompt=skip_prompt, 35 | skip_special_tokens=skip_special_tokens 36 | ) 37 | return streamer 38 | 39 | 40 | def build_gen_config( 41 | temperature, top_p, top_k, repetition_penalty, max_new_tokens, 42 | num_beams, use_cache, do_sample, eos_token_id, pad_token_id 43 | ): 44 | gen_config_raw = { 45 | "temperature": temperature, 46 | "top_p": top_p, 47 | "top_k": top_k, 48 | "repetition_penalty": repetition_penalty, 49 | "max_new_tokens": max_new_tokens, 50 | "num_beams": num_beams, 51 | "use_cache": use_cache, 52 | "do_sample": do_sample, 53 | "eos_token_id": eos_token_id, 54 | "pad_token_id": pad_token_id 55 | } 56 | 57 | return gen_config_raw, GenerationConfig(**gen_config_raw) 58 | 59 | def build_gen_kwargs( 60 | gen_config, 61 | model_inputs, 62 | streamer, 63 | stopping_criteria 64 | ): 65 | gen_kwargs = dict( 66 | model_inputs, 67 | streamer=streamer, 68 | stopping_criteria=stopping_criteria 69 | ) 70 | gen_kwargs.update(gen_config) 71 | return gen_kwargs 72 | 73 | def start_gen(gen_kwargs): 74 | t = Thread( 75 | target=global_vars.stream_model.generate, 76 | kwargs=gen_kwargs 77 | ) 78 | t.start() 79 | 80 | def build( 81 | prompt, 82 | temperature, top_p, top_k, repetition_penalty, max_new_tokens, 83 | num_beams, use_cache, do_sample, eos_token_id, pad_token_id, 84 | stopping_criteria=None, return_token_type_ids=True 85 | ): 86 | gen_config_raw, _ = build_gen_config( 87 | temperature, top_p, top_k, repetition_penalty, max_new_tokens, 88 | num_beams, use_cache, do_sample, eos_token_id, pad_token_id 89 | ) 90 | 91 | model_inputs = build_model_inputs( 92 | prompt, return_token_type_ids=return_token_type_ids 93 | ) 94 | streamer = build_streamer() 95 | gen_kwargs = build_gen_kwargs( 96 | gen_config_raw, 97 | model_inputs, 98 | streamer, 99 | stopping_criteria 100 | ) 101 | return gen_kwargs, streamer -------------------------------------------------------------------------------- /chats/redpajama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | # ref: https://github.com/togethercomputer/OpenChatKit/blob/7a931c7d7cf3602c93e00db6e27bdc09d3b5f70f/inference/bot.py 15 | def __init__(self, tokenizer, stop_words, stream_callback): 16 | super().__init__() 17 | self._tokenizer = tokenizer 18 | self._stop_words = stop_words 19 | self._partial_result = '' 20 | self._stream_buffer = '' 21 | self._stream_callback = stream_callback 22 | 23 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 24 | first = not self._partial_result 25 | text = self._tokenizer.decode(input_ids[0, -1]) 26 | self._partial_result += text 27 | for stop_word in self._stop_words: 28 | if stop_word in self._partial_result: 29 | return True 30 | return False 31 | 32 | def chat_stream( 33 | idx, local_data, user_message, state, 34 | global_context, ctx_num_lconv, ctx_sum_prompt, 35 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 36 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 37 | internet_option, serper_api_key 38 | ): 39 | res = [ 40 | state["ppmanager_type"].from_json(json.dumps(ppm)) 41 | for ppm in local_data 42 | ] 43 | 44 | ppm = res[idx] 45 | 46 | # add_ping returns a prompt structured in Alpaca form 47 | ppm.add_pingpong( 48 | PingPong(user_message, "") 49 | ) 50 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 51 | 52 | ####### 53 | if internet_option: 54 | search_prompt = None 55 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 56 | search_prompt = tmp_prompt 57 | yield "", uis, prompt, str(res) 58 | 59 | # prepare text generating streamer & start generating 60 | gen_kwargs, streamer = pre.build( 61 | search_prompt if internet_option else prompt, 62 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 63 | res_beams, res_cache, res_sample, res_eosid, res_padid, 64 | return_token_type_ids=False 65 | ) 66 | pre.start_gen(gen_kwargs) 67 | 68 | # handling stream 69 | for ppmanager, uis in text_stream(ppm, streamer): 70 | yield "", uis, prompt, str(res) 71 | 72 | if ppm.pingpongs[-1].pong.endswith(":"): 73 | ppm.pingpongs[-1].pong = ppm.pingpongs[-1].pong[:-1] 74 | 75 | ppm = post.strip_pong(ppm) 76 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/remote_tgi.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import sseclient 4 | 5 | async def gen_text( 6 | prompt, 7 | remote_addr, 8 | remote_port=None, 9 | remote_token=None, 10 | parameters=None 11 | ): 12 | if remote_port and remote_port != "": 13 | remote_addr = f"{remote_addr}:{remote_port}" 14 | 15 | headers={ 16 | 'Content-type': 'application/json' 17 | } 18 | if remote_token is not None and remote_token != "": 19 | headers["Authorization"] = f'Bearer {remote_token}' 20 | 21 | data = { 22 | 'inputs': prompt, 23 | 'stream': True, 24 | 'options': { 25 | 'use_cache': False, 26 | }, 27 | 'parameters': parameters 28 | } 29 | 30 | r = requests.post( 31 | remote_addr, 32 | headers=headers, 33 | data=json.dumps(data), 34 | stream=True 35 | ) 36 | 37 | client = sseclient.SSEClient(r) 38 | for event in client.events(): 39 | yield json.loads(event.data)['token']['text'] 40 | 41 | async def chat_stream( 42 | idx, local_data, user_message, state, 43 | global_context, ctx_num_lconv, ctx_sum_prompt, 44 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 45 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 46 | internet_option, serper_api_key 47 | ): 48 | res = [ 49 | state["ppmanager_type"].from_json(json.dumps(ppm)) 50 | for ppm in local_data 51 | ] 52 | 53 | ppm = res[idx] 54 | 55 | # add_ping returns a prompt structured in Alpaca form 56 | ppm.add_pingpong( 57 | PingPong(user_message, "") 58 | ) 59 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 60 | 61 | ####### 62 | if internet_option: 63 | search_prompt = None 64 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 65 | search_prompt = tmp_prompt 66 | yield "", uis, prompt, str(res) 67 | 68 | async for result in gen_text( 69 | prompt, 70 | remote_addr=global_vars.remote_addr, 71 | remote_port=global_vars.remote_port, 72 | remote_token=global_vars.remote_token, 73 | parameters={ 74 | 'max_new_tokens': res_mnts, 75 | 'do_sample': res_sample, 76 | 'return_full_text': False, 77 | 'temperature': res_temp, 78 | 'top_k': res_topk, 79 | # 'top_p": res_topp 80 | 'repetition_penalty': res_rpen 81 | } 82 | ): 83 | ppm.append_pong(result) 84 | yield "", ppm.build_uis(), prompt, str(res) 85 | 86 | ppm = post.strip_pong(ppm) 87 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/stable_vicuna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __init__(self, tokenizer): 15 | super().__init__() 16 | 17 | self.stop_token_ids = tokenizer.convert_tokens_to_ids( 18 | [""] 19 | ) 20 | 21 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | for stop_id in self.stop_token_ids: 23 | if input_ids[0][-1] == stop_id: 24 | return True 25 | return False 26 | 27 | def chat_stream( 28 | idx, local_data, user_message, state, 29 | global_context, ctx_num_lconv, ctx_sum_prompt, 30 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 31 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 32 | internet_option, serper_api_key 33 | ): 34 | res = [ 35 | state["ppmanager_type"].from_json(json.dumps(ppm)) 36 | for ppm in local_data 37 | ] 38 | 39 | ppm = res[idx] 40 | 41 | # add_ping returns a prompt structured in Alpaca form 42 | ppm.add_pingpong( 43 | PingPong(user_message, "") 44 | ) 45 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 46 | 47 | ####### 48 | if internet_option: 49 | search_prompt = None 50 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 51 | search_prompt = tmp_prompt 52 | yield "", uis, prompt, str(res) 53 | 54 | # prepare text generating streamer & start generating 55 | gen_kwargs, streamer = pre.build( 56 | search_prompt if internet_option else prompt, 57 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 58 | res_beams, res_cache, res_sample, res_eosid, res_padid, 59 | return_token_type_ids=False 60 | ) 61 | pre.start_gen(gen_kwargs) 62 | 63 | # handling stream 64 | for ppmanager, uis in text_stream(ppm, streamer): 65 | yield "", uis, prompt, str(res) 66 | 67 | ppm = post.strip_pong(ppm) 68 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/stablelm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_ids = [50278, 50279, 50277, 1, 0] 16 | for stop_id in stop_ids: 17 | if input_ids[0][-1] == stop_id: 18 | return True 19 | return False 20 | 21 | def chat_stream( 22 | idx, local_data, user_message, state, 23 | global_context, ctx_num_lconv, ctx_sum_prompt, 24 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 25 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 26 | internet_option, serper_api_key 27 | ): 28 | res = [ 29 | state["ppmanager_type"].from_json(json.dumps(ppm)) 30 | for ppm in local_data 31 | ] 32 | 33 | ppm = res[idx] 34 | 35 | # add_ping returns a prompt structured in Alpaca form 36 | ppm.add_pingpong( 37 | PingPong(user_message, "") 38 | ) 39 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 40 | 41 | ####### 42 | if internet_option: 43 | search_prompt = None 44 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 45 | search_prompt = tmp_prompt 46 | yield "", uis, prompt, str(res) 47 | 48 | # prepare text generating streamer & start generating 49 | gen_kwargs, streamer = pre.build( 50 | search_prompt if internet_option else prompt, 51 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 52 | res_beams, res_cache, res_sample, res_eosid, res_padid, 53 | return_token_type_ids=False 54 | ) 55 | pre.start_gen(gen_kwargs) 56 | 57 | # handling stream 58 | for ppmanager, uis in text_stream(ppm, streamer): 59 | yield "", uis, prompt, str(res) 60 | 61 | ppm = post.strip_pong(ppm) 62 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/starchat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_ids = [49155, 1, 0] 16 | for stop_id in stop_ids: 17 | if input_ids[0][-1] == stop_id: 18 | return True 19 | return False 20 | 21 | def chat_stream( 22 | idx, local_data, user_message, state, 23 | global_context, ctx_num_lconv, ctx_sum_prompt, 24 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 25 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 26 | internet_option, serper_api_key 27 | ): 28 | res = [ 29 | state["ppmanager_type"].from_json(json.dumps(ppm)) 30 | for ppm in local_data 31 | ] 32 | 33 | ppm = res[idx] 34 | 35 | # add_ping returns a prompt structured in Alpaca form 36 | ppm.add_pingpong( 37 | PingPong(user_message, "") 38 | ) 39 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 40 | 41 | ####### 42 | if internet_option: 43 | search_prompt = None 44 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 45 | search_prompt = tmp_prompt 46 | yield "", uis, prompt, str(res) 47 | 48 | # prepare text generating streamer & start generating 49 | gen_kwargs, streamer = pre.build( 50 | search_prompt if internet_option else prompt, 51 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 52 | res_beams, res_cache, res_sample, res_eosid, res_padid, 53 | return_token_type_ids=False 54 | ) 55 | pre.start_gen(gen_kwargs) 56 | 57 | # handling stream 58 | for ppmanager, uis in text_stream(ppm, streamer): 59 | yield "", uis, prompt, str(res) 60 | 61 | ppm = post.strip_pong(ppm) 62 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import global_vars 3 | 4 | from pingpong.context import CtxLastWindowStrategy 5 | from pingpong.context import InternetSearchStrategy, SimilaritySearcher 6 | 7 | from chats import pre, post 8 | 9 | def build_prompts(ppmanager, global_context, win_size=3): 10 | dummy_ppm = copy.deepcopy(ppmanager) 11 | 12 | dummy_ppm.ctx = global_context 13 | for pingpong in dummy_ppm.pingpongs: 14 | pong = pingpong.pong 15 | first_sentence = pong.split("\n")[0] 16 | if first_sentence != "" and \ 17 | pre.contains_image_markdown(first_sentence): 18 | pong = ' '.join(pong.split("\n")[1:]).strip() 19 | pingpong.pong = pong 20 | 21 | lws = CtxLastWindowStrategy(win_size) 22 | 23 | prompt = lws(dummy_ppm) 24 | return prompt 25 | 26 | def text_stream(ppmanager, streamer): 27 | count = 0 28 | 29 | for new_text in streamer: 30 | if count == 0: 31 | ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_name}]***\n\n") 32 | count = count + 1 33 | 34 | ppmanager.append_pong(new_text) 35 | yield ppmanager, ppmanager.build_uis() 36 | 37 | yield ppmanager, ppmanager.build_uis() 38 | 39 | def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cpu"): 40 | instruction = "Based on the provided texts below, please answer to '{ping}' in your own words. Try to explain in detail as much as possible." 41 | 42 | searcher = SimilaritySearcher.from_pretrained(device=device) 43 | iss = InternetSearchStrategy( 44 | searcher, 45 | instruction=instruction, 46 | serper_api_key=serper_api_key 47 | )(ppmanager) 48 | 49 | step_ppm = None 50 | while True: 51 | try: 52 | step_ppm, _ = next(iss) 53 | yield "", step_ppm.build_uis() 54 | except StopIteration: 55 | break 56 | 57 | search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv) 58 | yield search_prompt, ppmanager.build_uis() -------------------------------------------------------------------------------- /chats/vicuna.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import global_vars 4 | from chats import pre, post 5 | from pingpong import PingPong 6 | from gens.batch_gen import get_output_batch 7 | 8 | from chats.utils import build_prompts, text_stream, internet_search 9 | 10 | def chat_stream( 11 | idx, local_data, user_message, state, 12 | global_context, ctx_num_lconv, ctx_sum_prompt, 13 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 14 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 15 | internet_option, serper_api_key 16 | ): 17 | res = [ 18 | state["ppmanager_type"].from_json(json.dumps(ppm)) 19 | for ppm in local_data 20 | ] 21 | 22 | ppm = res[idx] 23 | 24 | # add_ping returns a prompt structured in Alpaca form 25 | ppm.add_pingpong( 26 | PingPong(user_message, "") 27 | ) 28 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 29 | 30 | ####### 31 | if internet_option: 32 | search_prompt = None 33 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 34 | search_prompt = tmp_prompt 35 | yield "", uis, prompt, str(res) 36 | 37 | # prepare text generating streamer & start generating 38 | gen_kwargs, streamer = pre.build( 39 | search_prompt if internet_option else prompt, 40 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 41 | res_beams, res_cache, res_sample, res_eosid, res_padid, 42 | return_token_type_ids=False 43 | ) 44 | pre.start_gen(gen_kwargs) 45 | 46 | # handling stream 47 | for ppmanager, uis in text_stream(ppm, streamer): 48 | yield "", uis, prompt, str(res) 49 | 50 | ppm = post.strip_pong(ppm) 51 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/wizard_coder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_ids = [0] 16 | for stop_id in stop_ids: 17 | if input_ids[0][-1] == stop_id: 18 | return True 19 | return False 20 | 21 | def chat_stream( 22 | idx, local_data, user_message, state, 23 | global_context, ctx_num_lconv, ctx_sum_prompt, 24 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 25 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 26 | internet_option, serper_api_key 27 | ): 28 | res = [ 29 | state["ppmanager_type"].from_json(json.dumps(ppm)) 30 | for ppm in local_data 31 | ] 32 | 33 | ppm = res[idx] 34 | 35 | # add_ping returns a prompt structured in Alpaca form 36 | ppm.add_pingpong( 37 | PingPong(user_message, "") 38 | ) 39 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 40 | 41 | ####### 42 | if internet_option: 43 | search_prompt = None 44 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 45 | search_prompt = tmp_prompt 46 | yield "", uis, prompt, str(res) 47 | 48 | # prepare text generating streamer & start generating 49 | gen_kwargs, streamer = pre.build( 50 | search_prompt if internet_option else prompt, 51 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 52 | res_beams, res_cache, res_sample, res_eosid, res_padid, 53 | return_token_type_ids=False 54 | ) 55 | pre.start_gen(gen_kwargs) 56 | 57 | # handling stream 58 | for ppmanager, uis in text_stream(ppm, streamer): 59 | yield "", uis, prompt, str(res) 60 | 61 | ppm = post.strip_pong(ppm) 62 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/wizard_falcon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import copy 5 | import json 6 | import global_vars 7 | from chats import pre, post 8 | from pingpong import PingPong 9 | from gens.batch_gen import get_output_batch 10 | 11 | from chats.utils import build_prompts, text_stream, internet_search 12 | 13 | class StopOnTokens(StoppingCriteria): 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | stop_ids = [11] 16 | for stop_id in stop_ids: 17 | if input_ids[0][-1] == stop_id: 18 | return True 19 | return False 20 | 21 | def chat_stream( 22 | idx, local_data, user_message, state, 23 | global_context, ctx_num_lconv, ctx_sum_prompt, 24 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 25 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 26 | internet_option, serper_api_key 27 | ): 28 | res = [ 29 | state["ppmanager_type"].from_json(json.dumps(ppm)) 30 | for ppm in local_data 31 | ] 32 | 33 | ppm = res[idx] 34 | 35 | # add_ping returns a prompt structured in Alpaca form 36 | ppm.add_pingpong( 37 | PingPong(user_message, "") 38 | ) 39 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 40 | 41 | ####### 42 | if internet_option: 43 | search_prompt = None 44 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 45 | search_prompt = tmp_prompt 46 | yield "", uis, prompt, str(res) 47 | 48 | # prepare text generating streamer & start generating 49 | gen_kwargs, streamer = pre.build( 50 | search_prompt if internet_option else prompt, 51 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 52 | res_beams, res_cache, res_sample, res_eosid, res_padid, 53 | return_token_type_ids=False 54 | ) 55 | pre.start_gen(gen_kwargs) 56 | 57 | # handling stream 58 | for ppmanager, uis in text_stream(ppm, streamer): 59 | yield "", uis, prompt, str(res) 60 | 61 | ppm = post.strip_pong(ppm) 62 | yield "", ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /chats/xgen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import StoppingCriteria, StoppingCriteriaList 3 | 4 | import re 5 | import copy 6 | import json 7 | import global_vars 8 | from chats import pre, post 9 | from pingpong import PingPong 10 | from gens.batch_gen import get_output_batch 11 | 12 | from chats.utils import build_prompts, internet_search 13 | 14 | def text_stream(ppmanager, streamer): 15 | count = 0 16 | dummy_ppm = copy.deepcopy(ppmanager) 17 | 18 | for new_text in streamer: 19 | if count == 0: 20 | ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") 21 | dummy_ppm.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") 22 | count = count + 1 23 | 24 | ppmanager.append_pong(new_text) 25 | dummy_ppm.append_pong(new_text) 26 | 27 | if "Assistant: " in ppmanager.pingpongs[-1].pong: 28 | dummy_ppm.replace_last_pong( 29 | dummy_ppm.pingpongs[-1].pong.replace("Assistant: ", "") 30 | ) 31 | 32 | if "<|endoftext|>" in ppmanager.pingpongs[-1].pong: 33 | ppmanager.replace_last_pong( 34 | re.sub(r'[\s|\n].*<\|endoftext\|>.*[\s|\n]', ' ', ppmanager.pingpongs[-1].pong) 35 | ) 36 | dummy_ppm.replace_last_pong( 37 | re.sub(r'[\s|\n].*<\|endoftext\|>.*[\s|\n]', ' ', dummy_ppm.pingpongs[-1].pong) 38 | ) 39 | break 40 | 41 | yield ppmanager, dummy_ppm.build_uis() 42 | 43 | yield ppmanager, dummy_ppm.build_uis() 44 | 45 | def chat_stream( 46 | idx, local_data, user_message, state, 47 | global_context, ctx_num_lconv, ctx_sum_prompt, 48 | res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, 49 | sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, 50 | internet_option, serper_api_key 51 | ): 52 | res = [ 53 | state["ppmanager_type"].from_json(json.dumps(ppm)) 54 | for ppm in local_data 55 | ] 56 | 57 | ppm = res[idx] 58 | 59 | # add_ping returns a prompt structured in Alpaca form 60 | ppm.add_pingpong( 61 | PingPong(user_message, "") 62 | ) 63 | prompt = build_prompts(ppm, global_context, ctx_num_lconv) 64 | 65 | ####### 66 | if internet_option: 67 | search_prompt = None 68 | for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): 69 | search_prompt = tmp_prompt 70 | yield "", uis, prompt, str(res) 71 | 72 | # prepare text generating streamer & start generating 73 | gen_kwargs, streamer = pre.build( 74 | search_prompt if internet_option else prompt, 75 | res_temp, res_topp, res_topk, res_rpen, res_mnts, 76 | res_beams, res_cache, res_sample, res_eosid, res_padid, 77 | return_token_type_ids=False 78 | ) 79 | pre.start_gen(gen_kwargs) 80 | 81 | # handling stream 82 | for ppmanager, uis in text_stream(ppm, streamer): 83 | yield "", uis, prompt, str(res) 84 | 85 | ppm = post.strip_pong(ppm) 86 | dummy_ppm = copy.deepcopy(ppm) 87 | 88 | if "Assistant: " in dummy_ppm.pingpongs[-1].pong: 89 | dummy_ppm.replace_last_pong( 90 | dummy_ppm.pingpongs[-1].pong.replace("Assistant: ", "") 91 | ) 92 | 93 | yield "", dummy_ppm.build_uis(), prompt, str(res) -------------------------------------------------------------------------------- /configs/constraints_config.yaml: -------------------------------------------------------------------------------- 1 | constraints: 2 | max_context: 1000 3 | max_prompt: 300 4 | max_conv_len: 1500 -------------------------------------------------------------------------------- /configs/response_configs/baize.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | bos_token_id: 0 11 | eos_token_id: 1 12 | pad_token_id: 0 -------------------------------------------------------------------------------- /configs/response_configs/camel.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | pad_token_id: 50257 11 | eos_token_id: 50256 -------------------------------------------------------------------------------- /configs/response_configs/default.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/default_4096.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 4096 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/falcon.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 10 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 512 9 | do_sample: True 10 | pad_token_id: 0 11 | bos_token_id: 1 12 | eos_token_id: 0 -------------------------------------------------------------------------------- /configs/response_configs/flan.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.8 3 | top_p: 0.95 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: False 7 | repetition_penalty: 1.2 8 | max_new_tokens: 256 9 | do_sample: True 10 | -------------------------------------------------------------------------------- /configs/response_configs/freewilly.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True 10 | bos_token_id: 1 11 | eos_token_id: 2 12 | pad_token_id: 0 -------------------------------------------------------------------------------- /configs/response_configs/gpt4_alpaca.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 512 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/guanaco.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/koalpaca.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | eos_token_id: 2 11 | pad_token_id: 2 12 | -------------------------------------------------------------------------------- /configs/response_configs/llama2.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 4096 9 | do_sample: True 10 | pad_token_id: 32000 11 | bos_token_id: 1 12 | eos_token_id: 2 -------------------------------------------------------------------------------- /configs/response_configs/mistral.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | bos_token_id: 1 11 | eos_token_id: 2 12 | -------------------------------------------------------------------------------- /configs/response_configs/mistral_openhermes.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | bos_token_id: 1 11 | eos_token_id: 32000 12 | -------------------------------------------------------------------------------- /configs/response_configs/redpajama.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1.0 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | eos_token_id: 0 11 | pad_token_id: 0 -------------------------------------------------------------------------------- /configs/response_configs/stablelm.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1.0 3 | top_p: 0.9 4 | top_k: 1000 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 512 9 | do_sample: True 10 | eos_token_id: 0 11 | pad_token_id: 1 12 | -------------------------------------------------------------------------------- /configs/response_configs/stackllama.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.9 3 | top_p: 0.95 4 | # top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 256 9 | do_sample: True 10 | early_stopping: True 11 | -------------------------------------------------------------------------------- /configs/response_configs/starchat.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.5 3 | top_p: 0.95 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True 10 | eos_token_id: 0 11 | bos_token_id: 0 12 | pad_token_id: 0 13 | -------------------------------------------------------------------------------- /configs/response_configs/t5_vicuna.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/upstage_llama.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/upstage_llama2.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 1024 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/vicuna.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True -------------------------------------------------------------------------------- /configs/response_configs/wizard-coder.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.2 3 | top_p: 0.9 4 | top_k: 40 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True 10 | eos_token_id: 0 11 | bos_token_id: 0 12 | pad_token_id: 49152 13 | -------------------------------------------------------------------------------- /configs/response_configs/wizardlm.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1.0 3 | top_p: 0.95 4 | top_k: 40 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True -------------------------------------------------------------------------------- /configs/summarization_configs/camel.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | max_new_tokens: 1024 8 | do_sample: True 9 | repetition_penalty: 1.5 10 | pad_token_id: 50257 11 | eos_token_id: 50256 -------------------------------------------------------------------------------- /configs/summarization_configs/default.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | max_new_tokens: 1024 8 | do_sample: True 9 | repetition_penalty: 1.5 10 | 11 | -------------------------------------------------------------------------------- /configs/summarization_configs/koalpaca.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | do_sample: True 8 | repetition_penalty: 1.2 9 | max_new_tokens: 512 10 | eos_token_id: 2 11 | pad_token_id: 2 -------------------------------------------------------------------------------- /configs/summarization_configs/redpajama.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1.0 3 | top_p: 0.9 4 | top_k: 1000 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 512 9 | do_sample: True 10 | eos_token_id: 0 11 | pad_token_id: 1 12 | -------------------------------------------------------------------------------- /configs/summarization_configs/stablelm.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 1 3 | top_p: 0.9 4 | top_k: 1000 5 | num_beams: 1 6 | use_cache: True 7 | do_sample: True 8 | repetition_penalty: 1.2 9 | max_new_tokens: 512 10 | eos_token_id: 0 11 | pad_token_id: 1 -------------------------------------------------------------------------------- /configs/summarization_configs/t5_vicuna.yaml: -------------------------------------------------------------------------------- 1 | generation_config: 2 | temperature: 0.95 3 | top_p: 0.9 4 | top_k: 50 5 | num_beams: 1 6 | use_cache: True 7 | repetition_penalty: 1.2 8 | max_new_tokens: 2048 9 | do_sample: True -------------------------------------------------------------------------------- /discord.dstack.yml: -------------------------------------------------------------------------------- 1 | type: task 2 | 3 | env: 4 | # (Required) Specify your Discord bot token. 5 | - DISCORD_BOT_TOKEN= 6 | # (Required) Specify the name of the model. See `README.md` for supported models. 7 | - DISCORD_BOT_MODEL_NAME=alpaca-lora-7b 8 | # (Optional) Specify your Hugging Face token 9 | - HUGGING_FACE_HUB_TOKEN= 10 | # (Optional) Specify your Serper API Key to enable Internet search support. 11 | - LLMCHAT_SERPER_API_KEY= 12 | 13 | commands: 14 | - pip install -r requirements.txt --progress-bar off 15 | - LLMCHAT_APP_MODE=DISCORD python entry_point.py 16 | -------------------------------------------------------------------------------- /discord_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import types 5 | import asyncio 6 | import argparse 7 | from urlextract import URLExtract 8 | from urllib.request import urlopen 9 | from concurrent.futures import ThreadPoolExecutor 10 | 11 | import discord 12 | from discord.errors import HTTPException 13 | 14 | import global_vars 15 | from pingpong.context import InternetSearchStrategy, SimilaritySearcher 16 | 17 | from discordbot.req import ( 18 | tgi_gen, vanilla_gen, build_prompt, build_ppm 19 | ) 20 | from discordbot.flags import parse_req 21 | from discordbot import helps, post 22 | from dumb_utils import URLSearchStrategy 23 | 24 | model_info = json.load(open("model_cards.json")) 25 | 26 | intents = discord.Intents.default() 27 | intents.members = True 28 | client = discord.Client(intents=intents) 29 | queue = asyncio.Queue() 30 | 31 | special_words = [ 32 | "help", 33 | "model-info", 34 | "default-params", 35 | ] 36 | max_response_length = 2000 37 | 38 | async def build_prompt_and_reply(executor, user_name, user_id): 39 | other_job_on_progress = False 40 | loop = asyncio.get_running_loop() 41 | 42 | print(queue.qsize()) 43 | msg = await queue.get() 44 | user_msg, user_args = parse_req( 45 | msg.content.replace(f"@{user_name} ", "").replace(f"<@{user_id}> ", ""), global_vars.gen_config 46 | ) 47 | 48 | if user_msg == "help": 49 | await msg.channel.send(helps.get_help()) 50 | elif user_msg == "model-info": 51 | await msg.channel.send(helps.get_model_info(model_name, model_info)) 52 | elif user_msg == "default-params": 53 | await msg.channel.send(helps.get_default_params(global_vars.gen_config, user_args["max-windows"])) 54 | else: 55 | try: 56 | ppm = await build_ppm(msg, user_msg, user_name, user_id) 57 | 58 | if user_args["internet"] and serper_api_key is not None: 59 | other_job_on_progress = True 60 | progress_msg = await msg.reply("Progress 🚧", mention_author=False) 61 | 62 | internet_search_ppm = copy.deepcopy(ppm) 63 | internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query." 64 | internet_search_ppm.pingpongs[-1].ping = internet_search_prompt 65 | internet_search_prompt = await build_prompt( 66 | internet_search_ppm, 67 | ctx_include=False, 68 | win_size=user_args["max-windows"] 69 | ) 70 | if tgi_server_addr is None: 71 | internet_search_prompt_response = await loop.run_in_executor( 72 | executor, gen_method, internet_search_prompt, user_args 73 | ) 74 | else: 75 | internet_search_prompt_response = await gen_method(internet_search_prompt, user_args) 76 | internet_search_prompt_response = post.clean(internet_search_prompt_response) 77 | 78 | ppm.pingpongs[-1].ping = internet_search_prompt_response 79 | 80 | await progress_msg.edit( 81 | content=f"• Search query re-organized by LLM: {internet_search_prompt_response}", 82 | suppress=True 83 | ) 84 | 85 | searcher = SimilaritySearcher.from_pretrained(device=global_vars.device) 86 | 87 | logs = "" 88 | for step_ppm, step_msg in InternetSearchStrategy( 89 | searcher, serper_api_key=serper_api_key 90 | )(ppm, search_query=internet_search_prompt_response, top_k=8): 91 | ppm = step_ppm 92 | logs = logs + step_msg + "\n" 93 | await progress_msg.edit(content=logs, suppress=True) 94 | else: 95 | url_extractor = URLExtract() 96 | urls = url_extractor.find_urls(user_msg) 97 | print(f"urls = {urls}") 98 | 99 | if len(urls) > 0: 100 | progress_msg = await msg.reply("Progress 🚧", mention_author=False) 101 | 102 | other_job_on_progress = True 103 | searcher = SimilaritySearcher.from_pretrained(device=global_vars.device) 104 | 105 | logs = "" 106 | for step_result, step_ppm, step_msg in URLSearchStrategy(searcher)(ppm, urls, top_k=8): 107 | if step_result is True: 108 | ppm = step_ppm 109 | logs = logs + step_msg + "\n" 110 | await progress_msg.edit(content=logs, suppress=True) 111 | else: 112 | ppm = step_ppm 113 | logs = logs + step_msg + "\n" 114 | await progress_msg.edit(content=logs, suppress=True) 115 | await asyncio.sleep(2) 116 | break 117 | 118 | prompt = await build_prompt(ppm, win_size=user_args["max-windows"]) 119 | if tgi_server_addr is None: 120 | response = await loop.run_in_executor(executor, gen_method, prompt, user_args) 121 | response = post.clean(response) 122 | else: 123 | response = await gen_method(prompt, user_args) 124 | 125 | response = f"**{model_name}** 💬\n{response.strip()}" 126 | if len(response) >= max_response_length: 127 | response = response[:max_response_length] 128 | 129 | if other_job_on_progress is True: 130 | await progress_msg.delete() 131 | 132 | await msg.reply(response, mention_author=False) 133 | except IndexError: 134 | await msg.channel.send("Index error") 135 | except HTTPException: 136 | pass 137 | 138 | async def background_task(user_name, user_id, max_workers): 139 | executor = ThreadPoolExecutor(max_workers=max_workers) 140 | print("Task Started. Waiting for inputs.") 141 | while True: 142 | await build_prompt_and_reply(executor, user_name, user_id) 143 | 144 | @client.event 145 | async def on_ready(): 146 | print(f"Logged in as {client.user}") 147 | asyncio.get_running_loop().create_task( 148 | background_task( 149 | client.user.name, 150 | client.user.id, 151 | max_workers, 152 | ) 153 | ) 154 | 155 | @client.event 156 | async def on_message(message): 157 | if message.author == client.user: 158 | return 159 | 160 | if isinstance(message.channel, discord.channel.DMChannel) or\ 161 | (client.user and client.user.mentioned_in(message)): 162 | await queue.put(message) 163 | 164 | def off_modes(args): 165 | args.mode_cpu = False 166 | args.mode_mps = False 167 | args.mode_8bit = False 168 | args.mode_4bit = False 169 | args.mode_full_gpu = False 170 | return args 171 | 172 | def discord_main(args): 173 | if args.token is None: 174 | args.token = os.getenv('DISCORD_BOT_TOKEN') 175 | 176 | if args.model_name is None: 177 | args.model_name = os.getenv('DISCORD_BOT_MODEL_NAME') 178 | 179 | if args.token is None or args.model_name is None: 180 | print('Either or both of token and model-name is not provided') 181 | print('Set them through CLI or environment variables(DISCORD_BOT_TOKEN, DISCORD_BOT_MODEL_NAME)') 182 | quit() 183 | 184 | if os.getenv('DISCORD_BOT_MAX_WORKERS'): 185 | args.max_workers = int(os.getenv('DISCORD_BOT_MAX_WORKERS')) 186 | 187 | if os.getenv('DISCORD_BOT_LOAD_MODE'): 188 | mode = os.getenv('DISCORD_BOT_LOAD_MODE') 189 | 190 | if mode == "CPU": 191 | off_modes(args) 192 | args.mode_cpu = True 193 | elif mode == "MPS": 194 | off_modes(args) 195 | args.mode_mps = True 196 | elif mode == "8BIT": 197 | off_modes(args) 198 | args.mode_8bit = True 199 | elif mode == "4BIT": 200 | off_modes(args) 201 | args.mode_4bit = True 202 | elif mode == "HALF": 203 | off_modes(args) 204 | args.mode_full_gpu = True 205 | 206 | if os.getenv('TGI_SERVER_ADDR') and os.getenv('TGI_SERVER_PORT'): 207 | args.tgi_server_addr = os.getenv('TGI_SERVER_ADDR') 208 | args.tgi_server_port = os.getenv('TGI_SERVER_PORT') 209 | 210 | global max_workers 211 | global model_name 212 | global serper_api_key 213 | global gen_method 214 | global tgi_server_addr 215 | global tgi_server_port 216 | 217 | max_workers = args.max_workers 218 | model_name = args.model_name 219 | serper_api_key = args.serper_api_key 220 | gen_method = vanilla_gen 221 | tgi_server_addr = None 222 | tgi_server_port = None 223 | 224 | selected_model_info = model_info[model_name] 225 | 226 | tmp_args = types.SimpleNamespace() 227 | tmp_args.model_name = args.model_name 228 | tmp_args.base_url = selected_model_info['hub(base)'] 229 | tmp_args.ft_ckpt_url = selected_model_info['hub(ckpt)'] 230 | tmp_args.gptq_url = None 231 | tmp_args.gptq_base_url = None 232 | tmp_args.gen_config_path = selected_model_info['default_gen_config'] 233 | tmp_args.gen_config_summarization_path = selected_model_info['default_gen_config'] 234 | tmp_args.force_download_ckpt = False 235 | tmp_args.thumbnail_tiny = selected_model_info['thumb-tiny'] 236 | 237 | tmp_args.mode_cpu = args.mode_cpu 238 | tmp_args.mode_mps = args.mode_mps 239 | tmp_args.mode_8bit = args.mode_8bit 240 | tmp_args.mode_4bit = args.mode_4bit 241 | tmp_args.mode_full_gpu = args.mode_full_gpu 242 | tmp_args.mode_gptq = False 243 | tmp_args.mode_mps_gptq = False 244 | tmp_args.mode_cpu_gptq = False 245 | tmp_args.mode_remote_tgi = False 246 | tmp_args.local_files_only = args.local_files_only 247 | 248 | if args.tgi_server_addr is not None and \ 249 | args.tgi_server_port is not None: 250 | 251 | tgi_server_addr = args.tgi_server_addr 252 | tgi_server_port = args.tgi_server_port 253 | 254 | tmp_args.mode_remote_tgi = True 255 | tmp_args.remote_addr = args.tgi_server_addr 256 | tmp_args.remote_port = args.tgi_server_port 257 | tmp_args.remote_token = None 258 | 259 | gen_method = tgi_gen 260 | 261 | try: 262 | global_vars.initialize_globals(tmp_args) 263 | except RuntimeError as e: 264 | print("GPU memory is not enough to load this model.") 265 | quit() 266 | 267 | client.run(args.token) 268 | 269 | if __name__ == "__main__": 270 | parser = argparse.ArgumentParser() 271 | # can be set via environment variable 272 | # --token == DISCORD_BOT_TOKEN 273 | # --model-name == DISCORD_BOT_MODEL_NAME 274 | parser.add_argument('--token', default=None, type=str) 275 | parser.add_argument('--model-name', default=None, type=str) 276 | parser.add_argument('--max-workers', default=1, type=int) 277 | parser.add_argument('--mode-cpu', default=False, action=argparse.BooleanOptionalAction) 278 | parser.add_argument('--mode-mps', default=False, action=argparse.BooleanOptionalAction) 279 | parser.add_argument('--mode-8bit', default=False, action=argparse.BooleanOptionalAction) 280 | parser.add_argument('--mode-4bit', default=False, action=argparse.BooleanOptionalAction) 281 | parser.add_argument('--mode-full-gpu', default=True, action=argparse.BooleanOptionalAction) 282 | parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) 283 | parser.add_argument('--serper-api-key', default=None, type=str) 284 | parser.add_argument('--tgi-server-addr', default=None, type=str) 285 | parser.add_argument('--tgi-server-port', default=None, type=str) 286 | args = parser.parse_args() 287 | 288 | discord_main(args) 289 | -------------------------------------------------------------------------------- /discordbot/flags.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | known_flags_def = { 5 | "max-new-tokens": { 6 | "default": None, 7 | "type": int 8 | }, 9 | "temperature": { 10 | "default": None, 11 | "type": float 12 | }, 13 | "max-windows": { 14 | "default": 3, 15 | "type": int 16 | }, 17 | "do-sample": { 18 | "default": True, 19 | "type": bool 20 | }, 21 | "top-p": { 22 | "default": None, 23 | "type": float 24 | }, 25 | "internet": { 26 | "default": False, 27 | "type": bool 28 | } 29 | } 30 | 31 | def parse_req(message, gen_config): 32 | message, flags = parse_known_flags( 33 | message, 34 | known_flags_def, 35 | gen_config 36 | ) 37 | return message, flags 38 | 39 | def init_flags(known_flags_def, gen_config): 40 | gen_config_attrs = vars(gen_config) 41 | known_flags = list(known_flags_def.keys()) 42 | flags = {} 43 | types = {} 44 | 45 | for known_flag in known_flags: 46 | flags[known_flag] = known_flags_def[known_flag]['default'] 47 | types[known_flag] = known_flags_def[known_flag]['type'] 48 | 49 | known_flag_underscore = known_flag.replace("-", "_") 50 | if known_flag_underscore in list(gen_config_attrs.keys()): 51 | if gen_config_attrs[known_flag_underscore] is not None: 52 | flags[known_flag] = gen_config_attrs[known_flag_underscore] 53 | 54 | return known_flags, flags, types 55 | 56 | def parse_known_flags(string, known_flags_def, gen_config, prefix="--"): 57 | words = string.split() 58 | known_flags, flags, types = init_flags(known_flags_def, gen_config) 59 | 60 | for i in range(len(words)): 61 | word = words[i] 62 | if word.startswith(prefix): 63 | flag = word[2:] 64 | if flag in known_flags: 65 | if types[flag] == bool: 66 | flags[flag] = True 67 | else: 68 | flags[flag] = None 69 | 70 | value = words[i+1:i+2] 71 | if len(value) != 0: 72 | value = value[0] 73 | try: 74 | flags[flag] = types[flag](value) 75 | except ValueError: 76 | continue 77 | i = i+1 78 | 79 | for k, v in flags.items(): 80 | sub_str = f"{prefix}{k}" 81 | if v is not None: 82 | if not isinstance(v, bool): 83 | sub_str = sub_str + " " + str(v) 84 | 85 | print(sub_str) 86 | string = string.replace(sub_str, "") 87 | 88 | return string.strip(), flags -------------------------------------------------------------------------------- /discordbot/helps.py: -------------------------------------------------------------------------------- 1 | def get_help(): 2 | help_msg = """Type one of the following for more information about this chatbot 3 | - **`help`:** list of supported commands 4 | - **`model-info`:** get currently selected model card 5 | - **`default-params`:** get default parameters of the Generation Config 6 | 7 | You can start conversation by metioning the chatbot `@{chatbot name} {your prompt} {options}`, and the following options are supported. 8 | - **`--top-p {float}`**: determins how many tokens to pick from the top tokens based on the sum of their probabilities(<= `top-p`). 9 | - **`--temperature {float}`**: used to modulate the next token probabilities. 10 | - **`--max-new-tokens {integer}`**: maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. 11 | - **`--do-sample`**: determines whether or not to use sampling ; use greedy decoding otherwise. 12 | - **`--max-windows {integer}`**: determines how many past conversations to look up as a reference. 13 | - **`--internet`**: determines whether or not to use internet search capabilities. 14 | 15 | If you want to continue conversation based on past conversation histories, you can simply `reply` to chatbot's message. At this time, you don't need to metion its name. However, you need to specify options in every turn. For instance, if you want to `reply` based on internet search information, then you shoul specify `--internet` in your message. 16 | """ 17 | return help_msg 18 | 19 | def get_model_info(model_name, model_infos): 20 | selected_model_info = model_infos[model_name] 21 | help_msg = f"""## {model_name} 22 | - **Description:** {selected_model_info['desc']} 23 | - **Number of parameters:** {selected_model_info['parameters']} 24 | - **Hugging Face Hub (base):** {selected_model_info['hub(base)']} 25 | - **Hugging Face Hub (ckpt):** {selected_model_info['hub(ckpt)']} 26 | """ 27 | return help_msg 28 | 29 | 30 | def get_default_params(gen_config, max_windows): 31 | help_msg = f"""{gen_config}, max-windows = {max_windows}""" 32 | return help_msg -------------------------------------------------------------------------------- /discordbot/post.py: -------------------------------------------------------------------------------- 1 | def clean(text): 2 | if text.endswith(""): 3 | text = text[:-len("")] 4 | 5 | if text.endswith("<|endoftext|>"): 6 | text = text[:-len("<|endoftext|>")] 7 | 8 | return text -------------------------------------------------------------------------------- /discordbot/req.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import global_vars 4 | 5 | from discordbot.utils import ( 6 | get_chat_manager, 7 | get_global_context 8 | ) 9 | from discordbot.flags import ( 10 | parse_known_flags, 11 | known_flags_def 12 | ) 13 | 14 | from pingpong import PingPong 15 | from pingpong.context import CtxLastWindowStrategy 16 | 17 | from discord import NotFound 18 | 19 | from transformers import GenerationConfig 20 | from text_generation import Client 21 | 22 | async def tgi_gen(prompt, args): 23 | gen_config = copy.deepcopy(global_vars.gen_config) 24 | if args["max-new-tokens"] is not None: 25 | gen_config.max_new_tokens = args["max-new-tokens"] 26 | if args["temperature"] is not None: 27 | gen_config.temperature = args["temperature"] 28 | if args["do-sample"] is not None: 29 | gen_config.do_sample = args["do-sample"] 30 | if args["top-p"] is not None: 31 | gen_config.top_p = args["top-p"] 32 | 33 | client = Client( 34 | f"http://{global_vars.remote_addr}:{global_vars.remote_port}", 35 | timeout=100 36 | ) 37 | 38 | response = client.generate( 39 | prompt, 40 | do_sample=gen_config.do_sample, 41 | max_new_tokens=512, 42 | repetition_penalty=gen_config.repetition_penalty, 43 | temperature=gen_config.repetition_penalty, 44 | top_k=gen_config.top_k, 45 | top_p=gen_config.top_p 46 | ) 47 | return response.generated_text 48 | 49 | 50 | def vanilla_gen(prompt, args): 51 | input_ids = global_vars.tokenizer(prompt, return_tensors="pt").input_ids.to(global_vars.device) 52 | 53 | gen_config = copy.deepcopy(global_vars.gen_config) 54 | if args["max-new-tokens"] is not None: 55 | gen_config.max_new_tokens = args["max-new-tokens"] 56 | if args["temperature"] is not None: 57 | gen_config.temperature = args["temperature"] 58 | if args["do-sample"] is not None: 59 | gen_config.do_sample = args["do-sample"] 60 | if args["top-p"] is not None: 61 | gen_config.top_p = args["top-p"] 62 | 63 | generated_ids = global_vars.model.generate( 64 | input_ids=input_ids, 65 | generation_config=gen_config 66 | ) 67 | response = global_vars.tokenizer.decode(generated_ids[0][input_ids.shape[-1]:]) 68 | return response 69 | 70 | async def build_prompt(ppmanager, ctx_include=True, win_size=3): 71 | dummy_ppm = copy.deepcopy(ppmanager) 72 | if ctx_include: 73 | dummy_ppm.ctx = get_global_context(global_vars.model_type) 74 | else: 75 | dummy_ppm.ctx = "" 76 | 77 | lws = CtxLastWindowStrategy(win_size) 78 | return lws(dummy_ppm) 79 | 80 | async def build_ppm(msg, msg_content, username, user_id): 81 | ppm = get_chat_manager(global_vars.model_type) 82 | 83 | channel = msg.channel 84 | user_msg = msg_content 85 | 86 | packs = [] 87 | partial_count = 0 88 | total_count = 0 89 | 90 | while True: 91 | try: 92 | if msg.reference is not None: 93 | ref_id = msg.reference.message_id 94 | msg = await channel.fetch_message(ref_id) 95 | msg_content = msg.content.replace(f"@{username} ", "").replace(f"<@{user_id}> ", "") 96 | try: 97 | idx = msg_content.index("💬") 98 | msg_content = msg_content[idx+1:].strip() 99 | except: 100 | msg_content = msg_content.strip() 101 | 102 | msg_content, _ = parse_known_flags( 103 | msg_content, 104 | known_flags_def, 105 | global_vars.gen_config 106 | ) 107 | print(msg_content) 108 | 109 | packs.insert( 110 | 0, msg_content 111 | ) 112 | 113 | partial_count = partial_count + 1 114 | if partial_count >= 2: 115 | partial_count = 0 116 | else: 117 | break 118 | 119 | except NotFound: 120 | break 121 | 122 | for idx in range(0, len(packs), 2): 123 | ppm.add_pingpong( 124 | PingPong(packs[idx], packs[idx+1]) 125 | ) 126 | 127 | ppm.add_pingpong( 128 | PingPong(user_msg, "") 129 | ) 130 | print(ppm.pingpongs) 131 | 132 | return ppm 133 | -------------------------------------------------------------------------------- /dumb_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | import random 5 | import string 6 | import http.client 7 | 8 | import chromadb 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from urllib.request import urlopen 13 | from urllib.error import HTTPError 14 | from bs4 import BeautifulSoup 15 | from transformers import AutoTokenizer, AutoModel 16 | 17 | from pingpong import PingPong 18 | from pingpong.pingpong import PPManager 19 | from pingpong.context.strategy import CtxStrategy 20 | 21 | default_instruction = """Below texts come from the webpages that you provided in '{ping}'. Try to explain '{ping}' in detail as much as possible. Your exaplanation should almost based on the text below. Try not to write anything unrelated information. 22 | ===================== 23 | """ 24 | 25 | class URLSearchStrategy(CtxStrategy): 26 | def __init__( 27 | self, 28 | similarity_searcher, 29 | instruction=default_instruction, 30 | db_name=None, chunk_size=1000 31 | ): 32 | self.searcher = similarity_searcher 33 | self.instruction = instruction 34 | self.db_name = db_name 35 | self.chunk_size = chunk_size 36 | 37 | if self.searcher is None: 38 | raise ValueError("SimilaritySearcher is not set.") 39 | 40 | if self.db_name is None: 41 | self.db_name = URLSearchStrategy.id_generator() 42 | 43 | def __call__(self, ppmanager: PPManager, urls, top_k=8, max_tokens=1024, keep_original=False): 44 | ppm = copy.deepcopy(ppmanager) 45 | last_ping = ppm.pingpongs[-1].ping 46 | # 1st yield 47 | ppm.add_pong("![loading](https://i.ibb.co/RPSPL5F/loading.gif)\n") 48 | ppm.append_pong("• Creating Chroma DB Collection...") 49 | yield True, ppm, "• Creating Chroma DB Collection √" 50 | 51 | chroma_client = chromadb.Client() 52 | try: 53 | chroma_client.delete_collection(self.db_name) 54 | except: 55 | pass 56 | 57 | col = chroma_client.create_collection(self.db_name) 58 | 59 | # 2nd yield 60 | ppm.replace_last_pong("![loading](https://i.ibb.co/RPSPL5F/loading.gif)\n") 61 | ppm.append_pong("• Creating Chroma DB Collection √\n") 62 | ppm.append_pong("• URL Searching...\n") 63 | yield True, ppm, "• URL Searching √" 64 | 65 | # HTML parsing 66 | search_results = [] 67 | success_urls = [] 68 | for url in urls: 69 | parse_result, contents = self._parse_html(url) 70 | if parse_result == True: 71 | success_urls.append(url) 72 | search_results.append(contents) 73 | 74 | ppm.append_pong(f" - {url} √\n") 75 | yield True, ppm, f" ▷ {url} √" 76 | 77 | if len(search_results) == 0: 78 | yield False, ppm, "There is no valid URLs. Check if there are trailing characters such as .(dot), ,(comma), etc., LLM will answer to your question based on its base knowledge." 79 | 80 | if len(' '.join(search_results).split(' ')) < max_tokens: 81 | final_result = ' '.join(search_results) 82 | 83 | # 3rd yield 84 | ppm.replace_last_pong("![loading](https://i.ibb.co/RPSPL5F/loading.gif)\n") 85 | ppm.append_pong("• Creating Chroma DB Collection √\n") 86 | ppm.append_pong("• URL Searching √\n") 87 | for url in success_urls: 88 | ppm.append_pong(f" - {url} √\n") 89 | yield True, ppm, "• Done √" 90 | 91 | last_ping = self.instruction.format(ping=last_ping) 92 | last_ping = last_ping + final_result 93 | 94 | ppm.pingpongs[-1].ping = last_ping 95 | ppm.replace_last_pong("") 96 | yield True, ppm, "⏳ Wait until LLM generates message for you ⏳" 97 | 98 | else: 99 | # 3rd yield 100 | ppm.replace_last_pong("![loading](https://i.ibb.co/RPSPL5F/loading.gif)\n") 101 | ppm.append_pong("• Creating Chroma DB Collection √\n") 102 | ppm.append_pong("• URL Searching √\n") 103 | for url in success_urls: 104 | ppm.append_pong(f" - {url} √\n") 105 | ppm.append_pong("• Creating embeddings...") 106 | yield True, ppm, "• Creating embeddings √" 107 | 108 | final_chunks = [] 109 | for search_result in search_results: 110 | chunks = self._create_chunks( 111 | search_result, 112 | chunk_size=self.searcher.max_length 113 | ) 114 | final_chunks.append(chunks) 115 | 116 | self._put_chunks_into_collection( 117 | col, final_chunks, docs_per_step=1 118 | ) 119 | 120 | query_results = self._query( 121 | col, f"query: {last_ping}", top_k, 122 | ) 123 | 124 | # 4th yield 125 | ppm.replace_last_pong("![loading](https://i.ibb.co/RPSPL5F/loading.gif)\n") 126 | ppm.append_pong("• Creating Chroma DB Collection √\n") 127 | ppm.append_pong("• URL Searching √\n") 128 | for url in success_urls: 129 | ppm.append_pong(f" - {url} √\n") 130 | ppm.append_pong("• Creating embeddings √\n") 131 | ppm.append_pong("• Information retrieval...") 132 | yield True, ppm, "• Information retrieval √" 133 | 134 | last_ping = self.instruction.format(ping=last_ping) 135 | for doc in query_results['documents'][0]: 136 | last_ping = last_ping + doc.replace('passage: ', '') + "\n" 137 | 138 | # 5th yield 139 | ppm.replace_last_pong("![loading](https://i.ibb.co/RPSPL5F/loading.gif)\n") 140 | ppm.append_pong("• Creating Chroma DB Collection √\n") 141 | ppm.append_pong("• URL Searching √\n") 142 | for url in success_urls: 143 | ppm.append_pong(f" - {url} √\n") 144 | ppm.append_pong("• Creating embeddings √\n") 145 | ppm.append_pong("• Information retrieval √") 146 | yield True, ppm, "• Done √" 147 | 148 | ppm.pingpongs[-1].ping = last_ping 149 | ppm.replace_last_pong("") 150 | yield True, ppm, "⏳ Wait until LLM generates message for you ⏳" 151 | 152 | def _parse_html(self, url): 153 | try: 154 | page = urlopen(url, timeout=5) 155 | html_bytes = page.read() 156 | html = html_bytes.decode("utf-8") 157 | except: 158 | return False, None 159 | 160 | text = "" 161 | soup = BeautifulSoup(html, "html.parser") 162 | 163 | for tag in soup.findAll('p'): 164 | for string in tag.strings: 165 | text = text + string 166 | 167 | for tag in soup.findAll('pre'): 168 | for string in tag.strings: 169 | text = text + string 170 | 171 | text = self._replace_multiple_newlines(text) 172 | return True, text 173 | 174 | def _query( 175 | self, collection, q, top_k 176 | ): 177 | _, q_embeddings_list = self.searcher.get_embeddings([q]) 178 | 179 | return collection.query( 180 | query_embeddings=q_embeddings_list, 181 | n_results=top_k 182 | ) 183 | 184 | # chunk_size == number of characters 185 | def _create_chunks(self, text, chunk_size): 186 | chunks = [] 187 | 188 | for idx in range(0, len(text), chunk_size): 189 | chunks.append( 190 | f"passage: {text[idx:idx+chunk_size]}" 191 | ) 192 | 193 | return chunks 194 | 195 | def _put_chunk_into_collection( 196 | self, collection, chunk_id, chunk, docs_per_step=1 197 | ): 198 | for i in range(0, len(chunk), docs_per_step): 199 | cur_texts = chunk[i:i+docs_per_step] 200 | _, embeddings_list = self.searcher.get_embeddings(cur_texts) 201 | ids = [ 202 | f"id-{chunk_id}-{num}" for num in range(i, i+docs_per_step) 203 | ] 204 | 205 | collection.add( 206 | embeddings=embeddings_list, 207 | documents=cur_texts, 208 | ids=ids 209 | ) 210 | 211 | def _put_chunks_into_collection( 212 | self, collection, 213 | chunks, docs_per_step=1 214 | ): 215 | for idx, chunk in enumerate(chunks): 216 | self._put_chunk_into_collection( 217 | collection, idx, 218 | chunk, docs_per_step=docs_per_step 219 | ) 220 | 221 | def _replace_multiple_newlines(self, text): 222 | """Replaces multiple newline characters with a single newline character.""" 223 | pattern = re.compile(r"\n+") 224 | return pattern.sub("\n", text) 225 | 226 | @classmethod 227 | def id_generator(cls, size=10, chars=string.ascii_uppercase + string.digits): 228 | return ''.join(random.choice(chars) for _ in range(size)) -------------------------------------------------------------------------------- /entry_point.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from discord_app import discord_main 5 | from app import gradio_main 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | 10 | app_mode = os.getenv("LLMCHAT_APP_MODE") 11 | local_files_only = os.getenv("LLMCHAT_LOCAL_FILES_ONLY") 12 | serper_api_key = os.getenv("LLMCHAT_SERPER_API_KEY") 13 | 14 | if app_mode is None or \ 15 | app_mode not in ["GRADIO", "DISCORD"]: 16 | app_mode = "GRADIO" 17 | 18 | if local_files_only is None: 19 | local_files_only = False 20 | else: 21 | local_files_only = bool(local_files_only) 22 | 23 | if app_mode == "GRADIO": 24 | parser.add_argument('--root-path', default="") 25 | parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) 26 | parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction) 27 | parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction) 28 | parser.add_argument('--serper-api-key', default=serper_api_key, type=str) 29 | args = parser.parse_args() 30 | gradio_main(args) 31 | 32 | elif app_mode == "DISCORD": 33 | parser.add_argument('--token', default=None, type=str) 34 | parser.add_argument('--model-name', default=None, type=str) 35 | parser.add_argument('--max-workers', default=1, type=int) 36 | parser.add_argument('--mode-cpu', default=False, action=argparse.BooleanOptionalAction) 37 | parser.add_argument('--mode-mps', default=False, action=argparse.BooleanOptionalAction) 38 | parser.add_argument('--mode-8bit', default=False, action=argparse.BooleanOptionalAction) 39 | parser.add_argument('--mode-4bit', default=False, action=argparse.BooleanOptionalAction) 40 | parser.add_argument('--mode-full-gpu', default=True, action=argparse.BooleanOptionalAction) 41 | parser.add_argument('--local-files-only', default=local_files_only, action=argparse.BooleanOptionalAction) 42 | parser.add_argument('--serper-api-key', default=serper_api_key, type=str) 43 | parser.add_argument('--tgi-server-addr', default=None, type=str) 44 | parser.add_argument('--tgi-server-port', default=None, type=str) 45 | args = parser.parse_args() 46 | discord_main(args) 47 | -------------------------------------------------------------------------------- /examples.txt: -------------------------------------------------------------------------------- 1 | Tell me about Generative AI 2 | What is the meaning of life? -------------------------------------------------------------------------------- /gens/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/gens/__init__.py -------------------------------------------------------------------------------- /gens/batch_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_output_batch( 4 | model, tokenizer, prompts, generation_config, device='cuda' 5 | ): 6 | if len(prompts) == 1: 7 | encoding = tokenizer(prompts, return_tensors="pt") 8 | input_ids = encoding["input_ids"].to(device) 9 | generated_id = model.generate( 10 | input_ids=input_ids, 11 | generation_config=generation_config, 12 | ) 13 | 14 | decoded = tokenizer.batch_decode( 15 | generated_id, skip_prompt=True, skip_special_tokens=True 16 | ) 17 | del input_ids, generated_id 18 | torch.cuda.empty_cache() 19 | return decoded 20 | else: 21 | encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device) 22 | generated_ids = model.generate( 23 | **encodings, 24 | generation_config=generation_config, 25 | ) 26 | 27 | decoded = tokenizer.batch_decode( 28 | generated_ids, skip_prompt=True, skip_special_tokens=True 29 | ) 30 | del encodings, generated_ids 31 | torch.cuda.empty_cache() 32 | return decoded 33 | -------------------------------------------------------------------------------- /gradio.dstack.yml: -------------------------------------------------------------------------------- 1 | type: task 2 | 3 | env: 4 | # (Optional) Specify your Hugging Face token 5 | - HUGGING_FACE_HUB_TOKEN= 6 | # (Optional) Specify your Serper API Key 7 | - LLMCHAT_SERPER_API_KEY= 8 | 9 | ports: 10 | - 6006 11 | 12 | commands: 13 | - pip install -r requirements.txt --progress-bar off 14 | - LLMCHAT_APP_MODE=GRADIO python entry_point.py 15 | -------------------------------------------------------------------------------- /miscs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/miscs/__init__.py -------------------------------------------------------------------------------- /miscs/js.py: -------------------------------------------------------------------------------- 1 | GET_LOCAL_STORAGE = """ 2 | function() { 3 | globalThis.setStorage = (key, value)=>{ 4 | localStorage.setItem(key, JSON.stringify(value)); 5 | } 6 | globalThis.getStorage = (key, value)=>{ 7 | return JSON.parse(localStorage.getItem(key)); 8 | } 9 | 10 | var local_data = getStorage('local_data'); 11 | var history = []; 12 | 13 | if(local_data) { 14 | local_data[0].pingpongs.forEach(element =>{ 15 | history.push([element.ping, element.pong]); 16 | }); 17 | } 18 | else { 19 | local_data = []; 20 | for (let step = 0; step < 10; step++) { 21 | local_data.push({'ctx': '', 'pingpongs':[]}); 22 | } 23 | setStorage('local_data', local_data); 24 | } 25 | 26 | if(history.length == 0) { 27 | document.querySelector("#initial-popup").classList.remove('hide'); 28 | } 29 | 30 | return [history, local_data]; 31 | } 32 | """ 33 | 34 | UPDATE_LEFT_BTNS_STATE = """ 35 | (v)=>{ 36 | document.querySelector('.custom-btn-highlight').classList.add('custom-btn'); 37 | document.querySelector('.custom-btn-highlight').classList.remove('custom-btn-highlight'); 38 | 39 | const elements = document.querySelectorAll(".custom-btn"); 40 | 41 | for(var i=0; i < elements.length; i++) { 42 | const element = elements[i]; 43 | if(element.textContent == v) { 44 | console.log(v); 45 | element.classList.add('custom-btn-highlight'); 46 | element.classList.remove('custom-btn'); 47 | break; 48 | } 49 | } 50 | }""" 51 | 52 | UPDATE_PLACEHOLDERS = """ 53 | function update_placeholders(txt, placeholder_txt1, placeholder_txt2, placeholder_txt3) { 54 | let example_prompt = txt; 55 | 56 | const regex = /\[([^\]]*)\]/g; 57 | const matches = txt.match(regex); 58 | 59 | if (matches != null) { 60 | if (matches.length >= 1) { 61 | if (placeholder_txt1 !== "") { 62 | example_prompt = example_prompt.replace(matches[0], placeholder_txt1); 63 | } 64 | } 65 | 66 | if (matches.length >= 2) { 67 | if (placeholder_txt2 !== "") { 68 | example_prompt = example_prompt.replace(matches[1], placeholder_txt2); 69 | } 70 | } 71 | 72 | if (matches.length >= 3) { 73 | if (placeholder_txt1 !== "") { 74 | example_prompt = example_prompt.replace(matches[2], placeholder_txt3); 75 | } 76 | } 77 | } 78 | 79 | return example_prompt 80 | } 81 | """ -------------------------------------------------------------------------------- /miscs/strings.py: -------------------------------------------------------------------------------- 1 | TITLE = "Alpaca-LoRA Playground" 2 | 3 | ABSTRACT = """ 4 | Thanks to [tolen](https://github.com/tloen/alpaca-lora), this application runs Alpaca-LoRA which is instruction fine-tuned version of [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/). This demo currently runs 30B version on a 3*A6000 instance at [Jarvislabs.ai](https://jarvislabs.ai/). 5 | 6 | NOTE: too long input (context, instruction) will not be allowed. Please keep context < 500 and instruction < 150 7 | """ 8 | 9 | BOTTOM_LINE = """ 10 | This demo application runs the open source project, [Alpaca-LoRA-Serve](https://github.com/deep-diver/Alpaca-LoRA-Serve). By default, it runs with streaming mode, but you can also run with dynamic batch generation model. Please visit the repo, find more information, and contribute if you can. 11 | 12 | Alpaca-LoRA is built on the same concept as Standford Alpaca project, but it lets us train and inference on a smaller GPUs such as RTX4090 for 7B version. Also, we could build very small size of checkpoints on top of base models thanks to [🤗 transformers](https://huggingface.co/docs/transformers/index), [🤗 peft](https://github.com/huggingface/peft), and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes/tree/main) libraries. 13 | 14 | We are thankful to the [Jarvislabs.ai](https://jarvislabs.ai/) who generously provided free GPU instances. 15 | """ 16 | 17 | DEFAULT_EXAMPLES = { 18 | "Typical Questions": [ 19 | { 20 | "title": "List all Canadian provinces in alphabetical order.", 21 | "examples": [ 22 | ["1", "List all Canadian provinces in alphabetical order."], 23 | ["2", "Which ones are on the east side?"], 24 | ["3", "What foods are famous in each province on the east side?"], 25 | ["4", "What about sightseeing? or landmarks? list one per province"], 26 | ], 27 | }, 28 | { 29 | "title": "Tell me about Alpacas.", 30 | "examples": [ 31 | ["1", "Tell me about alpacas in two sentences"], 32 | ["2", "What other animals are living in the same area?"], 33 | ["3", "Are they the same species?"], 34 | ["4", "Write a Python program to return those species"], 35 | ], 36 | }, 37 | { 38 | "title": "Tell me about the king of France in 2019.", 39 | "examples": [ 40 | ["1", "Tell me about the king of France in 2019."], 41 | ["2", "What about before him?"], 42 | ] 43 | }, 44 | { 45 | "title": "Write a Python program that prints the first 10 Fibonacci numbers.", 46 | "examples": [ 47 | ["1", "Write a Python program that prints the first 10 Fibonacci numbers."], 48 | ["2", "Could you explain how the code works?"], 49 | ["3", "What is recursion?"], 50 | ] 51 | } 52 | ], 53 | "Identity": [ 54 | { 55 | "title": "Conversation with the planet Pluto", 56 | "examples": [ 57 | ["1", "Conversation with the planet Pluto", "I'am so curious about you"], 58 | ["2", "Conversation with the planet Pluto", "Tell me what I would see if I visited"], 59 | ["3", "Conversation with the planet Pluto", "It sounds beautiful"], 60 | ["4", "Conversation with the planet Pluto", "I'll keep that in mind. Hey I was wondering have you ever had any visitor?"], 61 | ["5", "Conversation with the planet Pluto", "That must have been exciting"], 62 | ["6", "Conversation with the planet Pluto", "That's so great. What else do you wish people knew about you?"], 63 | ["7", "Conversation with the planet Pluto", "Thanks for talking with me"], 64 | ] 65 | }, 66 | { 67 | "title": "Conversation with a paper airplane", 68 | "examples": [ 69 | ["1", "Conversation with a paper airplane", "What's it like being thrown through the air"], 70 | ["2", "Conversation with a paper airplane", "What's the worst place you've ever landed"], 71 | ["3", "Conversation with a paper airplane", "Have you ever stucked?"], 72 | ["4", "Conversation with a paper airplane", "What's the secret to a really good paper airplane?"], 73 | ["5", "Conversation with a paper airplane", "What's the farthest you've ever flown?"], 74 | ["6", "Conversation with a paper airplane", "Good to talk to you!"] 75 | ] 76 | } 77 | ] 78 | } 79 | 80 | SPECIAL_STRS = { 81 | "continue": "continue.", 82 | "summarize": "what have we discussed so far? describe in the user's view and include important entities. also be brief as much as possible." 83 | } -------------------------------------------------------------------------------- /miscs/templates.py: -------------------------------------------------------------------------------- 1 | templates = [ 2 | { 3 | "title": "Marketing", 4 | "template": [ 5 | "Can you provide me with some ideas for blog posts about [topic of your choice]", 6 | "Create a social media post that targets [the specific audience] and explains how our product [product name] can help them.", 7 | "Write a list of 5 YouTube video ideas for [your product or company]", 8 | "Suggest inexpensive ways I can promote my [company] with/without using [Media channel]" 9 | ], 10 | }, 11 | { 12 | "title": "Business", 13 | "template": [ 14 | "Analyze the current state of [industry] and its trends, challenges, and opportunities, including relevant data and statistics. Provide a list of key players and a short and long-term industry forecast, and explain any potential impact of current events or future developments.", 15 | "Offer a detailed review of a [specific software or tool] for [describe your business].", 16 | "I need to prepare a presentation for a potential investor on [presentation topic]. Can you give me some guidance on what to include?", 17 | "I need to write an email to a client regarding a change in the project timeline. Can you give me some guidance on how to phrase it?" 18 | ] 19 | }, 20 | { 21 | "title": "Content Creation", 22 | "template": [ 23 | "Generate a creative social media content calendar for the next month for our [company or product] on [topic of choice]", 24 | "Generate a list of 5 LinkedIn articles to write for a [profession or topic of your choice]", 25 | "Create two Google Ads in an RSA format (using multiple headlines and descriptions) for an A/B test for [your company] Explain why the ads would make a good test.", 26 | "Write an email to [person] with some facts about [Topic of your choice] with a [theme of your choice]" 27 | ] 28 | }, 29 | { 30 | "title": "Education", 31 | "template": [ 32 | "Create a magical system that emphasizes education and is based on [topic of your choice].", 33 | "Teach me the [topic of your choice] and give me a quiz at the end, but don’t give me the answers and then tell me if I answered correctly.", 34 | "Can you give me an example of how to solve a [Problem statement]?", 35 | "Create a YAML template to detect the Magento version for the Nuclei vulnerability scanner." 36 | ] 37 | }, 38 | { 39 | "title": "Teachers", 40 | "template": [ 41 | "Create a list of 5 types of data that teachers can collect to monitor student learning and progress.", 42 | "Create a quiz with 5 multiple choice questions that assess students' understanding of [concept being taught].", 43 | "Generate a list of specific and actionable steps that a student can take to improve their performance in [subject/task]", 44 | "Create a list of 5 teaching strategies that could be used to engage and challenge students of different ability levels in a lesson on [concept being taught]" 45 | ] 46 | }, 47 | { 48 | "title": "Web Development", 49 | "template": [ 50 | "Suggest inexpensive ways I can promote my [company] with/without using [Media channel]", 51 | "I need to create a REST API endpoint for my web application. Can you provide an example of how to do that using Node.js and Express?", 52 | "I’m making a website for a small business [Business description]. I need ideas on how to structure the website using WordPress.", 53 | "Find the bug with this code: [post code below]" 54 | ] 55 | }, 56 | { 57 | "title": "Trravel and Tourism", 58 | "template": [ 59 | "How much money do I need as a tourist for [X] days in [Location]?", 60 | "How much money do I need to survive a day in [location]?", 61 | "I want to plan a three-week backpacking trip through Europe. I have a student’s budget, and I love finding local street food and open markets. Can you suggest an itinerary for me?", 62 | "Pick [X] cities for a [Y]-day trip in [location]" 63 | ] 64 | }, 65 | { 66 | "title": "Music", 67 | "template": [ 68 | "Write a lyrical verse in the style of [artist] about [topic]", 69 | "I want to make a music video, but I’m not sure what concept to use. Can you help me come up with a concept?", 70 | "I want to write a midi file. Can you provide python3 code that writes a simple tune using a for loop to add each note?", 71 | "Create a poem or song for [target audience] that explains . The song should have a distinct character and traits for each participant, as well as punctuation such as.,!?, and so on. Make it last as long as possible." 72 | ] 73 | }, 74 | { 75 | "title": "Fun", 76 | "template": [ 77 | "Tell me a joke about [topic of your choice]", 78 | "Explain [topic of your choice] in a funny way", 79 | "Write hilarious fan fiction about the Twilight saga.", 80 | "Make Eminem-style jokes about Max Payne." 81 | ] 82 | }, 83 | { 84 | "title": "UX", 85 | "template": [ 86 | "Generate examples of UI design requirements for a [mobile app]", 87 | "Generate a typography style guide for a [mobile application] in excel format.", 88 | "What are the UI cases that need to be considered when designing a [burger menu]", 89 | "How can I design a [law firm website] in a way that conveys [trust and authority]" 90 | ] 91 | } 92 | ] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-diver/LLM-As-Chatbot/99c2c03efececba39a633589775f77989f93deff/models/__init__.py -------------------------------------------------------------------------------- /models/airoboros.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from optimum.bettertransformer import BetterTransformer 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = AutoTokenizer.from_pretrained( 24 | base, local_files_only=local_files_only 25 | ) 26 | 27 | if mode_cpu: 28 | print("cpu mode") 29 | model = AutoModelForCausalLM.from_pretrained( 30 | base, 31 | device_map={"": "cpu"}, 32 | use_safetensors=False, 33 | local_files_only=local_files_only 34 | # low_cpu_mem_usage=True 35 | ) 36 | elif mode_mps: 37 | print("mps mode") 38 | model = AutoModelForCausalLM.from_pretrained( 39 | base, 40 | device_map={"": "mps"}, 41 | torch_dtype=torch.float16, 42 | use_safetensors=False, 43 | local_files_only=local_files_only 44 | ) 45 | 46 | elif mode_gptq: 47 | print("gpu(gptq) mode") 48 | tokenizer = AutoTokenizer.from_pretrained( 49 | gptq, local_files_only=local_files_only 50 | ) 51 | tokenizer.pad_token_id = 0 52 | tokenizer.padding_side = "left" 53 | 54 | model = AutoGPTQForCausalLM.from_quantized( 55 | gptq, 56 | model_basename=gptq_base, 57 | use_safetensors=True, 58 | trust_remote_code=False, 59 | device_map="auto", 60 | quantize_config=None, 61 | local_files_only=local_files_only 62 | ) 63 | 64 | else: 65 | print("gpu mode") 66 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 67 | model = AutoModelForCausalLM.from_pretrained( 68 | base, 69 | torch_dtype=torch.float16, 70 | load_in_8bit=mode_8bit, 71 | load_in_4bit=mode_4bit, 72 | device_map="auto", 73 | use_safetensors=False, 74 | local_files_only=local_files_only 75 | ) 76 | 77 | if not mode_8bit and not mode_4bit: 78 | model.half() 79 | 80 | model = BetterTransformer.transform(model) 81 | return model, tokenizer -------------------------------------------------------------------------------- /models/alpaca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import LlamaTokenizer, LlamaForCausalLM 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = LlamaTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.pad_token_id = 0 28 | tokenizer.padding_side = "left" 29 | 30 | if mode_cpu: 31 | print("cpu mode") 32 | model = LlamaForCausalLM.from_pretrained( 33 | base, 34 | device_map={"": "cpu"}, 35 | use_safetensors=False, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | if finetuned is not None and \ 40 | finetuned != "" and \ 41 | finetuned != "N/A": 42 | 43 | model = PeftModel.from_pretrained( 44 | model, 45 | finetuned, 46 | device_map={"": "cpu"}, 47 | # force_download=force_download_ckpt, 48 | ) 49 | else: 50 | model = BetterTransformer.transform(model) 51 | 52 | elif mode_mps: 53 | print("mps mode") 54 | model = LlamaForCausalLM.from_pretrained( 55 | base, 56 | device_map={"": "mps"}, 57 | torch_dtype=torch.float16, 58 | use_safetensors=False, 59 | local_files_only=local_files_only 60 | ) 61 | 62 | if finetuned is not None and \ 63 | finetuned != "" and \ 64 | finetuned != "N/A": 65 | 66 | model = PeftModel.from_pretrained( 67 | model, 68 | finetuned, 69 | torch_dtype=torch.float16, 70 | device_map={"": "mps"} 71 | # force_download=force_download_ckpt, 72 | ) 73 | else: 74 | model = BetterTransformer.transform(model) 75 | 76 | elif mode_gptq: 77 | print("gpu(gptq) mode") 78 | tokenizer = LlamaTokenizer.from_pretrained( 79 | gptq, local_files_only=local_files_only 80 | ) 81 | tokenizer.pad_token_id = 0 82 | tokenizer.padding_side = "left" 83 | 84 | model = AutoGPTQForCausalLM.from_quantized( 85 | gptq, 86 | model_basename=gptq_base, 87 | use_safetensors=True, 88 | trust_remote_code=False, 89 | device_map="auto", 90 | quantize_config=None, 91 | local_files_only=local_files_only 92 | ) 93 | 94 | # elif mode_mps_gptq: 95 | # print("mps(gptq) mode") 96 | # tokenizer = LlamaTokenizer.from_pretrained( 97 | # gptq, local_files_only=local_files_only 98 | # ) 99 | # tokenizer.pad_token_id = 0 100 | # tokenizer.padding_side = "left" 101 | 102 | # model = AutoGPTQForCausalLM.from_quantized( 103 | # gptq, 104 | # model_basename=gptq_base, 105 | # use_safetensors=True, 106 | # trust_remote_code=False, 107 | # device="mps", 108 | # quantize_config=None, 109 | # local_files_only=local_files_only 110 | # ) 111 | 112 | # elif mode_cpu_gptq: 113 | # print("cpu(gptq) mode") 114 | # tokenizer = LlamaTokenizer.from_pretrained( 115 | # gptq, local_files_only=local_files_only 116 | # ) 117 | # tokenizer.pad_token_id = 0 118 | # tokenizer.padding_side = "left" 119 | 120 | # quantize_config = BaseQuantizeConfig(bits=4, group_size=128) 121 | 122 | # model = AutoGPTQForCausalLM.from_pretrained( 123 | # base, 124 | # quantize_config, 125 | # local_files_only=local_files_only 126 | # ) 127 | 128 | else: 129 | print("gpu mode") 130 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 131 | model = LlamaForCausalLM.from_pretrained( 132 | base, 133 | load_in_8bit=mode_8bit, 134 | load_in_4bit=mode_4bit, 135 | torch_dtype=torch.float16, 136 | device_map="auto", 137 | use_safetensors=False, 138 | local_files_only=local_files_only 139 | ) 140 | 141 | if not mode_8bit and not mode_4bit: 142 | model.half() 143 | 144 | if finetuned is not None and \ 145 | finetuned != "" and \ 146 | finetuned != "N/A": 147 | 148 | model = PeftModel.from_pretrained( 149 | model, 150 | finetuned, 151 | # force_download=force_download_ckpt, 152 | ) 153 | else: 154 | model = BetterTransformer.transform(model) 155 | 156 | return model, tokenizer 157 | 158 | -------------------------------------------------------------------------------- /models/baize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import LlamaTokenizer, LlamaForCausalLM 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = LlamaTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.pad_token_id = 0 28 | tokenizer.padding_side = "left" 29 | 30 | if mode_cpu: 31 | print("cpu mode") 32 | model = LlamaForCausalLM.from_pretrained( 33 | base, 34 | device_map={"": "cpu"}, 35 | use_safetensors=False, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | if finetuned is not None and \ 40 | finetuned != "" and \ 41 | finetuned != "N/A": 42 | 43 | model = PeftModel.from_pretrained( 44 | model, 45 | finetuned, 46 | device_map={"": "cpu"}, 47 | # force_download=force_download_ckpt, 48 | ) 49 | else: 50 | model = BetterTransformer.transform(model) 51 | 52 | elif mode_mps: 53 | print("mps mode") 54 | model = LlamaForCausalLM.from_pretrained( 55 | base, 56 | device_map={"": "mps"}, 57 | torch_dtype=torch.float16, 58 | use_safetensors=False, 59 | local_files_only=local_files_only 60 | ) 61 | 62 | if finetuned is not None and \ 63 | finetuned != "" and \ 64 | finetuned != "N/A": 65 | 66 | model = PeftModel.from_pretrained( 67 | model, 68 | finetuned, 69 | torch_dtype=torch.float16, 70 | device_map={"": "mps"} 71 | # force_download=force_download_ckpt, 72 | ) 73 | else: 74 | model = BetterTransformer.transform(model) 75 | 76 | elif mode_gptq: 77 | print("gpu(gptq) mode") 78 | tokenizer = LlamaTokenizer.from_pretrained( 79 | gptq, local_files_only=local_files_only 80 | ) 81 | tokenizer.pad_token_id = 0 82 | tokenizer.padding_side = "left" 83 | 84 | model = AutoGPTQForCausalLM.from_quantized( 85 | gptq, 86 | model_basename=gptq_base, 87 | use_safetensors=True, 88 | trust_remote_code=False, 89 | device_map="auto", 90 | quantize_config=None, 91 | local_files_only=local_files_only 92 | ) 93 | 94 | else: 95 | print("gpu mode") 96 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 97 | model = LlamaForCausalLM.from_pretrained( 98 | base, 99 | load_in_8bit=mode_8bit, 100 | load_in_4bit=mode_4bit, 101 | torch_dtype=torch.float16, 102 | device_map="auto", 103 | use_safetensors=False, 104 | local_files_only=local_files_only 105 | ) 106 | 107 | if not mode_8bit and not mode_4bit: 108 | model.half() 109 | 110 | if finetuned is not None and \ 111 | finetuned != "" and \ 112 | finetuned != "N/A": 113 | 114 | model = PeftModel.from_pretrained( 115 | model, 116 | finetuned, 117 | # force_download=force_download_ckpt, 118 | ) 119 | else: 120 | model = BetterTransformer.transform(model) 121 | 122 | return model, tokenizer -------------------------------------------------------------------------------- /models/bloom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | 28 | if mode_cpu: 29 | print("cpu mode") 30 | model = AutoModelForCausalLM.from_pretrained( 31 | base, 32 | device_map={"": "cpu"}, 33 | use_safetensors=False, 34 | local_files_only=local_files_only 35 | ) 36 | 37 | if finetuned is not None and \ 38 | finetuned != "" and \ 39 | finetuned != "N/A": 40 | 41 | model = PeftModel.from_pretrained( 42 | model, 43 | finetuned, 44 | device_map={"": "cpu"}, 45 | # force_download=force_download_ckpt, 46 | ) 47 | else: 48 | model = BetterTransformer.transform(model) 49 | 50 | elif mode_mps: 51 | print("mps mode") 52 | model = AutoModelForCausalLM.from_pretrained( 53 | base, 54 | device_map={"": "mps"}, 55 | torch_dtype=torch.float16, 56 | use_safetensors=False, 57 | local_files_only=local_files_only 58 | ) 59 | 60 | if finetuned is not None and \ 61 | finetuned != "" and \ 62 | finetuned != "N/A": 63 | 64 | model = PeftModel.from_pretrained( 65 | model, 66 | finetuned, 67 | torch_dtype=torch.float16, 68 | device_map={"": "mps"}, 69 | # force_download=force_download_ckpt, 70 | ) 71 | else: 72 | model = BetterTransformer.transform(model) 73 | 74 | elif mode_gptq: 75 | print("gpu(gptq) mode") 76 | tokenizer = AutoTokenizer.from_pretrained( 77 | gptq, local_files_only=local_files_only 78 | ) 79 | tokenizer.pad_token_id = 0 80 | tokenizer.padding_side = "left" 81 | 82 | model = AutoGPTQForCausalLM.from_quantized( 83 | gptq, 84 | model_basename=gptq_base, 85 | use_safetensors=True, 86 | trust_remote_code=False, 87 | device_map="auto", 88 | quantize_config=None, 89 | local_files_only=local_files_only 90 | ) 91 | 92 | else: 93 | print("gpu mode") 94 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 95 | model = AutoModelForCausalLM.from_pretrained( 96 | base, 97 | load_in_8bit=mode_8bit, 98 | load_in_4bit=mode_4bit, 99 | device_map="auto", 100 | use_safetensors=False, 101 | local_files_only=local_files_only 102 | ) 103 | 104 | if not mode_8bit and not mode_4bit: 105 | model.half() 106 | 107 | if finetuned is not None and \ 108 | finetuned != "" and \ 109 | finetuned != "N/A": 110 | 111 | model = PeftModel.from_pretrained( 112 | model, 113 | finetuned, 114 | # force_download=force_download_ckpt, 115 | ) 116 | else: 117 | model = BetterTransformer.transform(model) 118 | 119 | return model, tokenizer -------------------------------------------------------------------------------- /models/byom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from peft import PeftModel 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | 6 | def load_model( 7 | base, 8 | finetuned, 9 | mode_cpu, 10 | mode_mps, 11 | mode_full_gpu, 12 | mode_8bit, 13 | mode_4bit, 14 | # force_download_ckpt, 15 | model_cls, 16 | tokenizer_cls 17 | ): 18 | if tokenizer_cls is None: 19 | tokenizer_cls = AutoTokenizer 20 | else: 21 | tokenizer_cls = eval(tokenizer_cls) 22 | 23 | if model_cls is None: 24 | model_cls = AutoModelForCausalLM 25 | else: 26 | model_cls = eval(model_cls) 27 | 28 | print(f"tokenizer_cls: {tokenizer_cls}") 29 | print(f"model_cls: {model_cls}") 30 | 31 | tokenizer = tokenizer_cls.from_pretrained(base) 32 | tokenizer.padding_side = "left" 33 | 34 | if mode_cpu: 35 | print("cpu mode") 36 | model = model_cls.from_pretrained( 37 | base, 38 | device_map={"": "cpu"}, 39 | use_safetensors=False 40 | # low_cpu_mem_usage=True 41 | ) 42 | 43 | if finetuned is not None and \ 44 | finetuned != "" and \ 45 | finetuned != "N/A": 46 | model = PeftModel.from_pretrained( 47 | model, 48 | finetuned, 49 | device_map={"": "cpu"} 50 | # force_download=force_download_ckpt, 51 | ) 52 | elif mode_mps: 53 | print("mps mode") 54 | model = model_cls.from_pretrained( 55 | base, 56 | device_map={"": "mps"}, 57 | torch_dtype=torch.float16, 58 | use_safetensors=False 59 | ) 60 | 61 | if finetuned is not None and \ 62 | finetuned != "" and \ 63 | finetuned != "N/A": 64 | model = PeftModel.from_pretrained( 65 | model, 66 | finetuned, 67 | torch_dtype=torch.float16, 68 | device_map={"": "mps"} 69 | # force_download=force_download_ckpt, 70 | ) 71 | else: 72 | print("gpu mode") 73 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 74 | 75 | model = model_cls.from_pretrained( 76 | base, 77 | load_in_8bit=mode_8bit, 78 | load_in_4bit=mode_4bit, 79 | torch_dtype=torch.float16, 80 | device_map="auto", 81 | ) 82 | 83 | if finetuned is not None and \ 84 | finetuned != "" and \ 85 | finetuned != "N/A": 86 | model = PeftModel.from_pretrained( 87 | model, 88 | finetuned, 89 | # force_download=force_download_ckpt, 90 | ) 91 | 92 | return model, tokenizer -------------------------------------------------------------------------------- /models/camel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | use_safetensors=False, 35 | local_files_only=local_files_only 36 | ) 37 | 38 | elif mode_mps: 39 | print("mps mode") 40 | model = AutoModelForCausalLM.from_pretrained( 41 | base, 42 | device_map={"": "mps"}, 43 | torch_dtype=torch.float16, 44 | use_safetensors=False, 45 | local_files_only=local_files_only 46 | ) 47 | 48 | elif mode_gptq: 49 | print("gpu(gptq) mode") 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | gptq, local_files_only=local_files_only 52 | ) 53 | tokenizer.pad_token_id = 0 54 | tokenizer.padding_side = "left" 55 | 56 | model = AutoGPTQForCausalLM.from_quantized( 57 | gptq, 58 | model_basename=gptq_base, 59 | use_safetensors=True, 60 | trust_remote_code=False, 61 | device_map="auto", 62 | quantize_config=None, 63 | local_files_only=local_files_only 64 | ) 65 | 66 | else: 67 | print("gpu mode") 68 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 69 | model = AutoModelForCausalLM.from_pretrained( 70 | base, 71 | load_in_8bit=mode_8bit, 72 | load_in_4bit=mode_4bit, 73 | device_map="auto", 74 | torch_dtype=torch.float16, 75 | use_safetensors=False, 76 | local_files_only=local_files_only 77 | ) 78 | 79 | if not mode_8bit and not mode_4bit: 80 | model.half() 81 | 82 | model = BetterTransformer.transform(model) 83 | return model, tokenizer -------------------------------------------------------------------------------- /models/falcon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | torch_dtype=torch.bfloat16, 35 | use_safetensors=False, 36 | trust_remote_code=True, 37 | local_files_only=local_files_only 38 | ) 39 | 40 | elif mode_mps: 41 | print("mps mode") 42 | model = AutoModelForCausalLM.from_pretrained( 43 | base, 44 | device_map={"": "mps"}, 45 | torch_dtype=torch.bfloat16, 46 | use_safetensors=False, 47 | trust_remote_code=True, 48 | local_files_only=local_files_only 49 | ) 50 | 51 | elif mode_gptq: 52 | print("gpu(gptq) mode") 53 | tokenizer = AutoTokenizer.from_pretrained( 54 | gptq, local_files_only=local_files_only, trust_remote_code=True 55 | ) 56 | tokenizer.pad_token_id = 0 57 | tokenizer.padding_side = "left" 58 | 59 | model = AutoGPTQForCausalLM.from_quantized( 60 | gptq, 61 | model_basename=gptq_base, 62 | use_safetensors=True, 63 | trust_remote_code=True, 64 | device_map="auto", 65 | quantize_config=None, 66 | local_files_only=local_files_only, 67 | ) 68 | 69 | else: 70 | print("gpu mode") 71 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 72 | model = AutoModelForCausalLM.from_pretrained( 73 | base, 74 | load_in_8bit=mode_8bit, 75 | load_in_4bit=mode_4bit, 76 | torch_dtype=torch.bfloat16, 77 | device_map="auto", 78 | trust_remote_code=True, 79 | use_safetensors=False, 80 | local_files_only=local_files_only 81 | ) 82 | 83 | # if not mode_8bit and not mode_4bit: 84 | # model.half() 85 | 86 | # model = BetterTransformer.transform(model) 87 | return model, tokenizer -------------------------------------------------------------------------------- /models/flan_alpaca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 3 | from optimum.bettertransformer import BetterTransformer 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = AutoTokenizer.from_pretrained( 24 | base, local_files_only=local_files_only 25 | ) 26 | tokenizer.pad_token_id = 0 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForSeq2SeqLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | low_cpu_mem_usage=True, 35 | local_files_only=local_files_only 36 | ) 37 | 38 | elif mode_mps: 39 | print("mps mode") 40 | model = AutoModelForSeq2SeqLM.from_pretrained( 41 | base, 42 | device_map={"": "mps"}, 43 | torch_dtype=torch.float16, 44 | local_files_only=local_files_only 45 | ) 46 | 47 | elif mode_gptq: 48 | print("gpu(gptq) mode") 49 | tokenizer = AutoTokenizer.from_pretrained( 50 | gptq, local_files_only=local_files_only 51 | ) 52 | tokenizer.pad_token_id = 0 53 | tokenizer.padding_side = "left" 54 | 55 | model = AutoGPTQForCausalLM.from_quantized( 56 | gptq, 57 | model_basename=gptq_base, 58 | use_safetensors=True, 59 | trust_remote_code=False, 60 | device_map="auto", 61 | quantize_config=None, 62 | local_files_only=local_files_only 63 | ) 64 | 65 | else: 66 | print("gpu mode") 67 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 68 | model = AutoModelForSeq2SeqLM.from_pretrained( 69 | base, 70 | load_in_8bit=mode_8bit, 71 | load_in_4bit=mode_4bit, 72 | device_map="auto", 73 | local_files_only=local_files_only 74 | ) 75 | 76 | if not mode_8bit and not mode_4bit: 77 | model.half() 78 | 79 | model = BetterTransformer.transform(model) 80 | return model, tokenizer 81 | 82 | -------------------------------------------------------------------------------- /models/freewilly.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only, use_fast=False 26 | ) 27 | tokenizer.pad_token_id = 0 28 | tokenizer.padding_side = "left" 29 | 30 | if mode_cpu: 31 | print("cpu mode") 32 | model = AutoModelForCausalLM.from_pretrained( 33 | base, 34 | device_map={"": "cpu"}, 35 | use_safetensors=False, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | if finetuned is not None and \ 40 | finetuned != "" and \ 41 | finetuned != "N/A": 42 | 43 | model = PeftModel.from_pretrained( 44 | model, 45 | finetuned, 46 | device_map={"": "cpu"}, 47 | # force_download=force_download_ckpt, 48 | ) 49 | else: 50 | model = BetterTransformer.transform(model) 51 | 52 | elif mode_mps: 53 | print("mps mode") 54 | model = AutoModelForCausalLM.from_pretrained( 55 | base, 56 | device_map={"": "mps"}, 57 | torch_dtype=torch.float16, 58 | use_safetensors=False, 59 | local_files_only=local_files_only 60 | ) 61 | 62 | if finetuned is not None and \ 63 | finetuned != "" and \ 64 | finetuned != "N/A": 65 | 66 | model = PeftModel.from_pretrained( 67 | model, 68 | finetuned, 69 | torch_dtype=torch.float16, 70 | device_map={"": "mps"} 71 | # force_download=force_download_ckpt, 72 | ) 73 | else: 74 | model = BetterTransformer.transform(model) 75 | 76 | elif mode_gptq: 77 | print("gpu(gptq) mode") 78 | tokenizer = AutoTokenizer.from_pretrained( 79 | gptq, local_files_only=local_files_only 80 | ) 81 | tokenizer.pad_token_id = 0 82 | tokenizer.padding_side = "left" 83 | 84 | model = AutoGPTQForCausalLM.from_quantized( 85 | gptq, 86 | model_basename=gptq_base, 87 | inject_fused_attention=False, 88 | use_safetensors=True, 89 | trust_remote_code=False, 90 | device_map="auto", 91 | quantize_config=None, 92 | local_files_only=local_files_only 93 | ) 94 | 95 | else: 96 | print("gpu mode") 97 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 98 | model = AutoModelForCausalLM.from_pretrained( 99 | base, 100 | load_in_8bit=mode_8bit, 101 | load_in_4bit=mode_4bit, 102 | torch_dtype=torch.float16, 103 | device_map="auto", 104 | use_safetensors=False, 105 | local_files_only=local_files_only, 106 | ) 107 | 108 | if not mode_8bit and not mode_4bit: 109 | model.half() 110 | 111 | if finetuned is not None and \ 112 | finetuned != "" and \ 113 | finetuned != "N/A": 114 | 115 | model = PeftModel.from_pretrained( 116 | model, 117 | finetuned, 118 | # force_download=force_download_ckpt, 119 | ) 120 | # else: 121 | # model = BetterTransformer.transform(model) 122 | 123 | return model, tokenizer 124 | 125 | -------------------------------------------------------------------------------- /models/guanaco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = LlamaTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.bos_token_id = 1 28 | tokenizer.padding_side = "left" 29 | 30 | if mode_cpu: 31 | print("cpu mode") 32 | model = AutoModelForCausalLM.from_pretrained( 33 | base, 34 | device_map={"": "cpu"}, 35 | use_safetensors=False, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | if finetuned is not None and \ 40 | finetuned != "" and \ 41 | finetuned != "N/A": 42 | 43 | model = PeftModel.from_pretrained( 44 | model, 45 | finetuned, 46 | device_map={"": "cpu"}, 47 | # force_download=force_download_ckpt, 48 | ) 49 | else: 50 | model = BetterTransformer.transform(model) 51 | 52 | elif mode_mps: 53 | print("mps mode") 54 | model = AutoModelForCausalLM.from_pretrained( 55 | base, 56 | device_map={"": "mps"}, 57 | torch_dtype=torch.float16, 58 | use_safetensors=False, 59 | local_files_only=local_files_only 60 | ) 61 | 62 | if finetuned is not None and \ 63 | finetuned != "" and \ 64 | finetuned != "N/A": 65 | 66 | model = PeftModel.from_pretrained( 67 | model, 68 | finetuned, 69 | torch_dtype=torch.float16, 70 | device_map={"": "mps"}, 71 | # force_download=force_download_ckpt, 72 | ) 73 | else: 74 | model = BetterTransformer.transform(model) 75 | 76 | elif mode_gptq: 77 | print("gpu(gptq) mode") 78 | tokenizer = AutoTokenizer.from_pretrained( 79 | gptq, local_files_only=local_files_only 80 | ) 81 | tokenizer.pad_token_id = 0 82 | tokenizer.padding_side = "left" 83 | 84 | model = AutoGPTQForCausalLM.from_quantized( 85 | gptq, 86 | model_basename=gptq_base, 87 | use_safetensors=True, 88 | trust_remote_code=False, 89 | device_map="auto", 90 | quantize_config=None, 91 | local_files_only=local_files_only 92 | ) 93 | 94 | else: 95 | print("gpu mode") 96 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 97 | model = AutoModelForCausalLM.from_pretrained( 98 | base, 99 | load_in_8bit=mode_8bit, 100 | load_in_4bit=mode_4bit, 101 | torch_dtype=torch.bfloat16, 102 | device_map="auto", 103 | use_safetensors=False, 104 | local_files_only=local_files_only 105 | ) 106 | 107 | if not mode_8bit and not mode_4bit: 108 | model.half() 109 | 110 | if finetuned is not None and \ 111 | finetuned != "" and \ 112 | finetuned != "N/A": 113 | 114 | model = PeftModel.from_pretrained( 115 | model, 116 | finetuned, 117 | # force_download=force_download_ckpt, 118 | ) 119 | else: 120 | model = BetterTransformer.transform(model) 121 | 122 | return model, tokenizer -------------------------------------------------------------------------------- /models/koalpaca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from optimum.bettertransformer import BetterTransformer 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = AutoTokenizer.from_pretrained( 24 | base, local_files_only=local_files_only 25 | ) 26 | 27 | if mode_cpu: 28 | print("cpu mode") 29 | model = AutoModelForCausalLM.from_pretrained( 30 | base, 31 | device_map={"": "cpu"}, 32 | use_safetensors=False, 33 | local_files_only=local_files_only 34 | ) 35 | 36 | elif mode_mps: 37 | print("mps mode") 38 | model = AutoModelForCausalLM.from_pretrained( 39 | base, 40 | device_map={"": "mps"}, 41 | torch_dtype=torch.float16, 42 | use_safetensors=False, 43 | local_files_only=local_files_only 44 | ) 45 | 46 | elif mode_gptq: 47 | print("gpu(gptq) mode") 48 | tokenizer = AutoTokenizer.from_pretrained( 49 | gptq, local_files_only=local_files_only 50 | ) 51 | tokenizer.pad_token_id = 0 52 | tokenizer.padding_side = "left" 53 | 54 | model = AutoGPTQForCausalLM.from_quantized( 55 | gptq, 56 | model_basename=gptq_base, 57 | use_safetensors=False, 58 | trust_remote_code=False, 59 | device_map="auto", 60 | quantize_config=None, 61 | local_files_only=local_files_only 62 | ) 63 | 64 | else: 65 | print("gpu mode") 66 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 67 | model = AutoModelForCausalLM.from_pretrained( 68 | base, 69 | load_in_8bit=mode_8bit, 70 | load_in_4bit=mode_4bit, 71 | torch_dtype=torch.float16, 72 | device_map="auto", 73 | use_safetensors=False, 74 | local_files_only=local_files_only 75 | ) 76 | 77 | if not mode_8bit and not mode_4bit: 78 | model.half() 79 | 80 | # model = BetterTransformer.transform(model) 81 | return model, tokenizer -------------------------------------------------------------------------------- /models/kullm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained(base) 25 | 26 | if mode_cpu: 27 | print("cpu mode") 28 | model = AutoModelForCausalLM.from_pretrained( 29 | base, 30 | device_map={"": "cpu"}, 31 | use_safetensors=False, 32 | local_files_only=local_files_only 33 | ) 34 | 35 | elif mode_mps: 36 | print("mps mode") 37 | model = AutoModelForCausalLM.from_pretrained( 38 | base, 39 | device_map={"": "mps"}, 40 | torch_dtype=torch.float16, 41 | use_safetensors=False, 42 | local_files_only=local_files_only 43 | ) 44 | 45 | elif mode_gptq: 46 | print("gpu(gptq) mode") 47 | tokenizer = AutoTokenizer.from_pretrained( 48 | gptq, local_files_only=local_files_only 49 | ) 50 | tokenizer.pad_token_id = 0 51 | tokenizer.padding_side = "left" 52 | 53 | model = AutoGPTQForCausalLM.from_quantized( 54 | gptq, 55 | model_basename=gptq_base, 56 | use_safetensors=True, 57 | trust_remote_code=False, 58 | device_map="auto", 59 | quantize_config=None, 60 | local_files_only=local_files_only 61 | ) 62 | 63 | else: 64 | print("gpu mode") 65 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 66 | model = AutoModelForCausalLM.from_pretrained( 67 | base, 68 | load_in_8bit=mode_8bit, 69 | load_in_4bit=mode_4bit, 70 | torch_dtype=torch.float16, 71 | device_map="auto", 72 | use_safetensors=False, 73 | local_files_only=local_files_only 74 | ) 75 | 76 | if not mode_8bit and not mode_4bit: 77 | model.half() 78 | 79 | # model = BetterTransformer.transform(model) 80 | return model, tokenizer -------------------------------------------------------------------------------- /models/llama_rlhf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import LlamaTokenizer, LlamaForCausalLM 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = LlamaTokenizer.from_pretrained( 24 | base,local_files_only=local_files_only 25 | ) 26 | tokenizer.pad_token_id = 0 27 | tokenizer.padding_side = "left" 28 | 29 | if not multi_gpu: 30 | model = LlamaForCausalLM.from_pretrained( 31 | base, 32 | load_in_8bit=mode_8bit, 33 | load_in_4bit=mode_4bit, 34 | device_map="auto", 35 | local_files_only=local_files_only 36 | ) 37 | 38 | model = PeftModel.from_pretrained( 39 | model, 40 | finetuned, 41 | # force_download=force_download_ckpt, 42 | device_map={'': 0} 43 | ) 44 | return model, tokenizer 45 | else: 46 | model = LlamaForCausalLM.from_pretrained( 47 | base, 48 | load_in_8bit=mode_8bit, 49 | load_in_4bit=mode_4bit, 50 | torch_dtype=torch.float16, 51 | device_map="auto", 52 | local_files_only=local_files_only 53 | ) 54 | 55 | model = PeftModel.from_pretrained( 56 | model, 57 | finetuned, 58 | # force_download=force_download_ckpt, 59 | torch_dtype=torch.float16 60 | ) 61 | model.half() 62 | return model, tokenizer 63 | 64 | -------------------------------------------------------------------------------- /models/mistral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only, use_fast=False 26 | ) 27 | tokenizer.pad_token_id = 0 28 | tokenizer.padding_side = "left" 29 | 30 | if mode_cpu: 31 | print("cpu mode") 32 | model = AutoModelForCausalLM.from_pretrained( 33 | base, 34 | device_map={"": "cpu"}, 35 | use_safetensors=False, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | if finetuned is not None and \ 40 | finetuned != "" and \ 41 | finetuned != "N/A": 42 | 43 | model = PeftModel.from_pretrained( 44 | model, 45 | finetuned, 46 | device_map={"": "cpu"}, 47 | # force_download=force_download_ckpt, 48 | ) 49 | else: 50 | model = BetterTransformer.transform(model) 51 | 52 | elif mode_mps: 53 | print("mps mode") 54 | model = AutoModelForCausalLM.from_pretrained( 55 | base, 56 | device_map={"": "mps"}, 57 | torch_dtype=torch.float16, 58 | use_safetensors=False, 59 | local_files_only=local_files_only 60 | ) 61 | 62 | if finetuned is not None and \ 63 | finetuned != "" and \ 64 | finetuned != "N/A": 65 | 66 | model = PeftModel.from_pretrained( 67 | model, 68 | finetuned, 69 | torch_dtype=torch.float16, 70 | device_map={"": "mps"} 71 | # force_download=force_download_ckpt, 72 | ) 73 | else: 74 | model = BetterTransformer.transform(model) 75 | 76 | elif mode_gptq: 77 | print("gpu(gptq) mode") 78 | tokenizer = AutoTokenizer.from_pretrained( 79 | gptq, local_files_only=local_files_only 80 | ) 81 | tokenizer.pad_token_id = 0 82 | tokenizer.padding_side = "left" 83 | 84 | model = AutoGPTQForCausalLM.from_quantized( 85 | gptq, 86 | model_basename=gptq_base, 87 | inject_fused_attention=False, 88 | use_safetensors=True, 89 | trust_remote_code=False, 90 | device_map="auto", 91 | quantize_config=None, 92 | local_files_only=local_files_only 93 | ) 94 | 95 | else: 96 | print("gpu mode") 97 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 98 | model = AutoModelForCausalLM.from_pretrained( 99 | base, 100 | load_in_8bit=mode_8bit, 101 | load_in_4bit=mode_4bit, 102 | torch_dtype=torch.float16, 103 | device_map="auto", 104 | use_safetensors=False, 105 | local_files_only=local_files_only, 106 | ) 107 | 108 | if not mode_8bit and not mode_4bit: 109 | model.half() 110 | 111 | if finetuned is not None and \ 112 | finetuned != "" and \ 113 | finetuned != "N/A": 114 | 115 | model = PeftModel.from_pretrained( 116 | model, 117 | finetuned, 118 | # force_download=force_download_ckpt, 119 | ) 120 | # else: 121 | # model = BetterTransformer.transform(model) 122 | 123 | return model, tokenizer 124 | 125 | -------------------------------------------------------------------------------- /models/mpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import global_vars 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, trust_remote_code=True, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | use_safetensors=False, 35 | trust_remote_code=True, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | elif mode_mps: 40 | print("mps mode") 41 | model = AutoModelForCausalLM.from_pretrained( 42 | base, 43 | device_map={"": "mps"}, 44 | torch_dtype=torch.float16, 45 | use_safetensors=False, 46 | trust_remote_code=True, 47 | local_files_only=local_files_only 48 | ) 49 | 50 | elif mode_gptq: 51 | print("gpu(gptq) mode") 52 | tokenizer = AutoTokenizer.from_pretrained( 53 | gptq, local_files_only=local_files_only 54 | ) 55 | tokenizer.pad_token_id = 0 56 | tokenizer.padding_side = "left" 57 | 58 | model = AutoGPTQForCausalLM.from_quantized( 59 | gptq, 60 | model_basename=gptq_base, 61 | use_safetensors=True, 62 | trust_remote_code=True, 63 | device_map="auto", 64 | quantize_config=None, 65 | local_files_only=local_files_only 66 | ) 67 | 68 | else: 69 | print("gpu mode") 70 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 71 | model = AutoModelForCausalLM.from_pretrained( 72 | base, 73 | load_in_8bit=mode_8bit, 74 | load_in_4bit=mode_4bit, 75 | device_map="auto", 76 | trust_remote_code=True, 77 | torch_dtype=torch.float16, 78 | use_safetensors=False, 79 | local_files_only=local_files_only, 80 | )#.to(global_vars.device) 81 | 82 | if not mode_8bit and not mode_4bit: 83 | model.half() 84 | 85 | # model = BetterTransformer.transform(model) 86 | return model, tokenizer -------------------------------------------------------------------------------- /models/redpajama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, trust_remote_code=True, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | use_safetensors=False, 35 | trust_remote_code=True, 36 | local_files_only=local_files_only 37 | ) 38 | 39 | elif mode_mps: 40 | print("mps mode") 41 | model = AutoModelForCausalLM.from_pretrained( 42 | base, 43 | device_map={"": "mps"}, 44 | torch_dtype=torch.float16, 45 | use_safetensors=False, 46 | trust_remote_code=True, 47 | local_files_only=local_files_only 48 | ) 49 | 50 | elif mode_gptq: 51 | print("gpu(gptq) mode") 52 | tokenizer = AutoTokenizer.from_pretrained( 53 | gptq, local_files_only=local_files_only 54 | ) 55 | tokenizer.pad_token_id = 0 56 | tokenizer.padding_side = "left" 57 | 58 | model = AutoGPTQForCausalLM.from_quantized( 59 | gptq, 60 | model_basename=gptq_base, 61 | use_safetensors=True, 62 | trust_remote_code=False, 63 | device_map="auto", 64 | quantize_config=None, 65 | local_files_only=local_files_only 66 | ) 67 | 68 | else: 69 | print("gpu mode") 70 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 71 | model = AutoModelForCausalLM.from_pretrained( 72 | base, 73 | load_in_8bit=mode_8bit, 74 | load_in_4bit=mode_4bit, 75 | device_map="auto", 76 | trust_remote_code=True, 77 | torch_dtype=torch.float16, 78 | use_safetensors=False, 79 | local_files_only=local_files_only 80 | )#.to(global_vars.device) 81 | 82 | if not mode_8bit and not mode_4bit: 83 | model.half() 84 | 85 | # model = BetterTransformer.transform(model) 86 | return model, tokenizer -------------------------------------------------------------------------------- /models/replit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, trust_remote_code=True, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | model = AutoModelForCausalLM.from_pretrained( 30 | base, 31 | load_in_8bit=mode_8bit, 32 | load_in_4bit=mode_4bit, 33 | torch_dtype=torch.bfloat16, 34 | trust_remote_code=True, 35 | local_files_only=local_files_only 36 | ) 37 | 38 | if finetuned is not None and \ 39 | finetuned != "" and \ 40 | finetuned != "N/A": 41 | 42 | model = PeftModel.from_pretrained( 43 | model, 44 | finetuned, 45 | # force_download=force_download_ckpt, 46 | trust_remote_code=True 47 | ) 48 | 49 | model = model.merge_and_unload() 50 | 51 | # model = BetterTransformer.transform(model) 52 | model.to('cuda') 53 | return model, tokenizer 54 | 55 | -------------------------------------------------------------------------------- /models/samantha_vicuna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from optimum.bettertransformer import BetterTransformer 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = AutoTokenizer.from_pretrained( 24 | base, local_files_only=local_files_only 25 | ) 26 | 27 | if mode_cpu: 28 | print("cpu mode") 29 | model = AutoModelForCausalLM.from_pretrained( 30 | base, 31 | device_map={"": "cpu"}, 32 | use_safetensors=False, 33 | local_files_only=local_files_only 34 | ) 35 | 36 | elif mode_mps: 37 | print("mps mode") 38 | model = AutoModelForCausalLM.from_pretrained( 39 | base, 40 | device_map={"": "mps"}, 41 | torch_dtype=torch.float16, 42 | use_safetensors=False, 43 | local_files_only=local_files_only 44 | ) 45 | 46 | elif mode_gptq: 47 | print("gpu(gptq) mode") 48 | tokenizer = AutoTokenizer.from_pretrained( 49 | gptq, local_files_only=local_files_only 50 | ) 51 | tokenizer.pad_token_id = 0 52 | tokenizer.padding_side = "left" 53 | 54 | model = AutoGPTQForCausalLM.from_quantized( 55 | gptq, 56 | model_basename=gptq_base, 57 | use_safetensors=True, 58 | trust_remote_code=False, 59 | device_map="auto", 60 | quantize_config=None, 61 | local_files_only=local_files_only 62 | ) 63 | 64 | else: 65 | print("gpu mode") 66 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 67 | model = AutoModelForCausalLM.from_pretrained( 68 | base, 69 | load_in_8bit=mode_8bit, 70 | load_in_4bit=mode_4bit, 71 | device_map="auto", 72 | torch_dtype=torch.float16, 73 | use_safetensors=False, 74 | local_files_only=local_files_only 75 | ) 76 | 77 | if not mode_8bit and not mode_4bit: 78 | model.half() 79 | 80 | model = BetterTransformer.transform(model) 81 | return model, tokenizer -------------------------------------------------------------------------------- /models/stablelm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.pad_token_id = 1 28 | tokenizer.eos_token_id = 0 29 | tokenizer.padding_side = "left" 30 | 31 | if mode_cpu: 32 | print("cpu mode") 33 | model = AutoModelForCausalLM.from_pretrained( 34 | base, 35 | device_map={"": "cpu"}, 36 | use_safetensors=False, 37 | local_files_only=local_files_only 38 | ) 39 | 40 | elif mode_mps: 41 | print("mps mode") 42 | model = AutoModelForCausalLM.from_pretrained( 43 | base, 44 | device_map={"": "mps"}, 45 | torch_dtype=torch.float16, 46 | use_safetensors=False, 47 | local_files_only=local_files_only 48 | ) 49 | 50 | elif mode_gptq: 51 | print("gpu(gptq) mode") 52 | tokenizer = AutoTokenizer.from_pretrained( 53 | gptq, local_files_only=local_files_only 54 | ) 55 | tokenizer.pad_token_id = 0 56 | tokenizer.padding_side = "left" 57 | 58 | model = AutoGPTQForCausalLM.from_quantized( 59 | gptq, 60 | model_basename=gptq_base, 61 | use_safetensors=True, 62 | trust_remote_code=False, 63 | device_map="auto", 64 | quantize_config=None, 65 | local_files_only=local_files_only 66 | ) 67 | 68 | else: 69 | print("gpu mode") 70 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 71 | model = AutoModelForCausalLM.from_pretrained( 72 | base, 73 | load_in_8bit=mode_8bit, 74 | load_in_4bit=mode_4bit, 75 | device_map="auto", 76 | torch_dtype=torch.float16, 77 | use_safetensors=False, 78 | local_files_only=local_files_only 79 | ) 80 | 81 | if not mode_8bit and not mode_4bit: 82 | model.half() 83 | 84 | model = BetterTransformer.transform(model) 85 | return model, tokenizer -------------------------------------------------------------------------------- /models/starchat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | use_safetensors=False, 35 | local_files_only=local_files_only 36 | ) 37 | 38 | elif mode_mps: 39 | print("mps mode") 40 | model = AutoModelForCausalLM.from_pretrained( 41 | base, 42 | device_map={"": "mps"}, 43 | torch_dtype=torch.float16, 44 | use_safetensors=False, 45 | local_files_only=local_files_only 46 | ) 47 | 48 | elif mode_gptq: 49 | print("gpu(gptq) mode") 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | gptq, local_files_only=local_files_only 52 | ) 53 | tokenizer.pad_token_id = 0 54 | tokenizer.padding_side = "left" 55 | 56 | model = AutoGPTQForCausalLM.from_quantized( 57 | gptq, 58 | model_basename=gptq_base, 59 | use_safetensors=True, 60 | trust_remote_code=False, 61 | device_map="auto", 62 | quantize_config=None, 63 | local_files_only=local_files_only 64 | ) 65 | 66 | else: 67 | print("gpu mode") 68 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 69 | model = AutoModelForCausalLM.from_pretrained( 70 | base, 71 | load_in_8bit=mode_8bit, 72 | load_in_4bit=mode_4bit, 73 | device_map="auto", 74 | torch_dtype=torch.float16, 75 | use_safetensors=False, 76 | local_files_only=local_files_only 77 | ) 78 | 79 | if not mode_8bit and not mode_4bit: 80 | model.half() 81 | 82 | # model = BetterTransformer.transform(model) 83 | return model, tokenizer -------------------------------------------------------------------------------- /models/t5_vicuna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSeq2SeqLM, T5Tokenizer 3 | from optimum.bettertransformer import BetterTransformer 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = T5Tokenizer.from_pretrained( 24 | base, use_fast=False, local_files_only=local_files_only 25 | ) 26 | tokenizer.padding_side = "left" 27 | 28 | if mode_cpu: 29 | print("cpu mode") 30 | model = AutoModelForSeq2SeqLM.from_pretrained( 31 | base, 32 | device_map={"": "cpu"}, 33 | use_safetensors=False, 34 | local_files_only=local_files_only 35 | ) 36 | 37 | elif mode_mps: 38 | print("mps mode") 39 | model = AutoModelForSeq2SeqLM.from_pretrained( 40 | base, 41 | device_map={"": "mps"}, 42 | torch_dtype=torch.float16, 43 | use_safetensors=False, 44 | local_files_only=local_files_only 45 | ) 46 | 47 | elif mode_gptq: 48 | print("gpu(gptq) mode") 49 | tokenizer = AutoTokenizer.from_pretrained( 50 | gptq, local_files_only=local_files_only 51 | ) 52 | tokenizer.pad_token_id = 0 53 | tokenizer.padding_side = "left" 54 | 55 | model = AutoGPTQForCausalLM.from_quantized( 56 | gptq, 57 | model_basename=gptq_base, 58 | use_safetensors=True, 59 | trust_remote_code=False, 60 | device_map="auto", 61 | quantize_config=None, 62 | local_files_only=local_files_only 63 | ) 64 | 65 | else: 66 | print("gpu mode") 67 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 68 | model = AutoModelForSeq2SeqLM.from_pretrained( 69 | base, 70 | load_in_8bit=mode_8bit, 71 | load_in_4bit=mode_4bit, 72 | device_map="auto", 73 | torch_dtype=torch.float16, 74 | use_safetensors=False, 75 | local_files_only=local_files_only 76 | ) 77 | 78 | if not mode_8bit and not mode_4bit: 79 | model.half() 80 | 81 | model = BetterTransformer.transform(model) 82 | return model, tokenizer -------------------------------------------------------------------------------- /models/vicuna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import global_vars 3 | 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from optimum.bettertransformer import BetterTransformer 6 | 7 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 8 | 9 | def load_model( 10 | base, 11 | finetuned, 12 | gptq, 13 | gptq_base, 14 | mode_cpu, 15 | mode_mps, 16 | mode_full_gpu, 17 | mode_8bit, 18 | mode_4bit, 19 | mode_gptq, 20 | mode_mps_gptq, 21 | mode_cpu_gptq, 22 | force_download_ckpt, 23 | local_files_only 24 | ): 25 | tokenizer = AutoTokenizer.from_pretrained( 26 | base, local_files_only=local_files_only, 27 | use_fast=False if global_vars.model_type == "stable-vicuna" else True, 28 | ) 29 | tokenizer.padding_side = "left" 30 | 31 | if mode_cpu: 32 | print("cpu mode") 33 | model = AutoModelForCausalLM.from_pretrained( 34 | base, 35 | device_map={"": "cpu"}, 36 | use_safetensors=False, 37 | local_files_only=local_files_only 38 | ) 39 | 40 | elif mode_mps: 41 | print("mps mode") 42 | model = AutoModelForCausalLM.from_pretrained( 43 | base, 44 | device_map={"": "mps"}, 45 | torch_dtype=torch.float16, 46 | use_safetensors=False, 47 | local_files_only=local_files_only 48 | ) 49 | 50 | elif mode_gptq: 51 | print("gpu(gptq) mode") 52 | tokenizer = AutoTokenizer.from_pretrained( 53 | gptq, local_files_only=local_files_only 54 | ) 55 | tokenizer.pad_token_id = 0 56 | tokenizer.padding_side = "left" 57 | 58 | model = AutoGPTQForCausalLM.from_quantized( 59 | gptq, 60 | model_basename=gptq_base, 61 | use_safetensors=True, 62 | trust_remote_code=False, 63 | device_map="auto", 64 | quantize_config=None, 65 | local_files_only=local_files_only 66 | ) 67 | 68 | else: 69 | print("gpu mode") 70 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 71 | model = AutoModelForCausalLM.from_pretrained( 72 | base, 73 | load_in_8bit=mode_8bit, 74 | load_in_4bit=mode_4bit, 75 | device_map="auto", 76 | torch_dtype=torch.float16, 77 | use_safetensors=False, 78 | local_files_only=local_files_only 79 | ) 80 | 81 | if not mode_8bit and not mode_4bit: 82 | model.half() 83 | 84 | model = BetterTransformer.transform(model) 85 | return model, tokenizer 86 | -------------------------------------------------------------------------------- /models/wizard_coder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from optimum.bettertransformer import BetterTransformer 5 | 6 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 7 | 8 | def load_model( 9 | base, 10 | finetuned, 11 | gptq, 12 | gptq_base, 13 | mode_cpu, 14 | mode_mps, 15 | mode_full_gpu, 16 | mode_8bit, 17 | mode_4bit, 18 | mode_gptq, 19 | mode_mps_gptq, 20 | mode_cpu_gptq, 21 | force_download_ckpt, 22 | local_files_only 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained( 25 | base, local_files_only=local_files_only 26 | ) 27 | tokenizer.padding_side = "left" 28 | 29 | if mode_cpu: 30 | print("cpu mode") 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base, 33 | device_map={"": "cpu"}, 34 | use_safetensors=False, 35 | local_files_only=local_files_only 36 | ) 37 | 38 | elif mode_mps: 39 | print("mps mode") 40 | model = AutoModelForCausalLM.from_pretrained( 41 | base, 42 | device_map={"": "mps"}, 43 | torch_dtype=torch.float16, 44 | use_safetensors=False, 45 | local_files_only=local_files_only 46 | ) 47 | 48 | elif mode_gptq: 49 | print("gpu(gptq) mode") 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | gptq, local_files_only=local_files_only 52 | ) 53 | tokenizer.pad_token_id = 0 54 | tokenizer.padding_side = "left" 55 | 56 | model = AutoGPTQForCausalLM.from_quantized( 57 | gptq, 58 | model_basename=gptq_base, 59 | use_safetensors=True, 60 | trust_remote_code=False, 61 | device_map="auto", 62 | quantize_config=None, 63 | local_files_only=local_files_only 64 | ) 65 | 66 | else: 67 | print("gpu mode") 68 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 69 | model = AutoModelForCausalLM.from_pretrained( 70 | base, 71 | load_in_8bit=mode_8bit, 72 | load_in_4bit=mode_4bit, 73 | device_map="auto", 74 | torch_dtype=torch.float16, 75 | use_safetensors=False, 76 | local_files_only=local_files_only 77 | ) 78 | 79 | if not mode_8bit and not mode_4bit: 80 | model.half() 81 | 82 | model.config.pad_token_id = tokenizer.pad_token_id 83 | # model = BetterTransformer.transform(model) 84 | return model, tokenizer -------------------------------------------------------------------------------- /models/xgen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from optimum.bettertransformer import BetterTransformer 4 | 5 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 6 | 7 | def load_model( 8 | base, 9 | finetuned, 10 | gptq, 11 | gptq_base, 12 | mode_cpu, 13 | mode_mps, 14 | mode_full_gpu, 15 | mode_8bit, 16 | mode_4bit, 17 | mode_gptq, 18 | mode_mps_gptq, 19 | mode_cpu_gptq, 20 | force_download_ckpt, 21 | local_files_only 22 | ): 23 | tokenizer = AutoTokenizer.from_pretrained( 24 | base, trust_remote_code=True, local_files_only=local_files_only 25 | ) 26 | 27 | if mode_cpu: 28 | print("cpu mode") 29 | model = AutoModelForCausalLM.from_pretrained( 30 | base, 31 | device_map={"": "cpu"}, 32 | use_safetensors=False, 33 | trust_remote_code=True, 34 | local_files_only=local_files_only 35 | ) 36 | 37 | elif mode_mps: 38 | print("mps mode") 39 | model = AutoModelForCausalLM.from_pretrained( 40 | base, 41 | device_map={"": "mps"}, 42 | torch_dtype=torch.float16, 43 | use_safetensors=False, 44 | trust_remote_code=True, 45 | local_files_only=local_files_only 46 | ) 47 | 48 | elif mode_gptq: 49 | print("gpu(gptq) mode") 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | gptq, local_files_only=local_files_only 52 | ) 53 | tokenizer.pad_token_id = 0 54 | tokenizer.padding_side = "left" 55 | 56 | model = AutoGPTQForCausalLM.from_quantized( 57 | gptq, 58 | model_basename=gptq_base, 59 | use_safetensors=True, 60 | trust_remote_code=False, 61 | device_map="auto", 62 | quantize_config=None, 63 | local_files_only=local_files_only 64 | ) 65 | 66 | else: 67 | print("gpu mode") 68 | print(f"8bit = {mode_8bit}, 4bit = {mode_4bit}") 69 | model = AutoModelForCausalLM.from_pretrained( 70 | base, 71 | torch_dtype=torch.float16, 72 | load_in_8bit=mode_8bit, 73 | load_in_4bit=mode_4bit, 74 | device_map="auto", 75 | use_safetensors=False, 76 | trust_remote_code=True, 77 | local_files_only=local_files_only 78 | ) 79 | 80 | if not mode_8bit and not mode_4bit: 81 | model.half() 82 | 83 | # model = BetterTransformer.transform(model) 84 | return model, tokenizer -------------------------------------------------------------------------------- /notebooks/llm_as_chatbot_in_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "xf3pUNyVO3WS" 7 | }, 8 | "source": [ 9 | "# Check GPU's Memory Capacity\n", 10 | "\n", 11 | "By running `nvidia-smi` command, you can find out the GPU's memory capacity on the current system. \n", 12 | "\n", 13 | "With the standard GPU instance(___T4___) which is free, you can run 7B and 13B models. With the premium GPU instance(___A100 40GB___) which is paid with the compute unit that you own, you can even run 30B model! Choose the instance at the menu `Runtime` -> `Change runtime type` -> `Hardware accelerator (GPU)` -> `GPU class (Standard or Premium)`" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "L2MoM27rfaKK", 24 | "outputId": "53175950-3269-4296-9425-3652c81ce9b7" 25 | }, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Wed Mar 22 12:11:41 2023 \n", 32 | "+-----------------------------------------------------------------------------+\n", 33 | "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", 34 | "|-------------------------------+----------------------+----------------------+\n", 35 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 36 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 37 | "| | | MIG M. |\n", 38 | "|===============================+======================+======================|\n", 39 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", 40 | "| N/A 41C P0 24W / 70W | 0MiB / 15360MiB | 0% Default |\n", 41 | "| | | N/A |\n", 42 | "+-------------------------------+----------------------+----------------------+\n", 43 | " \n", 44 | "+-----------------------------------------------------------------------------+\n", 45 | "| Processes: |\n", 46 | "| GPU GI CI PID Type Process name GPU Memory |\n", 47 | "| ID ID Usage |\n", 48 | "|=============================================================================|\n", 49 | "| No running processes found |\n", 50 | "+-----------------------------------------------------------------------------+\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "!nvidia-smi" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": { 61 | "id": "N0MDD9TuPTfJ" 62 | }, 63 | "source": [ 64 | "# Clone the repository" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "id": "a_i5DKBNnzAK" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "!git clone https://github.com/deep-diver/LLM-As-Chatbot.git" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "HUuzxWGuPYLq" 82 | }, 83 | "source": [ 84 | "# Move into the directory of the cloned repository" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": { 91 | "colab": { 92 | "base_uri": "https://localhost:8080/" 93 | }, 94 | "id": "wR-M8u7gsQqg", 95 | "outputId": "eb7b24ba-10e4-46d5-cf8f-852d9fac8170" 96 | }, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "/content/Alpaca-LoRA-Serve\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "%cd LLM-As-Chatbot" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": { 113 | "id": "XG8oy7BBPdMh" 114 | }, 115 | "source": [ 116 | "# Install dependencies" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": { 123 | "colab": { 124 | "base_uri": "https://localhost:8080/" 125 | }, 126 | "id": "moN-15x_ifHE", 127 | "outputId": "a7ec61ff-28cb-4ac4-a0ca-6a5cba060579" 128 | }, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | " Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 135 | " Created wheel for transformers: filename=transformers-4.28.0.dev0-py3-none-any.whl size=6758864 sha256=028619344608e01338ac944ad0d4e6496fe5c743c90a15dd20c2e436e56106a9\n", 136 | " Stored in directory: /tmp/pip-ephem-wheel-cache-vqcgstta/wheels/f7/92/8c/752ff3bfcd3439805d8bbf641614da38ef3226e127ebea86ee\n", 137 | " Building wheel for peft (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 138 | " Created wheel for peft: filename=peft-0.3.0.dev0-py3-none-any.whl size=40669 sha256=bb0afa4164ac44e0a604c781f61767ea3e7255b85b70e2d4cf76a4252119ac27\n", 139 | " Stored in directory: /tmp/pip-ephem-wheel-cache-vqcgstta/wheels/2d/60/1b/0edd9dc0f0c489738b1166bc1b0b560ee368f7721f89d06e3a\n", 140 | " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 141 | " Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4707 sha256=5f7dae7c29ab50f6251f5c864c70d4e485a4338a98c5cc1ee51523ace2758bf1\n", 142 | " Stored in directory: /root/.cache/pip/wheels/91/e2/96/f676aa08bfd789328c6576cd0f1fde4a3d686703bb0c247697\n", 143 | "Successfully built transformers peft ffmpy\n", 144 | "Installing collected packages: tokenizers, sentencepiece, rfc3986, pydub, ffmpy, bitsandbytes, xxhash, websockets, uc-micro-py, python-multipart, pycryptodome, orjson, multidict, mdurl, loralib, h11, frozenlist, dill, async-timeout, aiofiles, yarl, uvicorn, starlette, responses, multiprocess, markdown-it-py, linkify-it-py, huggingface-hub, httpcore, aiosignal, accelerate, transformers, mdit-py-plugins, httpx, fastapi, aiohttp, peft, gradio, datasets\n", 145 | "Successfully installed accelerate-0.17.1 aiofiles-23.1.0 aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 bitsandbytes-0.37.2 datasets-2.10.1 dill-0.3.6 fastapi-0.95.0 ffmpy-0.3.0 frozenlist-1.3.3 gradio-3.20.0 h11-0.14.0 httpcore-0.16.3 httpx-0.23.3 huggingface-hub-0.13.3 linkify-it-py-2.0.0 loralib-0.1.1 markdown-it-py-2.2.0 mdit-py-plugins-0.3.3 mdurl-0.1.2 multidict-6.0.4 multiprocess-0.70.14 orjson-3.8.8 peft-0.3.0.dev0 pycryptodome-3.17 pydub-0.25.1 python-multipart-0.0.6 responses-0.18.0 rfc3986-1.5.0 sentencepiece-0.1.97 starlette-0.26.1 tokenizers-0.13.2 transformers-4.28.0.dev0 uc-micro-py-1.0.1 uvicorn-0.21.1 websockets-10.4 xxhash-3.2.0 yarl-1.8.2\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "!pip install -r requirements.txt" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "Cr3bQkSePfrG" 157 | }, 158 | "source": [ 159 | "# Run the application" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 14, 165 | "metadata": { 166 | "id": "4Wg0eqnkPnq-" 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "#@title Choose models\n", 171 | "\n", 172 | "base_model = 'decapoda-research/llama-13b-hf' #@param [\"decapoda-research/llama-7b-hf\", \"decapoda-research/llama-13b-hf\", \"decapoda-research/llama-30b-hf\"]\n", 173 | "finetuned_model = 'chansung/alpaca-lora-13b' #@param [\"tloen/alpaca-lora-7b\", \"chansung/alpaca-lora-13b\", \"chansung/koalpaca-lora-13b\", \"chansung/alpaca-lora-30b\"]\n" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "id": "b81jhdtcQyOP" 180 | }, 181 | "source": [ 182 | "## Run the application\n", 183 | "\n", 184 | "It will take some time since LLaMA weights are huge. \n", 185 | "\n", 186 | "Click the URL appeared in the `Running on public URL:` field from the log. That will bring you to a new browser tab which opens up the running application." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "id": "y3qpzBw2jMHq" 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "!python app.py --base-url $base_model --ft-ckpt-url $finetuned_model --share" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "accelerator": "GPU", 203 | "colab": { 204 | "machine_shape": "hm", 205 | "provenance": [] 206 | }, 207 | "gpuClass": "premium", 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.8.12" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 4 228 | } 229 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | bitsandbytes 3 | datasets 4 | loralib 5 | sentencepiece 6 | git+https://github.com/huggingface/transformers.git 7 | git+https://github.com/huggingface/peft.git 8 | gradio 9 | bingbong 10 | git+https://github.com/huggingface/optimum.git 11 | git+https://github.com/huggingface/accelerate.git 12 | tokenizers>=0.14.0 13 | einops 14 | scipy 15 | protobuf==3.20.* 16 | tiktoken 17 | discord 18 | urlextract 19 | auto-gptq 20 | sseclient-py 21 | text-generation 22 | -------------------------------------------------------------------------------- /scripts/hparams_explore.py: -------------------------------------------------------------------------------- 1 | import time 2 | import itertools 3 | import wandb 4 | from transformers import GenerationConfig 5 | 6 | wandb.login(key="") 7 | 8 | PROJECT="txt_gen_test_project" 9 | 10 | generation_configs = { 11 | "temperature": [0.5, 0.7, 0.8, 0.9, 1.0], 12 | "top_p": [0.5, 0.75, 0.85, 0.95, 1.0], 13 | "num_beams": [1, 2, 3, 4] 14 | } 15 | 16 | num_gens = 1 17 | 18 | # token initialization 19 | # model initialization 20 | 21 | for comb in itertools.product(generation_configs['temperature'], 22 | generation_configs['top_p'], 23 | generation_configs['num_beams']): 24 | temperature = comb[0] 25 | top_p = comb[1] 26 | num_beams = comb[2] 27 | 28 | generation_config = GenerationConfig( 29 | temperature=temperature, 30 | top_p=top_p, 31 | num_beams=num_beams, 32 | ) 33 | 34 | first_columns = [f"gen_txt_{num}" for num in range(num_gens)] 35 | columns = first_columns + ["temperature", "top_p", "num_beams", "time_delta"] 36 | 37 | avg_time_delta = 0 38 | txt_gens = [] 39 | for i in range(num_gens): 40 | start = time.time() 41 | # text generation 42 | text = "dummy text" 43 | txt_gens.append(text) 44 | 45 | # decode outputs 46 | end = time.time() 47 | t_delta = end - start 48 | avg_time_delta = avg_time_delta + t_delta 49 | 50 | avg_time_delta = round(avg_time_delta / num_gens, 4) 51 | 52 | wandb.init( 53 | project=PROJECT, 54 | name=f"t@{temperature}-tp@{top_p}-nb@{num_beams}", 55 | config=generation_config, 56 | ) 57 | 58 | text_table = wandb.Table(columns=columns) 59 | text_table.add_data(*txt_gens, temperature, top_p, num_beams, avg_time_delta) 60 | 61 | wandb.log({ 62 | "avg_t_delta": avg_time_delta, 63 | "results": text_table 64 | }) 65 | 66 | wandb.finish() 67 | -------------------------------------------------------------------------------- /scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | wandb 3 | --------------------------------------------------------------------------------