├── .github └── workflows │ └── docker-image.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── alembic.ini ├── alembic ├── README ├── env.py ├── script.py.mako └── versions │ ├── b2363fd8dd59_create_soft_prompts_table.py │ └── e1b2cec09880_init.py ├── app ├── __init__.py ├── api │ ├── __init__.py │ ├── deps.py │ └── v1 │ │ ├── __init__.py │ │ ├── api.py │ │ └── endpoints │ │ ├── __init__.py │ │ ├── models.py │ │ ├── soft_prompts.py │ │ └── users.py ├── core │ ├── __init__.py │ ├── config.py │ ├── logging.py │ └── security.py ├── crud │ ├── __init__.py │ ├── base.py │ ├── soft_prompt.py │ └── user.py ├── db │ ├── __init__.py │ ├── base.py │ ├── base_class.py │ └── database.py ├── gpt │ ├── __init__.py │ ├── autohf.py │ ├── berthf.py │ ├── clip.py │ ├── engram.py │ ├── gooseai.py │ ├── gpthf.py │ ├── models.py │ ├── quantization.py │ ├── softprompt.py │ ├── tensorize.py │ ├── utils.py │ └── warpers.py ├── main.py ├── models │ ├── __init__.py │ ├── soft_prompt.py │ └── user.py └── schemas │ ├── __init__.py │ ├── model_item.py │ ├── soft_prompt.py │ ├── token.py │ └── user.py ├── banner.png ├── conf.env ├── docker-compose.yaml ├── docker-compose_nvidia-gpu.yaml ├── k8s ├── conf-env-configmap.yaml ├── database-deployment.yaml ├── database-service.yaml ├── postgres-data-persistentvolumeclaim.yaml ├── sukima-claim0-persistentvolumeclaim.yaml ├── sukima-deployment.yaml └── sukima-service.yaml ├── requirements.txt ├── storage └── .gitkeep └── tests ├── __init__.py └── test.py /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Fetch Dependencies 15 | run: sudo apt install docker-compose 16 | - name: Set Up Docker Environment 17 | run: docker-compose up -d && docker run --network host appropriate/curl --retry 30 --retry-connrefused http://localhost:8000/ 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Storage directory 2 | storage/* 3 | !storage/.gitkeep 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Secrets 136 | conf.env 137 | secrets.json 138 | db.json 139 | 140 | .vscode 141 | .idea 142 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9.6 2 | ENV PYTHONUNBUFFERED 1 3 | 4 | ENV PYTHONPATH "${PYTHONPATH}:/" 5 | ENV PORT=8000 6 | 7 | RUN mkdir /sukima 8 | WORKDIR /sukima 9 | 10 | COPY . /sukima 11 | RUN pip install --upgrade pip 12 | RUN pip install --no-cache-dir torch==1.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 13 | RUN pip install --no-cache-dir -r requirements.txt 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![logo](banner.png) 2 | 3 | ## Overview 4 | Sukima is a ready-to-deploy container that implements a REST API for Language Models designed with the specific purpose of easy deployment and scalability. 5 | 6 | ### Curent API Functions 7 | - **models** : Fetch a list of ready-to-use Language Models for inference. 8 | - **load** : Allocate a Language Model. 9 | - **generate** : Use a Language Model to generate tokens. 10 | - **classify** : Use a Language Model to classify tokens and retrieve scores. 11 | 12 | To view more information for API Usage, see ``/docs`` endpoint. 13 | 14 | ### Setup 15 | [Setup Guide](../../wiki/Setup) 16 | 17 | [Usage Guide](../../wiki/Usage) 18 | 19 | ### Todo 20 | - Autoscaling 21 | - HTTPS Support 22 | - Rate Limiting 23 | - Support for other Language Modeling tasks such as Sentiment Analysis and Named Entity Recognition. 24 | 25 | ### License 26 | [GPL-2.0](LICENSE) 27 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | [alembic] 2 | # path to migration scripts 3 | script_location = alembic 4 | 5 | # template used to generate migration files 6 | # file_template = %%(rev)s_%%(slug)s 7 | 8 | # sys.path path, will be prepended to sys.path if present. 9 | # defaults to the current working directory. 10 | # (new in 1.5.5) 11 | prepend_sys_path = . 12 | 13 | # timezone to use when rendering the date within the migration file 14 | # as well as the filename. 15 | # If specified, requires the python-dateutil library that can be 16 | # installed by adding `alembic[tz]` to the pip requirements 17 | # string value is passed to dateutil.tz.gettz() 18 | # leave blank for localtime 19 | # timezone = 20 | 21 | # max length of characters to apply to the 22 | # "slug" field 23 | # truncate_slug_length = 40 24 | 25 | # set to 'true' to run the environment during 26 | # the 'revision' command, regardless of autogenerate 27 | # revision_environment = false 28 | 29 | # set to 'true' to allow .pyc and .pyo files without 30 | # a source .py file to be detected as revisions in the 31 | # versions/ directory 32 | # sourceless = false 33 | 34 | # version location specification; This defaults 35 | # to ${script_location}/versions. When using multiple version 36 | # directories, initial revisions must be specified with --version-path. 37 | # The path separator used here should be the separator specified by "version_path_separator" below. 38 | # version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions 39 | 40 | # version path separator; As mentioned above, this is the character used to split 41 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 42 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 43 | # Valid values for version_path_separator are: 44 | # 45 | # version_path_separator = : 46 | # version_path_separator = ; 47 | # version_path_separator = space 48 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 49 | 50 | # the output encoding used when revision files 51 | # are written from script.py.mako 52 | # output_encoding = utf-8 53 | 54 | # sqlalchemy.url = driver://user:pass@localhost/dbname 55 | 56 | # [post_write_hooks] 57 | # This section defines scripts or Python functions that are run 58 | # on newly generated revision scripts. See the documentation for further 59 | # detail and examples 60 | 61 | # format using "black" - use the console_scripts runner, 62 | # against the "black" entrypoint 63 | # hooks = black 64 | # black.type = console_scripts 65 | # black.entrypoint = black 66 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 67 | 68 | # Logging configuration 69 | [loggers] 70 | keys = root,sqlalchemy,alembic 71 | 72 | [handlers] 73 | keys = console 74 | 75 | [formatters] 76 | keys = generic 77 | 78 | [logger_root] 79 | level = WARN 80 | handlers = console 81 | qualname = 82 | 83 | [logger_sqlalchemy] 84 | level = WARN 85 | handlers = 86 | qualname = sqlalchemy.engine 87 | 88 | [logger_alembic] 89 | level = INFO 90 | handlers = 91 | qualname = alembic 92 | 93 | [handler_console] 94 | class = StreamHandler 95 | args = (sys.stderr,) 96 | level = NOTSET 97 | formatter = generic 98 | 99 | [formatter_generic] 100 | format = %(levelname)-5.5s [%(name)s] %(message)s 101 | datefmt = %H:%M:%S 102 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/alembic/README -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from logging.config import fileConfig 3 | 4 | from app.core.config import settings 5 | from app.db.base_class import Base 6 | from sqlalchemy import engine_from_config, pool 7 | from sqlalchemy.ext.asyncio import AsyncEngine 8 | 9 | from alembic import context 10 | 11 | # this is the Alembic Config object, which provides 12 | # access to the values within the .ini file in use. 13 | config = context.config 14 | 15 | # Interpret the config file for Python logging. 16 | # This line sets up loggers basically. 17 | fileConfig(config.config_file_name) 18 | 19 | # add your model's MetaData object here 20 | # for 'autogenerate' support 21 | # from myapp import mymodel 22 | target_metadata = Base.metadata 23 | 24 | # other values from the config, defined by the needs of env.py, 25 | # can be acquired: 26 | # my_important_option = config.get_main_option("my_important_option") 27 | # ... etc. 28 | 29 | 30 | def run_migrations_offline(): 31 | """ 32 | Run migrations in 'offline' mode. 33 | 34 | This configures the context with just a URL 35 | and not an Engine, though an Engine is acceptable 36 | here as well. By skipping the Engine creation 37 | we don't even need a DBAPI to be available. 38 | 39 | Calls to context.execute() here emit the given string to the 40 | script output. 41 | """ 42 | 43 | url = settings.DATABASE_URI 44 | 45 | context.configure( 46 | url=url, 47 | target_metadata=target_metadata, 48 | literal_binds=True, 49 | dialect_opts={"paramstyle": "named"}, 50 | ) 51 | 52 | with context.begin_transaction(): 53 | context.run_migrations() 54 | 55 | 56 | def do_run_migrations(connection): 57 | context.configure(connection=connection, target_metadata=target_metadata) 58 | 59 | with context.begin_transaction(): 60 | context.run_migrations() 61 | 62 | 63 | async def run_migrations_online(): 64 | """ 65 | Run migrations in 'online' mode. 66 | 67 | In this scenario we need to create an Engine 68 | and associate a connection with the context. 69 | """ 70 | 71 | configuration = config.get_section(config.config_ini_section) 72 | configuration["sqlalchemy.url"] = settings.DATABASE_URI 73 | 74 | connectable = AsyncEngine( 75 | engine_from_config( 76 | configuration, 77 | prefix="sqlalchemy.", 78 | poolclass=pool.NullPool, 79 | future=True, 80 | ) 81 | ) 82 | 83 | async with connectable.connect() as connection: 84 | await connection.run_sync(do_run_migrations) 85 | 86 | 87 | if context.is_offline_mode(): 88 | run_migrations_offline() 89 | else: 90 | asyncio.run(run_migrations_online()) 91 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /alembic/versions/b2363fd8dd59_create_soft_prompts_table.py: -------------------------------------------------------------------------------- 1 | """create soft prompts table 2 | 3 | Revision ID: b2363fd8dd59 4 | Revises: e1b2cec09880 5 | Create Date: 2021-12-29 05:59:28.729017 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'b2363fd8dd59' 14 | down_revision = 'e1b2cec09880' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | # ### commands auto generated by Alembic - please adjust! ### 21 | op.create_table( 22 | 'soft_prompts', 23 | sa.Column('id', sa.String, autoincrement=False, nullable=False), 24 | sa.Column('name', sa.String, autoincrement=False, nullable=False), 25 | sa.Column('description', sa.String, autoincrement=False, nullable=True), 26 | sa.Column('public', sa.Boolean, default=sa.Boolean(True), autoincrement=False, nullable=False), 27 | sa.Column('creator', sa.Integer, autoincrement=False, nullable=False), 28 | 29 | # TEMP: it doesn't seem like we actually use the models table yet 30 | sa.Column('model', sa.String, nullable=False), 31 | 32 | # in the future, this will be required 33 | sa.Column('model_id', sa.Integer, autoincrement=False, nullable=True), 34 | 35 | sa.Column('loss', sa.Numeric, autoincrement=False, nullable=False), 36 | sa.Column('steps', sa.Integer, autoincrement=False, nullable=False), 37 | sa.ForeignKeyConstraint(['creator'], ['users.id'], name='fk_soft_prompts_users_id', onupdate='CASCADE', ondelete='CASCADE'), 38 | sa.ForeignKeyConstraint(['model_id'], ['models.id'], name='fk_soft_prompts_models_id', onupdate='CASCADE', ondelete='CASCADE'), 39 | sa.PrimaryKeyConstraint('id', name='pk_soft_prompts') 40 | ) 41 | op.create_index('ix_soft_prompts_id', 'soft_prompts', ['id'], unique=False) 42 | op.create_index('ix_soft_prompts_creator', 'soft_prompts', ['creator'], unique=False) 43 | op.create_index('ix_soft_prompts_model_id', 'soft_prompts', ['model_id'], unique=False) 44 | # ### end Alembic commands ### 45 | 46 | 47 | def downgrade(): 48 | # ### commands auto generated by Alembic - please adjust! ### 49 | op.drop_index('ix_soft_prompts_creator', table_name='soft_prompts') 50 | op.drop_index('ix_soft_prompts_id', table_name='soft_prompts') 51 | op.drop_index('ix_soft_prompts_model_id', table_name='soft_prompts') 52 | op.drop_table('soft_prompts') 53 | # ### end Alembic commands ### 54 | -------------------------------------------------------------------------------- /alembic/versions/e1b2cec09880_init.py: -------------------------------------------------------------------------------- 1 | """init 2 | 3 | Revision ID: e1b2cec09880 4 | Revises: 5 | Create Date: 2021-12-11 02:44:46.529627 6 | 7 | """ 8 | import sqlalchemy as sa 9 | from alembic import op 10 | 11 | # revision identifiers, used by Alembic. 12 | revision = "e1b2cec09880" 13 | down_revision = None 14 | branch_labels = None 15 | depends_on = None 16 | 17 | 18 | def upgrade(): 19 | # ### commands auto generated by Alembic - please adjust! ### 20 | op.create_table( 21 | "users", 22 | sa.Column("id", sa.Integer, primary_key=True), 23 | sa.Column("username", sa.String, nullable=False), 24 | sa.Column("password", sa.String, nullable=True), 25 | sa.Column("email", sa.String, unique=True, nullable=False), 26 | sa.Column("permission_level", sa.SmallInteger, default=0, nullable=True), 27 | ) 28 | 29 | op.create_index(op.f("ix_users_id"), "users", ["id"], unique=True) 30 | op.create_index(op.f("ix_users_username"), "users", ["username"], unique=False) 31 | op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) 32 | 33 | op.create_table( 34 | "models", 35 | sa.Column("id", sa.Integer, primary_key=True), 36 | sa.Column("name", sa.String, unique=True, nullable=False), 37 | sa.Column("size", sa.Integer, nullable=False), 38 | ) 39 | 40 | op.create_index(op.f("ix_models_id"), "models", ["id"], unique=True) 41 | op.create_index(op.f("ix_models_name"), "models", ["name"], unique=False) 42 | 43 | op.create_table( 44 | "user_model_association", 45 | sa.Column("user_id", sa.Integer, nullable=True), 46 | sa.Column("model_id", sa.Integer, nullable=True), 47 | sa.ForeignKeyConstraint(["model_id"], ["models.id"], ), 48 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ) 49 | ) 50 | # ### end Alembic commands ### 51 | 52 | 53 | def downgrade(): 54 | # ### commands auto generated by Alembic - please adjust! ### 55 | op.drop_table("user_model_association") 56 | 57 | op.drop_index(op.f("ix_users_id"), table_name="users") 58 | op.drop_index(op.f("ix_users_username"), table_name="users") 59 | op.drop_index(op.f("ix_users_email"), table_name="users") 60 | op.drop_table("users") 61 | 62 | op.drop_index(op.f("ix_models_id"), table_name="models") 63 | op.drop_index(op.f("ix_models_name"), table_name="models") 64 | op.drop_table("models") 65 | # ### end Alembic commands ### 66 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/__init__.py -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/api/__init__.py -------------------------------------------------------------------------------- /app/api/deps.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncIterator 2 | 3 | import app.crud.user as crud 4 | import app.models.user as models 5 | from app.core.config import settings 6 | from app.db.database import async_session 7 | from app.schemas.token import TokenData 8 | from fastapi import Depends, HTTPException 9 | from fastapi.security import OAuth2PasswordBearer 10 | from jose import JWTError, jwt 11 | from sqlalchemy.ext.asyncio import AsyncSession 12 | from starlette import status 13 | 14 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL) 15 | 16 | 17 | async def get_session() -> AsyncIterator[AsyncSession]: 18 | async with async_session() as session: 19 | try: 20 | yield session 21 | except Exception as e: 22 | raise e 23 | finally: 24 | await session.close() 25 | 26 | 27 | async def get_current_user(session: AsyncSession = Depends(get_session), token: str = Depends(oauth2_scheme)) -> models.User: # noqa 28 | credentials_exception = HTTPException( 29 | status_code=status.HTTP_401_UNAUTHORIZED, 30 | detail="Could not validate credentials", 31 | headers={"WWW-Authenticate": "Bearer"}, 32 | ) 33 | 34 | try: 35 | payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) 36 | username: str = payload.get("sub") 37 | 38 | if username is None: 39 | raise credentials_exception 40 | 41 | token_data = TokenData(username=username) 42 | 43 | except JWTError: 44 | raise credentials_exception 45 | 46 | user = await crud.user.get_by_username(session, username=token_data.username) 47 | 48 | if user is None: 49 | raise credentials_exception 50 | 51 | return user 52 | 53 | 54 | async def get_current_approved_user(current_user: models.User = Depends(get_current_user)) -> models.User: 55 | if not current_user.permission_level > 0: 56 | raise HTTPException(status_code=400, detail="Not approved.") 57 | 58 | return current_user 59 | -------------------------------------------------------------------------------- /app/api/v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/api/v1/__init__.py -------------------------------------------------------------------------------- /app/api/v1/api.py: -------------------------------------------------------------------------------- 1 | from app.api.v1.endpoints import models, soft_prompts, users 2 | from fastapi import APIRouter 3 | 4 | api_router = APIRouter() 5 | 6 | api_router.include_router(users.router, prefix="/users", tags=["users"]) 7 | api_router.include_router(models.router, prefix="/models", tags=["models"]) 8 | api_router.include_router(soft_prompts.router, prefix="/softprompts", tags=["softprompts"]) 9 | -------------------------------------------------------------------------------- /app/api/v1/endpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/api/v1/endpoints/__init__.py -------------------------------------------------------------------------------- /app/api/v1/endpoints/models.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import traceback 4 | 5 | import app.crud.soft_prompt as crud 6 | from app.api.deps import get_current_approved_user, get_session 7 | from app.gpt.berthf import BERTHF 8 | from app.gpt.gpthf import GPTHF 9 | from app.gpt.clip import CLIP 10 | from app.gpt.gooseai import OpenAI 11 | from app.gpt.models import gpt_models 12 | from app.gpt.utils import is_decoder 13 | from app.schemas.model_item import ModelGenRequest, ModelLoadRequest, ModelClassifyRequest, ModelHiddenRequest 14 | from app.schemas.user import User 15 | 16 | from fastapi import APIRouter, Depends, HTTPException 17 | from sqlalchemy.ext.asyncio import AsyncSession 18 | from transformers import AutoConfig 19 | 20 | router = APIRouter() 21 | 22 | 23 | @router.get("/") 24 | async def get_model_list(): 25 | model_dict = {"models": {}} 26 | 27 | for model in gpt_models: 28 | model_dict["models"][model.model_name] = {"ready": True} 29 | 30 | return model_dict 31 | 32 | 33 | @router.post("/load") 34 | async def load_model(request: ModelLoadRequest, current_user: User = Depends(get_current_approved_user)): # noqa 35 | # Check that model exists 36 | if gpt_models is not None: 37 | for m in gpt_models: 38 | if m.model_name == request.model: 39 | raise HTTPException(status_code=400, detail="Model already loaded") 40 | 41 | try: 42 | try: 43 | if is_decoder(AutoConfig.from_pretrained(request.model)): 44 | model = GPTHF(model_name=request.model, device=request.device, parallelize=request.parallel, sharded=request.sharded, quantized=request.quantized, tensorized=request.tensorize) 45 | else: 46 | if not ('clip' in request.model): 47 | model = BERTHF(model_name=request.model, device=request.device, parallelize=request.parallel, sharded=request. sharded, quantized=request.quantized, tensorized=request.tensorize) 48 | else: 49 | model = CLIP(model_name=request.model, device=request.device, parallelize=request.parallel, sharded=request.sharded, quantized=request.quantized, tensorized=request.tensorize) 50 | except: 51 | model = OpenAI(model_name=request.model, decoder=True) 52 | 53 | gpt_models.append(model) 54 | 55 | return {"message": f"Successfully loaded model: {request.model}"} 56 | 57 | except Exception as e: 58 | return HTTPException(status_code=400, detail=f"Unable to load the model!\n{e}\n{traceback.format_exc()}") 59 | 60 | 61 | @router.post("/generate") 62 | async def generate(request: ModelGenRequest, current_user: User = Depends(get_current_approved_user), session: AsyncSession = Depends(get_session)): # noqa 63 | for m in gpt_models: 64 | if m.model_name == request.model: 65 | db_softprompt = None 66 | if request.softprompt: 67 | db_softprompt = await crud.soft_prompt.get(session, request.softprompt) 68 | if db_softprompt is None: 69 | raise HTTPException(status_code=400, detail=f"No soft prompt with UUID {request.softprompt} exists!") # noqa 70 | try: 71 | if not m.decoder: 72 | raise RuntimeError("This is not a decoder model!") 73 | return m.generate(request.dict(), db_softprompt=db_softprompt) 74 | except Exception as e: 75 | raise HTTPException(status_code=400, detail=f"Unable to generate!\n{e}\n{traceback.format_exc()}") 76 | 77 | raise HTTPException(status_code=404, detail="Model not found.") 78 | 79 | @router.post("/classify") 80 | async def classify(request: ModelClassifyRequest, current_user: User = Depends(get_current_approved_user)): # noqa 81 | for m in gpt_models: 82 | if m.model_name == request.model: 83 | try: 84 | return m.classify(request.dict()) 85 | except Exception as e: 86 | raise HTTPException(status_code=400, detail=f"Invalid request body!\n{e}") 87 | 88 | raise HTTPException(status_code=404, detail="Model not found.") 89 | 90 | @router.post("/hidden") 91 | async def hidden(request: ModelHiddenRequest, current_user: User = Depends(get_current_approved_user)): # noqa 92 | for m in gpt_models: 93 | if m.model_name == request.model: 94 | try: 95 | return m.hidden(request.dict()) 96 | except Exception as e: 97 | raise HTTPException(status_code=400, detail=f"Invalid request body!\n{e}") 98 | 99 | raise HTTPException(status_code=404, detail="Model not found.") 100 | -------------------------------------------------------------------------------- /app/api/v1/endpoints/soft_prompts.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | 4 | from fastapi import APIRouter, Depends, File, HTTPException, UploadFile 5 | from sqlalchemy.ext.asyncio import AsyncSession 6 | from typing import Optional 7 | 8 | import app.crud.soft_prompt as crud 9 | from app.api.deps import get_current_approved_user, get_session, get_current_user 10 | from app.models.user import User 11 | from app.schemas.soft_prompt import SoftPromptCreate 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.get("/my") 17 | async def get_user_soft_prompts(current_user: User = Depends(get_current_approved_user), session: AsyncSession = Depends(get_session)): # noqa 18 | soft_prompts = await crud.soft_prompt.get_by_creator(session, creator=current_user) 19 | return [sp.asdict() for sp in soft_prompts] 20 | 21 | 22 | @router.post("/upload") 23 | async def upload_soft_prompt(file: UploadFile = File(...), current_user: User = Depends(get_current_approved_user), session: AsyncSession = Depends(get_session)): # noqa 24 | try: 25 | contents = json.load(file.file) 26 | metadata = SoftPromptCreate(**contents) 27 | data = base64.b64decode(contents["data"]) 28 | except Exception as e: 29 | raise HTTPException(status_code=400, detail=f"Malformed soft prompt JSON\n{e}") 30 | 31 | try: 32 | db_obj = await crud.soft_prompt.upload_soft_prompt(session, creator=current_user, data=data, obj_in=metadata) 33 | except LookupError as e: 34 | raise HTTPException(status_code=400, detail=str(e)) 35 | 36 | return db_obj.asdict() 37 | 38 | 39 | @router.get("/{id}") 40 | async def get_soft_prompt(id: str, export: Optional[bool] = False, current_user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session)): # noqa 41 | # detect json file extension 42 | if id.endswith(".json"): 43 | export = True 44 | id = id[:-5] 45 | 46 | db_obj = await crud.soft_prompt.get(session, id) 47 | 48 | if db_obj is None: 49 | raise HTTPException(status_code=404, detail="Soft prompt not found.") 50 | 51 | if not db_obj.public and current_user.id != db_obj.creator: 52 | raise HTTPException(status_code=403, detail="You are not authorized to view this soft prompt.") 53 | 54 | info = db_obj.asdict() 55 | if export: 56 | # add, remove fields to resemble original JSON that was uploaded 57 | info.pop("id") 58 | info.pop("public") 59 | info["data"] = base64.b64encode(db_obj.read()) 60 | 61 | return info 62 | 63 | 64 | @router.delete("/{id}") 65 | async def delete_soft_prompt(id: str, current_user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session)): # noqa 66 | db_obj = await crud.soft_prompt.get(session, id) 67 | 68 | if db_obj is not None: 69 | if current_user.id != db_obj.creator: 70 | raise HTTPException(status_code=403, detail="You are not allowed to delete this soft prompt.") 71 | 72 | await crud.soft_prompt.remove(session, id=id) 73 | 74 | return {"message": "Deleted soft prompt successfully."} 75 | -------------------------------------------------------------------------------- /app/api/v1/endpoints/users.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | import app.crud.user as crud 4 | from app.api.deps import get_session 5 | from app.core.config import settings 6 | from app.core.security import create_access_token 7 | from app.schemas.user import UserCreate 8 | from fastapi import APIRouter, Depends, HTTPException 9 | from fastapi.security import OAuth2PasswordRequestForm 10 | from sqlalchemy.ext.asyncio import AsyncSession 11 | 12 | router = APIRouter() 13 | 14 | 15 | @router.post("/register") 16 | async def register_user(user: UserCreate, session: AsyncSession = Depends(get_session)): 17 | db_user = await crud.user.get_by_email(session, user.email) 18 | 19 | if not db_user: 20 | await crud.user.create_user(session, obj_in=user) 21 | 22 | return {"message": "Successfully created user."} 23 | 24 | 25 | @router.post("/token") 26 | async def generate_token(form_data: OAuth2PasswordRequestForm = Depends(), session: AsyncSession = Depends(get_session)): 27 | user = await crud.user.authenticate(session, username=form_data.username, password=form_data.password) 28 | 29 | if not user: 30 | raise HTTPException(status_code=401, detail="Invalid credentials.") 31 | 32 | expiration = timedelta(days=settings.ACCESS_TOKEN_EXPIRATION) 33 | token = create_access_token({"sub": user.username}, expiration) 34 | 35 | return {"access_token": token, "token_type": "bearer"} 36 | -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/core/__init__.py -------------------------------------------------------------------------------- /app/core/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os import PathLike 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | from pydantic import AnyHttpUrl, BaseSettings, validator 6 | 7 | 8 | class Settings(BaseSettings): 9 | PROJECT_NAME: str 10 | BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] 11 | 12 | SECRET_KEY: str 13 | 14 | @validator("BACKEND_CORS_ORIGINS", pre=True) 15 | def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: 16 | if isinstance(v, str) and not v.startswith("["): 17 | return [i.strip() for i in v.split(",")] 18 | 19 | elif isinstance(v, (list, str)): 20 | return v 21 | 22 | raise ValueError(v) 23 | 24 | POSTGRES_SERVER: str 25 | POSTGRES_USER: str 26 | POSTGRES_PASSWORD: str 27 | POSTGRES_DB: str 28 | DATABASE_URI: Optional[str] = None 29 | STORAGE_PATH: PathLike = Path.cwd() / "storage" 30 | 31 | OPENAI_API_KEY: Optional[str] = None 32 | OPENAI_API_BASE: Optional[str] = None 33 | 34 | @validator("DATABASE_URI", pre=True) 35 | def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: 36 | if isinstance(v, str): 37 | return v 38 | 39 | return f"postgresql+asyncpg://{values.get('POSTGRES_USER')}:{values.get('POSTGRES_PASSWORD')}@{values.get('POSTGRES_SERVER')}/{values.get('POSTGRES_DB')}" 40 | 41 | ALGORITHM: str 42 | ACCESS_TOKEN_EXPIRATION: int 43 | TOKEN_URL: str 44 | 45 | class Config: 46 | case_sensitive = True 47 | env_file = ".env" 48 | 49 | 50 | settings = Settings() 51 | -------------------------------------------------------------------------------- /app/core/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | FORMAT = "%(asctime)s %(levelname)s %(filename)s(%(lineno)d) - %(message)s" 4 | logging.basicConfig(format=FORMAT, level=logging.INFO) 5 | logger = logging.getLogger(__name__) 6 | 7 | -------------------------------------------------------------------------------- /app/core/security.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Optional 3 | 4 | from app.core.config import settings 5 | from jose import jwt 6 | from passlib.context import CryptContext 7 | 8 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 9 | 10 | 11 | def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): 12 | to_encode = data.copy() 13 | 14 | if expires_delta: 15 | expire = datetime.utcnow() + expires_delta 16 | else: 17 | expire = datetime.utcnow() + timedelta(days=7) # 7 day token expiration by default 18 | 19 | to_encode.update({"exp": expire}) 20 | encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) 21 | 22 | return encoded_jwt 23 | 24 | 25 | def verify_password(plain, hashed): 26 | return pwd_context.verify(plain, hashed) 27 | 28 | 29 | def get_password_hash(plain): 30 | return pwd_context.hash(plain) 31 | -------------------------------------------------------------------------------- /app/crud/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/crud/__init__.py -------------------------------------------------------------------------------- /app/crud/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generic, Optional, Type, TypeVar 2 | 3 | from app.db.base_class import Base 4 | from fastapi.encoders import jsonable_encoder 5 | from pydantic import BaseModel 6 | from sqlalchemy import select 7 | from sqlalchemy.ext.asyncio import AsyncSession 8 | 9 | ModelType = TypeVar("ModelType", bound=Base) 10 | CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) 11 | UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) 12 | 13 | 14 | class CrudBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): 15 | def __init__(self, model: Type[ModelType]): 16 | self.model = model 17 | 18 | async def get(self, session: AsyncSession, id: Any) -> Optional[ModelType]: 19 | return (await session.execute(select(self.model).where(self.model.id == id))).scalars().first() 20 | 21 | async def update(self, session: AsyncSession, *, db_obj: ModelType, obj_in: CreateSchemaType) -> ModelType: 22 | obj_data = jsonable_encoder(db_obj) 23 | 24 | if isinstance(obj_in, dict): 25 | update_data = obj_in 26 | else: 27 | update_data = obj_in.dict(exclude_unset=True) 28 | 29 | for field in obj_data: 30 | if field in update_data: 31 | setattr(db_obj, field, update_data[field]) 32 | 33 | session.add(db_obj) 34 | await session.commit() 35 | await session.refresh(db_obj) 36 | 37 | return db_obj 38 | 39 | async def remove(self, session: AsyncSession, *, id: Any) -> ModelType: 40 | # TODO: impl 41 | pass 42 | -------------------------------------------------------------------------------- /app/crud/soft_prompt.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from sqlalchemy import select 4 | from sqlalchemy.ext.asyncio import AsyncSession 5 | from uuid import uuid4 6 | 7 | from app.core.config import settings 8 | from app.crud.base import CrudBase 9 | from app.gpt.models import gpt_models 10 | from app.models.soft_prompt import SoftPrompt 11 | from app.models.user import User 12 | from app.schemas.soft_prompt import SoftPromptCreate, SoftPromptUpdate 13 | 14 | 15 | class CrudSoftPrompt(CrudBase[SoftPrompt, SoftPromptCreate, SoftPromptUpdate]): 16 | async def upload_soft_prompt(self, session: AsyncSession, *, creator: User, data: bytes, obj_in: SoftPromptCreate) -> SoftPrompt: # noqa 17 | # was there supposed to be a database table for this? 18 | model_exists = False 19 | for model in gpt_models: 20 | if model.model_name == obj_in.model: 21 | model_exists = True 22 | break 23 | if not model_exists: 24 | raise LookupError(f"Model {obj_in.model} has not been loaded.") 25 | 26 | db_obj = SoftPrompt( 27 | id=str(uuid4()), 28 | name=obj_in.name, 29 | description=obj_in.description, 30 | public=obj_in.public, 31 | creator=creator.id, 32 | loss=obj_in.loss, 33 | steps=obj_in.steps, 34 | model=obj_in.model 35 | ) 36 | 37 | session.add(db_obj) 38 | await session.commit() 39 | await session.refresh(db_obj) 40 | 41 | db_obj.write(data) 42 | 43 | return db_obj 44 | 45 | async def get_by_creator(self, session: AsyncSession, *, creator: User) -> List[SoftPrompt]: 46 | return (await session.execute(select(self.model).where(self.model.creator == creator.id))).scalars().all() 47 | 48 | 49 | soft_prompt = CrudSoftPrompt(SoftPrompt) 50 | -------------------------------------------------------------------------------- /app/crud/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from app.crud.base import CrudBase 4 | from app.core.security import get_password_hash, verify_password 5 | from app.models.user import User 6 | from app.schemas.user import UserCreate, UserUpdate 7 | from sqlalchemy import select 8 | from sqlalchemy.ext.asyncio import AsyncSession 9 | 10 | 11 | class CrudUser(CrudBase[User, UserCreate, UserUpdate]): 12 | async def get_by_email(self, session: AsyncSession, email: str) -> Optional[User]: 13 | return (await session.execute(select(self.model).where(self.model.email == email))).scalars().first() 14 | 15 | async def get_by_username(self, session: AsyncSession, username: str) -> Optional[User]: 16 | return (await session.execute(select(self.model).where(self.model.username == username))).scalars().first() 17 | 18 | async def create_user(self, session: AsyncSession, *, obj_in: UserCreate) -> User: 19 | db_obj = User( 20 | username=obj_in.username, 21 | password=get_password_hash(obj_in.password), 22 | email=obj_in.email, 23 | permission_level=obj_in.permission_level 24 | ) 25 | 26 | session.add(db_obj) 27 | await session.commit() 28 | await session.refresh(db_obj) 29 | 30 | return db_obj 31 | 32 | async def authenticate(self, session: AsyncSession, *, username: str, password: str) -> Optional[User]: 33 | db_user = await self.get_by_username(session, username) 34 | 35 | if not db_user: 36 | return False 37 | 38 | if not verify_password(password, db_user.password): 39 | return False 40 | 41 | return db_user 42 | 43 | 44 | user = CrudUser(User) 45 | -------------------------------------------------------------------------------- /app/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/db/__init__.py -------------------------------------------------------------------------------- /app/db/base.py: -------------------------------------------------------------------------------- 1 | # Import models first so that Base will have them before being imported by Alembic 2 | from app.db.base_class import Base # noqa 3 | from app.schemas.model_item import ModelItem # noqa 4 | from app.schemas.user import User # noqa 5 | -------------------------------------------------------------------------------- /app/db/base_class.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from sqlalchemy.ext.declarative import as_declarative, declared_attr 4 | 5 | 6 | @as_declarative() 7 | class Base: 8 | id: Any 9 | __name__: str 10 | 11 | # Generate __tablename__ automatically 12 | @declared_attr 13 | def __tablename__(cls) -> str: 14 | return cls.__name__.lower() 15 | -------------------------------------------------------------------------------- /app/db/database.py: -------------------------------------------------------------------------------- 1 | from app.core.config import settings 2 | from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | engine = create_async_engine( 6 | settings.DATABASE_URI, pool_pre_ping=True 7 | ) 8 | 9 | async_session = sessionmaker( 10 | engine, expire_on_commit=False, class_=AsyncSession 11 | ) 12 | -------------------------------------------------------------------------------- /app/gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/gpt/__init__.py -------------------------------------------------------------------------------- /app/gpt/autohf.py: -------------------------------------------------------------------------------- 1 | class AutoHF: 2 | def __init__(self, model_name='generic', decoder=False): 3 | self.model_name = model_name 4 | self.decoder = decoder 5 | 6 | def generate(self, args): 7 | raise NotImplementedError 8 | 9 | def classify(self, args): 10 | raise NotImplementedError 11 | 12 | def hidden(self, args): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /app/gpt/berthf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from app.core.config import settings 4 | from app.core.logging import logger 5 | 6 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification 7 | 8 | from app.gpt.autohf import AutoHF 9 | from app.gpt.tensorize import tensorize, untensorize 10 | from app.gpt.utils import Checkpoint, get_dtype, tensorized_path 11 | 12 | class BERTHF(AutoHF): 13 | def __init__(self, model_name='distilroberta-base', device=None, parallelize=False, sharded=False, quantized=False, tensorized=False): 14 | super().__init__(model_name=model_name, decoder=False) 15 | 16 | model_dtype = get_dtype(device) 17 | self.device = device 18 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 19 | self.tensorized = False 20 | 21 | if tensorized: 22 | _path, exists = tensorized_path(model_name) 23 | if exists: 24 | logger.info(f'Loading tensorized model {model_name}') 25 | self.model = untensorize(str(_path), self.device, quantized=quantized) 26 | self.tensorized = True 27 | 28 | if sharded: 29 | self.model = AutoModelForSequenceClassification.from_pretrained( 30 | pretrained_model_name_or_path=None, 31 | config=AutoConfig.from_pretrained(model_name), 32 | state_dict=Checkpoint(model_name, self.device), 33 | torch_dtype=model_dtype 34 | ).eval().to(self.device) 35 | 36 | if (not sharded) and (not quantized) and (not self.tensorized): 37 | self.model = AutoModelForSequenceClassification.from_pretrained( 38 | model_name, 39 | torch_dtype=model_dtype 40 | ).eval().to(self.device) 41 | 42 | if quantized: 43 | raise NotImplementedError('Quantized models are not supported yet for encoder models such as BERT.') 44 | 45 | if (tensorized) and (not self.tensorized): 46 | # check if model file exists in ./storage/{model_name}.model 47 | _path, exists = tensorized_path(model_name) 48 | if not exists: 49 | logger.info(f'Tensorizing model {model_name}') 50 | # tensorize model 51 | tensorize(self.model, str(_path)) 52 | del self.model 53 | raise Exception('Tensorized the model! The original model has been altered, please load the model again to use the tensorized model.') 54 | 55 | if parallelize: 56 | raise NotImplementedError('Parallelization is not supported yet for encoder models such as BERT.') 57 | 58 | @torch.inference_mode() 59 | def classify(self, args): 60 | if not isinstance(args, dict): 61 | raise ValueError('args must be a dictionary.') 62 | 63 | if 'prompt' not in args or not isinstance(args['prompt'], str): 64 | raise ValueError('args must contain a prompt as a string.') 65 | 66 | 67 | if "labels" not in args or not isinstance(args["labels"], list): 68 | raise ValueError("args must contain a list of labels") 69 | 70 | for label in args["labels"]: 71 | if not isinstance(label, str): 72 | raise ValueError("labels must be a list of integers") 73 | 74 | prompt_inputs = self.tokenizer.encode(args['prompt'], return_tensors='pt').to(self.device) 75 | 76 | outputs = self.model(prompt_inputs).logits 77 | outputs = torch.nn.functional.softmax(outputs, dim=1) 78 | outputs = outputs.detach().cpu().numpy() 79 | output_probs = {} 80 | 81 | # TODO: automatically fill labels 82 | 83 | for i in range(len(args["labels"])): 84 | output_probs[args["labels"][i]] = float(outputs[0][i]) 85 | 86 | return output_probs 87 | 88 | @torch.inference_mode() 89 | def hidden(self, args): 90 | # args: 91 | # prompt: str - prompt to extract hidden states from 92 | # layers: int - number of last hidden layers to return 93 | 94 | if not isinstance(args, dict): 95 | raise ValueError('args must be a dictionary.') 96 | 97 | if 'prompt' not in args or not isinstance(args['prompt'], str): 98 | raise ValueError('args must contain a prompt as a string.') 99 | 100 | if 'layers' not in args or not isinstance(args['layers'], list): 101 | raise ValueError('layers must be the last n hidden layers to return.') 102 | 103 | prompt_inputs = self.tokenizer.encode(args['prompt'], return_tensors='pt').to(self.device) 104 | 105 | hidden_states = self.model(prompt_inputs, output_hidden_states=True).hidden_states 106 | layers = {i: torch.mean(hidden_states[i], dim = (1, )).detach().cpu().numpy().tolist() for i in args['layers']} 107 | 108 | return layers 109 | -------------------------------------------------------------------------------- /app/gpt/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import requests 3 | import numpy as np 4 | from transformers import CLIPProcessor, CLIPModel 5 | from PIL import Image 6 | 7 | from app.core.config import settings 8 | from app.core.logging import logger 9 | from app.gpt.autohf import AutoHF 10 | from app.gpt.tensorize import tensorize, untensorize 11 | from app.gpt.utils import Checkpoint, get_dtype, tensorized_path 12 | 13 | class CLIP(AutoHF): 14 | def __init__(self, model_name='openai/clip-vit-base-patch32', device=None, parallelize=False, sharded=False, quantized=False, tensorized=False): 15 | super().__init__(model_name=model_name, decoder=False) 16 | self.device = device 17 | self.tensorized = False 18 | processor_name = None 19 | if not model_name.startswith('openai'): 20 | processor_name = 'openai/'+model_name 21 | else: 22 | processor_name = model_name 23 | self.processor = CLIPProcessor.from_pretrained(processor_name) 24 | 25 | if tensorized: 26 | # check if tensorized model already exists so we can skip expensive model loading below 27 | _path, exists = tensorized_path(model_name) 28 | if exists: 29 | logger.info(f'Loading tensorized model {model_name}') 30 | self.model = untensorize(str(_path), self.device, quantized=quantized) 31 | self.tensorized = True 32 | 33 | if (not quantized) and (not self.tensorized): 34 | self.model = CLIPModel.from_pretrained(model_name).eval().to(device) 35 | 36 | if (tensorized) and (not self.tensorized): 37 | # check if model file exists in ./storage/{model_name}.model 38 | _path, exists = tensorized_path(model_name) 39 | if not exists: 40 | logger.info(f'Tensorizing model {model_name}') 41 | # tensorize model 42 | tensorize(self.model, str(_path)) 43 | del self.model 44 | raise Exception('Tensorized the model! The original model has been altered, please load the model again to use the tensorized model.') 45 | 46 | if parallelize: 47 | self.model.parallelize() 48 | 49 | @torch.inference_mode() 50 | def _text_feats(self, in_text: str): 51 | text_tokens = self.processor(text=in_text, return_tensors='pt', padding=True)['input_ids'].to(self.device) 52 | result = self.model.get_text_features(input_ids=text_tokens).cpu().detach().numpy() 53 | return (result / np.linalg.norm(result, axis=1, keepdims=True)).squeeze(axis=0) 54 | 55 | @torch.inference_mode() 56 | def _img_feats(self, url: str): 57 | image = Image.open(requests.get(url, stream=True).raw).convert('RGB') 58 | inputs = self.processor(images=image, return_tensors='pt')['pixel_values'].to(self.device) 59 | result = self.model.get_image_features(pixel_values=inputs).cpu().detach().numpy() 60 | return (result / np.linalg.norm(result)).squeeze(axis=0) 61 | 62 | def _sim(self, text_feats, img_feats): 63 | return np.dot(img_feats, text_feats)/(np.sqrt(np.linalg.norm(img_feats))*np.sqrt(np.linalg.norm(text_feats))) 64 | 65 | def classify(self, args): 66 | if 'prompt' not in args or not isinstance(args['prompt'], str): 67 | raise ValueError('args must contain a prompt url as a string') 68 | if 'labels' not in args or not isinstance(args['labels'], list): 69 | raise ValueError('args must contain a list of labels') 70 | for label in args['labels']: 71 | if not isinstance(label, list): 72 | raise ValueError('labels must be a list of lists containing strings') 73 | 74 | img_feats = self._img_feats(args['prompt']) 75 | 76 | compiled_similarities = [] 77 | for label_set in range(len(args['labels'])): 78 | label_similarities = {} 79 | for label in args['labels'][label_set]: 80 | text_feats = self._text_feats(label) 81 | label_similarities[label] = self._sim(text_feats, img_feats).item() 82 | compiled_similarities.append(label_similarities) 83 | output = {'labels': compiled_similarities} 84 | 85 | return output 86 | -------------------------------------------------------------------------------- /app/gpt/engram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class Engrams: 5 | def __init__(self, model=None, tokenizer=None): 6 | self.model = model 7 | self.tokenizer = tokenizer 8 | 9 | def build(forward, tokens, shift=10000, factor=20000, rampdown=lambda x: x / 2): 10 | h = list(forward(input_ids=tokens[:, -512:].long().cuda(), output_hidden_states=True).hidden_states[1:]) 11 | f = 0 12 | fa = 1.0 / float(len(h)) 13 | 14 | for layer in range(len(h)): 15 | f = f + fa 16 | h[layer] = torch.mean(h[layer].detach().double(), dim=(1, )) * f 17 | 18 | h = torch.sum(torch.stack(h, axis=1)[0], dim=(0, )) 19 | return ((h + shift) / factor).float().to("cpu").numpy() 20 | -------------------------------------------------------------------------------- /app/gpt/gooseai.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | from app.core.config import settings 4 | from app.gpt.autohf import AutoHF 5 | from app.models.soft_prompt import SoftPrompt as SoftPromptModel 6 | 7 | from transformers import AutoTokenizer 8 | 9 | import openai 10 | 11 | class OpenAI(AutoHF): 12 | def __init__(self, model_name='convo-6b', decoder=True): 13 | self.model_name = model_name 14 | self.decoder = decoder 15 | 16 | self.tokenizer = AutoTokenizer.from_pretrained('gpt2') # might wanna check if this is the same as the one used in the model, 20B uses a different tokenizer 17 | # tokenizer is just going to be used to convert eos_token_id to stop token string 18 | 19 | openai.api_key = settings.OPENAI_API_KEY 20 | openai.api_base = settings.OPENAI_API_BASE # give users option to use openai if they are financial masochists 21 | 22 | engines = openai.Engine.list() 23 | for e in engines.data: 24 | if e.id == self.model_name: 25 | self.engine = e 26 | break 27 | 28 | if self.engine is None: 29 | raise ValueError(f'OpenAI/GooseAI engine {self.model_name} not found') 30 | 31 | def generate(self, args, *, db_softprompt: Optional[SoftPromptModel] = None): 32 | prompt = args.get('prompt', None) 33 | if prompt is None: 34 | prompt = '<|endoftext|>' 35 | 36 | # Sample arguments 37 | 38 | sample_args = args.get('sample_args', None) 39 | if sample_args is None: 40 | raise ValueError('sample_args is required') 41 | 42 | temperature = sample_args.get('temperature', 1.0) 43 | top_p = sample_args.get('top_p', 1.0) 44 | top_k = sample_args.get('top_k', 128) 45 | tfs = sample_args.get('tfs', 0.99) 46 | repetition_penalty = sample_args.get('rep_p', 1.0) 47 | 48 | logit_bias = {} 49 | bad_words = sample_args.get('bad_words', None) 50 | if bad_words is not None: 51 | for bad_word in bad_words: 52 | bad_word = self.tokenizer.encode(bad_word) 53 | if len(bad_word) > 1: 54 | for bad_word_token_idx in range(len(bad_word)): 55 | logit_bias[str(bad_word[bad_word_token_idx])] = -math.sin((math.pi*(bad_word_token_idx/(len(bad_word)-1)))/2) * 100.0 56 | else: 57 | logit_bias[str(bad_word[0])] = -100.0 58 | 59 | logit_biases = sample_args.get('logit_biases', None) 60 | if logit_biases is not None: 61 | for bias in logit_biases: 62 | logit_bias[str(bias['id'])] = bias['bias'] 63 | 64 | # Generation arguments 65 | 66 | gen_args = args.get('gen_args', None) 67 | if gen_args is None: 68 | raise ValueError('gen_args is required') 69 | 70 | max_tokens = gen_args.get('max_length', None) 71 | if max_tokens is None: 72 | raise ValueError('max_length is required') 73 | 74 | stop = gen_args.get('eos_token_id', None) 75 | if stop is not None: 76 | stop = self.tokenizer.decode(stop) 77 | 78 | # Generate 79 | 80 | output = {} 81 | 82 | output['output'] = prompt + openai.Completion.create( 83 | engine = self.engine.id, 84 | prompt = prompt, 85 | max_tokens = max_tokens, 86 | stop = stop, 87 | temperature = temperature, 88 | top_p = top_p, 89 | top_k = top_k, 90 | tfs = tfs, 91 | repetition_penalty = repetition_penalty, 92 | logit_bias = logit_bias, 93 | ).choices[0].text 94 | 95 | return output 96 | 97 | def classify(self, args): 98 | raise NotImplementedError 99 | 100 | def hidden(self, args): 101 | raise NotImplementedError 102 | -------------------------------------------------------------------------------- /app/gpt/gpthf.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | 4 | from app.core.config import settings 5 | from app.core.logging import logger 6 | from app.gpt.autohf import AutoHF 7 | from app.gpt.softprompt import SoftPrompt, AutoModelForSoftPromptLM, current_sp, resize_model_embeddings 8 | from app.gpt.tensorize import tensorize, untensorize 9 | from app.gpt.utils import Checkpoint, get_dtype, tensorized_path 10 | from app.gpt.warpers import * 11 | from app.models.soft_prompt import SoftPrompt as SoftPromptModel 12 | from transformers import (AutoConfig, AutoTokenizer, 13 | LogitsProcessorList, MaxLengthCriteria, 14 | MaxTimeCriteria, NoBadWordsLogitsProcessor, 15 | StoppingCriteriaList, TemperatureLogitsWarper, 16 | TopKLogitsWarper, TopPLogitsWarper, MinLengthLogitsProcessor) 17 | 18 | from pathlib import Path 19 | 20 | import numpy as np 21 | import zlib 22 | 23 | try: 24 | import transformers 25 | from app.gpt.quantization import GPTJBlock, GPTJForCausalLM 26 | except ImportError: 27 | pass # don't do quantization 28 | 29 | class GPTHF(AutoHF): 30 | def __init__(self, model_name='hakurei/gpt-j-random-tinier', device=None, parallelize=False, sharded=False, quantized=False, tensorized=False): 31 | super().__init__(model_name=model_name, decoder=True) 32 | 33 | model_dtype = get_dtype(device) 34 | self.device = device 35 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 36 | self.tensorized = False 37 | 38 | if tensorized: 39 | # check if tensorized model already exists so we can skip expensive model loading below 40 | _path, exists = tensorized_path(model_name) 41 | if exists: 42 | logger.info(f'Loading tensorized model {model_name}') 43 | self.model = untensorize(str(_path), self.device, quantized=quantized) 44 | self.tensorized = True 45 | 46 | if sharded: 47 | model_cfg = AutoConfig.from_pretrained(model_name, return_dict_in_generate=True) 48 | self.model = AutoModelForSoftPromptLM.from_pretrained( 49 | pretrained_model_name_or_path=None, config=model_cfg, state_dict=Checkpoint(model_name, self.device), torch_dtype=model_dtype 50 | ).eval().to(self.device) 51 | elif (not sharded) and (not quantized) and (not self.tensorized): 52 | self.model = AutoModelForSoftPromptLM.from_pretrained(model_name, return_dict_in_generate=True, torch_dtype=model_dtype).eval().to(self.device) 53 | 54 | if quantized: 55 | self.quantized = True 56 | logger.info(f'Quantizing model {model_name}') 57 | # we assume this is a gptj model - TODO: fix this 58 | transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J 59 | if not self.tensorized: 60 | self.model = GPTJForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, return_dict_in_generate=True).eval().to(self.device) 61 | logger.info(f'Quantization complete.') 62 | else: 63 | self.quantized = False 64 | 65 | if (tensorized) and (not self.tensorized): 66 | # check if model file exists in ./storage/{model_name}.model 67 | _path, exists = tensorized_path(model_name) 68 | if not exists: 69 | logger.info(f'Tensorizing model {model_name}') 70 | # tensorize model 71 | tensorize(self.model, str(_path)) 72 | del self.model 73 | raise Exception('Tensorized the model! The original model has been altered, please load the model again to use the tensorized model.') 74 | 75 | if parallelize: 76 | self.model.parallelize() 77 | 78 | @torch.inference_mode() 79 | def generate(self, args, *, db_softprompt: Optional[SoftPromptModel] = None): 80 | logits_warpers = [] 81 | logits_processors = [] 82 | stopping_criterion = [] 83 | eos_token_id = None 84 | softprompt = None 85 | output_scores = False 86 | best_of = None 87 | prompt_length = None 88 | 89 | # Check if args are valid since it's a dictionary 90 | if not isinstance(args, dict): 91 | raise TypeError("Arguments must be a dictionary") 92 | 93 | if db_softprompt: 94 | tbuf = np.frombuffer(zlib.decompress(db_softprompt.read()), dtype=np.float16) 95 | tensor = torch.from_numpy(np.array(tbuf).reshape(20, len(tbuf)//20)).to(self.device) 96 | softprompt = SoftPrompt(softembedding=tensor) 97 | sp_ids = [[id] for id in softprompt.get_special_ids()] 98 | logits_processors.append(NoBadWordsLogitsProcessor(sp_ids, None)) 99 | 100 | if "prompt" not in args: 101 | raise KeyError("Arguments must contain a prompt") 102 | else: 103 | if softprompt: 104 | prompt = softprompt.get_special_str() + args["prompt"] 105 | else: 106 | prompt = args["prompt"] 107 | 108 | if "gen_args" not in args: 109 | raise KeyError("Arguments must contain generation arguments") 110 | 111 | if "sample_args" not in args: 112 | raise KeyError("Arguments must contain sampling arguments") 113 | 114 | # Stopping criteria 115 | if "max_length" in args["gen_args"] and args["gen_args"]["max_length"]: 116 | if not isinstance(args["gen_args"]["max_length"], int) or args["gen_args"]["max_length"] < 0: 117 | raise TypeError("max_length must be a positive integer") 118 | 119 | prompt_length = len(self.tokenizer.encode(args["prompt"])) 120 | if softprompt: 121 | prompt_length += 20 122 | stopping_criterion.append(MaxLengthCriteria(args["gen_args"]["max_length"] + prompt_length)) 123 | 124 | if "max_time" in args["gen_args"] and args["gen_args"]["max_time"]: 125 | if not isinstance(args["gen_args"]["max_time"], float) or args["gen_args"]["max_time"] < 0.0: 126 | raise TypeError("max_time must be a positive float") 127 | 128 | stopping_criterion.append(MaxTimeCriteria(args["gen_args"]["max_time"])) 129 | 130 | if "eos_token_id" in args["gen_args"] and args["gen_args"]["eos_token_id"]: 131 | if not isinstance(args["gen_args"]["eos_token_id"], int) or args["gen_args"]["eos_token_id"] < 0: 132 | raise TypeError("eos_token_id must be a positive integer") 133 | 134 | eos_token_id = args["gen_args"]["eos_token_id"] 135 | 136 | if "min_length" in args["gen_args"] and args["gen_args"]["min_length"]: 137 | if not isinstance(args["gen_args"]["min_length"], int) or args["gen_args"]["min_length"] > args["gen_args"]["max_length"]: 138 | raise TypeError("min_length must be an integer less than max_length.") 139 | 140 | logits_processors.append(MinLengthLogitsProcessor(args["gen_args"]["min_length"], eos_token_id)) 141 | 142 | if "logprobs" in args["gen_args"] and args["gen_args"]["logprobs"]: 143 | if not isinstance(args["gen_args"]["logprobs"], int) or args["gen_args"]["logprobs"] < 0 or args["gen_args"]["logprobs"] > 20: 144 | raise TypeError("logprobs must be an integer between 0 and 20.") 145 | output_scores = True 146 | 147 | if "best_of" in args["gen_args"] and args["gen_args"]["best_of"]: 148 | if not isinstance(args["gen_args"]["best_of"], int) or args["gen_args"]["best_of"] < 0: 149 | raise TypeError("best_of must be a positive integer.") 150 | best_of = args["gen_args"]["best_of"] 151 | output_scores = True 152 | 153 | if len(stopping_criterion) == 0: 154 | raise ValueError("Generation arguments must contain at least one stopping criteria such as max_length or max_time.") 155 | 156 | # Warpers 157 | if "temp" in args["sample_args"] and args["sample_args"]["temp"]: 158 | if not isinstance(args["sample_args"]["temp"], float) or (args["sample_args"]["temp"] < 0.0): 159 | raise ValueError("Temperature must be a float greater than 0.0") 160 | 161 | logits_warpers.append(TemperatureLogitsWarper(args["sample_args"]["temp"])) 162 | 163 | if "top_p" in args["sample_args"] and args["sample_args"]["top_p"]: 164 | if not isinstance(args["sample_args"]["top_p"], float) or (args["sample_args"]["top_p"] < 0.0 or args["sample_args"]["top_p"] > 1.0): 165 | raise ValueError("top_p must be a float between 0 and 1") 166 | 167 | logits_warpers.append(TopPLogitsWarper(args["sample_args"]["top_p"])) 168 | 169 | if "top_k" in args["sample_args"] and args["sample_args"]["top_k"]: 170 | if not isinstance(args["sample_args"]["top_k"], int) or (args["sample_args"]["top_k"] <= 0): 171 | raise ValueError("top_k must be a positive integer") 172 | 173 | logits_warpers.append(TopKLogitsWarper(args["sample_args"]["top_k"])) 174 | 175 | if "top_a" in args["sample_args"] and args["sample_args"]["top_a"]: 176 | if not isinstance(args["sample_args"]["top_a"], float) or (args["sample_args"]["top_a"] < 0.0 or args["sample_args"]["top_a"] > 1.0): 177 | raise ValueError("top_a must be a float between 0 and 1") 178 | 179 | logits_warpers.append(TopALogitsWarper(args["sample_args"]["top_a"])) 180 | 181 | if "typical_p" in args["sample_args"] and args["sample_args"]["typical_p"]: 182 | if not isinstance(args["sample_args"]["typical_p"], float) or (args["sample_args"]["typical_p"] < 0.0 or args["sample_args"]["typical_p"] > 1.0): 183 | raise ValueError("typical_p must be a float between 0 and 1") 184 | 185 | logits_warpers.append(TypicalLogitsWarper(args["sample_args"]["typical_p"])) 186 | 187 | if "tfs" in args["sample_args"] and args["sample_args"]["tfs"]: 188 | if not isinstance(args["sample_args"]["tfs"], float) or (args["sample_args"]["tfs"] < 0.0 or args["sample_args"]["tfs"] > 1.0): 189 | raise ValueError("tfs must be a float between 0 and 1") 190 | 191 | logits_warpers.append(TailFreeSamplingLogitsWarper(args["sample_args"]["tfs"])) 192 | 193 | # Processors 194 | if "rep_p" in args["sample_args"] and args["sample_args"]["rep_p"]: 195 | rep_slope = None 196 | rep_range = None 197 | 198 | if "rep_p_slope" in args["sample_args"] and args["sample_args"]["rep_p_slope"]: 199 | if not isinstance(args["sample_args"]["rep_p_slope"], float) or args["sample_args"]["rep_p_slope"] < 0.0: 200 | raise ValueError("rep_p_slope must be a positive float.") 201 | 202 | rep_slope = args["sample_args"]["rep_p_slope"] 203 | 204 | if "rep_p_range" in args["sample_args"] and args["sample_args"]["rep_p_range"]: 205 | if not isinstance(args["sample_args"]["rep_p_range"], int) or args["sample_args"]["rep_p_range"] < 0: 206 | raise ValueError("rep_p_range must be a positive integer.") 207 | 208 | rep_range = args["sample_args"]["rep_p_range"] 209 | 210 | logits_processors.append(RepetitionPenaltyLogitsProcessor(penalty=args["sample_args"]["rep_p"], slope=rep_slope, penalize_last=rep_range)) 211 | 212 | if "bad_words" in args["sample_args"] and args["sample_args"]["bad_words"]: 213 | if not isinstance(args["sample_args"]["bad_words"], list): 214 | raise ValueError("bad_words must be a non-empty list") 215 | 216 | bad_words_ids = [] 217 | 218 | for bad_word in args["sample_args"]["bad_words"]: 219 | if not isinstance(bad_word, str): 220 | raise ValueError("bad_words must be a list of strings") 221 | 222 | bad_words_ids.append(self.tokenizer.encode(bad_word)) 223 | 224 | logits_processors.append(NoBadWordsLogitsProcessor(bad_words_ids, None)) 225 | 226 | if "logit_biases" in args["sample_args"] and args["sample_args"]["logit_biases"]: 227 | if not isinstance(args["sample_args"]["logit_biases"], list): 228 | raise ValueError("logit_biases must be a list") 229 | 230 | logit_biases = [] 231 | 232 | for logit_bias in args["sample_args"]["logit_biases"]: 233 | if not isinstance(logit_bias, dict) or "id" not in logit_bias or "bias" not in logit_bias: 234 | raise ValueError("logit_biases must be a list of dicts with keys 'id' and 'bias'") 235 | 236 | if not isinstance(logit_bias["id"], int): 237 | raise ValueError("logit_biases 'id' must be an integer") 238 | 239 | if not isinstance(logit_bias["bias"], float): 240 | raise ValueError("logit_biases 'bias' must be a float") 241 | 242 | logit_biases.append((logit_bias["id"], logit_bias["bias"])) 243 | 244 | logits_processors.append(LogitBiasProcessor(logit_biases)) 245 | 246 | if "phrase_biases" in args["sample_args"] and args["sample_args"]["phrase_biases"]: 247 | if not isinstance(args["sample_args"]["phrase_biases"], list): 248 | raise ValueError("phrase_biases must be a non-empty list") 249 | 250 | for bias in args["sample_args"]["phrase_biases"]: 251 | if not isinstance(bias, dict): 252 | raise ValueError("biases must be a list of dictionaries") 253 | 254 | if "sequences" not in bias or not isinstance(bias["sequences"], list): 255 | raise ValueError("phrase_biases must be a list of dictionaries with sequences") 256 | 257 | if "bias" not in bias or not isinstance(bias["bias"], float): 258 | raise ValueError("biases must be a list of dictionaries with a bias key") 259 | 260 | if "ensure_sequence_finish" not in bias or not isinstance(bias["ensure_sequence_finish"], bool): 261 | raise ValueError("biases must be a list of dictionaries with an ensure_sequence_finish key") 262 | 263 | if "generate_once" not in bias or not isinstance(bias["generate_once"], bool): 264 | raise ValueError("biases must be a list of dictionaries with a generate_once key") 265 | 266 | logits_processors.append(PhraseBiasProcessor([self.tokenizer.encode(sequence) for sequence in bias["sequences"]], bias["bias"], bias["ensure_sequence_finish"], bias["generate_once"])) 267 | 268 | logits_warper = LogitsProcessorList(logits_warpers) 269 | logits_processor = LogitsProcessorList(logits_processors) 270 | stopping_criteria = StoppingCriteriaList(stopping_criterion) 271 | 272 | # Generate 273 | output = {} 274 | best_of_idx = 0 275 | 276 | global current_sp 277 | current_sp = softprompt 278 | if softprompt: 279 | sp_tokenizer = softprompt.get_tokenizer(self.tokenizer) 280 | resize_model_embeddings(self.model, sp_tokenizer) 281 | input_ids = sp_tokenizer(prompt, return_tensors='pt').to(self.device) 282 | else: 283 | resize_model_embeddings(self.model, self.tokenizer) 284 | input_ids = self.tokenizer(prompt, return_tensors='pt').to(self.device) 285 | 286 | outputs = None 287 | if best_of is None: 288 | outputs = self.model.sample( 289 | **input_ids, 290 | logits_warper=logits_warper, 291 | logits_processor=logits_processor, 292 | stopping_criteria=stopping_criteria, 293 | pad_token_id=self.tokenizer.eos_token_id, 294 | eos_token_id=eos_token_id, 295 | output_scores=output_scores 296 | ) 297 | else: 298 | best_of_outputs = [] 299 | best_of_sequences = [] 300 | for i in range(best_of): 301 | outputs = self.model.sample( 302 | **input_ids, 303 | logits_warper=logits_warper, 304 | logits_processor=logits_processor, 305 | stopping_criteria=stopping_criteria, 306 | pad_token_id=self.tokenizer.eos_token_id, 307 | eos_token_id=eos_token_id, 308 | output_scores=output_scores 309 | ) 310 | scores = [] 311 | for token_idx in range(len(outputs.sequences[0]) - prompt_length): 312 | scores.append(outputs.scores[token_idx][0][outputs.sequences[0][token_idx + prompt_length]].detach().item()) 313 | best_of_sequences.append(torch.tensor(scores).mean().detach().item()) 314 | best_of_outputs.append(outputs) 315 | best_of_idx = best_of_sequences.index(max(best_of_sequences)) 316 | outputs = best_of_outputs[best_of_idx] 317 | 318 | output["output"] = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) 319 | # if softprompt: 320 | # output["output"] = output["output"][len(softprompt.get_special_str()):] 321 | 322 | if "logprobs" in args["gen_args"] and args["gen_args"]["logprobs"]: 323 | if not isinstance(args["gen_args"]["logprobs"], int) or args["gen_args"]["logprobs"] < 0 or args["gen_args"]["logprobs"] > 20: 324 | pass 325 | logprobs = [] 326 | for i in range(len(outputs.scores)): 327 | logprobs_seq = [] 328 | scores_probs = outputs.scores[i].softmax(-1).topk(args["gen_args"]["logprobs"], dim=-1).values.tolist() 329 | scores_indices = outputs.scores[i].topk(args["gen_args"]["logprobs"], dim=-1).indices.tolist() 330 | for j in range(args['gen_args']['logprobs']): 331 | logprobs_seq.append((scores_indices[0][j], scores_probs[0][j])) 332 | logprobs.append(logprobs_seq) 333 | output["logprobs"] = logprobs 334 | 335 | return output 336 | 337 | @torch.inference_mode() 338 | def classify(self, args): 339 | if not isinstance(args, dict): 340 | raise ValueError("args must be a dictionary") 341 | 342 | if "prompt" not in args or not isinstance(args["prompt"], str): 343 | raise ValueError("args must contain a prompt") 344 | 345 | if "labels" not in args or not isinstance(args["labels"], list): 346 | raise ValueError("args must contain a list of labels") 347 | 348 | for label in args["labels"]: 349 | if not isinstance(label, str): 350 | raise ValueError("labels must be a list of integers") 351 | 352 | prompt_inputs = self.tokenizer(args["prompt"], return_tensors='pt').input_ids.to(self.device) 353 | 354 | output_probs = {} 355 | for i in args["labels"]: 356 | label_inputs = self.tokenizer(i, return_tensors='pt').input_ids.to(self.device) 357 | probs = self.model.forward(input_ids=torch.cat([prompt_inputs, label_inputs], dim=-1)).logits.softmax(-1)[0][-len(label_inputs[0]):] 358 | token_probs = [probs[t][label_inputs[0][t]] for t in range(0, len(label_inputs[0]))] 359 | output_probs[i] = torch.mean(torch.stack(token_probs, dim=-1)).item() 360 | 361 | return output_probs 362 | 363 | @torch.inference_mode() 364 | def hidden(self, args): 365 | # args: 366 | # prompt: str - prompt to extract hidden states from 367 | # layers: int - number of last hidden layers to return 368 | 369 | if not isinstance(args, dict): 370 | raise ValueError('args must be a dictionary.') 371 | 372 | if 'prompt' not in args or not isinstance(args['prompt'], str): 373 | raise ValueError('args must contain a prompt as a string.') 374 | 375 | if 'layers' not in args or not isinstance(args['layers'], list): 376 | raise ValueError('layers must be the last n hidden layers to return.') 377 | 378 | prompt_inputs = self.tokenizer.encode(args['prompt'], return_tensors='pt').to(self.device) 379 | 380 | hidden_states = self.model(prompt_inputs, output_hidden_states=True).hidden_states 381 | layers = {i: torch.mean(hidden_states[i], dim = (1, )).detach().cpu().numpy().tolist() for i in args['layers']} 382 | 383 | return layers 384 | -------------------------------------------------------------------------------- /app/gpt/models.py: -------------------------------------------------------------------------------- 1 | gpt_models = [] 2 | -------------------------------------------------------------------------------- /app/gpt/quantization.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.cuda.amp import custom_fwd, custom_bwd 6 | 7 | from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise 8 | 9 | class FrozenBNBLinear(nn.Module): 10 | def __init__(self, weight, absmax, code, bias=None): 11 | assert isinstance(bias, nn.Parameter) or bias is None 12 | super().__init__() 13 | self.out_features, self.in_features = weight.shape 14 | self.register_buffer("weight", weight.requires_grad_(False)) 15 | self.register_buffer("absmax", absmax.requires_grad_(False)) 16 | self.register_buffer("code", code.requires_grad_(False)) 17 | self.adapter = None 18 | self.bias = bias 19 | 20 | def forward(self, input): 21 | output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias) 22 | if self.adapter: 23 | output += self.adapter(input) 24 | return output 25 | 26 | @classmethod 27 | def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear": 28 | weights_int8, state = quantize_blockise_lowmemory(linear.weight) 29 | return cls(weights_int8, *state, linear.bias) 30 | 31 | def __repr__(self): 32 | return f"{self.__class__.__name__}({self.in_features}, {self.out_features})" 33 | 34 | 35 | class DequantizeAndLinear(torch.autograd.Function): 36 | @staticmethod 37 | @custom_fwd 38 | def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor, 39 | absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor): 40 | weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code) 41 | ctx.save_for_backward(input, weights_quantized, absmax, code) 42 | ctx._has_bias = bias is not None 43 | return F.linear(input, weights_deq, bias) 44 | 45 | @staticmethod 46 | @custom_bwd 47 | def backward(ctx, grad_output: torch.Tensor): 48 | assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3] 49 | input, weights_quantized, absmax, code = ctx.saved_tensors 50 | # grad_output: [*batch, out_features] 51 | weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code) 52 | grad_input = grad_output @ weights_deq 53 | grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None 54 | return grad_input, None, None, None, grad_bias 55 | 56 | 57 | class FrozenBNBEmbedding(nn.Module): 58 | def __init__(self, weight, absmax, code): 59 | super().__init__() 60 | self.num_embeddings, self.embedding_dim = weight.shape 61 | self.register_buffer("weight", weight.requires_grad_(False)) 62 | self.register_buffer("absmax", absmax.requires_grad_(False)) 63 | self.register_buffer("code", code.requires_grad_(False)) 64 | self.adapter = None 65 | 66 | def forward(self, input, **kwargs): 67 | with torch.no_grad(): 68 | # note: both quantuized weights and input indices are *not* differentiable 69 | weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code) 70 | output = F.embedding(input, weight_deq, **kwargs) 71 | if self.adapter: 72 | output += self.adapter(input) 73 | return output 74 | 75 | @classmethod 76 | def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding": 77 | weights_int8, state = quantize_blockise_lowmemory(embedding.weight) 78 | return cls(weights_int8, *state) 79 | 80 | def __repr__(self): 81 | return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})" 82 | 83 | 84 | def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20): 85 | assert chunk_size % 4096 == 0 86 | code = None 87 | chunks = [] 88 | absmaxes = [] 89 | flat_tensor = matrix.view(-1) 90 | for i in range((matrix.numel() - 1) // chunk_size + 1): 91 | input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone() 92 | quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code) 93 | chunks.append(quantized_chunk) 94 | absmaxes.append(absmax_chunk) 95 | 96 | matrix_i8 = torch.cat(chunks).reshape_as(matrix) 97 | absmax = torch.cat(absmaxes) 98 | return matrix_i8, (absmax, code) 99 | 100 | 101 | def convert_to_int8(model): 102 | """Convert linear and embedding modules to 8-bit with optional adapters""" 103 | for module in model.modules(): 104 | for name, child in module.named_children(): 105 | if isinstance(child, nn.Linear): 106 | setattr( 107 | module, 108 | name, 109 | FrozenBNBLinear( 110 | weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8), 111 | absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), 112 | code=torch.zeros(256), 113 | bias=child.bias, 114 | ), 115 | ) 116 | elif isinstance(child, nn.Embedding): 117 | setattr( 118 | module, 119 | name, 120 | FrozenBNBEmbedding( 121 | weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8), 122 | absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1), 123 | code=torch.zeros(256), 124 | ) 125 | ) 126 | 127 | class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock): 128 | def __init__(self, config): 129 | super().__init__(config) 130 | 131 | convert_to_int8(self.attn) 132 | convert_to_int8(self.mlp) 133 | 134 | class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel): 135 | def __init__(self, config): 136 | super().__init__(config) 137 | convert_to_int8(self) 138 | 139 | class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM): 140 | def __init__(self, config): 141 | super().__init__(config) 142 | convert_to_int8(self) 143 | 144 | class GPTNeoBlock(transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock): 145 | def __init__(self, config, layer_id): 146 | super().__init__(config, layer_id) 147 | 148 | convert_to_int8(self.attn) 149 | convert_to_int8(self.mlp) 150 | 151 | class GPTNeoModel(transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoModel): 152 | def __init__(self, config): 153 | super().__init__(config) 154 | convert_to_int8(self) 155 | 156 | class GPTNeoForCausalLM(transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM): 157 | def __init__(self, config): 158 | super().__init__(config) 159 | convert_to_int8(self) 160 | -------------------------------------------------------------------------------- /app/gpt/softprompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import base64 3 | import json 4 | import zlib 5 | import copy 6 | import numpy as np 7 | 8 | from app.gpt.quantization import convert_to_int8 9 | from transformers import AutoModelForCausalLM 10 | 11 | # globals 12 | current_sp = None 13 | 14 | class SoftPrompt(): 15 | def __init__(self, softembedding: torch.tensor = None, n_tokens: int = 20): 16 | if softembedding is None: 17 | raise ValueError('softembeddings must not be None') 18 | if n_tokens is None or n_tokens <= 0: 19 | raise ValueError('n_tokens must be a positive int') 20 | 21 | self.softembedding = softembedding 22 | self.n_tokens = n_tokens 23 | 24 | def get_input_embeds(self): 25 | return self.softembedding 26 | 27 | def get_special_ids(self, n_vocab = 50257): 28 | ids = [] 29 | for i in range(self.n_tokens): 30 | ids.append(n_vocab + i) 31 | return ids 32 | 33 | def get_special_str(self): 34 | ids = '' 35 | for i in range(self.n_tokens): 36 | ids = ids + f'' 37 | return ''.join(ids) 38 | 39 | def get_tokenizer(self, tokenizer): 40 | sp_tokenizer = copy.deepcopy(tokenizer) 41 | for i in range(self.n_tokens): 42 | sp_tokenizer.add_tokens(f'') 43 | return sp_tokenizer 44 | 45 | def resize_model_embeddings(_model, _tokenizer): 46 | _model.resize_token_embeddings(len(_tokenizer)) 47 | 48 | class GPTSoftPromptMixin: 49 | def replace_special_tokens(self, input_ids): 50 | input_embeds = self.transformer.wte(input_ids.to(self.device)) 51 | 52 | if current_sp is None: 53 | return input_embeds 54 | 55 | n_batches = input_ids.shape[0] 56 | n_tokens = input_ids.shape[-1] 57 | sp_tokens = current_sp.get_special_ids() 58 | 59 | for b in range(n_batches): 60 | for t in range(n_tokens): 61 | input_id = input_ids[b,t].item() 62 | if input_id in sp_tokens: 63 | replacement = current_sp.get_input_embeds().to(self.device).clone().unsqueeze(0) 64 | input_embeds[b,t:t+len(sp_tokens[input_id]),:] = replacement[0,:,:] 65 | 66 | return input_embeds 67 | 68 | def forward(self, *args, **kwargs): 69 | if kwargs.get('input_ids') is None: 70 | kwargs['input_ids'] = args[0] 71 | 72 | if kwargs.get('input_ids') is None: 73 | return super().forward(*args, **kwargs) 74 | 75 | kwargs['input_ids'] = None 76 | kwargs['input_embeds'] = self.replace_special_tokens(kwargs.get('input_ids')) 77 | 78 | args = () 79 | 80 | return super().forward(*args, **kwargs) 81 | 82 | class AutoModelForSoftPromptLM(GPTSoftPromptMixin, AutoModelForCausalLM): 83 | def __init__(self, config): 84 | super().__init__(config) 85 | if 'quantized' in config.__dict__: 86 | if config.quantized: 87 | convert_to_int8(self) 88 | -------------------------------------------------------------------------------- /app/gpt/tensorize.py: -------------------------------------------------------------------------------- 1 | from app.core.logging import logging 2 | from app.gpt.quantization import FrozenBNBEmbedding, FrozenBNBLinear 3 | from typing import Dict, List, Tuple 4 | from mmappickle import mmapdict 5 | from torch import nn 6 | import numpy as np 7 | import copy 8 | import torch 9 | import time 10 | import psutil 11 | import os 12 | 13 | def read_tensor(item): 14 | dtype = item.dtype 15 | shape = item.shape 16 | buffer = memoryview(item) 17 | arr = np.ndarray.__new__( 18 | np.memmap, 19 | dtype=dtype, 20 | shape=shape, 21 | buffer=buffer, 22 | offset=0 23 | ) 24 | return arr 25 | 26 | def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]: 27 | """ 28 | Remove the tensors from a PyTorch model, convert them to NumPy 29 | arrays, and return the stripped model and tensors. 30 | """ 31 | tensors = [] 32 | for _, module in m.named_modules(): 33 | # Store the tensors in Python dictionaries 34 | params = { 35 | name: torch.clone(param).detach().cpu().numpy() 36 | for name, param in module.named_parameters(recurse=False) 37 | } 38 | buffers = { 39 | name: torch.clone(buf).detach().cpu().numpy() 40 | for name, buf in module.named_buffers(recurse=False) 41 | } 42 | tensors.append({"params": params, "buffers": buffers}) 43 | 44 | # Make a copy of the original model and strip all tensors and 45 | # buffers out of the copy. 46 | for _, module in m.named_modules(): 47 | for name in ([name for name, _ in module.named_parameters(recurse=False)] 48 | + [name for name, _ in module.named_buffers(recurse=False)]): 49 | setattr(module, name, None) 50 | 51 | # Make sure the copy is configured for inference. 52 | m.train(False) 53 | return m, tensors 54 | 55 | def replace_tensors(m: torch.nn.Module, tensors: List[Dict], device: torch.device, quantized: bool=False): 56 | """ 57 | Restore the tensors that extract_tensors() stripped out of a 58 | PyTorch model. 59 | :param no_parameters_objects: Skip wrapping tensors in 60 | ``torch.nn.Parameters`` objects (~20% speedup, may impact 61 | some models) 62 | """ 63 | modules = [module for _, module in m.named_modules()] 64 | for module, tensor_dict in zip(modules, tensors): 65 | # There are separate APIs to set parameters and buffers. 66 | for name, array in tensor_dict["params"].items(): 67 | module.register_parameter(name, torch.nn.Parameter(torch.as_tensor(read_tensor(array), device=device))) 68 | for name, array in tensor_dict["buffers"].items(): 69 | module.register_buffer(name, torch.as_tensor(read_tensor(array), device=device)) 70 | 71 | if quantized: 72 | for module in m.modules(): 73 | for name, child in module.named_children(): 74 | if isinstance(child, nn.Linear): 75 | setattr( 76 | module, 77 | name, 78 | FrozenBNBLinear( 79 | weight=torch.zeros( 80 | child.out_features, 81 | child.in_features, 82 | dtype=torch.uint8, 83 | device=device 84 | ), 85 | absmax=torch.zeros( 86 | (child.weight.numel() - 1) // 4096 + 1, 87 | device=device 88 | ), 89 | code=torch.zeros(256, device=device), 90 | bias=child.bias 91 | ) 92 | ) 93 | elif isinstance(child, nn.Embedding): 94 | setattr( 95 | module, 96 | name, 97 | FrozenBNBEmbedding( 98 | weight=torch.zeros( 99 | child.num_embeddings, 100 | child.embedding_dim, 101 | dtype=torch.uint8, 102 | device=device 103 | ), 104 | absmax=torch.zeros( 105 | (child.weight.numel() - 1) // 4096 + 1, 106 | device=device 107 | ), 108 | code=torch.zeros(256) 109 | ) 110 | ) 111 | 112 | def tensorize(m: torch.nn.Module, path: str) -> None: 113 | logging.info(f'Tensorizing to {path}') 114 | model_map = mmapdict(path+'.model') 115 | b = time.time() 116 | m_copy, m_tensors = extract_tensors(m) 117 | logging.info(f'Model tensors and skeleton extracted in {(time.time()-b):.2f}s, {(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3):.2f}gb CPU RAM used') 118 | 119 | model_map['skeleton'] = m_copy 120 | model_map['tensors'] = m_tensors 121 | 122 | def untensorize(path: str, device: torch.device, quantized: bool = False) -> torch.nn.Module: 123 | model_map = mmapdict(path+'.model') 124 | 125 | logging.info(f'Loading {path}') 126 | 127 | b = time.time() 128 | m = model_map['skeleton'].to(device) 129 | logging.info(f'Model object skeleton loaded in {(time.time()-b):.2f}s') 130 | 131 | b = time.time() 132 | t = model_map['tensors'] 133 | replace_tensors(m, t, device, quantized) 134 | logging.info(f'Model tensors loaded in {(time.time()-b):.2f}s, {(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3):.2f}gb CPU RAM used') 135 | 136 | return m 137 | -------------------------------------------------------------------------------- /app/gpt/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig 3 | 4 | from app.core.config import settings 5 | 6 | try: 7 | from collections.abc import MutableMapping 8 | except ImportError: 9 | from collections import MutableMapping 10 | 11 | from pathlib import Path 12 | 13 | class Checkpoint(MutableMapping): 14 | def __init__(self, chkpt_dir, device="cpu"): 15 | self.device = device 16 | self.chkpt_dir = Path(chkpt_dir) 17 | self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt"))) 18 | 19 | def __len__(self): 20 | return len(self.checkpoint) 21 | 22 | def __getitem__(self, key): 23 | path = self.chkpt_dir / Path(self.checkpoint[key]).name 24 | 25 | if self.device == "cpu": 26 | return torch.load(str(path), map_location=self.device).long() 27 | else: 28 | return torch.load(str(path), map_location=self.device).half() 29 | 30 | def __setitem__(self, key, value): 31 | return 32 | 33 | def __delitem__(self, key, value): 34 | return 35 | 36 | def keys(self): 37 | return self.checkpoint.keys() 38 | 39 | def __iter__(self): 40 | for key in self.checkpoint: 41 | yield (key, self.__getitem__(key)) 42 | 43 | def __copy__(self): 44 | return Checkpoint(self.chkpt_dir, device=self.device) 45 | 46 | def copy(self): 47 | return Checkpoint(self.chkpt_dir, device=self.device) 48 | 49 | def get_dtype(device: torch.device): 50 | model_dtype = torch.float32 51 | if device is None: 52 | if torch.cuda.is_available(): 53 | device = torch.device('cuda') 54 | model_dtype = torch.float16 55 | else: 56 | device = torch.device('cpu') 57 | model_dtype = torch.float32 58 | else: 59 | if device == 'cuda': 60 | model_dtype = torch.float16 61 | return model_dtype 62 | 63 | def is_decoder(config: AutoConfig): 64 | decoder_types = ['gpt2', 'gptj', 'gpt_neo', 'gpt_neox', 'xglm'] 65 | encoder_types = ['distilbert', 'bert', 'xlm', 'xlm-roberta', 'roberta', 'clip'] 66 | 67 | if config.model_type in decoder_types: 68 | return True 69 | elif config.model_type in encoder_types: 70 | return False 71 | else: 72 | raise ValueError(f"Unknown model type: {config.model_type}") 73 | 74 | def tensorized_path(model_name: str): 75 | f = Path(settings.STORAGE_PATH) / Path(model_name.split('/')[-1]) 76 | return f, f.with_suffix('.model').exists() 77 | -------------------------------------------------------------------------------- /app/gpt/warpers.py: -------------------------------------------------------------------------------- 1 | # @title Tail Free Sampling Warper 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import LogitsProcessor, LogitsWarper 7 | from math import exp 8 | 9 | 10 | class TailFreeSamplingLogitsWarper(LogitsWarper): 11 | r""" 12 | :class:`transformers.LogitsWarper` that performs tail free sampling, as described in 13 | https://www.trentonbricken.com/Tail-Free-Sampling/. 14 | Args: 15 | tfs (:obj:`float`): 16 | If set to < 1, only the most probable tokens where the second derivative of the probabilities of the tokens 17 | sorted in descending order of probability add up to at most :obj:`tfs` are kept for generation. 18 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): 19 | All filtered values will be set to this float value. 20 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): 21 | Minimum number of tokens that cannot be filtered. 22 | """ 23 | 24 | def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): 25 | tfs = float(tfs) 26 | if tfs < 0 or tfs > 1.0: 27 | raise ValueError(f"`tfs` has to be a float > 0 and < 1, but is {tfs}") 28 | 29 | self.tfs = tfs 30 | self.filter_value = filter_value 31 | self.min_tokens_to_keep = min_tokens_to_keep 32 | 33 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 34 | if self.filter_value >= 1.0: 35 | return scores 36 | 37 | sorted_logits, sorted_indices = torch.sort(scores, descending=True) 38 | probs = sorted_logits.softmax(dim=-1) 39 | 40 | # Compute second derivative normalized CDF 41 | d2 = probs.diff().diff().abs() 42 | normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True) 43 | normalized_d2_cdf = normalized_d2.cumsum(dim=-1) 44 | 45 | # Remove tokens with CDF value above the threshold (token with 0 are kept) 46 | sorted_indices_to_remove = normalized_d2_cdf > self.tfs 47 | # Centre the distribution around the cutoff as in the original implementation of the algorithm 48 | sorted_indices_to_remove = torch.cat( 49 | ( 50 | torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device), 51 | sorted_indices_to_remove, 52 | torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device), 53 | ), 54 | dim=-1, 55 | ) 56 | if self.min_tokens_to_keep > 1: 57 | # Keep at least min_tokens_to_keep 58 | sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 59 | 60 | # scatter sorted tensors to original indexing 61 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 62 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 63 | return scores 64 | 65 | 66 | class TopALogitsWarper(LogitsWarper): 67 | def __init__(self, threshold: float, filter_value: float = -float("inf")): 68 | if not isinstance(threshold, float) or (threshold < 0 or threshold > 1.0): 69 | raise ValueError(f"`threshold` has to be a float > 0 and < 1, but is {threshold}") 70 | 71 | self.z = threshold 72 | self.filter_value = filter_value 73 | 74 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 75 | probs = torch.nn.functional.softmax(scores, dim=-1) 76 | limit = torch.pow(torch.max(probs), 2.0) * self.z 77 | indices_to_remove = probs < limit 78 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 79 | return scores 80 | 81 | class TypicalLogitsWarper(LogitsWarper): 82 | def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): 83 | mass = float(mass) 84 | if mass <= 0 or mass >= 1.0: 85 | raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}") 86 | self.filter_value = filter_value 87 | self.mass = mass 88 | self.min_tokens_to_keep = min_tokens_to_keep 89 | 90 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 91 | 92 | # calculate entropy 93 | normalized = torch.nn.functional.log_softmax(scores, dim=-1) 94 | p = torch.exp(normalized) 95 | ent = -(normalized * p).nansum(-1, keepdim=True) 96 | 97 | # shift and sort 98 | shifted_scores = torch.abs((-normalized) - ent) 99 | sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) 100 | sorted_logits = scores.gather(-1, sorted_indices) 101 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 102 | 103 | # Remove tokens with cumulative mass above the threshold 104 | last_ind = (cumulative_probs < self.mass).sum(dim=1) 105 | last_ind[last_ind < 0] = 0 106 | sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) 107 | if self.min_tokens_to_keep > 1: 108 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 109 | sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 110 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 111 | 112 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 113 | return scores 114 | 115 | # @title Repetition Penalty Processor 116 | class RepetitionPenaltyLogitsProcessor(LogitsProcessor): 117 | r""" 118 | :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences. 119 | Args: 120 | repetition_penalty (:obj:`float`): 121 | The parameter for repetition penalty. 1.0 means no penalty. See `this paper 122 | `__ for more details. 123 | """ 124 | 125 | def __init__(self, penalty: float = 1.0, slope=3.33, penalize_last=250, alpha_frequency=None, alpha_presence=None, whitelist=None): 126 | if not isinstance(penalty, float) or not (penalty > 0): 127 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 128 | 129 | self.penalty = 1.0 if penalty < 1.0 else penalty 130 | self.raw_penalty = penalty 131 | self.penalize_last = None 132 | 133 | if slope is not None and penalize_last is not None and penalize_last >= 1: 134 | self.penalty = (torch.arange(penalize_last) / (penalize_last - 1)) * 2. - 1 135 | self.penalty = (slope * self.penalty) / (1 + torch.abs(self.penalty) * (slope - 1)) 136 | self.penalty = 1 + ((self.penalty + 1) / 2).unsqueeze(0) * (penalty - 1) 137 | 138 | self.penalize_last = penalize_last 139 | 140 | self.alpha_frequency = alpha_frequency if alpha_frequency is not None and alpha_frequency > 0.0 else None 141 | self.alpha_presence = alpha_presence if alpha_presence is not None and alpha_presence > 0.0 else None 142 | self.alpha_enable = self.alpha_frequency is not None or self.alpha_presence is not None 143 | 144 | self.whitelist = None 145 | self.whitelist_list = None 146 | 147 | if whitelist is not None: 148 | self.whitelist_list = whitelist 149 | 150 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 151 | if self.whitelist is None and self.whitelist_list is not None: 152 | self.whitelist_list = list(filter(lambda x: x >= 0 and x < scores.shape[1], self.whitelist_list)) 153 | 154 | if len(self.whitelist_list) > 0: 155 | self.whitelist = torch.tensor(self.whitelist_list).long().sort()[0] 156 | self.whitelist = self.whitelist.to(input_ids.device) 157 | 158 | if self.whitelist is not None: 159 | unpenalized = scores.gather(1, self.whitelist.view(1, -1)) 160 | 161 | if self.raw_penalty > 1.0: 162 | if self.penalize_last is not None: 163 | penality_len = min(input_ids.shape[1], self.penalize_last) 164 | input_ids = input_ids[:, -penality_len:] 165 | 166 | score = torch.gather(scores, 1, input_ids) 167 | 168 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability 169 | if self.penalize_last is not None: 170 | penalty = self.penalty.type(score.dtype).to(score.device) 171 | score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:]) 172 | 173 | else: 174 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 175 | 176 | scores.scatter_(1, input_ids, score) 177 | 178 | if self.alpha_enable: 179 | c = torch.zeros(scores.shape).long().to(input_ids.device) 180 | # unique only returns counts for first item in batch, so manually iterate 181 | for i in range(input_ids.shape[0]): 182 | if self.penalize_last is not None: 183 | token_input_ids, counts = torch.unique(input_ids[i, -self.penalize_last:], sorted=True, return_counts=True, dim=-1) 184 | 185 | else: 186 | token_input_ids, counts = torch.unique(input_ids[i], sorted=True, return_counts=True, dim=-1) 187 | 188 | c[i].scatter_(0, token_input_ids, counts) 189 | 190 | if self.alpha_frequency: 191 | scores -= c * self.alpha_frequency 192 | 193 | if self.alpha_presence: 194 | scores[c > 0] -= self.alpha_presence 195 | 196 | if self.whitelist is not None: 197 | scores.scatter_(1, self.whitelist.view(1, -1), unpenalized) 198 | 199 | return scores 200 | 201 | class LogitBiasProcessor(LogitsProcessor): 202 | r""" 203 | :class:`transformers.LogitsProcessor` adding bias to specific tokens 204 | Args: 205 | logit_biases (:obj:`List[Tuple[int, float]]`): 206 | Adds a float bias to the given token's logit. 207 | """ 208 | 209 | def __init__(self, logit_bias: List[Tuple[int, float]]=[]): 210 | if not isinstance(logit_bias, list) and len(logit_bias) > 0: 211 | raise ValueError("`logit_bias` has to be a non-empty list") 212 | 213 | # apply exp to each bias 214 | self.logit_bias = [(token, exp(bias)) for token, bias in logit_bias] 215 | self.bias = None 216 | 217 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 218 | if self.bias is None: 219 | self.bias = torch.zeros(scores.shape[1]).float() 220 | logit_bias = torch.tensor(self.logit_bias) 221 | self.bias.scatter_(0, logit_bias[:,0].long(), logit_bias[:,1].float()) 222 | self.bias = self.bias.to(scores.dtype).to(scores.device).unsqueeze(0) 223 | return scores + self.bias 224 | 225 | class PhraseBiasProcessor(LogitsProcessor): 226 | def __init__(self, words_ids: List[List[int]], bias: float, ensure_sequence_finish: bool, generate_once: bool): 227 | if not isinstance(words_ids, list) or len(words_ids) == 0: 228 | return 229 | 230 | if any(not isinstance(word_ids, list) for word_ids in words_ids): 231 | raise ValueError("`words_ids` has to be a list of lists") 232 | 233 | if any( 234 | any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in word_ids) 235 | for word_ids in words_ids 236 | ): 237 | raise ValueError( 238 | "Each list in `words_ids` has to be a list of positive integers" 239 | ) 240 | 241 | self.words_ids = words_ids 242 | self.bias = exp(bias) 243 | self.ensure_sequence_finish = ensure_sequence_finish 244 | self.generate_once = generate_once 245 | 246 | def slice_in_list(self, l, s): 247 | a = 0 248 | for i in range(l.shape[1]): 249 | for j in range(len(s)): 250 | if l[:,i].item() == s[j]: 251 | a += 1 252 | return a 253 | 254 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 255 | for phrase_ids in self.words_ids: 256 | if self.generate_once: 257 | if phrase_ids[0] not in input_ids: 258 | scores[:, phrase_ids[0]] += self.bias 259 | continue 260 | else: 261 | scores[:, phrase_ids[0]] += self.bias 262 | idx = self.slice_in_list(input_ids, phrase_ids) 263 | if idx == len(phrase_ids) or idx > len(phrase_ids): 264 | continue # sequence is finished 265 | else: 266 | if self.ensure_sequence_finish: 267 | if self.generate_once: 268 | scores[:, phrase_ids[idx]] -= self.bias 269 | scores[:, phrase_ids[idx]] = 1000.0 # max bias 270 | break 271 | else: 272 | scores[:, phrase_ids[idx]] += self.bias 273 | continue 274 | 275 | return scores 276 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.middleware.cors import CORSMiddleware 3 | 4 | from app.core.config import settings 5 | from app.api.v1.api import api_router 6 | 7 | app = FastAPI( 8 | title=settings.PROJECT_NAME 9 | ) 10 | 11 | if settings.BACKEND_CORS_ORIGINS: 12 | app.add_middleware( 13 | CORSMiddleware, 14 | allow_credentials=True, 15 | allow_methods=["*"], 16 | allow_headers=["*"], 17 | allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS] 18 | ) 19 | 20 | app.include_router(api_router, prefix="/api/v1", tags=["v1"]) 21 | 22 | @app.get("/") 23 | async def root(): 24 | return 'Sometimes I dream about cheese.' -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/models/__init__.py -------------------------------------------------------------------------------- /app/models/soft_prompt.py: -------------------------------------------------------------------------------- 1 | from app.core.config import settings 2 | from app.db.base_class import Base 3 | from sqlalchemy import Column, ForeignKey, Integer, String, Boolean, Numeric 4 | 5 | 6 | class SoftPrompt(Base): 7 | __tablename__ = "soft_prompts" 8 | 9 | id = Column(String, primary_key=True) 10 | name = Column(String, index=True, nullable=False) 11 | description = Column(String, nullable=True) 12 | public = Column(Boolean, nullable=False, default=True) 13 | creator = Column(Integer, ForeignKey("users.id"), nullable=False) 14 | model = Column(String, nullable=False) 15 | # model = Column(Integer, ForeignKey("models.id"), nullable=False) 16 | loss = Column(Numeric, nullable=False) 17 | steps = Column(Integer, nullable=False) 18 | 19 | def storage_path(self) -> str: 20 | return settings.STORAGE_PATH / f"{self.id}.zz" 21 | 22 | def read(self) -> bytes: 23 | with open(self.storage_path(), "rb") as data_file: 24 | return data_file.read() 25 | 26 | def write(self, tensor_data: bytes): 27 | with open(self.storage_path(), "wb") as data_file: 28 | data_file.write(tensor_data) 29 | 30 | def asdict(self) -> dict: 31 | return { 32 | "id": self.id, 33 | "name": self.name, 34 | "description": self.description, 35 | "public": self.public, 36 | "model": self.model, 37 | "loss": self.loss, 38 | "steps": self.steps, 39 | } 40 | -------------------------------------------------------------------------------- /app/models/user.py: -------------------------------------------------------------------------------- 1 | from app.db.base_class import Base 2 | from sqlalchemy import Column, ForeignKey, Integer, SmallInteger, String, Table 3 | from sqlalchemy.orm import relationship 4 | 5 | user_model_association = Table( 6 | 'user_model_association', 7 | Base.metadata, 8 | Column('user_id', ForeignKey('users.id')), 9 | Column('model_id', ForeignKey('models.id')) 10 | ) 11 | 12 | 13 | class User(Base): 14 | __tablename__ = "users" 15 | 16 | id = Column(Integer, primary_key=True, index=True) 17 | username = Column(String, index=True, nullable=False) 18 | password = Column(String, nullable=False) 19 | email = Column(String, unique=True, index=True, nullable=False) 20 | permission_level = Column(SmallInteger, default=0, nullable=False) 21 | soft_prompts = relationship("SoftPrompt") 22 | 23 | 24 | class Model(Base): 25 | __tablename__ = "models" 26 | 27 | id = Column(Integer, primary_key=True, index=True) 28 | name = Column(String, unique=True, index=True, nullable=False) 29 | size = Column(Integer, nullable=False) 30 | -------------------------------------------------------------------------------- /app/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/app/schemas/__init__.py -------------------------------------------------------------------------------- /app/schemas/model_item.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class ModelItemBase(BaseModel): 7 | model_name: str 8 | 9 | 10 | class ModelItem(ModelItemBase): 11 | size: int 12 | 13 | class Config: 14 | orm_mode = True 15 | 16 | 17 | class ModelGenArgs(BaseModel): 18 | max_length: int 19 | max_time: Optional[float] = None 20 | min_length: Optional[int] = None 21 | eos_token_id: Optional[int] = None 22 | logprobs: Optional[int] = None 23 | best_of: Optional[int] = None 24 | 25 | class ModelSampleArgs(BaseModel): 26 | class ModelLogitBiasArgs(BaseModel): 27 | id: int 28 | bias: float 29 | 30 | class ModelPhraseBiasArgs(BaseModel): 31 | sequences: List[str] 32 | bias: float 33 | ensure_sequence_finish: bool 34 | generate_once: bool 35 | 36 | temp: Optional[float] = None 37 | top_p: Optional[float] = None 38 | top_a: Optional[float] = None 39 | top_k: Optional[int] = None 40 | typical_p: Optional[float] = None 41 | tfs: Optional[float] = None 42 | rep_p: Optional[float] = None 43 | rep_p_range: Optional[int] = None 44 | rep_p_slope: Optional[float] = None 45 | bad_words: Optional[List[str]] = None 46 | # logit biases are a list of int and float tuples 47 | logit_biases: Optional[List[ModelLogitBiasArgs]] = None 48 | phrase_biases: Optional[List[ModelPhraseBiasArgs]] = None 49 | 50 | 51 | class ModelGenRequest(BaseModel): 52 | model: str 53 | prompt: str 54 | softprompt: Optional[str] = None 55 | sample_args: ModelSampleArgs 56 | gen_args: ModelGenArgs 57 | 58 | 59 | class ModelClassifyRequest(BaseModel): 60 | model: str 61 | prompt: str 62 | labels: list 63 | 64 | class ModelHiddenRequest(BaseModel): 65 | model: str 66 | prompt: str 67 | layers: List[int] 68 | 69 | class ModelLoadRequest(BaseModel): 70 | model: str 71 | parallel: Optional[bool] = False 72 | sharded: Optional[bool] = False 73 | quantized: Optional[bool] = False 74 | tensorize: Optional[bool] = False 75 | device: Optional[str] = 'cpu' 76 | -------------------------------------------------------------------------------- /app/schemas/soft_prompt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from numpy import float64 4 | from pydantic import BaseModel 5 | 6 | 7 | class SoftPrompt(BaseModel): 8 | name: str 9 | description: Optional[str] 10 | public: Optional[bool] 11 | 12 | 13 | class SoftPromptCreate(SoftPrompt): 14 | model: str 15 | loss: float64 16 | steps: int 17 | 18 | 19 | class SoftPromptUpdate(SoftPrompt): 20 | pass 21 | -------------------------------------------------------------------------------- /app/schemas/token.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class Token(BaseModel): 7 | access_token: str 8 | token_type: str 9 | 10 | 11 | class TokenData(BaseModel): 12 | username: Optional[str] = None 13 | -------------------------------------------------------------------------------- /app/schemas/user.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from app.schemas.model_item import ModelItem 4 | from pydantic import BaseModel 5 | 6 | 7 | class UserBase(BaseModel): 8 | username: str 9 | email: str 10 | permission_level: int 11 | 12 | 13 | class UserCreate(UserBase): 14 | password: str 15 | 16 | 17 | class UserUpdate(UserBase): 18 | # TODO: fill this in 19 | pass 20 | 21 | 22 | class User(UserBase): 23 | permission_level: int 24 | allowed_models: List[ModelItem] = [] 25 | 26 | class Config: 27 | orm_mode = True 28 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/banner.png -------------------------------------------------------------------------------- /conf.env: -------------------------------------------------------------------------------- 1 | PROJECT_NAME=sukima 2 | BACKEND_CORS_ORIGINS=["http://localhost:8000", "https://localhost:8000", "http://localhost", "https://localhost"] 3 | 4 | SECRET_KEY=hJ05BJnKLArZUdJ591 5 | TOKEN_URL="/api/v1/users/token" 6 | 7 | POSTGRES_USER=postgres 8 | POSTGRES_PASSWORD=postgres 9 | POSTGRES_SERVER=database 10 | POSTGRES_PORT=5432 11 | POSTGRES_DB=app 12 | 13 | ALGORITHM="HS256" 14 | ACCESS_TOKEN_EXPIRATION=30 15 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | app: 5 | image: sukima_app:latest 6 | build: 7 | context: ./ 8 | dockerfile: Dockerfile 9 | env_file: 10 | - conf.env 11 | ports: 12 | - "8000:8000" 13 | volumes: 14 | - ./:/sukima/ 15 | command: bash -c "alembic upgrade head && uvicorn app.main:app --reload --host 0.0.0.0 --port 8000" 16 | depends_on: 17 | database: 18 | condition: service_healthy 19 | 20 | database: 21 | image: postgres:14-bullseye 22 | env_file: 23 | - conf.env 24 | ports: 25 | - 5432:5432 26 | volumes: 27 | - postgres_data:/var/lib/postgresql/data/ 28 | healthcheck: 29 | test: ["CMD-SHELL", "pg_isready -U postgres"] 30 | interval: 5s 31 | timeout: 5s 32 | retries: 5 33 | 34 | volumes: 35 | postgres_data: 36 | -------------------------------------------------------------------------------- /docker-compose_nvidia-gpu.yaml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | app: 5 | build: 6 | context: ./ 7 | dockerfile: Dockerfile 8 | env_file: 9 | - conf.env 10 | ports: 11 | - "8000:8000" 12 | volumes: 13 | - ./:/sukima/ 14 | command: bash -c "alembic upgrade head && uvicorn app.main:app --reload --host 0.0.0.0 --port 8000" 15 | depends_on: 16 | - database 17 | deploy: 18 | resources: 19 | reservations: 20 | devices: 21 | - driver: nvidia 22 | count: 1 23 | capabilities: [ gpu ] 24 | 25 | 26 | database: 27 | image: postgres:14-bullseye 28 | env_file: 29 | - conf.env 30 | ports: 31 | - 5432:5432 32 | volumes: 33 | - postgres_data:/var/lib/postgresql/data/ 34 | 35 | volumes: 36 | postgres_data: 37 | -------------------------------------------------------------------------------- /k8s/conf-env-configmap.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | data: 3 | ACCESS_TOKEN_EXPIRATION: "30" 4 | ALGORITHM: HS256 5 | BACKEND_CORS_ORIGINS: '["http://localhost:8000", "https://localhost:8000", "http://localhost", "https://localhost"]' 6 | POSTGRES_DB: app 7 | POSTGRES_PASSWORD: postgres 8 | POSTGRES_PORT: "5432" 9 | POSTGRES_SERVER: database 10 | POSTGRES_USER: postgres 11 | PROJECT_NAME: sukima 12 | SECRET_KEY: hJ05BJnKLArZUdJ591 13 | TOKEN_URL: /api/v1/users/token 14 | kind: ConfigMap 15 | metadata: 16 | creationTimestamp: null 17 | labels: 18 | io.kompose.service: app-conf-env 19 | name: conf-env 20 | -------------------------------------------------------------------------------- /k8s/database-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | annotations: 5 | kompose.cmd: kompose convert -o k8s 6 | kompose.version: 1.24.0 (HEAD) 7 | creationTimestamp: null 8 | labels: 9 | io.kompose.service: database 10 | name: database 11 | spec: 12 | replicas: 1 13 | selector: 14 | matchLabels: 15 | io.kompose.service: database 16 | strategy: 17 | type: Recreate 18 | template: 19 | metadata: 20 | annotations: 21 | kompose.cmd: kompose convert -o k8s 22 | kompose.version: 1.24.0 (HEAD) 23 | creationTimestamp: null 24 | labels: 25 | io.kompose.service: database 26 | spec: 27 | containers: 28 | - env: 29 | - name: ACCESS_TOKEN_EXPIRATION 30 | valueFrom: 31 | configMapKeyRef: 32 | key: ACCESS_TOKEN_EXPIRATION 33 | name: conf-env 34 | - name: ALGORITHM 35 | valueFrom: 36 | configMapKeyRef: 37 | key: ALGORITHM 38 | name: conf-env 39 | - name: BACKEND_CORS_ORIGINS 40 | valueFrom: 41 | configMapKeyRef: 42 | key: BACKEND_CORS_ORIGINS 43 | name: conf-env 44 | - name: POSTGRES_DB 45 | valueFrom: 46 | configMapKeyRef: 47 | key: POSTGRES_DB 48 | name: conf-env 49 | - name: POSTGRES_PASSWORD 50 | valueFrom: 51 | configMapKeyRef: 52 | key: POSTGRES_PASSWORD 53 | name: conf-env 54 | - name: POSTGRES_PORT 55 | valueFrom: 56 | configMapKeyRef: 57 | key: POSTGRES_PORT 58 | name: conf-env 59 | - name: POSTGRES_SERVER 60 | valueFrom: 61 | configMapKeyRef: 62 | key: POSTGRES_SERVER 63 | name: conf-env 64 | - name: POSTGRES_USER 65 | valueFrom: 66 | configMapKeyRef: 67 | key: POSTGRES_USER 68 | name: conf-env 69 | - name: PROJECT_NAME 70 | valueFrom: 71 | configMapKeyRef: 72 | key: PROJECT_NAME 73 | name: conf-env 74 | - name: SECRET_KEY 75 | valueFrom: 76 | configMapKeyRef: 77 | key: SECRET_KEY 78 | name: conf-env 79 | - name: TOKEN_URL 80 | valueFrom: 81 | configMapKeyRef: 82 | key: TOKEN_URL 83 | name: conf-env 84 | image: postgres:14-bullseye 85 | name: database 86 | ports: 87 | - containerPort: 5432 88 | resources: {} 89 | volumeMounts: 90 | - mountPath: /var/lib/postgresql/data/ 91 | name: postgres-data 92 | restartPolicy: Always 93 | volumes: 94 | - name: postgres-data 95 | persistentVolumeClaim: 96 | claimName: postgres-data 97 | status: {} 98 | -------------------------------------------------------------------------------- /k8s/database-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | annotations: 5 | kompose.cmd: kompose convert -o k8s 6 | kompose.version: 1.24.0 (HEAD) 7 | creationTimestamp: null 8 | labels: 9 | io.kompose.service: database 10 | name: database 11 | spec: 12 | ports: 13 | - name: "5432" 14 | port: 5432 15 | targetPort: 5432 16 | selector: 17 | io.kompose.service: database 18 | status: 19 | loadBalancer: {} 20 | -------------------------------------------------------------------------------- /k8s/postgres-data-persistentvolumeclaim.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | creationTimestamp: null 5 | labels: 6 | io.kompose.service: postgres-data 7 | name: postgres-data 8 | spec: 9 | accessModes: 10 | - ReadWriteOnce 11 | resources: 12 | requests: 13 | storage: 100Mi 14 | status: {} 15 | -------------------------------------------------------------------------------- /k8s/sukima-claim0-persistentvolumeclaim.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | creationTimestamp: null 5 | labels: 6 | io.kompose.service: sukima-claim0 7 | name: sukima-claim0 8 | spec: 9 | accessModes: 10 | - ReadWriteOnce 11 | resources: 12 | requests: 13 | storage: 100Mi 14 | status: {} 15 | -------------------------------------------------------------------------------- /k8s/sukima-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | annotations: 5 | kompose.cmd: kompose convert -o k8s 6 | kompose.version: 1.24.0 (HEAD) 7 | creationTimestamp: null 8 | labels: 9 | io.kompose.service: sukima 10 | name: sukima 11 | spec: 12 | replicas: 1 13 | selector: 14 | matchLabels: 15 | io.kompose.service: sukima 16 | strategy: 17 | type: Recreate 18 | template: 19 | metadata: 20 | annotations: 21 | kompose.cmd: kompose convert -o k8s 22 | kompose.version: 1.24.0 (HEAD) 23 | creationTimestamp: null 24 | labels: 25 | io.kompose.service: sukima 26 | spec: 27 | containers: 28 | - command: ["/bin/sh","-c"] 29 | args: ["alembic upgrade head && uvicorn app.main:app --reload --host 0.0.0.0 --port 8000"] 30 | env: 31 | - name: ACCESS_TOKEN_EXPIRATION 32 | valueFrom: 33 | configMapKeyRef: 34 | key: ACCESS_TOKEN_EXPIRATION 35 | name: conf-env 36 | - name: ALGORITHM 37 | valueFrom: 38 | configMapKeyRef: 39 | key: ALGORITHM 40 | name: conf-env 41 | - name: BACKEND_CORS_ORIGINS 42 | valueFrom: 43 | configMapKeyRef: 44 | key: BACKEND_CORS_ORIGINS 45 | name: conf-env 46 | - name: POSTGRES_DB 47 | valueFrom: 48 | configMapKeyRef: 49 | key: POSTGRES_DB 50 | name: conf-env 51 | - name: POSTGRES_PASSWORD 52 | valueFrom: 53 | configMapKeyRef: 54 | key: POSTGRES_PASSWORD 55 | name: conf-env 56 | - name: POSTGRES_PORT 57 | valueFrom: 58 | configMapKeyRef: 59 | key: POSTGRES_PORT 60 | name: conf-env 61 | - name: POSTGRES_SERVER 62 | valueFrom: 63 | configMapKeyRef: 64 | key: POSTGRES_SERVER 65 | name: conf-env 66 | - name: POSTGRES_USER 67 | valueFrom: 68 | configMapKeyRef: 69 | key: POSTGRES_USER 70 | name: conf-env 71 | - name: PROJECT_NAME 72 | valueFrom: 73 | configMapKeyRef: 74 | key: PROJECT_NAME 75 | name: conf-env 76 | - name: SECRET_KEY 77 | valueFrom: 78 | configMapKeyRef: 79 | key: SECRET_KEY 80 | name: conf-env 81 | - name: TOKEN_URL 82 | valueFrom: 83 | configMapKeyRef: 84 | key: TOKEN_URL 85 | name: conf-env 86 | image: sukima_app:latest 87 | imagePullPolicy: Never 88 | name: app 89 | workingDir: /sukima/ 90 | ports: 91 | - containerPort: 8000 92 | resources: {} 93 | restartPolicy: Always 94 | status: {} 95 | -------------------------------------------------------------------------------- /k8s/sukima-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | annotations: 5 | kompose.cmd: kompose convert -o k8s 6 | kompose.version: 1.24.0 (HEAD) 7 | creationTimestamp: null 8 | labels: 9 | io.kompose.service: sukima 10 | name: sukima 11 | spec: 12 | ports: 13 | - name: "8000" 14 | port: 8000 15 | targetPort: 8000 16 | selector: 17 | io.kompose.service: sukima 18 | status: 19 | loadBalancer: {} 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi[all] 2 | uvicorn[standard] 3 | python-jose[cryptography] 4 | asyncpg 5 | pydantic 6 | alembic 7 | sqlalchemy 8 | bitsandbytes-cuda113 9 | numpy 10 | transformers 11 | bcrypt 12 | passlib 13 | mmappickle 14 | psutil 15 | Pillow 16 | openai 17 | -------------------------------------------------------------------------------- /storage/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/storage/.gitkeep -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hitomi-team/sukima/eb3e8e968b971748890885fcc1287bd27ae8dde0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | from fastapi.testclient import TestClient 2 | from app.main import app 3 | 4 | """ 5 | These tests are very bad. However, something is better than nothing. 6 | """ 7 | 8 | client = TestClient(app) 9 | 10 | 11 | def test_get_model_list(): 12 | response = client.get("/api/v1/models") 13 | assert response.status_code == 200 14 | 15 | 16 | # This assumes that the database table is ready and there is a user with username "test" and password "test"! 17 | def test_login(): 18 | response = client.post("/api/v1/users/token", json={"username": "test", "password": "test"}) 19 | 20 | assert response.status_code == 200 21 | assert response.json()["access_token"] is not None 22 | assert response.json()["token_type"] == "bearer" 23 | 24 | 25 | # It might be a good idea to make use of the application config for testing endpoints requiring auth. 26 | def test_login_load_infer(): 27 | login_response = client.post("/api/v1/users/token", json={"username": "test", "password": "test"}) 28 | 29 | assert login_response.status_code == 200 30 | token = login_response.json()["access_token"] 31 | 32 | test_model = "distilgpt2" 33 | load_model_response = client.post( 34 | "/api/v1/models/load", 35 | headers={"Authorization": f"Bearer {token}"}, 36 | json={"model": f"{test_model}", "parallel": "false", "sharded": "false"} 37 | ) 38 | 39 | assert load_model_response.status_code == 200 40 | 41 | infer_model_response = client.post( 42 | "/api/v1/models/generate", 43 | headers={"Authorization": f"Bearer {token}"}, 44 | json={ 45 | "model": f"{test_model}", 46 | "prompt": "Hello! My name is", 47 | "sample_args": { 48 | "temp": 0.51, 49 | "top_p": 0.9, 50 | "top_k": 140, 51 | "tfs": 0.993, 52 | "rep_p": 1.3, 53 | "rep_p_range": 1024, 54 | "rep_p_slope": 0.18, 55 | "bad_words": ["Jack"], 56 | "bias_words": ["Melissa"], 57 | "bias": 5.0 58 | }, 59 | "gen_args": { 60 | "max_length": 10 61 | } 62 | } 63 | ) 64 | 65 | assert infer_model_response.status_code == 200 66 | assert infer_model_response.json()["completion"] is not None 67 | assert infer_model_response.json()["completion"]["text"] is not None 68 | assert infer_model_response.json()["completion"]["time"] > 0 69 | --------------------------------------------------------------------------------