├── data_pipeline
├── .gitignore
├── generate
│ ├── prompts.pkl
│ ├── generate_bard.py
│ ├── generate_gpt.py
│ ├── backup_prompts.py
│ └── parse.py
├── requirements.txt
├── utils.py
├── README.md
├── preprocessor.py
├── spellchecker.py
├── preprocessor_v2.py
├── qa_crawler.py
└── crawler.py
├── frontend
├── src
│ ├── index.css
│ ├── asset
│ │ ├── lawbot.png
│ │ └── spinner.gif
│ ├── reportWebVitals.js
│ ├── index.js
│ ├── components
│ │ ├── TypingAnimation.js
│ │ ├── Header.js
│ │ ├── SimilarPrecedent.js
│ │ ├── SimilarPrecedentComponents
│ │ │ └── PrecedentCard.js
│ │ ├── Loader.js
│ │ └── ChattingSideBar.js
│ └── App.js
├── public
│ ├── robots.txt
│ ├── lawbot.ico
│ ├── lawbot.png
│ ├── manifest.json
│ └── index.html
├── tailwind.config.js
├── .gitignore
├── package.json
└── README.md
├── prototype
├── src
│ ├── index.css
│ ├── asset
│ │ └── lawbot.png
│ ├── reportWebVitals.js
│ ├── index.js
│ ├── components
│ │ ├── Header.js
│ │ ├── SimilarPrecedent.js
│ │ ├── SimilarPrecedentComponents
│ │ │ └── PrecedentCard.js
│ │ └── ChattingSideBar.js
│ └── App.js
├── public
│ ├── robots.txt
│ ├── lawbot.ico
│ ├── lawbot.png
│ ├── manifest.json
│ └── index.html
├── tailwind.config.js
├── .gitignore
├── README.md
└── package.json
├── model
├── .gitignore
├── Filter
│ ├── utils.py
│ ├── infer.py
│ ├── data_preprocessing.py
│ ├── train.py
│ └── dataloader.py
├── BERT
│ ├── inference
│ │ ├── utils.py
│ │ └── inference.py
│ ├── preprocessing.py
│ └── make_vector_dataset
│ │ └── preprocessing_law_data.py
├── requirements.txt
├── LLM
│ ├── train
│ │ ├── utils.py
│ │ ├── load_model.py
│ │ ├── train.py
│ │ └── data_preprocessing.py
│ ├── evaluation
│ │ ├── ppl.py
│ │ ├── evaluate_mertrics.py
│ │ ├── data_preprocessing.py
│ │ ├── dialogue_evaluation.py
│ │ └── petf_ppl.py
│ └── inference
│ │ └── infer.py
├── Retrieval
│ ├── bm25_retrieval
│ │ ├── read_json.py
│ │ ├── retrieval_bm25.py
│ │ ├── retrieval_main.py
│ │ ├── data_preprocessing.py
│ │ └── retrieval.py
│ └── bert_retrieval
│ │ ├── inference.py
│ │ └── data_preprocessing.py
└── README.md
├── backend
├── .gitignore
├── Dockerfile
├── tests
│ ├── test_retrieval.py
│ ├── test_filter.py
│ ├── test_generate.py
│ └── test_search.py
├── requirements.txt
├── README.md
├── app
│ ├── filter.py
│ ├── generate.py
│ ├── search.py
│ ├── main.py
│ ├── bert_retrieval.py
│ └── bm25_retrieval.py
├── router
│ └── router.py
└── airflow
│ ├── module
│ ├── load_data.py
│ └── train_model.py
│ └── dags
│ └── training_pipeline.py
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── ISSUE_TEMPLATE
│ └── feature_request.md
└── workflows
│ └── backendCI.yml
└── README.md
/data_pipeline/.gitignore:
--------------------------------------------------------------------------------
1 | local/*
2 | data/*
3 | artifact/*
4 | __pycache__
5 | *.log
--------------------------------------------------------------------------------
/frontend/src/index.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
--------------------------------------------------------------------------------
/prototype/src/index.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
--------------------------------------------------------------------------------
/frontend/public/robots.txt:
--------------------------------------------------------------------------------
1 | # https://www.robotstxt.org/robotstxt.html
2 | User-agent: *
3 | Disallow:
4 |
--------------------------------------------------------------------------------
/prototype/public/robots.txt:
--------------------------------------------------------------------------------
1 | # https://www.robotstxt.org/robotstxt.html
2 | User-agent: *
3 | Disallow:
4 |
--------------------------------------------------------------------------------
/frontend/public/lawbot.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/frontend/public/lawbot.ico
--------------------------------------------------------------------------------
/frontend/public/lawbot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/frontend/public/lawbot.png
--------------------------------------------------------------------------------
/prototype/public/lawbot.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/prototype/public/lawbot.ico
--------------------------------------------------------------------------------
/prototype/public/lawbot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/prototype/public/lawbot.png
--------------------------------------------------------------------------------
/frontend/src/asset/lawbot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/frontend/src/asset/lawbot.png
--------------------------------------------------------------------------------
/frontend/src/asset/spinner.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/frontend/src/asset/spinner.gif
--------------------------------------------------------------------------------
/prototype/src/asset/lawbot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/prototype/src/asset/lawbot.png
--------------------------------------------------------------------------------
/data_pipeline/generate/prompts.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/HEAD/data_pipeline/generate/prompts.pkl
--------------------------------------------------------------------------------
/model/.gitignore:
--------------------------------------------------------------------------------
1 | LLM/train/data
2 | LLM/train/val_data
3 | ../.idea
4 | BERT/data
5 | Retrieval/bert_retrieval/data
6 | Retrieval/bm25_retrieval/all_data
7 | Filter/.idea
8 | Filter/data
--------------------------------------------------------------------------------
/data_pipeline/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==2.0.3
2 | beautifulsoup4==4.12.2
3 | scikit-learn==1.1.2
4 | selenium==4.10.0
5 | webdriver-manager==3.8.6
6 | tqdm==4.65.0
7 | openai==0.27.8
8 | bardapi==0.1.27
--------------------------------------------------------------------------------
/backend/.gitignore:
--------------------------------------------------------------------------------
1 | final_project
2 | __pycache__
3 | model
4 | test.ipynb
5 | data
6 | nohup.out
7 | logs
8 | airflow.cfg
9 | airflow.db
10 | webserver_config.py
11 | .env
12 | airflow-webserver.pid
--------------------------------------------------------------------------------
/prototype/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | module.exports = {
3 | content: ["./src/**/*.{html,js}"],
4 | theme: {
5 | extend: {},
6 | },
7 | plugins: []
8 | }
9 |
10 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Overview
2 | -
3 |
4 | ## Change Log
5 | -
6 |
7 | ## Further information
8 | -
9 |
10 | ## To Reviewer
11 | -
12 |
13 | ## Issue Tags
14 | - Closed | Fixed: #
15 | - See also : #
16 |
--------------------------------------------------------------------------------
/frontend/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | module.exports = {
3 | content: ["./src/**/*.{html,js}"],
4 | theme: {
5 | extend: {},
6 | },
7 | plugins: [
8 | require('tailwindcss-animated')
9 | ]
10 | }
11 |
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[FEAT]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Background
11 | -
12 |
13 | ## Todo
14 | - [ ] Todo 1
15 | - [ ] Todo 2
16 |
--------------------------------------------------------------------------------
/backend/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 |
3 | RUN python3 -m pip install --upgrade pip
4 |
5 | COPY ./requirements.txt /ws/requirements.txt
6 |
7 | WORKDIR /ws
8 |
9 | RUN pip install -r requirements.txt
10 |
11 | COPY ./app/ /ws
12 |
13 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0"]
--------------------------------------------------------------------------------
/backend/tests/test_retrieval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
5 |
6 | from app.bm25_retrieval import retrieve_QA
7 |
8 |
9 | def test_retrieve_QA():
10 | retrieve_QA("제가 술을 먹고 운전을 했는데 어떤 처벌을 받을까요?")
11 |
--------------------------------------------------------------------------------
/model/Filter/utils.py:
--------------------------------------------------------------------------------
1 | import evaluate
2 | import numpy as np
3 |
4 |
5 | def compute_metrics(eval_pred):
6 | accuracy = evaluate.load("f1")
7 | predictions, labels = eval_pred
8 | predictions = np.argmax(predictions, axis=1)
9 | return accuracy.compute(predictions=predictions, references=labels)
10 |
--------------------------------------------------------------------------------
/model/BERT/inference/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 |
5 | def load_vector_data(path):
6 | if os.path.isfile(path):
7 | with open(path, "rb") as fr:
8 | vector_data = pickle.load(fr)
9 | else:
10 | print("판례 데이터가 존재하지 않습니다.")
11 | vector_data = None
12 | return vector_data
13 |
--------------------------------------------------------------------------------
/backend/tests/test_filter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
5 |
6 | from app.filter import is_legal_question
7 |
8 |
9 | def test_is_legal_question():
10 | assert not is_legal_question("안녕하세요.")
11 | assert is_legal_question("제가 술을 먹고 운전을 했는데 어떤 처벌을 받을까요?")
12 |
--------------------------------------------------------------------------------
/model/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.0
2 | git+https://github.com/huggingface/transformers.git
3 | git+https://github.com/huggingface/accelerate.git
4 | pandas==2.0.3
5 | scikit-learn==1.3.0
6 | tqdm==4.65.0
7 | datasets==2.13.1
8 | sentence-transformers==2.2.2
9 | git+https://github.com/huggingface/peft.git
10 | rank_bm25==0.2.2
11 | bitsandbytes
12 | evaluate==0.4.0
13 |
14 |
--------------------------------------------------------------------------------
/backend/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.99.1
2 | uvicorn==0.22.0
3 | transformers==4.30.2
4 | torch==2.0.1
5 | accelerate==0.20.3
6 | pytest==7.4.0
7 | pandas==2.0.3
8 | scikit-learn==1.3.0
9 | sentence-transformers==2.2.2
10 | git+https://github.com/huggingface/peft.git
11 | rank-bm25==0.2.2
12 | datasets==2.13.1
13 | apache-airflow==2.2.3
14 | MarkupSafe==2.0.1
15 | python-dotenv==1.0.0
16 | bitsandbytes==0.41.0
17 | dnspython==2.3.0
--------------------------------------------------------------------------------
/frontend/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 |
8 | # testing
9 | /coverage
10 |
11 | # production
12 | /build
13 |
14 | # misc
15 | .DS_Store
16 | .env.local
17 | .env.development.local
18 | .env.test.local
19 | .env.production.local
20 |
21 | npm-debug.log*
22 | yarn-debug.log*
23 | yarn-error.log*
24 | yarn.lock
--------------------------------------------------------------------------------
/frontend/src/reportWebVitals.js:
--------------------------------------------------------------------------------
1 | const reportWebVitals = onPerfEntry => {
2 | if (onPerfEntry && onPerfEntry instanceof Function) {
3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
4 | getCLS(onPerfEntry);
5 | getFID(onPerfEntry);
6 | getFCP(onPerfEntry);
7 | getLCP(onPerfEntry);
8 | getTTFB(onPerfEntry);
9 | });
10 | }
11 | };
12 |
13 | export default reportWebVitals;
14 |
--------------------------------------------------------------------------------
/prototype/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 | yarn.lock
8 |
9 | # testing
10 | /coverage
11 |
12 | # production
13 | /build
14 |
15 | # misc
16 | .DS_Store
17 | .env.local
18 | .env.development.local
19 | .env.test.local
20 | .env.production.local
21 |
22 | npm-debug.log*
23 | yarn-debug.log*
24 | yarn-error.log*
25 |
--------------------------------------------------------------------------------
/prototype/src/reportWebVitals.js:
--------------------------------------------------------------------------------
1 | const reportWebVitals = onPerfEntry => {
2 | if (onPerfEntry && onPerfEntry instanceof Function) {
3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
4 | getCLS(onPerfEntry);
5 | getFID(onPerfEntry);
6 | getFCP(onPerfEntry);
7 | getLCP(onPerfEntry);
8 | getTTFB(onPerfEntry);
9 | });
10 | }
11 | };
12 |
13 | export default reportWebVitals;
14 |
--------------------------------------------------------------------------------
/model/LLM/train/utils.py:
--------------------------------------------------------------------------------
1 | def print_trainable_parameters(model):
2 | """
3 | Prints the number of trainable parameters in the model.
4 | """
5 | trainable_params = 0
6 | all_param = 0
7 | for _, param in model.named_parameters():
8 | all_param += param.numel()
9 | if param.requires_grad:
10 | trainable_params += param.numel()
11 | print(
12 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
13 | )
14 |
--------------------------------------------------------------------------------
/frontend/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "short_name": "React App",
3 | "name": "Create React App Sample",
4 | "icons": [
5 | {
6 | "src": "lawbot.ico",
7 | "sizes": "64x64 32x32 24x24 16x16",
8 | "type": "image/x-icon"
9 | },
10 | {
11 | "src": "lawbot.png",
12 | "type": "image/png",
13 | "sizes": "192x192"
14 | },
15 | {
16 | "src": "lawbot.png",
17 | "type": "image/png",
18 | "sizes": "512x512"
19 | }
20 | ],
21 | "start_url": ".",
22 | "display": "standalone",
23 | "theme_color": "#000000",
24 | "background_color": "#ffffff"
25 | }
--------------------------------------------------------------------------------
/frontend/src/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom/client';
3 | import './index.css';
4 | import App from './App';
5 | import reportWebVitals from './reportWebVitals';
6 |
7 | const root = ReactDOM.createRoot(document.getElementById('root'));
8 | root.render(
9 |
10 |
11 |
12 | );
13 |
14 | // If you want to start measuring performance in your app, pass a function
15 | // to log results (for example: reportWebVitals(console.log))
16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
17 | reportWebVitals();
18 |
--------------------------------------------------------------------------------
/prototype/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "short_name": "React App",
3 | "name": "Create React App Sample",
4 | "icons": [
5 | {
6 | "src": "lawbot.ico",
7 | "sizes": "64x64 32x32 24x24 16x16",
8 | "type": "image/x-icon"
9 | },
10 | {
11 | "src": "lawbot.png",
12 | "type": "image/png",
13 | "sizes": "192x192"
14 | },
15 | {
16 | "src": "lawbot.png",
17 | "type": "image/png",
18 | "sizes": "512x512"
19 | }
20 | ],
21 | "start_url": ".",
22 | "display": "standalone",
23 | "theme_color": "#000000",
24 | "background_color": "#ffffff"
25 | }
--------------------------------------------------------------------------------
/prototype/src/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom/client';
3 | import './index.css';
4 | import App from './App';
5 | import reportWebVitals from './reportWebVitals';
6 |
7 | const root = ReactDOM.createRoot(document.getElementById('root'));
8 | root.render(
9 |
10 |
11 |
12 | );
13 |
14 | // If you want to start measuring performance in your app, pass a function
15 | // to log results (for example: reportWebVitals(console.log))
16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
17 | reportWebVitals();
18 |
--------------------------------------------------------------------------------
/frontend/src/components/TypingAnimation.js:
--------------------------------------------------------------------------------
1 | import React, { useState, useEffect } from "react";
2 |
3 | function TypingAnimation({ text }) {
4 | const [visibleText, setVisibleText] = useState("");
5 | const typingDelay = 20; // 타이핑 딜레이 시간(ms)
6 |
7 | useEffect(() => {
8 | const typeText = (currentIndex) => {
9 | if (currentIndex < text.length) {
10 | setVisibleText(text.substring(0, currentIndex + 1));
11 | setTimeout(() => typeText(currentIndex + 1), typingDelay);
12 | }
13 | };
14 |
15 | typeText(0);
16 | }, [text]);
17 |
18 | return
{visibleText}
;
19 | }
20 |
21 | export default TypingAnimation;
--------------------------------------------------------------------------------
/model/Retrieval/bm25_retrieval/read_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 |
4 |
5 | with open("./all_data/wikipedia_documents.json", "r") as f:
6 | data = json.load(f)
7 | print(data["0"].keys())
8 |
9 | df = pd.read_csv("./all_data/legal_QA.csv")
10 |
11 | # 딕셔너리 초기화
12 | data_dict = {}
13 |
14 | # 행 번호를 키로 사용하여 딕셔너리에 데이터 추가
15 | for i in range(len(df)):
16 | key = str(i)
17 | data_dict[key] = {
18 | "question": df.iloc[i]["question"],
19 | "answer": df.iloc[i]["answer"],
20 | }
21 |
22 | # 결과 출력
23 | print(data_dict["0"])
24 | with open("./all_data/legal_QA.json", "w", encoding="utf-8") as file:
25 | json.dump(data_dict, file, ensure_ascii=False, indent=4)
26 |
--------------------------------------------------------------------------------
/data_pipeline/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def utilize_loggers(name):
6 | logger = logging.getLogger(__name__)
7 | logger_name = os.path.splitext(os.path.basename(name))[0]
8 |
9 | logging.basicConfig(
10 | filename=f"{logger_name}.log",
11 | format="%(asctime)s | %(levelname)s: %(message)s",
12 | level=logging.INFO,
13 | datefmt="%Y/%m/%d %I:%M:%S %p",
14 | )
15 |
16 | stream_handler = logging.StreamHandler()
17 | stream_handler.setLevel(logging.INFO)
18 | stream_handler.setFormatter(
19 | logging.Formatter("%(asctime)s | %(levelname)s: %(message)s")
20 | )
21 | logger.addHandler(stream_handler)
22 | return logger
23 |
24 |
--------------------------------------------------------------------------------
/model/BERT/preprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sentence_transformers import SentenceTransformer
4 | from tqdm import tqdm
5 | import os
6 |
7 | model_name = "jhgan/ko-sroberta-multitask"
8 | model = SentenceTransformer(model_name)
9 | model.to("cuda:0")
10 |
11 | df = pd.read_csv("./data/law_data.csv")
12 | df = df.dropna(
13 | subset=["caseName", "judgementAbstract", "precedentText", "judgementNote"]
14 | )
15 | np_df = np.array(df)
16 |
17 | Ab_list = []
18 | for i in tqdm(range(len(np_df))):
19 | Ab = np_df[i][4]
20 | Ab_query = model.encode(Ab)
21 | Ab_list.append(list(Ab_query))
22 |
23 | df["Ab_vector"] = Ab_list
24 | df.to_csv("Ab_vector_law_data_sroberta.csv", index=False)
25 |
--------------------------------------------------------------------------------
/prototype/README.md:
--------------------------------------------------------------------------------
1 | # Requirements for Web Prototype
2 |
3 | 아래는 웹 프로토타입 실행을 위한 의존성 패키지 설치 명령어 입니다.
4 | 순서대로 다운로드를 받아주세요.
5 | ```bash
6 | ### install react
7 | apt install curl
8 | curl https://raw.githubusercontent.com/creationix/nvm/master/install.sh | bash
9 | source ~/.profile
10 | nvm install 16.15.1
11 |
12 | ### Install Tailwind CSS
13 | npm install -g yarn
14 | yarn add tailwindcss postcss autoprefixer
15 | npx tailwindcss init
16 | ```
17 |
18 | 설치가 끝나셨으면 아래 알맞은 버전이 설치되었는지 확인 부탁드립니다.
19 | 실행 파일은 protoype안에 있는 readme파일을 참고해주세요.
20 | ```bash
21 | node -v # v16.15.1가 나와야 합니다.
22 | npm -v # v8.11.0가 나와야 합니다.
23 | ```
24 |
25 | # Getting Started with Create React App
26 |
27 | ```bash
28 | source ~/.profile # yarn 명령어가 인식되지 않으면 사용해주세요.
29 | nvm install 16.15.1
30 | yarn start
31 | ```
--------------------------------------------------------------------------------
/backend/README.md:
--------------------------------------------------------------------------------
1 | # LawBot - Backend
2 |
3 | ## 1. environtment
4 | * FastAPI
5 | * python 3.8.x (tested on 3.8.5)
6 |
7 | ## 2. install
8 |
9 | * Create your own virtual envorionment.
10 | ```bash
11 | $ pip install virtualenv
12 | $ virtualenv
13 | $ source /bin/activate
14 | ```
15 |
16 | * Install modules on your virtual environment.
17 | ```bash
18 | $ pip install -r requirements.txt
19 | ```
20 |
21 | ## 3. Execute
22 |
23 | ### Model Server
24 | ```bash
25 | $ cd app
26 | $ uvicorn main:app --host=0.0.0.0 --reload
27 | ```
28 | ### Test
29 | ```bash
30 | $ pytest
31 | ```
32 | ### Airflow
33 | ```bash
34 | $ cd airflow
35 | $ airflow db init
36 | $ airflow scheduler
37 | ```
38 | ## 4. Document
39 | 1. Execute server(local)
40 | 2. Goto http://localhost:8000/docs
41 |
--------------------------------------------------------------------------------
/prototype/src/components/Header.js:
--------------------------------------------------------------------------------
1 | import lawbot from "../asset/lawbot.png"
2 |
3 | function Header() {
4 | return (
5 |
6 |
7 |
8 |
9 |
10 |
11 | LawBot
12 |
13 |
14 |
15 |
16 |
17 | )
18 | }
19 | export default Header
20 |
--------------------------------------------------------------------------------
/model/Retrieval/bm25_retrieval/retrieval_bm25.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Callable, List
3 |
4 | import pandas as pd
5 | from datasets import (
6 | DatasetDict,
7 | )
8 | from retrieval import SparseRetrievalBM25
9 |
10 |
11 | def run_sparse_retrieval(
12 | tokenize_fn: Callable[[str], List[str]],
13 | datasets: pd.DataFrame,
14 | data_path: str = os.path.join(
15 | os.path.abspath(os.path.dirname(__file__)), "csv_data"
16 | ),
17 | context_path: str = "all_data.json",
18 | bm25: str = None,
19 | ) -> DatasetDict:
20 | assert bm25 in ["Okapi", "L", "plus"], "Invalid type for BM25 has been passed."
21 |
22 | retriever = SparseRetrievalBM25(
23 | tokenize_fn=tokenize_fn,
24 | data_path=data_path,
25 | context_path=context_path,
26 | bm25_type=bm25,
27 | )
28 |
29 | df = retriever.retrieve(datasets, topk=3)
30 | return df
31 |
--------------------------------------------------------------------------------
/backend/tests/test_generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 |
6 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
7 |
8 | from app.generate import generate_answer
9 | from peft import PeftConfig, PeftModel
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 |
13 | def test_generate_answer():
14 | q_sentence = "제가 술을 먹고 운전을 했는데 어떤 처벌을 받을까요?"
15 | peft_model_id = "YoonSeul/LawBot-level-3-KuLLM-5.8B-tae-2epoch"
16 | config = PeftConfig.from_pretrained(peft_model_id)
17 | model = AutoModelForCausalLM.from_pretrained(
18 | config.base_model_name_or_path, device_map={"": 0}, torch_dtype=torch.float16
19 | )
20 | model = PeftModel.from_pretrained(model, peft_model_id, torch_dtype=torch.float16)
21 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
22 | generate_answer(q_sentence=q_sentence, model=model, tokenizer=tokenizer)
23 |
--------------------------------------------------------------------------------
/frontend/src/components/Header.js:
--------------------------------------------------------------------------------
1 | import lawbot from "../asset/lawbot.png"
2 |
3 | function Header() {
4 | return (
5 |
6 |
18 |
19 | )
20 | }
21 | export default Header
22 |
--------------------------------------------------------------------------------
/model/BERT/make_vector_dataset/preprocessing_law_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sentence_transformers import SentenceTransformer
4 | from tqdm import tqdm
5 | import pickle
6 |
7 | model_name = "jhgan/ko-sroberta-multitask"
8 | model = SentenceTransformer(model_name)
9 | model.to("cuda:0")
10 |
11 | df = pd.read_csv("../data/law_data/law_data.csv", encoding="UTF-8")
12 | df = df.dropna(
13 | subset=["caseName", "judgementAbstract", "precedentText", "judgementNote"]
14 | )
15 | df.to_csv("../data/law_data/law_data_drop.csv", index=False)
16 | np_df = np.array(df)
17 |
18 | vector_list = []
19 | for i in tqdm(range(len(np_df))):
20 | judgementAbstract = np_df[i][4]
21 | judgementNote = np_df[i][9]
22 |
23 | judgementNote_vector = model.encode(judgementAbstract + judgementNote)
24 | vector_list.append(list(judgementNote_vector))
25 |
26 | with open("../data/law_data/law_data_drop_vector.bin", "wb") as fw:
27 | pickle.dump(vector_list, fw)
28 |
--------------------------------------------------------------------------------
/backend/app/filter.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch.nn.functional as F
4 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
5 |
6 |
7 | def is_legal_question(q_sentence):
8 | start_time = time.time()
9 | base_model_name = "monologg/koelectra-small-v3-discriminator"
10 | model = AutoModelForSequenceClassification.from_pretrained(
11 | "kfkas/legal-question-filter-koelectra",
12 | num_labels=2,
13 | ignore_mismatched_sizes=True,
14 | )
15 | tokenizer = AutoTokenizer.from_pretrained(base_model_name)
16 | inputs = tokenizer(q_sentence, padding=True, truncation=True, return_tensors="pt")
17 | outputs = model(**inputs)
18 | logits = outputs.logits.detach().cpu()
19 | pr = F.softmax(logits, dim=1).numpy()
20 | # arg = np.argmax(pr, axis=1)
21 | # print(logits)
22 | # print(pr)
23 | # print(int(arg))
24 |
25 | print(f"filter time: {time.time() - start_time}")
26 |
27 | if pr[0][0] >= 0.95:
28 | return True
29 | return False
30 |
--------------------------------------------------------------------------------
/model/Filter/infer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 |
6 | def infer():
7 | text = "안녕하세요 김주원입니다."
8 | base_model_name = "monologg/koelectra-small-v3-discriminator"
9 | model = AutoModelForSequenceClassification.from_pretrained(
10 | "kfkas/legal-question-filter-koelectra",
11 | num_labels=2,
12 | ignore_mismatched_sizes=True,
13 | )
14 | tokenizer = AutoTokenizer.from_pretrained(base_model_name)
15 | inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
16 | outputs = model(**inputs)
17 | logits = outputs.logits.detach().cpu()
18 | pr = F.softmax(logits).numpy()
19 | arg = np.argmax(pr, axis=1)
20 | print(logits)
21 | print(pr)
22 | print(int(arg))
23 | if int(arg) == 0 and (pr[0][0] >= 0.98).all():
24 | print("법률입니다")
25 | else:
26 | print("법률 질문이 아닙니다(bard,chatgpt 등등 다른 API 호출하면 좋을거 같아용)")
27 |
28 |
29 | if __name__ == "__main__":
30 | infer()
31 |
--------------------------------------------------------------------------------
/frontend/src/components/SimilarPrecedent.js:
--------------------------------------------------------------------------------
1 | import PrecedentCard from "./SimilarPrecedentComponents/PrecedentCard";
2 |
3 | function SimilarPresdent({ precedents }) {
4 | const validPrecedents = precedents || [];
5 |
6 | return (
7 |
8 |
21 |
22 | );
23 | }
24 |
25 | export default SimilarPresdent;
--------------------------------------------------------------------------------
/model/Retrieval/bm25_retrieval/retrieval_main.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from retrieval_bm25 import run_sparse_retrieval
3 | from data_preprocessing import Autodata
4 |
5 |
6 | def infer(input):
7 | data_path = "./all_data"
8 | data = Autodata(data_path)
9 | data.load_json_data(path="./all_data/all_data.json")
10 | tokenizer = AutoTokenizer.from_pretrained("nlpai-lab/kullm-polyglot-5.8b-v2")
11 |
12 | datasets = run_sparse_retrieval(
13 | tokenize_fn=tokenizer.tokenize, data_path=data_path, datasets=input, bm25="plus"
14 | ) # bm25 => None(TF-IDF), Okapi, L, plus
15 |
16 | print("유사도", datasets[0])
17 | print("인덱스", datasets[1])
18 | print("실제 질문", input)
19 |
20 | for question, answer in zip(datasets[2], datasets[3]):
21 | print("유사 질문")
22 | print(question)
23 | print("유사 답변")
24 | print(answer)
25 | print("-" * 200)
26 |
27 |
28 | if __name__ == "__main__":
29 | infer(
30 | "저는 중소기업을 운영하고 있는데, 평소 저희 회사의 품질과 신용을 좋게 평가해온 동일업종의 사업가가 저희 회사의 상호(商號)를 사겠다고 합니다. 상호를 팔 수 있는지요?"
31 | )
32 |
--------------------------------------------------------------------------------
/model/LLM/evaluation/ppl.py:
--------------------------------------------------------------------------------
1 | import dataset as dataset
2 | import evaluate
3 | from datasets import load_dataset
4 |
5 | from petf_ppl import Perplexity_Petf
6 | import pandas as pd
7 |
8 |
9 | # perplexity = evaluate.load("perplexity", module_type="metric")
10 | perplexity = Perplexity_Petf()
11 |
12 | path = "../train/data/easy_law.csv"
13 | db = pd.read_csv(path)
14 | dataset = load_dataset("csv", data_files=path)["train"]
15 |
16 | data = dataset.map(
17 | lambda x: {
18 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}"
19 | }
20 | )
21 |
22 | results = perplexity.compute(
23 | model_id="kfkas/LawBot-v1_koalpaca_legalQA_easylaw",
24 | add_start_token=False,
25 | predictions=data["text"],
26 | max_length=256,
27 | batch_size=4,
28 | )
29 | # results = perplexity.compute(model_id='nlpai-lab/kullm-polyglot-5.8b-v2',add_start_token=False,predictions=data['text'],max_length=256,batch_size=4)
30 | print(list(results.keys()))
31 | print(round(results["mean_perplexity"], 2))
32 | print(round(results["perplexities"][0], 2))
33 |
--------------------------------------------------------------------------------
/.github/workflows/backendCI.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: CI for backend
5 |
6 | on:
7 | push:
8 | paths: "backend/**"
9 | branches: [ "develop" ]
10 | pull_request:
11 | paths: "backend/**"
12 | branches: [ "develop" ]
13 |
14 | jobs:
15 | CI:
16 | runs-on: self-hosted
17 | env:
18 | working-directory: /opt/ml/level3_nlp_finalproject-nlp-08
19 |
20 | steps:
21 | - name: Update Code
22 | run: git pull
23 | working-directory: ${{ env.working-directory }}
24 |
25 | - name: install dependencies
26 | run: |
27 | source final_project/bin/activate
28 | pip install -r requirements.txt
29 | working-directory: "${{ env.working-directory }}/backend"
30 |
31 | - name: Test with pytest
32 | run: |
33 | source final_project/bin/activate
34 | pytest
35 | working-directory: "${{ env.working-directory }}/backend"
36 |
--------------------------------------------------------------------------------
/prototype/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "prototype",
3 | "version": "0.1.0",
4 | "private": true,
5 | "dependencies": {
6 | "@testing-library/jest-dom": "^5.16.5",
7 | "@testing-library/react": "^13.4.0",
8 | "@testing-library/user-event": "^13.5.0",
9 | "react": "^18.2.0",
10 | "react-dom": "^18.2.0",
11 | "react-scripts": "5.0.1",
12 | "web-vitals": "^2.1.4"
13 | },
14 | "scripts": {
15 | "start": "react-scripts start",
16 | "build": "react-scripts build",
17 | "test": "react-scripts test",
18 | "eject": "react-scripts eject"
19 | },
20 | "eslintConfig": {
21 | "extends": [
22 | "react-app",
23 | "react-app/jest"
24 | ]
25 | },
26 | "browserslist": {
27 | "production": [
28 | ">0.2%",
29 | "not dead",
30 | "not op_mini all"
31 | ],
32 | "development": [
33 | "last 1 chrome version",
34 | "last 1 firefox version",
35 | "last 1 safari version"
36 | ]
37 | },
38 | "devDependencies": {
39 | "autoprefixer": "^10.4.14",
40 | "postcss": "^8.4.25",
41 | "tailwindcss": "^3.3.2"
42 | },
43 | "proxy": "http://localhost:8000"
44 | }
--------------------------------------------------------------------------------
/backend/router/router.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datetime import datetime
3 | from typing import List, Union
4 |
5 | import pytz
6 | import requests
7 | from fastapi import FastAPI
8 | from pydantic import BaseModel
9 |
10 | app = FastAPI()
11 |
12 | class Question(BaseModel):
13 | q_sentence: str
14 |
15 | class Answer(BaseModel):
16 | answer_sentence: Union[str, None]
17 | similar_precedent: Union[List, None]
18 |
19 |
20 | @app.get("/")
21 | def root():
22 | print("Hello World!")
23 |
24 | @app.post("/generate", response_model=Union[Answer, None])
25 | async def generate(question: Question):
26 | KST = pytz.timezone('Asia/Seoul')
27 | print(datetime.now(KST).strftime("%Y/%m/%d %H:%M:%S"))
28 | q_sentence = question.q_sentence.strip()
29 | if q_sentence == "":
30 | print({"q_sentence": q_sentence})
31 | return None
32 | headers = {"Content-Type": "application/json", "accept": "application/json"}
33 | url = "http://127.0.0.1:8000/generate"
34 | data = {"q_sentence": q_sentence}
35 |
36 | print(data)
37 | res = requests.post(url, headers=headers, data=json.dumps(data))
38 |
39 | return res.json()
40 |
--------------------------------------------------------------------------------
/model/LLM/train/load_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3 | from peft import prepare_model_for_kbit_training
4 | from peft import LoraConfig, get_peft_model
5 | from utils import print_trainable_parameters
6 |
7 |
8 | def load_model(model_name):
9 | bnb_config = BitsAndBytesConfig(
10 | load_in_4bit=True,
11 | bnb_4bit_use_double_quant=True,
12 | bnb_4bit_quant_type="nf4",
13 | bnb_4bit_compute_dtype=torch.bfloat16,
14 | )
15 | tokenizer = AutoTokenizer.from_pretrained(model_name)
16 | model = AutoModelForCausalLM.from_pretrained(
17 | model_name, quantization_config=bnb_config, device_map={"": 0}
18 | )
19 | model.gradient_checkpointing_enable()
20 | model = prepare_model_for_kbit_training(model)
21 |
22 | config = LoraConfig(
23 | r=8,
24 | lora_alpha=32,
25 | target_modules=["query_key_value"],
26 | lora_dropout=0.05,
27 | bias="none",
28 | task_type="CAUSAL_LM",
29 | )
30 |
31 | model = get_peft_model(model, config)
32 | print_trainable_parameters(model)
33 |
34 | return model, tokenizer
35 |
--------------------------------------------------------------------------------
/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "frontend",
3 | "version": "0.1.0",
4 | "private": true,
5 | "dependencies": {
6 | "@testing-library/jest-dom": "^5.16.5",
7 | "@testing-library/react": "^13.4.0",
8 | "@testing-library/user-event": "^13.5.0",
9 | "autoprefixer": "^10.4.14",
10 | "axios": "^1.4.0",
11 | "postcss": "^8.4.26",
12 | "react": "^18.2.0",
13 | "react-dom": "^18.2.0",
14 | "react-scripts": "5.0.1",
15 | "tailwindcss": "^3.3.3",
16 | "tailwindcss-animated": "^1.0.1",
17 | "web-vitals": "^2.1.4"
18 | },
19 | "scripts": {
20 | "start": "react-scripts start",
21 | "build": "react-scripts build",
22 | "test": "react-scripts test",
23 | "eject": "react-scripts eject"
24 | },
25 | "eslintConfig": {
26 | "extends": [
27 | "react-app",
28 | "react-app/jest"
29 | ]
30 | },
31 | "browserslist": {
32 | "production": [
33 | ">0.2%",
34 | "not dead",
35 | "not op_mini all"
36 | ],
37 | "development": [
38 | "last 1 chrome version",
39 | "last 1 firefox version",
40 | "last 1 safari version"
41 | ]
42 | },
43 | "proxy": "https://api.yoonseul.link"
44 | }
45 |
--------------------------------------------------------------------------------
/data_pipeline/generate/generate_bard.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import time
4 |
5 | import pandas as pd
6 | from bardapi import Bard
7 | from tqdm.auto import tqdm
8 |
9 | bard = Bard(token_from_browser=True)
10 |
11 | def get_response(prompt):
12 | response = bard.get_answer(prompt)
13 | return response
14 |
15 | with open("prompts.pkl", "rb") as f:
16 | prompts = pickle.load(f)
17 |
18 | data = []
19 | num_data = 1
20 | prompt_name = "fewshot"
21 | prompt = prompts["fewshot"]
22 |
23 | for i in tqdm(range(num_data)):
24 | try:
25 | response = get_response(prompt)
26 | except:
27 | time.sleep(5)
28 | continue
29 | data.append(
30 | [
31 | prompt_name,
32 | prompt,
33 | *[response["choices"][i]["content"][0] for i in range(len(response["choices"]))]
34 | ]
35 | )
36 |
37 | generated_df = pd.DataFrame(
38 | data,
39 | columns=[
40 | "prompt_type",
41 | "prompt",
42 | "result_1",
43 | "result_2",
44 | "result_3"
45 | ])
46 |
47 | os.makedirs("../data/generated_data/bard", exist_ok=True)
48 | generated_df.to_csv(f"../data/generated_data/bard/generated_data_bard_{len(generated_df)}.csv", index=False)
--------------------------------------------------------------------------------
/model/LLM/evaluation/evaluate_mertrics.py:
--------------------------------------------------------------------------------
1 | from data_preprocessing import PPL_Autodata
2 | from petf_ppl import Perplexity_Petf
3 | import dataset as dataset
4 | import evaluate
5 |
6 |
7 | data = PPL_Autodata("./eval_data_legal").preprocess_data
8 | petf_model_id = "kfkas/LawBot-v1_koalpaca_legalQA_easylaw_train"
9 | normal_model_id = "nlpai-lab/kullm-polyglot-5.8b-v2"
10 |
11 | use = "petf" # petf or normal
12 |
13 | if use == "petf":
14 | perplexity = Perplexity_Petf()
15 | results = perplexity.compute(
16 | model_id=petf_model_id,
17 | add_start_token=False,
18 | predictions=data["text"],
19 | max_length=256,
20 | batch_size=4,
21 | )
22 | print(list(results.keys()))
23 | print(round(results["mean_perplexity"], 2))
24 | print(round(results["perplexities"][0], 2))
25 | else:
26 | perplexity = evaluate.load("perplexity", module_type="metric")
27 | results = perplexity.compute(
28 | model_id=normal_model_id,
29 | add_start_token=False,
30 | predictions=data["text"],
31 | max_length=256,
32 | batch_size=4,
33 | )
34 | print(list(results.keys()))
35 | print(round(results["mean_perplexity"], 2))
36 | print(round(results["perplexities"][0], 2))
37 |
--------------------------------------------------------------------------------
/data_pipeline/README.md:
--------------------------------------------------------------------------------
1 | # LawBot - Data Pipeline
2 | * All commands in this instruction should be run in the following directory.
3 | `/data_pipeline`
4 | ## ⚠️ How To install Requirements
5 | * Run the following command on your terminal.
6 |
7 | ```bash
8 | $ pip install -r requirments.txt
9 | ```
10 |
11 | ## ⌨️ How To Execute
12 | ### Web Crawling
13 | ```bash
14 | $ python3 crawler.py
15 | $ python3 qa_crawler.py
16 | ```
17 | ### Data Generation
18 | ```bash
19 | $ python3 generate_gpt.py
20 | $ python3 parse.py
21 | ```
22 | * To generate data using the GPT model, you need to [obtain an API key](https://platform.openai.com/account/api-keys) from OpenAI first.
23 | Depending on the model used, usage fees might be charged.
24 |
25 | * If you want to modify the prompts, follow these steps.
26 | 1. Add the prompts in the `backup_prompts.py` file.
27 | 2. Run the following command in your terminal.
28 |
29 | ```bash
30 | $ python3 backup_prompts.py
31 | ```
32 | 3. New pickle file will be overlapped to existing `prompts.pkl`
33 | 4. After that, you can proceed with the stpes mentioned earlier.
34 |
35 | ### Preprocessing
36 | * Run the following commands in your terminal.
37 | ```bash
38 | $ python3 spellchecker.py
39 | $ python3 preprocessor_v2.py
40 | ```
41 |
--------------------------------------------------------------------------------
/prototype/src/components/SimilarPrecedent.js:
--------------------------------------------------------------------------------
1 | import PrecedentCard from "./SimilarPrecedentComponents/PrecedentCard"
2 |
3 | function SimilarPresdent({ precedents }) {
4 | const precedents_best = precedents[0]
5 | const precedents_second = precedents[1]
6 | const precedents_third = precedents[2]
7 |
8 | return (
9 |
10 | s
23 |
24 | )
25 | }
26 | export default SimilarPresdent
27 |
--------------------------------------------------------------------------------
/model/LLM/evaluation/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datasets import concatenate_datasets
3 | from datasets import load_dataset
4 | import os
5 |
6 |
7 | class PPL_Autodata:
8 | def __init__(self, data_folder="./eval_data_legal"):
9 | self.data_foloder = data_folder
10 | self.concat_dataset = self.concat_datasets(self.data_foloder)
11 | self.preprocess_data = self.preprocessing_data(self.concat_dataset)
12 |
13 | def concat_datasets(self, folder_path):
14 | datasets = []
15 | for file_name in os.listdir(folder_path):
16 | if file_name.endswith(".csv"):
17 | file_path = os.path.join(folder_path, file_name)
18 | dataset = load_dataset("csv", data_files=file_path)
19 | datasets.append(dataset["train"])
20 |
21 | combined_dataset = concatenate_datasets(datasets)
22 | if len(combined_dataset.features) > 2:
23 | data = combined_dataset.remove_columns("title")
24 | else:
25 | data = combined_dataset
26 | return data
27 |
28 | def preprocessing_data(self, dataset):
29 | data = dataset.map(
30 | lambda x: {
31 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}"
32 | }
33 | )
34 | return data
35 |
--------------------------------------------------------------------------------
/backend/tests/test_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from app.search import load_vector_data, search_precedent
9 | from sentence_transformers import SentenceTransformer
10 |
11 |
12 | def test_search_precedent():
13 | q_a_sentence = "교통법규 위반으로 벌점을 받았습니다. 이 벌점은 소멸되지 않고 계속 누적되나요?"+"벌점이 소멸되는 것이 아니라 누적되어 관리됩니다. 즉, 벌점이 누적되면 운전면허가 취소되거나 정지될 수 있습니다. 벌점 누적에 따른 운전면허의 취소 또는 정지 기준은 다음과 같습니다(도로교통법 시행규칙 별표 28). 운전면허 취소기준 1. 혈중알콜농도가 0.1% 이상인 사람이 자동차 등을 운전한 경우 2. 음주측정기에 의한 측정결과에 불복하는 사람이 술에 취한 상태에 있다고 인정할 만한 상당한 이유가 있음에도 불구하고 경찰공무원의 측정 요구에 불응하거나 경찰공무원을 폭행 또는 협박한 경우(단, 운전자가 경찰공무원에게 폭행을 가한 경우에는 그 정도가 심하지 않을 때에 한함) 3. 적성검사를 받지 않거나 적성검사에 불합격된 사람이 다시 운전면허를 받고자 하는 경우 4. 자동차를 이용하여 범죄행위를 한 경우 5. 다른 사람의 자동차를 훔치거나 빼앗은 경우 6. 교통사고를 야기하고 도주한 경우 7. 단속경찰공무원 등을 폭행한 경우 8. 정차ㆍ주차위반에 대한 조치"
14 |
15 | model = SentenceTransformer("jhgan/ko-sroberta-multitask")
16 |
17 | print("Load data")
18 | base_path = os.path.abspath(os.path.dirname(__file__))
19 |
20 | text_data = np.array(pd.read_csv(base_path + "/../data/law_data/law_data.csv"))
21 | vector_data = load_vector_data(
22 | base_path + "/../data/law_data/law_data_drop_vector.bin"
23 | )
24 |
25 | search_precedent(q_a_sentence=q_a_sentence, model=model, text_data=text_data, vector_data=vector_data)
26 |
--------------------------------------------------------------------------------
/model/BERT/inference/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from sentence_transformers import SentenceTransformer
6 | from tqdm import tqdm
7 | import pickle
8 | from sklearn.metrics.pairwise import cosine_similarity
9 | from .utils import load_vector_data
10 |
11 |
12 | def BERT_infer(input):
13 | model_name = "jhgan/ko-sroberta-multitask"
14 | model = SentenceTransformer(model_name)
15 | model.to("cuda:0")
16 |
17 | input_vector = model.encode(input)
18 | input_vecotr = np.expand_dims(input_vector, axis=0)
19 |
20 | base_path = os.path.join(os.path.dirname(__file__))
21 |
22 | text_data = np.array(pd.read_csv(base_path + "/../data/law_data/law_data_drop.csv"))
23 | vector_data = load_vector_data(
24 | base_path + "/../data/law_data/law_data_drop_vector.bin"
25 | )
26 |
27 | cos_sim = cosine_similarity(input_vecotr, vector_data)
28 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3]
29 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3]
30 |
31 | pan_list = []
32 |
33 | for i, index in enumerate(top_question):
34 | if data_cosine[i] >= 0.5:
35 | pan_list.append(
36 | f"case Number : {text_data[index][0]} judgementAbstract : {text_data[index][4]} judgementNote :{text_data[index][9]}"
37 | )
38 |
39 | return pan_list
40 |
41 |
42 | if __name__ == "__main__":
43 | BERT_infer("상원이형과 이혼을 하는것은 중죄이고 죄질이 나쁘기 떄문에 징역 10년입니다.")
44 |
--------------------------------------------------------------------------------
/prototype/src/components/SimilarPrecedentComponents/PrecedentCard.js:
--------------------------------------------------------------------------------
1 | function PrecedentCard({ precedent, number }) {
2 | return (
3 |
4 |
유사 판례 조항 {number}
5 |
사건 케이스 : {precedent.case_name}
6 |
사건 번호 : {precedent.case_number}
7 |
사건 분류 : {precedent.case_type}
8 |
관련 법 조항 :{precedent.ref_article}
9 |
10 | Read more
11 |
12 |
13 |
14 |
15 |
16 | )
17 | }
18 | export default PrecedentCard
19 |
--------------------------------------------------------------------------------
/model/Retrieval/bert_retrieval/inference.py:
--------------------------------------------------------------------------------
1 | from data_preprocessing import Autodata
2 | import os
3 | import numpy as np
4 | from sentence_transformers import SentenceTransformer
5 | from sklearn.metrics.pairwise import cosine_similarity
6 |
7 |
8 | def Query_BERT_infer(input):
9 | model_name = "jhgan/ko-sroberta-multitask"
10 | model = SentenceTransformer(model_name).to("cuda:0")
11 |
12 | data = Autodata("./data")
13 | original_data = data.concat_dataset
14 | vector_data = data.load_vector_data()
15 |
16 | input_vector = model.encode(input)
17 | input_vector = np.expand_dims(input_vector, axis=0)
18 |
19 | cos_sim = cosine_similarity(input_vector, vector_data)
20 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3]
21 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3]
22 |
23 | print("유사도 : ", data_cosine)
24 |
25 | question_list = []
26 | answer_list = []
27 |
28 | for i, index in enumerate(top_question):
29 | if data_cosine[i] >= 0.6:
30 | question_list.append(original_data["question"][index])
31 | answer_list.append(original_data["answer"][index])
32 | count = 0
33 | for question, answer in zip(question_list, answer_list):
34 | print(f"유사 상담 사례 질문 {count} : {question}")
35 | print(f"유사 상담 사례 답변 {count} : {answer}")
36 | print()
37 | count += 1
38 |
39 |
40 | if __name__ == "__main__":
41 | Query_BERT_infer("제가 자동차를 운전하다 중앙선을 침범하다가 2충 추돌사고를 발생시켰습니다. 이때 무슨 법으로 처벌 받을 수 있나요?")
42 |
--------------------------------------------------------------------------------
/model/Filter/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datasets import concatenate_datasets, Dataset, load_dataset
3 | import os
4 |
5 |
6 | class Autodata:
7 | def __init__(self, data_folder="./data"):
8 | self.data_foloder = data_folder
9 | self.concat_dataset = self.concat_datasets(self.data_foloder)
10 |
11 | def concat_datasets(self, data_foloder):
12 | datasets = []
13 | pd_datasets = []
14 | for file_name in os.listdir(data_foloder):
15 | if file_name.endswith(".csv"):
16 | file_path = os.path.join(data_foloder, file_name)
17 | dataset = pd.read_csv(file_path)
18 | dataframe = dataset[["question", "answer"]]
19 | pd_datasets.append(dataframe)
20 | dataset = Dataset.from_pandas(dataframe)
21 | datasets.append(dataset)
22 |
23 | combined_dataset = concatenate_datasets(datasets)
24 |
25 | return combined_dataset
26 |
27 | def load_instruction_dataset(self, dataset_id):
28 | koalpaca_data = load_dataset(dataset_id)
29 | data = koalpaca_data["train"]
30 | data = data.rename_column("instruction", "question")
31 | question = data["question"]
32 | return question
33 |
34 | def label_indexing(self, data, state):
35 | if state == 1:
36 | answer = 1
37 | else:
38 | answer = 0
39 | answer_list = [answer] * len(data)
40 |
41 | return Dataset.from_dict({"question": data, "target": answer_list})
42 |
--------------------------------------------------------------------------------
/frontend/src/components/SimilarPrecedentComponents/PrecedentCard.js:
--------------------------------------------------------------------------------
1 | function PrecedentCard({ precedent, number }) {
2 | return (
3 |
4 |
유사 판례 조항 {number}
5 |
사건 이름 : {precedent.case_name}
6 |
사건 번호 : {precedent.case_number}
7 |
사건 분류 : {precedent.case_type}
8 |
관련 법 조항 :{precedent.ref_article}
9 |
10 | Read more
11 |
12 |
13 |
14 |
15 |
16 | )
17 | }
18 | export default PrecedentCard
19 |
--------------------------------------------------------------------------------
/model/LLM/train/train.py:
--------------------------------------------------------------------------------
1 | from load_model import load_model
2 | from data_preprocessing import Autodata
3 | import transformers
4 |
5 |
6 | def train():
7 | model_id = "nlpai-lab/kullm-polyglot-5.8b-v2"
8 | model, tokenizer = load_model(model_id)
9 | tokenizer.pad_token = tokenizer.eos_token
10 | train_data = Autodata(data_folder="./data", tokenizer=tokenizer).tokenizer_dataset
11 | val_data = Autodata(data_folder="./val_data", tokenizer=tokenizer).tokenizer_dataset
12 | trainer = transformers.Trainer(
13 | model=model,
14 | train_dataset=train_data,
15 | eval_dataset=val_data,
16 | args=transformers.TrainingArguments(
17 | per_device_train_batch_size=16,
18 | gradient_accumulation_steps=1,
19 | num_train_epochs=6,
20 | learning_rate=1e-4,
21 | fp16=True,
22 | logging_steps=10,
23 | save_strategy="epoch",
24 | evaluation_strategy="epoch",
25 | output_dir="./model_outputs",
26 | optim="paged_adamw_8bit",
27 | ),
28 | data_collator=transformers.DataCollatorForLanguageModeling(
29 | tokenizer, mlm=False
30 | ),
31 | )
32 | model.config.use_cache = (
33 | False # silence the warnings. Please re-enable for inference!
34 | )
35 | trainer.train()
36 |
37 | push_model_id = "kfkas/LawBot-level2-5.8B_FIX"
38 | huggingface_write_token = "" # Huggingface Write Token 작성
39 |
40 | model.push_to_hub(
41 | push_model_id, use_temp_dir=True, use_auth_token=huggingface_write_token
42 | )
43 | print(f"{push_model_id} 모델 업로드 완료!")
44 |
45 |
46 | if __name__ == "__main__":
47 | train()
48 |
--------------------------------------------------------------------------------
/model/Retrieval/bm25_retrieval/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datasets import concatenate_datasets, Dataset
3 | import os
4 | import json
5 |
6 |
7 | class Autodata:
8 | def __init__(self, data_folder="./data"):
9 | self.data_foloder = data_folder
10 | self.concat_dataset = self.concat_datasets(self.data_foloder)
11 |
12 | def concat_datasets(self, data_foloder):
13 | datasets = []
14 | pd_datasets = []
15 | for file_name in os.listdir(data_foloder):
16 | if file_name.endswith(".csv"):
17 | file_path = os.path.join(data_foloder, file_name)
18 | dataset = pd.read_csv(file_path)
19 | dataframe = dataset[["question", "answer"]]
20 | pd_datasets.append(dataframe)
21 | dataset = Dataset.from_pandas(dataframe)
22 | datasets.append(dataset)
23 |
24 | combined_dataset = concatenate_datasets(datasets)
25 | pd_combiend_dataset = pd.DataFrame(combined_dataset)
26 |
27 | return pd_combiend_dataset
28 |
29 | def make_all_data(self, data, path):
30 | df = data
31 | data_dict = {}
32 |
33 | for i in range(len(df)):
34 | key = str(i)
35 | data_dict[key] = {
36 | "question": df.iloc[i]["question"],
37 | "answer": df.iloc[i]["answer"],
38 | }
39 |
40 | with open(path, "w", encoding="utf-8") as file:
41 | json.dump(data_dict, file, ensure_ascii=False, indent=4)
42 |
43 | def load_json_data(self, path="./all_data/all_data.json"):
44 | if not os.path.isfile(path):
45 | self.make_all_data(self.concat_dataset, path)
46 |
--------------------------------------------------------------------------------
/model/LLM/train/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datasets import concatenate_datasets, Dataset
3 | from datasets import load_dataset
4 | import os
5 |
6 |
7 | class Autodata:
8 | def __init__(self, data_folder="./data", max_length=1024, tokenizer=None):
9 | self.data_foloder = data_folder
10 | self.max_length = max_length
11 | self.tokenizer = tokenizer
12 | self.concat_dataset = self.concat_datasets(self.data_foloder)
13 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset)
14 |
15 | def concat_datasets(self, folder_path):
16 | datasets = []
17 | for file_name in os.listdir(folder_path):
18 | if file_name.endswith(".csv"):
19 | file_path = os.path.join(folder_path, file_name)
20 | dataset = pd.read_csv(file_path)
21 | dataframe = dataset[["question", "answer"]]
22 | dataset = Dataset.from_pandas(dataframe)
23 | datasets.append(dataset)
24 |
25 | combined_dataset = concatenate_datasets(datasets)
26 |
27 | return combined_dataset
28 |
29 | def tokenizing_dataset(self, dataset):
30 | data = dataset.map(
31 | lambda x: {
32 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>"
33 | }
34 | )
35 | data = data.map(
36 | lambda samples: self.tokenizer(
37 | samples["text"],
38 | truncation=True,
39 | max_length=self.max_length,
40 | padding=False,
41 | return_tensors=None,
42 | ),
43 | batched=True,
44 | )
45 |
46 | return data.shuffle()
47 |
--------------------------------------------------------------------------------
/prototype/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
15 |
16 |
25 | Lawbot
26 |
27 |
28 |
29 | You need to enable JavaScript to run this app.
30 |
31 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/backend/app/generate.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 |
5 |
6 | def generate_answer(q_sentence: str, model, tokenizer):
7 | model.eval()
8 | model.config.use_cache = True # silence the warnings. Please re-enable for inference!
9 | # model.float()
10 | tokenizer.pad_token = tokenizer.eos_token
11 |
12 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
13 | # prompt = f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{q_sentence}\n\n### 응답:\n"
14 | prompt = f"다음은 한국 법률 QA입니다. 질문에 맞는 적절한 응답을 작성하세요.\n\n### 질문:\n{q_sentence}\n\n### 응답:\n"
15 | # prompt = f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{q_sentence}\n\n### 응답:\n"
16 | # prompt = f"다음은 폭행 관련 법률 QA입니다. 질문과 관련된 자료를 가져온 다음 해당 내용을 요약하여 응답을 작성하세요.\n\n### 질문:\n{q_sentence}\n\n### 응답:\n"
17 | # prompt = f"다음은 폭행 관련 법률 QA입니다. 상황을 읽고 질문에 맞는 적절한 응답을 작성하세요.\n\n### 상황:\n혁준이가 술을 마시고 아내를 폭행했어. 근데 아내는 외도를 했던 상황이야\n\n### 질문:\n감형이 가능할까?\n\n### 응답:\n"
18 | len_prompt = len(prompt)
19 |
20 | start_time = time.time()
21 |
22 | gened = model.generate(
23 | **tokenizer(
24 | prompt,
25 | return_tensors='pt',
26 | return_token_type_ids=False
27 | ).to(device),
28 | max_new_tokens=1024,
29 | early_stopping=True,
30 | do_sample=True,
31 | eos_token_id=2,
32 | # temperature=1e-5,
33 | top_k=10,
34 | top_p=0.95,
35 | no_repeat_ngram_size=2,
36 | num_beams=3,
37 | # force_words_ids = tokenizer("폭행", add_special_tokens=False).input_ids,
38 | )
39 |
40 | print(f"generate time: {time.time() - start_time}")
41 | answer = tokenizer.decode(gened[0])[len_prompt:].replace("응답:", "").replace("\n\n", "\n").replace("", "")
42 | print(f"LLM answer: {answer}\n")
43 | return answer
44 |
--------------------------------------------------------------------------------
/backend/airflow/module/load_data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | from datasets import load_dataset
5 | from dotenv import load_dotenv
6 |
7 | load_dotenv()
8 | huggingface_read_token = os.getenv("HUGGINGFACE_READ_TOKEN")
9 |
10 | def load_train_data():
11 | data = load_dataset("YoonSeul/legal-GPT-BARD-train_v3", use_auth_token=huggingface_read_token)
12 | df = pd.DataFrame(data)
13 |
14 | questions = []
15 | answers = []
16 |
17 | for i in df.iterrows():
18 | questions.append(i[1]["train"]["instruction"])
19 | answers.append(i[1]["train"]["output"])
20 |
21 | train_datasets = {
22 | "question": questions,
23 | "answer": answers
24 | }
25 |
26 | train_df = pd.DataFrame(train_datasets)
27 | BASE_PATH = os.path.join(os.path.dirname(os.path.abspath((os.path.dirname(__file__)))), "data")
28 | SAVE_PATH = os.path.join(BASE_PATH, "train_data.csv")
29 | os.makedirs(BASE_PATH, exist_ok=True)
30 | train_df.to_csv(SAVE_PATH)
31 |
32 | def load_eval_data():
33 | data = load_dataset("YoonSeul/legal-GPT-BARD-val_v3", use_auth_token=huggingface_read_token)
34 | df = pd.DataFrame(data)
35 |
36 | questions = []
37 | answers = []
38 |
39 | for i in df.iterrows():
40 | questions.append(i[1]["train"]["instruction"])
41 | answers.append(i[1]["train"]["output"])
42 |
43 | train_datasets = {
44 | "question": questions,
45 | "answer": answers
46 | }
47 |
48 | train_df = pd.DataFrame(train_datasets)
49 | BASE_PATH = os.path.join(os.path.dirname(os.path.abspath((os.path.dirname(__file__)))), "data")
50 | SAVE_PATH = os.path.join(BASE_PATH, "eval_data.csv")
51 | os.makedirs(BASE_PATH, exist_ok=True)
52 | train_df.to_csv(SAVE_PATH)
53 |
54 | def load_train_eval_data():
55 | load_train_data()
56 | load_eval_data()
57 |
58 | if __name__ == "__main__":
59 | load_train_eval_data()
60 |
--------------------------------------------------------------------------------
/frontend/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
16 |
17 |
26 | Lawbot
27 |
28 |
29 |
30 |
31 | You need to enable JavaScript to run this app.
32 |
33 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/backend/app/search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from pydantic import BaseModel
8 | from sentence_transformers import SentenceTransformer
9 | from sklearn.metrics.pairwise import cosine_similarity
10 |
11 |
12 | class Precedent(BaseModel):
13 | case_name: str
14 | case_number: str
15 | case_type: str
16 | ref_article: str
17 | url: str
18 |
19 |
20 | def search_precedent(q_a_sentence: str, model, text_data, vector_data):
21 | start_time = time.time()
22 | # model = SentenceTransformer("jhgan/ko-sroberta-multitask") #TODO
23 | model.to("cuda:0")
24 |
25 | input_vector = model.encode(q_a_sentence)
26 | input_vecotr = np.expand_dims(input_vector, axis=0)
27 |
28 | cos_sim = cosine_similarity(input_vecotr, vector_data)
29 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3]
30 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3]
31 |
32 | precedent_list = []
33 |
34 | for i, index in enumerate(top_question):
35 | if data_cosine[i] >= 0.5:
36 | ref_article = text_data[index][7]
37 | ref_article_split = str(ref_article).split()
38 | if len(ref_article_split) >= 2:
39 | url = f"https://law.go.kr/법령/{ref_article_split[0]}/{ref_article_split[1]}"
40 | else:
41 | url = ""
42 | precedent_list.append(
43 | Precedent(case_name=text_data[index][3], case_number=text_data[index][0], case_type=text_data[index][6], ref_article=ref_article, url=url)
44 | )
45 |
46 | # del model
47 | # torch.cuda.empty_cache()
48 |
49 | print(f"search time: {time.time() - start_time}")
50 |
51 | return precedent_list
52 |
53 | def load_vector_data(path):
54 | if os.path.isfile(path):
55 | with open(path, "rb") as fr:
56 | vector_data = pickle.load(fr)
57 | else:
58 | print("판례 데이터가 존재하지 않습니다.")
59 | vector_data = None
60 | return vector_data
61 |
--------------------------------------------------------------------------------
/frontend/src/components/Loader.js:
--------------------------------------------------------------------------------
1 | import React, { useState, useEffect } from "react";
2 |
3 | function Loader() {
4 | const phrases = [
5 | "법조문을 읽고 있습니다...",
6 | "당신의 법률 문제를 위한 판례를 찾아보는 중입니다. 잠시만 기다려 주세요.",
7 | "답변을 생성하는데 최대 30초 정도의 시간이 걸립니다.",
8 | "당신의 질문에 대한 법적 해결책을 찾아서 조금 있다가 돌아올게요. 티타임을 즐기세요.",
9 | "답변을 생성중입니다 잠시만 기다려주세요.",
10 | ];
11 |
12 | const [currentPhraseIndex, setCurrentPhraseIndex] = useState(0);
13 | const [visibleText, setVisibleText] = useState("");
14 |
15 | const typingDelay = 100;
16 | const nextPhraseDelay = 3000; // 3초
17 |
18 | useEffect(() => {
19 | const typeText = (currentIndex, currentText) => {
20 | if (currentIndex < currentText.length) {
21 | setVisibleText(currentText.substring(0, currentIndex + 1));
22 | setTimeout(() => typeText(currentIndex + 1, currentText), typingDelay);
23 | } else {
24 | setTimeout(() => {
25 | setCurrentPhraseIndex((prevIndex) => (prevIndex + 1) % phrases.length);
26 | setVisibleText("");
27 | }, nextPhraseDelay);
28 | }
29 | };
30 |
31 | const currentPhrase = phrases[currentPhraseIndex];
32 | typeText(0, currentPhrase);
33 | }, [currentPhraseIndex]);
34 |
35 | return (
36 |
37 |
38 |
39 | L
40 |
41 |
42 |
43 |
48 |
{visibleText}
49 |
50 |
51 |
52 |
53 | );
54 | }
55 |
56 | export default Loader;
57 |
--------------------------------------------------------------------------------
/frontend/README.md:
--------------------------------------------------------------------------------
1 | # LawBot - Frontend
2 |
3 | ## ⚠️ Requirements for Web Frontend
4 |
5 | Below is the commands for installing dependency packages to run web frontend.
6 | Run the following commands in proper order.
7 | ```bash
8 | # install react
9 | apt install curl
10 | curl https://raw.githubusercontent.com/creationix/nvm/master/install.sh | bash
11 | source ~/.bashr
12 | nvm install 18.04.0
13 |
14 | # Install Tailwind CSS
15 | npm install -g yarn
16 | yarn add tailwindcss postcss autoprefixer
17 | npx tailwindcss init
18 | npm i tailwindcss-animated
19 | ```
20 |
21 | When installation is completed, please run the following commands to check if proper version is installed.
22 | Please refer to the README file inside the `frontend` directory for the executable file.
23 |
24 | ```bash
25 | node -v # v18.04.1
26 | npm -v # v8.11.0
27 | ```
28 |
29 | ## 💻 Getting Started with Create React App
30 |
31 | ```bash
32 | yarn start # npm start
33 | ```
34 |
35 | ## How to Configure nginx
36 |
37 | - Below is the commands for installing dependency packages to run web frontend
38 |
39 | ```
40 | sudo apt install nginx
41 | sudo rm /etc/nginx/sites-available/default
42 | sudo rm /etc/nginx/sites-enabled/default
43 |
44 | sudo vim /etc/nginx/sites-available/frontend.conf
45 | sudo ln -s /etc/nginx/sites-available/myapp.conf /etc/nginx/sites-enabled/myapp.conf
46 | ```
47 |
48 | - The configuration files that you need to write in `/etc/nginx/sites-available`
49 |
50 | ```conf
51 | server {
52 | listen 80;
53 | listen [::]:80;
54 |
55 | server_name yoonseul.link ;
56 |
57 | root /home/ubuntu/level3_nlp_finalproject-nlp-08/frontend/build;
58 | index index.html;
59 |
60 | location /generate {
61 |
62 | proxy_pass https://backend_server;
63 | proxy_connect_timeout 500;
64 | proxy_send_timeout 500;
65 | proxy_read_timeout 500;
66 | }
67 | }
68 | ```
69 |
70 |
71 | ## How to start nginx
72 | ```bash
73 | sudo systemctl stop nginx # nginx 중단
74 | sudo systemctl start nginx # nginx 시작
75 | ``````
--------------------------------------------------------------------------------
/model/Retrieval/bert_retrieval/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datasets import concatenate_datasets, Dataset
3 | from datasets import load_dataset
4 | import os
5 | import pickle
6 | import numpy as np
7 | from sentence_transformers import SentenceTransformer
8 | from tqdm import tqdm
9 |
10 |
11 | class Autodata:
12 | def __init__(self, data_folder="./data"):
13 | self.data_foloder = data_folder
14 | self.concat_dataset = self.concat_datasets(self.data_foloder)
15 |
16 | def concat_datasets(self, data_foloder):
17 | datasets = []
18 | for file_name in os.listdir(data_foloder):
19 | if file_name.endswith(".csv"):
20 | file_path = os.path.join(data_foloder, file_name)
21 | dataset = pd.read_csv(file_path)
22 | dataframe = dataset[["question", "answer"]]
23 | dataset = Dataset.from_pandas(dataframe)
24 | datasets.append(dataset)
25 |
26 | combined_dataset = concatenate_datasets(datasets)
27 | pd_combiend_dataset = pd.DataFrame(combined_dataset)
28 | return pd_combiend_dataset
29 |
30 | def build_vector_dataset(self, dataset, path):
31 | dataset = np.array(dataset)
32 | model_name = "jhgan/ko-sroberta-multitask"
33 | model = SentenceTransformer(model_name).to("cuda:0")
34 |
35 | query_vector_list = []
36 |
37 | for i in tqdm(range(len(dataset))):
38 | question = dataset[i][0]
39 | query_vector = model.encode(question)
40 | query_vector_list.append(list(query_vector))
41 |
42 | with open(path, "wb") as fw:
43 | pickle.dump(query_vector_list, fw)
44 |
45 | with open(path, "rb") as fr:
46 | vector_data = pickle.load(fr)
47 |
48 | return vector_data
49 |
50 | def load_vector_data(self, path="./data/query_vector.bin"):
51 | if os.path.isfile(path):
52 | with open(path, "rb") as fr:
53 | vector_data = pickle.load(fr)
54 | else:
55 | vector_data = self.build_vector_dataset(self.concat_dataset, path)
56 | return vector_data
57 |
--------------------------------------------------------------------------------
/data_pipeline/preprocessor.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pandas as pd
4 |
5 |
6 | def remove_escape(raw_text: str) -> str:
7 | pattern = r"\t|\n|\xa0"
8 | processed_text = re.sub(pattern, " ", raw_text)
9 | processed_text_stripped = " ".join(processed_text.split())
10 | return processed_text_stripped
11 |
12 |
13 | def remove_phone_number(raw_text: str) -> str:
14 | pattern = r"\(*\d+\s*-\s*\d+\s*-\s*\d+\)*"
15 | processed_text = re.sub(pattern, "", raw_text)
16 | return processed_text
17 |
18 |
19 | def remove_hyperlink(raw_text: str) -> str:
20 | pattern = (
21 | r":*\s*\(*:*\s*https?://[\w\dㄱ-ㅎㅏ-ㅣ가-힣!@#$%^&*(),.?/:;\"'<>{}|+=~_-]+\s*\)*"
22 | )
23 | processed_text = re.sub(pattern, "", raw_text)
24 | return processed_text
25 |
26 |
27 | def remove_header(raw_text: str) -> str:
28 | header_pattern = "안녕하십니까. 대한법률구조공단 사이버상담을 이용해 주셔서 감사합니다."
29 | header_end_idx = re.search(header_pattern, raw_text)
30 | if header_end_idx != None:
31 | processed_text = raw_text[header_end_idx.end() :]
32 | return processed_text
33 | else:
34 | return raw_text
35 |
36 |
37 | def remove_footer(raw_text: str) -> str:
38 | footer_pattern = "1. 위 답변은 귀하께서 제공해주신 사실관계에 기초한 답변자 개인의 법률적 의견으로서 이와 다른 의견이 있을 수도 있으므로 참고자료로만 활용해주시고,"
39 | footer_start_idx = re.search(footer_pattern, raw_text)
40 | if footer_start_idx != None:
41 | processed_text = raw_text[: footer_start_idx.start()]
42 | return processed_text
43 | else:
44 | return raw_text
45 |
46 |
47 | def preprocess(raw_text: str) -> str:
48 | preprocessed_text = raw_text
49 | preprocess_functions = [
50 | remove_header,
51 | remove_footer,
52 | remove_escape,
53 | remove_phone_number,
54 | remove_hyperlink,
55 | ]
56 | for preprocess_function in preprocess_functions:
57 | preprocessed_text = preprocess_function(preprocessed_text)
58 | return preprocessed_text
59 |
60 |
61 | if __name__ == "__main__":
62 | df = pd.read_csv("./data/raw_qa_dataset.csv")
63 | preprocessed_df = df.assign(
64 | content=df["content"].apply(preprocess), answer=df["answer"].apply(preprocess)
65 | )
66 | preprocessed_df.to_csv("./data/preprocessed_qa_dataset.csv", index=False)
67 |
--------------------------------------------------------------------------------
/model/LLM/evaluation/dialogue_evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
5 |
6 | from inference import infer
7 | import pandas as pd
8 | import torch
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from peft import PeftModel, PeftConfig
11 | from tqdm.auto import tqdm
12 |
13 | class LawyerEvaluation:
14 | def __init__(self, path="lawyer_question.csv", model_name="uomnf97/LawBot-level2-final-preprocessing-v3"):
15 | self.data = pd.read_csv(f"./eval_data_legal/{path}")
16 | self.model_name = model_name
17 | self.answer = False
18 |
19 | def generate_answer(self):
20 | answer_list = []
21 |
22 | device = (
23 | torch.device("cuda:0") if torch.cuda.is_available(
24 | ) else torch.device("cpu")
25 | )
26 | peft_model_id = self.model_name
27 | config = PeftConfig.from_pretrained(peft_model_id)
28 | model = AutoModelForCausalLM.from_pretrained(
29 | config.base_model_name_or_path, device_map={"": 0}, torch_dtype=torch.float16
30 | )
31 | model = PeftModel.from_pretrained(
32 | model, peft_model_id, torch_dtype=torch.float16)
33 | tokenizer = AutoTokenizer.from_pretrained(
34 | config.base_model_name_or_path
35 | )
36 | model.eval()
37 | model.config.use_cache = (
38 | True # silence the warnings. Please re-enable for inference!
39 | )
40 | model.float()
41 | tokenizer.pad_token = tokenizer.eos_token
42 |
43 | for i in tqdm(range(len(self.data)), desc="processing evaluation data"):
44 | data = self.data.iloc[i]["question"]
45 | answer_list.append(infer.gen(data, model=model,
46 | tokenizer=tokenizer, device=device))
47 |
48 |
49 | self.data["answer"] = answer_list
50 | self.answer = True
51 |
52 | def to_csv(self):
53 | if self.answer == True:
54 | self.data.to_csv(f"./eval_data_legal/lawyer_val_with_answer.csv")
55 | else:
56 | print("답안을 생성해주세요!")
57 |
58 |
59 | if __name__ == "__main__":
60 | lawyer_evaluation = LawyerEvaluation()
61 | lawyer_evaluation.generate_answer()
62 | lawyer_evaluation.to_csv()
63 |
--------------------------------------------------------------------------------
/data_pipeline/spellchecker.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | from selenium import webdriver
3 | from selenium.webdriver.common.by import By
4 | from selenium.webdriver.common.keys import Keys
5 | from selenium.webdriver.support.ui import WebDriverWait
6 |
7 | import pandas as pd
8 |
9 | def set_options():
10 | options = webdriver.ChromeOptions()
11 | options.add_argument("--headless")
12 | options.add_argument('--no-sandbox')
13 | options.add_argument("--single-process")
14 | options.add_argument("--disable-dev-shm-usage")
15 | return options
16 |
17 | def wait_driver_click(id):
18 | WebDriverWait(driver, timeout=60).until(lambda d: d.find_element(By.ID, id))
19 | driver.find_element(By.ID, id).click()
20 |
21 |
22 | def spell_check(txt, delay = 15):
23 | if len(txt) > 1200:
24 | return ''
25 |
26 | WebDriverWait(driver, timeout=10).until(lambda d: d.find_element(By.ID, 'character_counter_content'))
27 | driver.find_element(By.ID, 'character_counter_content').send_keys(txt)
28 |
29 | wait_driver_click('spell_check')
30 | driver.implicitly_wait(delay)
31 | wait_driver_click('spell_done_all')
32 | driver.implicitly_wait(1)
33 |
34 | WebDriverWait(driver, timeout=10).until(lambda d: d.find_element(By.CSS_SELECTOR, '#checker_preview'))
35 | clean_txt = driver.find_element(By.CSS_SELECTOR, '#checker_preview').text
36 | driver.refresh()
37 | return clean_txt
38 |
39 |
40 | dt = pd.read_csv('./data/preprocessed_qa_spacing_pre_word.csv')
41 | new_q = []
42 | new_a = []
43 | options = set_options()
44 |
45 | # options = webdriver.ChromeOptions()
46 | # options.add_argument("--headless")
47 | # options.add_argument('--no-sandbox')
48 | # options.add_argument("--single-process")
49 | # options.add_argument("--disable-dev-shm-usage")
50 |
51 | # driver = webdriver.Chrome(options=options)
52 | driver = webdriver.Chrome()
53 | driver.get('https://www.saramin.co.kr/zf_user/tools/character-counter')
54 |
55 | q_v = dt['question'].values
56 | a_v = dt['answer'].values
57 |
58 | for i in trange(len(dt)):
59 | if len(q_v[i]) > 1200:
60 | continue
61 | try:
62 | new_q.append(spell_check(q_v[i]))
63 | new_a.append(a_v[i])
64 | except:
65 | continue
66 |
67 | driver.quit()
68 |
69 |
70 | clean_data = pd.DataFrame({'question' : new_q, 'answer' :new_a})
71 |
72 | clean_data.to_csv('./data/preprocessed_qa_spacing_spell.csv', index=False)
--------------------------------------------------------------------------------
/data_pipeline/generate/generate_gpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import time
4 |
5 | import openai
6 | import pandas as pd
7 | from tqdm.auto import tqdm
8 |
9 | openai.api_key = os.environ["OPENAI_API_KEY"]
10 |
11 | def get_response(prompt, model="gpt-3.5-turbo", temperature=1.0, max_tokens=1000):
12 | messages = [{"role": "user", "content": prompt}]
13 | response = openai.ChatCompletion.create(
14 | model=model,
15 | messages=messages,
16 | temperature=temperature,
17 | max_tokens=max_tokens
18 | )
19 | return response
20 |
21 |
22 | def get_price_of_inference(model, input_tokens, output_tokens):
23 | if model == "gpt-3.5-turbo-0613":
24 | input_price_per_k = 0.0015
25 | output_price_per_k = 0.002
26 | price_dollar = (input_tokens * input_price_per_k + output_tokens * output_price_per_k) / 1000
27 | price_won = round(price_dollar * 1281.61, 5)
28 | return [price_dollar, price_won]
29 | else:
30 | return None
31 |
32 |
33 | with open("prompts.pkl", "rb") as f:
34 | prompts = pickle.load(f)
35 |
36 | full_responses = {}
37 | data = []
38 | num_data = 1000
39 | prompt_type = "fewshot"
40 | prompt = prompts["fewshot"]
41 |
42 | for i in tqdm(range(num_data)):
43 | try:
44 | response = get_response(prompt)
45 | except:
46 | time.sleep(5)
47 | continue
48 | full_responses[prompt_type] = response
49 | output = response.choices[0].message.content # GPT output
50 | model = response.model # Model used
51 | input_tokens = response.usage.prompt_tokens # Number of tokens of input
52 | output_tokens = response.usage.completion_tokens # Number of tokens of output
53 | data.append(
54 | [
55 | prompt_type,
56 | prompt,
57 | output,
58 | model,
59 | input_tokens,
60 | output_tokens,
61 | *get_price_of_inference(model, input_tokens, output_tokens)
62 | ],
63 | )
64 |
65 | generated_df = pd.DataFrame(
66 | data,
67 | columns=[
68 | "prompt_type",
69 | "prompt",
70 | "output",
71 | "model",
72 | "input_tokens",
73 | "output_tokens",
74 | "price_dollar",
75 | "price_won"
76 | ])
77 |
78 | os.makedirs("../data/generated_data/gpt", exist_ok=True)
79 | generated_df.to_csv(f"./data/generated_data/gpt/generated_data_gpt_{len(generated_df)}.csv", index=False)
--------------------------------------------------------------------------------
/model/LLM/inference/infer.py:
--------------------------------------------------------------------------------
1 | from peft import PeftModel, PeftConfig
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 |
5 |
6 | def gen(x, model, tokenizer, device):
7 | prompt = (
8 | f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x}\n\n### 응답:\n"
9 | )
10 | len_prompt = len(prompt)
11 | gened = model.generate(
12 | **tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to(
13 | device
14 | ),
15 | max_new_tokens=1024,
16 | early_stopping=True,
17 | do_sample=True,
18 | top_k=20,
19 | top_p=0.92,
20 | no_repeat_ngram_size=3,
21 | eos_token_id=2,
22 | repetition_penalty=1.2,
23 | num_beams=3,
24 | )
25 | return tokenizer.decode(gened[0])[len_prompt:]
26 |
27 |
28 | def LLM_infer(input, model_type):
29 | device = (
30 | torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
31 | )
32 | if model_type == "kullm":
33 | peft_model_id = "YoonSeul/LawBot-level-3-KuLLM-5.8B-tae-2epoch"
34 | config = PeftConfig.from_pretrained(peft_model_id)
35 | model = AutoModelForCausalLM.from_pretrained(
36 | config.base_model_name_or_path,
37 | device_map={"": 0},
38 | torch_dtype=torch.float16,
39 | )
40 | model = PeftModel.from_pretrained(
41 | model, peft_model_id, torch_dtype=torch.float16
42 | )
43 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
44 |
45 | model.eval()
46 | model.config.use_cache = (
47 | True # silence the warnings. Please re-enable for inference!
48 | )
49 | tokenizer.pad_token = tokenizer.eos_token
50 | else:
51 | model_id = "kfkas/Legal-Llama-2-ko-7b-Chat"
52 | model = AutoModelForCausalLM.from_pretrained(
53 | model_id,
54 | device_map={"": 0},
55 | torch_dtype=torch.float16,
56 | low_cpu_mem_usage=True,
57 | )
58 | tokenizer = AutoTokenizer.from_pretrained(model_id)
59 | model.eval()
60 | model.config.use_cache = True
61 | tokenizer.pad_token = tokenizer.eos_token
62 | output = gen(input, model=model, tokenizer=tokenizer, device=device)
63 |
64 | return output
65 |
66 |
67 | if __name__ == "__main__":
68 | model_type = "kullm" # llama, kullm
69 | input = "음주운전을하면 어떤 법으로 처벌 되나요?"
70 | text = LLM_infer(input, model_type)
71 |
--------------------------------------------------------------------------------
/model/Filter/train.py:
--------------------------------------------------------------------------------
1 | from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
2 | from data_preprocessing import Autodata
3 | from datasets import concatenate_datasets
4 | from dataloader import CustomDataset
5 | from sklearn.model_selection import train_test_split
6 | from model.Filter.utils import compute_metrics
7 |
8 |
9 | def train():
10 | model_name = "monologg/koelectra-small-v3-discriminator"
11 | data = Autodata("./data")
12 | legal_dataset = data.concat_dataset["question"]
13 | legal_answer_dataset = data.concat_dataset["answer"]
14 | alpaca_dataset = data.load_instruction_dataset("nlpai-lab/kullm-v2")
15 |
16 | legal_data = data.label_indexing(legal_dataset, state=0)
17 | legal_dataset_answer = data.label_indexing(legal_answer_dataset, state=0)
18 | alpaca_data = data.label_indexing(alpaca_dataset, state=1)
19 |
20 | total_data = concatenate_datasets([legal_data, alpaca_data, legal_dataset_answer])
21 | train_dataset, val_dataset = train_test_split(
22 | total_data, test_size=0.2, random_state=42
23 | )
24 |
25 | train_data = CustomDataset(
26 | data_file=train_dataset,
27 | model_name=model_name,
28 | text_columns="question",
29 | target_columns="target",
30 | max_length=256,
31 | state="train",
32 | )
33 | val_data = CustomDataset(
34 | data_file=val_dataset,
35 | model_name=model_name,
36 | text_columns="question",
37 | target_columns="target",
38 | max_length=256,
39 | state="train",
40 | )
41 |
42 | model = AutoModelForSequenceClassification.from_pretrained(
43 | model_name, num_labels=2, ignore_mismatched_sizes=True
44 | )
45 |
46 | args = TrainingArguments(
47 | output_dir="output_dir",
48 | evaluation_strategy="epoch",
49 | save_strategy="epoch",
50 | learning_rate=1e-5,
51 | per_device_train_batch_size=64,
52 | per_device_eval_batch_size=64,
53 | num_train_epochs=10,
54 | weight_decay=0.01,
55 | load_best_model_at_end=True,
56 | dataloader_num_workers=4,
57 | logging_steps=50,
58 | seed=42,
59 | group_by_length=True,
60 | )
61 |
62 | trainer = Trainer(
63 | model=model,
64 | args=args,
65 | train_dataset=train_data,
66 | eval_dataset=val_data,
67 | compute_metrics=compute_metrics,
68 | )
69 |
70 | trainer.train()
71 |
72 |
73 | if __name__ == "__main__":
74 | train()
75 |
--------------------------------------------------------------------------------
/model/Filter/dataloader.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from tqdm.auto import tqdm
4 | from transformers import AutoTokenizer
5 |
6 |
7 | class CustomDataset(torch.utils.data.Dataset):
8 | def __init__(
9 | self,
10 | data_file,
11 | state,
12 | text_columns,
13 | target_columns,
14 | max_length=256,
15 | model_name="klue/roberta-small",
16 | ):
17 | self.state = state
18 | self.data = data_file
19 | self.text_columns = text_columns
20 | self.max_length = max_length
21 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22 |
23 | if self.state == "test":
24 | self.inputs = self.preprocessing(self.data)
25 | else:
26 | self.target_columns = target_columns if target_columns is not None else []
27 | self.inputs, self.targets = self.preprocessing(self.data)
28 |
29 | def __getitem__(self, idx):
30 | if self.state == "test":
31 | return {"input_ids": torch.tensor(self.inputs[idx], dtype=torch.long)}
32 | else:
33 | return {
34 | "input_ids": torch.tensor(self.inputs[idx], dtype=torch.long),
35 | "labels": torch.tensor(self.targets[idx], dtype=torch.long),
36 | }
37 |
38 | def __len__(self):
39 | return len(self.inputs)
40 |
41 | def tokenizing(self, dataframe: pd.DataFrame) -> list:
42 | """
43 | 토크나이징
44 | Args :
45 | dataframe (DataFrame): 토크나이징할 데이터
46 | Return :
47 | data (list) : 학습할 문장 토큰 리스트
48 | """
49 | data = []
50 | for item in tqdm(
51 | dataframe["question"], desc="Tokenizing", total=len(dataframe["question"])
52 | ):
53 | text = item
54 | # text = [item for text_column in self.text_columns]
55 | outputs = self.tokenizer(
56 | text,
57 | add_special_tokens=True,
58 | padding="max_length",
59 | truncation=True,
60 | max_length=self.max_length,
61 | )
62 | data.append(outputs["input_ids"])
63 | return data
64 |
65 | def preprocessing(self, data):
66 | inputs = self.tokenizing(data)
67 | if self.state == "test":
68 | return inputs
69 | else:
70 | try:
71 | targets = data[self.target_columns]
72 | except:
73 | targets = []
74 | return inputs, targets
75 |
--------------------------------------------------------------------------------
/data_pipeline/generate/backup_prompts.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | prompts = {
4 | "zeroshot": "임의의 법률 분쟁 상황을 가정하고, 그에 대한 내용을 질문의 형식으로 만들어주세요. 그리고 해당 질문에 대한 답변을 함께 출력해주세요.",
5 | "zeroshot2": "법률 분쟁 상황을 가정하고, 해당 상황의 가해자 또는 피해자가 의뢰할 만한 상담 내용을 작성해주세요. 출력에는 해당 상담 내용에 대한 답변을 포함해주세요.",
6 | "oneshot": """
7 | 아래 예시의 형식을 참고하여, 임의의 법률 분쟁 상황에 대한 질의 응답 데이터를 생성해주세요.
8 |
9 | [예시]
10 | 질문: 남편이 가출하여 연락이 되지 않다가 3년 6개월 뒤 ‘실종자 찾아주기’ 운동의 일환으로 DNA검사를 했더니 남편은 3년 전에 이미 교통사고로 사망하였고, 신원미상자로 처리되었다고 합니다. 가출 전 남편이 들어놓은 사망보험금을 청구했더니 보험회사는 사망 후 3년이 경과하였기 때문에 소멸시효 완성을 주장하고 있습니다. 보험금은 못 받는건가요?
11 |
12 | 답변: 보험금청구권의 소멸시효는 특별한 다른 사정이 없는 한 보험사고가 발생한 때부터 진행하는 것이 원칙입니다. 그러나 객관적으로 보험사고가 발생한 사실을 확인할 수 없는 사정이 있는 경우에는 보험금청구권자가 보험사고의 발생을 알았거나 알 수 있었던 때부터 보험금청구권의 소멸시효가 진행합니다. 따라서 사례의 경우에는 보험금 청구가 가능합니다.
13 | """,
14 | "fewshot": """
15 | 아래 예시의 형식을 참고하여, 임의의 법률 분쟁 상황에 대한 질의 응답 데이터를 생성해주세요.
16 |
17 | [예시 1]
18 | 질문: 남편이 가출하여 연락이 되지 않다가 3년 6개월 뒤 ‘실종자 찾아주기’ 운동의 일환으로 DNA검사를 했더니 남편은 3년 전에 이미 교통사고로 사망하였고, 신원미상자로 처리되었다고 합니다. 가출 전 남편이 들어놓은 사망보험금을 청구했더니 보험회사는 사망 후 3년이 경과하였기 때문에 소멸시효 완성을 주장하고 있습니다. 보험금은 못 받는건가요?
19 |
20 | 답변: 보험금청구권의 소멸시효는 특별한 다른 사정이 없는 한 보험사고가 발생한 때부터 진행하는 것이 원칙입니다. 그러나 객관적으로 보험사고가 발생한 사실을 확인할 수 없는 사정이 있는 경우에는 보험금청구권자가 보험사고의 발생을 알았거나 알 수 있었던 때부터 보험금청구권의 소멸시효가 진행합니다. 따라서 사례의 경우에는 보험금 청구가 가능합니다.
21 |
22 | [예시 2]
23 | 질문: 아는 사람에게 500만원을 빌려줬는데 갚지 않습니다. 소송을 해야 할 것 같은데 비용이며, 시간이 꽤 들 것 같네요. 방법이 없을까요?
24 |
25 | 답변: 소송의 당사자가 소송으로 청구하는 금액이나 물건의 가치가 3천만원을 넘지 않는 사건은 시간이나 비용에 있어서 민사소송보다 간편한 절차로 진행할 수 있는 소액사건재판 제도를 이용할 수 있습니다. 소액사건재판 외에도 민사조정이나 지급명령(독촉절차)을 이용할 수도 있습니다.
26 | """,
27 | "fewshot_with_constraints": """
28 | 아래 예시의 형식을 참고하여, 임의의 법률 분쟁 상황에 대한 새로운 질의 응답 데이터를 생성해주세요. 답변을 생성할 때는 다음 조건에 맞게 생성해주세요.
29 |
30 | [조건]
31 | - 제시된 예시와는 다른 사레를 바탕으로 질문을 생성해줘
32 | - 대한민국의 법률에 근거하여 법적 분쟁 상황에 대한 답변을 생성해줘
33 | - 대한민국 법률에 실제로 존재하는 조항과, 입력과 유사한 상황에 대한 판례를 답변 내용에 포함해줘
34 |
35 | [예시 1]
36 | 질문: 남편이 가출하여 연락이 되지 않다가 3년 6개월 뒤 ‘실종자 찾아주기’ 운동의 일환으로 DNA검사를 했더니 남편은 3년 전에 이미 교통사고로 사망하였고, 신원미상자로 처리되었다고 합니다. 가출 전 남편이 들어놓은 사망보험금을 청구했더니 보험회사는 사망 후 3년이 경과하였기 때문에 소멸시효 완성을 주장하고 있습니다. 보험금은 못 받는건가요?
37 |
38 | 답변: 보험금청구권의 소멸시효는 특별한 다른 사정이 없는 한 보험사고가 발생한 때부터 진행하는 것이 원칙입니다. 그러나 객관적으로 보험사고가 발생한 사실을 확인할 수 없는 사정이 있는 경우에는 보험금청구권자가 보험사고의 발생을 알았거나 알 수 있었던 때부터 보험금청구권의 소멸시효가 진행합니다. 따라서 사례의 경우에는 보험금 청구가 가능합니다.
39 |
40 | [예시 2]
41 | 질문: 아는 사람에게 500만원을 빌려줬는데 갚지 않습니다. 소송을 해야 할 것 같은데 비용이며, 시간이 꽤 들 것 같네요. 방법이 없을까요?
42 |
43 | 답변: 소송의 당사자가 소송으로 청구하는 금액이나 물건의 가치가 3천만원을 넘지 않는 사건은 시간이나 비용에 있어서 민사소송보다 간편한 절차로 진행할 수 있는 소액사건재판 제도를 이용할 수 있습니다. 소액사건재판 외에도 민사조정이나 지급명령(독촉절차)을 이용할 수도 있습니다.
44 | """,
45 | }
46 |
47 |
48 | if __name__ == "__main__":
49 | with open("prompts.pkl", "wb") as f:
50 | pickle.dump(prompts, f)
--------------------------------------------------------------------------------
/data_pipeline/generate/parse.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import pandas as pd
5 |
6 |
7 | def collect_raw_data(path):
8 | raw_data = [data for data in os.listdir(path) if data.endswith(".csv")]
9 | collected_data = pd.DataFrame()
10 |
11 | for data in raw_data:
12 | data_path = os.path.join(path, data)
13 | df = pd.read_csv(data_path)
14 | collected_data = pd.concat([collected_data, df], ignore_index=True)
15 |
16 | collected_data = collected_data.drop_duplicates()
17 | collected_data = collected_data.reset_index()
18 | return collected_data
19 |
20 |
21 | def check_output_format(data):
22 | for idx, datum in enumerate(data):
23 | pattern = r"\[\s*(질문|답변|Q|A)\s*\d*\]|(질문|답변)\s*:\s*"
24 | is_fit_format = (re.search(pattern, datum) != None)
25 | if is_fit_format:
26 | continue
27 | else:
28 | print(f"Index {idx} is out of format")
29 | return False
30 | return True
31 |
32 |
33 | def check_qa_pair(data):
34 | mismatched_indices = [idx for idx, datum in enumerate(data) if (len(datum) % 2) != 0]
35 | return mismatched_indices
36 |
37 |
38 | path = "../data/generated_data/gpt"
39 | raw_data = collect_raw_data(path)
40 | raw_outputs = raw_data.output.tolist()
41 |
42 | print(f"Found {len(raw_outputs)} data points.")
43 |
44 | assert check_output_format(raw_outputs) == True, "Check the format of the data"
45 |
46 | processed_outputs = []
47 |
48 | for idx, raw_data in enumerate(raw_outputs):
49 | pattern = r"\[\s*(질문|답변|Q|A)\s*\d*\]|(질문|답변)\s*:\s*"
50 | processed_data = re.sub(pattern, "[SEP]", raw_data)
51 | processed_data = processed_data.split("[SEP]")
52 | processed_data = [data.strip() for data in processed_data if len(data) > 20]
53 | processed_data = [re.sub(r"\[\s*예시\s*\d*\]", "", data) for data in processed_data]
54 |
55 | if len(processed_data) % 2 != 0:
56 | if len(processed_data) == 1:
57 | continue
58 | processed_data = processed_data[:-1]
59 | processed_outputs.append(processed_data)
60 |
61 | assert len(check_qa_pair(processed_outputs)) == 0, "QA pair mismatched data exists."
62 |
63 | q_list = []
64 | a_list = []
65 | for output in processed_outputs:
66 | while len(output) != 0:
67 | q_list.append(output.pop(0))
68 | a_list.append(output.pop(0))
69 |
70 | assert len(q_list) == len(a_list), "QA pair mismatch"
71 |
72 | processed_data = pd.DataFrame({"question": q_list, "answer": a_list})
73 | processed_data.to_csv(f"./data/generated_qa_data_{len(processed_data)}.csv", index=False)
74 |
75 | print(f"Generated {len(processed_data)} pairs of QA data.")
--------------------------------------------------------------------------------
/model/README.md:
--------------------------------------------------------------------------------
1 | # LawBot - Model
2 |
3 | ## 💻 Getting Started
4 |
5 | ## ⚠️ How To install Requirements
6 | ### Cuda install
7 |
8 | 1. Run the following code on your terminal.
9 |
10 | ```bash
11 | wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
12 | chmod +x cuda_11.8.0_520.61.05_linux.run
13 | sh cuda_11.8.0_520.61.05_linux.run
14 | ```
15 | 2. Input `accept` to proceed.
16 |
17 |
18 |
19 | 3. Select the driver and install.
20 |
21 |
22 |
23 |
24 | 4. Run the following commands on your terminal.
25 | ```bash
26 | $ export PATH=/usr/local/cuda-11.8/bin:$PATH
27 | $ export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH
28 | $ pip install -r requirements.txt
29 | ```
30 | ## ⌨️ How To Train
31 | ### LLM (Large Language Model)
32 | * Before training, place Legal QA data at following directory
33 | `model/LLM/train/data`
34 | `model/LLM/train/val_data`
35 | * HuggingFace Write Token should be filled at line #38 of the following file
36 | * Run the following command on your terminal
37 | ```bash
38 | $ python3 model/LLM/train/train.py
39 | ```
40 | ### Question Filterering Model (Koelectra)
41 | ```bash
42 | $ python3 model/Filter/train.py
43 | ```
44 | ## ⌨️ How To Infer
45 | ### LLM (Large Language Model)
46 | * peft model id should be changed after training at line #27 of the following file
47 | * Run the following command on your terminal
48 |
49 | ```bash
50 | $ python3 model/LLM/inference/infer.py
51 | ```
52 | ### Sentence BERT Retrieval
53 | * Before training, place Legal QA data at following directory
54 | `model/Retrieval/bert_retrieval/data`
55 | * Run the following command on your terminal
56 |
57 | ```bash
58 | $ python3 model/Retrieval/bert_retrieval/retrieval_main.py
59 | ```
60 | ### BM25 Retrieval
61 | * Before training, place Legal QA data at following directory
62 | `model/Retrieval/bm25_retrieval/all_data`
63 | * Run the following command on your terminal
64 |
65 | ```bash
66 | $ python3 model/Retrieval/bm25_retrieval/retrieval_main.py
67 | ```
68 | ### Question Filtering Model (Koelectra)
69 | * Run the following command on your terminal
70 |
71 | ```bash
72 | $ python3 model/Filter/infer.py
73 | ```
74 | ## ⌨️ How To Evaluate
75 | ### LLM (Large Language Model)
76 | * model name and `use` parmater should be changed if needed
77 | * Run the following command on your terminal
78 | ```bash
79 | $ python3 model/LLM/evaluation/evaluate_metrics.py
80 | ```
81 |
--------------------------------------------------------------------------------
/model/Retrieval/bm25_retrieval/retrieval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import List, Optional, Tuple, Union
4 | import numpy as np
5 | import pandas as pd
6 | from rank_bm25 import BM25L, BM25Okapi, BM25Plus
7 |
8 |
9 | def setup_bm25(parent_class):
10 | class CustomBM25(parent_class):
11 | def __init__(self, corpus, tokenizer):
12 | super().__init__(corpus, tokenizer)
13 |
14 | def get_relevant_doc(self, query, k):
15 | query_vec = self.tokenizer(query)
16 | result = self.get_scores(query_vec)
17 | sorted_result = np.argsort(result.squeeze())[::-1]
18 | doc_score = result.squeeze()[sorted_result].tolist()[:k]
19 | doc_indices = sorted_result.tolist()[:k]
20 | return doc_score, doc_indices
21 |
22 | def get_relevant_doc_bulk(self, queries, k):
23 | doc_scores = []
24 | doc_indices = []
25 | for query in queries:
26 | doc_score, doc_indice = self.get_relevant_doc(query, k)
27 | doc_scores.append(doc_score)
28 | doc_indices.append(doc_indice)
29 | return doc_scores, doc_indices
30 |
31 | return CustomBM25
32 |
33 |
34 | class SparseRetrievalBM25:
35 | def __init__(
36 | self,
37 | tokenize_fn,
38 | data_path: Optional[str] = "./csv_data/",
39 | context_path: Optional[str] = "all_data.json",
40 | bm25_type: Optional[str] = "",
41 | ) -> None:
42 | self.data_path = data_path
43 | with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
44 | wiki = json.load(f)
45 |
46 | self.contexts = list(([v["question"] for v in wiki.values()]))
47 | self.contexts_answer = list(([v["answer"] for v in wiki.values()]))
48 |
49 | if bm25_type == "Okapi":
50 | bm25_class = setup_bm25(BM25Okapi)
51 | self.bm25 = bm25_class(self.contexts, tokenize_fn)
52 | elif bm25_type == "L":
53 | bm25_class = setup_bm25(BM25L)
54 | self.bm25 = bm25_class(self.contexts, tokenize_fn)
55 | elif bm25_type == "plus":
56 | bm25_class = setup_bm25(BM25Plus)
57 | self.bm25 = bm25_class(self.contexts, tokenize_fn)
58 |
59 | def retrieve(
60 | self, query_or_dataset: Union[str, pd.DataFrame], topk: Optional[int] = 1
61 | ) -> Union[Tuple[List, List], pd.DataFrame]:
62 | if isinstance(query_or_dataset, str):
63 | doc_scores, doc_indices = self.bm25.get_relevant_doc(
64 | query_or_dataset, k=topk
65 | )
66 | return (
67 | doc_scores,
68 | doc_indices,
69 | [self.contexts[doc_indices[i]] for i in range(topk)],
70 | [self.contexts_answer[doc_indices[i]] for i in range(topk)],
71 | )
72 |
--------------------------------------------------------------------------------
/backend/app/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from typing import List, Union
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import pytz
8 | import torch
9 | from fastapi import FastAPI
10 | from peft import PeftConfig, PeftModel
11 | from pydantic import BaseModel
12 | from sentence_transformers import SentenceTransformer
13 | from transformers import AutoModelForCausalLM, AutoTokenizer
14 |
15 | from bert_retrieval import Autodata, bert_retrieve_QA
16 | from bm25_retrieval import retrieve_QA
17 | from filter import is_legal_question
18 | from generate import generate_answer
19 | from search import Precedent, load_vector_data, search_precedent
20 |
21 |
22 | class Question(BaseModel):
23 | q_sentence: str
24 |
25 | class Answer(BaseModel):
26 | answer_sentence: Union[str, None]
27 | similar_precedent: Union[List[Precedent], None]
28 |
29 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
30 | app = FastAPI()
31 |
32 | llm = None
33 | tokenizer = None
34 | search_model = None
35 | retrieve_model = None
36 | retrieve_data = None
37 | retrieve_vector_data = None
38 | text_data = None
39 | vector_data = None
40 |
41 | @app.on_event("startup")
42 | def startup_event():
43 | global tokenizer, llm, search_model, retrieve_model, retrieve_data, retrieve_vector_data, text_data, vector_data
44 |
45 | print("Load LLM")
46 | model_id = "kfkas/Legal-Llama-2-ko-7b-Chat"
47 | # config = PeftConfig.from_pretrained(peft_model_id)
48 | llm = AutoModelForCausalLM.from_pretrained(
49 | model_id, device_map={"": 0}, torch_dtype=torch.float16
50 | )
51 | # llm = PeftModel.from_pretrained(llm, peft_model_id, torch_dtype=torch.float16)
52 | tokenizer = AutoTokenizer.from_pretrained(model_id)
53 |
54 | print("Load search model")
55 | search_model = SentenceTransformer("jhgan/ko-sroberta-multitask")
56 |
57 | print("Load retrieve model and data")
58 | retrieve_model = SentenceTransformer("jhgan/ko-sroberta-multitask")
59 | DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data/bert_retrieval_data")
60 | retrieve_data = Autodata(DATA_DIR)
61 | retrieve_vector_data = retrieve_data.load_vector_data(os.path.join(DATA_DIR, "query_vector.bin"))
62 |
63 | print("Load data")
64 | base_path = os.path.abspath(os.path.dirname(__file__))
65 |
66 | text_data = np.array(pd.read_csv(base_path + "/../data/law_data/law_data_drop.csv"))
67 | vector_data = load_vector_data(
68 | base_path + "/../data/law_data/law_data_drop_vector.bin"
69 | )
70 |
71 |
72 | @app.post("/generate", response_model=Answer)
73 | async def generate(question: Question):
74 | KST = pytz.timezone('Asia/Seoul')
75 | print(datetime.now(KST).strftime("%Y/%m/%d %H:%M:%S"))
76 |
77 | q_sentence = question.q_sentence
78 | print(f"q_sentence: {q_sentence}")
79 |
80 | if not is_legal_question(q_sentence=q_sentence):
81 | return Answer(answer_sentence=None, similar_precedent=None)
82 |
83 | # retrieve_answer = retrieve_QA(q_sentence=q_sentence)
84 | retrieve_answer = bert_retrieve_QA(q_sentence=q_sentence, model=retrieve_model, data=retrieve_data, vector_data=retrieve_vector_data)
85 |
86 | answer_sentence = generate_answer(q_sentence=q_sentence, model=llm, tokenizer=tokenizer)
87 |
88 | similar_precedent = search_precedent(q_a_sentence=q_sentence+retrieve_answer+answer_sentence, model=search_model, text_data=text_data, vector_data=vector_data)
89 |
90 | return Answer(answer_sentence=answer_sentence, similar_precedent=similar_precedent)
91 |
--------------------------------------------------------------------------------
/data_pipeline/preprocessor_v2.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pandas as pd
4 |
5 |
6 | def remove_escape(raw_text: str) -> str:
7 | pattern = r"\t|\n|\xa0"
8 | processed_text = re.sub(pattern, " ", raw_text)
9 | processed_text_stripped = " ".join(processed_text.split())
10 | return processed_text_stripped
11 |
12 |
13 | def remove_phone_number(raw_text: str) -> str:
14 | pattern = r"\(*\d+\s*-\s*\d+\s*-\s*\d+\)*"
15 | processed_text = re.sub(pattern, "", raw_text)
16 | return processed_text
17 |
18 |
19 | def remove_hyperlink(raw_text: str) -> str:
20 | pattern = (
21 | r":*\s*\(*:*\s*https?://[\w\dㄱ-ㅎㅏ-ㅣ가-힣!@#$%^&*(),.?/:;\"'<>{}|+=~_-]+\s*\)*"
22 | )
23 | processed_text = re.sub(pattern, "", raw_text)
24 | return processed_text
25 |
26 |
27 | def remove_header(raw_text: str) -> str:
28 | header_pattern = "안녕하십니까. 대한법률구조공단 사이버상담을 이용해 주셔서 감사합니다."
29 | header_end_idx = re.search(header_pattern, raw_text)
30 | if header_end_idx != None:
31 | processed_text = raw_text[header_end_idx.end() :]
32 | return processed_text
33 | else:
34 | return raw_text
35 |
36 |
37 | def remove_footer(raw_text: str) -> str:
38 | footer_pattern = "1. 위 답변은 귀하께서 제공해주신 사실관계에 기초한 답변자 개인의 법률적 의견으로서 이와 다른 의견이 있을 수도 있으므로 참고자료로만 활용해주시고,"
39 | footer_start_idx = re.search(footer_pattern, raw_text)
40 | if footer_start_idx != None:
41 | processed_text = raw_text[: footer_start_idx.start()]
42 | return processed_text
43 | else:
44 | return raw_text
45 |
46 | def remove_link(raw_text: str) -> str:
47 | pattern = (
48 | '\(?[:/a-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+[/\da-zA-Z?=%@.&]+\s?\)?'
49 | )
50 | processed_text = re.sub(pattern, "", raw_text)
51 | pattern = (
52 | '\(?[:/a-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+\s?\)?'
53 | )
54 | processed_text = re.sub(pattern, "", processed_text)
55 |
56 | pattern = (
57 | '\(?[:/a-zA-Z]+.\s?[\da-zA-Z]+.\s?[\da-zA-Z]+\)?|'
58 | )
59 | processed_text = re.sub(pattern, "", processed_text)
60 | return processed_text
61 |
62 |
63 | def remove_page_word(raw_text: str) -> str:
64 |
65 | pattern = '사이버상담|사이버 상담|공단|방문|국번없이 132|132번'
66 | if re.findall(pattern, raw_text) == []:
67 | return raw_text
68 |
69 | split_text = raw_text.split('.')
70 | remove_text = [i for i in split_text if re.findall(pattern, i) == []]
71 |
72 | return '.'.join(remove_text)
73 |
74 | def remove_phone(raw_text: str) -> str:
75 | pattern = ('\(?\s?☎?\s?국번\s?없이\s?☎?\s?\d+-?\d+\s?번?\)?')
76 | processed_text = re.sub(pattern, "", raw_text)
77 | pattern = ('\(?\s?☎\s?\d+-?\d+\s?번?\)?')
78 | processed_text = re.sub(pattern, "", processed_text)
79 |
80 |
81 | return processed_text
82 |
83 |
84 | def preprocess(raw_text: str) -> str:
85 | preprocessed_text = raw_text
86 | preprocess_functions = [
87 | remove_header,
88 | remove_footer,
89 | remove_escape,
90 | remove_phone,
91 | remove_page_word,
92 | remove_hyperlink,
93 | remove_link,
94 |
95 | ]
96 | for preprocess_function in preprocess_functions:
97 | preprocessed_text = preprocess_function(preprocessed_text)
98 | return preprocessed_text
99 |
100 |
101 | if __name__ == "__main__":
102 | df = pd.read_csv("./data/legal_train_v1.csv", lineterminator='\n')
103 | preprocessed_df = df.assign(
104 | instruction=df["instruction"].apply(preprocess), output=df["output"].apply(preprocess)
105 | )
106 | preprocessed_df.to_csv("test_preprocess_dataset2.csv", index=False)
107 |
--------------------------------------------------------------------------------
/data_pipeline/qa_crawler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from urllib.request import urlopen
3 |
4 | import pandas as pd
5 | from bs4 import BeautifulSoup as bs
6 | from selenium import webdriver
7 | from selenium.webdriver.common.by import By
8 | from selenium.webdriver.support.ui import WebDriverWait
9 |
10 |
11 |
12 | class QALawCrawler:
13 | def __init__(self, start_url="https://www.klac.or.kr/legalinfo/counsel.do"):
14 | self.start_url = start_url
15 | self.driver = None
16 |
17 | def give_options(self):
18 | options = webdriver.ChromeOptions()
19 | options.add_argument("--headless")
20 | options.add_argument("--no-sandbox")
21 | options.add_argument("--single-process")
22 | options.add_argument("--disable-dev-shm-usage")
23 | return options
24 |
25 | def start_driver(self):
26 | self.driver = webdriver.Chrome(options=self.give_options())
27 |
28 | def quit_driver(self):
29 | self.driver.quit()
30 |
31 | def crawlling_data(self):
32 | self.driver.get(self.start_url)
33 |
34 | data_list = []
35 | end_point = 0
36 |
37 | while end_point == 0:
38 | for j in range(2, 11):
39 | for k in range(1, 11):
40 |
41 | try:
42 | self._wait_driver_click(
43 | f'//*[@id="content"]/div[2]/div/form/div[2]/table/tbody/tr[{k}]/td[2]/a'
44 | )
45 | data_list.append(self._collect_data())
46 | self._wait_driver_click(f'//*[@id="content"]/div[2]/div/div/a')
47 | except:
48 | break
49 |
50 | try:
51 | self._wait_driver_click(
52 | f'//*[@id="content"]/div[2]/div/form/div[3]/a[{j}]'
53 | )
54 | except:
55 | end_point = 1
56 | break
57 |
58 | try:
59 | self._wait_driver_click(
60 | '//*[@id="content"]/div[2]/div/form/div[3]/button[3]'
61 | )
62 | except:
63 | break
64 |
65 | self._save_data(data_list=data_list, drop_unused_columns=True)
66 |
67 | def _wait_driver_click(self, xpath):
68 | WebDriverWait(self.driver, timeout=10).until(
69 | lambda d: d.find_element(By.XPATH, xpath)
70 | )
71 | self.driver.find_element(By.XPATH, xpath).click()
72 |
73 | def _collect_data(self):
74 | try:
75 | url = self.driver.current_url
76 | page_url = urlopen(url)
77 | soup = bs(page_url, "html.parser")
78 | data = []
79 | for i in range(1, 5):
80 | find_data = soup.select_one(
81 | "#print_page > div:nth-child(" + str(i) + ") > dl > dd"
82 | ).text
83 | data.append(find_data)
84 | data.append(url)
85 | return data
86 | except:
87 | return ["error", 0, 0, 0, 0]
88 |
89 | def _save_data(self, data_list, drop_unused_columns=False):
90 | df = pd.DataFrame(
91 | data=data_list, columns=["division", "title", "question", "answer", "url"]
92 | )
93 | if drop_unused_columns:
94 | question = df["question"].values.tolist()
95 | answer = df["answer"].values.tolist()
96 | df = pd.DataFrame({"question": question, "answer": answer})
97 |
98 | os.makedirs("data", exist_ok=True)
99 | df.to_csv("./data/law_qa.csv", index=False)
100 |
101 |
102 | if __name__ == "__main__":
103 | qa_crawler = QALawCrawler()
104 | qa_crawler.start_driver()
105 | qa_crawler.crawlling_data()
106 | qa_crawler.quit_driver()
107 |
--------------------------------------------------------------------------------
/backend/app/bert_retrieval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import time
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from datasets import Dataset, concatenate_datasets
8 | from sentence_transformers import SentenceTransformer
9 | from sklearn.metrics.pairwise import cosine_similarity
10 | from tqdm import tqdm
11 |
12 |
13 | class Autodata:
14 | def __init__(self, data_folder):
15 | self.data_foloder = data_folder
16 | self.concat_dataset = self.concat_datasets(self.data_foloder)
17 |
18 | def concat_datasets(self, data_foloder):
19 | datasets = []
20 | for file_name in os.listdir(data_foloder):
21 | if file_name.endswith(".csv"):
22 | file_path = os.path.join(data_foloder, file_name)
23 | dataset = pd.read_csv(file_path)
24 | dataframe = dataset[["question", "answer"]]
25 | dataset = Dataset.from_pandas(dataframe)
26 | datasets.append(dataset)
27 |
28 | combined_dataset = concatenate_datasets(datasets)
29 | pd_combiend_dataset = pd.DataFrame(combined_dataset)
30 | return pd_combiend_dataset
31 |
32 | def build_vector_dataset(self, dataset, path):
33 | dataset = np.array(dataset)
34 | model_name = "jhgan/ko-sroberta-multitask"
35 | model = SentenceTransformer(model_name).to("cuda:0")
36 |
37 | query_vector_list = []
38 |
39 | for i in tqdm(range(len(dataset))):
40 | question = dataset[i][0]
41 | query_vector = model.encode(question)
42 | query_vector_list.append(list(query_vector))
43 |
44 | with open(path, "wb") as fw:
45 | pickle.dump(query_vector_list, fw)
46 |
47 | with open(path, "rb") as fr:
48 | vector_data = pickle.load(fr)
49 |
50 | return vector_data
51 |
52 | def load_vector_data(self, path):
53 | if os.path.isfile(path):
54 | with open(path, "rb") as fr:
55 | vector_data = pickle.load(fr)
56 | else:
57 | vector_data = self.build_vector_dataset(self.concat_dataset, path)
58 | return vector_data
59 |
60 |
61 | def bert_retrieve_QA(q_sentence, model, data, vector_data):
62 | start_time = time.time()
63 | model = model.to("cuda:0")
64 |
65 | original_data = data.concat_dataset
66 |
67 | input_vector = model.encode(q_sentence)
68 | input_vector = np.expand_dims(input_vector, axis=0)
69 |
70 | cos_sim = cosine_similarity(input_vector, vector_data)
71 | data_cosine = np.sort(cos_sim).squeeze()[::-1][0]
72 | top_question_idx = np.argsort(cos_sim).squeeze()[::-1][0]
73 |
74 | similar_answer = ""
75 |
76 | if data_cosine >= 0.75:
77 | similar_question = original_data["question"][top_question_idx]
78 | similar_answer = original_data["answer"][top_question_idx]
79 | print(f"retrieve_question: {similar_question}\n")
80 | print(f"retrieve_answer: {similar_answer}\n")
81 |
82 | print(f"retrieve time: {time.time() - start_time}")
83 |
84 | return similar_answer
85 |
86 |
87 | def retrieve_debugging(q_sentence):
88 | model = SentenceTransformer("jhgan/ko-sroberta-multitask")
89 | model = model.to("cuda:0")
90 |
91 | DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data/bert_retrieval_data")
92 | data = Autodata(DATA_DIR)
93 | original_data = data.concat_dataset
94 | vector_data = data.load_vector_data(os.path.join(DATA_DIR, "query_vector.bin"))
95 |
96 | input_vector = model.encode(q_sentence)
97 | input_vector = np.expand_dims(input_vector, axis=0)
98 |
99 | cos_sim = cosine_similarity(input_vector, vector_data)
100 | data_cosine = np.sort(cos_sim).squeeze()[::-1][:3] # array([0.79316866, 0.7515925 , 0.72607714])
101 | top_question = np.argsort(cos_sim).squeeze()[::-1][:3] # array([9285, 9217, 3223])
102 |
103 | print("유사도 : ", data_cosine)
104 |
105 | question_list = []
106 | answer_list = []
107 |
108 | for i, index in enumerate(top_question):
109 | if data_cosine[i] >= 0.6:
110 | question_list.append(original_data["question"][index])
111 | answer_list.append(original_data["answer"][index])
112 | count = 0
113 | for question, answer in zip(question_list, answer_list):
114 | print(f"유사 상담 사례 질문 {count} : {question}")
115 | print(f"유사 상담 사례 답변 {count} : {answer}")
116 | print()
117 | count += 1
118 |
119 |
120 | if __name__ == "__main__":
121 | retrieve_debugging("제가 자동차를 운전하다 중앙선을 침범하다가 2충 추돌사고를 발생시켰습니다. 이때 무슨 법으로 처벌 받을 수 있나요?")
122 |
--------------------------------------------------------------------------------
/backend/airflow/module/train_model.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | import torch
5 | import transformers
6 | from datasets import Dataset, concatenate_datasets
7 | from dotenv import load_dotenv
8 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
9 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
10 | BitsAndBytesConfig)
11 |
12 | load_dotenv()
13 | huggingface_write_token = os.getenv("HUGGINGFACE_WRITE_TOKEN")
14 |
15 | class Autodata:
16 | def __init__(self, data_path, max_length=1024, tokenizer=None):
17 | self.max_length = max_length
18 | self.tokenizer = tokenizer
19 | self.concat_dataset = self.concat_datasets(data_path)
20 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset)
21 |
22 | def concat_datasets(self, data_path):
23 | datasets = []
24 | dataset = pd.read_csv(data_path)
25 | dataframe = dataset[["question", "answer"]]
26 | dataset = Dataset.from_pandas(dataframe)
27 | datasets.append(dataset)
28 |
29 | combined_dataset = concatenate_datasets(datasets)
30 |
31 | return combined_dataset
32 |
33 | def tokenizing_dataset(self, dataset):
34 | data = dataset.map(
35 | lambda x: {
36 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>"
37 | }
38 | )
39 | data = data.map(
40 | lambda samples: self.tokenizer(
41 | samples["text"],
42 | truncation=True,
43 | max_length=self.max_length,
44 | padding=False,
45 | return_tensors=None,
46 | ),
47 | batched=True,
48 | )
49 |
50 | return data.shuffle()
51 |
52 | def load_model(model_name):
53 | # bnb_config = BitsAndBytesConfig(
54 | # load_in_4bit=True,
55 | # bnb_4bit_use_double_quant=True,
56 | # bnb_4bit_quant_type="nf4",
57 | # bnb_4bit_compute_dtype=torch.bfloat16,
58 | # )
59 | tokenizer = AutoTokenizer.from_pretrained(model_name)
60 | model = AutoModelForCausalLM.from_pretrained(
61 | model_name
62 | )
63 | model.gradient_checkpointing_enable()
64 | model = prepare_model_for_kbit_training(model)
65 |
66 | config = LoraConfig(
67 | r=8,
68 | lora_alpha=32,
69 | target_modules=["query_key_value"],
70 | lora_dropout=0.05,
71 | bias="none",
72 | task_type="CAUSAL_LM",
73 | )
74 |
75 | model = get_peft_model(model, config)
76 | print_trainable_parameters(model)
77 |
78 | return model, tokenizer
79 |
80 |
81 | def print_trainable_parameters(model):
82 | """
83 | Prints the number of trainable parameters in the model.
84 | """
85 | trainable_params = 0
86 | all_param = 0
87 | for _, param in model.named_parameters():
88 | all_param += param.numel()
89 | if param.requires_grad:
90 | trainable_params += param.numel()
91 | print(
92 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
93 | )
94 |
95 | def train_model():
96 | model_id = "nlpai-lab/kullm-polyglot-5.8b-v2"
97 | model, tokenizer = load_model(model_id)
98 | tokenizer.pad_token = tokenizer.eos_token
99 | BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data")
100 | TRAIN_DATA_PATH = os.path.join(BASE_DIR, "train_data.csv")
101 | EVAL_DATA_PATH = os.path.join(BASE_DIR, "eval_data.csv")
102 | train_data = Autodata(data_path=TRAIN_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset
103 | val_data = Autodata(data_path=EVAL_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset
104 | trainer = transformers.Trainer(
105 | model=model,
106 | train_dataset=train_data,
107 | eval_dataset=val_data,
108 | args=transformers.TrainingArguments(
109 | per_device_train_batch_size=16,
110 | gradient_accumulation_steps=1,
111 | num_train_epochs=6,
112 | learning_rate=1e-4,
113 | fp16=True,
114 | logging_steps=10,
115 | save_strategy="epoch",
116 | evaluation_strategy="epoch",
117 | output_dir="./model_outputs",
118 | optim="paged_adamw_8bit",
119 | ),
120 | data_collator=transformers.DataCollatorForLanguageModeling(
121 | tokenizer, mlm=False
122 | ),
123 | )
124 | model.config.use_cache = (
125 | False # silence the warnings. Please re-enable for inference!
126 | )
127 | trainer.train()
128 |
129 | push_model_id = "YoonSeul/LawBot-airflow-test"
130 |
131 | model.push_to_hub(
132 | push_model_id, use_temp_dir=True, use_auth_token=huggingface_write_token
133 | )
134 | print(f"{push_model_id} 모델 업로드 완료!")
135 |
136 |
137 | if __name__ == "__main__":
138 | train_model()
139 |
--------------------------------------------------------------------------------
/frontend/src/components/ChattingSideBar.js:
--------------------------------------------------------------------------------
1 | function ChattingSideBar() {
2 | return (
3 |
4 |
37 |
38 | )
39 | }
40 | export default ChattingSideBar
41 |
--------------------------------------------------------------------------------
/backend/app/bm25_retrieval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | from typing import Callable, List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from datasets import Dataset, DatasetDict, concatenate_datasets
9 | from rank_bm25 import BM25L, BM25Okapi, BM25Plus
10 | from transformers import AutoTokenizer
11 |
12 |
13 | def retrieve_QA(q_sentence):
14 | start_time = time.time()
15 |
16 | BASE_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
17 | data_path = os.path.join(BASE_DIR, "data/all_data")
18 | data = Autodata(data_path)
19 | data.load_json_data(path=os.path.join(data_path, "all_data.json"))
20 | tokenizer = AutoTokenizer.from_pretrained("nlpai-lab/kullm-polyglot-5.8b-v2")
21 |
22 | datasets = run_sparse_retrieval(
23 | tokenize_fn=tokenizer.tokenize, data_path=data_path, datasets=q_sentence, bm25="plus"
24 | )
25 |
26 | print(f"retrieve time: {time.time() - start_time}")
27 | print(f"retrieve_question: {datasets[2][0]}")
28 |
29 | return datasets[3][0]
30 |
31 |
32 | class Autodata:
33 | def __init__(self, data_folder="./data"):
34 | self.data_foloder = data_folder
35 | self.concat_dataset = self.concat_datasets(self.data_foloder)
36 |
37 | def concat_datasets(self, data_foloder):
38 | datasets = []
39 | pd_datasets = []
40 | for file_name in os.listdir(data_foloder):
41 | if file_name.endswith(".csv"):
42 | file_path = os.path.join(data_foloder, file_name)
43 | dataset = pd.read_csv(file_path)
44 | dataframe = dataset[["question", "answer"]]
45 | pd_datasets.append(dataframe)
46 | dataset = Dataset.from_pandas(dataframe)
47 | datasets.append(dataset)
48 |
49 | combined_dataset = concatenate_datasets(datasets)
50 | pd_combiend_dataset = pd.DataFrame(combined_dataset)
51 |
52 | return pd_combiend_dataset
53 |
54 |
55 | def make_all_data(self, data, path):
56 | df = data
57 | data_dict = {}
58 |
59 | for i in range(len(df)):
60 | key = str(i)
61 | data_dict[key] = {
62 | "question": df.iloc[i]["question"],
63 | "answer": df.iloc[i]["answer"],
64 | }
65 |
66 | with open(path, "w", encoding="utf-8") as file:
67 | json.dump(data_dict, file, ensure_ascii=False, indent=4)
68 |
69 | def load_json_data(self, path="./all_data/all_data.json"):
70 | if not os.path.isfile(path):
71 | self.make_all_data(self.concat_dataset, path)
72 |
73 |
74 | def setup_bm25(parent_class):
75 | class CustomBM25(parent_class):
76 | def __init__(self, corpus, tokenizer):
77 | super().__init__(corpus, tokenizer)
78 |
79 | def get_relevant_doc(self, query, k):
80 | query_vec = self.tokenizer(query)
81 | result = self.get_scores(query_vec)
82 | sorted_result = np.argsort(result.squeeze())[::-1]
83 | doc_score = result.squeeze()[sorted_result].tolist()[:k]
84 | doc_indices = sorted_result.tolist()[:k]
85 | return doc_score, doc_indices
86 |
87 | def get_relevant_doc_bulk(self, queries, k):
88 | doc_scores = []
89 | doc_indices = []
90 | for query in queries:
91 | doc_score, doc_indice = self.get_relevant_doc(query, k)
92 | doc_scores.append(doc_score)
93 | doc_indices.append(doc_indice)
94 | return doc_scores, doc_indices
95 |
96 | return CustomBM25
97 |
98 |
99 | class SparseRetrievalBM25:
100 | def __init__(
101 | self,
102 | tokenize_fn,
103 | data_path: Optional[str] = "./csv_data/",
104 | context_path: Optional[str] = "all_data.json",
105 | bm25_type: Optional[str] = "",
106 | ) -> None:
107 | self.data_path = data_path
108 | with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
109 | wiki = json.load(f)
110 |
111 | self.contexts = list(([v["question"] for v in wiki.values()]))
112 | self.contexts_answer = list(([v["answer"] for v in wiki.values()]))
113 |
114 | if bm25_type == "Okapi":
115 | bm25_class = setup_bm25(BM25Okapi)
116 | self.bm25 = bm25_class(self.contexts, tokenize_fn)
117 | elif bm25_type == "L":
118 | bm25_class = setup_bm25(BM25L)
119 | self.bm25 = bm25_class(self.contexts, tokenize_fn)
120 | elif bm25_type == "plus":
121 | bm25_class = setup_bm25(BM25Plus)
122 | self.bm25 = bm25_class(self.contexts, tokenize_fn)
123 |
124 | def retrieve(
125 | self, query_or_dataset: Union[str, pd.DataFrame], topk: Optional[int] = 1
126 | ) -> Union[Tuple[List, List], pd.DataFrame]:
127 | if isinstance(query_or_dataset, str):
128 | doc_scores, doc_indices = self.bm25.get_relevant_doc(
129 | query_or_dataset, k=topk
130 | )
131 | return (
132 | doc_scores,
133 | doc_indices,
134 | [self.contexts[doc_indices[i]] for i in range(topk)],
135 | [self.contexts_answer[doc_indices[i]] for i in range(topk)],
136 | )
137 |
138 |
139 | def run_sparse_retrieval(
140 | tokenize_fn: Callable[[str], List[str]],
141 | datasets: pd.DataFrame,
142 | data_path: str = os.path.join(
143 | os.path.abspath(os.path.dirname(__file__)), "csv_data"
144 | ),
145 | context_path: str = "all_data.json",
146 | bm25: str = None,
147 | ) -> DatasetDict:
148 | assert bm25 in ["Okapi", "L", "plus"], "Invalid type for BM25 has been passed."
149 |
150 | retriever = SparseRetrievalBM25(
151 | tokenize_fn=tokenize_fn,
152 | data_path=data_path,
153 | context_path=context_path,
154 | bm25_type=bm25,
155 | )
156 |
157 | df = retriever.retrieve(datasets, topk=3)
158 | return df
159 |
--------------------------------------------------------------------------------
/data_pipeline/crawler.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import re
5 | import time
6 | from contextlib import contextmanager
7 | from urllib.request import urlopen
8 |
9 | import pandas as pd
10 | from bs4 import BeautifulSoup
11 | from selenium import webdriver
12 | from selenium.webdriver.common.by import By
13 | from selenium.webdriver.remote.remote_connection import \
14 | LOGGER as selenium_logger
15 | from selenium.webdriver.support import expected_conditions as EC
16 | from selenium.webdriver.support.ui import WebDriverWait
17 | from tqdm.auto import tqdm
18 | from utils import utilize_loggers
19 | from webdriver_manager.chrome import ChromeDriverManager
20 |
21 | selenium_logger.setLevel(logging.WARNING)
22 | os.environ["WDM_LOG"] = "0"
23 |
24 |
25 | @contextmanager
26 | def timer():
27 | t0 = time.time()
28 | yield lambda: time.time() - t0
29 |
30 |
31 | def measure_elapsed_time(timer_name):
32 | def decorator(func):
33 | def wrapper(*args, **kwargs):
34 | with timer() as elapsed_time:
35 | result = func(*args, **kwargs)
36 | logger.info(f"{elapsed_time():>8.3f} seconds elapsed @ {timer_name}")
37 | return result
38 |
39 | return wrapper
40 |
41 | return decorator
42 |
43 |
44 | class QADataCrawler:
45 | def __init__(
46 | self,
47 | board_url="https://www.klac.or.kr/legalstruct/cyberConsultation/selectOpenArticleList.do?boardCode=3#none",
48 | base_url="https://www.klac.or.kr/legalstruct/cyberConsultation/selectOpenArticleDetail.do?boardCode=3&contentId=",
49 | ):
50 | self.driver = None
51 | self.board_url = board_url
52 | self.base_url = base_url
53 |
54 | def start_driver(self):
55 | self.chrome_service = webdriver.chrome.service.Service(
56 | ChromeDriverManager().install()
57 | )
58 | self.chrome_options = webdriver.ChromeOptions()
59 | self.chrome_options.add_argument("headless")
60 |
61 | self.driver = webdriver.Chrome(
62 | service=self.chrome_service,
63 | options=self.chrome_options,
64 | )
65 |
66 | self.chrome_service.start()
67 |
68 | def quit_driver(self):
69 | self.driver.quit()
70 |
71 | @measure_elapsed_time("Total Crawling Process")
72 | def get_data(self):
73 | case_ids = self._get_all_case_ids()
74 | case_info = self._get_all_case_contents(case_ids)
75 |
76 | self._save_dataframe(case_info)
77 |
78 | def _get_case_id(self):
79 | case_ids = []
80 | element_xpath = "//a[contains(@onclick, 'fn_inquire_detail')]"
81 | WebDriverWait(self.driver, 10).until(
82 | EC.presence_of_element_located((By.XPATH, element_xpath))
83 | )
84 | elements = self.driver.find_elements(By.XPATH, element_xpath)
85 | for element in elements:
86 | match = re.search(
87 | r"fn_inquire_detail\('(\d+)', '(.*?)'\);return false;",
88 | element.get_attribute("onclick"),
89 | )
90 | case_ids.append(match.group(2))
91 | return case_ids
92 |
93 | @measure_elapsed_time("Get all case ids")
94 | def _get_all_case_ids(self, save_id_list=False):
95 | is_crawling_finished = False
96 | page_move_cnt = 0
97 | page_number = 1
98 | page_idx = 2
99 | case_ids = []
100 |
101 | self.driver.get(self.board_url)
102 | case_ids.extend(self._get_case_id())
103 |
104 | while not is_crawling_finished:
105 | try:
106 | next_page = self.driver.find_element(
107 | By.XPATH,
108 | f'//*[@id="content"]/form[1]/div[2]/div/div[4]/a[{page_idx}]',
109 | )
110 | page_idx += 1
111 | except:
112 | try:
113 | next_page = self.driver.find_element(
114 | By.XPATH,
115 | f'//button[contains(@onclick, "fn_select_linkPage({page_move_cnt * 10 + page_idx}); return false;")]',
116 | )
117 | page_idx = 2
118 | page_move_cnt += 1
119 | except:
120 | is_crawling_finished = True
121 | break
122 |
123 | next_page.click()
124 | case_ids.extend(self._get_case_id())
125 | page_number += 1
126 |
127 | if save_id_list:
128 | self._save_case_id_list(case_ids, "case_id_list.pkl")
129 |
130 | return case_ids
131 |
132 | def _get_case_content_by_id(self, case_id):
133 | url = self.base_url + case_id
134 | html = urlopen(url)
135 | bsObject = BeautifulSoup(html, "html.parser")
136 |
137 | case_title = bsObject.find("div", {"class": "view_head"}).text
138 | date_created = bsObject.find("dt", text="신청일").find_next_sibling("dd").text
139 | date_answered = bsObject.find("dt", text="답변일자").find_next_sibling("dd").text
140 | content, answer = bsObject.find_all("div", {"class": "notice_contents"})
141 | case_info = [case_title, date_created, date_answered, content.text, answer.text]
142 |
143 | return case_info
144 |
145 | @measure_elapsed_time("Get all case contents")
146 | def _get_all_case_contents(self, case_ids):
147 | case_info = []
148 |
149 | for case_id in tqdm(case_ids):
150 | case_info.append(self._get_case_content_by_id(case_id))
151 |
152 | return case_info
153 |
154 | def _save_dataframe(self, case_info, drop_unused_columns=False):
155 | df = pd.DataFrame(
156 | case_info,
157 | columns=[
158 | "case_title",
159 | "date_created",
160 | "date_answered",
161 | "content",
162 | "answer",
163 | ],
164 | )
165 |
166 | if drop_unused_columns:
167 | df = df.drop(["case_title", "date_created", "date_answered"], axis=1)
168 |
169 | os.makedirs("data", exist_ok=True)
170 | df.to_csv("./data/raw_qa_dataset.csv", index=False)
171 | logger.info(f"\t\tGathered Data Count: {len(df)}")
172 |
173 | def _save_case_id_list(self, case_id_list, file_name="case_id_list.pkl"):
174 | with open(file_name, "wb") as f:
175 | pickle.dump(case_id_list, f)
176 |
177 | def _load_case_id_list(self, file="case_id_list.pkl"):
178 | with open(file, "rb") as f:
179 | loaded_list = pickle.load(f)
180 | return loaded_list
181 |
182 |
183 | if __name__ == "__main__":
184 | logger = utilize_loggers(__file__)
185 |
186 | crawler = QADataCrawler()
187 | crawler.start_driver()
188 | crawler.get_data()
189 | crawler.quit_driver()
190 |
--------------------------------------------------------------------------------
/frontend/src/App.js:
--------------------------------------------------------------------------------
1 | import Header from './components/Header';
2 | import Loader from './components/Loader';
3 | import ChattingSideBar from './components/ChattingSideBar';
4 | import SimilarPrecedent from './components/SimilarPrecedent';
5 | import { useState } from 'react';
6 | import TypingAnimation from './components/TypingAnimation';
7 |
8 | function App() {
9 | const [loading, setLoading] = useState(false)
10 | const [message, setMessage] = useState("");
11 | const [sentMessage, setSentMessage] = useState("");
12 | const [aianswer, setAianswer] = useState("");
13 | const [precedents, setPrecedents] = useState(null);
14 | const ans = "\n\nAI가 작성한 답변이며 실제와 다를 수 있으므로 참고 자료로만 활용하시고, 자세한 상담을 원하시는 경우에는 전문 법조인의 상담을 받으시기 바랍니다. LawBot은 법적 책임을 지지 않는다는 점 참고바랍니다."
15 |
16 | const messagehandler = async (e) => {
17 | e.preventDefault();
18 | if (!loading){
19 | setLoading(true);
20 | setMessage("");
21 | setSentMessage(message);
22 | setAianswer("");
23 | setPrecedents(null);
24 | try {
25 | console.log(message);
26 | console.log(message.trim().length)
27 | if (message.trim().length <= 5) {
28 | setAianswer("죄송합니다. 입력하신 \""+message+"\"는 너무 짧아 정확한 답변을 제공하기 어려운 점 양해해주시기 바랍니다. 정확하고 효과적인 답변을 위해 더욱 구체적으로 질문해주시기 바랍니다.");
29 | } else {
30 | const response = await fetch('/generate', {
31 | method: 'POST',
32 | headers: {
33 | 'Content-Type': 'application/json',
34 | },
35 | body: JSON.stringify({ q_sentence: message }),
36 | });
37 |
38 | const data = await response.json();
39 | console.log(message);
40 | console.log(data);
41 |
42 | if (data != null) {
43 | if (data.answer_sentence == null | data.answer_sentence =="\n") {
44 | setAianswer("죄송합니다. 저는 법률 상담을 도와드리는 AI LawBot입니다. 법률 내용 외의 질문은 답변해 드리지 않는 점 참고 부탁드립니다. 법률적인 질문이 있으시다면 언제든지 물어보세요. 제가 최대한 자연스럽고 이해하기 쉽게 답변해 드리겠습니다. 어떤 도움이 필요하신가요?");
45 | setPrecedents(null);
46 | } else {
47 | setAianswer(data.answer_sentence+ans);
48 | setPrecedents(data.similar_precedent);
49 | }
50 | }
51 | }
52 | } catch (error) {
53 | console.error("에러 발생:", error);
54 | } finally {
55 | setLoading(false);
56 | }}
57 | };
58 |
59 | return (
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | {sentMessage && (
73 |
74 |
75 |
76 | U
77 |
78 |
79 |
80 |
81 |
82 |
83 | )}
84 | {aianswer && (
85 |
86 |
87 |
90 | L
91 |
92 |
95 |
96 |
97 |
98 |
)}
99 | {loading && (
)}
100 |
101 |
102 |
103 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 | );
177 | }
178 |
179 | export default App;
180 |
--------------------------------------------------------------------------------
/prototype/src/App.js:
--------------------------------------------------------------------------------
1 | import Header from './components/Header';
2 | import ChattingSideBar from './components/ChattingSideBar';
3 | import SimilarPrecedent from './components/SimilarPrecedent';
4 | import { useState } from 'react';
5 |
6 | function App() {
7 |
8 | const [message, setMessage] = useState("");
9 | const [sentMessage, setSentMessage] = useState("");
10 | const [aianswer, setAianswer] = useState("")
11 | const [precedents, setPrecedents] = useState([
12 | {
13 | "case_name": "",
14 | "case_number": "",
15 | "case_type": "",
16 | "ref_article": "",
17 | "url": "string"
18 | },
19 | {
20 | "case_name": "",
21 | "case_number": "",
22 | "case_type": "",
23 | "ref_article": "",
24 | "url": "string"
25 | },
26 | {
27 | "case_name": "",
28 | "case_number": "",
29 | "case_type": "",
30 | "ref_article": "",
31 | "url": ""
32 | }
33 | ])
34 |
35 | const messagehandler = async (e) => {
36 | e.preventDefault();
37 | setMessage("");
38 | setPrecedents([
39 | {
40 | "case_name": "",
41 | "case_number": "",
42 | "case_type": "",
43 | "ref_article": "",
44 | "url": "string"
45 | },
46 | {
47 | "case_name": "",
48 | "case_number": "",
49 | "case_type": "",
50 | "ref_article": "",
51 | "url": "string"
52 | },
53 | {
54 | "case_name": "",
55 | "case_number": "",
56 | "case_type": "",
57 | "ref_article": "",
58 | "url": ""
59 | }
60 | ])
61 | setAianswer("")
62 | setSentMessage(message);
63 | const response = await fetch('/generate', {
64 | method: 'POST',
65 | headers: {
66 | 'Content-Type': 'application/json',
67 | },
68 | body: JSON.stringify({ q_sentence: sentMessage }),
69 | });
70 | const data = await response.json()
71 | setAianswer(data.answer_sentence)
72 | setPrecedents(data.similar_precedent)
73 | };
74 |
75 | return (
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 | {sentMessage && (
89 |
90 |
91 |
92 | U
93 |
94 |
97 |
98 |
99 | )}
100 | {aianswer && (
101 |
102 |
103 |
106 | L
107 |
108 |
113 |
114 |
)}
115 |
116 |
117 |
118 |
121 |
122 |
125 |
132 |
138 |
139 |
140 |
141 |
172 |
173 |
176 | Send
177 |
178 |
185 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 | );
207 | }
208 |
209 | export default App;
210 |
--------------------------------------------------------------------------------
/prototype/src/components/ChattingSideBar.js:
--------------------------------------------------------------------------------
1 | function ChattingSideBar() {
2 | return (
3 |
4 |
63 |
64 | )
65 | }
66 | export default ChattingSideBar
67 |
--------------------------------------------------------------------------------
/backend/airflow/dags/training_pipeline.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import sys
4 |
5 | import pandas as pd
6 | import transformers
7 | from airflow import DAG
8 | from airflow.operators.python import PythonOperator, PythonVirtualenvOperator
9 | from datasets import Dataset, concatenate_datasets
10 | from dotenv import load_dotenv
11 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
13 | BitsAndBytesConfig)
14 |
15 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
16 |
17 | import torch
18 | from module.load_data import load_train_eval_data
19 | from torch import multiprocessing
20 |
21 | # multiprocessing.set_start_method("forkserver")
22 |
23 | load_dotenv()
24 | huggingface_write_token = os.getenv("HUGGINGFACE_WRITE_TOKEN")
25 |
26 | class Autodata:
27 | def __init__(self, data_path, max_length=1024, tokenizer=None):
28 | self.max_length = max_length
29 | self.tokenizer = tokenizer
30 | self.concat_dataset = self.concat_datasets(data_path)
31 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset)
32 |
33 | def concat_datasets(self, data_path):
34 | datasets = []
35 | dataset = pd.read_csv(data_path)
36 | dataframe = dataset[["question", "answer"]]
37 | dataset = Dataset.from_pandas(dataframe)
38 | datasets.append(dataset)
39 |
40 | combined_dataset = concatenate_datasets(datasets)
41 |
42 | return combined_dataset
43 |
44 | def tokenizing_dataset(self, dataset):
45 | data = dataset.map(
46 | lambda x: {
47 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>"
48 | }
49 | )
50 | data = data.map(
51 | lambda samples: self.tokenizer(
52 | samples["text"],
53 | truncation=True,
54 | max_length=self.max_length,
55 | padding=False,
56 | return_tensors=None,
57 | ),
58 | batched=True,
59 | )
60 |
61 | return data.shuffle()
62 |
63 | def load_model(model_name):
64 | # bnb_config = BitsAndBytesConfig(
65 | # load_in_4bit=True,
66 | # bnb_4bit_use_double_quant=True,
67 | # bnb_4bit_quant_type="nf4",
68 | # bnb_4bit_compute_dtype=torch.bfloat16,
69 | # )
70 | tokenizer = AutoTokenizer.from_pretrained(model_name)
71 | model = AutoModelForCausalLM.from_pretrained(
72 | model_name
73 | )
74 | model.gradient_checkpointing_enable()
75 | model = prepare_model_for_kbit_training(model)
76 |
77 | config = LoraConfig(
78 | r=8,
79 | lora_alpha=32,
80 | target_modules=["query_key_value"],
81 | lora_dropout=0.05,
82 | bias="none",
83 | task_type="CAUSAL_LM",
84 | )
85 |
86 | model = get_peft_model(model, config)
87 | print_trainable_parameters(model)
88 |
89 | return model, tokenizer
90 |
91 |
92 | def print_trainable_parameters(model):
93 | """
94 | Prints the number of trainable parameters in the model.
95 | """
96 | trainable_params = 0
97 | all_param = 0
98 | for _, param in model.named_parameters():
99 | all_param += param.numel()
100 | if param.requires_grad:
101 | trainable_params += param.numel()
102 | print(
103 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
104 | )
105 |
106 | def train_model():
107 | import datetime
108 | import os
109 | import sys
110 |
111 | import pandas as pd
112 | import torch
113 | import transformers
114 | from airflow import DAG
115 | from airflow.operators.python import (PythonOperator,
116 | PythonVirtualenvOperator)
117 | from datasets import Dataset, concatenate_datasets
118 | from dotenv import load_dotenv
119 | from peft import (LoraConfig, get_peft_model,
120 | prepare_model_for_kbit_training)
121 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
122 | BitsAndBytesConfig)
123 |
124 | class Autodata:
125 | def __init__(self, data_path, max_length=1024, tokenizer=None):
126 | self.max_length = max_length
127 | self.tokenizer = tokenizer
128 | self.concat_dataset = self.concat_datasets(data_path)
129 | self.tokenizer_dataset = self.tokenizing_dataset(self.concat_dataset)
130 |
131 | def concat_datasets(self, data_path):
132 | datasets = []
133 | dataset = pd.read_csv(data_path)
134 | dataframe = dataset[["question", "answer"]]
135 | dataset = Dataset.from_pandas(dataframe)
136 | datasets.append(dataset)
137 |
138 | combined_dataset = concatenate_datasets(datasets)
139 |
140 | return combined_dataset
141 |
142 | def tokenizing_dataset(self, dataset):
143 | data = dataset.map(
144 | lambda x: {
145 | "text": f"아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n{x['question']}\n\n### 응답:\n{x['answer']}<|endoftext|>"
146 | }
147 | )
148 | data = data.map(
149 | lambda samples: self.tokenizer(
150 | samples["text"],
151 | truncation=True,
152 | max_length=self.max_length,
153 | padding=False,
154 | return_tensors=None,
155 | ),
156 | batched=True,
157 | )
158 |
159 | return data.shuffle()
160 |
161 | model_id = "nlpai-lab/kullm-polyglot-5.8b-v2"
162 | tokenizer = AutoTokenizer.from_pretrained(model_id)
163 | bnb_config = BitsAndBytesConfig(
164 | load_in_4bit=True,
165 | bnb_4bit_use_double_quant=True,
166 | bnb_4bit_quant_type="nf4",
167 | bnb_4bit_compute_dtype=torch.bfloat16,
168 | )
169 | model = AutoModelForCausalLM.from_pretrained(
170 | model_id, quantization_config=bnb_config, device_map={"": 0}
171 | )
172 | model.gradient_checkpointing_enable()
173 | model = prepare_model_for_kbit_training(model)
174 |
175 | config = LoraConfig(
176 | r=8,
177 | lora_alpha=32,
178 | target_modules=["query_key_value"],
179 | lora_dropout=0.05,
180 | bias="none",
181 | task_type="CAUSAL_LM",
182 | )
183 |
184 | model = get_peft_model(model, config)
185 | tokenizer.pad_token = tokenizer.eos_token
186 | # BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), "data")
187 | BASE_DIR = os.path.join("/opt/ml/level3_nlp_finalproject-nlp-08/backend/airflow", "data")
188 | TRAIN_DATA_PATH = os.path.join(BASE_DIR, "train_data.csv")
189 | EVAL_DATA_PATH = os.path.join(BASE_DIR, "eval_data.csv")
190 | train_data = Autodata(data_path=TRAIN_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset
191 | val_data = Autodata(data_path=EVAL_DATA_PATH, tokenizer=tokenizer).tokenizer_dataset
192 | trainer = transformers.Trainer(
193 | model=model,
194 | train_dataset=train_data,
195 | eval_dataset=val_data,
196 | args=transformers.TrainingArguments(
197 | per_device_train_batch_size=16,
198 | gradient_accumulation_steps=1,
199 | num_train_epochs=6,
200 | learning_rate=1e-4,
201 | fp16=True,
202 | logging_steps=10,
203 | save_strategy="epoch",
204 | evaluation_strategy="epoch",
205 | output_dir="./model_outputs",
206 | optim="paged_adamw_8bit",
207 | ),
208 | data_collator=transformers.DataCollatorForLanguageModeling(
209 | tokenizer, mlm=False
210 | ),
211 | )
212 | model.config.use_cache = (
213 | False # silence the warnings. Please re-enable for inference!
214 | )
215 | trainer.train()
216 |
217 | push_model_id = "YoonSeul/LawBot-airflow-test"
218 |
219 | model.push_to_hub(
220 | push_model_id, use_temp_dir=True, use_auth_token=huggingface_write_token
221 | )
222 | print(f"{push_model_id} 모델 업로드 완료!")
223 |
224 | with DAG(
225 | dag_id="training_pipeline",
226 | description="train the model periodically",
227 | start_date=datetime.datetime(2023,7,27),
228 | schedule_interval="0 0 * * 5",
229 | tags=["LLM"],
230 | ) as dag:
231 |
232 | load_data = PythonOperator(
233 | task_id="load_data",
234 | python_callable=load_train_eval_data,
235 | depends_on_past=True,
236 | owner="SangwonYoon",
237 | retries=3,
238 | retry_delay=datetime.timedelta(minutes=5)
239 | )
240 |
241 | training_model = PythonVirtualenvOperator(
242 | task_id="train_model",
243 | python_callable=train_model,
244 | depends_on_past=True,
245 | owner="SangwonYoon",
246 | retries=3,
247 | retry_delay=datetime.timedelta(minutes=5),
248 | system_site_packages=True,
249 | # requirements=["torch==2.0.1"],
250 | )
251 |
252 | # test = PythonVirtualenvOperator(
253 | # task_id="test",
254 | # python_callable=test_torch,
255 | # depends_on_past=True,
256 | # owner="SangwonYoon",
257 | # retries=3,
258 | # retry_delay=datetime.timedelta(minutes=1),
259 | # system_site_packages=True,
260 | # # requirements=["torch==2.0.1"],
261 | # python_version=3.8,
262 | # )
263 |
264 | load_data >> training_model
265 | # test
266 |
267 |
268 |
269 |
--------------------------------------------------------------------------------
/model/LLM/evaluation/petf_ppl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Perplexity Metric."""
15 |
16 | import datasets
17 | import numpy as np
18 | import torch
19 | from peft import PeftConfig, PeftModel
20 | from torch.nn import CrossEntropyLoss
21 | from transformers import AutoModelForCausalLM, AutoTokenizer
22 |
23 | import evaluate
24 | from evaluate import logging
25 |
26 |
27 | _CITATION = """\
28 | """
29 |
30 | _DESCRIPTION = """
31 | Perplexity (PPL) is one of the most common evaluation for evaluating language models.
32 | It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
33 | For more information, see https://huggingface.co/docs/transformers/perplexity
34 | """
35 |
36 | _KWARGS_DESCRIPTION = """
37 | Args:
38 | model_id (str): model used for calculating Perplexity
39 | NOTE: Perplexity can only be calculated for causal language models.
40 | This includes models such as gpt2, causal variations of bert,
41 | causal versions of t5, and more (the full list can be found
42 | in the AutoModelForCausalLM documentation here:
43 | https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
44 | predictions (list of str): input text, each separate text snippet
45 | is one list entry.
46 | batch_size (int): the batch size to run texts through the model. Defaults to 16.
47 | add_start_token (bool): whether to add the start token to the texts,
48 | so the perplexity can include the probability of the first word. Defaults to True.
49 | device (str): device to run on, defaults to 'cuda' when available
50 | Returns:
51 | perplexity: dictionary containing the perplexity scores for the texts
52 | in the input list, as well as the mean perplexity. If one of the input texts is
53 | longer than the max input length of the model, then it is truncated to the
54 | max length for the perplexity computation.
55 | Examples:
56 | Example 1:
57 | >>> perplexity = evaluate.load("perplexity", module_type="metric")
58 | >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
59 | >>> results = perplexity.compute(model_id='gpt2',
60 | ... add_start_token=False,
61 | ... predictions=input_texts) # doctest:+ELLIPSIS
62 | >>> print(list(results.keys()))
63 | ['perplexities', 'mean_perplexity']
64 | >>> print(round(results["mean_perplexity"], 0))
65 | 647.0
66 | >>> print(round(results["perplexities"][0], 0))
67 | 32.0
68 | Example 2:
69 | >>> from datasets import load_dataset
70 | >>> perplexity = evaluate.load("perplexity", module_type="metric")
71 | >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
72 | >>> input_texts = [s for s in input_texts if s!='']
73 | >>> results = perplexity.compute(model_id='gpt2',
74 | ... predictions=input_texts)
75 | >>> print(list(results.keys()))
76 | ['perplexities', 'mean_perplexity']
77 | >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP
78 | 576.76
79 | >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP
80 | 889.28
81 | """
82 |
83 |
84 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
85 | class Perplexity_Petf(evaluate.Metric):
86 | def _info(self):
87 | return evaluate.MetricInfo(
88 | module_type="metric",
89 | description=_DESCRIPTION,
90 | citation=_CITATION,
91 | inputs_description=_KWARGS_DESCRIPTION,
92 | features=datasets.Features(
93 | {
94 | "predictions": datasets.Value("string"),
95 | }
96 | ),
97 | reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
98 | )
99 |
100 | def _compute(
101 | self,
102 | predictions,
103 | model_id,
104 | batch_size: int = 4,
105 | add_start_token: bool = True,
106 | device=None,
107 | max_length=None,
108 | ):
109 | if device is not None:
110 | assert device in [
111 | "gpu",
112 | "cpu",
113 | "cuda",
114 | ], "device should be either gpu or cpu."
115 | if device == "gpu":
116 | device = "cuda"
117 | else:
118 | device = "cuda" if torch.cuda.is_available() else "cpu"
119 |
120 | config = PeftConfig.from_pretrained(model_id)
121 | model = AutoModelForCausalLM.from_pretrained(
122 | config.base_model_name_or_path, device_map={"": 0}
123 | )
124 | model = PeftModel.from_pretrained(
125 | model, model_id
126 | ) # ,quantization_config=bnb_config)
127 |
128 | model = model.to(device)
129 |
130 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
131 |
132 | # if batch_size > 1 (which generally leads to padding being required), and
133 | # if there is not an already assigned pad_token, assign an existing
134 | # special token to also be the padding token
135 | if tokenizer.pad_token is None and batch_size > 1:
136 | existing_special_tokens = list(
137 | tokenizer.special_tokens_map_extended.values()
138 | )
139 | # check that the model already has at least one special token defined
140 | assert (
141 | len(existing_special_tokens) > 0
142 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
143 | # assign one of the special tokens to also be the pad token
144 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
145 |
146 | if add_start_token and max_length:
147 | # leave room for token to be added:
148 | assert (
149 | tokenizer.bos_token is not None
150 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
151 | max_tokenized_len = max_length - 1
152 | else:
153 | max_tokenized_len = max_length
154 |
155 | encodings = tokenizer(
156 | predictions,
157 | add_special_tokens=False,
158 | padding=True,
159 | truncation=True if max_tokenized_len else False,
160 | max_length=max_length,
161 | return_tensors="pt",
162 | return_attention_mask=True,
163 | ).to(device)
164 |
165 | encoded_texts = encodings["input_ids"]
166 | attn_masks = encodings["attention_mask"]
167 |
168 | # check that each input is long enough:
169 | if add_start_token:
170 | assert torch.all(
171 | torch.ge(attn_masks.sum(1), 1)
172 | ), "Each input text must be at least one token long."
173 | else:
174 | assert torch.all(
175 | torch.ge(attn_masks.sum(1), 2)
176 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
177 |
178 | ppls = []
179 | loss_fct = CrossEntropyLoss(reduction="none")
180 |
181 | for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
182 | end_index = min(start_index + batch_size, len(encoded_texts))
183 | encoded_batch = encoded_texts[start_index:end_index]
184 | attn_mask = attn_masks[start_index:end_index]
185 |
186 | if add_start_token:
187 | bos_tokens_tensor = torch.tensor(
188 | [[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)
189 | ).to(device)
190 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
191 | attn_mask = torch.cat(
192 | [
193 | torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(
194 | device
195 | ),
196 | attn_mask,
197 | ],
198 | dim=1,
199 | )
200 |
201 | labels = encoded_batch
202 |
203 | with torch.no_grad():
204 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits
205 |
206 | shift_logits = out_logits[..., :-1, :].contiguous()
207 | shift_labels = labels[..., 1:].contiguous()
208 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
209 |
210 | perplexity_batch = torch.exp(
211 | (
212 | loss_fct(shift_logits.transpose(1, 2), shift_labels)
213 | * shift_attention_mask_batch
214 | ).sum(1)
215 | / shift_attention_mask_batch.sum(1)
216 | )
217 |
218 | ppls += perplexity_batch.tolist()
219 |
220 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}
221 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
법률 조언 웹 서비스 ‘LawBot’
4 |
5 | [](https://huggingface.co/models?filter=keytotext) [](https://github.com/gagan3012/keytotext#api)
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | >LawBot은 유사 판례 및 법률 조항과 함께 가벼운 법률 상담 서비스를 제공합니다. LLM 모델의 기학습된 방대한 정보와 더불어 fine-tuning에 사용한 법률 지식을 사용하여 다양하고 특수한 상황에 유연하게 대응하여 답변을 생성할 수 있으며, 유사도 기반의 AI 모델을 이용하여 관련된 내용에 대한 유사 판례를 제공합니다.
14 |
15 |
16 |
17 | 🖥 **[LawBot 웹서비스 체험하기](http://yoonseul.link)**
18 |
19 | **※ 본 웹 서비스는 포스트 세션 종료일인 2023년 8월 18일까지만 이용하실 수 있으니 참고바랍니다. 종료일까지 서비스 고도화 및 성능 개선 작업이 이루어질 예정입니다.**
20 |
21 |
22 |
23 | ## ⌘ Project BackGround
24 |
25 | ### 기획 의도 및 기대효과
26 |
27 |
28 |
29 |
30 |
31 | * **`배경`** : 국제적으로 리걸 테크 산업은 매우 빠르게 발전하고 있으며, 국내에서도 관련 서비스 수요가 꾸준히 증가하고 있습니다. 그러나 법률 용어나 법률 문장은 해석하기 어려워 일반인들에게 이해도는 낮은 편이며, 관련 정보를 얻기 위해서는 많은 비용이 필요합니다.
32 |
33 | * **`목표`**: 윤슬 팀은 이러한 상황을 해결하기 위해 법률 상황에 대해 이해하기 쉬운 가이드라인을 제시하고, 관련된 유사 판례 및 법률 조항을 제공함으로써 법의 장벽을 낮출 수 있는 가벼운 법률 상담 서비스를 제공하고자 합니다.
34 |
35 |
36 |
37 | ### LawBot 서비스의 차별점
38 |
39 | - 기존의 주류 Legal Tech 서비스는 **변호사를 매칭시켜주거나 여전히 어려운 법률조항 및 판례에 대한 직접적인 검색**만을 제공합니다. LawBot은 이와 다르게 AI 모델을 이용하여 법적 분쟁 상황에 대한 유사판례를 찾아주고, 빠른 시간 내에 가벼운 가이드라인을 직접 생성해서 유저에게 제공한다는 점에서 차별점을 가지고 있습니다.
40 |
41 |
42 |
43 | ## ⚙️ Use Case
44 |
45 | 
46 |
47 |
48 |
49 | >1. 웹 서버 접속
50 | >2. 메시지 프롬프트 창에 자신이 처한 법적 분쟁 상황 및 연관된 질문 입력
51 | >3. AI 모델이 상황 맥락을 이해하여 가이드라인을 메시지 형태로 제공
52 | >4. 우측 사이드바에서 사용자가 입력한 상황과 비슷한 최대 3가지의 유사 판례 제공
53 | >5. 링크를 통해 법 조항 등 관련된 법령 정보를 직접 확인 가능
54 |
55 |
56 |
57 |
58 | ## 🧑🏻💻 Team Introduction & Members
59 |
60 | ### 💬 팀 소개
61 | >**조화와 지속 가능한 성장을 추구하는 팀 `윤슬`입니다!** **팀 개개인 모두 주어진 위치에 상관없이 모든 일에 `오너십`을 가지고 `적극적으로 참여`하는 것을 최우선으로 생각하였습니다. 좋은 동료가 되기 위해 치열하게 고민하고, 학습하고, 성장하고 있습니다.**
62 |
63 |
64 |
65 | ### 👨🏼💻 Members
66 | 강민재|김주원|김태민|신혁준|윤상원|
67 | :-:|:-:|:-:|:-:|:-:
68 | | | | | |
69 | | | | |
70 | | | | | |
71 |
72 |
73 |
74 | ### 👨🏼💻 역할 분담
75 |
76 |
77 |
78 |
79 | ## 🗓️ Project TimeLine
80 | - 초기에 핵심 기능을 우선적으로 개발하였고, 그 후 팀원들의 의견을 수렴하여 지속적으로 발전시키는 애자일적인 접근 방식을 적용하여 서비스 고도화 작업을 진행했습니다.
81 |
82 |
83 |
84 |
85 |
86 | ## ⌘ Service Archiecture
87 |
88 |
89 |
90 |
91 | >네이버 커넥트 재단으로부터 제공받은 V100 서버 4대를 모두 활용하기 위해 첫 설계 때부터 **서비스 확장이 쉬운 마이크로 서비스 아키텍처**를 고려했습니다. 또한 서비스 간의 상호 의존도를 낮춰 서버에 장애가 발생할 경우 전체 서비스가 중단되는 것을 방지하고자 하여 위와 같이 웹 서버, 모델 서버를 독립적으로 분리하고 API를 통해 서로 통신하는 구조로 설계했습니다. 이를 통해 한 대의 V100 서버에 장애가 발생하더라도 나머지 서비스는 전혀 영향을 받지 않고 서비스를 제공할 수 있습니다.
92 |
93 |
94 |
95 | - 아래는 구현한 내용이고 관련된 내용을 확인하실 수 있는 링크입니다. 각 링크에서 왜 해당 기능을 구현을 했으며, 어떤 것들을 중점적으로 고려해서 개발했는지 살펴보실 수 있습니다.
96 |
97 | - 🛠️ [CI 파이프라인 구축](https://uomnf97.notion.site/CI-f687f03b192f49fa80d451f8850a03f6?pvs=4)
98 | - ✍🏻 [로드 밸런싱 적용](https://uomnf97.notion.site/96f697aab756407aadbe51582a0a68d4?pvs=4)
99 | - ✍🏻 [Airflow를 이용한 모델 학습 파이프라인](https://uomnf97.notion.site/64a55c1e1f4a4ff985343a97b224a101?pvs=4)
100 | - 🛠️ [Auto Scaling을 통한 Failover](https://uomnf97.notion.site/Auto-Scaling-Failover-fa0ab424dcda44739ababe1eb719a106?pvs=4)
101 |
102 |
103 |
104 | ## 💿 Data
105 |
106 |
107 | - 데이터는 위와 같이 단계별로 나누어 목표를 설정하고 데이터를 탐색, 수집, EDA 및 전처리, 생성모델을 통한 증강을 하여 학습데이터 셋을 구축하였습니다.
108 |
109 |
110 |
111 |
112 |
113 | - 데이터 파이프라인을 도식화한 순환 DMOps 구조도는 위와 같습니다.
114 |
115 |
116 |
117 |
118 | ### 1️⃣ 법률 QA 데이터
119 |
120 | | 데이터셋 이름 | 데이터 개수 | 출처 |
121 | | :---: | :---: | :---: |
122 | | easylaw_kr | 2,195 | https://huggingface.co/datasets/juicyjung/easylaw_kr |
123 | | LegalQA | 1,830 | https://github.com/haven-jeon/LegalQA |
124 | |대한법률구조공단의 법률상담사례 데이터|9994|https://www.klac.or.kr/legalinfo/counsel.do|
125 | |대한법률구조공단의 국내 사이버상담 데이터|2463|https://www.klac.or.kr/legalstruct/cyberConsultation.do|
126 | |Open AI GPT증강|8666| - |
127 |
128 |
129 |
130 | ### 2️⃣ 판례 데이터
131 |
132 | | 데이터셋 이름 | 데이터 개수 | 출처 |
133 | | :---: | :---: | :---: |
134 | | 법률/규정 (판결서, 약관 등) 텍스트 분석 데이터 |77382| https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=580 |
135 |
136 |
137 |
138 | - 데이터 전처리, EDA, 증강에 대한 자세한 내용은 아래에서 확인하실 수 있습니다.
139 |
140 | - 💿[데이터 수집](https://uomnf97.notion.site/a5f4628cdc7b4ce3928ce626c727ff32?pvs=4)
141 | - 📈[EDA 및 전처리](https://uomnf97.notion.site/EDA-b53d75aaa7574ea586cbd6cdbd5c755a?pvs=4)
142 | - 🛠️[생성모델을 통한 데이터 증강](https://uomnf97.notion.site/5e8e9ecc27694d7497fbad68f72136c0?pvs=4)
143 |
144 |
145 |
146 | ## 📊 Model
147 |
148 | ### Overview
149 |
150 |
151 |
152 |
153 | > 먼저 유저가 입력하면 모델에서는 Question Filtering 모델을 통해 법률적인 질문인지 아닌지 구분하게 됩니다. 법률적인 질문이라면 Similar Precedent Mode과 Law LLM Model 통해 유사판례와 법률적인 조언을 생성합니다.
154 |
155 |
156 |
157 | ### 1️⃣ Question Filtering Model
158 |
159 |
160 |
161 |
162 | ### 2️⃣ Similar Precedent Model
163 |
164 |
165 |
166 |
167 | ### 3️⃣LLM 모델
168 |
169 |
170 |
171 |
172 | - 구축한 모델 리스트(활용한 데이터 + Backbone 모델)
173 |
174 | #### 평가지표
175 | - LLM 모델은 명확한 평가지표가 없어 직접 평가지표를 만들어 평가했습니다. 평가지표를 만들 때 고려했던 부분은 도메인 특성상 법률적인 정확도가 중요하므로 법률적인 정확도와 언어의 자연스러움 두가지 모두 평가할 수 있도록 metric을 제작하여 평가했습니다.
176 |
177 |
178 |
179 | **[ Dialogue Evaluation Metric ]**
180 |
181 | - Kullm 모델의 Dialogue Evaluation Metric 평가요소를 도메인에 맞게 변형하여 활용하였고, 해당 지표를 직접 변호사에게 의뢰하여 답변을 평가했습니다. 추가로 모든 모델들은 모델 A, B, C ... 등 모델 이름을 가리는 블라인드 평가를 진행했으며, 명확한 평가지표를 만들기 위해 ChatGPT, BARD 모델과 함께 평가를 진행하였습니다. 최종적으로 저희가 구축한 **kfkas/legal-llama-2-ko-7b-Chat** 모델이 가장 좋은 성능을 보였습니다.
182 |
183 |
184 |
185 |
186 |
187 | **[ Perplexity ]**
188 | - 얼마나 생성모델이 법률적인 용어를 생성해내는지 평가하기 위해 Perplexity 평가지표를 활용하였습니다. 낮을 수록 좋은 값을 나타내는 metric인데, 크롤링 데이터와 탐색한 데이터로 학습한 모델들이 대체로 높은 성능을 나타내는 경향을 보였습니다.
189 |
190 |
191 |
192 |
193 | ## 💻 Getting Started
194 |
195 | > 아래 Readme를 통해 직접 프로젝트를 실행할 수 있으며, 구현된 코드를 살펴볼 수 있습니다.
196 |
197 | ### 📊 Model
198 | - [Model](model) / [README.md](model/README.md)
199 |
200 | ### 💽 Data
201 | - [Data Pipeline](data_pipeline) / [README.md](data_pipeline/README.md)
202 |
203 | ### 🎨 Frontend
204 | - [Frontend](frontend) / [README.md](frontend/README.md)
205 |
206 | ### 💻 Backend
207 | - [Backend](backend) / [README.md](backend/README.md)
208 |
209 |
210 |
211 | ## 📚 Further Information
212 |
213 | ### 1️⃣ 개발 스택 및 개발 환경
214 |
215 |
216 |
217 |
218 | ### 2️⃣ 협업 Tools
219 |
220 | - 노션 :
221 | - Kanban Board를 이용하여 체계적으로 To do List 관리
222 | - 노션 협업기구를 활용해 회의 및 기록 체계화å
223 |
224 | - GitHub :
225 | - GitHub Flow를 이용하여 브랜치 전략 수립
226 | - PR Template, Issue Template을 이용하여 체계젹으로 관리.
227 | - Ground Rule을 정해 모두 일관된 Commit convention을 유지
228 |
229 |
230 |
231 | ### 3️⃣ Links :
232 |
233 | - [개발 관련 Notion 링크](https://uomnf97.notion.site/NLP-08-LawBot-b2dfef92f666458583d6b459af53aa66?pvs=4)
234 | - [Youtube 발표 영상](https://www.youtube.com/watch?v=fgboxtWM4B4)
235 | - [발표 영상 자료](https://github.com/boostcampaitech5/level3_nlp_finalproject-nlp-08/files/12394029/NLP_08_.LawBot.pdf)
236 |
237 |
--------------------------------------------------------------------------------