├── .github ├── ISSUE_TEMPLATE │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── backendCI.yml ├── README.md ├── backend ├── .gitignore ├── Dockerfile ├── README.md ├── airflow │ ├── dags │ │ └── training_pipeline.py │ └── module │ │ ├── load_data.py │ │ └── train_model.py ├── app │ ├── bert_retrieval.py │ ├── bm25_retrieval.py │ ├── filter.py │ ├── generate.py │ ├── main.py │ └── search.py ├── requirements.txt ├── router │ └── router.py └── tests │ ├── test_filter.py │ ├── test_generate.py │ ├── test_retrieval.py │ └── test_search.py ├── data_pipeline ├── .gitignore ├── README.md ├── crawler.py ├── generate │ ├── backup_prompts.py │ ├── generate_bard.py │ ├── generate_gpt.py │ ├── parse.py │ └── prompts.pkl ├── preprocessor.py ├── preprocessor_v2.py ├── qa_crawler.py ├── requirements.txt ├── spellchecker.py └── utils.py ├── frontend ├── .gitignore ├── README.md ├── package-lock.json ├── package.json ├── public │ ├── index.html │ ├── lawbot.ico │ ├── lawbot.png │ ├── manifest.json │ └── robots.txt ├── src │ ├── App.js │ ├── asset │ │ ├── lawbot.png │ │ └── spinner.gif │ ├── components │ │ ├── ChattingSideBar.js │ │ ├── Header.js │ │ ├── Loader.js │ │ ├── SimilarPrecedent.js │ │ ├── SimilarPrecedentComponents │ │ │ └── PrecedentCard.js │ │ └── TypingAnimation.js │ ├── index.css │ ├── index.js │ └── reportWebVitals.js └── tailwind.config.js ├── model ├── .gitignore ├── BERT │ ├── inference │ │ ├── inference.py │ │ └── utils.py │ ├── make_vector_dataset │ │ └── preprocessing_law_data.py │ └── preprocessing.py ├── Filter │ ├── data_preprocessing.py │ ├── dataloader.py │ ├── infer.py │ ├── train.py │ └── utils.py ├── LLM │ ├── evaluation │ │ ├── data_preprocessing.py │ │ ├── dialogue_evaluation.py │ │ ├── eval_data_alpaca │ │ │ └── dataset_val.csv │ │ ├── eval_data_legal │ │ │ ├── easy_law_val.csv │ │ │ └── legal_QA_val.csv │ │ ├── evaluate_mertrics.py │ │ ├── petf_ppl.py │ │ └── ppl.py │ ├── inference │ │ └── infer.py │ └── train │ │ ├── data_preprocessing.py │ │ ├── load_model.py │ │ ├── train.py │ │ └── utils.py ├── README.md ├── Retrieval │ ├── bert_retrieval │ │ ├── data_preprocessing.py │ │ └── inference.py │ └── bm25_retrieval │ │ ├── data_preprocessing.py │ │ ├── read_json.py │ │ ├── retrieval.py │ │ ├── retrieval_bm25.py │ │ └── retrieval_main.py └── requirements.txt └── prototype ├── .gitignore ├── README.md ├── package-lock.json ├── package.json ├── public ├── index.html ├── lawbot.ico ├── lawbot.png ├── manifest.json └── robots.txt ├── src ├── App.js ├── asset │ └── lawbot.png ├── components │ ├── ChattingSideBar.js │ ├── Header.js │ ├── SimilarPrecedent.js │ └── SimilarPrecedentComponents │ │ └── PrecedentCard.js ├── index.css ├── index.js └── reportWebVitals.js └── tailwind.config.js /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEAT]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Background 11 | - 12 | 13 | ## Todo 14 | - [ ] Todo 1 15 | - [ ] Todo 2 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | - 3 | 4 | ## Change Log 5 | - 6 | 7 | ## Further information 8 | - 9 | 10 | ## To Reviewer 11 | - 12 | 13 | ## Issue Tags 14 | - Closed | Fixed: # 15 | - See also : # 16 | -------------------------------------------------------------------------------- /.github/workflows/backendCI.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: CI for backend 5 | 6 | on: 7 | push: 8 | paths: "backend/**" 9 | branches: [ "develop" ] 10 | pull_request: 11 | paths: "backend/**" 12 | branches: [ "develop" ] 13 | 14 | jobs: 15 | CI: 16 | runs-on: self-hosted 17 | env: 18 | working-directory: /opt/ml/level3_nlp_finalproject-nlp-08 19 | 20 | steps: 21 | - name: Update Code 22 | run: git pull 23 | working-directory: ${{ env.working-directory }} 24 | 25 | - name: install dependencies 26 | run: | 27 | source final_project/bin/activate 28 | pip install -r requirements.txt 29 | working-directory: "${{ env.working-directory }}/backend" 30 | 31 | - name: Test with pytest 32 | run: | 33 | source final_project/bin/activate 34 | pytest 35 | working-directory: "${{ env.working-directory }}/backend" 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

법률 조언 웹 서비스 ‘LawBot’

4 | 5 |          [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97-Models%20on%20Hub-yellow)](https://huggingface.co/models?filter=keytotext) [![FastAPI](https://img.shields.io/badge/-FastAPI-red?logo=fastapi&labelColor=white)](https://github.com/gagan3012/keytotext#api) 6 | 7 |

8 | 9 | Screenshot 2023-07-31 at 11 27 53 AM 10 | 11 |
12 | 13 | >LawBot은 유사 판례 및 법률 조항과 함께 가벼운 법률 상담 서비스를 제공합니다. LLM 모델의 기학습된 방대한 정보와 더불어 fine-tuning에 사용한 법률 지식을 사용하여 다양하고 특수한 상황에 유연하게 대응하여 답변을 생성할 수 있으며, 유사도 기반의 AI 모델을 이용하여 관련된 내용에 대한 유사 판례를 제공합니다. 14 | 15 |
16 | 17 | 🖥 **[LawBot 웹서비스 체험하기](http://yoonseul.link)** 18 | 19 | **※ 본 웹 서비스는 포스트 세션 종료일인 2023년 8월 18일까지만 이용하실 수 있으니 참고바랍니다. 종료일까지 서비스 고도화 및 성능 개선 작업이 이루어질 예정입니다.** 20 | 21 |
22 | 23 | ## ⌘ Project BackGround 24 | 25 | ### 기획 의도 및 기대효과 26 | 27 | Screenshot 2023-07-31 at 11 36 29 AM 28 | 29 |
30 | 31 | * **`배경`** : 국제적으로 리걸 테크 산업은 매우 빠르게 발전하고 있으며, 국내에서도 관련 서비스 수요가 꾸준히 증가하고 있습니다. 그러나 법률 용어나 법률 문장은 해석하기 어려워 일반인들에게 이해도는 낮은 편이며, 관련 정보를 얻기 위해서는 많은 비용이 필요합니다. 32 | 33 | * **`목표`**: 윤슬 팀은 이러한 상황을 해결하기 위해 법률 상황에 대해 이해하기 쉬운 가이드라인을 제시하고, 관련된 유사 판례 및 법률 조항을 제공함으로써 법의 장벽을 낮출 수 있는 가벼운 법률 상담 서비스를 제공하고자 합니다. 34 | 35 |
36 | 37 | ### LawBot 서비스의 차별점 38 | 39 | - 기존의 주류 Legal Tech 서비스는 **변호사를 매칭시켜주거나 여전히 어려운 법률조항 및 판례에 대한 직접적인 검색**만을 제공합니다. LawBot은 이와 다르게 AI 모델을 이용하여 법적 분쟁 상황에 대한 유사판례를 찾아주고, 빠른 시간 내에 가벼운 가이드라인을 직접 생성해서 유저에게 제공한다는 점에서 차별점을 가지고 있습니다. 40 | 41 |
42 | 43 | ## ⚙️ Use Case 44 | 45 | ![예시-케이스](https://github.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/assets/81630351/03d80449-819c-4c63-9c1c-6e0492230ba8) 46 | 47 | 48 | 49 | >1. 웹 서버 접속 50 | >2. 메시지 프롬프트 창에 자신이 처한 법적 분쟁 상황 및 연관된 질문 입력 51 | >3. AI 모델이 상황 맥락을 이해하여 가이드라인을 메시지 형태로 제공 52 | >4. 우측 사이드바에서 사용자가 입력한 상황과 비슷한 최대 3가지의 유사 판례 제공 53 | >5. 링크를 통해 법 조항 등 관련된 법령 정보를 직접 확인 가능 54 | 55 |
56 | 57 | 58 | ## 🧑🏻‍💻 Team Introduction & Members 59 | 60 | ### 💬 팀 소개 61 | >**조화와 지속 가능한 성장을 추구하는 팀 `윤슬`입니다!**

**팀 개개인 모두 주어진 위치에 상관없이 모든 일에 `오너십`을 가지고 `적극적으로 참여`하는 것을 최우선으로 생각하였습니다. 좋은 동료가 되기 위해 치열하게 고민하고, 학습하고, 성장하고 있습니다.** 62 | 63 |
64 | 65 | ### 👨🏼‍💻 Members 66 | 강민재|김주원|김태민|신혁준|윤상원| 67 | :-:|:-:|:-:|:-:|:-: 68 | ||||| 69 | |||| 70 | ||||| 71 | 72 |
73 | 74 | ### 👨🏼‍💻 역할 분담 75 | Screenshot 2023-08-20 at 11 21 39 PM 76 | 77 |
78 | 79 | ## 🗓️ Project TimeLine 80 | - 초기에 핵심 기능을 우선적으로 개발하였고, 그 후 팀원들의 의견을 수렴하여 지속적으로 발전시키는 애자일적인 접근 방식을 적용하여 서비스 고도화 작업을 진행했습니다. 81 | 82 | Screenshot 2023-07-31 at 11 54 44 AM 83 | 84 |
85 | 86 | ## ⌘ Service Archiecture 87 | Screenshot 2023-07-31 at 12 03 17 PM 88 | 89 |
90 | 91 | >네이버 커넥트 재단으로부터 제공받은 V100 서버 4대를 모두 활용하기 위해 첫 설계 때부터 **서비스 확장이 쉬운 마이크로 서비스 아키텍처**를 고려했습니다. 또한 서비스 간의 상호 의존도를 낮춰 서버에 장애가 발생할 경우 전체 서비스가 중단되는 것을 방지하고자 하여 위와 같이 웹 서버, 모델 서버를 독립적으로 분리하고 API를 통해 서로 통신하는 구조로 설계했습니다. 이를 통해 한 대의 V100 서버에 장애가 발생하더라도 나머지 서비스는 전혀 영향을 받지 않고 서비스를 제공할 수 있습니다. 92 | 93 |
94 | 95 | - 아래는 구현한 내용이고 관련된 내용을 확인하실 수 있는 링크입니다. 각 링크에서 왜 해당 기능을 구현을 했으며, 어떤 것들을 중점적으로 고려해서 개발했는지 살펴보실 수 있습니다. 96 | 97 | - 🛠️ [CI 파이프라인 구축](https://uomnf97.notion.site/CI-f687f03b192f49fa80d451f8850a03f6?pvs=4) 98 | - ✍🏻 [로드 밸런싱 적용](https://uomnf97.notion.site/96f697aab756407aadbe51582a0a68d4?pvs=4) 99 | - ✍🏻 [Airflow를 이용한 모델 학습 파이프라인](https://uomnf97.notion.site/64a55c1e1f4a4ff985343a97b224a101?pvs=4) 100 | - 🛠️ [Auto Scaling을 통한 Failover](https://uomnf97.notion.site/Auto-Scaling-Failover-fa0ab424dcda44739ababe1eb719a106?pvs=4) 101 | 102 |
103 | 104 | ## 💿 Data 105 | Screenshot 2023-07-31 at 12 30 46 PM 106 | 107 | - 데이터는 위와 같이 단계별로 나누어 목표를 설정하고 데이터를 탐색, 수집, EDA 및 전처리, 생성모델을 통한 증강을 하여 학습데이터 셋을 구축하였습니다. 108 | 109 |
110 | 111 | Screenshot 2023-07-31 at 12 31 04 PM 112 | 113 | - 데이터 파이프라인을 도식화한 순환 DMOps 구조도는 위와 같습니다. 114 | 115 |
116 | 117 | 118 | ### 1️⃣ 법률 QA 데이터 119 | 120 | | 데이터셋 이름 | 데이터 개수 | 출처 | 121 | | :---: | :---: | :---: | 122 | | easylaw_kr | 2,195 | https://huggingface.co/datasets/juicyjung/easylaw_kr | 123 | | LegalQA | 1,830 | https://github.com/haven-jeon/LegalQA | 124 | |대한법률구조공단의 법률상담사례 데이터|9994|https://www.klac.or.kr/legalinfo/counsel.do| 125 | |대한법률구조공단의 국내 사이버상담 데이터|2463|https://www.klac.or.kr/legalstruct/cyberConsultation.do| 126 | |Open AI GPT증강|8666| - | 127 | 128 |
129 | 130 | ### 2️⃣ 판례 데이터 131 | 132 | | 데이터셋 이름 | 데이터 개수 | 출처 | 133 | | :---: | :---: | :---: | 134 | | 법률/규정 (판결서, 약관 등) 텍스트 분석 데이터 |77382| https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=580 | 135 | 136 |
137 | 138 | - 데이터 전처리, EDA, 증강에 대한 자세한 내용은 아래에서 확인하실 수 있습니다. 139 | 140 | - 💿[데이터 수집](https://uomnf97.notion.site/a5f4628cdc7b4ce3928ce626c727ff32?pvs=4) 141 | - 📈[EDA 및 전처리](https://uomnf97.notion.site/EDA-b53d75aaa7574ea586cbd6cdbd5c755a?pvs=4) 142 | - 🛠️[생성모델을 통한 데이터 증강](https://uomnf97.notion.site/5e8e9ecc27694d7497fbad68f72136c0?pvs=4) 143 | 144 |
145 | 146 | ## 📊 Model 147 | 148 | ### Overview 149 | Screenshot 2023-07-31 at 12 17 55 PM 150 | 151 |
152 | 153 | > 먼저 유저가 입력하면 모델에서는 Question Filtering 모델을 통해 법률적인 질문인지 아닌지 구분하게 됩니다. 법률적인 질문이라면 Similar Precedent Mode과 Law LLM Model 통해 유사판례와 법률적인 조언을 생성합니다. 154 | 155 |
156 | 157 | ### 1️⃣ Question Filtering Model 158 | Screenshot 2023-07-31 at 12 47 58 PM 159 | 160 |
161 | 162 | ### 2️⃣ Similar Precedent Model 163 | Screenshot 2023-07-31 at 12 48 08 PM 164 | 165 |
166 | 167 | ### 3️⃣LLM 모델 168 | 169 | Screenshot 2023-07-31 at 12 48 28 PM 170 | 171 | 172 | - 구축한 모델 리스트(활용한 데이터 + Backbone 모델)
Screenshot 2023-08-21 at 5 35 27 PM 173 | 174 | #### 평가지표 175 | - LLM 모델은 명확한 평가지표가 없어 직접 평가지표를 만들어 평가했습니다. 평가지표를 만들 때 고려했던 부분은 도메인 특성상 법률적인 정확도가 중요하므로 법률적인 정확도와 언어의 자연스러움 두가지 모두 평가할 수 있도록 metric을 제작하여 평가했습니다. 176 | 177 |
178 | 179 | **[ Dialogue Evaluation Metric ]** 180 | 181 | - Kullm 모델의 Dialogue Evaluation Metric 평가요소를 도메인에 맞게 변형하여 활용하였고, 해당 지표를 직접 변호사에게 의뢰하여 답변을 평가했습니다. 추가로 모든 모델들은 모델 A, B, C ... 등 모델 이름을 가리는 블라인드 평가를 진행했으며, 명확한 평가지표를 만들기 위해 ChatGPT, BARD 모델과 함께 평가를 진행하였습니다. 최종적으로 저희가 구축한 **kfkas/legal-llama-2-ko-7b-Chat** 모델이 가장 좋은 성능을 보였습니다. 182 | 183 | Screenshot 2023-07-31 at 12 56 06 PM 184 | 185 |
186 | 187 | **[ Perplexity ]** 188 | - 얼마나 생성모델이 법률적인 용어를 생성해내는지 평가하기 위해 Perplexity 평가지표를 활용하였습니다. 낮을 수록 좋은 값을 나타내는 metric인데, 크롤링 데이터와 탐색한 데이터로 학습한 모델들이 대체로 높은 성능을 나타내는 경향을 보였습니다. 189 |
Screenshot 2023-07-31 at 12 58 26 PM 190 | 191 | 192 | 193 | ## 💻 Getting Started 194 | 195 | > 아래 Readme를 통해 직접 프로젝트를 실행할 수 있으며, 구현된 코드를 살펴볼 수 있습니다. 196 | 197 | ### 📊 Model 198 | - [Model](model) / [README.md](model/README.md) 199 | 200 | ### 💽 Data 201 | - [Data Pipeline](data_pipeline) / [README.md](data_pipeline/README.md) 202 | 203 | ### 🎨 Frontend 204 | - [Frontend](frontend) / [README.md](frontend/README.md) 205 | 206 | ### 💻 Backend 207 | - [Backend](backend) / [README.md](backend/README.md) 208 | 209 |
210 | 211 | ## 📚 Further Information 212 | 213 | ### 1️⃣ 개발 스택 및 개발 환경 214 | Screenshot 2023-07-31 at 12 44 34 PM 215 | 216 |
217 | 218 | ### 2️⃣ 협업 Tools 219 | 220 | - 노션 : 221 | - Kanban Board를 이용하여 체계적으로 To do List 관리 222 | - 노션 협업기구를 활용해 회의 및 기록 체계화åScreenshot 2023-07-31 at 12 38 10 PM 223 | 224 | - GitHub : 225 | - GitHub Flow를 이용하여 브랜치 전략 수립 226 | - PR Template, Issue Template을 이용하여 체계젹으로 관리. 227 | - Ground Rule을 정해 모두 일관된 Commit convention을 유지Screenshot 2023-07-31 at 12 38 50 PM 228 | 229 |
230 | 231 | ### 3️⃣ Links : 232 | 233 | - [개발 관련 Notion 링크](https://uomnf97.notion.site/NLP-08-LawBot-b2dfef92f666458583d6b459af53aa66?pvs=4) 234 | - [Youtube 발표 영상](https://www.youtube.com/watch?v=fgboxtWM4B4) 235 | - [발표 영상 자료](https://github.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/files/12394029/NLP_08_.LawBot.pdf) 236 | 237 | -------------------------------------------------------------------------------- /backend/.gitignore: -------------------------------------------------------------------------------- 1 | final_project 2 | __pycache__ 3 | model 4 | test.ipynb 5 | data 6 | nohup.out 7 | logs 8 | airflow.cfg 9 | airflow.db 10 | webserver_config.py 11 | .env 12 | airflow-webserver.pid -------------------------------------------------------------------------------- /backend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | RUN python3 -m pip install --upgrade pip 4 | 5 | COPY ./requirements.txt /ws/requirements.txt 6 | 7 | WORKDIR /ws 8 | 9 | RUN pip install -r requirements.txt 10 | 11 | COPY ./app/ /ws 12 | 13 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0"] -------------------------------------------------------------------------------- /backend/README.md: -------------------------------------------------------------------------------- 1 | # LawBot - Backend 2 | 3 | ## 1. environtment 4 | * FastAPI 5 | * python 3.8.x (tested on 3.8.5) 6 | 7 | ## 2. install 8 | 9 | * Create your own virtual envorionment. 10 | ```bash 11 | $ pip install virtualenv 12 | $ virtualenv 13 | $ source /bin/activate 14 | ``` 15 | 16 | * Install modules on your virtual environment. 17 | ```bash 18 | $ pip install -r requirements.txt 19 | ``` 20 | 21 | ## 3. Execute 22 | 23 | ### Model Server 24 | ```bash 25 | $ cd app 26 | $ uvicorn main:app --host=0.0.0.0 --reload 27 | ``` 28 | ### Test 29 | ```bash 30 | $ pytest 31 | ``` 32 | ### Airflow 33 | ```bash 34 | $ cd airflow 35 | $ airflow db init 36 | $ airflow scheduler 37 | ``` 38 | ## 4. Document 39 | 1. Execute server(local) 40 | 2. Goto http://localhost:8000/docs 41 | -------------------------------------------------------------------------------- /backend/airflow/dags/training_pipeline.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import sys 4 | 5 | import pandas as pd 6 | import transformers 7 | from airflow import DAG 8 | from airflow.operators.python import PythonOperator, PythonVirtualenvOperator 9 | from datasets import Dataset, concatenate_datasets 10 | from dotenv import load_dotenv 11 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 12 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 13 | BitsAndBytesConfig) 14 | 15 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 16 | 17 | import torch 18 | from module.load_data import load_train_eval_data 19 | from torch import multiprocessing 20 | 21 | # multiprocessing.set_start_method("forkserver") 22 | 23 | load_dotenv() 24 | huggingface_write_token = os.getenv("HUGGINGFACE_WRITE_TOKEN") 25 | 26 | class Autodata: 27 | def __init__(self, data_path, max_length=1024, tokenizer=None): 28 | self.max_length = max_length 29 | self.tokenizer = tokenizer 30 | self.concat_dataset = self.concat_datasets(data_path) 31 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset) 32 | 33 | def concat_datasets(self, data_path): 34 | datasets = [] 35 | dataset = pd.read_csv(data_path) 36 | dataframe = dataset[["question", "answer"]] 37 | dataset = Dataset.from_pandas(dataframe) 38 | datasets.append(dataset) 39 | 40 | combined_dataset = concatenate_datasets(datasets) 41 | 42 | return combined_dataset 43 | 44 | def tokenizing_dataset(self, dataset): 45 | data = dataset.map( 46 | lambda x: { 47 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>" 48 | } 49 | ) 50 | data = data.map( 51 | lambda samples: self.tokenizer( 52 | samples["text"], 53 | truncation=True, 54 | max_length=self.max_length, 55 | padding=False, 56 | return_tensors=None, 57 | ), 58 | batched=True, 59 | ) 60 | 61 | return data.shuffle() 62 | 63 | def load_model(model_name): 64 | # bnb_config = BitsAndBytesConfig( 65 | # load_in_4bit=True, 66 | # bnb_4bit_use_double_quant=True, 67 | # bnb_4bit_quant_type="nf4", 68 | # bnb_4bit_compute_dtype=torch.bfloat16, 69 | # ) 70 | tokenizer = AutoTokenizer.from_pretrained(model_name) 71 | model = AutoModelForCausalLM.from_pretrained( 72 | model_name 73 | ) 74 | model.gradient_checkpointing_enable() 75 | model = prepare_model_for_kbit_training(model) 76 | 77 | config = LoraConfig( 78 | r=8, 79 | lora_alpha=32, 80 | target_modules=["query_key_value"], 81 | lora_dropout=0.05, 82 | bias="none", 83 | task_type="CAUSAL_LM", 84 | ) 85 | 86 | model = get_peft_model(model, config) 87 | print_trainable_parameters(model) 88 | 89 | return model, tokenizer 90 | 91 | 92 | def print_trainable_parameters(model): 93 | """ 94 | Prints the number of trainable parameters in the model. 95 | """ 96 | trainable_params = 0 97 | all_param = 0 98 | for _, param in model.named_parameters(): 99 | all_param += param.numel() 100 | if param.requires_grad: 101 | trainable_params += param.numel() 102 | print( 103 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" 104 | ) 105 | 106 | def train_model(): 107 | import datetime 108 | import os 109 | import sys 110 | 111 | import pandas as pd 112 | import torch 113 | import transformers 114 | from airflow import DAG 115 | from airflow.operators.python import (PythonOperator, 116 | PythonVirtualenvOperator) 117 | from datasets import Dataset, concatenate_datasets 118 | from dotenv import load_dotenv 119 | from peft import (LoraConfig, get_peft_model, 120 | prepare_model_for_kbit_training) 121 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 122 | BitsAndBytesConfig) 123 | 124 | class Autodata: 125 | def __init__(self, data_path, max_length=1024, tokenizer=None): 126 | self.max_length = max_length 127 | self.tokenizer = tokenizer 128 | self.concat_dataset = self.concat_datasets(data_path) 129 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset) 130 | 131 | def concat_datasets(self, data_path): 132 | datasets = [] 133 | dataset = pd.read_csv(data_path) 134 | dataframe = dataset[["question", "answer"]] 135 | dataset = Dataset.from_pandas(dataframe) 136 | datasets.append(dataset) 137 | 138 | combined_dataset = concatenate_datasets(datasets) 139 | 140 | return combined_dataset 141 | 142 | def tokenizing_dataset(self, dataset): 143 | data = dataset.map( 144 | lambda x: { 145 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>" 146 | } 147 | ) 148 | data = data.map( 149 | lambda samples: self.tokenizer( 150 | samples["text"], 151 | truncation=True, 152 | max_length=self.max_length, 153 | padding=False, 154 | return_tensors=None, 155 | ), 156 | batched=True, 157 | ) 158 | 159 | return data.shuffle() 160 | 161 | model_id = "nlpai-lab/kullm-polyglot-5.8b-v2" 162 | tokenizer = AutoTokenizer.from_pretrained(model_id) 163 | bnb_config = BitsAndBytesConfig( 164 | load_in_4bit=True, 165 | bnb_4bit_use_double_quant=True, 166 | bnb_4bit_quant_type="nf4", 167 | bnb_4bit_compute_dtype=torch.bfloat16, 168 | ) 169 | model = AutoModelForCausalLM.from_pretrained( 170 | model_id, quantization_config=bnb_config, device_map={"": 0} 171 | ) 172 | model.gradient_checkpointing_enable() 173 | model = prepare_model_for_kbit_training(model) 174 | 175 | config = LoraConfig( 176 | r=8, 177 | lora_alpha=32, 178 | target_modules=["query_key_value"], 179 | lora_dropout=0.05, 180 | bias="none", 181 | task_type="CAUSAL_LM", 182 | ) 183 | 184 | model = get_peft_model(model, config) 185 | tokenizer.pad_token = tokenizer.eos_token 186 | # BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data") 187 | BASE_DIR = os.path.join("/opt/ml/level3_nlp_finalproject-nlp-08/backend/airflow", "data") 188 | TRAIN_DATA_PATH = os.path.join(BASE_DIR, "train_data.csv") 189 | EVAL_DATA_PATH = os.path.join(BASE_DIR, "eval_data.csv") 190 | train_data = Autodata(data_path=TRAIN_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset 191 | val_data = Autodata(data_path=EVAL_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset 192 | trainer = transformers.Trainer( 193 | model=model, 194 | train_dataset=train_data, 195 | eval_dataset=val_data, 196 | args=transformers.TrainingArguments( 197 | per_device_train_batch_size=16, 198 | gradient_accumulation_steps=1, 199 | num_train_epochs=6, 200 | learning_rate=1e-4, 201 | fp16=True, 202 | logging_steps=10, 203 | save_strategy="epoch", 204 | evaluation_strategy="epoch", 205 | output_dir="./model_outputs", 206 | optim="paged_adamw_8bit", 207 | ), 208 | data_collator=transformers.DataCollatorForLanguageModeling( 209 | tokenizer, mlm=False 210 | ), 211 | ) 212 | model.config.use_cache = ( 213 | False # silence the warnings. Please re-enable for inference! 214 | ) 215 | trainer.train() 216 | 217 | push_model_id = "YoonSeul/LawBot-airflow-test" 218 | 219 | model.push_to_hub( 220 | push_model_id, use_temp_dir=True, use_auth_token=huggingface_write_token 221 | ) 222 | print(f"{push_model_id} 모델 업로드 완료!") 223 | 224 | with DAG( 225 | dag_id="training_pipeline", 226 | description="train the model periodically", 227 | start_date=datetime.datetime(2023,7,27), 228 | schedule_interval="0 0 * * 5", 229 | tags=["LLM"], 230 | ) as dag: 231 | 232 | load_data = PythonOperator( 233 | task_id="load_data", 234 | python_callable=load_train_eval_data, 235 | depends_on_past=True, 236 | owner="SangwonYoon", 237 | retries=3, 238 | retry_delay=datetime.timedelta(minutes=5) 239 | ) 240 | 241 | training_model = PythonVirtualenvOperator( 242 | task_id="train_model", 243 | python_callable=train_model, 244 | depends_on_past=True, 245 | owner="SangwonYoon", 246 | retries=3, 247 | retry_delay=datetime.timedelta(minutes=5), 248 | system_site_packages=True, 249 | # requirements=["torch==2.0.1"], 250 | ) 251 | 252 | # test = PythonVirtualenvOperator( 253 | # task_id="test", 254 | # python_callable=test_torch, 255 | # depends_on_past=True, 256 | # owner="SangwonYoon", 257 | # retries=3, 258 | # retry_delay=datetime.timedelta(minutes=1), 259 | # system_site_packages=True, 260 | # # requirements=["torch==2.0.1"], 261 | # python_version=3.8, 262 | # ) 263 | 264 | load_data >> training_model 265 | # test 266 | 267 | 268 | 269 | -------------------------------------------------------------------------------- /backend/airflow/module/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from datasets import load_dataset 5 | from dotenv import load_dotenv 6 | 7 | load_dotenv() 8 | huggingface_read_token = os.getenv("HUGGINGFACE_READ_TOKEN") 9 | 10 | def load_train_data(): 11 | data = load_dataset("YoonSeul/legal-GPT-BARD-train_v3", use_auth_token=huggingface_read_token) 12 | df = pd.DataFrame(data) 13 | 14 | questions = [] 15 | answers = [] 16 | 17 | for i in df.iterrows(): 18 | questions.append(i[1]["train"]["instruction"]) 19 | answers.append(i[1]["train"]["output"]) 20 | 21 | train_datasets = { 22 | "question": questions, 23 | "answer": answers 24 | } 25 | 26 | train_df = pd.DataFrame(train_datasets) 27 | BASE_PATH = os.path.join(os.path.dirname(os.path.abspath((os.path.dirname(__file__)))), "data") 28 | SAVE_PATH = os.path.join(BASE_PATH, "train_data.csv") 29 | os.makedirs(BASE_PATH, exist_ok=True) 30 | train_df.to_csv(SAVE_PATH) 31 | 32 | def load_eval_data(): 33 | data = load_dataset("YoonSeul/legal-GPT-BARD-val_v3", use_auth_token=huggingface_read_token) 34 | df = pd.DataFrame(data) 35 | 36 | questions = [] 37 | answers = [] 38 | 39 | for i in df.iterrows(): 40 | questions.append(i[1]["train"]["instruction"]) 41 | answers.append(i[1]["train"]["output"]) 42 | 43 | train_datasets = { 44 | "question": questions, 45 | "answer": answers 46 | } 47 | 48 | train_df = pd.DataFrame(train_datasets) 49 | BASE_PATH = os.path.join(os.path.dirname(os.path.abspath((os.path.dirname(__file__)))), "data") 50 | SAVE_PATH = os.path.join(BASE_PATH, "eval_data.csv") 51 | os.makedirs(BASE_PATH, exist_ok=True) 52 | train_df.to_csv(SAVE_PATH) 53 | 54 | def load_train_eval_data(): 55 | load_train_data() 56 | load_eval_data() 57 | 58 | if __name__ == "__main__": 59 | load_train_eval_data() 60 | -------------------------------------------------------------------------------- /backend/airflow/module/train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch 5 | import transformers 6 | from datasets import Dataset, concatenate_datasets 7 | from dotenv import load_dotenv 8 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 9 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 10 | BitsAndBytesConfig) 11 | 12 | load_dotenv() 13 | huggingface_write_token = os.getenv("HUGGINGFACE_WRITE_TOKEN") 14 | 15 | class Autodata: 16 | def __init__(self, data_path, max_length=1024, tokenizer=None): 17 | self.max_length = max_length 18 | self.tokenizer = tokenizer 19 | self.concat_dataset = self.concat_datasets(data_path) 20 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset) 21 | 22 | def concat_datasets(self, data_path): 23 | datasets = [] 24 | dataset = pd.read_csv(data_path) 25 | dataframe = dataset[["question", "answer"]] 26 | dataset = Dataset.from_pandas(dataframe) 27 | datasets.append(dataset) 28 | 29 | combined_dataset = concatenate_datasets(datasets) 30 | 31 | return combined_dataset 32 | 33 | def tokenizing_dataset(self, dataset): 34 | data = dataset.map( 35 | lambda x: { 36 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>" 37 | } 38 | ) 39 | data = data.map( 40 | lambda samples: self.tokenizer( 41 | samples["text"], 42 | truncation=True, 43 | max_length=self.max_length, 44 | padding=False, 45 | return_tensors=None, 46 | ), 47 | batched=True, 48 | ) 49 | 50 | return data.shuffle() 51 | 52 | def load_model(model_name): 53 | # bnb_config = BitsAndBytesConfig( 54 | # load_in_4bit=True, 55 | # bnb_4bit_use_double_quant=True, 56 | # bnb_4bit_quant_type="nf4", 57 | # bnb_4bit_compute_dtype=torch.bfloat16, 58 | # ) 59 | tokenizer = AutoTokenizer.from_pretrained(model_name) 60 | model = AutoModelForCausalLM.from_pretrained( 61 | model_name 62 | ) 63 | model.gradient_checkpointing_enable() 64 | model = prepare_model_for_kbit_training(model) 65 | 66 | config = LoraConfig( 67 | r=8, 68 | lora_alpha=32, 69 | target_modules=["query_key_value"], 70 | lora_dropout=0.05, 71 | bias="none", 72 | task_type="CAUSAL_LM", 73 | ) 74 | 75 | model = get_peft_model(model, config) 76 | print_trainable_parameters(model) 77 | 78 | return model, tokenizer 79 | 80 | 81 | def print_trainable_parameters(model): 82 | """ 83 | Prints the number of trainable parameters in the model. 84 | """ 85 | trainable_params = 0 86 | all_param = 0 87 | for _, param in model.named_parameters(): 88 | all_param += param.numel() 89 | if param.requires_grad: 90 | trainable_params += param.numel() 91 | print( 92 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" 93 | ) 94 | 95 | def train_model(): 96 | model_id = "nlpai-lab/kullm-polyglot-5.8b-v2" 97 | model, tokenizer = load_model(model_id) 98 | tokenizer.pad_token = tokenizer.eos_token 99 | BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data") 100 | TRAIN_DATA_PATH = os.path.join(BASE_DIR, "train_data.csv") 101 | EVAL_DATA_PATH = os.path.join(BASE_DIR, "eval_data.csv") 102 | train_data = Autodata(data_path=TRAIN_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset 103 | val_data = Autodata(data_path=EVAL_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset 104 | trainer = transformers.Trainer( 105 | model=model, 106 | train_dataset=train_data, 107 | eval_dataset=val_data, 108 | args=transformers.TrainingArguments( 109 | per_device_train_batch_size=16, 110 | gradient_accumulation_steps=1, 111 | num_train_epochs=6, 112 | learning_rate=1e-4, 113 | fp16=True, 114 | logging_steps=10, 115 | save_strategy="epoch", 116 | evaluation_strategy="epoch", 117 | output_dir="./model_outputs", 118 | optim="paged_adamw_8bit", 119 | ), 120 | data_collator=transformers.DataCollatorForLanguageModeling( 121 | tokenizer, mlm=False 122 | ), 123 | ) 124 | model.config.use_cache = ( 125 | False # silence the warnings. Please re-enable for inference! 126 | ) 127 | trainer.train() 128 | 129 | push_model_id = "YoonSeul/LawBot-airflow-test" 130 | 131 | model.push_to_hub( 132 | push_model_id, use_temp_dir=True, use_auth_token=huggingface_write_token 133 | ) 134 | print(f"{push_model_id} 모델 업로드 완료!") 135 | 136 | 137 | if __name__ == "__main__": 138 | train_model() 139 | -------------------------------------------------------------------------------- /backend/app/bert_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from datasets import Dataset, concatenate_datasets 8 | from sentence_transformers import SentenceTransformer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from tqdm import tqdm 11 | 12 | 13 | class Autodata: 14 | def __init__(self, data_folder): 15 | self.data_foloder = data_folder 16 | self.concat_dataset = self.concat_datasets(self.data_foloder) 17 | 18 | def concat_datasets(self, data_foloder): 19 | datasets = [] 20 | for file_name in os.listdir(data_foloder): 21 | if file_name.endswith(".csv"): 22 | file_path = os.path.join(data_foloder, file_name) 23 | dataset = pd.read_csv(file_path) 24 | dataframe = dataset[["question", "answer"]] 25 | dataset = Dataset.from_pandas(dataframe) 26 | datasets.append(dataset) 27 | 28 | combined_dataset = concatenate_datasets(datasets) 29 | pd_combiend_dataset = pd.DataFrame(combined_dataset) 30 | return pd_combiend_dataset 31 | 32 | def build_vector_dataset(self, dataset, path): 33 | dataset = np.array(dataset) 34 | model_name = "jhgan/ko-sroberta-multitask" 35 | model = SentenceTransformer(model_name).to("cuda:0") 36 | 37 | query_vector_list = [] 38 | 39 | for i in tqdm(range(len(dataset))): 40 | question = dataset[i][0] 41 | query_vector = model.encode(question) 42 | query_vector_list.append(list(query_vector)) 43 | 44 | with open(path, "wb") as fw: 45 | pickle.dump(query_vector_list, fw) 46 | 47 | with open(path, "rb") as fr: 48 | vector_data = pickle.load(fr) 49 | 50 | return vector_data 51 | 52 | def load_vector_data(self, path): 53 | if os.path.isfile(path): 54 | with open(path, "rb") as fr: 55 | vector_data = pickle.load(fr) 56 | else: 57 | vector_data = self.build_vector_dataset(self.concat_dataset, path) 58 | return vector_data 59 | 60 | 61 | def bert_retrieve_QA(q_sentence, model, data, vector_data): 62 | start_time = time.time() 63 | model = model.to("cuda:0") 64 | 65 | original_data = data.concat_dataset 66 | 67 | input_vector = model.encode(q_sentence) 68 | input_vector = np.expand_dims(input_vector, axis=0) 69 | 70 | cos_sim = cosine_similarity(input_vector, vector_data) 71 | data_cosine = np.sort(cos_sim).squeeze()[::-1][0] 72 | top_question_idx = np.argsort(cos_sim).squeeze()[::-1][0] 73 | 74 | similar_answer = "" 75 | 76 | if data_cosine >= 0.75: 77 | similar_question = original_data["question"][top_question_idx] 78 | similar_answer = original_data["answer"][top_question_idx] 79 | print(f"retrieve_question: {similar_question}\n") 80 | print(f"retrieve_answer: {similar_answer}\n") 81 | 82 | print(f"retrieve time: {time.time() - start_time}") 83 | 84 | return similar_answer 85 | 86 | 87 | def retrieve_debugging(q_sentence): 88 | model = SentenceTransformer("jhgan/ko-sroberta-multitask") 89 | model = model.to("cuda:0") 90 | 91 | DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data/bert_retrieval_data") 92 | data = Autodata(DATA_DIR) 93 | original_data = data.concat_dataset 94 | vector_data = data.load_vector_data(os.path.join(DATA_DIR, "query_vector.bin")) 95 | 96 | input_vector = model.encode(q_sentence) 97 | input_vector = np.expand_dims(input_vector, axis=0) 98 | 99 | cos_sim = cosine_similarity(input_vector, vector_data) 100 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3] # array([0.79316866, 0.7515925 , 0.72607714]) 101 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3] # array([9285, 9217, 3223]) 102 | 103 | print("유사도 : ", data_cosine) 104 | 105 | question_list = [] 106 | answer_list = [] 107 | 108 | for i, index in enumerate(top_question): 109 | if data_cosine[i] >= 0.6: 110 | question_list.append(original_data["question"][index]) 111 | answer_list.append(original_data["answer"][index]) 112 | count = 0 113 | for question, answer in zip(question_list, answer_list): 114 | print(f"유사 상담 사례 질문 {count} : {question}") 115 | print(f"유사 상담 사례 답변 {count} : {answer}") 116 | print() 117 | count += 1 118 | 119 | 120 | if __name__ == "__main__": 121 | retrieve_debugging("제가 자동차를 운전하다 중앙선을 침범하다가 2충 추돌사고를 발생시켰습니다. 이때 무슨 법으로 처벌 받을 수 있나요?") 122 | -------------------------------------------------------------------------------- /backend/app/bm25_retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from typing import Callable, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from datasets import Dataset, DatasetDict, concatenate_datasets 9 | from rank_bm25 import BM25L, BM25Okapi, BM25Plus 10 | from transformers import AutoTokenizer 11 | 12 | 13 | def retrieve_QA(q_sentence): 14 | start_time = time.time() 15 | 16 | BASE_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) 17 | data_path = os.path.join(BASE_DIR, "data/all_data") 18 | data = Autodata(data_path) 19 | data.load_json_data(path=os.path.join(data_path, "all_data.json")) 20 | tokenizer = AutoTokenizer.from_pretrained("nlpai-lab/kullm-polyglot-5.8b-v2") 21 | 22 | datasets = run_sparse_retrieval( 23 | tokenize_fn=tokenizer.tokenize, data_path=data_path, datasets=q_sentence, bm25="plus" 24 | ) 25 | 26 | print(f"retrieve time: {time.time() - start_time}") 27 | print(f"retrieve_question: {datasets[2][0]}") 28 | 29 | return datasets[3][0] 30 | 31 | 32 | class Autodata: 33 | def __init__(self, data_folder="./data"): 34 | self.data_foloder = data_folder 35 | self.concat_dataset = self.concat_datasets(self.data_foloder) 36 | 37 | def concat_datasets(self, data_foloder): 38 | datasets = [] 39 | pd_datasets = [] 40 | for file_name in os.listdir(data_foloder): 41 | if file_name.endswith(".csv"): 42 | file_path = os.path.join(data_foloder, file_name) 43 | dataset = pd.read_csv(file_path) 44 | dataframe = dataset[["question", "answer"]] 45 | pd_datasets.append(dataframe) 46 | dataset = Dataset.from_pandas(dataframe) 47 | datasets.append(dataset) 48 | 49 | combined_dataset = concatenate_datasets(datasets) 50 | pd_combiend_dataset = pd.DataFrame(combined_dataset) 51 | 52 | return pd_combiend_dataset 53 | 54 | 55 | def make_all_data(self, data, path): 56 | df = data 57 | data_dict = {} 58 | 59 | for i in range(len(df)): 60 | key = str(i) 61 | data_dict[key] = { 62 | "question": df.iloc[i]["question"], 63 | "answer": df.iloc[i]["answer"], 64 | } 65 | 66 | with open(path, "w", encoding="utf-8") as file: 67 | json.dump(data_dict, file, ensure_ascii=False, indent=4) 68 | 69 | def load_json_data(self, path="./all_data/all_data.json"): 70 | if not os.path.isfile(path): 71 | self.make_all_data(self.concat_dataset, path) 72 | 73 | 74 | def setup_bm25(parent_class): 75 | class CustomBM25(parent_class): 76 | def __init__(self, corpus, tokenizer): 77 | super().__init__(corpus, tokenizer) 78 | 79 | def get_relevant_doc(self, query, k): 80 | query_vec = self.tokenizer(query) 81 | result = self.get_scores(query_vec) 82 | sorted_result = np.argsort(result.squeeze())[::-1] 83 | doc_score = result.squeeze()[sorted_result].tolist()[:k] 84 | doc_indices = sorted_result.tolist()[:k] 85 | return doc_score, doc_indices 86 | 87 | def get_relevant_doc_bulk(self, queries, k): 88 | doc_scores = [] 89 | doc_indices = [] 90 | for query in queries: 91 | doc_score, doc_indice = self.get_relevant_doc(query, k) 92 | doc_scores.append(doc_score) 93 | doc_indices.append(doc_indice) 94 | return doc_scores, doc_indices 95 | 96 | return CustomBM25 97 | 98 | 99 | class SparseRetrievalBM25: 100 | def __init__( 101 | self, 102 | tokenize_fn, 103 | data_path: Optional[str] = "./csv_data/", 104 | context_path: Optional[str] = "all_data.json", 105 | bm25_type: Optional[str] = "", 106 | ) -> None: 107 | self.data_path = data_path 108 | with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: 109 | wiki = json.load(f) 110 | 111 | self.contexts = list(([v["question"] for v in wiki.values()])) 112 | self.contexts_answer = list(([v["answer"] for v in wiki.values()])) 113 | 114 | if bm25_type == "Okapi": 115 | bm25_class = setup_bm25(BM25Okapi) 116 | self.bm25 = bm25_class(self.contexts, tokenize_fn) 117 | elif bm25_type == "L": 118 | bm25_class = setup_bm25(BM25L) 119 | self.bm25 = bm25_class(self.contexts, tokenize_fn) 120 | elif bm25_type == "plus": 121 | bm25_class = setup_bm25(BM25Plus) 122 | self.bm25 = bm25_class(self.contexts, tokenize_fn) 123 | 124 | def retrieve( 125 | self, query_or_dataset: Union[str, pd.DataFrame], topk: Optional[int] = 1 126 | ) -> Union[Tuple[List, List], pd.DataFrame]: 127 | if isinstance(query_or_dataset, str): 128 | doc_scores, doc_indices = self.bm25.get_relevant_doc( 129 | query_or_dataset, k=topk 130 | ) 131 | return ( 132 | doc_scores, 133 | doc_indices, 134 | [self.contexts[doc_indices[i]] for i in range(topk)], 135 | [self.contexts_answer[doc_indices[i]] for i in range(topk)], 136 | ) 137 | 138 | 139 | def run_sparse_retrieval( 140 | tokenize_fn: Callable[[str], List[str]], 141 | datasets: pd.DataFrame, 142 | data_path: str = os.path.join( 143 | os.path.abspath(os.path.dirname(__file__)), "csv_data" 144 | ), 145 | context_path: str = "all_data.json", 146 | bm25: str = None, 147 | ) -> DatasetDict: 148 | assert bm25 in ["Okapi", "L", "plus"], "Invalid type for BM25 has been passed." 149 | 150 | retriever = SparseRetrievalBM25( 151 | tokenize_fn=tokenize_fn, 152 | data_path=data_path, 153 | context_path=context_path, 154 | bm25_type=bm25, 155 | ) 156 | 157 | df = retriever.retrieve(datasets, topk=3) 158 | return df 159 | -------------------------------------------------------------------------------- /backend/app/filter.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch.nn.functional as F 4 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 5 | 6 | 7 | def is_legal_question(q_sentence): 8 | start_time = time.time() 9 | base_model_name = "monologg/koelectra-small-v3-discriminator" 10 | model = AutoModelForSequenceClassification.from_pretrained( 11 | "kfkas/legal-question-filter-koelectra", 12 | num_labels=2, 13 | ignore_mismatched_sizes=True, 14 | ) 15 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 16 | inputs = tokenizer(q_sentence, padding=True, truncation=True, return_tensors="pt") 17 | outputs = model(**inputs) 18 | logits = outputs.logits.detach().cpu() 19 | pr = F.softmax(logits, dim=1).numpy() 20 | # arg = np.argmax(pr, axis=1) 21 | # print(logits) 22 | # print(pr) 23 | # print(int(arg)) 24 | 25 | print(f"filter time: {time.time() - start_time}") 26 | 27 | if pr[0][0] >= 0.95: 28 | return True 29 | return False 30 | -------------------------------------------------------------------------------- /backend/app/generate.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | 6 | def generate_answer(q_sentence: str, model, tokenizer): 7 | model.eval() 8 | model.config.use_cache = True # silence the warnings. Please re-enable for inference! 9 | # model.float() 10 | tokenizer.pad_token = tokenizer.eos_token 11 | 12 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 13 | # prompt = f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{q_sentence}\n\n### 응답:\n" 14 | prompt = f"다음은 한국 법률 QA입니다. 질문에 맞는 적절한 응답을 작성하세요.\n\n### 질문:\n{q_sentence}\n\n### 응답:\n" 15 | # prompt = f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{q_sentence}\n\n### 응답:\n" 16 | # prompt = f"다음은 폭행 관련 법률 QA입니다. 질문과 관련된 자료를 가져온 다음 해당 내용을 요약하여 응답을 작성하세요.\n\n### 질문:\n{q_sentence}\n\n### 응답:\n" 17 | # prompt = f"다음은 폭행 관련 법률 QA입니다. 상황을 읽고 질문에 맞는 적절한 응답을 작성하세요.\n\n### 상황:\n혁준이가 술을 마시고 아내를 폭행했어. 근데 아내는 외도를 했던 상황이야\n\n### 질문:\n감형이 가능할까?\n\n### 응답:\n" 18 | len_prompt = len(prompt) 19 | 20 | start_time = time.time() 21 | 22 | gened = model.generate( 23 | **tokenizer( 24 | prompt, 25 | return_tensors='pt', 26 | return_token_type_ids=False 27 | ).to(device), 28 | max_new_tokens=1024, 29 | early_stopping=True, 30 | do_sample=True, 31 | eos_token_id=2, 32 | # temperature=1e-5, 33 | top_k=10, 34 | top_p=0.95, 35 | no_repeat_ngram_size=2, 36 | num_beams=3, 37 | # force_words_ids = tokenizer("폭행", add_special_tokens=False).input_ids, 38 | ) 39 | 40 | print(f"generate time: {time.time() - start_time}") 41 | answer = tokenizer.decode(gened[0])[len_prompt:].replace("응답:", "").replace("\n\n", "\n").replace("", "") 42 | print(f"LLM answer: {answer}\n") 43 | return answer 44 | -------------------------------------------------------------------------------- /backend/app/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytz 8 | import torch 9 | from fastapi import FastAPI 10 | from peft import PeftConfig, PeftModel 11 | from pydantic import BaseModel 12 | from sentence_transformers import SentenceTransformer 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | from bert_retrieval import Autodata, bert_retrieve_QA 16 | from bm25_retrieval import retrieve_QA 17 | from filter import is_legal_question 18 | from generate import generate_answer 19 | from search import Precedent, load_vector_data, search_precedent 20 | 21 | 22 | class Question(BaseModel): 23 | q_sentence: str 24 | 25 | class Answer(BaseModel): 26 | answer_sentence: Union[str, None] 27 | similar_precedent: Union[List[Precedent], None] 28 | 29 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 30 | app = FastAPI() 31 | 32 | llm = None 33 | tokenizer = None 34 | search_model = None 35 | retrieve_model = None 36 | retrieve_data = None 37 | retrieve_vector_data = None 38 | text_data = None 39 | vector_data = None 40 | 41 | @app.on_event("startup") 42 | def startup_event(): 43 | global tokenizer, llm, search_model, retrieve_model, retrieve_data, retrieve_vector_data, text_data, vector_data 44 | 45 | print("Load LLM") 46 | model_id = "kfkas/Legal-Llama-2-ko-7b-Chat" 47 | # config = PeftConfig.from_pretrained(peft_model_id) 48 | llm = AutoModelForCausalLM.from_pretrained( 49 | model_id, device_map={"": 0}, torch_dtype=torch.float16 50 | ) 51 | # llm = PeftModel.from_pretrained(llm, peft_model_id, torch_dtype=torch.float16) 52 | tokenizer = AutoTokenizer.from_pretrained(model_id) 53 | 54 | print("Load search model") 55 | search_model = SentenceTransformer("jhgan/ko-sroberta-multitask") 56 | 57 | print("Load retrieve model and data") 58 | retrieve_model = SentenceTransformer("jhgan/ko-sroberta-multitask") 59 | DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data/bert_retrieval_data") 60 | retrieve_data = Autodata(DATA_DIR) 61 | retrieve_vector_data = retrieve_data.load_vector_data(os.path.join(DATA_DIR, "query_vector.bin")) 62 | 63 | print("Load data") 64 | base_path = os.path.abspath(os.path.dirname(__file__)) 65 | 66 | text_data = np.array(pd.read_csv(base_path + "/../data/law_data/law_data_drop.csv")) 67 | vector_data = load_vector_data( 68 | base_path + "/../data/law_data/law_data_drop_vector.bin" 69 | ) 70 | 71 | 72 | @app.post("/generate", response_model=Answer) 73 | async def generate(question: Question): 74 | KST = pytz.timezone('Asia/Seoul') 75 | print(datetime.now(KST).strftime("%Y/%m/%d %H:%M:%S")) 76 | 77 | q_sentence = question.q_sentence 78 | print(f"q_sentence: {q_sentence}") 79 | 80 | if not is_legal_question(q_sentence=q_sentence): 81 | return Answer(answer_sentence=None, similar_precedent=None) 82 | 83 | # retrieve_answer = retrieve_QA(q_sentence=q_sentence) 84 | retrieve_answer = bert_retrieve_QA(q_sentence=q_sentence, model=retrieve_model, data=retrieve_data, vector_data=retrieve_vector_data) 85 | 86 | answer_sentence = generate_answer(q_sentence=q_sentence, model=llm, tokenizer=tokenizer) 87 | 88 | similar_precedent = search_precedent(q_a_sentence=q_sentence+retrieve_answer+answer_sentence, model=search_model, text_data=text_data, vector_data=vector_data) 89 | 90 | return Answer(answer_sentence=answer_sentence, similar_precedent=similar_precedent) 91 | -------------------------------------------------------------------------------- /backend/app/search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from pydantic import BaseModel 8 | from sentence_transformers import SentenceTransformer 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | 11 | 12 | class Precedent(BaseModel): 13 | case_name: str 14 | case_number: str 15 | case_type: str 16 | ref_article: str 17 | url: str 18 | 19 | 20 | def search_precedent(q_a_sentence: str, model, text_data, vector_data): 21 | start_time = time.time() 22 | # model = SentenceTransformer("jhgan/ko-sroberta-multitask") #TODO 23 | model.to("cuda:0") 24 | 25 | input_vector = model.encode(q_a_sentence) 26 | input_vecotr = np.expand_dims(input_vector, axis=0) 27 | 28 | cos_sim = cosine_similarity(input_vecotr, vector_data) 29 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3] 30 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3] 31 | 32 | precedent_list = [] 33 | 34 | for i, index in enumerate(top_question): 35 | if data_cosine[i] >= 0.5: 36 | ref_article = text_data[index][7] 37 | ref_article_split = str(ref_article).split() 38 | if len(ref_article_split) >= 2: 39 | url = f"https://law.go.kr/법령/{ref_article_split[0]}/{ref_article_split[1]}" 40 | else: 41 | url = "" 42 | precedent_list.append( 43 | Precedent(case_name=text_data[index][3], case_number=text_data[index][0], case_type=text_data[index][6], ref_article=ref_article, url=url) 44 | ) 45 | 46 | # del model 47 | # torch.cuda.empty_cache() 48 | 49 | print(f"search time: {time.time() - start_time}") 50 | 51 | return precedent_list 52 | 53 | def load_vector_data(path): 54 | if os.path.isfile(path): 55 | with open(path, "rb") as fr: 56 | vector_data = pickle.load(fr) 57 | else: 58 | print("판례 데이터가 존재하지 않습니다.") 59 | vector_data = None 60 | return vector_data 61 | -------------------------------------------------------------------------------- /backend/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.99.1 2 | uvicorn==0.22.0 3 | transformers==4.30.2 4 | torch==2.0.1 5 | accelerate==0.20.3 6 | pytest==7.4.0 7 | pandas==2.0.3 8 | scikit-learn==1.3.0 9 | sentence-transformers==2.2.2 10 | git+https://github.com/huggingface/peft.git 11 | rank-bm25==0.2.2 12 | datasets==2.13.1 13 | apache-airflow==2.2.3 14 | MarkupSafe==2.0.1 15 | python-dotenv==1.0.0 16 | bitsandbytes==0.41.0 17 | dnspython==2.3.0 -------------------------------------------------------------------------------- /backend/router/router.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | from typing import List, Union 4 | 5 | import pytz 6 | import requests 7 | from fastapi import FastAPI 8 | from pydantic import BaseModel 9 | 10 | app = FastAPI() 11 | 12 | class Question(BaseModel): 13 | q_sentence: str 14 | 15 | class Answer(BaseModel): 16 | answer_sentence: Union[str, None] 17 | similar_precedent: Union[List, None] 18 | 19 | 20 | @app.get("/") 21 | def root(): 22 | print("Hello World!") 23 | 24 | @app.post("/generate", response_model=Union[Answer, None]) 25 | async def generate(question: Question): 26 | KST = pytz.timezone('Asia/Seoul') 27 | print(datetime.now(KST).strftime("%Y/%m/%d %H:%M:%S")) 28 | q_sentence = question.q_sentence.strip() 29 | if q_sentence == "": 30 | print({"q_sentence": q_sentence}) 31 | return None 32 | headers = {"Content-Type": "application/json", "accept": "application/json"} 33 | url = "http://127.0.0.1:8000/generate" 34 | data = {"q_sentence": q_sentence} 35 | 36 | print(data) 37 | res = requests.post(url, headers=headers, data=json.dumps(data)) 38 | 39 | return res.json() 40 | -------------------------------------------------------------------------------- /backend/tests/test_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 5 | 6 | from app.filter import is_legal_question 7 | 8 | 9 | def test_is_legal_question(): 10 | assert not is_legal_question("안녕하세요.") 11 | assert is_legal_question("제가 술을 먹고 운전을 했는데 어떤 처벌을 받을까요?") 12 | -------------------------------------------------------------------------------- /backend/tests/test_generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | 6 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 7 | 8 | from app.generate import generate_answer 9 | from peft import PeftConfig, PeftModel 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | def test_generate_answer(): 14 | q_sentence = "제가 술을 먹고 운전을 했는데 어떤 처벌을 받을까요?" 15 | peft_model_id = "YoonSeul/LawBot-level-3-KuLLM-5.8B-tae-2epoch" 16 | config = PeftConfig.from_pretrained(peft_model_id) 17 | model = AutoModelForCausalLM.from_pretrained( 18 | config.base_model_name_or_path, device_map={"": 0}, torch_dtype=torch.float16 19 | ) 20 | model = PeftModel.from_pretrained(model, peft_model_id, torch_dtype=torch.float16) 21 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) 22 | generate_answer(q_sentence=q_sentence, model=model, tokenizer=tokenizer) 23 | -------------------------------------------------------------------------------- /backend/tests/test_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 5 | 6 | from app.bm25_retrieval import retrieve_QA 7 | 8 | 9 | def test_retrieve_QA(): 10 | retrieve_QA("제가 술을 먹고 운전을 했는데 어떤 처벌을 받을까요?") 11 | -------------------------------------------------------------------------------- /backend/tests/test_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from app.search import load_vector_data, search_precedent 9 | from sentence_transformers import SentenceTransformer 10 | 11 | 12 | def test_search_precedent(): 13 | q_a_sentence = "교통법규 위반으로 벌점을 받았습니다. 이 벌점은 소멸되지 않고 계속 누적되나요?"+"벌점이 소멸되는 것이 아니라 누적되어 관리됩니다. 즉, 벌점이 누적되면 운전면허가 취소되거나 정지될 수 있습니다. 벌점 누적에 따른 운전면허의 취소 또는 정지 기준은 다음과 같습니다(도로교통법 시행규칙 별표 28). 운전면허 취소기준 1. 혈중알콜농도가 0.1% 이상인 사람이 자동차 등을 운전한 경우 2. 음주측정기에 의한 측정결과에 불복하는 사람이 술에 취한 상태에 있다고 인정할 만한 상당한 이유가 있음에도 불구하고 경찰공무원의 측정 요구에 불응하거나 경찰공무원을 폭행 또는 협박한 경우(단, 운전자가 경찰공무원에게 폭행을 가한 경우에는 그 정도가 심하지 않을 때에 한함) 3. 적성검사를 받지 않거나 적성검사에 불합격된 사람이 다시 운전면허를 받고자 하는 경우 4. 자동차를 이용하여 범죄행위를 한 경우 5. 다른 사람의 자동차를 훔치거나 빼앗은 경우 6. 교통사고를 야기하고 도주한 경우 7. 단속경찰공무원 등을 폭행한 경우 8. 정차ㆍ주차위반에 대한 조치" 14 | 15 | model = SentenceTransformer("jhgan/ko-sroberta-multitask") 16 | 17 | print("Load data") 18 | base_path = os.path.abspath(os.path.dirname(__file__)) 19 | 20 | text_data = np.array(pd.read_csv(base_path + "/../data/law_data/law_data.csv")) 21 | vector_data = load_vector_data( 22 | base_path + "/../data/law_data/law_data_drop_vector.bin" 23 | ) 24 | 25 | search_precedent(q_a_sentence=q_a_sentence, model=model, text_data=text_data, vector_data=vector_data) 26 | -------------------------------------------------------------------------------- /data_pipeline/.gitignore: -------------------------------------------------------------------------------- 1 | local/* 2 | data/* 3 | artifact/* 4 | __pycache__ 5 | *.log -------------------------------------------------------------------------------- /data_pipeline/README.md: -------------------------------------------------------------------------------- 1 | # LawBot - Data Pipeline 2 | * All commands in this instruction should be run in the following directory.
3 | `/data_pipeline` 4 | ## ⚠️ How To install Requirements 5 | * Run the following command on your terminal. 6 | 7 | ```bash 8 | $ pip install -r requirments.txt 9 | ``` 10 | 11 | ## ⌨️ How To Execute 12 | ### Web Crawling 13 | ```bash 14 | $ python3 crawler.py 15 | $ python3 qa_crawler.py 16 | ``` 17 | ### Data Generation 18 | ```bash 19 | $ python3 generate_gpt.py 20 | $ python3 parse.py 21 | ``` 22 | * To generate data using the GPT model, you need to [obtain an API key](https://platform.openai.com/account/api-keys) from OpenAI first.
23 | Depending on the model used, usage fees might be charged. 24 | 25 | * If you want to modify the prompts, follow these steps. 26 | 1. Add the prompts in the `backup_prompts.py` file. 27 | 2. Run the following command in your terminal. 28 | 29 | ```bash 30 | $ python3 backup_prompts.py 31 | ``` 32 | 3. New pickle file will be overlapped to existing `prompts.pkl` 33 | 4. After that, you can proceed with the stpes mentioned earlier. 34 | 35 | ### Preprocessing 36 | * Run the following commands in your terminal. 37 | ```bash 38 | $ python3 spellchecker.py 39 | $ python3 preprocessor_v2.py 40 | ``` 41 | -------------------------------------------------------------------------------- /data_pipeline/crawler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import re 5 | import time 6 | from contextlib import contextmanager 7 | from urllib.request import urlopen 8 | 9 | import pandas as pd 10 | from bs4 import BeautifulSoup 11 | from selenium import webdriver 12 | from selenium.webdriver.common.by import By 13 | from selenium.webdriver.remote.remote_connection import \ 14 | LOGGER as selenium_logger 15 | from selenium.webdriver.support import expected_conditions as EC 16 | from selenium.webdriver.support.ui import WebDriverWait 17 | from tqdm.auto import tqdm 18 | from utils import utilize_loggers 19 | from webdriver_manager.chrome import ChromeDriverManager 20 | 21 | selenium_logger.setLevel(logging.WARNING) 22 | os.environ["WDM_LOG"] = "0" 23 | 24 | 25 | @contextmanager 26 | def timer(): 27 | t0 = time.time() 28 | yield lambda: time.time() - t0 29 | 30 | 31 | def measure_elapsed_time(timer_name): 32 | def decorator(func): 33 | def wrapper(*args, **kwargs): 34 | with timer() as elapsed_time: 35 | result = func(*args, **kwargs) 36 | logger.info(f"{elapsed_time():>8.3f} seconds elapsed @ {timer_name}") 37 | return result 38 | 39 | return wrapper 40 | 41 | return decorator 42 | 43 | 44 | class QADataCrawler: 45 | def __init__( 46 | self, 47 | board_url="https://www.klac.or.kr/legalstruct/cyberConsultation/selectOpenArticleList.do?boardCode=3#none", 48 | base_url="https://www.klac.or.kr/legalstruct/cyberConsultation/selectOpenArticleDetail.do?boardCode=3&contentId=", 49 | ): 50 | self.driver = None 51 | self.board_url = board_url 52 | self.base_url = base_url 53 | 54 | def start_driver(self): 55 | self.chrome_service = webdriver.chrome.service.Service( 56 | ChromeDriverManager().install() 57 | ) 58 | self.chrome_options = webdriver.ChromeOptions() 59 | self.chrome_options.add_argument("headless") 60 | 61 | self.driver = webdriver.Chrome( 62 | service=self.chrome_service, 63 | options=self.chrome_options, 64 | ) 65 | 66 | self.chrome_service.start() 67 | 68 | def quit_driver(self): 69 | self.driver.quit() 70 | 71 | @measure_elapsed_time("Total Crawling Process") 72 | def get_data(self): 73 | case_ids = self._get_all_case_ids() 74 | case_info = self._get_all_case_contents(case_ids) 75 | 76 | self._save_dataframe(case_info) 77 | 78 | def _get_case_id(self): 79 | case_ids = [] 80 | element_xpath = "//a[contains(@onclick, 'fn_inquire_detail')]" 81 | WebDriverWait(self.driver, 10).until( 82 | EC.presence_of_element_located((By.XPATH, element_xpath)) 83 | ) 84 | elements = self.driver.find_elements(By.XPATH, element_xpath) 85 | for element in elements: 86 | match = re.search( 87 | r"fn_inquire_detail\('(\d+)', '(.*?)'\);return false;", 88 | element.get_attribute("onclick"), 89 | ) 90 | case_ids.append(match.group(2)) 91 | return case_ids 92 | 93 | @measure_elapsed_time("Get all case ids") 94 | def _get_all_case_ids(self, save_id_list=False): 95 | is_crawling_finished = False 96 | page_move_cnt = 0 97 | page_number = 1 98 | page_idx = 2 99 | case_ids = [] 100 | 101 | self.driver.get(self.board_url) 102 | case_ids.extend(self._get_case_id()) 103 | 104 | while not is_crawling_finished: 105 | try: 106 | next_page = self.driver.find_element( 107 | By.XPATH, 108 | f'//*[@id="content"]/form[1]/div[2]/div/div[4]/a[{page_idx}]', 109 | ) 110 | page_idx += 1 111 | except: 112 | try: 113 | next_page = self.driver.find_element( 114 | By.XPATH, 115 | f'//button[contains(@onclick, "fn_select_linkPage({page_move_cnt * 10 + page_idx}); return false;")]', 116 | ) 117 | page_idx = 2 118 | page_move_cnt += 1 119 | except: 120 | is_crawling_finished = True 121 | break 122 | 123 | next_page.click() 124 | case_ids.extend(self._get_case_id()) 125 | page_number += 1 126 | 127 | if save_id_list: 128 | self._save_case_id_list(case_ids, "case_id_list.pkl") 129 | 130 | return case_ids 131 | 132 | def _get_case_content_by_id(self, case_id): 133 | url = self.base_url + case_id 134 | html = urlopen(url) 135 | bsObject = BeautifulSoup(html, "html.parser") 136 | 137 | case_title = bsObject.find("div", {"class": "view_head"}).text 138 | date_created = bsObject.find("dt", text="신청일").find_next_sibling("dd").text 139 | date_answered = bsObject.find("dt", text="답변일자").find_next_sibling("dd").text 140 | content, answer = bsObject.find_all("div", {"class": "notice_contents"}) 141 | case_info = [case_title, date_created, date_answered, content.text, answer.text] 142 | 143 | return case_info 144 | 145 | @measure_elapsed_time("Get all case contents") 146 | def _get_all_case_contents(self, case_ids): 147 | case_info = [] 148 | 149 | for case_id in tqdm(case_ids): 150 | case_info.append(self._get_case_content_by_id(case_id)) 151 | 152 | return case_info 153 | 154 | def _save_dataframe(self, case_info, drop_unused_columns=False): 155 | df = pd.DataFrame( 156 | case_info, 157 | columns=[ 158 | "case_title", 159 | "date_created", 160 | "date_answered", 161 | "content", 162 | "answer", 163 | ], 164 | ) 165 | 166 | if drop_unused_columns: 167 | df = df.drop(["case_title", "date_created", "date_answered"], axis=1) 168 | 169 | os.makedirs("data", exist_ok=True) 170 | df.to_csv("./data/raw_qa_dataset.csv", index=False) 171 | logger.info(f"\t\tGathered Data Count: {len(df)}") 172 | 173 | def _save_case_id_list(self, case_id_list, file_name="case_id_list.pkl"): 174 | with open(file_name, "wb") as f: 175 | pickle.dump(case_id_list, f) 176 | 177 | def _load_case_id_list(self, file="case_id_list.pkl"): 178 | with open(file, "rb") as f: 179 | loaded_list = pickle.load(f) 180 | return loaded_list 181 | 182 | 183 | if __name__ == "__main__": 184 | logger = utilize_loggers(__file__) 185 | 186 | crawler = QADataCrawler() 187 | crawler.start_driver() 188 | crawler.get_data() 189 | crawler.quit_driver() 190 | -------------------------------------------------------------------------------- /data_pipeline/generate/backup_prompts.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | prompts = { 4 | "zeroshot": "임의의 법률 분쟁 상황을 가정하고, 그에 대한 내용을 질문의 형식으로 만들어주세요. 그리고 해당 질문에 대한 답변을 함께 출력해주세요.", 5 | "zeroshot2": "법률 분쟁 상황을 가정하고, 해당 상황의 가해자 또는 피해자가 의뢰할 만한 상담 내용을 작성해주세요. 출력에는 해당 상담 내용에 대한 답변을 포함해주세요.", 6 | "oneshot": """ 7 | 아래 예시의 형식을 참고하여, 임의의 법률 분쟁 상황에 대한 질의 응답 데이터를 생성해주세요. 8 | 9 | [예시] 10 | 질문: 남편이 가출하여 연락이 되지 않다가 3년 6개월 뒤 ‘실종자 찾아주기’ 운동의 일환으로 DNA검사를 했더니 남편은 3년 전에 이미 교통사고로 사망하였고, 신원미상자로 처리되었다고 합니다. 가출 전 남편이 들어놓은 사망보험금을 청구했더니 보험회사는 사망 후 3년이 경과하였기 때문에 소멸시효 완성을 주장하고 있습니다. 보험금은 못 받는건가요? 11 | 12 | 답변: 보험금청구권의 소멸시효는 특별한 다른 사정이 없는 한 보험사고가 발생한 때부터 진행하는 것이 원칙입니다. 그러나 객관적으로 보험사고가 발생한 사실을 확인할 수 없는 사정이 있는 경우에는 보험금청구권자가 보험사고의 발생을 알았거나 알 수 있었던 때부터 보험금청구권의 소멸시효가 진행합니다. 따라서 사례의 경우에는 보험금 청구가 가능합니다. 13 | """, 14 | "fewshot": """ 15 | 아래 예시의 형식을 참고하여, 임의의 법률 분쟁 상황에 대한 질의 응답 데이터를 생성해주세요. 16 | 17 | [예시 1] 18 | 질문: 남편이 가출하여 연락이 되지 않다가 3년 6개월 뒤 ‘실종자 찾아주기’ 운동의 일환으로 DNA검사를 했더니 남편은 3년 전에 이미 교통사고로 사망하였고, 신원미상자로 처리되었다고 합니다. 가출 전 남편이 들어놓은 사망보험금을 청구했더니 보험회사는 사망 후 3년이 경과하였기 때문에 소멸시효 완성을 주장하고 있습니다. 보험금은 못 받는건가요? 19 | 20 | 답변: 보험금청구권의 소멸시효는 특별한 다른 사정이 없는 한 보험사고가 발생한 때부터 진행하는 것이 원칙입니다. 그러나 객관적으로 보험사고가 발생한 사실을 확인할 수 없는 사정이 있는 경우에는 보험금청구권자가 보험사고의 발생을 알았거나 알 수 있었던 때부터 보험금청구권의 소멸시효가 진행합니다. 따라서 사례의 경우에는 보험금 청구가 가능합니다. 21 | 22 | [예시 2] 23 | 질문: 아는 사람에게 500만원을 빌려줬는데 갚지 않습니다. 소송을 해야 할 것 같은데 비용이며, 시간이 꽤 들 것 같네요. 방법이 없을까요? 24 | 25 | 답변: 소송의 당사자가 소송으로 청구하는 금액이나 물건의 가치가 3천만원을 넘지 않는 사건은 시간이나 비용에 있어서 민사소송보다 간편한 절차로 진행할 수 있는 소액사건재판 제도를 이용할 수 있습니다. 소액사건재판 외에도 민사조정이나 지급명령(독촉절차)을 이용할 수도 있습니다. 26 | """, 27 | "fewshot_with_constraints": """ 28 | 아래 예시의 형식을 참고하여, 임의의 법률 분쟁 상황에 대한 새로운 질의 응답 데이터를 생성해주세요. 답변을 생성할 때는 다음 조건에 맞게 생성해주세요. 29 | 30 | [조건] 31 | - 제시된 예시와는 다른 사레를 바탕으로 질문을 생성해줘 32 | - 대한민국의 법률에 근거하여 법적 분쟁 상황에 대한 답변을 생성해줘 33 | - 대한민국 법률에 실제로 존재하는 조항과, 입력과 유사한 상황에 대한 판례를 답변 내용에 포함해줘 34 | 35 | [예시 1] 36 | 질문: 남편이 가출하여 연락이 되지 않다가 3년 6개월 뒤 ‘실종자 찾아주기’ 운동의 일환으로 DNA검사를 했더니 남편은 3년 전에 이미 교통사고로 사망하였고, 신원미상자로 처리되었다고 합니다. 가출 전 남편이 들어놓은 사망보험금을 청구했더니 보험회사는 사망 후 3년이 경과하였기 때문에 소멸시효 완성을 주장하고 있습니다. 보험금은 못 받는건가요? 37 | 38 | 답변: 보험금청구권의 소멸시효는 특별한 다른 사정이 없는 한 보험사고가 발생한 때부터 진행하는 것이 원칙입니다. 그러나 객관적으로 보험사고가 발생한 사실을 확인할 수 없는 사정이 있는 경우에는 보험금청구권자가 보험사고의 발생을 알았거나 알 수 있었던 때부터 보험금청구권의 소멸시효가 진행합니다. 따라서 사례의 경우에는 보험금 청구가 가능합니다. 39 | 40 | [예시 2] 41 | 질문: 아는 사람에게 500만원을 빌려줬는데 갚지 않습니다. 소송을 해야 할 것 같은데 비용이며, 시간이 꽤 들 것 같네요. 방법이 없을까요? 42 | 43 | 답변: 소송의 당사자가 소송으로 청구하는 금액이나 물건의 가치가 3천만원을 넘지 않는 사건은 시간이나 비용에 있어서 민사소송보다 간편한 절차로 진행할 수 있는 소액사건재판 제도를 이용할 수 있습니다. 소액사건재판 외에도 민사조정이나 지급명령(독촉절차)을 이용할 수도 있습니다. 44 | """, 45 | } 46 | 47 | 48 | if __name__ == "__main__": 49 | with open("prompts.pkl", "wb") as f: 50 | pickle.dump(prompts, f) -------------------------------------------------------------------------------- /data_pipeline/generate/generate_bard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | 5 | import pandas as pd 6 | from bardapi import Bard 7 | from tqdm.auto import tqdm 8 | 9 | bard = Bard(token_from_browser=True) 10 | 11 | def get_response(prompt): 12 | response = bard.get_answer(prompt) 13 | return response 14 | 15 | with open("prompts.pkl", "rb") as f: 16 | prompts = pickle.load(f) 17 | 18 | data = [] 19 | num_data = 1 20 | prompt_name = "fewshot" 21 | prompt = prompts["fewshot"] 22 | 23 | for i in tqdm(range(num_data)): 24 | try: 25 | response = get_response(prompt) 26 | except: 27 | time.sleep(5) 28 | continue 29 | data.append( 30 | [ 31 | prompt_name, 32 | prompt, 33 | *[response["choices"][i]["content"][0] for i in range(len(response["choices"]))] 34 | ] 35 | ) 36 | 37 | generated_df = pd.DataFrame( 38 | data, 39 | columns=[ 40 | "prompt_type", 41 | "prompt", 42 | "result_1", 43 | "result_2", 44 | "result_3" 45 | ]) 46 | 47 | os.makedirs("../data/generated_data/bard", exist_ok=True) 48 | generated_df.to_csv(f"../data/generated_data/bard/generated_data_bard_{len(generated_df)}.csv", index=False) -------------------------------------------------------------------------------- /data_pipeline/generate/generate_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | 5 | import openai 6 | import pandas as pd 7 | from tqdm.auto import tqdm 8 | 9 | openai.api_key = os.environ["OPENAI_API_KEY"] 10 | 11 | def get_response(prompt, model="gpt-3.5-turbo", temperature=1.0, max_tokens=1000): 12 | messages = [{"role": "user", "content": prompt}] 13 | response = openai.ChatCompletion.create( 14 | model=model, 15 | messages=messages, 16 | temperature=temperature, 17 | max_tokens=max_tokens 18 | ) 19 | return response 20 | 21 | 22 | def get_price_of_inference(model, input_tokens, output_tokens): 23 | if model == "gpt-3.5-turbo-0613": 24 | input_price_per_k = 0.0015 25 | output_price_per_k = 0.002 26 | price_dollar = (input_tokens * input_price_per_k + output_tokens * output_price_per_k) / 1000 27 | price_won = round(price_dollar * 1281.61, 5) 28 | return [price_dollar, price_won] 29 | else: 30 | return None 31 | 32 | 33 | with open("prompts.pkl", "rb") as f: 34 | prompts = pickle.load(f) 35 | 36 | full_responses = {} 37 | data = [] 38 | num_data = 1000 39 | prompt_type = "fewshot" 40 | prompt = prompts["fewshot"] 41 | 42 | for i in tqdm(range(num_data)): 43 | try: 44 | response = get_response(prompt) 45 | except: 46 | time.sleep(5) 47 | continue 48 | full_responses[prompt_type] = response 49 | output = response.choices[0].message.content # GPT output 50 | model = response.model # Model used 51 | input_tokens = response.usage.prompt_tokens # Number of tokens of input 52 | output_tokens = response.usage.completion_tokens # Number of tokens of output 53 | data.append( 54 | [ 55 | prompt_type, 56 | prompt, 57 | output, 58 | model, 59 | input_tokens, 60 | output_tokens, 61 | *get_price_of_inference(model, input_tokens, output_tokens) 62 | ], 63 | ) 64 | 65 | generated_df = pd.DataFrame( 66 | data, 67 | columns=[ 68 | "prompt_type", 69 | "prompt", 70 | "output", 71 | "model", 72 | "input_tokens", 73 | "output_tokens", 74 | "price_dollar", 75 | "price_won" 76 | ]) 77 | 78 | os.makedirs("../data/generated_data/gpt", exist_ok=True) 79 | generated_df.to_csv(f"./data/generated_data/gpt/generated_data_gpt_{len(generated_df)}.csv", index=False) -------------------------------------------------------------------------------- /data_pipeline/generate/parse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import pandas as pd 5 | 6 | 7 | def collect_raw_data(path): 8 | raw_data = [data for data in os.listdir(path) if data.endswith(".csv")] 9 | collected_data = pd.DataFrame() 10 | 11 | for data in raw_data: 12 | data_path = os.path.join(path, data) 13 | df = pd.read_csv(data_path) 14 | collected_data = pd.concat([collected_data, df], ignore_index=True) 15 | 16 | collected_data = collected_data.drop_duplicates() 17 | collected_data = collected_data.reset_index() 18 | return collected_data 19 | 20 | 21 | def check_output_format(data): 22 | for idx, datum in enumerate(data): 23 | pattern = r"\[\s*(질문|답변|Q|A)\s*\d*\]|(질문|답변)\s*:\s*" 24 | is_fit_format = (re.search(pattern, datum) != None) 25 | if is_fit_format: 26 | continue 27 | else: 28 | print(f"Index {idx} is out of format") 29 | return False 30 | return True 31 | 32 | 33 | def check_qa_pair(data): 34 | mismatched_indices = [idx for idx, datum in enumerate(data) if (len(datum) % 2) != 0] 35 | return mismatched_indices 36 | 37 | 38 | path = "../data/generated_data/gpt" 39 | raw_data = collect_raw_data(path) 40 | raw_outputs = raw_data.output.tolist() 41 | 42 | print(f"Found {len(raw_outputs)} data points.") 43 | 44 | assert check_output_format(raw_outputs) == True, "Check the format of the data" 45 | 46 | processed_outputs = [] 47 | 48 | for idx, raw_data in enumerate(raw_outputs): 49 | pattern = r"\[\s*(질문|답변|Q|A)\s*\d*\]|(질문|답변)\s*:\s*" 50 | processed_data = re.sub(pattern, "[SEP]", raw_data) 51 | processed_data = processed_data.split("[SEP]") 52 | processed_data = [data.strip() for data in processed_data if len(data) > 20] 53 | processed_data = [re.sub(r"\[\s*예시\s*\d*\]", "", data) for data in processed_data] 54 | 55 | if len(processed_data) % 2 != 0: 56 | if len(processed_data) == 1: 57 | continue 58 | processed_data = processed_data[:-1] 59 | processed_outputs.append(processed_data) 60 | 61 | assert len(check_qa_pair(processed_outputs)) == 0, "QA pair mismatched data exists." 62 | 63 | q_list = [] 64 | a_list = [] 65 | for output in processed_outputs: 66 | while len(output) != 0: 67 | q_list.append(output.pop(0)) 68 | a_list.append(output.pop(0)) 69 | 70 | assert len(q_list) == len(a_list), "QA pair mismatch" 71 | 72 | processed_data = pd.DataFrame({"question": q_list, "answer": a_list}) 73 | processed_data.to_csv(f"./data/generated_qa_data_{len(processed_data)}.csv", index=False) 74 | 75 | print(f"Generated {len(processed_data)} pairs of QA data.") -------------------------------------------------------------------------------- /data_pipeline/generate/prompts.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/data_pipeline/generate/prompts.pkl -------------------------------------------------------------------------------- /data_pipeline/preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pandas as pd 4 | 5 | 6 | def remove_escape(raw_text: str) -> str: 7 | pattern = r"\t|\n|\xa0" 8 | processed_text = re.sub(pattern, " ", raw_text) 9 | processed_text_stripped = " ".join(processed_text.split()) 10 | return processed_text_stripped 11 | 12 | 13 | def remove_phone_number(raw_text: str) -> str: 14 | pattern = r"\(*\d+\s*-\s*\d+\s*-\s*\d+\)*" 15 | processed_text = re.sub(pattern, "", raw_text) 16 | return processed_text 17 | 18 | 19 | def remove_hyperlink(raw_text: str) -> str: 20 | pattern = ( 21 | r":*\s*\(*:*\s*https?://[\w\dㄱ-ㅎㅏ-ㅣ가-힣!@#$%^&*(),.?/:;\"'<>{}|+=~_-]+\s*\)*" 22 | ) 23 | processed_text = re.sub(pattern, "", raw_text) 24 | return processed_text 25 | 26 | 27 | def remove_header(raw_text: str) -> str: 28 | header_pattern = "안녕하십니까. 대한법률구조공단 사이버상담을 이용해 주셔서 감사합니다." 29 | header_end_idx = re.search(header_pattern, raw_text) 30 | if header_end_idx != None: 31 | processed_text = raw_text[header_end_idx.end() :] 32 | return processed_text 33 | else: 34 | return raw_text 35 | 36 | 37 | def remove_footer(raw_text: str) -> str: 38 | footer_pattern = "1. 위 답변은 귀하께서 제공해주신 사실관계에 기초한 답변자 개인의 법률적 의견으로서 이와 다른 의견이 있을 수도 있으므로 참고자료로만 활용해주시고," 39 | footer_start_idx = re.search(footer_pattern, raw_text) 40 | if footer_start_idx != None: 41 | processed_text = raw_text[: footer_start_idx.start()] 42 | return processed_text 43 | else: 44 | return raw_text 45 | 46 | 47 | def preprocess(raw_text: str) -> str: 48 | preprocessed_text = raw_text 49 | preprocess_functions = [ 50 | remove_header, 51 | remove_footer, 52 | remove_escape, 53 | remove_phone_number, 54 | remove_hyperlink, 55 | ] 56 | for preprocess_function in preprocess_functions: 57 | preprocessed_text = preprocess_function(preprocessed_text) 58 | return preprocessed_text 59 | 60 | 61 | if __name__ == "__main__": 62 | df = pd.read_csv("./data/raw_qa_dataset.csv") 63 | preprocessed_df = df.assign( 64 | content=df["content"].apply(preprocess), answer=df["answer"].apply(preprocess) 65 | ) 66 | preprocessed_df.to_csv("./data/preprocessed_qa_dataset.csv", index=False) 67 | -------------------------------------------------------------------------------- /data_pipeline/preprocessor_v2.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pandas as pd 4 | 5 | 6 | def remove_escape(raw_text: str) -> str: 7 | pattern = r"\t|\n|\xa0" 8 | processed_text = re.sub(pattern, " ", raw_text) 9 | processed_text_stripped = " ".join(processed_text.split()) 10 | return processed_text_stripped 11 | 12 | 13 | def remove_phone_number(raw_text: str) -> str: 14 | pattern = r"\(*\d+\s*-\s*\d+\s*-\s*\d+\)*" 15 | processed_text = re.sub(pattern, "", raw_text) 16 | return processed_text 17 | 18 | 19 | def remove_hyperlink(raw_text: str) -> str: 20 | pattern = ( 21 | r":*\s*\(*:*\s*https?://[\w\dㄱ-ㅎㅏ-ㅣ가-힣!@#$%^&*(),.?/:;\"'<>{}|+=~_-]+\s*\)*" 22 | ) 23 | processed_text = re.sub(pattern, "", raw_text) 24 | return processed_text 25 | 26 | 27 | def remove_header(raw_text: str) -> str: 28 | header_pattern = "안녕하십니까. 대한법률구조공단 사이버상담을 이용해 주셔서 감사합니다." 29 | header_end_idx = re.search(header_pattern, raw_text) 30 | if header_end_idx != None: 31 | processed_text = raw_text[header_end_idx.end() :] 32 | return processed_text 33 | else: 34 | return raw_text 35 | 36 | 37 | def remove_footer(raw_text: str) -> str: 38 | footer_pattern = "1. 위 답변은 귀하께서 제공해주신 사실관계에 기초한 답변자 개인의 법률적 의견으로서 이와 다른 의견이 있을 수도 있으므로 참고자료로만 활용해주시고," 39 | footer_start_idx = re.search(footer_pattern, raw_text) 40 | if footer_start_idx != None: 41 | processed_text = raw_text[: footer_start_idx.start()] 42 | return processed_text 43 | else: 44 | return raw_text 45 | 46 | def remove_link(raw_text: str) -> str: 47 | pattern = ( 48 | '\(?[:/a-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+[/\da-zA-Z?=%@.&]+\s?\)?' 49 | ) 50 | processed_text = re.sub(pattern, "", raw_text) 51 | pattern = ( 52 | '\(?[:/a-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+\s?\)?' 53 | ) 54 | processed_text = re.sub(pattern, "", processed_text) 55 | 56 | pattern = ( 57 | '\(?[:/a-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+\)?|' 58 | ) 59 | processed_text = re.sub(pattern, "", processed_text) 60 | return processed_text 61 | 62 | 63 | def remove_page_word(raw_text: str) -> str: 64 | 65 | pattern = '사이버상담|사이버 상담|공단|방문|국번없이 132|132번' 66 | if re.findall(pattern, raw_text) == []: 67 | return raw_text 68 | 69 | split_text = raw_text.split('.') 70 | remove_text = [i for i in split_text if re.findall(pattern, i) == []] 71 | 72 | return '.'.join(remove_text) 73 | 74 | def remove_phone(raw_text: str) -> str: 75 | pattern = ('\(?\s?☎?\s?국번\s?없이\s?☎?\s?\d+-?\d+\s?번?\)?') 76 | processed_text = re.sub(pattern, "", raw_text) 77 | pattern = ('\(?\s?☎\s?\d+-?\d+\s?번?\)?') 78 | processed_text = re.sub(pattern, "", processed_text) 79 | 80 | 81 | return processed_text 82 | 83 | 84 | def preprocess(raw_text: str) -> str: 85 | preprocessed_text = raw_text 86 | preprocess_functions = [ 87 | remove_header, 88 | remove_footer, 89 | remove_escape, 90 | remove_phone, 91 | remove_page_word, 92 | remove_hyperlink, 93 | remove_link, 94 | 95 | ] 96 | for preprocess_function in preprocess_functions: 97 | preprocessed_text = preprocess_function(preprocessed_text) 98 | return preprocessed_text 99 | 100 | 101 | if __name__ == "__main__": 102 | df = pd.read_csv("./data/legal_train_v1.csv", lineterminator='\n') 103 | preprocessed_df = df.assign( 104 | instruction=df["instruction"].apply(preprocess), output=df["output"].apply(preprocess) 105 | ) 106 | preprocessed_df.to_csv("test_preprocess_dataset2.csv", index=False) 107 | -------------------------------------------------------------------------------- /data_pipeline/qa_crawler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib.request import urlopen 3 | 4 | import pandas as pd 5 | from bs4 import BeautifulSoup as bs 6 | from selenium import webdriver 7 | from selenium.webdriver.common.by import By 8 | from selenium.webdriver.support.ui import WebDriverWait 9 | 10 | 11 | 12 | class QALawCrawler: 13 | def __init__(self, start_url="https://www.klac.or.kr/legalinfo/counsel.do"): 14 | self.start_url = start_url 15 | self.driver = None 16 | 17 | def give_options(self): 18 | options = webdriver.ChromeOptions() 19 | options.add_argument("--headless") 20 | options.add_argument("--no-sandbox") 21 | options.add_argument("--single-process") 22 | options.add_argument("--disable-dev-shm-usage") 23 | return options 24 | 25 | def start_driver(self): 26 | self.driver = webdriver.Chrome(options=self.give_options()) 27 | 28 | def quit_driver(self): 29 | self.driver.quit() 30 | 31 | def crawlling_data(self): 32 | self.driver.get(self.start_url) 33 | 34 | data_list = [] 35 | end_point = 0 36 | 37 | while end_point == 0: 38 | for j in range(2, 11): 39 | for k in range(1, 11): 40 | 41 | try: 42 | self._wait_driver_click( 43 | f'//*[@id="content"]/div[2]/div/form/div[2]/table/tbody/tr[{k}]/td[2]/a' 44 | ) 45 | data_list.append(self._collect_data()) 46 | self._wait_driver_click(f'//*[@id="content"]/div[2]/div/div/a') 47 | except: 48 | break 49 | 50 | try: 51 | self._wait_driver_click( 52 | f'//*[@id="content"]/div[2]/div/form/div[3]/a[{j}]' 53 | ) 54 | except: 55 | end_point = 1 56 | break 57 | 58 | try: 59 | self._wait_driver_click( 60 | '//*[@id="content"]/div[2]/div/form/div[3]/button[3]' 61 | ) 62 | except: 63 | break 64 | 65 | self._save_data(data_list=data_list, drop_unused_columns=True) 66 | 67 | def _wait_driver_click(self, xpath): 68 | WebDriverWait(self.driver, timeout=10).until( 69 | lambda d: d.find_element(By.XPATH, xpath) 70 | ) 71 | self.driver.find_element(By.XPATH, xpath).click() 72 | 73 | def _collect_data(self): 74 | try: 75 | url = self.driver.current_url 76 | page_url = urlopen(url) 77 | soup = bs(page_url, "html.parser") 78 | data = [] 79 | for i in range(1, 5): 80 | find_data = soup.select_one( 81 | "#print_page > div:nth-child(" + str(i) + ") > dl > dd" 82 | ).text 83 | data.append(find_data) 84 | data.append(url) 85 | return data 86 | except: 87 | return ["error", 0, 0, 0, 0] 88 | 89 | def _save_data(self, data_list, drop_unused_columns=False): 90 | df = pd.DataFrame( 91 | data=data_list, columns=["division", "title", "question", "answer", "url"] 92 | ) 93 | if drop_unused_columns: 94 | question = df["question"].values.tolist() 95 | answer = df["answer"].values.tolist() 96 | df = pd.DataFrame({"question": question, "answer": answer}) 97 | 98 | os.makedirs("data", exist_ok=True) 99 | df.to_csv("./data/law_qa.csv", index=False) 100 | 101 | 102 | if __name__ == "__main__": 103 | qa_crawler = QALawCrawler() 104 | qa_crawler.start_driver() 105 | qa_crawler.crawlling_data() 106 | qa_crawler.quit_driver() 107 | -------------------------------------------------------------------------------- /data_pipeline/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==2.0.3 2 | beautifulsoup4==4.12.2 3 | scikit-learn==1.1.2 4 | selenium==4.10.0 5 | webdriver-manager==3.8.6 6 | tqdm==4.65.0 7 | openai==0.27.8 8 | bardapi==0.1.27 -------------------------------------------------------------------------------- /data_pipeline/spellchecker.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | from selenium import webdriver 3 | from selenium.webdriver.common.by import By 4 | from selenium.webdriver.common.keys import Keys 5 | from selenium.webdriver.support.ui import WebDriverWait 6 | 7 | import pandas as pd 8 | 9 | def set_options(): 10 | options = webdriver.ChromeOptions() 11 | options.add_argument("--headless") 12 | options.add_argument('--no-sandbox') 13 | options.add_argument("--single-process") 14 | options.add_argument("--disable-dev-shm-usage") 15 | return options 16 | 17 | def wait_driver_click(id): 18 | WebDriverWait(driver, timeout=60).until(lambda d: d.find_element(By.ID, id)) 19 | driver.find_element(By.ID, id).click() 20 | 21 | 22 | def spell_check(txt, delay = 15): 23 | if len(txt) > 1200: 24 | return '' 25 | 26 | WebDriverWait(driver, timeout=10).until(lambda d: d.find_element(By.ID, 'character_counter_content')) 27 | driver.find_element(By.ID, 'character_counter_content').send_keys(txt) 28 | 29 | wait_driver_click('spell_check') 30 | driver.implicitly_wait(delay) 31 | wait_driver_click('spell_done_all') 32 | driver.implicitly_wait(1) 33 | 34 | WebDriverWait(driver, timeout=10).until(lambda d: d.find_element(By.CSS_SELECTOR, '#checker_preview')) 35 | clean_txt = driver.find_element(By.CSS_SELECTOR, '#checker_preview').text 36 | driver.refresh() 37 | return clean_txt 38 | 39 | 40 | dt = pd.read_csv('./data/preprocessed_qa_spacing_pre_word.csv') 41 | new_q = [] 42 | new_a = [] 43 | options = set_options() 44 | 45 | # options = webdriver.ChromeOptions() 46 | # options.add_argument("--headless") 47 | # options.add_argument('--no-sandbox') 48 | # options.add_argument("--single-process") 49 | # options.add_argument("--disable-dev-shm-usage") 50 | 51 | # driver = webdriver.Chrome(options=options) 52 | driver = webdriver.Chrome() 53 | driver.get('https://www.saramin.co.kr/zf_user/tools/character-counter') 54 | 55 | q_v = dt['question'].values 56 | a_v = dt['answer'].values 57 | 58 | for i in trange(len(dt)): 59 | if len(q_v[i]) > 1200: 60 | continue 61 | try: 62 | new_q.append(spell_check(q_v[i])) 63 | new_a.append(a_v[i]) 64 | except: 65 | continue 66 | 67 | driver.quit() 68 | 69 | 70 | clean_data = pd.DataFrame({'question' : new_q, 'answer' :new_a}) 71 | 72 | clean_data.to_csv('./data/preprocessed_qa_spacing_spell.csv', index=False) -------------------------------------------------------------------------------- /data_pipeline/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def utilize_loggers(name): 6 | logger = logging.getLogger(__name__) 7 | logger_name = os.path.splitext(os.path.basename(name))[0] 8 | 9 | logging.basicConfig( 10 | filename=f"{logger_name}.log", 11 | format="%(asctime)s | %(levelname)s: %(message)s", 12 | level=logging.INFO, 13 | datefmt="%Y/%m/%d %I:%M:%S %p", 14 | ) 15 | 16 | stream_handler = logging.StreamHandler() 17 | stream_handler.setLevel(logging.INFO) 18 | stream_handler.setFormatter( 19 | logging.Formatter("%(asctime)s | %(levelname)s: %(message)s") 20 | ) 21 | logger.addHandler(stream_handler) 22 | return logger 23 | 24 | -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | yarn.lock -------------------------------------------------------------------------------- /frontend/README.md: -------------------------------------------------------------------------------- 1 | # LawBot - Frontend 2 | 3 | ## ⚠️ Requirements for Web Frontend 4 | 5 | Below is the commands for installing dependency packages to run web frontend.
6 | Run the following commands in proper order. 7 | ```bash 8 | # install react 9 | apt install curl 10 | curl https://raw.githubusercontent.com/creationix/nvm/master/install.sh | bash 11 | source ~/.bashr 12 | nvm install 18.04.0 13 | 14 | # Install Tailwind CSS 15 | npm install -g yarn 16 | yarn add tailwindcss postcss autoprefixer 17 | npx tailwindcss init 18 | npm i tailwindcss-animated 19 | ``` 20 | 21 | When installation is completed, please run the following commands to check if proper version is installed.
22 | Please refer to the README file inside the `frontend` directory for the executable file. 23 | 24 | ```bash 25 | node -v # v18.04.1 26 | npm -v # v8.11.0 27 | ``` 28 | 29 | ## 💻 Getting Started with Create React App 30 | 31 | ```bash 32 | yarn start # npm start 33 | ``` 34 | 35 | ## How to Configure nginx 36 | 37 | - Below is the commands for installing dependency packages to run web frontend 38 | 39 | ``` 40 | sudo apt install nginx 41 | sudo rm /etc/nginx/sites-available/default 42 | sudo rm /etc/nginx/sites-enabled/default 43 | 44 | sudo vim /etc/nginx/sites-available/frontend.conf 45 | sudo ln -s /etc/nginx/sites-available/myapp.conf /etc/nginx/sites-enabled/myapp.conf 46 | ``` 47 | 48 | - The configuration files that you need to write in `/etc/nginx/sites-available` 49 | 50 | ```conf 51 | server { 52 | listen 80; 53 | listen [::]:80; 54 | 55 | server_name yoonseul.link ; 56 | 57 | root /home/ubuntu/level3_nlp_finalproject-nlp-08/frontend/build; 58 | index index.html; 59 | 60 | location /generate { 61 | 62 | proxy_pass https://backend_server; 63 | proxy_connect_timeout 500; 64 | proxy_send_timeout 500; 65 | proxy_read_timeout 500; 66 | } 67 | } 68 | ``` 69 | 70 | 71 | ## How to start nginx 72 | ```bash 73 | sudo systemctl stop nginx # nginx 중단 74 | sudo systemctl start nginx # nginx 시작 75 | `````` -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "frontend", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@testing-library/jest-dom": "^5.16.5", 7 | "@testing-library/react": "^13.4.0", 8 | "@testing-library/user-event": "^13.5.0", 9 | "autoprefixer": "^10.4.14", 10 | "axios": "^1.4.0", 11 | "postcss": "^8.4.26", 12 | "react": "^18.2.0", 13 | "react-dom": "^18.2.0", 14 | "react-scripts": "5.0.1", 15 | "tailwindcss": "^3.3.3", 16 | "tailwindcss-animated": "^1.0.1", 17 | "web-vitals": "^2.1.4" 18 | }, 19 | "scripts": { 20 | "start": "react-scripts start", 21 | "build": "react-scripts build", 22 | "test": "react-scripts test", 23 | "eject": "react-scripts eject" 24 | }, 25 | "eslintConfig": { 26 | "extends": [ 27 | "react-app", 28 | "react-app/jest" 29 | ] 30 | }, 31 | "browserslist": { 32 | "production": [ 33 | ">0.2%", 34 | "not dead", 35 | "not op_mini all" 36 | ], 37 | "development": [ 38 | "last 1 chrome version", 39 | "last 1 firefox version", 40 | "last 1 safari version" 41 | ] 42 | }, 43 | "proxy": "https://api.yoonseul.link" 44 | } 45 | -------------------------------------------------------------------------------- /frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 16 | 17 | 26 | Lawbot 27 | 28 | 29 | 30 | 31 | 32 |
33 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /frontend/public/lawbot.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/frontend/public/lawbot.ico -------------------------------------------------------------------------------- /frontend/public/lawbot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/frontend/public/lawbot.png -------------------------------------------------------------------------------- /frontend/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "lawbot.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "lawbot.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "lawbot.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } -------------------------------------------------------------------------------- /frontend/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /frontend/src/App.js: -------------------------------------------------------------------------------- 1 | import Header from './components/Header'; 2 | import Loader from './components/Loader'; 3 | import ChattingSideBar from './components/ChattingSideBar'; 4 | import SimilarPrecedent from './components/SimilarPrecedent'; 5 | import { useState } from 'react'; 6 | import TypingAnimation from './components/TypingAnimation'; 7 | 8 | function App() { 9 | const [loading, setLoading] = useState(false) 10 | const [message, setMessage] = useState(""); 11 | const [sentMessage, setSentMessage] = useState(""); 12 | const [aianswer, setAianswer] = useState(""); 13 | const [precedents, setPrecedents] = useState(null); 14 | const ans = "\n\nAI가 작성한 답변이며 실제와 다를 수 있으므로 참고 자료로만 활용하시고, 자세한 상담을 원하시는 경우에는 전문 법조인의 상담을 받으시기 바랍니다. LawBot은 법적 책임을 지지 않는다는 점 참고바랍니다." 15 | 16 | const messagehandler = async (e) => { 17 | e.preventDefault(); 18 | if (!loading){ 19 | setLoading(true); 20 | setMessage(""); 21 | setSentMessage(message); 22 | setAianswer(""); 23 | setPrecedents(null); 24 | try { 25 | console.log(message); 26 | console.log(message.trim().length) 27 | if (message.trim().length <= 5) { 28 | setAianswer("죄송합니다. 입력하신 \""+message+"\"는 너무 짧아 정확한 답변을 제공하기 어려운 점 양해해주시기 바랍니다. 정확하고 효과적인 답변을 위해 더욱 구체적으로 질문해주시기 바랍니다."); 29 | } else { 30 | const response = await fetch('/generate', { 31 | method: 'POST', 32 | headers: { 33 | 'Content-Type': 'application/json', 34 | }, 35 | body: JSON.stringify({ q_sentence: message }), 36 | }); 37 | 38 | const data = await response.json(); 39 | console.log(message); 40 | console.log(data); 41 | 42 | if (data != null) { 43 | if (data.answer_sentence == null | data.answer_sentence =="\n") { 44 | setAianswer("죄송합니다. 저는 법률 상담을 도와드리는 AI LawBot입니다. 법률 내용 외의 질문은 답변해 드리지 않는 점 참고 부탁드립니다. 법률적인 질문이 있으시다면 언제든지 물어보세요. 제가 최대한 자연스럽고 이해하기 쉽게 답변해 드리겠습니다. 어떤 도움이 필요하신가요?"); 45 | setPrecedents(null); 46 | } else { 47 | setAianswer(data.answer_sentence+ans); 48 | setPrecedents(data.similar_precedent); 49 | } 50 | } 51 | } 52 | } catch (error) { 53 | console.error("에러 발생:", error); 54 | } finally { 55 | setLoading(false); 56 | }} 57 | }; 58 | 59 | return ( 60 |
61 |
62 |
63 |
64 |
65 | 66 |
67 |
68 |
69 |
70 |
71 |
72 | {sentMessage && ( 73 |
74 |
75 |
76 | U 77 |
78 |
79 | 80 |
81 |
82 |
83 | )} 84 | {aianswer && ( 85 |
86 |
87 |
90 | L 91 |
92 |
95 | 96 |
97 |
98 |
)} 99 | {loading && ()} 100 |
101 |
102 |
103 |
108 |
109 |
112 | 119 | 125 | 126 |
127 |
128 | 129 |
130 |
131 | setMessage(e.target.value)} 137 | /> 138 |
139 |
140 |
141 | 163 |
164 | 165 |
166 |
167 |
168 |
169 | 170 |
171 | 172 |
173 |
174 | 175 |
176 | ); 177 | } 178 | 179 | export default App; 180 | -------------------------------------------------------------------------------- /frontend/src/asset/lawbot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/frontend/src/asset/lawbot.png -------------------------------------------------------------------------------- /frontend/src/asset/spinner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/frontend/src/asset/spinner.gif -------------------------------------------------------------------------------- /frontend/src/components/ChattingSideBar.js: -------------------------------------------------------------------------------- 1 | function ChattingSideBar() { 2 | return ( 3 |
4 | 37 |
38 | ) 39 | } 40 | export default ChattingSideBar 41 | -------------------------------------------------------------------------------- /frontend/src/components/Header.js: -------------------------------------------------------------------------------- 1 | import lawbot from "../asset/lawbot.png" 2 | 3 | function Header() { 4 | return ( 5 | 19 | ) 20 | } 21 | export default Header 22 | -------------------------------------------------------------------------------- /frontend/src/components/Loader.js: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | 3 | function Loader() { 4 | const phrases = [ 5 | "법조문을 읽고 있습니다...", 6 | "당신의 법률 문제를 위한 판례를 찾아보는 중입니다. 잠시만 기다려 주세요.", 7 | "답변을 생성하는데 최대 30초 정도의 시간이 걸립니다.", 8 | "당신의 질문에 대한 법적 해결책을 찾아서 조금 있다가 돌아올게요. 티타임을 즐기세요.", 9 | "답변을 생성중입니다 잠시만 기다려주세요.", 10 | ]; 11 | 12 | const [currentPhraseIndex, setCurrentPhraseIndex] = useState(0); 13 | const [visibleText, setVisibleText] = useState(""); 14 | 15 | const typingDelay = 100; 16 | const nextPhraseDelay = 3000; // 3초 17 | 18 | useEffect(() => { 19 | const typeText = (currentIndex, currentText) => { 20 | if (currentIndex < currentText.length) { 21 | setVisibleText(currentText.substring(0, currentIndex + 1)); 22 | setTimeout(() => typeText(currentIndex + 1, currentText), typingDelay); 23 | } else { 24 | setTimeout(() => { 25 | setCurrentPhraseIndex((prevIndex) => (prevIndex + 1) % phrases.length); 26 | setVisibleText(""); 27 | }, nextPhraseDelay); 28 | } 29 | }; 30 | 31 | const currentPhrase = phrases[currentPhraseIndex]; 32 | typeText(0, currentPhrase); 33 | }, [currentPhraseIndex]); 34 | 35 | return ( 36 |
37 |
38 |
39 | L 40 |
41 |
42 |
43 | 48 |
{visibleText}
49 |
50 |
51 |
52 |
53 | ); 54 | } 55 | 56 | export default Loader; 57 | -------------------------------------------------------------------------------- /frontend/src/components/SimilarPrecedent.js: -------------------------------------------------------------------------------- 1 | import PrecedentCard from "./SimilarPrecedentComponents/PrecedentCard"; 2 | 3 | function SimilarPresdent({ precedents }) { 4 | const validPrecedents = precedents || []; 5 | 6 | return ( 7 |
8 | 21 |
22 | ); 23 | } 24 | 25 | export default SimilarPresdent; -------------------------------------------------------------------------------- /frontend/src/components/SimilarPrecedentComponents/PrecedentCard.js: -------------------------------------------------------------------------------- 1 | function PrecedentCard({ precedent, number }) { 2 | return ( 3 |
4 |
유사 판례 조항 {number}
5 |

사건 이름 : {precedent.case_name}

6 |

사건 번호 : {precedent.case_number}

7 |

사건 분류 : {precedent.case_type}

8 |

관련 법 조항 :{precedent.ref_article}

9 | 10 | Read more 11 | 14 | 15 |
16 | ) 17 | } 18 | export default PrecedentCard 19 | -------------------------------------------------------------------------------- /frontend/src/components/TypingAnimation.js: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | 3 | function TypingAnimation({ text }) { 4 | const [visibleText, setVisibleText] = useState(""); 5 | const typingDelay = 20; // 타이핑 딜레이 시간(ms) 6 | 7 | useEffect(() => { 8 | const typeText = (currentIndex) => { 9 | if (currentIndex < text.length) { 10 | setVisibleText(text.substring(0, currentIndex + 1)); 11 | setTimeout(() => typeText(currentIndex + 1), typingDelay); 12 | } 13 | }; 14 | 15 | typeText(0); 16 | }, [text]); 17 | 18 | return
{visibleText}
; 19 | } 20 | 21 | export default TypingAnimation; -------------------------------------------------------------------------------- /frontend/src/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; -------------------------------------------------------------------------------- /frontend/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | 10 | 11 | 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | reportWebVitals(); 18 | -------------------------------------------------------------------------------- /frontend/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /frontend/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: ["./src/**/*.{html,js}"], 4 | theme: { 5 | extend: {}, 6 | }, 7 | plugins: [ 8 | require('tailwindcss-animated') 9 | ] 10 | } 11 | 12 | -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | LLM/train/data 2 | LLM/train/val_data 3 | ../.idea 4 | BERT/data 5 | Retrieval/bert_retrieval/data 6 | Retrieval/bm25_retrieval/all_data 7 | Filter/.idea 8 | Filter/data -------------------------------------------------------------------------------- /model/BERT/inference/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sentence_transformers import SentenceTransformer 6 | from tqdm import tqdm 7 | import pickle 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | from .utils import load_vector_data 10 | 11 | 12 | def BERT_infer(input): 13 | model_name = "jhgan/ko-sroberta-multitask" 14 | model = SentenceTransformer(model_name) 15 | model.to("cuda:0") 16 | 17 | input_vector = model.encode(input) 18 | input_vecotr = np.expand_dims(input_vector, axis=0) 19 | 20 | base_path = os.path.join(os.path.dirname(__file__)) 21 | 22 | text_data = np.array(pd.read_csv(base_path + "/../data/law_data/law_data_drop.csv")) 23 | vector_data = load_vector_data( 24 | base_path + "/../data/law_data/law_data_drop_vector.bin" 25 | ) 26 | 27 | cos_sim = cosine_similarity(input_vecotr, vector_data) 28 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3] 29 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3] 30 | 31 | pan_list = [] 32 | 33 | for i, index in enumerate(top_question): 34 | if data_cosine[i] >= 0.5: 35 | pan_list.append( 36 | f"case Number : {text_data[index][0]} judgementAbstract : {text_data[index][4]} judgementNote :{text_data[index][9]}" 37 | ) 38 | 39 | return pan_list 40 | 41 | 42 | if __name__ == "__main__": 43 | BERT_infer("상원이형과 이혼을 하는것은 중죄이고 죄질이 나쁘기 떄문에 징역 10년입니다.") 44 | -------------------------------------------------------------------------------- /model/BERT/inference/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | def load_vector_data(path): 6 | if os.path.isfile(path): 7 | with open(path, "rb") as fr: 8 | vector_data = pickle.load(fr) 9 | else: 10 | print("판례 데이터가 존재하지 않습니다.") 11 | vector_data = None 12 | return vector_data 13 | -------------------------------------------------------------------------------- /model/BERT/make_vector_dataset/preprocessing_law_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sentence_transformers import SentenceTransformer 4 | from tqdm import tqdm 5 | import pickle 6 | 7 | model_name = "jhgan/ko-sroberta-multitask" 8 | model = SentenceTransformer(model_name) 9 | model.to("cuda:0") 10 | 11 | df = pd.read_csv("../data/law_data/law_data.csv", encoding="UTF-8") 12 | df = df.dropna( 13 | subset=["caseName", "judgementAbstract", "precedentText", "judgementNote"] 14 | ) 15 | df.to_csv("../data/law_data/law_data_drop.csv", index=False) 16 | np_df = np.array(df) 17 | 18 | vector_list = [] 19 | for i in tqdm(range(len(np_df))): 20 | judgementAbstract = np_df[i][4] 21 | judgementNote = np_df[i][9] 22 | 23 | judgementNote_vector = model.encode(judgementAbstract + judgementNote) 24 | vector_list.append(list(judgementNote_vector)) 25 | 26 | with open("../data/law_data/law_data_drop_vector.bin", "wb") as fw: 27 | pickle.dump(vector_list, fw) 28 | -------------------------------------------------------------------------------- /model/BERT/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sentence_transformers import SentenceTransformer 4 | from tqdm import tqdm 5 | import os 6 | 7 | model_name = "jhgan/ko-sroberta-multitask" 8 | model = SentenceTransformer(model_name) 9 | model.to("cuda:0") 10 | 11 | df = pd.read_csv("./data/law_data.csv") 12 | df = df.dropna( 13 | subset=["caseName", "judgementAbstract", "precedentText", "judgementNote"] 14 | ) 15 | np_df = np.array(df) 16 | 17 | Ab_list = [] 18 | for i in tqdm(range(len(np_df))): 19 | Ab = np_df[i][4] 20 | Ab_query = model.encode(Ab) 21 | Ab_list.append(list(Ab_query)) 22 | 23 | df["Ab_vector"] = Ab_list 24 | df.to_csv("Ab_vector_law_data_sroberta.csv", index=False) 25 | -------------------------------------------------------------------------------- /model/Filter/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datasets import concatenate_datasets, Dataset, load_dataset 3 | import os 4 | 5 | 6 | class Autodata: 7 | def __init__(self, data_folder="./data"): 8 | self.data_foloder = data_folder 9 | self.concat_dataset = self.concat_datasets(self.data_foloder) 10 | 11 | def concat_datasets(self, data_foloder): 12 | datasets = [] 13 | pd_datasets = [] 14 | for file_name in os.listdir(data_foloder): 15 | if file_name.endswith(".csv"): 16 | file_path = os.path.join(data_foloder, file_name) 17 | dataset = pd.read_csv(file_path) 18 | dataframe = dataset[["question", "answer"]] 19 | pd_datasets.append(dataframe) 20 | dataset = Dataset.from_pandas(dataframe) 21 | datasets.append(dataset) 22 | 23 | combined_dataset = concatenate_datasets(datasets) 24 | 25 | return combined_dataset 26 | 27 | def load_instruction_dataset(self, dataset_id): 28 | koalpaca_data = load_dataset(dataset_id) 29 | data = koalpaca_data["train"] 30 | data = data.rename_column("instruction", "question") 31 | question = data["question"] 32 | return question 33 | 34 | def label_indexing(self, data, state): 35 | if state == 1: 36 | answer = 1 37 | else: 38 | answer = 0 39 | answer_list = [answer] * len(data) 40 | 41 | return Dataset.from_dict({"question": data, "target": answer_list}) 42 | -------------------------------------------------------------------------------- /model/Filter/dataloader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from tqdm.auto import tqdm 4 | from transformers import AutoTokenizer 5 | 6 | 7 | class CustomDataset(torch.utils.data.Dataset): 8 | def __init__( 9 | self, 10 | data_file, 11 | state, 12 | text_columns, 13 | target_columns, 14 | max_length=256, 15 | model_name="klue/roberta-small", 16 | ): 17 | self.state = state 18 | self.data = data_file 19 | self.text_columns = text_columns 20 | self.max_length = max_length 21 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 22 | 23 | if self.state == "test": 24 | self.inputs = self.preprocessing(self.data) 25 | else: 26 | self.target_columns = target_columns if target_columns is not None else [] 27 | self.inputs, self.targets = self.preprocessing(self.data) 28 | 29 | def __getitem__(self, idx): 30 | if self.state == "test": 31 | return {"input_ids": torch.tensor(self.inputs[idx], dtype=torch.long)} 32 | else: 33 | return { 34 | "input_ids": torch.tensor(self.inputs[idx], dtype=torch.long), 35 | "labels": torch.tensor(self.targets[idx], dtype=torch.long), 36 | } 37 | 38 | def __len__(self): 39 | return len(self.inputs) 40 | 41 | def tokenizing(self, dataframe: pd.DataFrame) -> list: 42 | """ 43 | 토크나이징 44 | Args : 45 | dataframe (DataFrame): 토크나이징할 데이터 46 | Return : 47 | data (list) : 학습할 문장 토큰 리스트 48 | """ 49 | data = [] 50 | for item in tqdm( 51 | dataframe["question"], desc="Tokenizing", total=len(dataframe["question"]) 52 | ): 53 | text = item 54 | # text = [item for text_column in self.text_columns] 55 | outputs = self.tokenizer( 56 | text, 57 | add_special_tokens=True, 58 | padding="max_length", 59 | truncation=True, 60 | max_length=self.max_length, 61 | ) 62 | data.append(outputs["input_ids"]) 63 | return data 64 | 65 | def preprocessing(self, data): 66 | inputs = self.tokenizing(data) 67 | if self.state == "test": 68 | return inputs 69 | else: 70 | try: 71 | targets = data[self.target_columns] 72 | except: 73 | targets = [] 74 | return inputs, targets 75 | -------------------------------------------------------------------------------- /model/Filter/infer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | def infer(): 7 | text = "안녕하세요 김주원입니다." 8 | base_model_name = "monologg/koelectra-small-v3-discriminator" 9 | model = AutoModelForSequenceClassification.from_pretrained( 10 | "kfkas/legal-question-filter-koelectra", 11 | num_labels=2, 12 | ignore_mismatched_sizes=True, 13 | ) 14 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 15 | inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") 16 | outputs = model(**inputs) 17 | logits = outputs.logits.detach().cpu() 18 | pr = F.softmax(logits).numpy() 19 | arg = np.argmax(pr, axis=1) 20 | print(logits) 21 | print(pr) 22 | print(int(arg)) 23 | if int(arg) == 0 and (pr[0][0] >= 0.98).all(): 24 | print("법률입니다") 25 | else: 26 | print("법률 질문이 아닙니다(bard,chatgpt 등등 다른 API 호출하면 좋을거 같아용)") 27 | 28 | 29 | if __name__ == "__main__": 30 | infer() 31 | -------------------------------------------------------------------------------- /model/Filter/train.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification 2 | from data_preprocessing import Autodata 3 | from datasets import concatenate_datasets 4 | from dataloader import CustomDataset 5 | from sklearn.model_selection import train_test_split 6 | from model.Filter.utils import compute_metrics 7 | 8 | 9 | def train(): 10 | model_name = "monologg/koelectra-small-v3-discriminator" 11 | data = Autodata("./data") 12 | legal_dataset = data.concat_dataset["question"] 13 | legal_answer_dataset = data.concat_dataset["answer"] 14 | alpaca_dataset = data.load_instruction_dataset("nlpai-lab/kullm-v2") 15 | 16 | legal_data = data.label_indexing(legal_dataset, state=0) 17 | legal_dataset_answer = data.label_indexing(legal_answer_dataset, state=0) 18 | alpaca_data = data.label_indexing(alpaca_dataset, state=1) 19 | 20 | total_data = concatenate_datasets([legal_data, alpaca_data, legal_dataset_answer]) 21 | train_dataset, val_dataset = train_test_split( 22 | total_data, test_size=0.2, random_state=42 23 | ) 24 | 25 | train_data = CustomDataset( 26 | data_file=train_dataset, 27 | model_name=model_name, 28 | text_columns="question", 29 | target_columns="target", 30 | max_length=256, 31 | state="train", 32 | ) 33 | val_data = CustomDataset( 34 | data_file=val_dataset, 35 | model_name=model_name, 36 | text_columns="question", 37 | target_columns="target", 38 | max_length=256, 39 | state="train", 40 | ) 41 | 42 | model = AutoModelForSequenceClassification.from_pretrained( 43 | model_name, num_labels=2, ignore_mismatched_sizes=True 44 | ) 45 | 46 | args = TrainingArguments( 47 | output_dir="output_dir", 48 | evaluation_strategy="epoch", 49 | save_strategy="epoch", 50 | learning_rate=1e-5, 51 | per_device_train_batch_size=64, 52 | per_device_eval_batch_size=64, 53 | num_train_epochs=10, 54 | weight_decay=0.01, 55 | load_best_model_at_end=True, 56 | dataloader_num_workers=4, 57 | logging_steps=50, 58 | seed=42, 59 | group_by_length=True, 60 | ) 61 | 62 | trainer = Trainer( 63 | model=model, 64 | args=args, 65 | train_dataset=train_data, 66 | eval_dataset=val_data, 67 | compute_metrics=compute_metrics, 68 | ) 69 | 70 | trainer.train() 71 | 72 | 73 | if __name__ == "__main__": 74 | train() 75 | -------------------------------------------------------------------------------- /model/Filter/utils.py: -------------------------------------------------------------------------------- 1 | import evaluate 2 | import numpy as np 3 | 4 | 5 | def compute_metrics(eval_pred): 6 | accuracy = evaluate.load("f1") 7 | predictions, labels = eval_pred 8 | predictions = np.argmax(predictions, axis=1) 9 | return accuracy.compute(predictions=predictions, references=labels) 10 | -------------------------------------------------------------------------------- /model/LLM/evaluation/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datasets import concatenate_datasets 3 | from datasets import load_dataset 4 | import os 5 | 6 | 7 | class PPL_Autodata: 8 | def __init__(self, data_folder="./eval_data_legal"): 9 | self.data_foloder = data_folder 10 | self.concat_dataset = self.concat_datasets(self.data_foloder) 11 | self.preprocess_data = self.preprocessing_data(self.concat_dataset) 12 | 13 | def concat_datasets(self, folder_path): 14 | datasets = [] 15 | for file_name in os.listdir(folder_path): 16 | if file_name.endswith(".csv"): 17 | file_path = os.path.join(folder_path, file_name) 18 | dataset = load_dataset("csv", data_files=file_path) 19 | datasets.append(dataset["train"]) 20 | 21 | combined_dataset = concatenate_datasets(datasets) 22 | if len(combined_dataset.features) > 2: 23 | data = combined_dataset.remove_columns("title") 24 | else: 25 | data = combined_dataset 26 | return data 27 | 28 | def preprocessing_data(self, dataset): 29 | data = dataset.map( 30 | lambda x: { 31 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}" 32 | } 33 | ) 34 | return data 35 | -------------------------------------------------------------------------------- /model/LLM/evaluation/dialogue_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 5 | 6 | from inference import infer 7 | import pandas as pd 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from peft import PeftModel, PeftConfig 11 | from tqdm.auto import tqdm 12 | 13 | class LawyerEvaluation: 14 | def __init__(self, path="lawyer_question.csv", model_name="uomnf97/LawBot-level2-final-preprocessing-v3"): 15 | self.data = pd.read_csv(f"./eval_data_legal/{path}") 16 | self.model_name = model_name 17 | self.answer = False 18 | 19 | def generate_answer(self): 20 | answer_list = [] 21 | 22 | device = ( 23 | torch.device("cuda:0") if torch.cuda.is_available( 24 | ) else torch.device("cpu") 25 | ) 26 | peft_model_id = self.model_name 27 | config = PeftConfig.from_pretrained(peft_model_id) 28 | model = AutoModelForCausalLM.from_pretrained( 29 | config.base_model_name_or_path, device_map={"": 0}, torch_dtype=torch.float16 30 | ) 31 | model = PeftModel.from_pretrained( 32 | model, peft_model_id, torch_dtype=torch.float16) 33 | tokenizer = AutoTokenizer.from_pretrained( 34 | config.base_model_name_or_path 35 | ) 36 | model.eval() 37 | model.config.use_cache = ( 38 | True # silence the warnings. Please re-enable for inference! 39 | ) 40 | model.float() 41 | tokenizer.pad_token = tokenizer.eos_token 42 | 43 | for i in tqdm(range(len(self.data)), desc="processing evaluation data"): 44 | data = self.data.iloc[i]["question"] 45 | answer_list.append(infer.gen(data, model=model, 46 | tokenizer=tokenizer, device=device)) 47 | 48 | 49 | self.data["answer"] = answer_list 50 | self.answer = True 51 | 52 | def to_csv(self): 53 | if self.answer == True: 54 | self.data.to_csv(f"./eval_data_legal/lawyer_val_with_answer.csv") 55 | else: 56 | print("답안을 생성해주세요!") 57 | 58 | 59 | if __name__ == "__main__": 60 | lawyer_evaluation = LawyerEvaluation() 61 | lawyer_evaluation.generate_answer() 62 | lawyer_evaluation.to_csv() 63 | -------------------------------------------------------------------------------- /model/LLM/evaluation/evaluate_mertrics.py: -------------------------------------------------------------------------------- 1 | from data_preprocessing import PPL_Autodata 2 | from petf_ppl import Perplexity_Petf 3 | import dataset as dataset 4 | import evaluate 5 | 6 | 7 | data = PPL_Autodata("./eval_data_legal").preprocess_data 8 | petf_model_id = "kfkas/LawBot-v1_koalpaca_legalQA_easylaw_train" 9 | normal_model_id = "nlpai-lab/kullm-polyglot-5.8b-v2" 10 | 11 | use = "petf" # petf or normal 12 | 13 | if use == "petf": 14 | perplexity = Perplexity_Petf() 15 | results = perplexity.compute( 16 | model_id=petf_model_id, 17 | add_start_token=False, 18 | predictions=data["text"], 19 | max_length=256, 20 | batch_size=4, 21 | ) 22 | print(list(results.keys())) 23 | print(round(results["mean_perplexity"], 2)) 24 | print(round(results["perplexities"][0], 2)) 25 | else: 26 | perplexity = evaluate.load("perplexity", module_type="metric") 27 | results = perplexity.compute( 28 | model_id=normal_model_id, 29 | add_start_token=False, 30 | predictions=data["text"], 31 | max_length=256, 32 | batch_size=4, 33 | ) 34 | print(list(results.keys())) 35 | print(round(results["mean_perplexity"], 2)) 36 | print(round(results["perplexities"][0], 2)) 37 | -------------------------------------------------------------------------------- /model/LLM/evaluation/petf_ppl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Perplexity Metric.""" 15 | 16 | import datasets 17 | import numpy as np 18 | import torch 19 | from peft import PeftConfig, PeftModel 20 | from torch.nn import CrossEntropyLoss 21 | from transformers import AutoModelForCausalLM, AutoTokenizer 22 | 23 | import evaluate 24 | from evaluate import logging 25 | 26 | 27 | _CITATION = """\ 28 | """ 29 | 30 | _DESCRIPTION = """ 31 | Perplexity (PPL) is one of the most common evaluation for evaluating language models. 32 | It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. 33 | For more information, see https://huggingface.co/docs/transformers/perplexity 34 | """ 35 | 36 | _KWARGS_DESCRIPTION = """ 37 | Args: 38 | model_id (str): model used for calculating Perplexity 39 | NOTE: Perplexity can only be calculated for causal language models. 40 | This includes models such as gpt2, causal variations of bert, 41 | causal versions of t5, and more (the full list can be found 42 | in the AutoModelForCausalLM documentation here: 43 | https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) 44 | predictions (list of str): input text, each separate text snippet 45 | is one list entry. 46 | batch_size (int): the batch size to run texts through the model. Defaults to 16. 47 | add_start_token (bool): whether to add the start token to the texts, 48 | so the perplexity can include the probability of the first word. Defaults to True. 49 | device (str): device to run on, defaults to 'cuda' when available 50 | Returns: 51 | perplexity: dictionary containing the perplexity scores for the texts 52 | in the input list, as well as the mean perplexity. If one of the input texts is 53 | longer than the max input length of the model, then it is truncated to the 54 | max length for the perplexity computation. 55 | Examples: 56 | Example 1: 57 | >>> perplexity = evaluate.load("perplexity", module_type="metric") 58 | >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] 59 | >>> results = perplexity.compute(model_id='gpt2', 60 | ... add_start_token=False, 61 | ... predictions=input_texts) # doctest:+ELLIPSIS 62 | >>> print(list(results.keys())) 63 | ['perplexities', 'mean_perplexity'] 64 | >>> print(round(results["mean_perplexity"], 0)) 65 | 647.0 66 | >>> print(round(results["perplexities"][0], 0)) 67 | 32.0 68 | Example 2: 69 | >>> from datasets import load_dataset 70 | >>> perplexity = evaluate.load("perplexity", module_type="metric") 71 | >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP 72 | >>> input_texts = [s for s in input_texts if s!=''] 73 | >>> results = perplexity.compute(model_id='gpt2', 74 | ... predictions=input_texts) 75 | >>> print(list(results.keys())) 76 | ['perplexities', 'mean_perplexity'] 77 | >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP 78 | 576.76 79 | >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP 80 | 889.28 81 | """ 82 | 83 | 84 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 85 | class Perplexity_Petf(evaluate.Metric): 86 | def _info(self): 87 | return evaluate.MetricInfo( 88 | module_type="metric", 89 | description=_DESCRIPTION, 90 | citation=_CITATION, 91 | inputs_description=_KWARGS_DESCRIPTION, 92 | features=datasets.Features( 93 | { 94 | "predictions": datasets.Value("string"), 95 | } 96 | ), 97 | reference_urls=["https://huggingface.co/docs/transformers/perplexity"], 98 | ) 99 | 100 | def _compute( 101 | self, 102 | predictions, 103 | model_id, 104 | batch_size: int = 4, 105 | add_start_token: bool = True, 106 | device=None, 107 | max_length=None, 108 | ): 109 | if device is not None: 110 | assert device in [ 111 | "gpu", 112 | "cpu", 113 | "cuda", 114 | ], "device should be either gpu or cpu." 115 | if device == "gpu": 116 | device = "cuda" 117 | else: 118 | device = "cuda" if torch.cuda.is_available() else "cpu" 119 | 120 | config = PeftConfig.from_pretrained(model_id) 121 | model = AutoModelForCausalLM.from_pretrained( 122 | config.base_model_name_or_path, device_map={"": 0} 123 | ) 124 | model = PeftModel.from_pretrained( 125 | model, model_id 126 | ) # ,quantization_config=bnb_config) 127 | 128 | model = model.to(device) 129 | 130 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) 131 | 132 | # if batch_size > 1 (which generally leads to padding being required), and 133 | # if there is not an already assigned pad_token, assign an existing 134 | # special token to also be the padding token 135 | if tokenizer.pad_token is None and batch_size > 1: 136 | existing_special_tokens = list( 137 | tokenizer.special_tokens_map_extended.values() 138 | ) 139 | # check that the model already has at least one special token defined 140 | assert ( 141 | len(existing_special_tokens) > 0 142 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." 143 | # assign one of the special tokens to also be the pad token 144 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) 145 | 146 | if add_start_token and max_length: 147 | # leave room for token to be added: 148 | assert ( 149 | tokenizer.bos_token is not None 150 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" 151 | max_tokenized_len = max_length - 1 152 | else: 153 | max_tokenized_len = max_length 154 | 155 | encodings = tokenizer( 156 | predictions, 157 | add_special_tokens=False, 158 | padding=True, 159 | truncation=True if max_tokenized_len else False, 160 | max_length=max_length, 161 | return_tensors="pt", 162 | return_attention_mask=True, 163 | ).to(device) 164 | 165 | encoded_texts = encodings["input_ids"] 166 | attn_masks = encodings["attention_mask"] 167 | 168 | # check that each input is long enough: 169 | if add_start_token: 170 | assert torch.all( 171 | torch.ge(attn_masks.sum(1), 1) 172 | ), "Each input text must be at least one token long." 173 | else: 174 | assert torch.all( 175 | torch.ge(attn_masks.sum(1), 2) 176 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." 177 | 178 | ppls = [] 179 | loss_fct = CrossEntropyLoss(reduction="none") 180 | 181 | for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): 182 | end_index = min(start_index + batch_size, len(encoded_texts)) 183 | encoded_batch = encoded_texts[start_index:end_index] 184 | attn_mask = attn_masks[start_index:end_index] 185 | 186 | if add_start_token: 187 | bos_tokens_tensor = torch.tensor( 188 | [[tokenizer.bos_token_id]] * encoded_batch.size(dim=0) 189 | ).to(device) 190 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) 191 | attn_mask = torch.cat( 192 | [ 193 | torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to( 194 | device 195 | ), 196 | attn_mask, 197 | ], 198 | dim=1, 199 | ) 200 | 201 | labels = encoded_batch 202 | 203 | with torch.no_grad(): 204 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits 205 | 206 | shift_logits = out_logits[..., :-1, :].contiguous() 207 | shift_labels = labels[..., 1:].contiguous() 208 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 209 | 210 | perplexity_batch = torch.exp( 211 | ( 212 | loss_fct(shift_logits.transpose(1, 2), shift_labels) 213 | * shift_attention_mask_batch 214 | ).sum(1) 215 | / shift_attention_mask_batch.sum(1) 216 | ) 217 | 218 | ppls += perplexity_batch.tolist() 219 | 220 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} 221 | -------------------------------------------------------------------------------- /model/LLM/evaluation/ppl.py: -------------------------------------------------------------------------------- 1 | import dataset as dataset 2 | import evaluate 3 | from datasets import load_dataset 4 | 5 | from petf_ppl import Perplexity_Petf 6 | import pandas as pd 7 | 8 | 9 | # perplexity = evaluate.load("perplexity", module_type="metric") 10 | perplexity = Perplexity_Petf() 11 | 12 | path = "../train/data/easy_law.csv" 13 | db = pd.read_csv(path) 14 | dataset = load_dataset("csv", data_files=path)["train"] 15 | 16 | data = dataset.map( 17 | lambda x: { 18 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}" 19 | } 20 | ) 21 | 22 | results = perplexity.compute( 23 | model_id="kfkas/LawBot-v1_koalpaca_legalQA_easylaw", 24 | add_start_token=False, 25 | predictions=data["text"], 26 | max_length=256, 27 | batch_size=4, 28 | ) 29 | # results = perplexity.compute(model_id='nlpai-lab/kullm-polyglot-5.8b-v2',add_start_token=False,predictions=data['text'],max_length=256,batch_size=4) 30 | print(list(results.keys())) 31 | print(round(results["mean_perplexity"], 2)) 32 | print(round(results["perplexities"][0], 2)) 33 | -------------------------------------------------------------------------------- /model/LLM/inference/infer.py: -------------------------------------------------------------------------------- 1 | from peft import PeftModel, PeftConfig 2 | import torch 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | 5 | 6 | def gen(x, model, tokenizer, device): 7 | prompt = ( 8 | f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x}\n\n### 응답:\n" 9 | ) 10 | len_prompt = len(prompt) 11 | gened = model.generate( 12 | **tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to( 13 | device 14 | ), 15 | max_new_tokens=1024, 16 | early_stopping=True, 17 | do_sample=True, 18 | top_k=20, 19 | top_p=0.92, 20 | no_repeat_ngram_size=3, 21 | eos_token_id=2, 22 | repetition_penalty=1.2, 23 | num_beams=3, 24 | ) 25 | return tokenizer.decode(gened[0])[len_prompt:] 26 | 27 | 28 | def LLM_infer(input, model_type): 29 | device = ( 30 | torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 31 | ) 32 | if model_type == "kullm": 33 | peft_model_id = "YoonSeul/LawBot-level-3-KuLLM-5.8B-tae-2epoch" 34 | config = PeftConfig.from_pretrained(peft_model_id) 35 | model = AutoModelForCausalLM.from_pretrained( 36 | config.base_model_name_or_path, 37 | device_map={"": 0}, 38 | torch_dtype=torch.float16, 39 | ) 40 | model = PeftModel.from_pretrained( 41 | model, peft_model_id, torch_dtype=torch.float16 42 | ) 43 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) 44 | 45 | model.eval() 46 | model.config.use_cache = ( 47 | True # silence the warnings. Please re-enable for inference! 48 | ) 49 | tokenizer.pad_token = tokenizer.eos_token 50 | else: 51 | model_id = "kfkas/Legal-Llama-2-ko-7b-Chat" 52 | model = AutoModelForCausalLM.from_pretrained( 53 | model_id, 54 | device_map={"": 0}, 55 | torch_dtype=torch.float16, 56 | low_cpu_mem_usage=True, 57 | ) 58 | tokenizer = AutoTokenizer.from_pretrained(model_id) 59 | model.eval() 60 | model.config.use_cache = True 61 | tokenizer.pad_token = tokenizer.eos_token 62 | output = gen(input, model=model, tokenizer=tokenizer, device=device) 63 | 64 | return output 65 | 66 | 67 | if __name__ == "__main__": 68 | model_type = "kullm" # llama, kullm 69 | input = "음주운전을하면 어떤 법으로 처벌 되나요?" 70 | text = LLM_infer(input, model_type) 71 | -------------------------------------------------------------------------------- /model/LLM/train/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datasets import concatenate_datasets, Dataset 3 | from datasets import load_dataset 4 | import os 5 | 6 | 7 | class Autodata: 8 | def __init__(self, data_folder="./data", max_length=1024, tokenizer=None): 9 | self.data_foloder = data_folder 10 | self.max_length = max_length 11 | self.tokenizer = tokenizer 12 | self.concat_dataset = self.concat_datasets(self.data_foloder) 13 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset) 14 | 15 | def concat_datasets(self, folder_path): 16 | datasets = [] 17 | for file_name in os.listdir(folder_path): 18 | if file_name.endswith(".csv"): 19 | file_path = os.path.join(folder_path, file_name) 20 | dataset = pd.read_csv(file_path) 21 | dataframe = dataset[["question", "answer"]] 22 | dataset = Dataset.from_pandas(dataframe) 23 | datasets.append(dataset) 24 | 25 | combined_dataset = concatenate_datasets(datasets) 26 | 27 | return combined_dataset 28 | 29 | def tokenizing_dataset(self, dataset): 30 | data = dataset.map( 31 | lambda x: { 32 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>" 33 | } 34 | ) 35 | data = data.map( 36 | lambda samples: self.tokenizer( 37 | samples["text"], 38 | truncation=True, 39 | max_length=self.max_length, 40 | padding=False, 41 | return_tensors=None, 42 | ), 43 | batched=True, 44 | ) 45 | 46 | return data.shuffle() 47 | -------------------------------------------------------------------------------- /model/LLM/train/load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 3 | from peft import prepare_model_for_kbit_training 4 | from peft import LoraConfig, get_peft_model 5 | from utils import print_trainable_parameters 6 | 7 | 8 | def load_model(model_name): 9 | bnb_config = BitsAndBytesConfig( 10 | load_in_4bit=True, 11 | bnb_4bit_use_double_quant=True, 12 | bnb_4bit_quant_type="nf4", 13 | bnb_4bit_compute_dtype=torch.bfloat16, 14 | ) 15 | tokenizer = AutoTokenizer.from_pretrained(model_name) 16 | model = AutoModelForCausalLM.from_pretrained( 17 | model_name, quantization_config=bnb_config, device_map={"": 0} 18 | ) 19 | model.gradient_checkpointing_enable() 20 | model = prepare_model_for_kbit_training(model) 21 | 22 | config = LoraConfig( 23 | r=8, 24 | lora_alpha=32, 25 | target_modules=["query_key_value"], 26 | lora_dropout=0.05, 27 | bias="none", 28 | task_type="CAUSAL_LM", 29 | ) 30 | 31 | model = get_peft_model(model, config) 32 | print_trainable_parameters(model) 33 | 34 | return model, tokenizer 35 | -------------------------------------------------------------------------------- /model/LLM/train/train.py: -------------------------------------------------------------------------------- 1 | from load_model import load_model 2 | from data_preprocessing import Autodata 3 | import transformers 4 | 5 | 6 | def train(): 7 | model_id = "nlpai-lab/kullm-polyglot-5.8b-v2" 8 | model, tokenizer = load_model(model_id) 9 | tokenizer.pad_token = tokenizer.eos_token 10 | train_data = Autodata(data_folder="./data", tokenizer=tokenizer).tokenizer_dataset 11 | val_data = Autodata(data_folder="./val_data", tokenizer=tokenizer).tokenizer_dataset 12 | trainer = transformers.Trainer( 13 | model=model, 14 | train_dataset=train_data, 15 | eval_dataset=val_data, 16 | args=transformers.TrainingArguments( 17 | per_device_train_batch_size=16, 18 | gradient_accumulation_steps=1, 19 | num_train_epochs=6, 20 | learning_rate=1e-4, 21 | fp16=True, 22 | logging_steps=10, 23 | save_strategy="epoch", 24 | evaluation_strategy="epoch", 25 | output_dir="./model_outputs", 26 | optim="paged_adamw_8bit", 27 | ), 28 | data_collator=transformers.DataCollatorForLanguageModeling( 29 | tokenizer, mlm=False 30 | ), 31 | ) 32 | model.config.use_cache = ( 33 | False # silence the warnings. Please re-enable for inference! 34 | ) 35 | trainer.train() 36 | 37 | push_model_id = "kfkas/LawBot-level2-5.8B_FIX" 38 | huggingface_write_token = "" # Huggingface Write Token 작성 39 | 40 | model.push_to_hub( 41 | push_model_id, use_temp_dir=True, use_auth_token=huggingface_write_token 42 | ) 43 | print(f"{push_model_id} 모델 업로드 완료!") 44 | 45 | 46 | if __name__ == "__main__": 47 | train() 48 | -------------------------------------------------------------------------------- /model/LLM/train/utils.py: -------------------------------------------------------------------------------- 1 | def print_trainable_parameters(model): 2 | """ 3 | Prints the number of trainable parameters in the model. 4 | """ 5 | trainable_params = 0 6 | all_param = 0 7 | for _, param in model.named_parameters(): 8 | all_param += param.numel() 9 | if param.requires_grad: 10 | trainable_params += param.numel() 11 | print( 12 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" 13 | ) 14 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | # LawBot - Model 2 | 3 | ## 💻 Getting Started 4 | 5 | ## ⚠️ How To install Requirements 6 | ### Cuda install 7 | 8 | 1. Run the following code on your terminal. 9 | 10 | ```bash 11 | wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 12 | chmod +x cuda_11.8.0_520.61.05_linux.run 13 | sh cuda_11.8.0_520.61.05_linux.run 14 | ``` 15 | 2. Input `accept` to proceed. 16 | 17 | Screenshot 2023-07-17 at 9 26 12 PM 18 | 19 | 3. Select the driver and install. 20 | 21 | Screenshot 2023-07-17 at 9 26 25 PM 22 | 23 | 24 | 4. Run the following commands on your terminal. 25 | ```bash 26 | $ export PATH=/usr/local/cuda-11.8/bin:$PATH 27 | $ export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH 28 | $ pip install -r requirements.txt 29 | ``` 30 | ## ⌨️ How To Train 31 | ### LLM (Large Language Model) 32 | * Before training, place Legal QA data at following directory
33 | `model/LLM/train/data`
34 | `model/LLM/train/val_data` 35 | * HuggingFace Write Token should be filled at line #38 of the following file 36 | * Run the following command on your terminal 37 | ```bash 38 | $ python3 model/LLM/train/train.py 39 | ``` 40 | ### Question Filterering Model (Koelectra) 41 | ```bash 42 | $ python3 model/Filter/train.py 43 | ``` 44 | ## ⌨️ How To Infer 45 | ### LLM (Large Language Model) 46 | * peft model id should be changed after training at line #27 of the following file 47 | * Run the following command on your terminal 48 | 49 | ```bash 50 | $ python3 model/LLM/inference/infer.py 51 | ``` 52 | ### Sentence BERT Retrieval 53 | * Before training, place Legal QA data at following directory
54 | `model/Retrieval/bert_retrieval/data`
55 | * Run the following command on your terminal 56 | 57 | ```bash 58 | $ python3 model/Retrieval/bert_retrieval/retrieval_main.py 59 | ``` 60 | ### BM25 Retrieval 61 | * Before training, place Legal QA data at following directory
62 | `model/Retrieval/bm25_retrieval/all_data`
63 | * Run the following command on your terminal 64 | 65 | ```bash 66 | $ python3 model/Retrieval/bm25_retrieval/retrieval_main.py 67 | ``` 68 | ### Question Filtering Model (Koelectra) 69 | * Run the following command on your terminal 70 | 71 | ```bash 72 | $ python3 model/Filter/infer.py 73 | ``` 74 | ## ⌨️ How To Evaluate 75 | ### LLM (Large Language Model) 76 | * model name and `use` parmater should be changed if needed 77 | * Run the following command on your terminal 78 | ```bash 79 | $ python3 model/LLM/evaluation/evaluate_metrics.py 80 | ``` 81 | -------------------------------------------------------------------------------- /model/Retrieval/bert_retrieval/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datasets import concatenate_datasets, Dataset 3 | from datasets import load_dataset 4 | import os 5 | import pickle 6 | import numpy as np 7 | from sentence_transformers import SentenceTransformer 8 | from tqdm import tqdm 9 | 10 | 11 | class Autodata: 12 | def __init__(self, data_folder="./data"): 13 | self.data_foloder = data_folder 14 | self.concat_dataset = self.concat_datasets(self.data_foloder) 15 | 16 | def concat_datasets(self, data_foloder): 17 | datasets = [] 18 | for file_name in os.listdir(data_foloder): 19 | if file_name.endswith(".csv"): 20 | file_path = os.path.join(data_foloder, file_name) 21 | dataset = pd.read_csv(file_path) 22 | dataframe = dataset[["question", "answer"]] 23 | dataset = Dataset.from_pandas(dataframe) 24 | datasets.append(dataset) 25 | 26 | combined_dataset = concatenate_datasets(datasets) 27 | pd_combiend_dataset = pd.DataFrame(combined_dataset) 28 | return pd_combiend_dataset 29 | 30 | def build_vector_dataset(self, dataset, path): 31 | dataset = np.array(dataset) 32 | model_name = "jhgan/ko-sroberta-multitask" 33 | model = SentenceTransformer(model_name).to("cuda:0") 34 | 35 | query_vector_list = [] 36 | 37 | for i in tqdm(range(len(dataset))): 38 | question = dataset[i][0] 39 | query_vector = model.encode(question) 40 | query_vector_list.append(list(query_vector)) 41 | 42 | with open(path, "wb") as fw: 43 | pickle.dump(query_vector_list, fw) 44 | 45 | with open(path, "rb") as fr: 46 | vector_data = pickle.load(fr) 47 | 48 | return vector_data 49 | 50 | def load_vector_data(self, path="./data/query_vector.bin"): 51 | if os.path.isfile(path): 52 | with open(path, "rb") as fr: 53 | vector_data = pickle.load(fr) 54 | else: 55 | vector_data = self.build_vector_dataset(self.concat_dataset, path) 56 | return vector_data 57 | -------------------------------------------------------------------------------- /model/Retrieval/bert_retrieval/inference.py: -------------------------------------------------------------------------------- 1 | from data_preprocessing import Autodata 2 | import os 3 | import numpy as np 4 | from sentence_transformers import SentenceTransformer 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | 7 | 8 | def Query_BERT_infer(input): 9 | model_name = "jhgan/ko-sroberta-multitask" 10 | model = SentenceTransformer(model_name).to("cuda:0") 11 | 12 | data = Autodata("./data") 13 | original_data = data.concat_dataset 14 | vector_data = data.load_vector_data() 15 | 16 | input_vector = model.encode(input) 17 | input_vector = np.expand_dims(input_vector, axis=0) 18 | 19 | cos_sim = cosine_similarity(input_vector, vector_data) 20 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3] 21 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3] 22 | 23 | print("유사도 : ", data_cosine) 24 | 25 | question_list = [] 26 | answer_list = [] 27 | 28 | for i, index in enumerate(top_question): 29 | if data_cosine[i] >= 0.6: 30 | question_list.append(original_data["question"][index]) 31 | answer_list.append(original_data["answer"][index]) 32 | count = 0 33 | for question, answer in zip(question_list, answer_list): 34 | print(f"유사 상담 사례 질문 {count} : {question}") 35 | print(f"유사 상담 사례 답변 {count} : {answer}") 36 | print() 37 | count += 1 38 | 39 | 40 | if __name__ == "__main__": 41 | Query_BERT_infer("제가 자동차를 운전하다 중앙선을 침범하다가 2충 추돌사고를 발생시켰습니다. 이때 무슨 법으로 처벌 받을 수 있나요?") 42 | -------------------------------------------------------------------------------- /model/Retrieval/bm25_retrieval/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datasets import concatenate_datasets, Dataset 3 | import os 4 | import json 5 | 6 | 7 | class Autodata: 8 | def __init__(self, data_folder="./data"): 9 | self.data_foloder = data_folder 10 | self.concat_dataset = self.concat_datasets(self.data_foloder) 11 | 12 | def concat_datasets(self, data_foloder): 13 | datasets = [] 14 | pd_datasets = [] 15 | for file_name in os.listdir(data_foloder): 16 | if file_name.endswith(".csv"): 17 | file_path = os.path.join(data_foloder, file_name) 18 | dataset = pd.read_csv(file_path) 19 | dataframe = dataset[["question", "answer"]] 20 | pd_datasets.append(dataframe) 21 | dataset = Dataset.from_pandas(dataframe) 22 | datasets.append(dataset) 23 | 24 | combined_dataset = concatenate_datasets(datasets) 25 | pd_combiend_dataset = pd.DataFrame(combined_dataset) 26 | 27 | return pd_combiend_dataset 28 | 29 | def make_all_data(self, data, path): 30 | df = data 31 | data_dict = {} 32 | 33 | for i in range(len(df)): 34 | key = str(i) 35 | data_dict[key] = { 36 | "question": df.iloc[i]["question"], 37 | "answer": df.iloc[i]["answer"], 38 | } 39 | 40 | with open(path, "w", encoding="utf-8") as file: 41 | json.dump(data_dict, file, ensure_ascii=False, indent=4) 42 | 43 | def load_json_data(self, path="./all_data/all_data.json"): 44 | if not os.path.isfile(path): 45 | self.make_all_data(self.concat_dataset, path) 46 | -------------------------------------------------------------------------------- /model/Retrieval/bm25_retrieval/read_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | 4 | 5 | with open("./all_data/wikipedia_documents.json", "r") as f: 6 | data = json.load(f) 7 | print(data["0"].keys()) 8 | 9 | df = pd.read_csv("./all_data/legal_QA.csv") 10 | 11 | # 딕셔너리 초기화 12 | data_dict = {} 13 | 14 | # 행 번호를 키로 사용하여 딕셔너리에 데이터 추가 15 | for i in range(len(df)): 16 | key = str(i) 17 | data_dict[key] = { 18 | "question": df.iloc[i]["question"], 19 | "answer": df.iloc[i]["answer"], 20 | } 21 | 22 | # 결과 출력 23 | print(data_dict["0"]) 24 | with open("./all_data/legal_QA.json", "w", encoding="utf-8") as file: 25 | json.dump(data_dict, file, ensure_ascii=False, indent=4) 26 | -------------------------------------------------------------------------------- /model/Retrieval/bm25_retrieval/retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List, Optional, Tuple, Union 4 | import numpy as np 5 | import pandas as pd 6 | from rank_bm25 import BM25L, BM25Okapi, BM25Plus 7 | 8 | 9 | def setup_bm25(parent_class): 10 | class CustomBM25(parent_class): 11 | def __init__(self, corpus, tokenizer): 12 | super().__init__(corpus, tokenizer) 13 | 14 | def get_relevant_doc(self, query, k): 15 | query_vec = self.tokenizer(query) 16 | result = self.get_scores(query_vec) 17 | sorted_result = np.argsort(result.squeeze())[::-1] 18 | doc_score = result.squeeze()[sorted_result].tolist()[:k] 19 | doc_indices = sorted_result.tolist()[:k] 20 | return doc_score, doc_indices 21 | 22 | def get_relevant_doc_bulk(self, queries, k): 23 | doc_scores = [] 24 | doc_indices = [] 25 | for query in queries: 26 | doc_score, doc_indice = self.get_relevant_doc(query, k) 27 | doc_scores.append(doc_score) 28 | doc_indices.append(doc_indice) 29 | return doc_scores, doc_indices 30 | 31 | return CustomBM25 32 | 33 | 34 | class SparseRetrievalBM25: 35 | def __init__( 36 | self, 37 | tokenize_fn, 38 | data_path: Optional[str] = "./csv_data/", 39 | context_path: Optional[str] = "all_data.json", 40 | bm25_type: Optional[str] = "", 41 | ) -> None: 42 | self.data_path = data_path 43 | with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f: 44 | wiki = json.load(f) 45 | 46 | self.contexts = list(([v["question"] for v in wiki.values()])) 47 | self.contexts_answer = list(([v["answer"] for v in wiki.values()])) 48 | 49 | if bm25_type == "Okapi": 50 | bm25_class = setup_bm25(BM25Okapi) 51 | self.bm25 = bm25_class(self.contexts, tokenize_fn) 52 | elif bm25_type == "L": 53 | bm25_class = setup_bm25(BM25L) 54 | self.bm25 = bm25_class(self.contexts, tokenize_fn) 55 | elif bm25_type == "plus": 56 | bm25_class = setup_bm25(BM25Plus) 57 | self.bm25 = bm25_class(self.contexts, tokenize_fn) 58 | 59 | def retrieve( 60 | self, query_or_dataset: Union[str, pd.DataFrame], topk: Optional[int] = 1 61 | ) -> Union[Tuple[List, List], pd.DataFrame]: 62 | if isinstance(query_or_dataset, str): 63 | doc_scores, doc_indices = self.bm25.get_relevant_doc( 64 | query_or_dataset, k=topk 65 | ) 66 | return ( 67 | doc_scores, 68 | doc_indices, 69 | [self.contexts[doc_indices[i]] for i in range(topk)], 70 | [self.contexts_answer[doc_indices[i]] for i in range(topk)], 71 | ) 72 | -------------------------------------------------------------------------------- /model/Retrieval/bm25_retrieval/retrieval_bm25.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, List 3 | 4 | import pandas as pd 5 | from datasets import ( 6 | DatasetDict, 7 | ) 8 | from retrieval import SparseRetrievalBM25 9 | 10 | 11 | def run_sparse_retrieval( 12 | tokenize_fn: Callable[[str], List[str]], 13 | datasets: pd.DataFrame, 14 | data_path: str = os.path.join( 15 | os.path.abspath(os.path.dirname(__file__)), "csv_data" 16 | ), 17 | context_path: str = "all_data.json", 18 | bm25: str = None, 19 | ) -> DatasetDict: 20 | assert bm25 in ["Okapi", "L", "plus"], "Invalid type for BM25 has been passed." 21 | 22 | retriever = SparseRetrievalBM25( 23 | tokenize_fn=tokenize_fn, 24 | data_path=data_path, 25 | context_path=context_path, 26 | bm25_type=bm25, 27 | ) 28 | 29 | df = retriever.retrieve(datasets, topk=3) 30 | return df 31 | -------------------------------------------------------------------------------- /model/Retrieval/bm25_retrieval/retrieval_main.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from retrieval_bm25 import run_sparse_retrieval 3 | from data_preprocessing import Autodata 4 | 5 | 6 | def infer(input): 7 | data_path = "./all_data" 8 | data = Autodata(data_path) 9 | data.load_json_data(path="./all_data/all_data.json") 10 | tokenizer = AutoTokenizer.from_pretrained("nlpai-lab/kullm-polyglot-5.8b-v2") 11 | 12 | datasets = run_sparse_retrieval( 13 | tokenize_fn=tokenizer.tokenize, data_path=data_path, datasets=input, bm25="plus" 14 | ) # bm25 => None(TF-IDF), Okapi, L, plus 15 | 16 | print("유사도", datasets[0]) 17 | print("인덱스", datasets[1]) 18 | print("실제 질문", input) 19 | 20 | for question, answer in zip(datasets[2], datasets[3]): 21 | print("유사 질문") 22 | print(question) 23 | print("유사 답변") 24 | print(answer) 25 | print("-" * 200) 26 | 27 | 28 | if __name__ == "__main__": 29 | infer( 30 | "저는 중소기업을 운영하고 있는데, 평소 저희 회사의 품질과 신용을 좋게 평가해온 동일업종의 사업가가 저희 회사의 상호(商號)를 사겠다고 합니다. 상호를 팔 수 있는지요?" 31 | ) 32 | -------------------------------------------------------------------------------- /model/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0 2 | git+https://github.com/huggingface/transformers.git 3 | git+https://github.com/huggingface/accelerate.git 4 | pandas==2.0.3 5 | scikit-learn==1.3.0 6 | tqdm==4.65.0 7 | datasets==2.13.1 8 | sentence-transformers==2.2.2 9 | git+https://github.com/huggingface/peft.git 10 | rank_bm25==0.2.2 11 | bitsandbytes 12 | evaluate==0.4.0 13 | 14 | -------------------------------------------------------------------------------- /prototype/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | yarn.lock 8 | 9 | # testing 10 | /coverage 11 | 12 | # production 13 | /build 14 | 15 | # misc 16 | .DS_Store 17 | .env.local 18 | .env.development.local 19 | .env.test.local 20 | .env.production.local 21 | 22 | npm-debug.log* 23 | yarn-debug.log* 24 | yarn-error.log* 25 | -------------------------------------------------------------------------------- /prototype/README.md: -------------------------------------------------------------------------------- 1 | # Requirements for Web Prototype 2 | 3 | 아래는 웹 프로토타입 실행을 위한 의존성 패키지 설치 명령어 입니다. 4 | 순서대로 다운로드를 받아주세요. 5 | ```bash 6 | ### install react 7 | apt install curl 8 | curl https://raw.githubusercontent.com/creationix/nvm/master/install.sh | bash 9 | source ~/.profile 10 | nvm install 16.15.1 11 | 12 | ### Install Tailwind CSS 13 | npm install -g yarn 14 | yarn add tailwindcss postcss autoprefixer 15 | npx tailwindcss init 16 | ``` 17 | 18 | 설치가 끝나셨으면 아래 알맞은 버전이 설치되었는지 확인 부탁드립니다. 19 | 실행 파일은 protoype안에 있는 readme파일을 참고해주세요. 20 | ```bash 21 | node -v # v16.15.1가 나와야 합니다. 22 | npm -v # v8.11.0가 나와야 합니다. 23 | ``` 24 | 25 | # Getting Started with Create React App 26 | 27 | ```bash 28 | source ~/.profile # yarn 명령어가 인식되지 않으면 사용해주세요. 29 | nvm install 16.15.1 30 | yarn start 31 | ``` -------------------------------------------------------------------------------- /prototype/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "prototype", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@testing-library/jest-dom": "^5.16.5", 7 | "@testing-library/react": "^13.4.0", 8 | "@testing-library/user-event": "^13.5.0", 9 | "react": "^18.2.0", 10 | "react-dom": "^18.2.0", 11 | "react-scripts": "5.0.1", 12 | "web-vitals": "^2.1.4" 13 | }, 14 | "scripts": { 15 | "start": "react-scripts start", 16 | "build": "react-scripts build", 17 | "test": "react-scripts test", 18 | "eject": "react-scripts eject" 19 | }, 20 | "eslintConfig": { 21 | "extends": [ 22 | "react-app", 23 | "react-app/jest" 24 | ] 25 | }, 26 | "browserslist": { 27 | "production": [ 28 | ">0.2%", 29 | "not dead", 30 | "not op_mini all" 31 | ], 32 | "development": [ 33 | "last 1 chrome version", 34 | "last 1 firefox version", 35 | "last 1 safari version" 36 | ] 37 | }, 38 | "devDependencies": { 39 | "autoprefixer": "^10.4.14", 40 | "postcss": "^8.4.25", 41 | "tailwindcss": "^3.3.2" 42 | }, 43 | "proxy": "http://localhost:8000" 44 | } -------------------------------------------------------------------------------- /prototype/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 15 | 16 | 25 | Lawbot 26 | 27 | 28 | 29 | 30 |
31 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /prototype/public/lawbot.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/prototype/public/lawbot.ico -------------------------------------------------------------------------------- /prototype/public/lawbot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/prototype/public/lawbot.png -------------------------------------------------------------------------------- /prototype/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "lawbot.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "lawbot.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "lawbot.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } -------------------------------------------------------------------------------- /prototype/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /prototype/src/App.js: -------------------------------------------------------------------------------- 1 | import Header from './components/Header'; 2 | import ChattingSideBar from './components/ChattingSideBar'; 3 | import SimilarPrecedent from './components/SimilarPrecedent'; 4 | import { useState } from 'react'; 5 | 6 | function App() { 7 | 8 | const [message, setMessage] = useState(""); 9 | const [sentMessage, setSentMessage] = useState(""); 10 | const [aianswer, setAianswer] = useState("") 11 | const [precedents, setPrecedents] = useState([ 12 | { 13 | "case_name": "", 14 | "case_number": "", 15 | "case_type": "", 16 | "ref_article": "", 17 | "url": "string" 18 | }, 19 | { 20 | "case_name": "", 21 | "case_number": "", 22 | "case_type": "", 23 | "ref_article": "", 24 | "url": "string" 25 | }, 26 | { 27 | "case_name": "", 28 | "case_number": "", 29 | "case_type": "", 30 | "ref_article": "", 31 | "url": "" 32 | } 33 | ]) 34 | 35 | const messagehandler = async (e) => { 36 | e.preventDefault(); 37 | setMessage(""); 38 | setPrecedents([ 39 | { 40 | "case_name": "", 41 | "case_number": "", 42 | "case_type": "", 43 | "ref_article": "", 44 | "url": "string" 45 | }, 46 | { 47 | "case_name": "", 48 | "case_number": "", 49 | "case_type": "", 50 | "ref_article": "", 51 | "url": "string" 52 | }, 53 | { 54 | "case_name": "", 55 | "case_number": "", 56 | "case_type": "", 57 | "ref_article": "", 58 | "url": "" 59 | } 60 | ]) 61 | setAianswer("") 62 | setSentMessage(message); 63 | const response = await fetch('/generate', { 64 | method: 'POST', 65 | headers: { 66 | 'Content-Type': 'application/json', 67 | }, 68 | body: JSON.stringify({ q_sentence: sentMessage }), 69 | }); 70 | const data = await response.json() 71 | setAianswer(data.answer_sentence) 72 | setPrecedents(data.similar_precedent) 73 | }; 74 | 75 | return ( 76 |
77 |
78 |
79 |
80 |
81 | 82 |
83 |
84 |
85 |
86 |
87 |
88 | {sentMessage && ( 89 |
90 |
91 |
92 | U 93 |
94 |
95 |
{sentMessage}
96 |
97 |
98 |
99 | )} 100 | {aianswer && ( 101 |
102 |
103 |
106 | L 107 |
108 |
111 |
{aianswer}
112 |
113 |
114 |
)} 115 |
116 |
117 |
118 |
121 |
122 | 140 |
141 |
142 |
143 | setMessage(e.target.value)} 149 | /> 150 | 170 |
171 |
172 |
173 | 194 |
195 |
196 |
197 |
198 |
199 | 200 |
201 | 202 |
203 |
204 | 205 |
206 | ); 207 | } 208 | 209 | export default App; 210 | -------------------------------------------------------------------------------- /prototype/src/asset/lawbot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/771de997cb0ce713826cd7a0a868280d866cd70f/prototype/src/asset/lawbot.png -------------------------------------------------------------------------------- /prototype/src/components/ChattingSideBar.js: -------------------------------------------------------------------------------- 1 | function ChattingSideBar() { 2 | return ( 3 | 64 | ) 65 | } 66 | export default ChattingSideBar 67 | -------------------------------------------------------------------------------- /prototype/src/components/Header.js: -------------------------------------------------------------------------------- 1 | import lawbot from "../asset/lawbot.png" 2 | 3 | function Header() { 4 | return ( 5 | 17 | ) 18 | } 19 | export default Header 20 | -------------------------------------------------------------------------------- /prototype/src/components/SimilarPrecedent.js: -------------------------------------------------------------------------------- 1 | import PrecedentCard from "./SimilarPrecedentComponents/PrecedentCard" 2 | 3 | function SimilarPresdent({ precedents }) { 4 | const precedents_best = precedents[0] 5 | const precedents_second = precedents[1] 6 | const precedents_third = precedents[2] 7 | 8 | return ( 9 |
10 | s 23 |
24 | ) 25 | } 26 | export default SimilarPresdent 27 | -------------------------------------------------------------------------------- /prototype/src/components/SimilarPrecedentComponents/PrecedentCard.js: -------------------------------------------------------------------------------- 1 | function PrecedentCard({ precedent, number }) { 2 | return ( 3 |
4 |
유사 판례 조항 {number}
5 |

사건 케이스 : {precedent.case_name}

6 |

사건 번호 : {precedent.case_number}

7 |

사건 분류 : {precedent.case_type}

8 |

관련 법 조항 :{precedent.ref_article}

9 | 10 | Read more 11 | 14 | 15 |
16 | ) 17 | } 18 | export default PrecedentCard 19 | -------------------------------------------------------------------------------- /prototype/src/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; -------------------------------------------------------------------------------- /prototype/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | 10 | 11 | 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | reportWebVitals(); 18 | -------------------------------------------------------------------------------- /prototype/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /prototype/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: ["./src/**/*.{html,js}"], 4 | theme: { 5 | extend: {}, 6 | }, 7 | plugins: [] 8 | } 9 | 10 | --------------------------------------------------------------------------------