├── configs ├── test.yaml └── train.yaml ├── data ├── README.md └── get_data.sh ├── .python-version ├── scripts └── pretraining │ ├── test.py │ ├── train.py │ └── data_processing.py ├── static └── electra.png ├── tpu.Dockerfile ├── .devcontainer ├── init ├── packages.txt ├── Dockerfile ├── local.devcontainer.json └── devcontainer.json ├── helm └── electra-training │ ├── Chart.yaml │ ├── templates │ ├── pvc.yaml │ └── job.yaml │ └── values.yaml ├── notebooks ├── test.json ├── test.py ├── dataprocess.py └── data_processing.py ├── .github ├── workflows │ └── docs-generator.yaml └── dependabot.yml ├── src └── ElectraKAN │ ├── __init__.py │ ├── handlers.py │ ├── callbacks.py │ ├── modules.py │ ├── datamodule.py │ └── kan.py ├── .gitignore ├── nvidia.Dockerfile ├── pyproject.toml ├── CITATION.cff ├── test └── test_streaming_dataset.py └── README.md /configs/test.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | -------------------------------------------------------------------------------- /scripts/pretraining/test.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /static/electra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Klassikcat/KANElectra/HEAD/static/electra.png -------------------------------------------------------------------------------- /tpu.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/neuron/pytorch-training-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04 2 | -------------------------------------------------------------------------------- /.devcontainer/init: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -u -o pipefail 4 | 5 | curl -LsSf https://astral.sh/uv/install.sh | sh 6 | . $HOME/.local/bin/env 7 | git config --global core.editor 'vim' 8 | uv sync 9 | -------------------------------------------------------------------------------- /helm/electra-training/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: electra-training 3 | description: A Helm chart for running ElectraKAN training job 4 | type: application 5 | version: 0.1.0 6 | appVersion: "1.0.0" -------------------------------------------------------------------------------- /notebooks/test.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "text": "Hello, world!" 4 | }, 5 | { 6 | "text": "This is a text." 7 | }, 8 | { 9 | "text": "This is a{ test." 10 | } 11 | ] -------------------------------------------------------------------------------- /.devcontainer/packages.txt: -------------------------------------------------------------------------------- 1 | python3 2 | python3-pip 3 | python3-venv 4 | python3-dev 5 | git 6 | gh 7 | vim 8 | apt-transport-https 9 | ca-certificates 10 | curl 11 | g++ 12 | bash 13 | vim 14 | neovim -------------------------------------------------------------------------------- /.github/workflows/docs-generator.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_dispatch: 3 | jobs: 4 | generate-docs: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - name: Checkout code 8 | uses: actions/checkout@v3 -------------------------------------------------------------------------------- /data/get_data.sh: -------------------------------------------------------------------------------- 1 | wget -P ./data https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json 2 | wget -P ./data https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json 3 | python scripts.py --input-path=./train-v2.0.json --output-path=./train.txt 4 | python scripts.py --input-path=./dev-v2.0.json --output-path=./dev.txt 5 | -------------------------------------------------------------------------------- /helm/electra-training/templates/pvc.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | name: checkpoints-kanelectra 5 | annotations: 6 | "helm.sh/hook": "pre-install" 7 | "helm.sh/hook-weight": "-10" 8 | "helm.sh/hook-delete-policy": "before-hook-creation" 9 | spec: 10 | accessModes: 11 | - ReadWriteOnce 12 | resources: 13 | requests: 14 | storage: 10Gi # 필요에 따라 크기 조정 -------------------------------------------------------------------------------- /src/ElectraKAN/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ElectraKAN - Electra model using KAN model instead of Fully Connected Layer 3 | """ 4 | 5 | __version__ = "0.1.0" 6 | 7 | from .callbacks import ( 8 | PretrainingCheckpoint, 9 | OnnxCompiler 10 | ) 11 | from .handlers import ( 12 | ElectraModel, 13 | ) 14 | from .datamodule import ( 15 | StreamingElectraClassificationDataset, 16 | ElectraKANDataModule, 17 | StreamingElectraPretrainingDataset, 18 | ) 19 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for more information: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | # https://containers.dev/guide/dependabot 6 | 7 | version: 2 8 | updates: 9 | - package-ecosystem: "devcontainers" 10 | directory: "/" 11 | schedule: 12 | interval: weekly 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .lightning_studio/ 2 | .wakatime.cfg 3 | .wakatime.bdb 4 | .wakatime 5 | .aws 6 | .dotnet 7 | .yarn 8 | .npm 9 | .redhat/ 10 | .quokka/ 11 | .wallaby 12 | .wget-hsts 13 | .bash_history 14 | .idea/ 15 | .vscode/ 16 | .vscode-server/ 17 | .viminfo 18 | /__pycache__/ 19 | src/ElectraKAN/__pycache__/ 20 | .vscode/ 21 | .venv/ 22 | /data/MNLI 23 | /data/*.zip 24 | /data/*.tar.gz 25 | /data/*.json 26 | /data/*.txt 27 | /data/*.parquet 28 | .zed_server/ 29 | .ropeproject/ 30 | .mypy_cache/ 31 | */__pycache__/ 32 | */*/__pycache__/ 33 | outputs/ 34 | -------------------------------------------------------------------------------- /nvidia.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04 as builder 2 | RUN apt-get update && apt-get install -y python3 python3-pip 3 | RUN pip install uv && \ 4 | uv sync --no-dev && \ 5 | uv pip install -e . 6 | 7 | FROM nvidia/cuda:12.6.2-cudnn-runtime-ubuntu22.04 8 | RUN apt-get update && apt-get install -y --no-install-recommends python3 && rm -rf /var/lib/apt/lists/* 9 | COPY --from=builder /usr/local/lib/python3.12/dist-packages /usr/local/lib/python3.12/dist-packages 10 | COPY --from=builder /usr/local/bin /usr/local/bin 11 | COPY ./scripts /scripts 12 | 13 | RUN chmod +x /scripts/pretraining/train.py 14 | # uncomment this to run the script without "uv python /scripts/pretraining/train.py" option 15 | # CMD ["/scripts/pretraining/train.py"] -------------------------------------------------------------------------------- /notebooks/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from omegaconf import DictConfig 3 | from pathlib import Path 4 | import tqdm 5 | sys.path.append(str(Path(__file__).parent.parent)) 6 | from pyarrow import parquet as pq 7 | from scripts.pretraining.data_processing import DatasetDownloader, DatasetPreprocessor 8 | 9 | downloader = DatasetDownloader(dataset_config=DictConfig({ 10 | "raw_data": { 11 | "dataset_name": "wikipedia", 12 | "dataset_version": "20220301.en", 13 | "dataset_split": { 14 | "train": "train[:1%]", 15 | "val": "train[1%:2%]", 16 | "test": "train[:1%]" 17 | } 18 | } 19 | })) 20 | preprocessor = DatasetPreprocessor(tokenizer_path="google/electra-base-discriminator") 21 | datasets = downloader() 22 | 23 | for dataset_name, dataset in tqdm.tqdm(datasets.items()): 24 | array = preprocessor(dataset) 25 | pq.write_table(pq.Table.from_pylist(array), f"data/{dataset_name}.parquet") -------------------------------------------------------------------------------- /helm/electra-training/templates/job.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: batch/v1 2 | kind: Job 3 | metadata: 4 | name: {{ .Values.job.name }} 5 | spec: 6 | backoffLimit: {{ .Values.job.backoffLimit }} 7 | ttlSecondsAfterFinished: {{ .Values.job.ttlSecondsAfterFinished }} 8 | template: 9 | spec: 10 | nodeSelector: 11 | env: batch 12 | containers: 13 | - name: training-kanelectra-autoencoder 14 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" 15 | imagePullPolicy: {{ .Values.image.pullPolicy }} 16 | command: ["python", "scripts/pretraining/train.py"] 17 | resources: 18 | {{- toYaml .Values.resources | nindent 10 }} 19 | volumeMounts: 20 | - name: data 21 | mountPath: /data 22 | - name: outputs 23 | mountPath: /outputs 24 | volumes: 25 | - name: data 26 | emptyDir: {} 27 | - name: outputs 28 | persistentVolumeClaim: 29 | claimName: checkpoints-kanelectra 30 | restartPolicy: Never -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG IMAGE= 2 | FROM ${IMAGE} 3 | 4 | COPY .devcontainer/packages.txt /tmp/packages.txt 5 | 6 | RUN (type -p wget >/dev/null || (apt update && apt install wget -y)) \ 7 | && mkdir -p -m 755 /etc/apt/keyrings \ 8 | && out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \ 9 | && cat $out | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \ 10 | && chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \ 11 | && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \ 12 | && apt update \ 13 | && apt-get install -y $(cat /tmp/packages.txt) 14 | 15 | RUN usermod -l vscode ubuntu && \ 16 | usermod -d /home/vscode -m vscode && \ 17 | usermod -aG sudo vscode && \ 18 | echo "vscode ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers 19 | 20 | COPY --chown=vscode:vscode .devcontainer/init /usr/local/bin/init 21 | RUN chmod +x /usr/local/bin/init 22 | 23 | USER vscode 24 | -------------------------------------------------------------------------------- /helm/electra-training/values.yaml: -------------------------------------------------------------------------------- 1 | image: 2 | repository: electrakan 3 | tag: latest 4 | pullPolicy: IfNotPresent 5 | 6 | job: 7 | name: electra-training 8 | backoffLimit: 4 9 | ttlSecondsAfterFinished: 100 10 | 11 | resources: 12 | limits: 13 | cpu: "8" 14 | memory: "16Gi" 15 | nvidia.com/gpu: "4" 16 | requests: 17 | cpu: "8" 18 | memory: "8Gi" 19 | nvidia.com/gpu: "4" 20 | 21 | config: 22 | tokenizer_name: "klue/bert-base" 23 | datasets: 24 | train: 25 | path: "/data/train.csv" 26 | max_length: 512 27 | text_row: "text" 28 | val: 29 | path: "/data/val.csv" 30 | max_length: 512 31 | text_row: "text" 32 | test: 33 | path: "/data/test.csv" 34 | max_length: 512 35 | text_row: "text" 36 | datamodule: 37 | batch_size: 32 38 | num_workers: 4 39 | pin_memory: true 40 | trainer: 41 | max_epochs: 10 42 | accelerator: "gpu" 43 | devices: 1 44 | precision: "16-mixed" 45 | callbacks: 46 | - name: ModelCheckpoint 47 | params: 48 | dirpath: "/outputs/checkpoints" 49 | filename: "electra-{epoch:02d}-{val_loss:.2f}" 50 | save_top_k: 3 51 | monitor: "val_loss" 52 | mode: "min" -------------------------------------------------------------------------------- /.devcontainer/local.devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ElectraKAN-local-DevContainer", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "context": "..", 6 | "args": { 7 | "IMAGE": "mcr.microsoft.com/devcontainers/python:dev-3.13-bookworm" 8 | } 9 | }, 10 | "runArgs": [ 11 | "--name=ElectraKAN-local-DevContainer" 12 | ], 13 | "features": { 14 | "ghcr.io/devcontainers/features/docker-in-docker:2": { 15 | "moby": true, 16 | "dockerDashComposeVersion": "latest" 17 | }, 18 | "ghcr.io/devcontainers/features/aws-cli:1": {}, 19 | "ghcr.io/devcontainers/features/kubectl-helm-minikube:1": {} 20 | }, 21 | "customizations": { 22 | "vscode": { 23 | "extensions": [ 24 | "ms-python.python", 25 | "ms-python.black-formatter", 26 | "ms-python.mypy-linter", 27 | "wakatime.vscode-wakatime", 28 | "esbenp.prettier-vscode", 29 | "github.vscode-github-actions", 30 | "me-dutour-mathieu.vscode-github-actions", 31 | "oderwat.indent-rainbow" 32 | ] 33 | } 34 | }, 35 | "postCreateCommand": [ 36 | "init" 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "electrakan" 3 | version = "0.1.0" 4 | description = "ElectraKAN is Electra model using KAN model instead of Fully Connected Layer" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | authors = [ 8 | {name = "shin jung tae", email = "shinjeongtae@gmail.com"} 9 | ] 10 | dependencies = [ 11 | "tqdm>=4.67.1", 12 | "torch>=2.5.1", 13 | "lightning>=2.3.1", 14 | "hydra-core>=1.3.2", 15 | "ruamel.yaml>=0.18.3", 16 | "omegaconf>=2.3.0", 17 | "torchmetrics>=1.1.1", 18 | "transformers>=4.45.2", 19 | "ujson>=5.7.0", 20 | "polars>=1.13.0", 21 | "datasets>=3.3.1", 22 | "orjson>=3.10.0", 23 | "pyarrow>=19.0.1", 24 | "jsonlines>=4.0.0", 25 | ] 26 | 27 | [dependency-groups] 28 | dev = [ 29 | "ipywidgets", 30 | "pytest", 31 | "neuronx-cc", 32 | "torch-neuronx", 33 | "marimo", 34 | "polars", 35 | "line-profiler", 36 | "memory-profiler", 37 | "ipykernel", 38 | "ipython", 39 | "jupyter", 40 | "jupyterlab", 41 | "jupyterlab-git", 42 | "typer" 43 | ] 44 | 45 | [build-system] 46 | requires = ["setuptools", "wheel", "poetry-core>=1.0.0"] 47 | build-backend = "poetry.core.masonry.api" 48 | 49 | [tool.poetry] 50 | name = "electrakan" 51 | version = "0.1.0" 52 | description = "ElectraKAN is Electra model using KAN model instead of Fully Connected Layer" 53 | authors = ["shin jung tae "] 54 | packages = [ 55 | { include = "ElectraKAN", from = "src" } 56 | ] 57 | 58 | [tool.poetry.scripts] 59 | train = "ElectraKAN.train:main" -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ElectraKAN-Nvidia-DevContainer", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "context": "..", 6 | "args": { 7 | "IMAGE": "nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04" 8 | } 9 | 10 | }, 11 | "features": { 12 | "ghcr.io/devcontainers/features/docker-in-docker:2": { 13 | "moby": true, 14 | "dockerDashComposeVersion": "latest" 15 | }, 16 | "ghcr.io/devcontainers/features/aws-cli:1.1.2": {}, 17 | "ghcr.io/devcontainers/features/kubectl-helm-minikube:1": {} 18 | }, 19 | "runArgs": [ 20 | "--gpus=all", 21 | "--name=ElectraKAN-Nvidia-DevContainer" 22 | ], 23 | "remoteEnv": { 24 | "PATH": "${containerEnv:PATH}:/usr/local/cuda/bin", 25 | "LD_LIBRARY_PATH": "$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64", 26 | "XLA_FLAGS": "--xla_gpu_cuda_data_dir=/usr/local/cuda" 27 | }, 28 | "customizations": { 29 | "vscode": { 30 | "extensions": [ 31 | "ms-python.python", 32 | "ms-python.black-formatter", 33 | "ms-python.mypy-linter", 34 | "wakatime.vscode-wakatime", 35 | "esbenp.prettier-vscode", 36 | "github.vscode-github-actions", 37 | "me-dutour-mathieu.vscode-github-actions", 38 | "oderwat.indent-rainbow" 39 | ] 40 | } 41 | }, 42 | "postCreateCommand": [ 43 | "init" 44 | ] 45 | } 46 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | references: 3 | - type: misc 4 | authors: 5 | - family-names: "Park" 6 | given-names: "Jangwon" 7 | title: "KoELECTRA: Pretrained ELECTRA Model for Korean" 8 | year: 2020 9 | version: 1.0.0 10 | publisher: "GitHub" 11 | url: "https://github.com/monologg/KoELECTRA" 12 | - type: inproceedings 13 | authors: 14 | - family-names: "Clark" 15 | given-names: "Kevin" 16 | - family-names: "Luong" 17 | given-names: "Minh-Thang" 18 | - family-names: "Le" 19 | given-names: "Quoc V." 20 | - family-names: "Manning" 21 | given-names: "Christopher D." 22 | title: "ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators" 23 | year: 2020 24 | booktitle: "ICLR" 25 | url: "https://openreview.net/pdf?id=r1xMH1BtvB" 26 | - type: article 27 | authors: 28 | - family-names: "Liu" 29 | given-names: "Ziming" 30 | - family-names: "Wang" 31 | given-names: "Yixuan" 32 | - family-names: "Vaidya" 33 | given-names: "Sachin" 34 | - family-names: "Ruehle" 35 | given-names: "Fabian" 36 | - family-names: "Halverson" 37 | given-names: "James" 38 | - family-names: "Soljačić" 39 | given-names: "Marin" 40 | - family-names: "Hou" 41 | given-names: "Thomas Y" 42 | - family-names: "Tegmark" 43 | given-names: "Max" 44 | title: "KAN: Kolmogorov-Arnold Networks" 45 | year: 2024 46 | journal: "arXiv preprint arXiv:2404.19756" 47 | url: "https://arxiv.org/abs/2404.19756" 48 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | max_length: 512 2 | tokenizer_name: "google/electra-base-discriminator" 3 | raw_data: 4 | dataset_name: "wikipedia" 5 | dataset_version: "20220301.en" 6 | dataset_split: 7 | train: "train[:80%]" 8 | val: "train[80%:]" 9 | test: "test" 10 | datasets: 11 | train: 12 | path: /workspaces/https---github-com-klassikcat-kanelectra/data/dev.txt 13 | max_length: ${max_length} 14 | text_row: 0 15 | val: 16 | path: /workspaces/https---github-com-klassikcat-kanelectra/data/dev.txt 17 | max_length: ${max_length} 18 | text_row: 0 19 | test: 20 | path: /workspaces/https---github-com-klassikcat-kanelectra/data/dev.txt 21 | max_length: ${max_length} 22 | text_row: 0 23 | datamodule: 24 | batch_size: 1 25 | num_workers: 4 26 | pin_memory: True 27 | chunk_size: 1000 # 스트리밍 데이터셋의 청크 크기 28 | nn: 29 | generator_lr: 1e-3 30 | discriminator_lr: 1e-3 31 | mask_token_id: 103 # [MASK] token of google/electra-base-discriminator 32 | generator: 33 | vocab_size: 35000 34 | embedding_dim: 768 35 | vocab_type_size: 2 36 | embedding_dropout_p: 0.1 37 | hidden_dim: 768 38 | num_heads: 4 39 | ff_dim: 1024 40 | num_layers: 12 41 | max_pos_embedding: 512 42 | discriminator: 43 | vocab_size: 35000 44 | num_labels: 2 45 | embedding_dim: 768 46 | vocab_type_size: 2 47 | embedding_dropout_p: 0.1 48 | hidden_dim: 768 49 | num_heads: 4 50 | ff_dim: 1024 51 | num_layers: 12 52 | max_pos_embedding: 512 53 | trainer: 54 | accelerator: auto 55 | devices: 3 56 | max_epochs: 10 57 | strategy: ddp 58 | precision: 32 59 | enable_progress_bar: true 60 | callbacks: 61 | - name: ModelCheckpoint 62 | params: 63 | monitor: val_loss 64 | mode: min 65 | save_top_k: 1 66 | dirpath: checkpoints 67 | filename: "{epoch}-{val_loss:.2f}" 68 | -------------------------------------------------------------------------------- /notebooks/dataprocess.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import polars as pl 3 | from datasets import load_dataset 4 | 5 | 6 | datasets_us = load_dataset("wikipedia", "20220301.en", split="train[:1%]") 7 | 8 | #%% 9 | 10 | from transformers import AutoTokenizer 11 | 12 | text_sample = datasets_us[0]["text"] 13 | tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") 14 | tokenized_text = tokenizer(datasets_us[0]["text"]) 15 | 16 | print(len(tokenized_text.input_ids)) 17 | 18 | #%% 19 | 20 | def is_title(word: str, titles: list[str]) -> bool: 21 | """단어가 호칭인지 확인합니다.""" 22 | return word.lower().rstrip('.') in titles 23 | 24 | 25 | def is_sentence_boundary(char: str, next_char: str, current_sentence: str, titles: list[str]) -> bool: 26 | """문장의 경계인지 판단합니다.""" 27 | if char != '.' or next_char not in [' ', '\n']: 28 | return False 29 | 30 | words = current_sentence.strip().split() 31 | if not words: 32 | return False 33 | 34 | return not is_title(words[-1], titles) 35 | 36 | 37 | def split_sentences(text: str) -> list[str]: 38 | """텍스트를 문장 단위로 분리합니다.""" 39 | titles = ['mr', 'mrs', 'ms', 'miss', 'dr', 'prof', 'rev', 'hon'] 40 | 41 | result = [] 42 | current = "" 43 | 44 | for i, char in enumerate(text): 45 | current += char 46 | 47 | if i + 1 < len(text) and is_sentence_boundary(char, text[i + 1], current, titles): 48 | result.append(current.strip()) 49 | current = "" 50 | 51 | if current.strip(): 52 | result.append(current.strip()) 53 | 54 | return result 55 | 56 | # 테스트 57 | text = "안녕하세요. Mr. Smith는 의사입니다.\n그의 메일은 mr.smith@gmail.com 입니다. Dr. Lee와 Prof. Kim이 왔어요. 그리고 Mrs. Park도 왔습니다." 58 | sentences = split_sentences(text) 59 | for sentence in sentences: 60 | print(sentence) 61 | 62 | #%% 63 | 64 | import tqdm 65 | from line_profiler import LineProfiler 66 | 67 | profiler = LineProfiler() 68 | 69 | splited_sentences = split_sentences(text_sample) 70 | 71 | def sample_func(): 72 | total_texts = [tokenizer(i, max_length=512, return_attention_mask=True, return_token_type_ids=True, truncation=True, padding="max_length").input_ids for i in splited_sentences] 73 | return total_texts 74 | 75 | wrapped = profiler(sample_func) 76 | result = wrapped() 77 | profiler.print_stats() 78 | 79 | 80 | # %% 81 | import sys 82 | import tqdm 83 | from omegaconf import DictConfig 84 | from pathlib import Path 85 | sys.path.append(str(Path(__file__).parent.parent)) 86 | from pyarrow import parquet as pq 87 | from scripts.pretraining.data_processing import DatasetDownloader, DatasetPreprocessor 88 | 89 | downloader = DatasetDownloader(dataset_config=DictConfig({ 90 | "raw_data": { 91 | "dataset_name": "wikipedia", 92 | "dataset_version": "20220301.en", 93 | "dataset_split": { 94 | "train": "train[:1%]", 95 | "val": "train[1%:2%]", 96 | "test": "train[:1%]" 97 | } 98 | } 99 | })) 100 | preprocessor = DatasetPreprocessor(tokenizer_path="google/electra-base-discriminator") 101 | datasets = downloader() 102 | 103 | for dataset_name, dataset in tqdm.tqdm(datasets.items()): 104 | array = preprocessor(dataset) 105 | pq.write_table(pq.Table.from_pylist(array), f"data/{dataset_name}.parquet") 106 | 107 | # %% 108 | -------------------------------------------------------------------------------- /test/test_streaming_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | import pytest 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer 7 | import pyarrow as pa 8 | from pyarrow import parquet as pq 9 | import tqdm 10 | import torch 11 | 12 | # 프로젝트 루트 디렉토리를 Python 경로에 추가 13 | sys.path.append(str(Path(__file__).parent.parent)) 14 | 15 | from src.ElectraKAN.datamodule import StreamingElectraPretrainingDataset 16 | from scripts.pretraining.data_processing import DatasetPreprocessor 17 | 18 | 19 | def test_dataset_download_and_preprocess(): 20 | """위키피디아 데이터셋 다운로드 및 전처리 테스트""" 21 | # 1. 데이터셋 다운로드 22 | dataset = load_dataset("wikipedia", "20220301.en", split="train[:1%]") 23 | assert len(dataset) > 0, "데이터셋이 비어있습니다." 24 | 25 | # 2. 전처리 26 | tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") 27 | preprocessor = DatasetPreprocessor(tokenizer_path="google/electra-base-discriminator") 28 | processed_data = preprocessor(dataset) 29 | 30 | # 3. Parquet으로 저장 31 | output_path = "data/test_wikipedia.parquet" 32 | table = pa.Table.from_pylist(processed_data) 33 | pq.write_table(table, output_path) 34 | assert os.path.exists(output_path), "Parquet 파일이 생성되지 않았습니다." 35 | 36 | 37 | def test_streaming_dataset_loading(): 38 | """스트리밍 데이터셋 로딩 테스트""" 39 | # 1. 데이터셋 생성 40 | tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") 41 | dataset = StreamingElectraPretrainingDataset( 42 | path="data/test_wikipedia.parquet", 43 | tokenizer=tokenizer, 44 | max_length=512, 45 | chunk_size=1000, 46 | text_column="text" 47 | ) 48 | 49 | # 2. 데이터셋 크기 확인 50 | assert len(dataset) > 0, "데이터셋이 비어있습니다." 51 | 52 | # 3. 데이터 로딩 테스트 53 | for i in tqdm.tqdm(range(min(100, len(dataset))), desc="데이터 로딩 테스트"): 54 | # 데이터 로딩 55 | masked_input_ids, attention_mask, token_type_ids, input_ids = dataset[i] 56 | 57 | # 텐서 형태 확인 58 | assert masked_input_ids.shape == (512,), f"masked_input_ids shape error at index {i}" 59 | assert attention_mask.shape == (512,), f"attention_mask shape error at index {i}" 60 | assert token_type_ids.shape == (512,), f"token_type_ids shape error at index {i}" 61 | assert input_ids.shape == (512,), f"input_ids shape error at index {i}" 62 | 63 | # 데이터 타입 확인 64 | assert masked_input_ids.dtype == torch.long, f"masked_input_ids dtype error at index {i}" 65 | assert attention_mask.dtype == torch.long, f"attention_mask dtype error at index {i}" 66 | assert token_type_ids.dtype == torch.long, f"token_type_ids dtype error at index {i}" 67 | assert input_ids.dtype == torch.long, f"input_ids dtype error at index {i}" 68 | 69 | 70 | def test_chunk_loading(): 71 | """청크 로딩 테스트""" 72 | tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") 73 | dataset = StreamingElectraPretrainingDataset( 74 | path="data/test_wikipedia.parquet", 75 | tokenizer=tokenizer, 76 | max_length=512, 77 | chunk_size=1000, 78 | text_column="text" 79 | ) 80 | 81 | # 첫 번째 청크 로딩 82 | first_chunk = dataset.current_chunk 83 | assert len(first_chunk) > 0, "첫 번째 청크가 비어있습니다." 84 | 85 | # 두 번째 청크 로딩 86 | dataset._load_next_chunk() 87 | second_chunk = dataset.current_chunk 88 | assert len(second_chunk) > 0, "두 번째 청크가 비어있습니다." 89 | 90 | # 청크가 다른지 확인 91 | assert first_chunk != second_chunk, "청크가 중복되었습니다." 92 | 93 | 94 | if __name__ == "__main__": 95 | # 테스트 실행 96 | test_dataset_download_and_preprocess() 97 | test_streaming_dataset_loading() 98 | test_chunk_loading() 99 | print("모든 테스트가 성공적으로 완료되었습니다!") -------------------------------------------------------------------------------- /notebooks/data_processing.py: -------------------------------------------------------------------------------- 1 | import marimo 2 | 3 | __generated_with = "0.11.29" 4 | app = marimo.App(width="medium") 5 | 6 | 7 | @app.cell 8 | def _(): 9 | import polars as pl 10 | from datasets import load_dataset 11 | return load_dataset, pl 12 | 13 | 14 | @app.cell 15 | def _(load_dataset): 16 | datasets_us = load_dataset("wikipedia", "20220301.en", split="train[:1%]") 17 | return (datasets_us,) 18 | 19 | 20 | @app.cell 21 | def _(datasets_us): 22 | datasets_us[0] 23 | return 24 | 25 | 26 | @app.cell 27 | def _(datasets_us): 28 | from transformers import AutoTokenizer 29 | 30 | text_sample = datasets_us[0]["text"] 31 | tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") 32 | tokenized_text = tokenizer(datasets_us[0]["text"]) 33 | 34 | print(len(tokenized_text.input_ids)) 35 | return AutoTokenizer, text_sample, tokenized_text, tokenizer 36 | 37 | 38 | @app.cell 39 | def _(): 40 | def is_title(word: str, titles: list[str]) -> bool: 41 | """단어가 호칭인지 확인합니다.""" 42 | return word.lower().rstrip('.') in titles 43 | 44 | 45 | def is_sentence_boundary(char: str, next_char: str, current_sentence: str, titles: list[str]) -> bool: 46 | """문장의 경계인지 판단합니다.""" 47 | if char != '.' or next_char not in [' ', '\n']: 48 | return False 49 | 50 | words = current_sentence.strip().split() 51 | if not words: 52 | return False 53 | 54 | return not is_title(words[-1], titles) 55 | 56 | 57 | def split_sentences(text: str) -> list[str]: 58 | """텍스트를 문장 단위로 분리합니다.""" 59 | titles = ['mr', 'mrs', 'ms', 'miss', 'dr', 'prof', 'rev', 'hon'] 60 | 61 | result = [] 62 | current = "" 63 | 64 | for i, char in enumerate(text): 65 | current += char 66 | 67 | if i + 1 < len(text) and is_sentence_boundary(char, text[i + 1], current, titles): 68 | result.append(current.strip()) 69 | current = "" 70 | 71 | if current.strip(): 72 | result.append(current.strip()) 73 | 74 | return result 75 | 76 | # 테스트 77 | text = "안녕하세요. Mr. Smith는 의사입니다.\n그의 메일은 mr.smith@gmail.com 입니다. Dr. Lee와 Prof. Kim이 왔어요. 그리고 Mrs. Park도 왔습니다." 78 | sentences = split_sentences(text) 79 | for sentence in sentences: 80 | print(sentence) 81 | return ( 82 | is_sentence_boundary, 83 | is_title, 84 | sentence, 85 | sentences, 86 | split_sentences, 87 | text, 88 | ) 89 | 90 | 91 | @app.cell 92 | def _(split_sentences, text_sample, tokenizer): 93 | import tqdm 94 | from line_profiler import LineProfiler 95 | 96 | profiler = LineProfiler() 97 | 98 | splited_sentences = split_sentences(text_sample) 99 | 100 | def sample_func(): 101 | total_texts = [tokenizer(i, max_length=512, return_attention_mask=True, return_token_type_ids=True, truncation=True, padding="max_length").input_ids for i in splited_sentences] 102 | return total_texts 103 | 104 | wrapped = profiler(sample_func) 105 | result = wrapped() 106 | profiler.print_stats() 107 | return ( 108 | LineProfiler, 109 | profiler, 110 | result, 111 | sample_func, 112 | splited_sentences, 113 | tqdm, 114 | wrapped, 115 | ) 116 | 117 | 118 | @app.cell 119 | def _(result): 120 | result 121 | return 122 | 123 | 124 | @app.cell 125 | def _(datasets_us): 126 | datasets_us 127 | return 128 | 129 | 130 | @app.cell 131 | def _(): 132 | import sys 133 | sys.path.append("/workspaces/kanelectra") 134 | from pipeline.modules import PretrainingDataset 135 | 136 | 137 | 138 | return PretrainingDataset, sys 139 | 140 | 141 | if __name__ == "__main__": 142 | app.run() 143 | -------------------------------------------------------------------------------- /scripts/pretraining/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import orjson 3 | from transformers import AutoTokenizer 4 | from omegaconf import DictConfig 5 | from typing import List 6 | import lightning.pytorch as pl 7 | from pyarrow import parquet as pq 8 | 9 | from ElectraKAN.datamodule import ElectraKANDataModule, StreamingElectraPretrainingDataset 10 | from ElectraKAN.handlers import ElectraModel 11 | from ElectraKAN import callbacks 12 | 13 | from data_processing import DatasetDownloader, DatasetPreprocessor 14 | 15 | 16 | @hydra.main(config_path="../../configs", config_name="train") 17 | def main(cfg: DictConfig) -> None: 18 | datamodule = get_dataloader( 19 | tokenizer_path=cfg.tokenizer_name, 20 | dataset_config=cfg.datasets, 21 | datamodule_config=cfg.datamodule 22 | ) 23 | model = ElectraModel(cfg.nn) 24 | callback_lst = get_callbacks(cfg.trainer.callbacks) 25 | del cfg.trainer.callbacks 26 | trainer = pl.Trainer(**cfg.trainer, callbacks=callback_lst) 27 | trainer.fit(model, datamodule=datamodule) 28 | trainer.test(model, datamodule=datamodule) 29 | 30 | def get_callbacks( 31 | callbacks_config: DictConfig 32 | ) -> List[pl.Callback]: 33 | callback_lst: List[pl.Callback] = [] 34 | for config in callbacks_config: 35 | try: 36 | callback = getattr(pl.callbacks, config.name)(**config.params) 37 | except ModuleNotFoundError: 38 | callback = getattr(callbacks, config.name)(**config.params) 39 | callback_lst.append(callback) 40 | return callback_lst 41 | 42 | 43 | def get_dataloader( 44 | tokenizer_path: str, 45 | dataset_config: DictConfig, 46 | datamodule_config: DictConfig 47 | ) -> ElectraKANDataModule: 48 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 49 | 50 | # 파일 확장자에 따라 적절한 데이터셋 생성 메서드 선택 51 | if str(dataset_config.train.path).endswith('.parquet'): 52 | train_dataset = StreamingElectraPretrainingDataset.from_parquet( 53 | path=dataset_config.train.path, 54 | tokenizer=tokenizer, 55 | max_length=dataset_config.train.max_length, 56 | text_column="text", 57 | chunk_size=datamodule_config.chunk_size 58 | ) 59 | val_dataset = StreamingElectraPretrainingDataset.from_parquet( 60 | path=dataset_config.val.path, 61 | tokenizer=tokenizer, 62 | max_length=dataset_config.val.max_length, 63 | text_column="text", 64 | chunk_size=datamodule_config.chunk_size 65 | ) 66 | test_dataset = StreamingElectraPretrainingDataset.from_parquet( 67 | path=dataset_config.test.path, 68 | tokenizer=tokenizer, 69 | max_length=dataset_config.test.max_length, 70 | text_column="text", 71 | chunk_size=datamodule_config.chunk_size 72 | ) 73 | else: 74 | train_dataset = StreamingElectraPretrainingDataset.from_csv( 75 | path=dataset_config.train.path, 76 | tokenizer=tokenizer, 77 | max_length=dataset_config.train.max_length, 78 | text_row=dataset_config.train.text_row, 79 | chunk_size=datamodule_config.chunk_size 80 | ) 81 | val_dataset = StreamingElectraPretrainingDataset.from_csv( 82 | path=dataset_config.val.path, 83 | tokenizer=tokenizer, 84 | max_length=dataset_config.val.max_length, 85 | text_row=dataset_config.val.text_row, 86 | chunk_size=datamodule_config.chunk_size 87 | ) 88 | test_dataset = StreamingElectraPretrainingDataset.from_csv( 89 | path=dataset_config.test.path, 90 | tokenizer=tokenizer, 91 | max_length=dataset_config.test.max_length, 92 | text_row=dataset_config.test.text_row, 93 | chunk_size=datamodule_config.chunk_size 94 | ) 95 | 96 | datamodule = ElectraKANDataModule( 97 | train_dataset=train_dataset, 98 | val_dataset=val_dataset, 99 | test_dataset=test_dataset, 100 | batch_size=datamodule_config.batch_size, 101 | num_workers=datamodule_config.num_workers, 102 | pin_memory=datamodule_config.pin_memory 103 | ) 104 | return datamodule 105 | 106 | 107 | def preprocess_datasets( 108 | dataset_config: DictConfig 109 | ) -> None: 110 | downloader = DatasetDownloader(dataset_config) 111 | preprocessor = DatasetPreprocessor(tokenizer_path=dataset_config.tokenizer_name) 112 | datasets = downloader() 113 | for dataset_name, dataset in datasets: 114 | array = preprocessor(dataset) 115 | with open(f"data/{dataset_name}.jsonl", "w") as f: 116 | for item in array: 117 | f.write(orjson.dumps(item).decode("utf-8") + "\n") 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /scripts/pretraining/data_processing.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import asyncio 3 | import json 4 | from typing import Dict, Any, List, Generator 5 | from omegaconf import DictConfig 6 | from datasets import load_dataset, Dataset 7 | from transformers import AutoTokenizer 8 | 9 | 10 | class DatasetDownloader: 11 | def __init__(self, dataset_config: DictConfig): 12 | self.dataset_config = dataset_config 13 | 14 | async def download_dataset(self, dataset_name: str, dataset_version: str, dataset_split: str) -> Dataset: 15 | return load_dataset(dataset_name, dataset_version, split=dataset_split) 16 | 17 | async def download_datasets(self) -> Dict[str, Dataset]: 18 | tasks = [ 19 | self.download_dataset( 20 | dataset_name=self.dataset_config.raw_data.dataset_name, 21 | dataset_version=self.dataset_config.raw_data.dataset_version, 22 | dataset_split=self.dataset_config.raw_data.dataset_split.train 23 | ), 24 | self.download_dataset( 25 | dataset_name=self.dataset_config.raw_data.dataset_name, 26 | dataset_version=self.dataset_config.raw_data.dataset_version, 27 | dataset_split=self.dataset_config.raw_data.dataset_split.val 28 | ), 29 | self.download_dataset( 30 | dataset_name=self.dataset_config.raw_data.dataset_name, 31 | dataset_version=self.dataset_config.raw_data.dataset_version, 32 | dataset_split=self.dataset_config.raw_data.dataset_split.test 33 | ) 34 | ] 35 | 36 | result = await asyncio.gather(*tasks) 37 | return { 38 | "train": result[0], 39 | "val": result[1], 40 | "test": result[2] 41 | } 42 | 43 | def __call__(self) -> Dict[str, Dataset]: 44 | return asyncio.run(self.download_datasets()) 45 | 46 | 47 | class DatasetPreprocessor: 48 | def __init__(self, tokenizer_path: str): 49 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 50 | 51 | def is_title(self, word: str, titles: list[str]) -> bool: 52 | """단어가 호칭인지 확인합니다.""" 53 | return word.lower().rstrip('.') in titles 54 | 55 | def is_sentence_boundary(self, char: str, next_char: str, current_sentence: str, titles: list[str]) -> bool: 56 | """문장의 경계인지 판단합니다.""" 57 | if char != '.' or next_char not in [' ', '\n']: 58 | return False 59 | 60 | words = current_sentence.strip().split() 61 | if not words: 62 | return False 63 | 64 | return not self.is_title(words[-1], titles) 65 | 66 | 67 | def split_sentences(self, text: str) -> list[str]: 68 | """텍스트를 문장 단위로 분리합니다.""" 69 | titles = ['mr', 'mrs', 'ms', 'miss', 'dr', 'prof', 'rev', 'hon', 'st'] 70 | 71 | result = [] 72 | current = "" 73 | 74 | for i, char in enumerate(text): 75 | current += char 76 | 77 | if i + 1 < len(text) and self.is_sentence_boundary(char, text[i + 1], current, titles): 78 | result.append(current.strip()) 79 | current = "" 80 | 81 | if current.strip(): 82 | result.append(current.strip()) 83 | 84 | return [s.replace("\n", " ") for s in result] 85 | 86 | def __call__(self, texts: list[Dict[str, Any]]) -> Generator[Dict[str, Any], None, None]: 87 | """Streaming 방식으로 텍스트를 처리하여 메모리 사용량을 줄입니다.""" 88 | for text in tqdm.tqdm(texts["text"]): 89 | sentences = self.split_sentences(text) 90 | for sentence in sentences: 91 | yield {"text": sentence} 92 | 93 | 94 | if __name__ == "__main__": 95 | downloader = DatasetDownloader(dataset_config=DictConfig({ 96 | "raw_data": { 97 | "dataset_name": "wikipedia", 98 | "dataset_version": "20220301.en", 99 | "dataset_split": { 100 | "train": "train[:1%]", 101 | "val": "train[1%:2%]", 102 | "test": "train[:1%]" 103 | } 104 | } 105 | })) 106 | preprocessor = DatasetPreprocessor(tokenizer_path="google/electra-base-discriminator") 107 | 108 | # Streaming 방식으로 처리하고 파일에 저장 109 | total_count = 0 110 | datasets = downloader() 111 | 112 | for dataset_name, dataset in tqdm.tqdm(datasets.items(), desc="Processing datasets"): 113 | output_file = f"processed_{dataset_name}.jsonl" 114 | with open(output_file, 'w', encoding='utf-8') as f: 115 | for processed_item in preprocessor(dataset): 116 | f.write(json.dumps(processed_item, ensure_ascii=False) + '\n') 117 | total_count += 1 118 | # 메모리 사용량을 더 줄이기 위해 주기적으로 flush 119 | if total_count % 1000 == 0: 120 | f.flush() 121 | 122 | print(f"Total processed items: {total_count}") 123 | print("Processed data saved to JSONL files") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ElectraModel based on Kolmogorov–Arnold Network 2 | 3 | ## Introduction 4 | 5 | Recently, Kolmogorov–Arnold Networks (KANs) have been introduced as a replacement for Fully Connected Layers. 6 | According to the authors of the paper, KANs have a significant advantage in terms of speed and performance compared to traditional FC layers. 7 | 8 | They are currently being actively applied in vision-based models and were recently applied to Transformer decoder models as well. 9 | It has been reported that this has resulted in a substantial improvement in both performance and throuthput side. 10 | 11 | Despite the verification of KANs' performance in various areas, I have not yet found examples of their application to the Transformer encoder part. 12 | In this repository, we aim to verify whether KANs can be applied to Transformer encoder models and to evaluate their performance. 13 | 14 | ## Model Download link 15 | 16 | Once the model training is complete, a download link will be provided on Hugging Face. 17 | 18 | ## About KAN Electra 19 | 20 | ![KAN Electra Model](static/electra.png) 21 | 22 | KAN Electra replaces the fully connected layers in a typical Transformer Encoder model with Kolmogorov–Arnold Networks (KANs). This modification aims to leverage the speed and performance benefits of KANs to enhance the efficiency and effectiveness of the Transformer Encoder. 23 | 24 | In the Encoder model, attention is implemented as self-attention using scaled-dot product attention. 25 | 26 | Self-attention allows the model to weigh the importance of different words in a sentence relative to each other, regardless of their position. This is particularly useful for capturing long-range dependencies in the input sequence. 27 | 28 | Scaled-dot product attention works as follows: 29 | 1. **Query, Key, and Value Matrices**: The input embeddings are transformed into three matrices: Query (Q), Key (K), and Value (V). 30 | 2. **Dot Product**: The Query matrix is multiplied by the Key matrix to obtain a score matrix. This score indicates the relevance of each word in the sequence to every other word. 31 | 3. **Scaling**: The scores are scaled down by the square root of the dimension of the Key matrix to prevent the gradients from becoming too small during backpropagation. 32 | 4. **Softmax**: The scaled scores are passed through a softmax function to obtain attention weights. These weights determine the importance of each word in the context of the current word. 33 | 5. **Weighted Sum**: The attention weights are used to compute a weighted sum of the Value matrix, producing the final output of the attention mechanism. 34 | 35 | This mechanism allows the model to focus on relevant parts of the input sequence, enhancing its ability to understand and generate complex patterns in the data. 36 | 37 | ### Hyperparameters 38 | 39 | ### Vocab 40 | 41 | I use Google's `google/electra-small-discriminator` for English model, `monologg/koelectra-base-v3-discriminator` for korean model. to train KANElectra model for pretraining. 42 | 43 | ### Data 44 | 45 | ## Requirements 46 | 47 | ```text 48 | python 3.8 or higher 49 | torch >= 2.2.0 50 | torch-tensorRT 51 | hydra 52 | lightning 53 | transformers 54 | ``` 55 | 56 | ## CLI 57 | 58 | ### Train 59 | 60 | ElectraKAN uses Hydra from Meta Inc to simplify configuring parameters and Easy-to-plug parameters in CLI. You can train ElectraKAN with your own code by using Hydra's own syntax here ss an example of training Electra Language model using generator and discriminator. 61 | 62 | 63 | #### Train using Helm Charts 64 | 65 | ```shell 66 | # Run training with default configuration 67 | helm install electra-training ./helm/electra-training 68 | 69 | # Run training with specific image tag 70 | helm install electra-training ./helm/electra-training --set image.tag=v1.0.0 71 | 72 | # Run training with custom configuration 73 | helm install electra-training ./helm/electra-training -f custom-values.yaml 74 | 75 | # Uninstall training job 76 | helm uninstall electra-training 77 | ``` 78 | 79 | The Helm chart uses the following resources: 80 | - CPU: 8 vCPU 81 | - GPU: 4 GPUs 82 | - Memory: 8Gi (request) / 16Gi (limit) 83 | - Node: Scheduled on nodes with label env=batch 84 | 85 | #### Docker 86 | 87 | ```shell 88 | git clone https://github.com/Klassikcat/KANElectra 89 | docker buildx build -t kanelectra:latest -f nvidia.Dockerfile . 90 | docker run 91 | ``` 92 | 93 | #### On Local Machine 94 | 95 | ```shell 96 | python scripts/pretraining/train.py 97 | ``` 98 | 99 | ### Test(TODO) 100 | 101 | ```shell 102 | python scripts/pretraining/test.py 103 | ``` 104 | 105 | ## TensorRT Convert 106 | 107 | ```shell 108 | trtexec \ 109 | --onnx=${your_onnx_engine_path} \ 110 | --saveEngine=${engine_save_path} \ 111 | --minShape=1x512,1x512,1x512 \ 112 | --optShape=${opt_batch_size}x512,${opt_batch_size}x512,${opt_batch_size}x512 \ 113 | --maxShape=${max_batch_size}x512,${max_batch_size}x512,${max_batch_size}x512 114 | ``` 115 | 116 | #### Install Package 117 | 118 | ```shell 119 | # Install from source 120 | pip install -e . 121 | 122 | # Install with development dependencies 123 | pip install -e ".[dev]" 124 | ``` 125 | 126 | After installation, you can use ElectraKAN in your Python code: 127 | 128 | ```python 129 | from ElectraKAN import ElectraModel 130 | from ElectraKAN.datamodule import ElectraKANDataModule 131 | ``` 132 | -------------------------------------------------------------------------------- /src/ElectraKAN/handlers.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Tuple, Any, Dict, List 3 | import torch 4 | from torch import ( 5 | Tensor, 6 | LongTensor, 7 | FloatTensor 8 | ) 9 | import torch.nn.functional as F 10 | import lightning.pytorch as pl 11 | from torchmetrics.classification import Accuracy, Precision, Recall, FBetaScore 12 | from omegaconf import DictConfig 13 | from .modules import ( 14 | ElectraGenerator, 15 | ElectraDiscriminator 16 | ) 17 | 18 | 19 | class ElectraModel(pl.LightningModule): 20 | def __init__(self, config: DictConfig): 21 | super(ElectraModel, self).__init__() 22 | self.generator = ElectraGenerator(**config.generator) 23 | self.discriminator = ElectraDiscriminator(**config.discriminator) 24 | self.config = config 25 | self._automatic_optimization = False 26 | 27 | self.accuracy = Accuracy(task='multiclass', num_classes=2) 28 | self.precision = Precision(task='multiclass', num_classes=2) 29 | self.recall = Recall(task='multiclass', num_classes=2) 30 | self.f1 = FBetaScore(task='multiclass', num_classes=2, beta=1.0) 31 | 32 | def locate_mask(self, input_ids: Tensor) -> Tensor: 33 | return input_ids == self.config.mask_token_id 34 | 35 | 36 | def forward(self, input_ids: LongTensor, attention_mask: LongTensor, token_type_ids: LongTensor) -> Tuple[Tensor, Tensor]: 37 | generator_logits = self.generator(input_ids, attention_mask, token_type_ids) 38 | # Sample new tokens based on generator logits 39 | sampled_ids = torch.argmax(generator_logits, dim=-1) # Greedy sampling 40 | 41 | # Replace masked tokens in input_ids with sampled_ids 42 | input_ids_with_generated = input_ids.clone() 43 | mask_token_indices = input_ids == self.config.mask_token_id 44 | input_ids_with_generated[mask_token_indices] = sampled_ids[mask_token_indices] 45 | 46 | # Discriminator processes the replaced input_ids 47 | discriminator_logits = self.discriminator(input_ids_with_generated, attention_mask, token_type_ids) 48 | return generator_logits, discriminator_logits 49 | 50 | def training_step(self, batch: Tuple[LongTensor, LongTensor, LongTensor], batch_idx: int) -> Tensor: 51 | masked_input_ids, attention_mask, token_type_ids, input_ids = batch 52 | generator_logits, discriminator_logits = self(masked_input_ids, attention_mask, token_type_ids) 53 | 54 | generator_optimizer, discriminator_optimizer = self.optimizers() 55 | 56 | mask_token_indices = self.locate_mask(masked_input_ids) 57 | generator_loss = self.generator_loss(generator_logits, input_ids, mask_token_indices) 58 | self.log('train_generator_loss', generator_loss, on_step=True, on_epoch=True, prog_bar=True) 59 | generator_loss.backward() 60 | generator_optimizer.step() 61 | generator_optimizer.zero_grad() 62 | 63 | discriminator_token_ids = torch.argmax(discriminator_logits, dim=-1) 64 | generated_labels = self.create_discriminator_labels(input_ids, discriminator_token_ids) 65 | 66 | discriminator_loss = self.discriminator_loss(discriminator_logits, generated_labels) 67 | self.log('train_discriminator_loss', discriminator_loss, on_step=True, on_epoch=True, prog_bar=True) 68 | discriminator_loss.backward() 69 | discriminator_optimizer.step() 70 | discriminator_optimizer.zero_grad() 71 | 72 | preds = torch.argmax(discriminator_logits, dim=-1) 73 | self.log('train_accuracy', self.accuracy(preds, generated_labels), on_epoch=True) 74 | self.log('train_precision', self.precision(preds, generated_labels), on_epoch=True) 75 | self.log('train_recall', self.recall(preds, generated_labels), on_epoch=True) 76 | self.log('train_f1', self.f1(preds, generated_labels), on_epoch=True, prog_bar=True, on_step=True) 77 | 78 | @torch.no_grad() 79 | def validation_step(self, batch: Tuple[LongTensor, LongTensor, LongTensor], batch_idx: int) -> None: 80 | masked_input_ids, attention_mask, token_type_ids, input_ids = batch 81 | generator_logits, discriminator_logits = self(masked_input_ids, attention_mask, token_type_ids) 82 | 83 | mask_token_indices = self.locate_mask(masked_input_ids) 84 | generator_loss = self.generator_loss(generator_logits, input_ids, mask_token_indices) 85 | 86 | discriminator_token_ids = torch.argmax(discriminator_logits, dim=-1) 87 | generated_labels = self.create_discriminator_labels(input_ids, discriminator_token_ids) 88 | 89 | discriminator_loss = self.discriminator_loss(discriminator_logits, generated_labels) 90 | 91 | self.log('val_generator_loss', generator_loss, on_step=True, on_epoch=True, prog_bar=True) 92 | self.log('val_discriminator_loss', discriminator_loss, on_step=True, on_epoch=True, prog_bar=True) 93 | 94 | preds = torch.argmax(discriminator_logits, dim=-1) 95 | self.log('val_accuracy', self.accuracy(preds, generated_labels), on_epoch=True) 96 | self.log('val_precision', self.precision(preds, generated_labels), on_epoch=True) 97 | self.log('val_recall', self.recall(preds, generated_labels), on_epoch=True) 98 | self.log('val_f1', self.f1(preds, generated_labels), on_epoch=True, prog_bar=True) 99 | 100 | def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Any]]: 101 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.config.generator_lr) 102 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.config.discriminator_lr) 103 | return [generator_optimizer, discriminator_optimizer], [] 104 | 105 | def generator_loss(self, generator_logits: LongTensor, input_ids: LongTensor, mask_token_indices: Tensor) -> Tensor: 106 | # Only calculate loss for masked tokens 107 | logits_for_masked = generator_logits[mask_token_indices] 108 | labels_for_masked = input_ids[mask_token_indices] 109 | return F.cross_entropy(logits_for_masked, labels_for_masked) 110 | 111 | def discriminator_loss(self, discriminator_logits: LongTensor, generated_labels: Tensor|LongTensor) -> Tensor: 112 | return F.cross_entropy(discriminator_logits.view(-1, 2), generated_labels.view(-1)) 113 | 114 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 115 | checkpoint['generator_state_dict'] = self.generator.state_dict() 116 | checkpoint['discriminator_state_dict'] = self.discriminator.state_dict() 117 | 118 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 119 | self.generator.load_state_dict(checkpoint['generator_state_dict']) 120 | self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) 121 | 122 | def create_discriminator_labels(self, input_ids: LongTensor, generated_logits: Tensor|LongTensor) -> torch.Tensor: 123 | # Compare original and generated logits 124 | labels = (input_ids != generated_logits).long() 125 | return labels 126 | -------------------------------------------------------------------------------- /src/ElectraKAN/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | from torch import Tensor 5 | import lightning.pytorch as pl 6 | from lightning.fabric.utilities.types import _PATH 7 | from typing import Optional, Literal, Dict, Tuple 8 | from datetime import timedelta 9 | 10 | 11 | class PretrainingCheckpoint(pl.callbacks.ModelCheckpoint): 12 | def __init__(self) -> None: 13 | super().__init__() # TODO: Add a checkpointer for generator and discriminator 14 | 15 | 16 | class OnnxCompiler(pl.callbacks.ModelCheckpoint): 17 | def __init__( 18 | self, 19 | dirpath: Optional[_PATH] = None, 20 | filename: Optional[str] = None, 21 | monitor: Optional[str] = None, 22 | verbose: bool = False, 23 | save_last: Optional[Literal[True, False, "link"]] = None, 24 | save_top_k: int = 1, 25 | save_weights_only: bool = False, 26 | mode: str = "min", 27 | auto_insert_metric_name: bool = True, 28 | every_n_train_steps: Optional[int] = None, 29 | train_time_interval: Optional[timedelta] = None, 30 | every_n_epochs: Optional[int] = None, 31 | save_on_train_epoch_end: Optional[bool] = None, 32 | enable_version_counter: bool = True, 33 | save_onnx: bool = True, 34 | save_tensorrt_engine: bool = False, 35 | min_shape: Tuple[int, int] = (1, 512), 36 | opt_shape: Tuple[int, int] = (16, 512), 37 | max_shape: Tuple[int, int] = (32, 512), 38 | save_mixed_precision: bool = True, 39 | ): 40 | super().__init__( 41 | dirpath=dirpath, 42 | filename=filename, 43 | monitor=monitor, 44 | verbose=verbose, 45 | save_last=save_last, 46 | save_top_k=save_top_k, 47 | save_weights_only=save_weights_only, 48 | mode=mode, 49 | auto_insert_metric_name=auto_insert_metric_name, 50 | every_n_train_steps=every_n_train_steps, 51 | train_time_interval=train_time_interval, 52 | every_n_epochs=every_n_epochs, 53 | save_on_train_epoch_end=save_on_train_epoch_end, 54 | enable_version_counter=enable_version_counter, 55 | ) 56 | self.min_shape = min_shape 57 | self.opt_shape = opt_shape 58 | self.max_shape = max_shape 59 | self.mixed_precision = save_mixed_precision 60 | self.save_dir = Path(dirpath) / filename # TODO: find out optimial name 61 | if not save_onnx: 62 | assert ( 63 | save_tensorrt_engine 64 | ), "save_tensorrt_engine must be True if save_onnx is False. if you don't want to save neither .onnx and .engine, you should use ModelCheckpoint instead." 65 | self.save_onnx = save_onnx 66 | if save_tensorrt_engine: 67 | import tensorrt as trt 68 | 69 | self.trt_logger = trt.Logger(min_severity=trt.Logger.WARNING) 70 | self.builder = trt.Builder(self.trt_logger) 71 | self.network = trt.builder.create_network( 72 | 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 73 | ) 74 | else: 75 | self.builder = False 76 | 77 | def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 78 | super().on_train_epoch_end(trainer, pl_module) 79 | if not self._should_skip_saving_checkpoint( 80 | trainer 81 | ) and not self._should_save_on_train_epoch_end(trainer): 82 | monitor_candidates = self._monitor_candidates(trainer) 83 | if ( 84 | self._every_n_epochs >= 1 85 | and (trainer.current_epoch + 1) % self._every_n_epochs == 0 86 | ): 87 | self._save_into_onnx(trainer, monitor_candidates) 88 | if self.builder: 89 | self._save_tensorRT_engine(trainer, monitor_candidates) 90 | self._save_into_onnx(trainer, monitor_candidates) 91 | if self.builder: 92 | self._save_tensorRT_engine(trainer, monitor_candidates) 93 | 94 | def on_validation_epoch_end( 95 | self, trainer: pl.Trainer, pl_module: pl.LightningModule 96 | ): 97 | super().on_train_epoch_end(trainer, pl_module) 98 | if not self._should_skip_saving_checkpoint( 99 | trainer 100 | ) and not self._should_save_on_train_epoch_end(trainer): 101 | monitor_candidates = self._monitor_candidates(trainer) 102 | if ( 103 | self._every_n_epochs >= 1 104 | and (trainer.current_epoch + 1) % self._every_n_epochs == 0 105 | ): 106 | onnx_save_path = self._save_into_onnx(trainer, monitor_candidates) 107 | if self.builder: 108 | self._save_tensorRT_engine(trainer, monitor_candidates, onnx_save_path) 109 | onnx_save_path = self._save_into_onnx(trainer, monitor_candidates) 110 | if self.builder: 111 | self._save_tensorRT_engine(trainer, monitor_candidates, onnx_save_path) 112 | 113 | def _save_into_onnx( 114 | self, trainer: pl.Trainer, monitor_candidates: Dict[str, Tensor] 115 | ) -> str: 116 | input_ids = torch.randint(0, 10000, (8, 512)) 117 | attention_masks = torch.randint(0, 1, (8, 512)) 118 | token_type_ids = torch.randint(0, 1, (8, 512)) 119 | 120 | trainer.model.eval() 121 | with torch.no_grad(): 122 | _ = trainer.model( 123 | input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids 124 | ) 125 | 126 | dynamic_axes = { 127 | "input_ids": {0: "batch_size"}, 128 | "attention_mask": {0: "batch_size"}, 129 | "token_type_ids": {0: "batch_size"}, 130 | "outputs": {0: "batch_size"}, 131 | } 132 | 133 | torch.onnx.export( 134 | trainer.model, # 학습된 모델 인스턴스 135 | (input_ids, attention_masks, token_type_ids), # input args 136 | str(self.save_dir) + ".onnx", # 저장 경로 137 | export_params=True, # 모델의 학습된 파라미터를 저장할 것인지 138 | opset_version=17, # 사용할 onnx의 버전 139 | do_constant_folding=True, 140 | input_names=["input_ids", "attention_mask", "token_type_ids"], 141 | output_names=["outputs"], 142 | dynamic_axes=dynamic_axes, 143 | ) 144 | 145 | def _save_tensorRT_engine( 146 | self, trainer: pl.Trainer, monitor_candidates: Dict[str, Tensor] 147 | ) -> None: 148 | import tensorrt as trt 149 | 150 | parser = trt.OnnxParser(self.network, self.trt_logger) 151 | parser.parse_from_file(str(self.save_dir) + ".onnx") 152 | config = self.builder.create_builder_config() 153 | profile = self.builder.create_optimization_profile() 154 | profile.set_shape( 155 | "input_ids", 156 | (self.min_shape[0], self.min_shape[1]), 157 | (self.opt_shape[0], self.opt_shape[1]), 158 | (self.max_shape[0], self.max_shape[1]), 159 | ) 160 | profile.set_shape( 161 | "attention_mask", 162 | (self.min_shape[0], self.min_shape[1]), 163 | (self.opt_shape[0], self.opt_shape[1]), 164 | (self.max_shape[0], self.max_shape[1]), 165 | ) 166 | profile.set_shape( 167 | "token_type_ids", 168 | (self.min_shape[0], self.min_shape[1]), 169 | (self.opt_shape[0], self.opt_shape[1]), 170 | (self.max_shape[0], self.max_shape[1]), 171 | ) 172 | config.add_optimization_profile(profile) 173 | if self.mixed_precision: 174 | config.set_flag(trt.BuilderFlag.FP16) 175 | serialized_engine = self.builder.build_serialized_network(self.network, config) 176 | with open(f"{str(self.save_dir)}.plan", "wb") as engine: 177 | engine.write(serialized_engine) 178 | if not self.save_onnx: 179 | os.remove(f"{str(self.save_dir)}.onnx") 180 | -------------------------------------------------------------------------------- /src/ElectraKAN/modules.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import math 3 | import torch 4 | from torch import ( 5 | nn, 6 | Tensor, 7 | FloatTensor, 8 | LongTensor 9 | ) 10 | from .kan import KAN 11 | import torch.nn.functional as F 12 | 13 | 14 | class ElectraGenerator(nn.Module): 15 | def __init__( 16 | self, 17 | vocab_size: int, 18 | embedding_dim: int, 19 | vocab_type_size: int, 20 | embedding_dropout_p: float, 21 | hidden_dim: int, 22 | num_heads: int, 23 | ff_dim: int, 24 | num_layers: int, 25 | max_pos_embedding: int 26 | ): 27 | super().__init__() 28 | self.embedding = InputEmbedding( 29 | vocab_size, 30 | embedding_dim, 31 | vocab_type_size, 32 | embedding_dropout_p, 33 | max_pos_embedding 34 | ) 35 | self.encoder = ElectraEncoder( 36 | hidden_dim, 37 | num_heads, 38 | num_layers, 39 | 0.1, 40 | ff_dim 41 | ) 42 | self.generator = GeneratorOutput(hidden_dim, vocab_size) 43 | 44 | def forward( 45 | self, 46 | input_ids: LongTensor, 47 | attention_mask: LongTensor, 48 | token_type_ids: LongTensor, 49 | ) -> Tensor: 50 | embeddings = self.embedding(input_ids, token_type_ids) 51 | seq_out = self.encoder(embeddings, attention_mask) 52 | dropouted_seq_output = F.dropout(seq_out, p=0.1) 53 | return self.generator(dropouted_seq_output) 54 | 55 | 56 | class GeneratorOutput(nn.Module): 57 | def __init__(self, hidden, vocab_size) : 58 | super().__init__() 59 | self.linear = nn.Linear(hidden, vocab_size) 60 | self.softmax = nn.LogSoftmax(dim = -1) 61 | 62 | def forward(self, x) : 63 | return self.softmax(self.linear(x)) 64 | 65 | 66 | class ElectraDiscriminator(nn.Module): 67 | def __init__( 68 | self, 69 | vocab_size: int, 70 | embedding_dim: int, 71 | vocab_type_size: int, 72 | embedding_dropout_p: float, 73 | hidden_dim: int, 74 | num_heads: int, 75 | ff_dim: int, 76 | num_layers: int, 77 | max_pos_embedding: int, 78 | num_labels: int 79 | ): 80 | super().__init__() 81 | self.embedding = InputEmbedding( 82 | vocab_size, 83 | embedding_dim, 84 | vocab_type_size, 85 | embedding_dropout_p, 86 | max_pos_embedding 87 | ) 88 | self.encoder = ElectraEncoder( 89 | hidden_dim, 90 | num_heads, 91 | num_layers, 92 | 0.1, 93 | ff_dim 94 | ) 95 | self.classifier = KAN(width=[hidden_dim, num_labels]) 96 | 97 | def forward( 98 | self, 99 | input_ids: LongTensor, 100 | attention_mask: LongTensor, 101 | token_type_ids: LongTensor, 102 | ) -> Tensor: 103 | embeddings = self.embedding(input_ids, token_type_ids) 104 | seq_out = self.encoder(embeddings, attention_mask) 105 | dropouted_seq_output = F.dropout(seq_out, p=0.1) 106 | return self.classifier(dropouted_seq_output) 107 | 108 | 109 | class ElectraEncoder(nn.Module): 110 | def __init__( 111 | self, 112 | dim: int, 113 | num_heads: int, 114 | num_layers: int, 115 | dropout_p: float = 0.1, 116 | hidden_dim: Optional[int] = None, 117 | ): 118 | super().__init__() 119 | if not hidden_dim: 120 | hidden_dim = dim * 4 # default hidden_dim on paper 121 | self.layers = nn.ModuleList([ 122 | EncoderLayer(dim, num_heads, hidden_dim, dropout_p) for i in range(num_layers) 123 | ]) 124 | 125 | def forward( 126 | self, 127 | hidden_states: Tensor, 128 | mask: Tensor 129 | ) -> Tensor: 130 | for layer in self.layers: 131 | hidden_states = layer(hidden_states, mask) 132 | return hidden_states 133 | 134 | 135 | class InputEmbedding(nn.Module): 136 | def __init__( 137 | self, 138 | vocab_size: int, 139 | embedding_dim: int, 140 | vocab_type_size: int, 141 | embedding_dropout_p: float, 142 | max_pos_embedding: int 143 | ): 144 | super().__init__() 145 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 146 | self.positional_embedding = nn.Embedding(max_pos_embedding, embedding_dim) 147 | self.token_type_embedding = nn.Embedding(vocab_type_size, embedding_dim) 148 | self.dropout = nn.Dropout(embedding_dropout_p) 149 | 150 | def forward( 151 | self, 152 | input_ids: LongTensor, 153 | token_type_ids: LongTensor, 154 | ) -> Tensor: 155 | seq_length = input_ids.shape[1] 156 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 157 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 158 | 159 | embeddings = ( 160 | self.embedding(input_ids) + 161 | self.positional_embedding(position_ids) + 162 | self.token_type_embedding(token_type_ids) 163 | ) 164 | return self.dropout(embeddings) 165 | 166 | 167 | class ScaledDotProductAttention(nn.Module): 168 | def __init__(self, dropout_p: float): 169 | super().__init__() 170 | self.softmax = nn.Softmax(dim=-1) 171 | 172 | def forward( 173 | self, 174 | query: Tensor, 175 | key: Tensor, 176 | value: Tensor, 177 | attention_mask: Optional[LongTensor] = None 178 | ) -> Tensor: 179 | batch_size, n_head, length, d_tensor = query.shape 180 | multiplied_kv = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(d_tensor) 181 | if attention_mask is not None: 182 | broadcased_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 183 | masked_attention = multiplied_kv.masked_fill(broadcased_attention_mask == 0, -1e9) 184 | else: 185 | masked_attention = multiplied_kv 186 | attention = self.softmax(masked_attention) 187 | return torch.matmul(attention, value) 188 | 189 | 190 | class MultiHeadAttention(nn.Module): 191 | def __init__( 192 | self, 193 | dim: int, 194 | num_heads: int, 195 | dropout_p: float 196 | ): 197 | super().__init__() 198 | assert dim % num_heads == 0 199 | self.attention = ScaledDotProductAttention(dropout_p) 200 | self.dropout = nn.Dropout(dropout_p) 201 | self.fc_q = KAN(width=[dim, dim]) 202 | self.fc_k = KAN(width=[dim, dim]) 203 | self.fc_v = KAN(width=[dim, dim]) 204 | self.fc_out = KAN(width=[dim, dim]) 205 | self.num_heads = num_heads 206 | self.dim = dim 207 | 208 | def forward( 209 | self, 210 | query: Tensor, 211 | key: Tensor, 212 | value: Tensor, 213 | attention_mask: LongTensor 214 | ) -> Tensor: 215 | batch_size = query.size(0) 216 | length = query.size(1) 217 | dim = query.size(2) 218 | d_tensor = dim // self.num_heads 219 | 220 | # split 221 | query = self.fc_q(query).view(batch_size, self.num_heads, length, d_tensor) 222 | key = self.fc_k(key).view(batch_size, self.num_heads, length, d_tensor) 223 | value = self.fc_v(value).view(batch_size, self.num_heads, length, d_tensor) 224 | attention_output = self.attention(query, key, value, attention_mask) 225 | # concat 226 | attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * (self.dim // self.num_heads)) 227 | output = self.fc_out(attention_output) 228 | return self.dropout(output) 229 | 230 | 231 | class FeedForward(nn.Module): 232 | def __init__( 233 | self, 234 | dim: int, 235 | ff_dim: int, 236 | dropout_p: float 237 | ): 238 | super().__init__() 239 | self.fc1 = nn.Linear(dim, ff_dim) 240 | self.relu = nn.ReLU() 241 | self.fc2 = nn.Linear(ff_dim, dim) 242 | self.dropout = nn.Dropout(dropout_p) 243 | 244 | def forward( 245 | self, 246 | x: Tensor 247 | ) -> Tensor: 248 | x = self.fc1(x) 249 | x = self.relu(x) 250 | x = self.fc2(x) 251 | x = self.dropout(x) 252 | return x 253 | 254 | 255 | class EncoderLayer(nn.Module): 256 | def __init__( 257 | self, 258 | dim: int, 259 | num_heads: int, 260 | hidden_dim: int, 261 | dropout_p: float 262 | ): 263 | super().__init__() 264 | self.attn = MultiHeadAttention(dim, num_heads, dropout_p) 265 | self.ff = FeedForward(dim, hidden_dim, dropout_p) 266 | self.norm1 = nn.LayerNorm(dim) 267 | self.norm2 = nn.LayerNorm(dim) 268 | self.dropout = nn.Dropout(dropout_p) 269 | 270 | def forward( 271 | self, 272 | x: Tensor, 273 | attention_mask: LongTensor 274 | ) -> Tensor: 275 | attention_output = self.attn(x, x, x, attention_mask) 276 | add_norm = self.norm1(x + attention_output) 277 | output = self.ff(attention_output) 278 | ff_add_norm = self.norm2(add_norm + output) 279 | return self.dropout(ff_add_norm) 280 | -------------------------------------------------------------------------------- /src/ElectraKAN/datamodule.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | from os import PathLike 4 | from typing import List, Tuple, Set, Optional, Dict, Any 5 | 6 | import torch 7 | from pyarrow import parquet as pq 8 | import numpy as np 9 | from transformers import PreTrainedTokenizer, AutoTokenizer 10 | from torch.utils.data import Dataset, DataLoader 11 | import torch.nn.functional as F 12 | import lightning.pytorch as ptl 13 | 14 | 15 | class ElectraKANDataModule(ptl.LightningDataModule): 16 | def __init__( 17 | self, 18 | train_dataset: Dataset, 19 | val_dataset: Dataset, 20 | test_dataset: Dataset, 21 | batch_size: int, 22 | num_workers: int, 23 | pin_memory: bool 24 | ) -> None: 25 | super().__init__() 26 | self.train_dataset = train_dataset 27 | self.val_dataset = val_dataset 28 | self.test_dataset = test_dataset 29 | self.batch_size = batch_size 30 | self.num_workers = num_workers 31 | self.pin_memory = pin_memory 32 | 33 | def train_dataloader(self) -> DataLoader: 34 | return DataLoader( 35 | self.train_dataset, 36 | batch_size=self.batch_size, 37 | num_workers=self.num_workers, 38 | pin_memory=self.pin_memory, 39 | shuffle=True 40 | ) 41 | 42 | def val_dataloader(self) -> DataLoader: 43 | return DataLoader( 44 | self.val_dataset, 45 | batch_size=self.batch_size, 46 | num_workers=self.num_workers, 47 | pin_memory=self.pin_memory, 48 | shuffle=False 49 | ) 50 | 51 | def test_dataloader(self) -> DataLoader: 52 | return DataLoader( 53 | self.test_dataset, 54 | batch_size=self.batch_size, 55 | num_workers=self.num_workers, 56 | pin_memory=self.pin_memory, 57 | shuffle=False 58 | ) 59 | 60 | 61 | class StreamingElectraPretrainingDataset(Dataset): 62 | def __init__( 63 | self, 64 | path: PathLike|str, 65 | tokenizer: PreTrainedTokenizer, 66 | max_length: int = 512, 67 | chunk_size: int = 1000, 68 | text_column: str = "text" 69 | ) -> None: 70 | super().__init__() 71 | self.path = path 72 | self.tokenizer = tokenizer 73 | self.max_length = max_length 74 | self.chunk_size = chunk_size 75 | self.text_column = text_column 76 | self.current_chunk = [] 77 | self.current_chunk_idx = 0 78 | self.total_samples = self._count_samples() 79 | 80 | def _count_samples(self) -> int: 81 | if str(self.path).endswith('.parquet'): 82 | return len(pq.read_table(self.path)) 83 | else: 84 | with open(self.path, 'r') as f: 85 | return sum(1 for _ in f) - 1 # 헤더 제외 86 | 87 | def _load_next_chunk(self): 88 | if str(self.path).endswith('.parquet'): 89 | table = pq.read_table(self.path, skip=self.current_chunk_idx, take=self.chunk_size) 90 | self.current_chunk = table.column(self.text_column).to_pylist() 91 | else: 92 | with open(self.path, 'r') as f: 93 | reader = csv.reader(f) 94 | next(reader) # 헤더 스킵 95 | for _ in range(self.current_chunk_idx): 96 | next(reader) 97 | self.current_chunk = [next(reader)[0] for _ in range(min(self.chunk_size, self.total_samples - self.current_chunk_idx))] 98 | self.current_chunk_idx += len(self.current_chunk) 99 | 100 | def __len__(self) -> int: 101 | return self.total_samples 102 | 103 | def dynamic_masking(self, tokens: torch.Tensor) -> torch.Tensor: 104 | tokens_to_be_masked: torch.Tensor = tokens.clone() 105 | end_of_token: int = int(torch.where(tokens == self.tokenizer.sep_token_id)[0]) - 1 106 | masked_indices = torch.bernoulli(torch.full((end_of_token,), 0.15)).bool() 107 | if masked_indices[0] == True: 108 | masked_indices[0] = False 109 | padded_masked_indices: torch.Tensor = F.pad(masked_indices, mode='constant', value=False, pad=(self.max_length - end_of_token, 0)) 110 | tokens_to_be_masked = torch.where(padded_masked_indices == True, self.tokenizer.mask_token_id, tokens_to_be_masked) 111 | return tokens_to_be_masked 112 | 113 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.LongTensor]: 114 | if idx >= self.current_chunk_idx or idx < self.current_chunk_idx - len(self.current_chunk): 115 | self.current_chunk_idx = (idx // self.chunk_size) * self.chunk_size 116 | self._load_next_chunk() 117 | 118 | local_idx = idx - (self.current_chunk_idx - len(self.current_chunk)) 119 | text = self.current_chunk[local_idx] 120 | 121 | tokenized = tuple( 122 | self.tokenizer( 123 | text, 124 | return_attention_mask=True, 125 | return_token_type_ids=True, 126 | return_tensors='pt', 127 | max_length=self.max_length, 128 | padding="max_length", 129 | truncation=True 130 | ).values() 131 | ) 132 | input_ids, token_type_ids, attention_mask = tokenized 133 | masked_input_ids = self.dynamic_masking(input_ids.squeeze(0)) 134 | return masked_input_ids, attention_mask.squeeze(0), token_type_ids.squeeze(0), input_ids.squeeze(0) 135 | 136 | @classmethod 137 | def from_csv( 138 | cls, 139 | path: PathLike|str, 140 | tokenizer: PreTrainedTokenizer|str, 141 | text_row: int, 142 | text_b_row: Optional[int] = None, 143 | max_length: int = 512, 144 | chunk_size: int = 1000 145 | ): 146 | if isinstance(tokenizer, str): 147 | tokenizer_instance = AutoTokenizer.from_pretrained(tokenizer) 148 | else: 149 | tokenizer_instance = tokenizer 150 | return cls(path, tokenizer_instance, max_length, chunk_size) 151 | 152 | @classmethod 153 | def from_parquet( 154 | cls, 155 | path: PathLike|str, 156 | tokenizer: PreTrainedTokenizer|str, 157 | text_column: str, 158 | max_length: int = 512, 159 | chunk_size: int = 1000 160 | ): 161 | if isinstance(tokenizer, str): 162 | tokenizer_instance = AutoTokenizer.from_pretrained(tokenizer) 163 | else: 164 | tokenizer_instance = tokenizer 165 | return cls(path, tokenizer_instance, max_length, chunk_size, text_column) 166 | 167 | 168 | class StreamingElectraClassificationDataset(Dataset): 169 | def __init__( 170 | self, 171 | path: PathLike|str, 172 | tokenizer: PreTrainedTokenizer, 173 | max_length: int = 512, 174 | chunk_size: int = 1000, 175 | text_column: str = "text", 176 | label_column: str = "label" 177 | ) -> None: 178 | super().__init__() 179 | self.path = path 180 | self.tokenizer = tokenizer 181 | self.max_length = max_length 182 | self.chunk_size = chunk_size 183 | self.text_column = text_column 184 | self.label_column = label_column 185 | self.current_chunk = [] 186 | self.current_chunk_idx = 0 187 | self.total_samples = self._count_samples() 188 | self.label_dict = self._build_label_dict() 189 | 190 | def _count_samples(self) -> int: 191 | if str(self.path).endswith('.parquet'): 192 | return len(pq.read_table(self.path)) 193 | else: 194 | with open(self.path, 'r') as f: 195 | return sum(1 for _ in f) - 1 # 헤더 제외 196 | 197 | def _build_label_dict(self) -> Dict[str, int]: 198 | if str(self.path).endswith('.parquet'): 199 | table = pq.read_table(self.path) 200 | labels = set(table.column(self.label_column).to_pylist()) 201 | else: 202 | with open(self.path, 'r') as f: 203 | reader = csv.reader(f) 204 | next(reader) # 헤더 스킵 205 | labels = set(row[1] for row in reader) 206 | return {label: idx for idx, label in enumerate(sorted(labels))} 207 | 208 | def _load_next_chunk(self): 209 | if str(self.path).endswith('.parquet'): 210 | table = pq.read_table(self.path, skip=self.current_chunk_idx, take=self.chunk_size) 211 | texts = table.column(self.text_column).to_pylist() 212 | labels = table.column(self.label_column).to_pylist() 213 | self.current_chunk = list(zip(texts, labels)) 214 | else: 215 | with open(self.path, 'r') as f: 216 | reader = csv.reader(f) 217 | next(reader) # 헤더 스킵 218 | for _ in range(self.current_chunk_idx): 219 | next(reader) 220 | self.current_chunk = [next(reader) for _ in range(min(self.chunk_size, self.total_samples - self.current_chunk_idx))] 221 | self.current_chunk_idx += len(self.current_chunk) 222 | 223 | def __len__(self) -> int: 224 | return self.total_samples 225 | 226 | def __getitem__(self, idx: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: 227 | if idx >= self.current_chunk_idx or idx < self.current_chunk_idx - len(self.current_chunk): 228 | self.current_chunk_idx = (idx // self.chunk_size) * self.chunk_size 229 | self._load_next_chunk() 230 | 231 | local_idx = idx - (self.current_chunk_idx - len(self.current_chunk)) 232 | text, label = self.current_chunk[local_idx] 233 | 234 | tokenized = tuple( 235 | self.tokenizer( 236 | text, 237 | return_attention_mask=True, 238 | return_token_type_ids=True, 239 | return_tensors='pt', 240 | max_length=self.max_length, 241 | padding="max_length", 242 | truncation=True 243 | ).values() 244 | ) 245 | input_ids, token_type_ids, attention_mask = tokenized 246 | label_id = self.label_dict[label] 247 | return input_ids.squeeze(0), attention_mask.squeeze(0), token_type_ids.squeeze(0), torch.tensor(label_id, dtype=torch.long) 248 | 249 | @classmethod 250 | def from_csv( 251 | cls, 252 | path: PathLike|str, 253 | tokenizer: PreTrainedTokenizer|str, 254 | text_row: int, 255 | label_row: int, 256 | text_b_row: Optional[int] = None, 257 | max_length: int = 512, 258 | chunk_size: int = 1000 259 | ): 260 | if isinstance(tokenizer, str): 261 | tokenizer_instance = AutoTokenizer.from_pretrained(tokenizer) 262 | else: 263 | tokenizer_instance = tokenizer 264 | return cls(path, tokenizer_instance, max_length, chunk_size) 265 | 266 | @classmethod 267 | def from_parquet( 268 | cls, 269 | path: PathLike|str, 270 | tokenizer: PreTrainedTokenizer|str, 271 | text_column: str, 272 | label_column: str, 273 | max_length: int = 512, 274 | chunk_size: int = 1000 275 | ): 276 | if isinstance(tokenizer, str): 277 | tokenizer_instance = AutoTokenizer.from_pretrained(tokenizer) 278 | else: 279 | tokenizer_instance = tokenizer 280 | return cls(path, tokenizer_instance, max_length, chunk_size, text_column, label_column) 281 | 282 | -------------------------------------------------------------------------------- /src/ElectraKAN/kan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # reference: efficient-kan by @Blealtan 6 | # CODE: https://github.com/Blealtan/efficient-kan 7 | 8 | 9 | class KANLinear(torch.nn.Module): 10 | def __init__( 11 | self, 12 | in_features, 13 | out_features, 14 | grid_size=5, 15 | spline_order=3, 16 | scale_noise=0.1, 17 | scale_base=1.0, 18 | scale_spline=1.0, 19 | enable_standalone_scale_spline=True, 20 | base_activation=torch.nn.SiLU, 21 | grid_eps=0.02, 22 | grid_range=[-1, 1], 23 | ): 24 | super(KANLinear, self).__init__() 25 | self.in_features = in_features 26 | self.out_features = out_features 27 | self.grid_size = grid_size 28 | self.spline_order = spline_order 29 | 30 | h = (grid_range[1] - grid_range[0]) / grid_size 31 | grid = ( 32 | ( 33 | torch.arange(-spline_order, grid_size + spline_order + 1) * h 34 | + grid_range[0] 35 | ) 36 | .expand(in_features, -1) 37 | .contiguous() 38 | ) 39 | self.register_buffer("grid", grid) 40 | 41 | self.base_weight = torch.nn.Parameter( 42 | torch.Tensor(out_features, in_features) 43 | ) 44 | self.spline_weight = torch.nn.Parameter( 45 | torch.Tensor(out_features, in_features, grid_size + spline_order) 46 | ) 47 | if enable_standalone_scale_spline: 48 | self.spline_scaler = torch.nn.Parameter( 49 | torch.Tensor(out_features, in_features) 50 | ) 51 | 52 | self.scale_noise = scale_noise 53 | self.scale_base = scale_base 54 | self.scale_spline = scale_spline 55 | self.enable_standalone_scale_spline = enable_standalone_scale_spline 56 | self.base_activation = base_activation() 57 | self.grid_eps = grid_eps 58 | 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | torch.nn.init.xavier_uniform_(self.base_weight, gain=self.scale_base) 63 | with torch.no_grad(): 64 | noise = ( 65 | ( 66 | torch.rand( 67 | self.grid_size + 1, self.in_features, self.out_features 68 | ) 69 | - 1 / 2 70 | ) 71 | * self.scale_noise 72 | / self.grid_size 73 | ) 74 | self.spline_weight.data.copy_( 75 | ( 76 | self.scale_spline 77 | if not self.enable_standalone_scale_spline 78 | else 1.0 79 | ) 80 | * self.curve2coeff( 81 | self.grid.T[self.spline_order : -self.spline_order], 82 | noise, 83 | ) 84 | ) 85 | if self.enable_standalone_scale_spline: 86 | torch.nn.init.constant_(self.spline_scaler, self.scale_spline) 87 | 88 | def b_splines(self, x: torch.Tensor): 89 | """ 90 | Compute the B-spline bases for the given input tensor. 91 | 92 | Args: 93 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 94 | 95 | Returns: 96 | torch.Tensor: B-spline bases tensor of shape \ 97 | (batch_size, in_features, grid_size + spline_order). 98 | """ 99 | assert x.dim() == 2 and x.size(1) == self.in_features 100 | 101 | grid: torch.Tensor = ( 102 | self.grid # type: ignore 103 | ) # (in_features, grid_size + 2 * spline_order + 1) 104 | x = x.unsqueeze(-1) 105 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) 106 | for k in range(1, self.spline_order + 1): 107 | bases = ( 108 | (x - grid[:, : -(k + 1)]) 109 | / (grid[:, k:-1] - grid[:, : -(k + 1)]) 110 | * bases[:, :, :-1] 111 | ) + ( 112 | (grid[:, k + 1 :] - x) 113 | / (grid[:, k + 1 :] - grid[:, 1:(-k)]) 114 | * bases[:, :, 1:] 115 | ) 116 | 117 | assert bases.size() == ( 118 | x.size(0), 119 | self.in_features, 120 | self.grid_size + self.spline_order, 121 | ) 122 | return bases.contiguous() 123 | 124 | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): 125 | """ 126 | Compute the coefficients of the curve that interpolates the given 127 | points. 128 | 129 | Args: 130 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 131 | y (torch.Tensor): Output tensor of shape \ 132 | (batch_size, in_features, out_features). 133 | 134 | Returns: 135 | torch.Tensor: Coefficients tensor of shape \ 136 | (out_features, in_features, grid_size + spline_order). 137 | """ 138 | assert x.dim() == 2 and x.size(1) == self.in_features 139 | assert y.size() == (x.size(0), self.in_features, self.out_features) 140 | 141 | A = self.b_splines(x).transpose( 142 | 0, 1 143 | ) # (in_features, batch_size, grid_size + spline_order) 144 | B = y.transpose(0, 1) # (in_features, batch_size, out_features) 145 | solution = torch.linalg.lstsq( 146 | A, B 147 | ).solution # (in_features, grid_size + spline_order, out_features) 148 | result = solution.permute( 149 | 2, 0, 1 150 | ) # (out_features, in_features, grid_size + spline_order) 151 | 152 | assert result.size() == ( 153 | self.out_features, 154 | self.in_features, 155 | self.grid_size + self.spline_order, 156 | ) 157 | return result.contiguous() 158 | 159 | @property 160 | def scaled_spline_weight(self): 161 | return self.spline_weight * ( 162 | self.spline_scaler.unsqueeze(-1) 163 | if self.enable_standalone_scale_spline 164 | else 1.0 165 | ) 166 | 167 | def forward(self, x: torch.Tensor): 168 | assert x.dim() == 2 and x.size(1) == self.in_features 169 | 170 | base_output = F.linear(self.base_activation(x), self.base_weight) 171 | spline_output = F.linear( 172 | self.b_splines(x).view(x.size(0), -1), 173 | self.scaled_spline_weight.view(self.out_features, -1), 174 | ) 175 | return base_output + spline_output 176 | 177 | @torch.no_grad() 178 | def update_grid(self, x: torch.Tensor, margin=0.01): 179 | assert x.dim() == 2 and x.size(1) == self.in_features 180 | batch = x.size(0) 181 | 182 | splines = self.b_splines(x) # (batch, in, coeff) 183 | splines = splines.permute(1, 0, 2) # (in, batch, coeff) 184 | orig_coeff = self.scaled_spline_weight # (out, in, coeff) 185 | orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) 186 | unreduced_spline_output = torch.bmm( 187 | splines, orig_coeff 188 | ) # (in, batch, out) 189 | unreduced_spline_output = unreduced_spline_output.permute( 190 | 1, 0, 2 191 | ) # (batch, in, out) 192 | 193 | # sort each channel individually to collect data distribution 194 | x_sorted = torch.sort(x, dim=0)[0] 195 | grid_adaptive = x_sorted[ 196 | torch.linspace( 197 | 0, 198 | batch - 1, 199 | self.grid_size + 1, 200 | dtype=torch.int64, 201 | device=x.device, 202 | ) 203 | ] 204 | 205 | uniform_step = ( 206 | x_sorted[-1] - x_sorted[0] + 2 * margin 207 | ) / self.grid_size 208 | grid_uniform = ( 209 | torch.arange( 210 | self.grid_size + 1, dtype=torch.float32, device=x.device 211 | ).unsqueeze(1) 212 | * uniform_step 213 | + x_sorted[0] 214 | - margin 215 | ) 216 | 217 | grid = ( 218 | self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 219 | ) 220 | grid = torch.concatenate( 221 | [ 222 | grid[:1] 223 | - uniform_step 224 | * torch.arange( 225 | self.spline_order, 0, -1, device=x.device 226 | ).unsqueeze(1), 227 | grid, 228 | grid[-1:] 229 | + uniform_step 230 | * torch.arange( 231 | 1, self.spline_order + 1, device=x.device 232 | ).unsqueeze(1), 233 | ], 234 | dim=0, 235 | ) 236 | 237 | self.grid.copy_(grid.T) # type: ignore 238 | self.spline_weight.data.copy_( 239 | self.curve2coeff(x, unreduced_spline_output) 240 | ) 241 | 242 | def regularization_loss( 243 | self, regularize_activation=1.0, regularize_entropy=1.0 244 | ): 245 | """ 246 | Compute the regularization loss. 247 | 248 | This is a dumb simulation of the original L1 regularization as 249 | stated in the paper, since the original one requires computing 250 | absolutes and entropy from the expanded 251 | (batch, in_features, out_features) intermediate tensor, which is 252 | hidden behind the F.linear function if we want an memory 253 | efficient implementation. 254 | 255 | The L1 regularization is now computed as mean absolute value of the 256 | spline weights. The authors implementation also includes this term 257 | in addition to the sample-based regularization. 258 | """ 259 | l1_fake = self.spline_weight.abs().mean(-1) 260 | regularization_loss_activation = l1_fake.sum() 261 | p = l1_fake / regularization_loss_activation 262 | regularization_loss_entropy = -torch.sum(p * p.log()) 263 | return ( 264 | regularize_activation * regularization_loss_activation 265 | + regularize_entropy * regularization_loss_entropy 266 | ) 267 | 268 | 269 | class KAN(torch.nn.Module): 270 | def __init__( 271 | self, 272 | width, 273 | grid=3, 274 | k=3, 275 | noise_scale=0.1, 276 | noise_scale_base=1.0, 277 | scale_spline=1.0, 278 | base_fun=torch.nn.SiLU, 279 | grid_eps=0.02, 280 | grid_range=[-1, 1], 281 | bias_trainable=True, 282 | ): 283 | super(KAN, self).__init__() 284 | self.grid_size = grid 285 | self.spline_order = k 286 | self.bias_trainable = bias_trainable # TODO 287 | 288 | self.layers = torch.nn.ModuleList() 289 | for in_features, out_features in zip(width, width[1:]): 290 | self.layers.append( 291 | KANLinear( 292 | in_features, 293 | out_features, 294 | grid_size=grid, 295 | spline_order=grid, 296 | scale_noise=noise_scale, 297 | scale_base=noise_scale_base, 298 | scale_spline=scale_spline, 299 | base_activation=base_fun, 300 | grid_eps=grid_eps, 301 | grid_range=grid_range, 302 | ) 303 | ) 304 | 305 | def forward(self, x: torch.Tensor, update_grid=False): 306 | B, C, T = x.shape 307 | 308 | x = x.view(-1, T) 309 | 310 | for layer in self.layers: 311 | if update_grid: 312 | layer.update_grid(x) 313 | x = layer(x) 314 | 315 | U = x.shape[1] 316 | 317 | x = x.view(B, C, U) 318 | 319 | return x 320 | 321 | def regularization_loss( 322 | self, regularize_activation=1.0, regularize_entropy=1.0 323 | ): 324 | return sum( 325 | layer.regularization_loss( 326 | regularize_activation, regularize_entropy 327 | ) 328 | for layer in self.layers 329 | ) --------------------------------------------------------------------------------