├── .dockerignore ├── data └── .gitkeep ├── deploy ├── __init__.py └── api.py ├── code └── utils │ ├── __init__.py │ ├── train_deberta.py │ └── train_llama.py ├── configs ├── conf_deberta_v0.yaml └── conf_llama_v0.yaml ├── requirements.txt ├── Makefile ├── docker ├── Dockerfile.vllm ├── Dockerfile.tensorflow └── Dockerfile.torch ├── notebooks └── vllm_inference.ipynb ├── README.md ├── docker-compose.yml ├── pyproject.toml └── .gitignore /.dockerignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deploy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/utils/train_deberta.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/conf_deberta_v0.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | omegaconf 3 | transformers 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CHECK_DIRS := code tests 2 | 3 | 4 | clean-logs: ## Clean logs 5 | rm -rf logs/** 6 | 7 | format: ## Run pre-commit hooks 8 | black $(CHECK_DIRS) 9 | ruff format $(CHECK_DIRS) 10 | pre-commit run -a 11 | -------------------------------------------------------------------------------- /docker/Dockerfile.vllm: -------------------------------------------------------------------------------- 1 | FROM vllm/vllm-openai:v0.9.2 2 | 3 | WORKDIR /tmp/working 4 | 5 | RUN pip install --no-cache-dir jupyterlab ipywidgets omegaconf hydra-core pandas scikit-learn matplotlib 6 | 7 | EXPOSE 8888 8 | 9 | CMD ["bash"] 10 | -------------------------------------------------------------------------------- /notebooks/vllm_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e7d46d62", 6 | "metadata": {}, 7 | "source": [ 8 | "# vllm inference" 9 | ] 10 | } 11 | ], 12 | "metadata": { 13 | "language_info": { 14 | "name": "python" 15 | } 16 | }, 17 | "nbformat": 4, 18 | "nbformat_minor": 5 19 | } 20 | -------------------------------------------------------------------------------- /docker/Dockerfile.tensorflow: -------------------------------------------------------------------------------- 1 | From tensorflow/tensorflow:2.18.1-gpu 2 | 3 | WORKDIR /tmp/working 4 | COPY requirements.txt . 5 | 6 | RUN apt-get update 7 | RUN apt-get install -y libgl1-mesa-dev wget vim 8 | 9 | RUN pip install --no-cache-dir -r requirements.txt 10 | RUN pip install --no-cache-dir tfts jupyterlab 11 | 12 | EXPOSE 8888 13 | 14 | CMD ["bash"] 15 | -------------------------------------------------------------------------------- /docker/Dockerfile.torch: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.11-py3 2 | 3 | RUN apt-get update && apt-get install -y \ 4 | libgl1-mesa-dev \ 5 | tmux &&\ 6 | apt-get clean 7 | 8 | RUN pip install -U pip && \ 9 | pip install transformers==4.44.2 polars==0.20.18 sentencepiece==0.1.99 \ 10 | datasets==2.19.2 huggingface-hub==0.23.2 peft==0.12.0 bitsandbytes==0.43.1 accelerate==0.32.1 trl==0.11.4 omegaconf hydra-core 11 | 12 | WORKDIR /tmp/working 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text classification with transformers 2 | ML training template, especially for DL tasks. 3 | 4 | ## Setup 5 | 6 | - if you want different framework like tensorflow, change the dockerfile in `docker-compose.yml` 7 | 8 | ``` 9 | docker compose up --build 10 | ``` 11 | 12 | - use notebook, open `http://localhost:8888/?token=12345` 13 | 14 | ## Pipeline 15 | 16 | - synthetic data 17 | - train 18 | - train deberta 19 | - train llama 20 | - inference 21 | - vllm inference 22 | - batch inference 23 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2.3" 2 | services: 3 | dl-task: 4 | build: 5 | context: . 6 | dockerfile: ./docker/Dockerfile.torch 7 | image: dl-task:latest 8 | container_name: dl-task 9 | shm_size: '32gb' 10 | runtime: nvidia 11 | environment: 12 | - LD_LIBRARY_PATH=/usr/local/cuda/lib64 13 | volumes: 14 | - ./:/tmp/working 15 | working_dir: /tmp/working 16 | ports: 17 | - 8888:8888 18 | entrypoint: "" 19 | command: jupyter lab --ip=0.0.0.0 --port 8888 --allow-root --NotebookApp.notebook_dir='/tmp/working' --no-browser --NotebookApp.token='12345' 20 | -------------------------------------------------------------------------------- /code/utils/train_llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | import torch 9 | import transformers 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer, get_linear_schedule_with_warmup, set_seed 11 | from qwen_classifier import * 12 | from utils import * 13 | 14 | 15 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 16 | os.environ["WANDB_DISABLED"] = "true" 17 | logger = set_logger() 18 | 19 | 20 | @hydra.main(version_base=None, config_path="../configs", config_name="conf_llama_v0") 21 | def main(cfg): 22 | set_seed(cfg.seed) 23 | tokenizer = AutoTokenizer.from_pretrained( 24 | cfg.model.model_name, 25 | use_fast=False, 26 | ) 27 | tokenizer.add_eos_token = True 28 | tokenizer.padding_side = "right" 29 | 30 | df = prepare_data(cfg) 31 | if cfg.debug: 32 | df = df.sample(100) 33 | cfg.training.per_device_train_batch_size = 1 34 | 35 | 36 | return 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /configs/conf_llama_v0.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | debug: true 3 | use_wandb: false 4 | 5 | 6 | dataset: 7 | df_name_or_path: ../data/input/train.csv 8 | 9 | 10 | model: 11 | name: qwen_v0 12 | pretrained_model_name_or_path: Qwen/Qwen-7B-v0.5 13 | tokenizer_name_or_path: Qwen/Qwen-7B-v0.5 14 | model_type: qwen 15 | config: 16 | hidden_size: 4096 17 | num_attention_heads: 32 18 | num_hidden_layers: 32 19 | intermediate_size: 11008 20 | max_position_embeddings: 2048 21 | 22 | 23 | train_params: 24 | batch_size: 32 25 | num_epochs: 3 26 | gradient_accumulation_steps: 1 27 | max_seq_length: 128 28 | eval_steps: 100 29 | save_steps: 500 30 | logging_steps: 50 31 | early_stopping_patience: 3 32 | fp16: true 33 | 34 | 35 | optimizer: 36 | type: adamw 37 | lr: 5e-5 38 | weight_decay: 0.01 39 | adam_beta1: 0.9 40 | adam_beta2: 0.999 41 | adam_epsilon: 1e-8 42 | max_grad_norm: 1.0 43 | scheduler: 44 | name: cosine 45 | warmup_steps: 0 46 | warmup_ratio: 0.1 47 | num_training_steps: 1000 48 | 49 | 50 | wandb: 51 | project: text-classification 52 | run_name: qwen_v0 53 | tags: 54 | - qwen_v0 55 | 56 | hydra: 57 | run: 58 | dir: ../outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} 59 | output_subdir: .hydra 60 | verbose: false 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools] 6 | include-package-data = true 7 | packages.find.where = ["."] 8 | packages.find.include = ["code*"] 9 | 10 | [project] 11 | name = "text-classification" 12 | readme = "README.md" 13 | version = "0.0.1" 14 | authors = [ 15 | { name = "Hongying Yue", email = "yuehongyingyhy@gmail.com" } 16 | ] 17 | description = "Text classification using Transformer and LLM" 18 | dependencies = [ 19 | "pyyaml", 20 | "numpy", 21 | "pandas", 22 | "matplotlib", 23 | "psutil", 24 | ] 25 | 26 | [project.optional-dependencies] 27 | dev = [ 28 | "pre-commit", 29 | "black", 30 | "ruff", 31 | "pylint", 32 | "isort", 33 | ] 34 | 35 | [tool.black] 36 | line-length = 120 37 | 38 | [tool.ruff] 39 | line-length = 120 40 | 41 | [tool.ruff.lint] 42 | select = [ 43 | "E", # pycodestyle error 44 | "W", # pycodestyle warning 45 | "F", # pyflakes 46 | "A", # flakes8-builtins 47 | "COM", # flakes8-commas 48 | "C4", # flake8-comprehensions 49 | "Q", # flake8-quotes 50 | "SIM", # flake8-simplify 51 | "PTH", # flake8-use-pathlib 52 | "I", # isort 53 | "N", # pep8 naming 54 | "UP", # pyupgrade 55 | "S", # bandit 56 | ] 57 | ignore = [ 58 | "COM812", # conflicts with the formatter 59 | ] 60 | -------------------------------------------------------------------------------- /deploy/api.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import datetime 4 | import json 5 | import re 6 | import uuid 7 | from concurrent.futures import ThreadPoolExecutor 8 | from contextlib import asynccontextmanager 9 | import logging 10 | 11 | from fastapi import Depends, FastAPI, HTTPException, Request 12 | from fastapi.middleware.cors import CORSMiddleware 13 | from logging.handlers import TimedRotatingFileHandler 14 | 15 | from pydantic import BaseModel 16 | from util import calc_accuracy4math, calc_cloud_score, format_reward_deepseek 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | file_handler = TimedRotatingFileHandler( 22 | "data/logs/text_classifiction.log", 23 | when="midnight", 24 | interval=1, 25 | backupCount=5, 26 | ) 27 | file_handler.setLevel(logging.INFO) 28 | logger.addHandler(file_handler) 29 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 30 | file_handler.setFormatter(formatter) 31 | 32 | 33 | logging.basicConfig( 34 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 35 | datefmt="%m/%d/%Y %H:%M:%S", 36 | level=logging.INFO, 37 | ) 38 | 39 | app = FastAPI() 40 | origins = ["*"] 41 | 42 | app.add_middleware( 43 | CORSMiddleware, 44 | allow_origins=origins, 45 | allow_credentials=True, 46 | allow_methods=["*"], 47 | allow_headers=["*"], 48 | ) 49 | 50 | 51 | class BaseRequest(BaseModel): 52 | data_source: Any 53 | solution_str: str 54 | ground_truth: Union[dict, list, str] 55 | extra_info: Union[dict, list, str] 56 | 57 | 58 | @app.post("/get_classifiction") 59 | async def get_reward(request: Request): 60 | json_data = await request.json() 61 | # print(json_data) 62 | 63 | groud_truth: str = json_data.get("ground_truth", "") 64 | pred_answer = json_data.get("response_str", "") 65 | 66 | if isinstance(pred_answer, list): 67 | pred_answer = pred_answer[0] 68 | 69 | 70 | score["score"] = sum( 71 | score.values() 72 | ) 73 | 74 | cur_date = datetime.datetime.now() 75 | 76 | temp_data = { 77 | "cur_date": cur_date.strftime("%Y-%m-%d %X"), 78 | "input_data": json_data, 79 | "score": score, 80 | } 81 | logger.info(json.dumps(temp_data, ensure_ascii=False)) 82 | 83 | return score 84 | 85 | 86 | if __name__ == "__main__": 87 | import uvicorn 88 | uvicorn.run(app, host="0.0.0.0", port=6009) 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | run.sh 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in 92 | # version control. However, in case of collaboration, if having 93 | # platform-specific dependencies or dependencies having no cross-platform 94 | # support, pipenv may install dependencies that don't work, or not install 95 | # all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | ### VisualStudioCode 135 | .vscode/* 136 | !.vscode/settings.json 137 | !.vscode/tasks.json 138 | !.vscode/launch.json 139 | !.vscode/extensions.json 140 | *.code-workspace 141 | **/.vscode 142 | 143 | # JetBrains 144 | .idea/ 145 | 146 | # Data & Models 147 | *.h5 148 | *.h5py 149 | *.tar 150 | *.tar.gz 151 | 152 | # Hydra-Template 153 | configs/local/default.yaml 154 | .env 155 | .autoenv 156 | .DS_Store 157 | --------------------------------------------------------------------------------