├── .gitignore ├── README.md ├── data ├── README.md ├── caption_data │ ├── MSCOCO_train_val_Korean.json │ ├── data_download.sh │ ├── dataset_coco_kor.json │ ├── train2014 │ │ └── COCO_train2014_000000000009.jpg │ └── valid2014 │ │ └── COCO_val2014_000000000042.jpg └── poem_data │ ├── crawl │ ├── crawl.py │ └── poem_crawler │ │ ├── data_crawl.sh │ │ ├── poem_crawler │ │ ├── __init__.py │ │ ├── items.py │ │ ├── middlewares.py │ │ ├── pipelines.py │ │ ├── settings.py │ │ └── spiders │ │ │ ├── __init__.py │ │ │ └── spider.py │ │ └── scrapy.cfg │ ├── preprocess_data │ ├── poem.csv │ ├── preprocess(extract_keyword).ipynb │ └── preprocess_poems.ipynb │ └── raw_data │ └── raw.txt ├── evaluation ├── scoring_with_bert │ ├── __pycache__ │ │ ├── arguments.cpython-37.pyc │ │ ├── arguments.cpython-39.pyc │ │ ├── model.cpython-39.pyc │ │ ├── train.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── arguments.py │ ├── demo.ipynb │ ├── main.py │ ├── model.py │ ├── train.py │ └── utils.py └── sentencebert │ ├── arguments.py │ ├── demo.ipynb │ ├── main.py │ ├── model.py │ └── train.py ├── model ├── caption_model │ ├── utils.py │ └── vit_gpt2 │ │ ├── dataset.py │ │ ├── run_train.py │ │ ├── tt │ │ └── utils.py ├── gpt2_base_train.py ├── poem_model │ ├── gpt2_base │ │ ├── dataset.py │ │ ├── run_train.py │ │ ├── t │ │ └── utils.py │ ├── gpt2_trinity │ │ ├── arguments.py │ │ ├── t │ │ ├── train.py │ │ └── utils.py │ └── utils.py └── vit_gpt2_train.py ├── requirements.txt ├── show_attend_and_tell ├── README.md ├── caption.py ├── create_input_files.py ├── datasets.py ├── eval.py ├── model.py ├── requirements.txt ├── train.py └── utils.py └── web ├── app.py ├── db_utils.py ├── package-lock.json ├── utils.py └── web ├── assets └── images │ └── profile.jpeg └── templates ├── about.html ├── js.js ├── layout.html └── responsive.html /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | __pycache__ 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode 132 | .vscode/ 133 | 134 | # Data & Model 135 | data/ 136 | 137 | outputs/ 138 | 139 | wandb/ 140 | finetuned/ 141 | checkpoints/ 142 | dataset/ 143 | data/caption_data/MSCOCO_train_val_Korean.json 144 | checkpoints/ 145 | dataset/ 146 | 147 | # css 148 | web/web/semantic/ 149 | web/semantic.json 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 청계산 셰르파 2 | 3 | # Look, Attend and Generate Poem 사진을 보고 시를 써내려가는 감성시인 서비스 4 | 5 | 해당 프로젝트는 네이버 커넥트재단 부스트캠프 AI Tech 2기 청계산셰르파 팀에서 진행한 최종 프로젝트로 사용자가 이미지를 업로드하면 이미지에 걸맞는 시를 생성하여 카드형태로 다운로드 혹은 공유할 수 있는 웹서비스입니다. 6 | 7 |

8 | 9 |

10 | 11 | ## 팀원 & 역할 소개 12 | |||||||| 13 | | :--------: | :--------: | :--------: | :--------: | :--------: | :--------: | :--------: | 14 | |[T2011] 곽진성
[@jskwak98](https://github.com/jskwak98)|[T2025] 김민수
[@lexiconium](https://github.com/lexiconium)|[T2076] 문하겸
[@ddobokki](https://github.com/ddobokki)|[T2166] 이요한
[@l-yohai](https://github.com/l-yohai)|[T2195]
전준영
[@20180707jun](https://github.com/20180707jun)|[T2206] 정진원
[@godjw](https://github.com/godjw)|[T2210] 정희영
[@hyeong01](https://github.com/hyeong01)| 15 | |데이터 수집 및 전처리|데이터 수집
및 전처리|데이터 수집 및 전처리|데이터 수집 및 전처리|데이터 수집
및 전처리|데이터 수집
및 전처리|데이터 수집 및 전처리| 16 | |데이터 분석|생성 모델
모델링|Vision Encoder Decoder
모델 학습|모델링 및
베이스라인
작성|서비스 아키텍쳐 구성 및 모델 서빙|캡셔닝 모델 한국어 데이터에 대해 학습|데이터 분석| 17 | |시 생성 모델 학습 및 개선|시 생성 모델 학습 및 개선|시 생성 모델 학습|서비스
아키텍쳐 구성 및 UI/UX 디자인|웹사이트 및 API 설계, UI/UX 디자인|시 생성 모델 학습 및 개선|모델
성능평가
방법론 연구개발| 18 | 19 | 20 | ## Installation 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Architecture 26 | 27 | ![](https://i.imgur.com/5BkTjCf.png) 28 | 29 | 30 | ## Usage 31 | 32 | ### Crawl 33 | 34 | ```bash 35 | python data/crawl/crawl.py 36 | ``` 37 | 38 | ### Train 39 | 40 | **Caption Model** 41 | ```bash 42 | python model/vit_gpt2_train.py 43 | ``` 44 | Vision Encoder Decoder model의 경우 저희가 학습시킨 이후 서비스에서 사용하는 가중치는 [이곳](https://huggingface.co/ddobokki/vision-encoder-decoder-vit-gpt2-coco-ko)에 공개되어 있습니다. 45 | 46 | Show, attend and Tell 방식의 캡셔닝은 최종적으로 사용되지는 않았지만, 사용해보고 싶으시면 [이곳](https://github.com/boostcampaitech2/final-project-level3-nlp-08/tree/dev/merge/show_attend_and_tell)을 확인해주시면 됩니다. 47 | 48 |
49 | 50 | **Poem Model** 51 | ```bash 52 | # gpt2 base 53 | python model/gpt2_base_train.py 54 | ``` 55 | Poem generator model의 경우 저희가 학습시킨 이후 서비스에서 사용하는 가중치는 [이곳](https://huggingface.co/ddobokki/gpt2_poem)과 [이곳](https://huggingface.co/CheonggyeMountain-Sherpa/kogpt-trinity-poem)에 공개되어 있습니다. 56 | 57 | ### Inference 58 | **Caption Model** 59 | 60 | ```python 61 | import requests 62 | import torch 63 | from PIL import Image 64 | from transformers import ( 65 | VisionEncoderDecoderModel, 66 | ViTFeatureExtractor, 67 | PreTrainedTokenizerFast, 68 | ) 69 | 70 | # device setting 71 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 72 | 73 | # load feature extractor and tokenizer 74 | encoder_model_name_or_path = "ddobokki/vision-encoder-decoder-vit-gpt2-coco-ko" 75 | feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_model_name_or_path) 76 | tokenizer = PreTrainedTokenizerFast.from_pretrained(encoder_model_name_or_path) 77 | 78 | # load model 79 | model = VisionEncoderDecoderModel.from_pretrained(encoder_model_name_or_path) 80 | model.to(device) 81 | 82 | # inference 83 | url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 84 | with Image.open(requests.get(url, stream=True).raw) as img: 85 | pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values 86 | 87 | generated_ids = model.generate(pixel_values.to(device),num_beams=5) 88 | generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 89 | 90 | >> ['고양이 두마리가 담요 위에 누워 있다.'] 91 | ``` 92 | 93 | **Poem Model** 94 | ```python 95 | import torch 96 | from transformers import AutoTokenizer, AutoModelForCausalLM 97 | 98 | # device setting 99 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 100 | # load model and tokenizer 101 | model_name_or_path = "ddobokki/gpt2_poem" 102 | 103 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 104 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path) 105 | model.to(device) 106 | 107 | keyword_start_token = "" 108 | keyword_end_token = "" 109 | text = "산 꼭대기가 보이는 경치" 110 | input_text = keyword_start_token + text + keyword_end_token 111 | 112 | input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) 113 | gen_ids = model.generate( 114 | input_ids, max_length=64, num_beams=100, no_repeat_ngram_size=2 115 | ) 116 | generated = tokenizer.decode(gen_ids[0, :].tolist(), skip_special_tokens=True) 117 | >> 오르락내리락 118 | 산 꼭대기를 올려다보니 119 | 아득히 멀고 아득한 120 | 나뭇가지에 매달린 121 | 작은 산새 한 마리 122 | 이름 모를 풀 한포기 안고 123 | 어디론가 훌쩍 떠나가 버렸다 124 | ``` 125 | 126 | 127 | ### Web 128 | ``` 129 | python web/app.py 130 | ``` 131 | web에 관련된 코드는 [이곳](https://github.com/boostcampaitech2/final-project-level3-nlp-08/tree/dev/merge/web)에 공개되어 있습니다. 132 | 133 | ## Service Outputs 134 | 135 |

136 | 137 | 138 | 139 |

140 | 141 | ## Reference 142 | 143 | - [MS COCO](https://cocodataset.org/#home) 144 | - [AI HUB 한국어 이미지 설명 데이터셋](https://aihub.or.kr/opendata/keti-data/recognition-visual/KETI-01-003) 145 | - [국립국어원 모두의 말뭉치 비출판물 데이터](https://corpus.korean.go.kr/) 146 | - [근현대시 데이터](www.baedalmal.com/) 147 | - [글틴 시 데이터](https://teen.munjang.or.kr/archives/category/write/poetry) 148 | - [디카시 마니아 시, 이미지 데이터](https://cafe.daum.net/dicapoetry/1aSh) 149 | - [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/pdf/1502.03044.pdf) 150 | - [SP-GPT2: Semantics Improvement in Vietnamese Poetry Generation (GPT2 + LSTM)](https://arxiv.org/abs/2110.15723) 151 | - [CCPM: A Chinese Classical Poetry Matching Dataset (CCPM Evaluation)](https://arxiv.org/abs/2106.01979) 152 | - [Automatic Poetry Generation from Prosaic Text](https://aclanthology.org/2020.acl-main.223.pdf) 153 | - [MixPoet: Diverse Poetry Generation via Learning Controllable Mixed Latent Space (Mixed Latent Space 를 사용한 시 generation)](https://ojs.aaai.org/index.php/AAAI/article/view/6488) 154 | - [Introducing Aspects of Creativity in Automatic Poetry Generation (크라우드소싱 eval + 그 외 insight)](https://arxiv.org/pdf/2002.02511.pdf) 155 | - [Lingxi: A Diversity-aware Chinese Modern Poetry Generation System lower self BLEU score + human eval](https://arxiv.org/pdf/2108.12108.pdf) 156 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Caption data 2 | 3 | ## image download 4 | 5 | ```bash 6 | caption_data/data_download.sh 7 | ``` 8 | 9 | ## kor labels (MSCOCO_train_val_Korean.json) 10 | 11 | [AI hub](https://aihub.or.kr/opendata/keti-data/recognition-visual/KETI-01-003) 12 | 13 | # Poem Data 14 | 15 | * [글틴 시 데이터](https://teen.munjang.or.kr/archives/category/write/poetry) 16 | * [근현대시 400편](http://www.baedalmal.com/poem/1-10.html) 17 | * [디카시 마니아 창작 게시판](https://cafe.daum.net/dicapoetry/1aSh) 18 | 19 | 작품별로 제목과 시를 크롤링해 csv 파일로 저장합니다. 20 | 21 | ## crawl 22 | 23 | ```bash 24 | poem_data/crawl/poem_crawler/data_crawl.sh 25 | ``` -------------------------------------------------------------------------------- /data/caption_data/MSCOCO_train_val_Korean.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/caption_data/MSCOCO_train_val_Korean.json -------------------------------------------------------------------------------- /data/caption_data/data_download.sh: -------------------------------------------------------------------------------- 1 | wget http://images.cocodataset.org/zips/val2014.zip && wget http://images.cocodataset.org/zips/train2014.zip -------------------------------------------------------------------------------- /data/caption_data/dataset_coco_kor.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/caption_data/dataset_coco_kor.json -------------------------------------------------------------------------------- /data/caption_data/train2014/COCO_train2014_000000000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/caption_data/train2014/COCO_train2014_000000000009.jpg -------------------------------------------------------------------------------- /data/caption_data/valid2014/COCO_val2014_000000000042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/caption_data/valid2014/COCO_val2014_000000000042.jpg -------------------------------------------------------------------------------- /data/poem_data/crawl/crawl.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/poem_data/crawl/crawl.py -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/data_crawl.sh: -------------------------------------------------------------------------------- 1 | scrapy crawl geulteen 2 | scrapy crawl modernpoem 3 | scrapy crawl dica -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/poem_data/crawl/poem_crawler/poem_crawler/__init__.py -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/items.py: -------------------------------------------------------------------------------- 1 | # Define here the models for your scraped items 2 | # 3 | # See documentation in: 4 | # https://docs.scrapy.org/en/latest/topics/items.html 5 | 6 | import scrapy 7 | 8 | 9 | class PoemCrawlerItem(scrapy.Item): 10 | # define the fields for your item here like: 11 | # name = scrapy.Field() 12 | pass 13 | -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/middlewares.py: -------------------------------------------------------------------------------- 1 | # Define here the models for your spider middleware 2 | # 3 | # See documentation in: 4 | # https://docs.scrapy.org/en/latest/topics/spider-middleware.html 5 | 6 | from scrapy import signals 7 | 8 | # useful for handling different item types with a single interface 9 | from itemadapter import is_item, ItemAdapter 10 | 11 | 12 | class PoemCrawlerSpiderMiddleware: 13 | # Not all methods need to be defined. If a method is not defined, 14 | # scrapy acts as if the spider middleware does not modify the 15 | # passed objects. 16 | 17 | @classmethod 18 | def from_crawler(cls, crawler): 19 | # This method is used by Scrapy to create your spiders. 20 | s = cls() 21 | crawler.signals.connect(s.spider_opened, signal=signals.spider_opened) 22 | return s 23 | 24 | def process_spider_input(self, response, spider): 25 | # Called for each response that goes through the spider 26 | # middleware and into the spider. 27 | 28 | # Should return None or raise an exception. 29 | return None 30 | 31 | def process_spider_output(self, response, result, spider): 32 | # Called with the results returned from the Spider, after 33 | # it has processed the response. 34 | 35 | # Must return an iterable of Request, or item objects. 36 | for i in result: 37 | yield i 38 | 39 | def process_spider_exception(self, response, exception, spider): 40 | # Called when a spider or process_spider_input() method 41 | # (from other spider middleware) raises an exception. 42 | 43 | # Should return either None or an iterable of Request or item objects. 44 | pass 45 | 46 | def process_start_requests(self, start_requests, spider): 47 | # Called with the start requests of the spider, and works 48 | # similarly to the process_spider_output() method, except 49 | # that it doesn’t have a response associated. 50 | 51 | # Must return only requests (not items). 52 | for r in start_requests: 53 | yield r 54 | 55 | def spider_opened(self, spider): 56 | spider.logger.info('Spider opened: %s' % spider.name) 57 | 58 | 59 | class PoemCrawlerDownloaderMiddleware: 60 | # Not all methods need to be defined. If a method is not defined, 61 | # scrapy acts as if the downloader middleware does not modify the 62 | # passed objects. 63 | 64 | @classmethod 65 | def from_crawler(cls, crawler): 66 | # This method is used by Scrapy to create your spiders. 67 | s = cls() 68 | crawler.signals.connect(s.spider_opened, signal=signals.spider_opened) 69 | return s 70 | 71 | def process_request(self, request, spider): 72 | # Called for each request that goes through the downloader 73 | # middleware. 74 | 75 | # Must either: 76 | # - return None: continue processing this request 77 | # - or return a Response object 78 | # - or return a Request object 79 | # - or raise IgnoreRequest: process_exception() methods of 80 | # installed downloader middleware will be called 81 | return None 82 | 83 | def process_response(self, request, response, spider): 84 | # Called with the response returned from the downloader. 85 | 86 | # Must either; 87 | # - return a Response object 88 | # - return a Request object 89 | # - or raise IgnoreRequest 90 | return response 91 | 92 | def process_exception(self, request, exception, spider): 93 | # Called when a download handler or a process_request() 94 | # (from other downloader middleware) raises an exception. 95 | 96 | # Must either: 97 | # - return None: continue processing this exception 98 | # - return a Response object: stops process_exception() chain 99 | # - return a Request object: stops process_exception() chain 100 | pass 101 | 102 | def spider_opened(self, spider): 103 | spider.logger.info('Spider opened: %s' % spider.name) 104 | -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/pipelines.py: -------------------------------------------------------------------------------- 1 | # Define your item pipelines here 2 | # 3 | # Don't forget to add your pipeline to the ITEM_PIPELINES setting 4 | # See: https://docs.scrapy.org/en/latest/topics/item-pipeline.html 5 | 6 | 7 | # useful for handling different item types with a single interface 8 | from itemadapter import ItemAdapter 9 | 10 | 11 | class PoemCrawlerPipeline: 12 | def process_item(self, item, spider): 13 | return item 14 | -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/settings.py: -------------------------------------------------------------------------------- 1 | # Scrapy settings for poem_crawler project 2 | # 3 | # For simplicity, this file contains only settings considered important or 4 | # commonly used. You can find more settings consulting the documentation: 5 | # 6 | # https://docs.scrapy.org/en/latest/topics/settings.html 7 | # https://docs.scrapy.org/en/latest/topics/downloader-middleware.html 8 | # https://docs.scrapy.org/en/latest/topics/spider-middleware.html 9 | 10 | BOT_NAME = 'poem_crawler' 11 | 12 | SPIDER_MODULES = ['poem_crawler.spiders'] 13 | NEWSPIDER_MODULE = 'poem_crawler.spiders' 14 | 15 | 16 | # Crawl responsibly by identifying yourself (and your website) on the user-agent 17 | #USER_AGENT = 'poem_crawler (+http://www.yourdomain.com)' 18 | 19 | # Obey robots.txt rules 20 | ROBOTSTXT_OBEY = True 21 | 22 | # Configure maximum concurrent requests performed by Scrapy (default: 16) 23 | #CONCURRENT_REQUESTS = 32 24 | 25 | # Configure a delay for requests for the same website (default: 0) 26 | # See https://docs.scrapy.org/en/latest/topics/settings.html#download-delay 27 | # See also autothrottle settings and docs 28 | #DOWNLOAD_DELAY = 3 29 | # The download delay setting will honor only one of: 30 | #CONCURRENT_REQUESTS_PER_DOMAIN = 16 31 | #CONCURRENT_REQUESTS_PER_IP = 16 32 | 33 | # Disable cookies (enabled by default) 34 | #COOKIES_ENABLED = False 35 | 36 | # Disable Telnet Console (enabled by default) 37 | #TELNETCONSOLE_ENABLED = False 38 | 39 | # Override the default request headers: 40 | #DEFAULT_REQUEST_HEADERS = { 41 | # 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 42 | # 'Accept-Language': 'en', 43 | #} 44 | 45 | # Enable or disable spider middlewares 46 | # See https://docs.scrapy.org/en/latest/topics/spider-middleware.html 47 | #SPIDER_MIDDLEWARES = { 48 | # 'poem_crawler.middlewares.PoemCrawlerSpiderMiddleware': 543, 49 | #} 50 | 51 | # Enable or disable downloader middlewares 52 | # See https://docs.scrapy.org/en/latest/topics/downloader-middleware.html 53 | #DOWNLOADER_MIDDLEWARES = { 54 | # 'poem_crawler.middlewares.PoemCrawlerDownloaderMiddleware': 543, 55 | #} 56 | 57 | # Enable or disable extensions 58 | # See https://docs.scrapy.org/en/latest/topics/extensions.html 59 | #EXTENSIONS = { 60 | # 'scrapy.extensions.telnet.TelnetConsole': None, 61 | #} 62 | 63 | # Configure item pipelines 64 | # See https://docs.scrapy.org/en/latest/topics/item-pipeline.html 65 | #ITEM_PIPELINES = { 66 | # 'poem_crawler.pipelines.PoemCrawlerPipeline': 300, 67 | #} 68 | 69 | # Enable and configure the AutoThrottle extension (disabled by default) 70 | # See https://docs.scrapy.org/en/latest/topics/autothrottle.html 71 | #AUTOTHROTTLE_ENABLED = True 72 | # The initial download delay 73 | #AUTOTHROTTLE_START_DELAY = 5 74 | # The maximum download delay to be set in case of high latencies 75 | #AUTOTHROTTLE_MAX_DELAY = 60 76 | # The average number of requests Scrapy should be sending in parallel to 77 | # each remote server 78 | #AUTOTHROTTLE_TARGET_CONCURRENCY = 1.0 79 | # Enable showing throttling stats for every response received: 80 | #AUTOTHROTTLE_DEBUG = False 81 | 82 | # Enable and configure HTTP caching (disabled by default) 83 | # See https://docs.scrapy.org/en/latest/topics/downloader-middleware.html#httpcache-middleware-settings 84 | #HTTPCACHE_ENABLED = True 85 | #HTTPCACHE_EXPIRATION_SECS = 0 86 | #HTTPCACHE_DIR = 'httpcache' 87 | #HTTPCACHE_IGNORE_HTTP_CODES = [] 88 | #HTTPCACHE_STORAGE = 'scrapy.extensions.httpcache.FilesystemCacheStorage' 89 | -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/spiders/__init__.py: -------------------------------------------------------------------------------- 1 | # This package will contain the spiders of your Scrapy project 2 | # 3 | # Please refer to the documentation for information on how to create and manage 4 | # your spiders. 5 | -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/poem_crawler/spiders/spider.py: -------------------------------------------------------------------------------- 1 | import scrapy 2 | from os import path 3 | import json 4 | import pandas as pd 5 | from scrapy.http import request 6 | 7 | 8 | data_path = "../../../../raw_data/" 9 | 10 | 11 | class TeenSpider(scrapy.Spider): 12 | # 글틴 사이트 시 크롤링 13 | name = "geulteen" 14 | 15 | def __init__(self, name=None, **kwargs): 16 | super().__init__(name=name, **kwargs) 17 | self.poems = [] 18 | """ 19 | { 20 | "title": title, 21 | "poem": poem 22 | } 23 | """ 24 | 25 | def start_requests(self): 26 | url_main = 'https://teen.munjang.or.kr/archives/category/write/poetry' 27 | yield scrapy.Request(url=url_main, callback=self.parse_page) 28 | for page_num in range(2,2207): 29 | url = url_main + '/page/' + str(page_num) 30 | yield scrapy.Request(url=url, callback=self.parse_page) 31 | 32 | def parse_page(self, response): 33 | for i in range(1, 11): 34 | 35 | title = response.xpath(f"/html/body/div[1]/div[4]/div[2]/div/main/article[{i}]/div/div[2]/div[1]/a/text()").get() 36 | if '장원' not in title and '시 게시판' not in title: 37 | page = response.xpath(f"/html/body/div[1]/div[4]/div[2]/div/main/article[{i}]/div/div[2]/div[1]/a/@href").get() 38 | yield scrapy.Request(url=page, callback=self.parse) 39 | 40 | def parse(self, response): 41 | self.poems.append({ 42 | "title": response.xpath('//header/h1/text()').get(), 43 | "poem": response.xpath('//div[@class="entry-content"]/p//text()').getall() 44 | }) 45 | 46 | 47 | def closed(self, reason): 48 | data = pd.DataFrame(self.poems) 49 | data.to_csv(path + 'geulteen_poems.csv', encoding='utf-8') 50 | 51 | 52 | class ModernPoemSpider(scrapy.Spider): 53 | # 근현대시 400편 크롤링 54 | name = "modernpoem" 55 | 56 | def __init__(self, name=None, **kwargs): 57 | super().__init__(name=name, **kwargs) 58 | self.contents = [] 59 | self.keys = { 60 | 0: 47, 61 | 1: 52, 62 | 2: [36,24], 63 | 3: 41, 64 | 4: 24, 65 | 5: [44, 17], 66 | 6: 68, 67 | 7: 49 68 | } 69 | """ 70 | { 71 | "title": title, 72 | "poem": poem 73 | } 74 | """ 75 | 76 | def start_requests(self): 77 | urls = ['http://www.baedalmal.com/poem/1-10.html', 78 | 'http://www.baedalmal.com/poem/1-20.html', 79 | 'http://www.baedalmal.com/poem/1-30.html', 80 | 'http://www.baedalmal.com/poem/1-40.html', 81 | 'http://www.baedalmal.com/poem/2-10.html', 82 | 'http://www.baedalmal.com/poem/2-20.html', 83 | 'http://www.baedalmal.com/poem/2-30.html', 84 | 'http://www.baedalmal.com/poem/2-40.html' 85 | ] 86 | for i in range(len(urls)): 87 | request = scrapy.Request(url=urls[i], callback=self.parse, cb_kwargs=dict(url_num=i), encoding='cp949') 88 | yield request 89 | 90 | def parse(self, response, url_num): 91 | print(f'\n\n\n\n\n\n\n{response.encoding}\n\n\n\n\n\n') 92 | if response.encoding == 'cp1252': 93 | return 94 | magic_num = self.keys[url_num] 95 | if type(magic_num) == int: 96 | for i in range(1, magic_num+1): 97 | poem = response.xpath(f'/html/body/ul[{i}]//font[@size="3"]//text()').getall() 98 | new = "\n".join(poem) 99 | self.contents.append({ 100 | "title" : response.xpath(f'/html/body/ul[{i}]//b//text()').get(), 101 | "poem" : new 102 | }) 103 | else: 104 | for i in range(1, magic_num[0]+1): 105 | poem = response.xpath(f'/html/body/ul[{i}]//font[@size="3"]//text()').getall() 106 | new = "\n".join(poem) 107 | self.contents.append({ 108 | "title" : response.xpath(f'/html/body/ul[{i}]//b//text()').get(), 109 | "poem" : new 110 | }) 111 | for i in range(1, magic_num[1]+1): 112 | poem = response.xpath(f'/html/body/ul[{magic_num[0]+1}]/ul/ul/ul[{i}]//font[@size="3"]//text()').getall() 113 | new = "\n".join(poem) 114 | self.contents.append({ 115 | "title" : response.xpath(f'/html/body/ul[{magic_num[0]+1}]/ul/ul/ul[{i}]//b//text()').get(), 116 | "poem" : new 117 | }) 118 | 119 | 120 | def closed(self, reason): 121 | csvdata = pd.DataFrame(self.contents) 122 | csvdata.to_csv(path +'modern_poems_raw.csv', encoding='utf-8') 123 | 124 | 125 | class DicaSpider(scrapy.Spider): 126 | # 디카시 마니아의 시 크롤링 127 | name = "dica" 128 | 129 | def __init__(self, name=None, **kwargs): 130 | super().__init__(name=name, **kwargs) 131 | self.poems = [] 132 | """ 133 | { 134 | "img": link, 135 | "poem": poem 136 | } 137 | """ 138 | 139 | def start_requests(self): 140 | url_main = 'https://m.cafe.daum.net/dicapoetry/1aSh/' 141 | for page_num in range(1,16681): 142 | url = url_main + str(page_num) 143 | yield scrapy.Request(url=url, callback=self.parse) 144 | 145 | def parse(self, response): 146 | if response.xpath('//div[contains(@class,"cafe_error")]'): 147 | return 148 | else: 149 | self.poems.append({ 150 | "img": response.xpath('//*[@id="article"]//img/@src').get(), 151 | "poem": response.xpath('//*[@id="article"]//text()').getall() 152 | }) 153 | 154 | def closed(self, reason): 155 | csvdata = pd.DataFrame(self.poems) 156 | csvdata.to_csv(path + 'dica_poems_raw.csv', encoding='utf-8') -------------------------------------------------------------------------------- /data/poem_data/crawl/poem_crawler/scrapy.cfg: -------------------------------------------------------------------------------- 1 | # Automatically created by: scrapy startproject 2 | # 3 | # For more information about the [deploy] section see: 4 | # https://scrapyd.readthedocs.io/en/latest/deploy.html 5 | 6 | [settings] 7 | default = poem_crawler.settings 8 | 9 | [deploy] 10 | #url = http://localhost:6800/ 11 | project = poem_crawler 12 | -------------------------------------------------------------------------------- /data/poem_data/preprocess_data/preprocess(extract_keyword).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from konlpy.tag import Okt\n", 10 | "import pandas as pd\n", 11 | "from collections import Counter" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "okt = Okt()" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "poem = pd.read_csv('poem.csv')" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# 키워드를 뽑는 함수\n", 39 | "def get_top3_noun(text):\n", 40 | " noun = okt.nouns(text)\n", 41 | " count = Counter(noun)\n", 42 | " noun_list = count.most_common(3)\n", 43 | " rtn_noun = list(map(lambda x: x[0],noun_list))\n", 44 | " return ', '.join(rtn_noun)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# 키워드 추출\n", 54 | "keyword = poem['poem'].apply(get_top3_noun)\n", 55 | "poem['key_word'] = keyword" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "poem_with_keyowrd = pd.DataFrame(data={\n", 65 | " 'text' : poem['text'],\n", 66 | " 'key_word': poem['key_word']\n", 67 | " })\n", 68 | "poem_with_keyowrd.to_csv('poem_with_keyowrd.csv',index=False)" 69 | ] 70 | } 71 | ], 72 | "metadata": { 73 | "interpreter": { 74 | "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" 75 | }, 76 | "kernelspec": { 77 | "display_name": "Python 3.8.12 64-bit ('base': conda)", 78 | "language": "python", 79 | "name": "python3" 80 | }, 81 | "language_info": { 82 | "codemirror_mode": { 83 | "name": "ipython", 84 | "version": 3 85 | }, 86 | "file_extension": ".py", 87 | "mimetype": "text/x-python", 88 | "name": "python", 89 | "nbconvert_exporter": "python", 90 | "pygments_lexer": "ipython3", 91 | "version": "3.8.12" 92 | }, 93 | "orig_nbformat": 4 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 2 97 | } 98 | -------------------------------------------------------------------------------- /data/poem_data/preprocess_data/preprocess_poems.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1. 크롤링 데이터 전처리
\n", 8 | "Scrapy를 통해 List 형식으로 크롤링된 시를 str 형식으로 변경" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import re\n", 18 | "import pandas as pd\n", 19 | "import ast\n", 20 | "import hanja" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from_teen = pd.read_csv(\"../raw_data/geulteen_poems.csv\").drop(\"Unnamed: 0\", axis=1)\n", 30 | "from_modern_poems = pd.read_csv(\"../raw_data/modern_poems_raw.csv\").drop(\"Unnamed: 0\", axis=1)\n", 31 | "from_dica_poems = pd.read_csv(\"../raw_data/dica_poems_raw.csv\").drop(\"Unnamed: 0\", axis=1)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# \"[poem]\" 형식의 str을 List인 [poem]으로 치환\n", 41 | "def listify(poem):\n", 42 | " return ast.literal_eval(poem)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from_teen['poem'] = from_teen['poem'].apply(listify)\n", 52 | "from_modern_poems['poem'] = from_modern_poems['poem'].apply(listify)\n", 53 | "from_dica_poems['poem'] = from_dica_poems['poem'].apply(listify)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# 리스트 안의 시 내용을 개행문자로 묶어줌. \n", 63 | "def strip_and_join_newline(poem_list):\n", 64 | " return \"\\n\".join(map(str.strip, poem_list))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "# 2. 데이터의 노이즈가 될 수 있는 부분 전처리\n", 72 | "KLUE 데이터셋의 전처리 방식을 활용" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# Klue 데이터셋 전처리 응용\n", 82 | "def preprocess(poem):\n", 83 | " new = []\n", 84 | " for text in poem:\n", 85 | " # 문제를 일으킬 수 있는 문자 제거\n", 86 | " bad_chars = {\"\\u200b\": \"\", \"…\": \" ... \", \"\\ufeff\": \"\"}\n", 87 | " for bad_char in bad_chars:\n", 88 | " text = text.replace(bad_char, bad_chars[bad_char])\n", 89 | " \n", 90 | " error_chars = {\"\\u3000\": \" \", \"\\u2009\": \" \", \"\\u2002\": \" \", \"\\xa0\":\" \"}\n", 91 | " for error_char in error_chars:\n", 92 | " text = text.replace(error_char, error_chars[error_char])\n", 93 | "\n", 94 | " # URL 제거\n", 95 | " text = re.sub(r\"(http|https)?:\\/\\/\\S+\\b|www\\.(\\w+\\.)+\\S*\", \"[웹주소]\", text).strip()\n", 96 | " text = re.sub(r\"pic\\.(\\w+\\.)+\\S*\", \"[웹주소]\", text).strip()\n", 97 | "\n", 98 | " # 뉴스 저작권 관련 텍스트 제거\n", 99 | " re_patterns = [\n", 100 | " r\"\\<저작권자(\\(c\\)|ⓒ|©|\\(Copyright\\)|(\\(c\\))|(\\(C\\))).+?\\>\",\n", 101 | " r\"저작권자\\(c\\)|ⓒ|©|(Copyright)|(\\(c\\))|(\\(C\\))\"\n", 102 | " ]\n", 103 | " \n", 104 | " for re_pattern in re_patterns:\n", 105 | " text = re.sub(re_pattern, \"\", text).strip()\n", 106 | " \n", 107 | " # 뉴스 내 포함된 이미지에 대한 레이블 제거\n", 108 | " text = re.sub(r\"\\(출처 ?= ?.+\\) |\\(사진 ?= ?.+\\) |\\(자료 ?= ?.+\\)| \\(자료사진\\) |사진=.+기자 \", \"\", text).strip()\n", 109 | " \n", 110 | " # 문제를 일으킬 수 있는 구두점 치환\n", 111 | " punct_mapping = {\"‘\": \"'\", \"₹\": \"e\", \"´\": \"'\", \"°\": \"\", \"€\": \"e\", \"™\": \"tm\", \"√\": \" sqrt \", \"×\": \"x\", \"²\": \"2\", \"—\": \"-\", \"–\": \"-\", \"’\": \"'\", \"_\": \"-\", \"`\": \"'\", '“': '\"', '”': '\"', '“': '\"', \"£\": \"e\", '∞': 'infinity', 'θ': 'theta', '÷': '/', 'α': 'alpha', '•': '.', 'à': 'a', '−': '-', 'β': 'beta', '∅': '', '³': '3', 'π': 'pi', }\n", 112 | " for p in punct_mapping:\n", 113 | " text = text.replace(p, punct_mapping[p])\n", 114 | " \n", 115 | " # 연속된 공백 치환\n", 116 | " text = re.sub(r\"\\s+\", \" \", text).strip()\n", 117 | " \n", 118 | " # 개행을 먼저 없애고 그 후 합쳐줌.\n", 119 | " re.sub('\\n|\\t|\\r', \"\", text).strip()\n", 120 | " re.sub('\\xa0', \" \", text).strip()\n", 121 | "\n", 122 | " if text:\n", 123 | " new.append(text)\n", 124 | " return \"\\n\".join(new)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "from_teen['poem'] = from_teen['poem'].apply(preprocess)\n", 134 | "from_modern_poems['poem'] = from_modern_poems['poem'].apply(preprocess)\n", 135 | "from_dica_poems['poem'] = from_dica_poems['poem'].apply(preprocess)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "# 3. 한자 전처리\n", 143 | "한자 전처리가 없을 시, 모델이 적절하지 않은 한자어를 생성하는 경우가 있음을 확인, 데이터의 한자를 제거 및 번역" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "# 한글과 병기된 한자어는 삭제 후, 남은 한자어는 번역\n", 153 | "def hanja_preprocess(txt):\n", 154 | " new = re.sub(r\"\\([\\u2e80-\\u2eff\\u31c0-\\u31ef\\u3200-\\u32ff\\u3400-\\u4dbf\\u4e00-\\u9fbf\\uf900-\\ufaff]+\\)|\\[[\\u2e80-\\u2eff\\u31c0-\\u31ef\\u3200-\\u32ff\\u3400-\\u4dbf\\u4e00-\\u9fbf\\uf900-\\ufaff]+\\]\", \"\", txt)\n", 155 | " new = hanja.translate(new, 'substitution')\n", 156 | " return new" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "from_teen['poem'] = from_teen['poem'].apply(hanja_preprocess)\n", 166 | "from_modern_poems['poem'] = from_modern_poems['poem'].apply(hanja_preprocess)\n", 167 | "from_dica_poems['poem'] = from_dica_poems['poem'].apply(hanja_preprocess)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "# 4. CSV 형식으로 Export\n", 175 | "같은 형식으로 응용 가능한 teen과 modern_poems는 하나로 concat" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "from_teen.to_csv(\"teen.csv\")\n", 185 | "from_modern_poems.to_csv(\"modern_poems.csv\")\n", 186 | "from_dica_poems.to_csv(\"dica_poems.csv\")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "train_data = pd.concat([from_teen, from_modern_poems])\n", 196 | "train_data.to_csv(\"train.csv\")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "# 5. 기타 전처리\n", 204 | "인터넷 게시판의 시를 크롤링 했기에, 코드 외적으로 수정할 부분들이 많았습니다.
\n", 205 | "____, ----, ++++, 등으로 시작하는 분리(?) 마커 뒤에 개인적인 시에 대한 코멘트를 덧붙이는 경우,
\n", 206 | "dica_poems는 작가명(개인정보)이 같이 크롤링 된 경우,
\n", 207 | "주석을 단어에 *, 1) 등의 표시를 해둔 후 글 마지막에 설명하는 경우 등이 있었습니다.
\n", 208 | "또한, 현대시 형식으로 띄어쓰기 없이 작성된 시들, 특수문자만 사용한 시들 등은 상당부분 직접 확인하고 삭제했습니다." 209 | ] 210 | } 211 | ], 212 | "metadata": { 213 | "interpreter": { 214 | "hash": "e31c68abf1d5dd3f9e2269f23eadf1b199587e56c0618a30760176a65ebfcab4" 215 | }, 216 | "kernelspec": { 217 | "display_name": "Python 3.7.11 64-bit ('lightweight': conda)", 218 | "language": "python", 219 | "name": "python3" 220 | }, 221 | "language_info": { 222 | "codemirror_mode": { 223 | "name": "ipython", 224 | "version": 3 225 | }, 226 | "file_extension": ".py", 227 | "mimetype": "text/x-python", 228 | "name": "python", 229 | "nbconvert_exporter": "python", 230 | "pygments_lexer": "ipython3", 231 | "version": "3.7.11" 232 | }, 233 | "orig_nbformat": 4 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 2 237 | } 238 | -------------------------------------------------------------------------------- /data/poem_data/raw_data/raw.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/data/poem_data/raw_data/raw.txt -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/__pycache__/arguments.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/evaluation/scoring_with_bert/__pycache__/arguments.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/__pycache__/arguments.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/evaluation/scoring_with_bert/__pycache__/arguments.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/evaluation/scoring_with_bert/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/__pycache__/train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/evaluation/scoring_with_bert/__pycache__/train.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/evaluation/scoring_with_bert/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser('Image Classification', add_help=False) 5 | parser.add_argument('--seed', default=2021, type=int) 6 | parser.add_argument('--lr', default=2e-5, type=float) 7 | parser.add_argument('--train_bs', default=8, type=int) 8 | parser.add_argument('--num_epochs', default=100, type=int) 9 | parser.add_argument('--weight_decay', default=0.01, type=float) 10 | parser.add_argument('--early_stop', default=10, type=int) 11 | parser.add_argument('--adam_epsilon', default=1e-08, type=float) 12 | parser.add_argument('--gradient_accumulation_steps', default=32, type=int) 13 | parser.add_argument('--warmup_steps', default=20, type=int) 14 | parser.add_argument('--model', default='klue/roberta-small', type=str) 15 | parser.add_argument('--data', default='klue/sts', type=str) 16 | 17 | return parser -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import AutoTokenizer\n", 10 | "from model import RobertaEncoder\n", 11 | "import torch" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 21, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "tokenizer = AutoTokenizer.from_pretrained('klue/roberta-small')\n", 21 | "sen_encoder = RobertaEncoder.from_pretrained('checkpoints/val_acc_58.88%_sen_encoder').cuda()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 51, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def checker(title, text):\n", 31 | "\n", 32 | " val_u_seqs = tokenizer(title)\n", 33 | " val_v_seqs = tokenizer(text)\n", 34 | "\n", 35 | " u_input_ids_lst = val_u_seqs['input_ids']\n", 36 | " u_attention_mask_lst = val_u_seqs['attention_mask']\n", 37 | "\n", 38 | " v_input_ids_lst = val_v_seqs['input_ids']\n", 39 | " v_attention_mask_lst = val_v_seqs['attention_mask']\n", 40 | "\n", 41 | " with torch.no_grad():\n", 42 | " sen_encoder.eval()\n", 43 | " sen_encoder.eval()\n", 44 | " \n", 45 | " u_input_ids = torch.tensor(u_input_ids_lst).cuda().unsqueeze(0)\n", 46 | " u_attention_mask = torch.tensor(u_attention_mask_lst).cuda().unsqueeze(0)\n", 47 | " v_input_ids = torch.tensor(v_input_ids_lst).cuda().unsqueeze(0)\n", 48 | " v_attention_mask = torch.tensor(v_attention_mask_lst).cuda().unsqueeze(0)\n", 49 | " \n", 50 | " inputs = {'input_ids_1': u_input_ids,\n", 51 | " 'input_ids_2': v_input_ids,\n", 52 | " 'attention_mask_1': u_attention_mask,\n", 53 | " 'attention_mask_2': v_attention_mask\n", 54 | " }\n", 55 | "\n", 56 | " val_cos_sim = sen_encoder(**inputs).to('cpu')\n", 57 | " return val_cos_sim.item()\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 31, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "checker('title', 'text')" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 13, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stderr", 76 | "output_type": "stream", 77 | "text": [ 78 | "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", 79 | "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 80 | "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 81 | "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", 82 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 83 | "/opt/conda/envs/bc/lib/python3.9/site-packages/torch/nn/modules/loss.py:520: UserWarning: Using a target size (torch.Size([1, 1])) that is different to the input size (torch.Size([1, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", 84 | " return F.mse_loss(input, target, reduction=self.reduction)\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "from transformers import BertTokenizer, BertForSequenceClassification\n", 90 | "import torch\n", 91 | "\n", 92 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", 93 | "model = BertForSequenceClassification.from_pretrained('bert-base-uncased', problem_type = 'single_label_classification')\n", 94 | "\n", 95 | "inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n", 96 | "labels = torch.tensor([1]).unsqueeze(0) # Batch size 1\n", 97 | "outputs = model(**inputs, labels=labels)\n", 98 | "loss = outputs.loss\n", 99 | "logits = outputs.logits" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 25, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "from torch.nn import Softmax\n", 109 | "import numpy as np" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 118, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "[-0.4755, -0.4808]\n" 122 | ] 123 | }, 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "1" 128 | ] 129 | }, 130 | "execution_count": 118, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "m = Softmax(dim=1)\n", 137 | "input = torch.randn(2, 2)\n", 138 | "print([-0.4755, -0.4808])\n", 139 | "np.argmax([0.4755, 0.4808])" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 14, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/plain": [ 150 | "(SequenceClassifierOutput(loss=tensor(1.8609, grad_fn=), logits=tensor([[-0.2711, -0.4513]], grad_fn=), hidden_states=None, attentions=None),\n", 151 | " SequenceClassifierOutput(loss=tensor(0.5674, grad_fn=), logits=tensor([[0.0474, 0.5238]], grad_fn=), hidden_states=None, attentions=None))" 152 | ] 153 | }, 154 | "execution_count": 14, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "outputs, outputs1" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stderr", 170 | "output_type": "stream", 171 | "text": [ 172 | "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", 173 | "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 174 | "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 175 | "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", 176 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | ">>> from transformers import BertTokenizer, BertForSequenceClassification\n", 182 | ">>> import torch\n", 183 | "\n", 184 | ">>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", 185 | ">>> model = BertForSequenceClassification.from_pretrained('bert-base-uncased', problem_type=\"multi_label_classification\")\n", 186 | "\n", 187 | ">>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n", 188 | ">>> labels = torch.tensor([[1, 1]], dtype=torch.float) # need dtype=float for BCEWithLogitsLoss\n", 189 | ">>> outputs1 = model(**inputs, labels=labels)\n", 190 | ">>> loss1 = outputs1.loss\n", 191 | ">>> logits1 = outputs.logits" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "interpreter": { 197 | "hash": "8cf67b00fe44a7bbb95b84fa7b31b362c005e968b891258bc8597366f52a1c53" 198 | }, 199 | "kernelspec": { 200 | "display_name": "Python 3.9.7 64-bit ('bc': conda)", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.9.7" 215 | }, 216 | "orig_nbformat": 4 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 2 220 | } 221 | -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/main.py: -------------------------------------------------------------------------------- 1 | # for arguments 2 | import argparse 3 | from arguments import get_args_parser 4 | 5 | # torch 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | 9 | # settings 10 | from utils import seed_everything 11 | 12 | # hf 13 | from transformers import AutoTokenizer, RobertaForSequenceClassification 14 | 15 | # basics 16 | import pickle 17 | 18 | from model import RobertaScorer 19 | from train import train 20 | 21 | 22 | def main(args): 23 | # model for use 24 | model_checkpoint = args.model 25 | 26 | with open("dataset/dataset.bin", "rb") as fp: 27 | datasets = pickle.load(fp) 28 | 29 | train_data = datasets['train'] 30 | validation_data = datasets['validation'] 31 | 32 | train_labels = torch.tensor(train_data['label']) 33 | validation_labels = torch.tensor(validation_data['label']) 34 | 35 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 36 | 37 | val_emb = tokenizer(validation_data['poem'], max_length=512) 38 | train_emb = tokenizer(train_data['poem'], padding="max_length", truncation=True, return_tensors='pt', max_length=512) 39 | 40 | train_dataset = TensorDataset(train_emb['input_ids'], train_emb['token_type_ids'], train_labels) 41 | 42 | # load model 43 | #scoring_model = RobertaScorer(args.model).cuda() 44 | scoring_model = RobertaForSequenceClassification.from_pretrained(args.model).cuda() 45 | 46 | train(args, train_dataset, val_emb, validation_labels, scoring_model) 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser(description='RoBerta-Poem-Scorer', parents=[get_args_parser()]) 50 | args = parser.parse_args() 51 | seed_everything(args.seed) 52 | main(args) -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/model.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaModel, RobertaPreTrainedModel 2 | from torch import mean, nn 3 | import torch 4 | 5 | class RobertaScorer(nn.Module): 6 | def __init__(self, model_name): 7 | super(RobertaScorer, self).__init__() 8 | 9 | self.roberta = RobertaModel.from_pretrained(model_name) 10 | self.lstm = nn.LSTM(1024, 1024, batch_first=True,bidirectional=True).cuda() 11 | self.linear1 = nn.Linear(1024*2, 512) 12 | self.bn1 = nn.BatchNorm1d(512) 13 | self.dropout1 = nn.Dropout(p=0.5) 14 | 15 | self.linear2 = nn.Linear(512, 128) 16 | self.bn2 = nn.BatchNorm1d(128) 17 | self.dropout2 = nn.Dropout(p=0.5) 18 | 19 | self.linear3 = nn.Linear(128, 32) 20 | self.bn3 = nn.BatchNorm1d(32) 21 | self.dropout3 = nn.Dropout(p=0.5) 22 | 23 | self.linear4 = nn.Linear(32, 1) 24 | self.sig = nn.Sigmoid() 25 | 26 | 27 | def forward(self, input_ids, attention_mask=None): 28 | 29 | outputs = self.roberta(input_ids, attention_mask) 30 | 31 | seq_output = outputs[0] 32 | 33 | lstm_output, (h,c) = self.lstm(seq_output) 34 | hidden = torch.cat((lstm_output[:,-1, :1024],lstm_output[:,0, 1024:]),dim=-1) 35 | x = self.linear1(hidden.view(-1,1024*2)) 36 | x = self.bn1(x) 37 | x = self.dropout1(x) 38 | x = self.linear2(x) 39 | x = self.bn2(x) 40 | x = self.dropout2(x) 41 | x = self.linear3(x) 42 | x = self.bn3(x) 43 | x = self.dropout3(x) 44 | x = self.linear4(x) 45 | 46 | return self.sig(x) 47 | -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import (DataLoader, RandomSampler) 3 | from torch.nn import MSELoss, CrossEntropyLoss 4 | 5 | from transformers import AdamW, get_linear_schedule_with_warmup 6 | from transformers import AutoTokenizer 7 | 8 | from tqdm import trange, tqdm 9 | 10 | from scipy.stats import pearsonr 11 | 12 | import numpy as np 13 | 14 | def train(args, train_dataset, val_emb, validation_labels, scoring_model): 15 | 16 | # logging 17 | best_cor = 0 18 | stop_counter = 0 19 | 20 | # Dataloader 21 | train_sampler = RandomSampler(train_dataset) 22 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_bs, drop_last=True) 23 | 24 | # Optimizer 25 | no_decay = ['bias', 'LayerNorm.weight'] 26 | optimizer_grouped_parameters = [ 27 | {'params': [p for n, p in scoring_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 28 | {'params': [p for n, p in scoring_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 29 | ] 30 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) 31 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs 32 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 33 | 34 | # Start training! 35 | global_step = 0 36 | 37 | scoring_model.zero_grad() 38 | torch.cuda.empty_cache() 39 | 40 | train_iterator = trange(int(args.num_epochs), desc="Epoch") 41 | 42 | for epoch in train_iterator: 43 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 44 | 45 | # to compute average loss in an epoch 46 | train_loss_list = [] 47 | 48 | print(f"**********Train: epoch {epoch}**********") 49 | for step, batch in enumerate(epoch_iterator): 50 | 51 | scoring_model.train() 52 | 53 | if torch.cuda.is_available(): 54 | batch = tuple(t.cuda() for t in batch) 55 | 56 | inputs = {'input_ids': batch[0], 57 | 'attention_mask': batch[1] 58 | } 59 | 60 | scores = scoring_model(**inputs, labels=batch[2]) 61 | 62 | train_loss = scores.loss 63 | 64 | # loss 65 | #criterion = MSELoss() 66 | #train_loss = criterion(answer_label.float(), scores.float()) 67 | train_loss_list.append(train_loss.detach().cpu().numpy()) 68 | 69 | # print loss every 1000 steps 70 | if step % 100 == 0 and step > 99: 71 | epoch_average_loss = sum(train_loss_list[step-100:step]) / 99 72 | print(f'step: {step} with loss: {epoch_average_loss}') 73 | 74 | train_loss = train_loss / args.gradient_accumulation_steps 75 | train_loss.backward() 76 | 77 | if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(epoch_iterator)): 78 | optimizer.step() 79 | scheduler.step() 80 | scoring_model.zero_grad() 81 | 82 | global_step += 1 83 | 84 | torch.cuda.empty_cache() 85 | 86 | print("**********EVALUATION**********") 87 | with torch.no_grad(): 88 | scoring_model.eval() 89 | 90 | input_ids_lst = val_emb['input_ids'] 91 | attention_mask_lst = val_emb['attention_mask'] 92 | 93 | val_result_lst = [] 94 | val_loss_lst = [] 95 | 96 | for i in range(0,len(validation_labels)): 97 | input_ids = torch.tensor(input_ids_lst[i]).cuda().unsqueeze(0) 98 | attention_mask = torch.tensor(attention_mask_lst[i]).cuda().unsqueeze(0) 99 | label = torch.tensor(validation_labels[i]).cuda().unsqueeze(0) 100 | 101 | val_output = scoring_model(input_ids, attention_mask, labels=label) 102 | val_loss_lst.append(val_output.loss.cpu()) 103 | val_result = np.argmax(val_output.logits.cpu()) 104 | val_result_lst.append(val_result) #(num_query, emb_dim) 105 | 106 | print(validation_labels, val_result_lst) 107 | pearson_cor = pearsonr(validation_labels, val_result_lst)[0] 108 | val_loss = (sum(val_result_lst) / len(validation_labels)).item() 109 | 110 | if pearson_cor > best_cor: 111 | stop_counter = 0 112 | best_cor = pearson_cor 113 | 114 | scoring_model.save_pretrained(f'checkpoints/val_cor_{pearson_cor:4.2%}_sen_encoder') 115 | 116 | else: 117 | stop_counter += 1 118 | print(f"early stop count {stop_counter} out of {args.early_stop}") 119 | if args.early_stop == stop_counter: 120 | break 121 | 122 | print("epoch loss:", val_loss) 123 | print("epoch pearson correaltion:", pearson_cor) 124 | print("best cor from all epochs", best_cor) 125 | -------------------------------------------------------------------------------- /evaluation/scoring_with_bert/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import os 5 | 6 | import argparse 7 | import pickle 8 | 9 | # for reproducibility 10 | def seed_everything(seed: int = 2021): 11 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 12 | 13 | torch.manual_seed(seed) 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | if device == "cuda:0": 18 | torch.cuda.manual_seed(seed) # type: ignore 19 | torch.backends.cudnn.deterministic = True # type: ignore 20 | torch.backends.cudnn.benchmark = True 21 | 22 | def sort_by_length_with_reference(reference: list, dataset: object): 23 | 24 | sorted_dataset = {} 25 | idx_lst = np.argsort(reference) 26 | 27 | key_lst = dataset.features.keys() 28 | 29 | for key in key_lst: 30 | sorted_dataset[key] = [] 31 | 32 | for key in key_lst: 33 | temp_lst = sorted_dataset[key] 34 | original_lst = dataset[key] 35 | for i in idx_lst: 36 | temp_lst.append(original_lst[int(i)]) 37 | 38 | return sorted_dataset -------------------------------------------------------------------------------- /evaluation/sentencebert/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser('Image Classification', add_help=False) 5 | parser.add_argument('--seed', default=2021, type=int) 6 | parser.add_argument('--lr', default=2e-5, type=float) 7 | parser.add_argument('--train_bs', default=32, type=int) 8 | parser.add_argument('--num_epochs', default=10, type=int) 9 | parser.add_argument('--weight_decay', default=0.01, type=float) 10 | parser.add_argument('--early_stop', default=3, type=int) 11 | parser.add_argument('--adam_epsilon', default=1e-08, type=float) 12 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int) 13 | parser.add_argument('--warmup_steps', default=8, type=int) 14 | parser.add_argument('--model', default='klue/roberta-small', type=str) 15 | parser.add_argument('--data', default='klue/sts', type=str) 16 | 17 | return parser -------------------------------------------------------------------------------- /evaluation/sentencebert/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import AutoTokenizer\n", 10 | "from model import RobertaEncoder\n", 11 | "import torch" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 21, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "tokenizer = AutoTokenizer.from_pretrained('klue/roberta-small')\n", 21 | "sen_encoder = RobertaEncoder.from_pretrained('checkpoints/val_acc_58.88%_sen_encoder').cuda()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 51, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def checker(title, text):\n", 31 | "\n", 32 | " val_u_seqs = tokenizer(title)\n", 33 | " val_v_seqs = tokenizer(text)\n", 34 | "\n", 35 | " u_input_ids_lst = val_u_seqs['input_ids']\n", 36 | " u_attention_mask_lst = val_u_seqs['attention_mask']\n", 37 | "\n", 38 | " v_input_ids_lst = val_v_seqs['input_ids']\n", 39 | " v_attention_mask_lst = val_v_seqs['attention_mask']\n", 40 | "\n", 41 | " with torch.no_grad():\n", 42 | " sen_encoder.eval()\n", 43 | " sen_encoder.eval()\n", 44 | " \n", 45 | " u_input_ids = torch.tensor(u_input_ids_lst).cuda().unsqueeze(0)\n", 46 | " u_attention_mask = torch.tensor(u_attention_mask_lst).cuda().unsqueeze(0)\n", 47 | " v_input_ids = torch.tensor(v_input_ids_lst).cuda().unsqueeze(0)\n", 48 | " v_attention_mask = torch.tensor(v_attention_mask_lst).cuda().unsqueeze(0)\n", 49 | " \n", 50 | " inputs = {'input_ids_1': u_input_ids,\n", 51 | " 'input_ids_2': v_input_ids,\n", 52 | " 'attention_mask_1': u_attention_mask,\n", 53 | " 'attention_mask_2': v_attention_mask\n", 54 | " }\n", 55 | "\n", 56 | " val_cos_sim = sen_encoder(**inputs).to('cpu')\n", 57 | " return val_cos_sim.item()\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 31, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "checker('title', 'text')" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "interpreter": { 72 | "hash": "8cf67b00fe44a7bbb95b84fa7b31b362c005e968b891258bc8597366f52a1c53" 73 | }, 74 | "kernelspec": { 75 | "display_name": "Python 3.9.7 64-bit ('bc': conda)", 76 | "language": "python", 77 | "name": "python3" 78 | }, 79 | "language_info": { 80 | "codemirror_mode": { 81 | "name": "ipython", 82 | "version": 3 83 | }, 84 | "file_extension": ".py", 85 | "mimetype": "text/x-python", 86 | "name": "python", 87 | "nbconvert_exporter": "python", 88 | "pygments_lexer": "ipython3", 89 | "version": "3.9.7" 90 | }, 91 | "orig_nbformat": 4 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 2 95 | } 96 | -------------------------------------------------------------------------------- /evaluation/sentencebert/main.py: -------------------------------------------------------------------------------- 1 | # for arguments 2 | import argparse 3 | from arguments import get_args_parser 4 | 5 | # torch 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | 9 | # settings 10 | from utils import seed_everything 11 | 12 | # hf 13 | from transformers import AutoTokenizer 14 | from datasets import load_dataset 15 | 16 | # basics 17 | import json 18 | import pickle 19 | import os 20 | from tqdm import tqdm 21 | 22 | from model import RobertaEncoder 23 | from train import train 24 | 25 | 26 | def main(args): 27 | # model for use 28 | model_checkpoint = args.model 29 | 30 | if args.data == 'klue/sts': 31 | print("Downloading Klue STS Data") 32 | datasets = load_dataset("klue", "sts") 33 | 34 | # loading train and validation data 35 | train_data = datasets['validation'] 36 | validation_data = datasets['validation'] 37 | 38 | train_labels = torch.tensor([line['label'] for line in train_data['labels']]) 39 | validation_labels = torch.tensor([line['label'] for line in validation_data['labels']]) 40 | 41 | # tokenize 42 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 43 | ## train_data 44 | train_u_seqs = tokenizer(train_data['sentence1'], padding="max_length", truncation=True, return_tensors='pt') 45 | train_v_seqs = tokenizer(train_data['sentence2'], padding="max_length", truncation=True, return_tensors='pt') 46 | 47 | # validation_data 48 | val_u_seqs = tokenizer(validation_data['sentence1']) 49 | val_v_seqs = tokenizer(validation_data['sentence2']) 50 | 51 | # create train dataset 52 | train_dataset = TensorDataset(train_u_seqs['input_ids'], train_v_seqs['input_ids'], 53 | train_u_seqs['attention_mask'], train_v_seqs['attention_mask'], train_labels) 54 | 55 | elif args.data == "poem_data": 56 | with open("../dataset.bin", "rb") as fp: #Pickling 57 | datasets = pickle.load(fp) 58 | 59 | train_data = datasets['train'] 60 | validation_data = datasets['validation'] 61 | 62 | train_labels = torch.tensor([n for n in train_data['score']]) 63 | validation_labels = torch.tensor([n for n in validation_data['score']]) 64 | 65 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 66 | 67 | train_u_seqs = tokenizer(train_data['title'], padding="max_length", truncation=True, return_tensors='pt') 68 | train_v_seqs = tokenizer(train_data['text'], padding="max_length", truncation=True, return_tensors='pt') 69 | 70 | val_u_seqs = tokenizer(validation_data['title']) 71 | val_v_seqs = tokenizer(validation_data['text']) 72 | 73 | train_dataset = TensorDataset(train_u_seqs['input_ids'], train_v_seqs['input_ids'], 74 | train_u_seqs['attention_mask'], train_v_seqs['attention_mask'], train_labels) 75 | 76 | # load model 77 | sen_encoder = RobertaEncoder.from_pretrained('klue/roberta-small').cuda() 78 | 79 | train(args, train_dataset, val_u_seqs, val_v_seqs, validation_labels, sen_encoder) 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser(description='KoSentence-RoBerta', parents=[get_args_parser()]) 83 | args = parser.parse_args() 84 | seed_everything(args.seed) 85 | main(args) -------------------------------------------------------------------------------- /evaluation/sentencebert/model.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaModel, RobertaPreTrainedModel 2 | from torch import mean, nn 3 | 4 | class RobertaEncoder(RobertaPreTrainedModel): 5 | def __init__(self, config): 6 | super(RobertaEncoder, self).__init__(config) 7 | 8 | self.roberta = RobertaModel(config) 9 | self.cos_sim = nn.CosineSimilarity(dim=1) 10 | self.init_weights() 11 | 12 | def forward(self, input_ids_1, input_ids_2, attention_mask_1=None, attention_mask_2=None): 13 | 14 | outputs_1 = self.roberta(input_ids_1 ,attention_mask=attention_mask_1) 15 | outputs_2 = self.roberta(input_ids_2 ,attention_mask=attention_mask_2) 16 | 17 | sequence_outputs_1 = outputs_1[0] 18 | sequence_outputs_2 = outputs_2[0] 19 | 20 | u = mean(sequence_outputs_1,1) 21 | v = mean(sequence_outputs_2,1) 22 | 23 | cos_sim = self.cos_sim(u,v) 24 | 25 | return cos_sim 26 | -------------------------------------------------------------------------------- /evaluation/sentencebert/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import (DataLoader, RandomSampler) 3 | from torch.nn import MSELoss 4 | 5 | from transformers import AdamW, get_linear_schedule_with_warmup 6 | from transformers import AutoTokenizer 7 | 8 | from tqdm import trange, tqdm 9 | 10 | import pickle 11 | 12 | from scipy.stats import pearsonr 13 | 14 | def train(args, train_dataset, val_u_seqs, val_v_seqs, validation_labels, sen_encoder): 15 | 16 | # logging 17 | best_cor = 0 18 | stop_counter = 0 19 | 20 | # Dataloader 21 | train_sampler = RandomSampler(train_dataset) 22 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_bs) 23 | 24 | # Optimizer 25 | no_decay = ['bias', 'LayerNorm.weight'] 26 | optimizer_grouped_parameters = [ 27 | {'params': [p for n, p in sen_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 28 | {'params': [p for n, p in sen_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 29 | ] 30 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) 31 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs 32 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 33 | 34 | # Start training! 35 | global_step = 0 36 | 37 | sen_encoder.zero_grad() 38 | torch.cuda.empty_cache() 39 | 40 | train_iterator = trange(int(args.num_epochs), desc="Epoch") 41 | 42 | for epoch in train_iterator: 43 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 44 | 45 | # to compute average loss in an epoch 46 | train_loss_list = [] 47 | 48 | print(f"**********Train: epoch {epoch}**********") 49 | for step, batch in enumerate(epoch_iterator): 50 | 51 | sen_encoder.train() 52 | 53 | if torch.cuda.is_available(): 54 | batch = tuple(t.cuda() for t in batch) 55 | 56 | inputs = {'input_ids_1': batch[0], 57 | 'input_ids_2': batch[1], 58 | 'attention_mask_1': batch[2], 59 | 'attention_mask_2': batch[3] 60 | } 61 | 62 | targets = batch[4] 63 | cos_sim_outputs = (sen_encoder(**inputs) * (5/2)) + 2.5 64 | 65 | # loss 66 | criterion = MSELoss() 67 | train_loss = criterion(cos_sim_outputs, targets.float()) 68 | train_loss_list.append(train_loss.detach().cpu().numpy()) 69 | 70 | # print loss every 1000 steps 71 | if step % 500 == 0 and step > 99: 72 | epoch_average_loss = sum(train_loss_list[step-100:step]) / 99 73 | print(f'step: {step} with loss: {epoch_average_loss}') 74 | 75 | train_loss = train_loss / args.gradient_accumulation_steps 76 | train_loss.backward() 77 | 78 | if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(epoch_iterator)): 79 | optimizer.step() 80 | scheduler.step() 81 | sen_encoder.zero_grad() 82 | 83 | 84 | 85 | global_step += 1 86 | 87 | torch.cuda.empty_cache() 88 | 89 | print("**********EVALUATION**********") 90 | with torch.no_grad(): 91 | sen_encoder.eval() 92 | sen_encoder.eval() 93 | 94 | u_input_ids_lst = val_u_seqs['input_ids'] 95 | u_attention_mask_lst = val_u_seqs['attention_mask'] 96 | 97 | v_input_ids_lst = val_v_seqs['input_ids'] 98 | v_attention_mask_lst = val_v_seqs['attention_mask'] 99 | 100 | val_cos_sim_lst = [] 101 | 102 | for i in range(0,len(validation_labels)): 103 | u_input_ids = torch.tensor(u_input_ids_lst[i]).cuda().unsqueeze(0) 104 | u_attention_mask = torch.tensor(u_attention_mask_lst[i]).cuda().unsqueeze(0) 105 | v_input_ids = torch.tensor(v_input_ids_lst[i]).cuda().unsqueeze(0) 106 | v_attention_mask = torch.tensor(v_attention_mask_lst[i]).cuda().unsqueeze(0) 107 | 108 | inputs = {'input_ids_1': u_input_ids, 109 | 'input_ids_2': v_input_ids, 110 | 'attention_mask_1': u_attention_mask, 111 | 'attention_mask_2': v_attention_mask 112 | } 113 | 114 | val_cos_sim = sen_encoder(**inputs).to('cpu') 115 | val_cos_sim_lst.append(val_cos_sim.item()*5/2+2.5) #(num_query, emb_dim) 116 | 117 | pearson_cor = pearsonr(validation_labels.tolist(), val_cos_sim_lst)[0] 118 | 119 | if pearson_cor > best_cor: 120 | stop_counter = 0 121 | best_cor = pearson_cor 122 | 123 | sen_encoder.save_pretrained(f'checkpoints/val_acc_{pearson_cor:4.2%}_sen_encoder') 124 | 125 | else: 126 | stop_counter += 1 127 | print(f"early stop count {stop_counter} out of {args.early_stop}") 128 | if args.early_stop == stop_counter: 129 | break 130 | 131 | print("epoch pearson correaltion:", pearson_cor) 132 | print("best cor from all epochs", best_cor) 133 | -------------------------------------------------------------------------------- /model/caption_model/utils.py: -------------------------------------------------------------------------------- 1 | ## 2 | -------------------------------------------------------------------------------- /model/caption_model/vit_gpt2/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class COCODataset(Dataset): 5 | def __init__(self, img_lst, labels) -> None: 6 | super().__init__() 7 | self.img_lst = img_lst 8 | self.labels = labels 9 | 10 | def __len__(self): 11 | return len(self.img_lst) 12 | 13 | def __getitem__(self, index): 14 | item = { 15 | "pixel_values": self.img_lst[index].squeeze(), 16 | "labels": self.labels[index], 17 | } 18 | return item 19 | -------------------------------------------------------------------------------- /model/caption_model/vit_gpt2/run_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import torch 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | from nltk.translate.bleu_score import corpus_bleu 7 | 8 | from caption_model.vit_gpt2.utils import ( 9 | read_json, 10 | get_data_df, 11 | get_pixel_values_and_tokenized_labels, 12 | ) 13 | from transformers import ( 14 | VisionEncoderDecoderModel, 15 | ViTFeatureExtractor, 16 | PreTrainedTokenizerFast, 17 | Seq2SeqTrainingArguments, 18 | Seq2SeqTrainer, 19 | default_data_collator, 20 | ) 21 | from caption_model.vit_gpt2.dataset import COCODataset 22 | 23 | 24 | def train(args): 25 | 26 | # 데이터 로드 및 feature_extractor, tokenizer 선언 27 | coco = read_json(args.ms_coco_kor_file_path) 28 | coco_df = get_data_df(coco, args.data_dir) 29 | coco_data = coco_df 30 | train_df, valid_df = train_test_split(coco_data, test_size=0.2, random_state=42) 31 | train_df = train_df.reset_index() 32 | valid_df = valid_df.reset_index() 33 | 34 | feature_extractor = ViTFeatureExtractor.from_pretrained( 35 | args.encoder_model_name_or_path 36 | ) 37 | tokenizer = PreTrainedTokenizerFast.from_pretrained( 38 | args.args.decoder_model_name_or_path, 39 | bos_token="", 40 | eos_token="", 41 | unk_token="", 42 | pad_token="", 43 | mask_token="", 44 | ) 45 | 46 | # feature, label 생성 및 caching 47 | train_pixel, train_labels = get_pixel_values_and_tokenized_labels( 48 | df=train_df, feature_extractor=feature_extractor, tokenizer=tokenizer 49 | ) 50 | valid_pixel, valid_labels = get_pixel_values_and_tokenized_labels( 51 | df=valid_df, feature_extractor=feature_extractor, tokenizer=tokenizer 52 | ) 53 | 54 | # dataset load 55 | train_dataset = COCODataset(train_pixel, train_labels) 56 | valid_dataset = COCODataset(valid_pixel, valid_labels) 57 | 58 | model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( 59 | args.encoder_model_name_or_path, args.decoder_model_name_or_path 60 | ) 61 | 62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | model.to(device) 64 | 65 | def compute_metrics(pred): 66 | """validation을 위한 metrics function""" 67 | """decode후에 bleu4를 측정하기 때문에 nested function으로 선언(tokenizer 필요)""" 68 | labels = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) 69 | preds = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) 70 | # labels -> [sen1, sen2, sen3 ...] 71 | # list_of_references -> [[sen1],[sen2],[sen3]...] 72 | list_of_references = [] 73 | for i in range(len(labels)): 74 | list_of_references.append([labels[i]]) 75 | # calculate blue4 76 | blue4 = corpus_bleu(list_of_references=list_of_references, hypotheses=preds) 77 | return {"bleu4": blue4} 78 | 79 | training_args = Seq2SeqTrainingArguments( 80 | output_dir=args.output_dir, 81 | predict_with_generate=True, 82 | evaluation_strategy=args.evaluation_strategy, 83 | per_device_train_batch_size=args.batch_size, 84 | per_device_eval_batch_size=args.batch_size, 85 | gradient_accumulation_steps=args.gradient_accumulation_steps, 86 | overwrite_output_dir=True, 87 | fp16=True, 88 | load_best_model_at_end=True, 89 | metric_for_best_model="bleu4", 90 | logging_steps=args.logging_steps, 91 | save_steps=args.save_steps, 92 | eval_steps=args.eval_steps, 93 | num_train_epochs=args.num_train_epochs, 94 | save_total_limit=args.save_total_limit, 95 | ) 96 | 97 | trainer = Seq2SeqTrainer( 98 | model=model, 99 | tokenizer=feature_extractor, 100 | args=training_args, 101 | train_dataset=train_dataset, 102 | eval_dataset=valid_dataset, 103 | compute_metrics=compute_metrics, 104 | data_collator=default_data_collator, 105 | ) 106 | trainer.train() 107 | 108 | model.save_pretrained(args.output_dir) 109 | feature_extractor.save_pretrained(args.output_dir) 110 | tokenizer.save_pretrained(args.output_dir) 111 | -------------------------------------------------------------------------------- /model/caption_model/vit_gpt2/tt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/model/caption_model/vit_gpt2/tt -------------------------------------------------------------------------------- /model/caption_model/vit_gpt2/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | 8 | def read_json(file_path): 9 | with open(file_path) as f: 10 | return json.load(f) 11 | 12 | 13 | def get_data_df(coco_data: json, data_dir: str) -> pd.DataFrame: 14 | """ 15 | MSCOCO_train_val_Korea.json과 16 | 해당 파일이 있는 경로를 입력받아 17 | 실제 사진이 있는 path와 캡션 label등을 df로 넘겨줍니다. 18 | 경로 ex) 19 | caption_data/train2014/image 20 | caption_data/valid2014/image 21 | caption_data/MSCOCO_train_val_Korea.json 22 | """ 23 | img_path = [] 24 | data_id = [] 25 | total_caption_lst = [] 26 | data_dir = data_dir + "/" 27 | for i in range(len(coco_data)): 28 | # 캡션 5개 미만이면 추가하지 않음 29 | if len(coco_data[i]["caption_ko"]) < 5: 30 | continue 31 | # img path 추가 32 | img_path.append(data_dir + coco_data[i]["file_path"]) 33 | data_id.append(coco_data[i]["id"]) 34 | 35 | # img path와 매칭되는 caption 5개 추가 36 | caption_lst = [] 37 | for j in range(5): 38 | caption_lst.append(coco_data[i]["caption_ko"][j]) 39 | total_caption_lst.append(caption_lst) 40 | 41 | coco_df = pd.DataFrame(data={"labels": total_caption_lst, "img_paths": img_path}) 42 | return coco_df 43 | 44 | 45 | def get_pixel_values_and_tokenized_labels(df, feature_extractor, tokenizer): 46 | # 이미지 캐싱 47 | img_lst = [] 48 | for i in tqdm(range(len(df)), "img_cache"): 49 | image = Image.open(df["img_paths"][i]).convert("RGB") 50 | image_tensor = np.array(image) 51 | pixel_values = feature_extractor(image_tensor, return_tensors="pt").pixel_values 52 | img_lst.append(pixel_values) 53 | # 캐싱된 이미지를 5배 해줌 -> 메모리의 이미지 객체의 주소만 넘기므로, 메모리 문제는 없음 54 | img_for_matching_captions = [] 55 | for i in tqdm(range(5), "img extend"): 56 | img_for_matching_captions.extend(img_lst) 57 | 58 | # 캐싱된 이미지의 인덱스에 맞추어서 label들을 리스트에 넣고 tokenizing을 해줌 59 | # [iamge1, image2, image3, ... image1, image2, image3 ...] 60 | # [label1, label2, label3, ... label1, label2, label3 ...] 61 | labels_for_matching_img = [] 62 | for i in tqdm(range(5), "tokenizing"): 63 | labels = [] 64 | for j in range(len(df)): 65 | labels.append(df["labels"][j][i]) 66 | labels_for_matching_img.extend(labels) 67 | tokenized_labels = tokenizer( 68 | labels_for_matching_img, return_tensors="pt", padding=True, truncation=True 69 | ).input_ids 70 | return img_for_matching_captions, tokenized_labels 71 | -------------------------------------------------------------------------------- /model/gpt2_base_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from poem_model.gpt2_base.run_train import train 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | # data_arg 7 | parser.add_argument("--output_dir", type=str, default="./gpt2_base") 8 | parser.add_argument( 9 | "--data_dir", type=str, default="../data/poem_data/preprocess_data" 10 | ) 11 | parser.add_argument("--model_name_or_path", type=str, default="skt/kogpt2-base-v2") 12 | parser.add_argument("--train_filename", type=str, default="poem_with_keyowrd.csv") 13 | 14 | # train_arg 15 | parser.add_argument("--num_labels", type=int, default=3) 16 | parser.add_argument("--seed", type=int, default=42) 17 | parser.add_argument("--num_train_epochs", type=int, default=5) 18 | parser.add_argument("--batch_size", type=int, default=64) 19 | parser.add_argument("--learning_rate", type=float, default=5e-5) 20 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 21 | parser.add_argument("--weight_decay", type=float, default=0.01) 22 | 23 | # eval_arg 24 | parser.add_argument("--evaluation_strategy", type=str, default="steps") 25 | parser.add_argument("--logging_steps", type=int, default=500) 26 | parser.add_argument("--save_steps", type=int, default=500) 27 | parser.add_argument("--eval_steps", type=int, default=500) 28 | parser.add_argument("--save_total_limit", type=int, default=2) 29 | 30 | args = parser.parse_args() 31 | train(args) 32 | -------------------------------------------------------------------------------- /model/poem_model/gpt2_base/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class PoemDataset(Dataset): 5 | def __init__(self, data) -> None: 6 | super().__init__() 7 | self.data = data 8 | 9 | def __getitem__(self, index): 10 | item = {k: v[index] for k, v in self.data.items()} 11 | item["labels"] = self.data.input_ids[index] 12 | return item 13 | 14 | def __len__(self): 15 | return len(self.data.input_ids) 16 | -------------------------------------------------------------------------------- /model/poem_model/gpt2_base/run_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import ( 3 | AutoModelForCausalLM, 4 | TrainingArguments, 5 | Trainer, 6 | PreTrainedTokenizerFast, 7 | ) 8 | import pandas as pd 9 | from sklearn.model_selection import train_test_split 10 | import torch 11 | import os 12 | from poem_model.gpt2_base.utils import get_tagged_data 13 | from poem_model.gpt2_base.dataset import PoemDataset 14 | 15 | 16 | def train(args): 17 | 18 | # 토크나이저 선언 19 | tokenizer = PreTrainedTokenizerFast.from_pretrained( 20 | args.model_name_or_path, 21 | bos_token="", 22 | eos_token="", 23 | unk_token="", 24 | pad_token="", 25 | mask_token="", 26 | ) 27 | # 새로운 스페셜 토큰 생성 (키워드 토큰) 28 | keyword_start_marker = "" 29 | keyword_end_marker = "" 30 | tokenizer.add_special_tokens( 31 | {"additional_special_tokens": [keyword_start_marker, keyword_end_marker]} 32 | ) 33 | 34 | # keyword가 담겨있는 데이터를 로드 35 | data_path = os.path.join(args.data_dir, args.train_filename) 36 | poem_df = pd.read_csv(data_path) 37 | train_poem, valid_poem = train_test_split(poem_df, test_size=0.1, random_state=42) 38 | 39 | train_poem = train_poem.reset_index() 40 | valid_poem = valid_poem.reset_index() 41 | 42 | # 키워드 추출이 안돼서(명사 추출이 안돼서) keyword가 None인 경우가 존재 43 | # 그 경우 train, valid data에서 제외 44 | train_data = get_tagged_data(train_poem) 45 | valid_data = get_tagged_data(valid_poem) 46 | 47 | # 시 토크나이즈 48 | train_data = tokenizer(train_data, padding=True, return_tensors="pt") 49 | valid_data = tokenizer(valid_data, padding=True, return_tensors="pt") 50 | 51 | # 데이터셋 52 | train_dataset = PoemDataset(train_data) 53 | valid_dataset = PoemDataset(valid_data) 54 | 55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 56 | 57 | # 스페셜 토큰만큼 모델 리사이즈 58 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) 59 | model.resize_token_embeddings(tokenizer.vocab_size + 2) 60 | model.to(device) 61 | 62 | training_args = TrainingArguments( 63 | output_dir=args.output_dir, 64 | predict_with_generate=True, 65 | evaluation_strategy="steps", 66 | per_device_train_batch_size=args.batch_size, 67 | per_device_eval_batch_size=args.batch_size, 68 | gradient_accumulation_steps=args.gradient_accumulation_steps, 69 | overwrite_output_dir=True, 70 | fp16=True, 71 | load_best_model_at_end=True, 72 | logging_steps=args.logging_steps, 73 | save_steps=args.save_steps, 74 | eval_steps=args.eval_steps, 75 | num_train_epochs=args.num_train_epochs, 76 | save_total_limit=args.save_total_limit, 77 | ) 78 | 79 | trainer = Trainer( 80 | model=model, 81 | tokenizer=tokenizer, 82 | args=training_args, 83 | train_dataset=train_dataset, 84 | eval_dataset=valid_dataset, 85 | # compute_metrics = compute_metrics, 86 | ) 87 | trainer.train() 88 | model.save_pretrained(args.output_dir) 89 | tokenizer.save_pretrained(args.output_dir) 90 | -------------------------------------------------------------------------------- /model/poem_model/gpt2_base/t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/model/poem_model/gpt2_base/t -------------------------------------------------------------------------------- /model/poem_model/gpt2_base/utils.py: -------------------------------------------------------------------------------- 1 | def tag_keyword_and_special_token(df): 2 | # k1, k2, k3 poem 3 | keyword_start_marker = "" 4 | keyword_end_marker = "" 5 | text = df["text"] 6 | keyword = df["key_word"] 7 | tagged_text = keyword_start_marker + keyword + keyword_end_marker + text 8 | return tagged_text 9 | 10 | 11 | def get_tagged_data(df): 12 | data = [] 13 | 14 | for i in range(len(df)): 15 | try: 16 | data.append(tag_keyword_and_special_token(df.iloc[i])) 17 | except: 18 | continue 19 | return data 20 | -------------------------------------------------------------------------------- /model/poem_model/gpt2_trinity/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | from transformers import ( 6 | MODEL_FOR_CAUSAL_LM_MAPPING, 7 | TrainingArguments as _TrainingArguments, 8 | ) 9 | 10 | from transformers.trainer_utils import IntervalStrategy, SchedulerType 11 | 12 | 13 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 14 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 15 | 16 | 17 | @dataclass 18 | class ModelNames: 19 | kogpt_skt_base = "skt/kogpt2-base-v2" 20 | kogpt_skt_trinity = "skt/ko-gpt-trinity-1.2B-v0.5" 21 | gpt_poem = "CheonggyeMountain-Sherpa/kogpt-trinity-poem" 22 | 23 | 24 | REVISIONS = { 25 | ModelNames.gpt_poem: "main", 26 | ModelNames.kogpt_skt_base: "main", 27 | ModelNames.kogpt_skt_trinity: "main", 28 | } 29 | 30 | MODEL_NAME = ModelNames.kogpt_skt_trinity 31 | 32 | 33 | @dataclass 34 | class ModelArguments: 35 | """ 36 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 37 | """ 38 | 39 | model_name_or_path: Optional[str] = field( 40 | default=MODEL_NAME, 41 | metadata={ 42 | "help": "The model checkpoint for weights initialization." 43 | "Don't set if you want to train a model from scratch." 44 | }, 45 | ) 46 | model_revision: str = field( 47 | default=REVISIONS[MODEL_NAME], 48 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 49 | ) 50 | model_type: Optional[str] = field( 51 | default=None, 52 | metadata={ 53 | "help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES) 54 | }, 55 | ) 56 | config_overrides: Optional[str] = field( 57 | default=None, 58 | metadata={ 59 | "help": "Override some existing default config settings when a model is trained from scratch. Example: " 60 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 61 | }, 62 | ) 63 | config_name: Optional[str] = field( 64 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 65 | ) 66 | tokenizer_name: Optional[str] = field( 67 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 68 | ) 69 | cache_dir: Optional[str] = field( 70 | default=None, 71 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 72 | ) 73 | use_fast_tokenizer: bool = field( 74 | default=True, 75 | metadata={ 76 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 77 | }, 78 | ) 79 | use_auth_token: bool = field( 80 | default=False, 81 | metadata={ 82 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 83 | "with private models)." 84 | }, 85 | ) 86 | 87 | def __post_init__(self): 88 | if self.config_overrides is not None and ( 89 | self.config_name is not None or self.model_name_or_path is not None 90 | ): 91 | raise ValueError( 92 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 93 | ) 94 | 95 | 96 | @dataclass 97 | class DataTrainingArguments: 98 | """ 99 | Arguments pertaining to what data we are going to input our model for training and eval. 100 | """ 101 | 102 | dataset_name: Optional[str] = field( 103 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 104 | ) 105 | dataset_config_name: Optional[str] = field( 106 | default=None, 107 | metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, 108 | ) 109 | train_file: Optional[str] = field( 110 | default="../../../data/poem_data/preprocess_data/total_train.csv", metadata={"help": "The input training data file (a text file)."} 111 | ) 112 | validation_file: Optional[str] = field( 113 | default=None, 114 | metadata={ 115 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." 116 | }, 117 | ) 118 | max_train_samples: Optional[int] = field( 119 | default=None, 120 | metadata={ 121 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 122 | "value if set." 123 | }, 124 | ) 125 | max_eval_samples: Optional[int] = field( 126 | default=None, 127 | metadata={ 128 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 129 | "value if set." 130 | }, 131 | ) 132 | 133 | block_size: Optional[int] = field( 134 | default=None, 135 | metadata={ 136 | "help": "Optional input sequence length after tokenization. " 137 | "The training dataset will be truncated in block of this size for training. " 138 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 139 | }, 140 | ) 141 | overwrite_cache: bool = field( 142 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 143 | ) 144 | validation_split_percentage: Optional[int] = field( 145 | default=10, 146 | metadata={ 147 | "help": "The percentage of the train set used as validation set in case there's no validation split" 148 | }, 149 | ) 150 | preprocessing_num_workers: Optional[int] = field( 151 | default=None, 152 | metadata={"help": "The number of processes to use for the preprocessing."}, 153 | ) 154 | keep_linebreaks: bool = field( 155 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 156 | ) 157 | 158 | def __post_init__(self): 159 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 160 | raise ValueError("Need either a dataset name or a training/validation file.") 161 | else: 162 | if self.train_file is not None: 163 | extension = self.train_file.split(".")[-1] 164 | assert extension in [ 165 | "csv", 166 | "json", 167 | "txt", 168 | ], "`train_file` should be a csv, a json or a txt file." 169 | if self.validation_file is not None: 170 | extension = self.validation_file.split(".")[-1] 171 | assert extension in [ 172 | "csv", 173 | "json", 174 | "txt", 175 | ], "`validation_file` should be a csv, a json or a txt file." 176 | 177 | 178 | @dataclass 179 | class TrainingArguments(_TrainingArguments): 180 | output_dir: str = field( 181 | default="outputs", 182 | metadata={ 183 | "help": "The output directory where the model predictions and checkpoints will be written." 184 | }, 185 | ) 186 | overwrite_output_dir: bool = field( 187 | default=True, 188 | metadata={ 189 | "help": ( 190 | "Overwrite the content of the output directory. " 191 | "Use this to continue training if output_dir points to a checkpoint directory." 192 | ) 193 | }, 194 | ) 195 | 196 | do_train: bool = field(default=True, metadata={"help": "Whether to run training."}) 197 | do_eval: bool = field(default=True, metadata={"help": "Whether to run eval on the dev set."}) 198 | do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) 199 | evaluation_strategy: IntervalStrategy = field( 200 | default="epoch", 201 | metadata={"help": "The evaluation strategy to use."}, 202 | ) 203 | 204 | per_device_train_batch_size: int = field( 205 | default=1, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} 206 | ) 207 | per_device_eval_batch_size: int = field( 208 | default=1, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} 209 | ) 210 | 211 | gradient_accumulation_steps: int = field( 212 | default=128, 213 | metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, 214 | ) 215 | 216 | learning_rate: float = field(default=1e-5, metadata={"help": "The initial learning rate for AdamW."}) 217 | weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) 218 | adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) 219 | adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) 220 | adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) 221 | max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) 222 | num_train_epochs: float = field( 223 | default=40.0, metadata={"help": "Total number of training epochs to perform."} 224 | ) 225 | lr_scheduler_type: SchedulerType = field( 226 | default=SchedulerType.LINEAR, 227 | metadata={"help": "The scheduler type to use."}, 228 | ) 229 | warmup_ratio: float = field( 230 | default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} 231 | ) 232 | 233 | save_strategy: IntervalStrategy = field( 234 | default="epoch", 235 | metadata={"help": "The checkpoint save strategy to use."}, 236 | ) 237 | save_total_limit: Optional[int] = field( 238 | default=1, 239 | metadata={ 240 | "help": ( 241 | "Limit the total amount of checkpoints. " 242 | "Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints" 243 | ) 244 | }, 245 | ) 246 | seed: int = field( 247 | default=42, metadata={"help": "Random seed that will be set at the beginning of training."} 248 | ) 249 | 250 | fp16: bool = field( 251 | default=True, 252 | metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"}, 253 | ) 254 | fp16_opt_level: str = field( 255 | default="O1", 256 | metadata={ 257 | "help": ( 258 | "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " 259 | "See details at https://nvidia.github.io/apex/amp.html" 260 | ) 261 | }, 262 | ) 263 | 264 | dataloader_drop_last: bool = field( 265 | default=False, 266 | metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}, 267 | ) 268 | dataloader_num_workers: int = field( 269 | default=8, 270 | metadata={ 271 | "help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process." 272 | }, 273 | ) 274 | 275 | load_best_model_at_end: Optional[bool] = field( 276 | default=True, 277 | metadata={ 278 | "help": "Whether or not to load the best model found during training at the end of training." 279 | }, 280 | ) 281 | metric_for_best_model: Optional[str] = field( 282 | default=None, metadata={"help": "The metric to use to compare two different models."} 283 | ) 284 | greater_is_better: Optional[bool] = field( 285 | default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} 286 | ) 287 | group_by_length: bool = field( 288 | default=False, 289 | metadata={ 290 | "help": "Whether or not to group samples of roughly the same length together when batching." 291 | }, 292 | ) 293 | push_to_hub: bool = field( 294 | default=False, 295 | metadata={"help": "Whether or not to upload the trained model to the model hub after training."}, 296 | ) 297 | resume_from_checkpoint: Optional[str] = field( 298 | default=None, 299 | metadata={"help": "The path to a folder with a valid checkpoint for your model."}, 300 | ) 301 | report_to: Optional[str] = field( 302 | default="wandb", 303 | metadata={"help": "Choose a tool to report train, evaluation log"} 304 | ) -------------------------------------------------------------------------------- /model/poem_model/gpt2_trinity/t: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/model/poem_model/gpt2_trinity/t -------------------------------------------------------------------------------- /model/poem_model/gpt2_trinity/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # https://huggingface.co/models?filter=causal-lm 의 코드를 수정했음을 APACHE LICENSE-2.0에 따라 고지함. 17 | 18 | import logging 19 | import math 20 | import os 21 | import sys 22 | import wandb 23 | 24 | from arguments import ModelArguments, DataTrainingArguments, TrainingArguments 25 | from utils import group_texts, send_along, tokenize_function 26 | 27 | import datasets 28 | from datasets import load_dataset 29 | 30 | import transformers 31 | from transformers import ( 32 | CONFIG_MAPPING, 33 | AutoConfig, 34 | AutoModelForCausalLM, 35 | AutoTokenizer, 36 | HfArgumentParser, 37 | Trainer, 38 | default_data_collator, 39 | set_seed, 40 | ) 41 | 42 | from transformers.trainer_utils import get_last_checkpoint 43 | from transformers.utils import check_min_version 44 | from transformers.utils.versions import require_version 45 | 46 | # 버젼 체크 47 | check_min_version("4.12.5") 48 | 49 | require_version("datasets>=1.8.0", "To fix: pip install -U datasets") 50 | 51 | logger = logging.getLogger(__name__) 52 | 53 | 54 | def main(): 55 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 56 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 57 | 58 | # 로깅 세팅 59 | logging.basicConfig( 60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 61 | datefmt="%m/%d/%Y %H:%M:%S", 62 | handlers=[logging.StreamHandler(sys.stdout)], 63 | ) 64 | 65 | log_level = training_args.get_process_log_level() 66 | logger.setLevel(log_level) 67 | datasets.utils.logging.set_verbosity(log_level) 68 | transformers.utils.logging.set_verbosity(log_level) 69 | transformers.utils.logging.enable_default_handler() 70 | transformers.utils.logging.enable_explicit_format() 71 | 72 | logger.warning( 73 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 74 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 75 | ) 76 | logger.info(f"Training/evaluation parameters {training_args}") 77 | 78 | # 체크포인트로부터 학습할 수 있도록, 존재한다면 체크포인트를 찾습니다. 79 | last_checkpoint = None 80 | if ( 81 | os.path.isdir(training_args.output_dir) 82 | and training_args.do_train 83 | and not training_args.overwrite_output_dir 84 | ): 85 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 86 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 87 | raise ValueError( 88 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 89 | "Use --overwrite_output_dir to overcome." 90 | ) 91 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 92 | logger.info( 93 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 94 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 95 | ) 96 | 97 | # 랜덤 시드 설정 98 | set_seed(training_args.seed) 99 | 100 | # 데이터셋 설정, csv, json, txt 형식으로 주어지는 경우 별도의 column_name 설정이 없으면 "text"라고 표기된 column이 training data로 인식됩니다. 101 | if data_args.dataset_name is not None: 102 | raw_datasets = load_dataset( 103 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 104 | ) 105 | if "validation" not in raw_datasets.keys(): 106 | raw_datasets["validation"] = load_dataset( 107 | data_args.dataset_name, 108 | data_args.dataset_config_name, 109 | split=f"train[:{data_args.validation_split_percentage}%]", 110 | cache_dir=model_args.cache_dir, 111 | ) 112 | raw_datasets["train"] = load_dataset( 113 | data_args.dataset_name, 114 | data_args.dataset_config_name, 115 | split=f"train[{data_args.validation_split_percentage}%:]", 116 | cache_dir=model_args.cache_dir, 117 | ) 118 | else: 119 | data_files = {} 120 | dataset_args = {} 121 | if data_args.train_file is not None: 122 | data_files["train"] = data_args.train_file 123 | if data_args.validation_file is not None: 124 | data_files["validation"] = data_args.validation_file 125 | extension = ( 126 | data_args.train_file.split(".")[-1] 127 | if data_args.train_file is not None 128 | else data_args.validation_file.split(".")[-1] 129 | ) 130 | if extension == "txt": 131 | extension = "text" 132 | dataset_args["keep_linebreaks"] = data_args.keep_linebreaks 133 | 134 | raw_datasets = load_dataset( 135 | extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args 136 | ) 137 | 138 | # validation set이 존재하지 않으면, data_args의 data_split_percentage에 따라 데이터를 분리합니다. 139 | if "validation" not in raw_datasets.keys(): 140 | raw_datasets["validation"] = load_dataset( 141 | extension, 142 | data_files=data_files, 143 | split=f"train[:{data_args.validation_split_percentage}%]", 144 | cache_dir=model_args.cache_dir, 145 | **dataset_args, 146 | ) 147 | raw_datasets["train"] = load_dataset( 148 | extension, 149 | data_files=data_files, 150 | split=f"train[{data_args.validation_split_percentage}%:]", 151 | cache_dir=model_args.cache_dir, 152 | **dataset_args, 153 | ) 154 | 155 | 156 | # 기학습 가중치와 토크나이저 불러오기 157 | config_kwargs = { 158 | "cache_dir": model_args.cache_dir, 159 | "revision": model_args.model_revision, 160 | "use_auth_token": True if model_args.use_auth_token else None, 161 | } 162 | if model_args.config_name: 163 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 164 | elif model_args.model_name_or_path: 165 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 166 | else: 167 | config = CONFIG_MAPPING[model_args.model_type]() 168 | logger.warning("You are instantiating a new config instance from scratch.") 169 | if model_args.config_overrides is not None: 170 | logger.info(f"Overriding config: {model_args.config_overrides}") 171 | config.update_from_string(model_args.config_overrides) 172 | logger.info(f"New config: {config}") 173 | 174 | tokenizer_kwargs = { 175 | "cache_dir": model_args.cache_dir, 176 | "use_fast": model_args.use_fast_tokenizer, 177 | "revision": model_args.model_revision, 178 | "use_auth_token": True if model_args.use_auth_token else None, 179 | "bos_token": "", 180 | "eos_token": "", 181 | "unk_token": "", 182 | "pad_token": "", 183 | "mask_token": "", 184 | } 185 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 186 | 187 | model = AutoModelForCausalLM.from_pretrained( 188 | model_args.model_name_or_path, 189 | revision=model_args.model_revision, 190 | pad_token_id=tokenizer.eos_token_id, 191 | cache_dir=model_args.cache_dir, 192 | use_auth_token=True if model_args.use_auth_token else None, 193 | ) 194 | model.resize_token_embeddings(len(tokenizer)) 195 | 196 | # 불러온 데이터를 토크나이징 합니다. 197 | if training_args.do_train: 198 | column_names = raw_datasets["train"].column_names 199 | else: 200 | column_names = raw_datasets["validation"].column_names 201 | text_column_name = "text" if "text" in column_names else column_names[0] 202 | 203 | tokenize_args = {"text_column_name": text_column_name, "tokenizer": tokenizer} 204 | 205 | with training_args.main_process_first(desc="dataset map tokenization"): 206 | tokenized_datasets = raw_datasets.map( 207 | send_along(tokenize_function, sent_along=tokenize_args), 208 | batched=True, 209 | num_proc=data_args.preprocessing_num_workers, 210 | remove_columns=column_names, 211 | load_from_cache_file=not data_args.overwrite_cache, 212 | desc="Running tokenizer on dataset", 213 | ) 214 | 215 | 216 | # 데이터를 block_size 단위로 병합합니다. 217 | if data_args.block_size is None: 218 | block_size = tokenizer.model_max_length 219 | if block_size > 1024: 220 | logger.warning( 221 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 222 | "Picking 1024 instead. You can change that default value by passing --block_size xxx." 223 | ) 224 | block_size = 1024 225 | else: 226 | if data_args.block_size > tokenizer.model_max_length: 227 | logger.warning( 228 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" 229 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 230 | ) 231 | block_size = min(data_args.block_size, tokenizer.model_max_length) 232 | 233 | 234 | with training_args.main_process_first(desc="grouping texts together"): 235 | lm_datasets = tokenized_datasets.map( 236 | send_along(group_texts, sent_along=block_size), 237 | batched=True, 238 | num_proc=data_args.preprocessing_num_workers, 239 | load_from_cache_file=not data_args.overwrite_cache, 240 | desc=f"Grouping texts in chunks of {block_size}", 241 | ) 242 | 243 | if training_args.do_train: 244 | if "train" not in tokenized_datasets: 245 | raise ValueError("--do_train requires a train dataset") 246 | train_dataset = lm_datasets["train"] 247 | if data_args.max_train_samples is not None: 248 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 249 | 250 | if training_args.do_eval: 251 | if "validation" not in tokenized_datasets: 252 | raise ValueError("--do_eval requires a validation dataset") 253 | eval_dataset = lm_datasets["validation"] 254 | if data_args.max_eval_samples is not None: 255 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 256 | 257 | 258 | # Trainer 초기화 259 | trainer = Trainer( 260 | model=model, 261 | args=training_args, 262 | train_dataset=train_dataset if training_args.do_train else None, 263 | eval_dataset=eval_dataset if training_args.do_eval else None, 264 | tokenizer=tokenizer, 265 | data_collator=default_data_collator, 266 | ) 267 | 268 | # wandb 설정 269 | if training_args.report_to == "wandb": 270 | wandb.init(project="kogpt_trinity_poem", 271 | name="kogpt_trinity_finetuning", 272 | tags=["baseline", "finetune"], 273 | group="kogpt_trinity_poem") 274 | 275 | # 학습 코드 276 | if training_args.do_train: 277 | checkpoint = None 278 | if training_args.resume_from_checkpoint is not None: 279 | checkpoint = training_args.resume_from_checkpoint 280 | elif last_checkpoint is not None: 281 | checkpoint = last_checkpoint 282 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 283 | trainer.save_model() 284 | 285 | metrics = train_result.metrics 286 | 287 | max_train_samples = ( 288 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 289 | ) 290 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 291 | 292 | trainer.log_metrics("train", metrics) 293 | trainer.save_metrics("train", metrics) 294 | trainer.save_state() 295 | 296 | # Evaluation 297 | if training_args.do_eval: 298 | logger.info("*** Evaluate ***") 299 | 300 | metrics = trainer.evaluate() 301 | 302 | max_eval_samples = ( 303 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 304 | ) 305 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 306 | try: 307 | perplexity = math.exp(metrics["eval_loss"]) 308 | except OverflowError: 309 | perplexity = float("inf") 310 | metrics["perplexity"] = perplexity 311 | 312 | trainer.log_metrics("eval", metrics) 313 | trainer.save_metrics("eval", metrics) 314 | 315 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 316 | if data_args.dataset_name is not None: 317 | kwargs["dataset_tags"] = data_args.dataset_name 318 | if data_args.dataset_config_name is not None: 319 | kwargs["dataset_args"] = data_args.dataset_config_name 320 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 321 | else: 322 | kwargs["dataset"] = data_args.dataset_name 323 | 324 | if training_args.push_to_hub: 325 | trainer.push_to_hub(**kwargs) 326 | else: 327 | trainer.create_model_card(**kwargs) 328 | 329 | 330 | if __name__ == "__main__": 331 | main() 332 | -------------------------------------------------------------------------------- /model/poem_model/gpt2_trinity/utils.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | 4 | from itertools import chain 5 | from transformers.testing_utils import CaptureLogger 6 | 7 | 8 | def send_along(func, sent_along): 9 | def inner(*args, **kwargs): 10 | return func(sent_along, *args, **kwargs) 11 | 12 | return inner 13 | 14 | 15 | 16 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 17 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 18 | 19 | def tokenize_function(tokenize_args, examples): 20 | tokenizer = tokenize_args['tokenizer'] 21 | text_column_name = tokenize_args['text_column_name'] 22 | with CaptureLogger(tok_logger) as cl: 23 | examples[text_column_name] = list(map(lambda x: str(x), examples[text_column_name])) 24 | output = tokenizer(examples[text_column_name]) 25 | # clm input could be much much longer than block_size 26 | if "Token indices sequence length is longer than the" in cl.out: 27 | tok_logger.warning( 28 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." 29 | ) 30 | return output 31 | 32 | 33 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 34 | def group_texts(block_size, examples): 35 | # Concatenate all texts. 36 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 37 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 38 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 39 | # customize this part to your needs. 40 | if total_length >= block_size: 41 | total_length = (total_length // block_size) * block_size 42 | # Split by chunks of max_len. 43 | result = { 44 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 45 | for k, t in concatenated_examples.items() 46 | } 47 | result["labels"] = result["input_ids"].copy() 48 | return result 49 | 50 | -------------------------------------------------------------------------------- /model/poem_model/utils.py: -------------------------------------------------------------------------------- 1 | ## 2 | -------------------------------------------------------------------------------- /model/vit_gpt2_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from caption_model.vit_gpt2.run_train import train 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | # data_arg 9 | parser.add_argument( 10 | "--output_dir", type=str, default="./model/caption_model/vit_gpt2" 11 | ) 12 | parser.add_argument("--data_dir", type=str, default="./data/caption_data") 13 | parser.add_argument( 14 | "--ms_coco_kor_file_path", 15 | type=str, 16 | default="./data/caption_data/MSCOCO_train_val_Korean.json", 17 | ) 18 | parser.add_argument( 19 | "--encoder_model_name_or_path", 20 | type=str, 21 | default="google/vit-base-patch16-224-in21k", 22 | ) 23 | parser.add_argument( 24 | "--decoder_model_name_or_path", type=str, default="skt/kogpt2-base-v2" 25 | ) 26 | 27 | # train_arg 28 | parser.add_argument("--num_labels", type=int, default=1) 29 | parser.add_argument("--seed", type=int, default=42) 30 | parser.add_argument("--num_train_epochs", type=int, default=5) 31 | parser.add_argument("--batch_size", type=int, default=64) 32 | parser.add_argument("--learning_rate", type=float, default=5e-5) 33 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 34 | parser.add_argument("--weight_decay", type=float, default=0.01) 35 | 36 | # eval_arg 37 | parser.add_argument("--evaluation_strategy", type=str, default="steps") 38 | parser.add_argument("--logging_steps", type=int, default=500) 39 | parser.add_argument("--save_steps", type=int, default=500) 40 | parser.add_argument("--eval_steps", type=int, default=500) 41 | parser.add_argument("--save_total_limit", type=int, default=2) 42 | 43 | args = parser.parse_args() 44 | train(args) 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.1 2 | aiosignal==1.2.0 3 | async-timeout==4.0.2 4 | attrs==21.2.0 5 | Automat==20.2.0 6 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 7 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work 8 | certifi==2021.10.8 9 | cffi==1.15.0 10 | charset-normalizer==2.0.9 11 | click==8.0.3 12 | constantly==15.1.0 13 | coverage==6.2 14 | coveralls==3.3.1 15 | cryptography==36.0.1 16 | cssselect==1.1.0 17 | datasets==1.17.0 18 | debugpy @ file:///tmp/build/80754af9/debugpy_1637091799509/work 19 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1631346842025/work 20 | dill==0.3.4 21 | docopt==0.6.2 22 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1605121927639/work/dist/entrypoints-0.3-py2.py3-none-any.whl 23 | filelock==3.4.0 24 | Flask==2.0.2 25 | frozenlist==1.2.0 26 | fsspec==2021.11.1 27 | h2==3.2.0 28 | hanja==0.13.3 29 | hpack==3.0.0 30 | huggingface-hub==0.2.1 31 | hyperframe==5.2.0 32 | hyperlink==21.0.0 33 | idna==3.3 34 | incremental==21.3.0 35 | iniconfig==1.1.1 36 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1638555504864/work/dist/ipykernel-6.6.0-py3-none-any.whl 37 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1638470227058/work 38 | itemadapter==0.4.0 39 | itemloaders==1.0.4 40 | itsdangerous==2.0.1 41 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1637175083648/work 42 | Jinja2==3.0.3 43 | jmespath==0.10.0 44 | joblib==1.1.0 45 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1637611911738/work 46 | jupyter-core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1636814260563/work 47 | lxml==4.7.1 48 | MarkupSafe==2.0.1 49 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1631080358261/work 50 | multidict==5.2.0 51 | multiprocess==0.70.12.2 52 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1638419302549/work 53 | nltk==3.6.6 54 | numpy==1.21.5 55 | packaging==21.3 56 | pandas==1.3.5 57 | parsel==1.6.0 58 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 59 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work 60 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 61 | Pillow==8.4.0 62 | pluggy==1.0.0 63 | priority==1.3.0 64 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1639065841292/work 65 | Protego==0.1.16 66 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 67 | py==1.11.0 68 | pyarrow==6.0.1 69 | pyasn1==0.4.8 70 | pyasn1-modules==0.2.8 71 | pycparser==2.21 72 | PyDispatcher==2.0.5 73 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1629119114968/work 74 | pyOpenSSL==21.0.0 75 | pyparsing==3.0.6 76 | pysqlite3==0.4.6 77 | pytest==6.2.5 78 | pytest-cov==3.0.0 79 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work 80 | pytz==2021.3 81 | PyYAML==5.1.2 82 | pyzmq==19.0.2 83 | queuelib==1.6.2 84 | regex==2021.11.10 85 | requests==2.26.0 86 | sacremoses==0.0.46 87 | scikit-learn==1.0.1 88 | scipy==1.7.3 89 | Scrapy==2.5.1 90 | service-identity==21.1.0 91 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 92 | threadpoolctl==3.0.0 93 | tokenizers==0.10.3 94 | toml==0.10.2 95 | tomli==2.0.0 96 | torch==1.10.1 97 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1610094708661/work 98 | tqdm==4.62.3 99 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1635260543454/work 100 | transformers==4.15.0 101 | Twisted==21.7.0 102 | typing_extensions==4.0.1 103 | urllib3==1.26.7 104 | w3lib==1.22.0 105 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work 106 | Werkzeug==2.0.2 107 | xxhash==2.0.2 108 | yarl==1.7.2 109 | zope.interface==5.4.0 110 | -------------------------------------------------------------------------------- /show_attend_and_tell/README.md: -------------------------------------------------------------------------------- 1 | 이 코드는 [**a-PyTorch-Tutorial-to-Image-Captioning**](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning)을 바탕으로 작성됐습니다. 2 | 3 | # Installation 4 | 5 | ``` 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | # Dataset 10 | 이미지 데이터는 MS COCO '14 Dataset을 이용합니다. [Training (13GB)](http://images.cocodataset.org/zips/train2014.zip)과 [Validation (6GB)](http://images.cocodataset.org/zips/val2014.zip) 이미지를 다운받아 caption_data 폴더에 저장해주시면 됩니다. 11 | 12 | 캡션 데이터는, AI Hub의 KETI R&D Data [한국어 이미지 설명 데이터셋](https://aihub.or.kr/opendata/keti-data/recognition-visual/KETI-01-003)을 [Andrej Karpathy's training, validation, and test splits](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip)에 알맞게 가공한 caption_data/dataset_coco_kor.json 파일을 이용합니다 13 | 14 | ## Inputs to model 15 | 16 | 세개의 input이 필요합니다. 17 |
18 | 19 | ### Images 20 | 21 | Pretrain된 encoder를 사용하기 때문에, encoder에 맞는 방식으로 이미지를 가공해야합니다. Pretrain된 ImageNet 모듈은 Pytorch의 `torchvision` 모듈로 제공됩니다. 22 | 23 | 필요한 전처리는 아래와 같습니다. 24 | - 픽셀 값을 [0,1]사이로 만들기 25 | - ImageNet image의 RGB 채널의 평균과 표준편차로 이미지 정규화 하기 26 | ```python 27 | mean = [0.485, 0.456, 0.406] 28 | std = [0.229, 0.224, 0.225] 29 | ``` 30 | - 256x256 사이즈로 이미지를 Resize하기 31 | - Pytorch가 NCHW convention을 따르기 때문에 channel dimension (C) 가 size dimension 보다 먼저 와야합니다. 32 | 33 | 34 | 그러므로, **입력 이미지는 `N, 3, 256, 256`의 `Float` tensor여야 하고**, 앞서 말한 평균과 표준편차로 정규화되어야 합니다. `N` 은 batch size입니다. 35 | 36 | ### Captions 37 | 38 | 캡션은 Decoder의 target임과 동시에 다음 단어를 생성하기 위한 input으로 이용됩니다. 39 | 40 | ## Data pipeline 41 | 42 | [`utils.py`](https://github.com/boostcampaitech2/final-project-level3-nlp-08/tree/dev/merge/show_attend_and_tell/utils.py)의 `create_input_files()`함수를 확인하면 됩니다. 43 | 44 | 이는 데이터를 읽고, 다음과 같은 파일들을 저장합니다 45 | - **`각 split에 해당하는 I, 3, 256, 256` 이미지 tensor를 포함하는 HDF5 file**, `I`는 split의 image 개수입니다. 46 | - **`N_c` * `I` 개의 encoded caption을 포함하는 JSON file**. `N_c`는 이미지당 캡션의 수 입니다. 47 | - **`N_c` * `I` 개의 캡션 길이를 포함하는 JSON file**. `i`번째 값은 `i` 번째 캡션의 길이입니다. 48 | - **`word_map`을 포함하는 JSON file**. 49 | 50 | `CaptionDataset`은 [`datasets.py`](https://github.com/boostcampaitech2/final-project-level3-nlp-08/tree/dev/merge/show_attend_and_tell/datasets.py)에서 확인 가능합니다. 51 | 52 | # Training 53 | 54 | 시작 전에 훈련에 필요한 데이터를 만들어야합니다. 이는 [`create_input_files.py`](https://github.com/boostcampaitech2/final-project-level3-nlp-08/tree/dev/merge/show_attend_and_tell/create_input_files.py)을 Karpathy JSON file과 `train2014` and `val2014` 이미지 폴더로 point 해주고 실행하면 됩니다. 55 | 56 | 처음부터 모델을 훈련하고 싶다면 57 | 58 | `python train.py` 를 실행하면 됩니다. 59 | 60 | # Inference 61 | 62 | command line 에서 **caption an image** 를 하기 위해서는 다음과 같은 명령어를 쳐주면 됩니다 – 63 | 64 | `python caption.py --img='path/to/image.jpeg' --model='path/to/BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' --word_map='path/to/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' --beam_size=5` 65 | 66 | [`eval.py`](https://github.com/boostcampaitech2/final-project-level3-nlp-08/tree/dev/merge/show_attend_and_tell/eval.py)에서는 validation set에 대한 BLEU-4 score를 계산해줍니다. 67 | 68 | 모델의 성능을 평가하고 싶다면 69 | 70 | `python eval.py` 를 실행하면 됩니다. 71 | 72 | # Evaluation Score 73 | Evaluation metric으로는 BLEU-4 score를 활용했습니다. BLEU는 generated sentence가 reference setence에 얼마나 포함되는지를 나타내주는 지표이며, BLEU-4 의 경우 4-gram 방식입니다. 평가한 성능은 아래와 같습니다. 74 | 75 | Beam Size | Validation BLEU-4 | Test BLEU-4 | 76 | :---: | :---: | :---: | 77 | 1 | 16.98 | 10.17 | 78 | 79 | 80 | ### 예시 81 |

82 | 83 | 84 | 85 |

86 | -------------------------------------------------------------------------------- /show_attend_and_tell/caption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import json 5 | import torchvision.transforms as transforms 6 | import matplotlib.pyplot as plt 7 | import matplotlib.cm as cm 8 | import skimage.transform 9 | import argparse 10 | from scipy.misc import imread, imresize 11 | from PIL import Image 12 | from matplotlib import font_manager, rc 13 | import time 14 | 15 | start = 0 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3): 19 | """ 20 | Reads an image and captions it with beam search. 21 | 22 | :param encoder: encoder model 23 | :param decoder: decoder model 24 | :param image_path: path to image 25 | :param word_map: word map 26 | :param beam_size: number of sequences to consider at each decode-step 27 | :return: caption, weights for visualization 28 | """ 29 | 30 | k = beam_size 31 | vocab_size = len(word_map) 32 | 33 | # Read image and process 34 | img = imread(image_path) 35 | #img = Image.open(image_path).convert("RGB") 36 | #img = np.array(img) 37 | print(type(img)) 38 | if len(img.shape) == 2: 39 | img = img[:, :, np.newaxis] 40 | img = np.concatenate([img, img, img], axis=2) 41 | img = imresize(img, (256, 256)) 42 | img = img.transpose(2, 0, 1) 43 | img = img / 255. 44 | img = torch.FloatTensor(img).to(device) 45 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 46 | std=[0.229, 0.224, 0.225]) 47 | transform = transforms.Compose([normalize]) 48 | image = transform(img) # (3, 256, 256) 49 | 50 | # Encode 51 | image = image.unsqueeze(0) # (1, 3, 256, 256) 52 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 53 | enc_image_size = encoder_out.size(1) 54 | encoder_dim = encoder_out.size(3) 55 | 56 | # Flatten encoding 57 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 58 | num_pixels = encoder_out.size(1) 59 | 60 | # We'll treat the problem as having a batch size of k 61 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 62 | 63 | # Tensor to store top k previous words at each step; now they're just 64 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 65 | 66 | # Tensor to store top k sequences; now they're just 67 | seqs = k_prev_words # (k, 1) 68 | 69 | # Tensor to store top k sequences' scores; now they're just 0 70 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 71 | 72 | # Tensor to store top k sequences' alphas; now they're just 1s 73 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 74 | 75 | # Lists to store completed sequences, their alphas and scores 76 | complete_seqs = list() 77 | complete_seqs_alpha = list() 78 | complete_seqs_scores = list() 79 | 80 | # Start decoding 81 | step = 1 82 | h, c = decoder.init_hidden_state(encoder_out) 83 | 84 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 85 | while True: 86 | 87 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 88 | 89 | awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 90 | 91 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 92 | 93 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 94 | awe = gate * awe 95 | 96 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 97 | 98 | scores = decoder.fc(h) # (s, vocab_size) 99 | scores = F.log_softmax(scores, dim=1) 100 | 101 | # Add 102 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 103 | 104 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 105 | if step == 1: 106 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 107 | else: 108 | # Unroll and find top scores, and their unrolled indices 109 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 110 | 111 | # Convert unrolled indices to actual indices of scores 112 | prev_word_inds = top_k_words / vocab_size # (s) 113 | next_word_inds = top_k_words % vocab_size # (s) 114 | 115 | # Add new words to sequences, alphas 116 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 117 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 118 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 119 | # Which sequences are incomplete (didn't reach )? 120 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 121 | next_word != word_map['']] 122 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 123 | #print(incomplete_inds) 124 | # Set aside complete sequences 125 | if len(complete_inds) > 0: 126 | complete_seqs.extend(seqs[complete_inds].tolist()) 127 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 128 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 129 | k -= len(complete_inds) # reduce beam length accordingly 130 | 131 | # Proceed with incomplete sequences 132 | if k == 0: 133 | break 134 | seqs = seqs[incomplete_inds] 135 | seqs_alpha = seqs_alpha[incomplete_inds] 136 | h = h[prev_word_inds[incomplete_inds]] 137 | c = c[prev_word_inds[incomplete_inds]] 138 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 139 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 140 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 141 | 142 | # Break if things have been going on too long 143 | if step > 50: 144 | complete_seqs.extend(seqs[incomplete_inds].tolist()) 145 | complete_seqs_alpha.extend(seqs_alpha[incomplete_inds].tolist()) 146 | complete_seqs_scores.extend(top_k_scores[incomplete_inds]) 147 | break 148 | step += 1 149 | 150 | #print(complete_seqs_scores) 151 | 152 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 153 | seq = complete_seqs[i] 154 | alphas = complete_seqs_alpha[i] 155 | return seq, alphas 156 | 157 | 158 | def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True): 159 | """ 160 | Visualizes caption with weights at every word. 161 | 162 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb 163 | 164 | :param image_path: path to image that has been captioned 165 | :param seq: caption 166 | :param alphas: weights 167 | :param rev_word_map: reverse word mapping, i.e. ix2word 168 | :param smooth: smooth weights? 169 | """ 170 | rc('font', family='NanumBarunGothic') 171 | image = Image.open(image_path) 172 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) 173 | 174 | words = [rev_word_map[ind] for ind in seq] 175 | print(words) 176 | 177 | for t in range(len(words)): 178 | if t > 50: 179 | break 180 | plt.subplot(int(np.ceil(len(words) / 5.)), 5, t + 1) 181 | 182 | plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12) 183 | plt.imshow(image) 184 | current_alpha = alphas[t, :] 185 | if smooth: 186 | alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8, multichannel=False) 187 | else: 188 | alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24], multichannel=False) 189 | if t == 0: 190 | plt.imshow(alpha, alpha=0) 191 | else: 192 | plt.imshow(alpha, alpha=0.8) 193 | plt.set_cmap(cm.Greys_r) 194 | plt.axis('off') 195 | plt.savefig("test_attention.png") 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser(description='Show, Attend, and Tell - Tutorial - Generate Caption') 200 | 201 | parser.add_argument('--img', '-i', help='path to image') 202 | parser.add_argument('--model', '-m', help='path to model') 203 | parser.add_argument('--word_map', '-wm', help='path to word map JSON') 204 | parser.add_argument('--beam_size', '-b', default=5, type=int, help='beam size for beam search') 205 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay') 206 | 207 | args = parser.parse_args() 208 | 209 | # Load model 210 | checkpoint = torch.load(args.model, map_location=str(device)) 211 | decoder = checkpoint['decoder'] 212 | decoder = decoder.to(device) 213 | pytorch_total_params1 = sum(p.numel() for p in decoder.parameters()) 214 | decoder.eval() 215 | encoder = checkpoint['encoder'] 216 | encoder = encoder.to(device) 217 | pytorch_total_params2 = sum(p.numel() for p in encoder.parameters()) 218 | encoder.eval() 219 | 220 | print(pytorch_total_params1 + pytorch_total_params2) 221 | 222 | # Load word map (word2ix) 223 | with open(args.word_map, 'r') as j: 224 | word_map = json.load(j) 225 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 226 | 227 | # Encode, decode with attention and beam search 228 | start = time.time() 229 | seq, alphas = caption_image_beam_search(encoder, decoder, args.img, word_map, args.beam_size) 230 | alphas = torch.FloatTensor(alphas) 231 | 232 | 233 | # Visualize caption and attention of best sequence 234 | visualize_att(args.img, seq, alphas, rev_word_map, args.smooth) 235 | print('done') 236 | -------------------------------------------------------------------------------- /show_attend_and_tell/create_input_files.py: -------------------------------------------------------------------------------- 1 | from utils import create_input_files 2 | 3 | if __name__ == '__main__': 4 | # Create input files (along with word map) 5 | create_input_files(dataset='coco', 6 | karpathy_json_path='../data/caption_data/dataset_coco_kor.json', 7 | image_folder='../data/captioin_data', 8 | captions_per_image=5, 9 | min_word_freq=5, 10 | output_folder='outputs/', 11 | max_len=50) 12 | -------------------------------------------------------------------------------- /show_attend_and_tell/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | import json 5 | import os 6 | 7 | 8 | class CaptionDataset(Dataset): 9 | """ 10 | A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches. 11 | """ 12 | 13 | def __init__(self, data_folder, data_name, split, transform=None): 14 | """ 15 | :param data_folder: folder where data files are stored 16 | :param data_name: base name of processed datasets 17 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST' 18 | :param transform: image transform pipeline 19 | """ 20 | self.split = split 21 | assert self.split in {'TRAIN', 'VAL', 'TEST'} 22 | 23 | # Open hdf5 file where images are stored 24 | self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r') 25 | self.imgs = self.h['images'] 26 | 27 | # Captions per image 28 | self.cpi = self.h.attrs['captions_per_image'] 29 | 30 | # Load encoded captions (completely into memory) 31 | with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as j: 32 | self.captions = json.load(j) 33 | 34 | # Load caption lengths (completely into memory) 35 | with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r') as j: 36 | self.caplens = json.load(j) 37 | 38 | # PyTorch transformation pipeline for the image (normalizing, etc.) 39 | self.transform = transform 40 | 41 | # Total number of datapoints 42 | self.dataset_size = len(self.captions) 43 | 44 | def __getitem__(self, i): 45 | # Remember, the Nth caption corresponds to the (N // captions_per_image)th image 46 | img = torch.FloatTensor(self.imgs[i // self.cpi] / 255.) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | caption = torch.LongTensor(self.captions[i]) 51 | 52 | caplen = torch.LongTensor([self.caplens[i]]) 53 | 54 | if self.split is 'TRAIN': 55 | return img, caption, caplen 56 | else: 57 | # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score 58 | all_captions = torch.LongTensor( 59 | self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)]) 60 | return img, caption, caplen, all_captions 61 | 62 | def __len__(self): 63 | return self.dataset_size 64 | -------------------------------------------------------------------------------- /show_attend_and_tell/eval.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | import torch.optim 3 | import torch.utils.data 4 | import torchvision.transforms as transforms 5 | from datasets import * 6 | from utils import * 7 | from nltk.translate.bleu_score import corpus_bleu 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | import pandas as pd 11 | 12 | # Parameters 13 | data_folder = 'outputs' # folder with data files saved by create_input_files.py 14 | data_name = 'coco_5_cap_per_img_5_min_word_freq' # base name shared by data files 15 | checkpoint = 'BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' # model checkpoint 16 | word_map_file = 'outputs/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' # word map, ensure it's the same the data was encoded with and the model was trained with 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors 18 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 19 | 20 | # Load model 21 | checkpoint = torch.load(checkpoint) 22 | decoder = checkpoint['decoder'] 23 | decoder = decoder.to(device) 24 | decoder.eval() 25 | encoder = checkpoint['encoder'] 26 | encoder = encoder.to(device) 27 | encoder.eval() 28 | 29 | # Load word map (word2ix) 30 | with open(word_map_file, 'r') as j: 31 | word_map = json.load(j) 32 | rev_word_map = {v: k for k, v in word_map.items()} 33 | vocab_size = len(word_map) 34 | 35 | # Normalization transform 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | 39 | 40 | def evaluate(beam_size): 41 | """ 42 | Evaluation 43 | 44 | :param beam_size: beam size at which to generate captions for evaluation 45 | :return: BLEU-4 score 46 | """ 47 | # DataLoader 48 | loader = torch.utils.data.DataLoader( 49 | CaptionDataset(data_folder, data_name, 'TEST', transform=transforms.Compose([normalize])), 50 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 51 | 52 | # TODO: Batched Beam Search 53 | # Therefore, do not use a batch_size greater than 1 - IMPORTANT! 54 | 55 | # Lists to store references (true captions), and hypothesis (prediction) for each image 56 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 57 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 58 | references = list() 59 | hypotheses = list() 60 | 61 | ref = [] 62 | hyp = [] 63 | # For each image 64 | for i, (image, caps, caplens, allcaps) in enumerate( 65 | tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))): 66 | 67 | k = beam_size 68 | 69 | # Move to GPU device, if available 70 | image = image.to(device) # (1, 3, 256, 256) 71 | 72 | # Encode 73 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 74 | enc_image_size = encoder_out.size(1) 75 | encoder_dim = encoder_out.size(3) 76 | 77 | # Flatten encoding 78 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 79 | num_pixels = encoder_out.size(1) 80 | 81 | # We'll treat the problem as having a batch size of k 82 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 83 | 84 | # Tensor to store top k previous words at each step; now they're just 85 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 86 | 87 | # Tensor to store top k sequences; now they're just 88 | seqs = k_prev_words # (k, 1) 89 | 90 | # Tensor to store top k sequences' scores; now they're just 0 91 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 92 | 93 | # Lists to store completed sequences and scores 94 | complete_seqs = list() 95 | complete_seqs_scores = list() 96 | 97 | # Start decoding 98 | step = 1 99 | h, c = decoder.init_hidden_state(encoder_out) 100 | 101 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 102 | while True: 103 | 104 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 105 | 106 | awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 107 | 108 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 109 | awe = gate * awe 110 | 111 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 112 | 113 | scores = decoder.fc(h) # (s, vocab_size) 114 | scores = F.log_softmax(scores, dim=1) 115 | 116 | # Add 117 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 118 | 119 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 120 | if step == 1: 121 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 122 | else: 123 | # Unroll and find top scores, and their unrolled indices 124 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 125 | 126 | # Convert unrolled indices to actual indices of scores 127 | prev_word_inds = top_k_words / vocab_size # (s) 128 | next_word_inds = top_k_words % vocab_size # (s) 129 | 130 | # Add new words to sequences 131 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 132 | 133 | # Which sequences are incomplete (didn't reach )? 134 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 135 | next_word != word_map['']] 136 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 137 | 138 | # Set aside complete sequences 139 | if len(complete_inds) > 0: 140 | complete_seqs.extend(seqs[complete_inds].tolist()) 141 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 142 | k -= len(complete_inds) # reduce beam length accordingly 143 | 144 | # Proceed with incomplete sequences 145 | if k == 0: 146 | break 147 | seqs = seqs[incomplete_inds] 148 | h = h[prev_word_inds[incomplete_inds]] 149 | c = c[prev_word_inds[incomplete_inds]] 150 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 151 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 152 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 153 | 154 | # Break if things have been going on too long 155 | if step > 50: 156 | complete_seqs.extend(seqs[incomplete_inds].tolist()) 157 | complete_seqs_scores.extend(top_k_scores[incomplete_inds]) 158 | break 159 | step += 1 160 | 161 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 162 | seq = complete_seqs[i] 163 | 164 | # References 165 | img_caps = allcaps[0].tolist() 166 | img_captions = list( 167 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}], 168 | img_caps)) # remove and pads 169 | references.append(img_captions) 170 | print(references) 171 | # Hypotheses 172 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}]) 173 | 174 | assert len(references) == len(hypotheses) 175 | temp = [] 176 | for cap in img_caps: 177 | c = [rev_word_map[ind] for ind in cap if ind not in {word_map[''], word_map[''], word_map['']}] 178 | temp.append(' '.join(c)) 179 | ref.append(temp) 180 | words = [rev_word_map[ind] for ind in seq if ind not in {word_map[''], word_map[''], word_map['']}] 181 | hyp.append(' '.join(words) + ' hyp') 182 | 183 | #save evalulation result 184 | df = pd.DataFrame(ref, hyp) 185 | df.to_csv('caption eval.csv') 186 | # Calculate BLEU-4 scores 187 | bleu4 = corpus_bleu(references, hypotheses) 188 | 189 | return bleu4 190 | 191 | 192 | if __name__ == '__main__': 193 | beam_size = 1 194 | print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, evaluate(beam_size))) 195 | -------------------------------------------------------------------------------- /show_attend_and_tell/model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/show_attend_and_tell/model.py -------------------------------------------------------------------------------- /show_attend_and_tell/requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.3.4 2 | argcomplete==1.12.3 3 | argon2-cffi==21.1.0 4 | attrs==21.2.0 5 | Babel==2.9.1 6 | backcall==0.2.0 7 | bleach==4.1.0 8 | cached-property==1.5.2 9 | certifi==2021.10.8 10 | cffi==1.15.0 11 | charset-normalizer==2.0.7 12 | click==8.0.3 13 | cycler==0.11.0 14 | debugpy==1.5.1 15 | decorator==5.1.0 16 | defusedxml==0.7.1 17 | entrypoints==0.3 18 | fonttools==4.28.2 19 | h5py==3.6.0 20 | idna==3.3 21 | imageio==2.11.1 22 | importlib-metadata==4.8.1 23 | importlib-resources==5.4.0 24 | install==1.3.4 25 | ipykernel==6.5.0 26 | ipython==7.29.0 27 | ipython-genutils==0.2.0 28 | jedi==0.18.0 29 | Jinja2==3.0.2 30 | joblib==1.1.0 31 | json5==0.9.6 32 | jsonschema==4.2.1 33 | jupyter-client==7.0.6 34 | jupyter-core==4.9.1 35 | jupyter-server==1.11.2 36 | jupyterlab==3.2.2 37 | jupyterlab-pygments==0.1.2 38 | jupyterlab-server==2.8.2 39 | kiwisolver==1.3.2 40 | MarkupSafe==2.0.1 41 | matplotlib==3.1.0 42 | matplotlib-inline==0.1.3 43 | mistune==0.8.4 44 | nbclassic==0.3.4 45 | nbclient==0.5.4 46 | nbconvert==6.2.0 47 | nbformat==5.1.3 48 | nest-asyncio==1.5.1 49 | networkx==2.6.3 50 | nltk==3.6.5 51 | notebook==6.4.5 52 | numpy==1.21.4 53 | opencv-python==4.5.4.60 54 | packaging==21.2 55 | pandas==1.3.4 56 | pandocfilters==1.5.0 57 | parso==0.8.2 58 | pexpect==4.8.0 59 | pickleshare==0.7.5 60 | Pillow==8.4.0 61 | praw==7.5.0 62 | prawcore==2.3.0 63 | prometheus-client==0.12.0 64 | prompt-toolkit==3.0.22 65 | psaw==0.1.0 66 | ptyprocess==0.7.0 67 | pycparser==2.21 68 | Pygments==2.10.0 69 | pyparsing==2.4.7 70 | pyrsistent==0.18.0 71 | python-dateutil==2.8.2 72 | pytz==2021.3 73 | PyWavelets==1.2.0 74 | pyzmq==22.3.0 75 | regex==2021.11.10 76 | requests==2.26.0 77 | scikit-image==0.15.0 78 | scipy==1.2.0 79 | Send2Trash==1.8.0 80 | setuptools-scm==6.3.2 81 | six==1.16.0 82 | sniffio==1.2.0 83 | terminado==0.12.1 84 | testpath==0.5.0 85 | tifffile==2021.11.2 86 | tomli==1.2.2 87 | torch==0.4.1.post2 88 | torchvision==0.2.1 89 | tornado==6.1 90 | tqdm==4.62.3 91 | traitlets==5.1.1 92 | typing-extensions==3.10.0.2 93 | update-checker==0.18.0 94 | urllib3==1.26.7 95 | wcwidth==0.2.5 96 | webencodings==0.5.1 97 | websocket-client==1.2.1 98 | zipp==3.6.0 99 | -------------------------------------------------------------------------------- /show_attend_and_tell/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.backends.cudnn as cudnn 3 | import torch.optim 4 | import torch.utils.data 5 | import torchvision.transforms as transforms 6 | from torch import nn 7 | from torch.nn.utils.rnn import pack_padded_sequence 8 | from models import Encoder, DecoderWithAttention 9 | from datasets import * 10 | from utils import * 11 | from nltk.translate.bleu_score import corpus_bleu 12 | 13 | # Data parameters 14 | data_folder = 'outputs/' # folder with data files saved by create_input_files.py 15 | data_name = 'coco_5_cap_per_img_5_min_word_freq' # base name shared by data files 16 | 17 | # Model parameters 18 | emb_dim = 512 # dimension of word embeddings 19 | attention_dim = 512 # dimension of attention linear layers 20 | decoder_dim = 512 # dimension of decoder RNN 21 | dropout = 0.5 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors 23 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 24 | 25 | # Training parameters 26 | start_epoch = 0 27 | epochs = 120 # number of epochs to train for (if early stopping is not triggered) 28 | epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU 29 | batch_size = 64 30 | workers = 1 # for data-loading; right now, only 1 works with h5py 31 | encoder_lr = 1e-4 # learning rate for encoder if fine-tuning 32 | decoder_lr = 4e-4 # learning rate for decoder 33 | grad_clip = 5. # clip gradients at an absolute value of 34 | alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper 35 | best_bleu4 = 0. # BLEU-4 score right now 36 | print_freq = 100 # print training/validation stats every __ batches 37 | fine_tune_encoder = True # fine-tune encoder? 38 | #checkpoint = '/opt/ml/a-PyTorch-Tutorial-to-Image-Captioning/eng_models/BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' # path to checkpoint, None if none 39 | checkpoint = None 40 | 41 | def main(): 42 | """ 43 | Training and validation. 44 | """ 45 | 46 | global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map 47 | 48 | # Read word map 49 | word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json') 50 | with open(word_map_file, 'r') as j: 51 | word_map = json.load(j) 52 | 53 | # Initialize / load checkpoint 54 | if checkpoint is None: 55 | decoder = DecoderWithAttention(attention_dim=attention_dim, 56 | embed_dim=emb_dim, 57 | decoder_dim=decoder_dim, 58 | vocab_size=len(word_map), 59 | dropout=dropout) 60 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()), 61 | lr=decoder_lr) 62 | encoder = Encoder() 63 | encoder.fine_tune(fine_tune_encoder) 64 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 65 | lr=encoder_lr) if fine_tune_encoder else None 66 | 67 | else: 68 | checkpoint = torch.load(checkpoint) 69 | start_epoch = checkpoint['epoch'] + 1 70 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 71 | best_bleu4 = checkpoint['bleu-4'] 72 | decoder = checkpoint['decoder'] 73 | decoder_optimizer = checkpoint['decoder_optimizer'] 74 | encoder = checkpoint['encoder'] 75 | encoder_optimizer = checkpoint['encoder_optimizer'] 76 | if fine_tune_encoder is True and encoder_optimizer is None: 77 | encoder.fine_tune(fine_tune_encoder) 78 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 79 | lr=encoder_lr) 80 | 81 | # Move to GPU, if available 82 | decoder = decoder.to(device) 83 | encoder = encoder.to(device) 84 | 85 | # Loss function 86 | criterion = nn.CrossEntropyLoss().to(device) 87 | 88 | # Custom dataloaders 89 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 90 | std=[0.229, 0.224, 0.225]) 91 | train_loader = torch.utils.data.DataLoader( 92 | CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])), 93 | batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) 94 | val_loader = torch.utils.data.DataLoader( 95 | CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])), 96 | batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) 97 | 98 | # Epochs 99 | for epoch in range(start_epoch, epochs): 100 | 101 | # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20 102 | if epochs_since_improvement == 20: 103 | break 104 | if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0: 105 | adjust_learning_rate(decoder_optimizer, 0.8) 106 | if fine_tune_encoder: 107 | adjust_learning_rate(encoder_optimizer, 0.8) 108 | 109 | # One epoch's training 110 | train(train_loader=train_loader, 111 | encoder=encoder, 112 | decoder=decoder, 113 | criterion=criterion, 114 | encoder_optimizer=encoder_optimizer, 115 | decoder_optimizer=decoder_optimizer, 116 | epoch=epoch) 117 | 118 | # One epoch's validation 119 | recent_bleu4 = validate(val_loader=val_loader, 120 | encoder=encoder, 121 | decoder=decoder, 122 | criterion=criterion) 123 | 124 | # Check if there was an improvement 125 | is_best = recent_bleu4 > best_bleu4 126 | best_bleu4 = max(recent_bleu4, best_bleu4) 127 | if not is_best: 128 | epochs_since_improvement += 1 129 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 130 | else: 131 | epochs_since_improvement = 0 132 | 133 | # Save checkpoint 134 | save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, 135 | decoder_optimizer, recent_bleu4, is_best) 136 | 137 | 138 | def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch): 139 | """ 140 | Performs one epoch's training. 141 | 142 | :param train_loader: DataLoader for training data 143 | :param encoder: encoder model 144 | :param decoder: decoder model 145 | :param criterion: loss layer 146 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 147 | :param decoder_optimizer: optimizer to update decoder's weights 148 | :param epoch: epoch number 149 | """ 150 | 151 | decoder.train() # train mode (dropout and batchnorm is used) 152 | encoder.train() 153 | 154 | batch_time = AverageMeter() # forward prop. + back prop. time 155 | data_time = AverageMeter() # data loading time 156 | losses = AverageMeter() # loss (per word decoded) 157 | top5accs = AverageMeter() # top5 accuracy 158 | 159 | start = time.time() 160 | 161 | # Batches 162 | for i, (imgs, caps, caplens) in enumerate(train_loader): 163 | data_time.update(time.time() - start) 164 | 165 | # Move to GPU, if available 166 | imgs = imgs.to(device) 167 | caps = caps.to(device) 168 | caplens = caplens.to(device) 169 | 170 | # Forward prop. 171 | imgs = encoder(imgs) 172 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 173 | 174 | # Since we decoded starting with , the targets are all words after , up to 175 | targets = caps_sorted[:, 1:] 176 | 177 | # Remove timesteps that we didn't decode at, or are pads 178 | # pack_padded_sequence is an easy trick to do this 179 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) 180 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) 181 | 182 | # Calculate loss 183 | loss = criterion(scores, targets) 184 | 185 | # Add doubly stochastic attention regularization 186 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 187 | 188 | # Back prop. 189 | decoder_optimizer.zero_grad() 190 | if encoder_optimizer is not None: 191 | encoder_optimizer.zero_grad() 192 | loss.backward() 193 | 194 | # Clip gradients 195 | if grad_clip is not None: 196 | clip_gradient(decoder_optimizer, grad_clip) 197 | if encoder_optimizer is not None: 198 | clip_gradient(encoder_optimizer, grad_clip) 199 | 200 | # Update weights 201 | decoder_optimizer.step() 202 | if encoder_optimizer is not None: 203 | encoder_optimizer.step() 204 | 205 | # Keep track of metrics 206 | top5 = accuracy(scores, targets, 5) 207 | losses.update(loss.item(), sum(decode_lengths)) 208 | top5accs.update(top5, sum(decode_lengths)) 209 | batch_time.update(time.time() - start) 210 | 211 | start = time.time() 212 | 213 | # Print status 214 | if i % print_freq == 0: 215 | print('Epoch: [{0}][{1}/{2}]\t' 216 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 217 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 218 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 219 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader), 220 | batch_time=batch_time, 221 | data_time=data_time, loss=losses, 222 | top5=top5accs)) 223 | 224 | 225 | def validate(val_loader, encoder, decoder, criterion): 226 | """ 227 | Performs one epoch's validation. 228 | 229 | :param val_loader: DataLoader for validation data. 230 | :param encoder: encoder model 231 | :param decoder: decoder model 232 | :param criterion: loss layer 233 | :return: BLEU-4 score 234 | """ 235 | decoder.eval() # eval mode (no dropout or batchnorm) 236 | if encoder is not None: 237 | encoder.eval() 238 | 239 | batch_time = AverageMeter() 240 | losses = AverageMeter() 241 | top5accs = AverageMeter() 242 | 243 | start = time.time() 244 | 245 | references = list() # references (true captions) for calculating BLEU-4 score 246 | hypotheses = list() # hypotheses (predictions) 247 | 248 | # explicitly disable gradient calculation to avoid CUDA memory error 249 | # solves the issue #57 250 | with torch.no_grad(): 251 | # Batches 252 | for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): 253 | 254 | # Move to device, if available 255 | imgs = imgs.to(device) 256 | caps = caps.to(device) 257 | caplens = caplens.to(device) 258 | 259 | # Forward prop. 260 | if encoder is not None: 261 | imgs = encoder(imgs) 262 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 263 | 264 | # Since we decoded starting with , the targets are all words after , up to 265 | targets = caps_sorted[:, 1:] 266 | 267 | # Remove timesteps that we didn't decode at, or are pads 268 | # pack_padded_sequence is an easy trick to do this 269 | scores_copy = scores.clone() 270 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) 271 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) 272 | 273 | # Calculate loss 274 | loss = criterion(scores, targets) 275 | 276 | # Add doubly stochastic attention regularization 277 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 278 | 279 | # Keep track of metrics 280 | losses.update(loss.item(), sum(decode_lengths)) 281 | top5 = accuracy(scores, targets, 5) 282 | top5accs.update(top5, sum(decode_lengths)) 283 | batch_time.update(time.time() - start) 284 | 285 | start = time.time() 286 | 287 | if i % print_freq == 0: 288 | print('Validation: [{0}/{1}]\t' 289 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 290 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 291 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time, 292 | loss=losses, top5=top5accs)) 293 | 294 | # Store references (true captions), and hypothesis (prediction) for each image 295 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 296 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 297 | 298 | # References 299 | allcaps = allcaps[sort_ind] # because images were sorted in the decoder 300 | for j in range(allcaps.shape[0]): 301 | img_caps = allcaps[j].tolist() 302 | img_captions = list( 303 | map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], 304 | img_caps)) # remove and pads 305 | references.append(img_captions) 306 | # Hypotheses 307 | _, preds = torch.max(scores_copy, dim=2) 308 | preds = preds.tolist() 309 | temp_preds = list() 310 | for j, p in enumerate(preds): 311 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads 312 | preds = temp_preds 313 | hypotheses.extend(preds) 314 | 315 | assert len(references) == len(hypotheses) 316 | 317 | # Calculate BLEU-4 scores 318 | #import pdb; pdb.set_trace() 319 | bleu4 = corpus_bleu(references, hypotheses) 320 | 321 | print( 322 | '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format( 323 | loss=losses, 324 | top5=top5accs, 325 | bleu=bleu4)) 326 | 327 | return bleu4 328 | 329 | 330 | if __name__ == '__main__': 331 | main() 332 | -------------------------------------------------------------------------------- /show_attend_and_tell/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import json 5 | import torch 6 | from scipy.misc import imread, imresize 7 | from tqdm import tqdm 8 | from collections import Counter 9 | from random import seed, choice, sample 10 | 11 | 12 | def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder, 13 | max_len=100): 14 | """ 15 | Creates input files for training, validation, and test data. 16 | 17 | :param dataset: name of dataset, one of 'coco', 'flickr8k', 'flickr30k' 18 | :param karpathy_json_path: path of Karpathy JSON file with splits and captions 19 | :param image_folder: folder with downloaded images 20 | :param captions_per_image: number of captions to sample per image 21 | :param min_word_freq: words occuring less frequently than this threshold are binned as s 22 | :param output_folder: folder to save files 23 | :param max_len: don't sample captions longer than this length 24 | """ 25 | 26 | assert dataset in {'coco', 'flickr8k', 'flickr30k'} 27 | 28 | # Read Karpathy JSON 29 | with open(karpathy_json_path, 'r') as j: 30 | data = json.load(j) 31 | 32 | # Read image paths and captions for each image 33 | train_image_paths = [] 34 | train_image_captions = [] 35 | val_image_paths = [] 36 | val_image_captions = [] 37 | test_image_paths = [] 38 | test_image_captions = [] 39 | word_freq = Counter() 40 | 41 | for img in data['images']: 42 | captions = [] 43 | for c in img['sentences']: 44 | # Update word frequency 45 | word_freq.update(c['tokens']) 46 | if len(c['tokens']) <= max_len: 47 | captions.append(c['tokens']) 48 | 49 | if len(captions) == 0: 50 | continue 51 | 52 | path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join( 53 | image_folder, img['filename']) 54 | 55 | if img['split'] in {'train', 'restval'}: 56 | train_image_paths.append(path) 57 | train_image_captions.append(captions) 58 | elif img['split'] in {'val'}: 59 | val_image_paths.append(path) 60 | val_image_captions.append(captions) 61 | elif img['split'] in {'test'}: 62 | test_image_paths.append(path) 63 | test_image_captions.append(captions) 64 | 65 | # Sanity check 66 | assert len(train_image_paths) == len(train_image_captions) 67 | assert len(val_image_paths) == len(val_image_captions) 68 | assert len(test_image_paths) == len(test_image_captions) 69 | 70 | # Create word map 71 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq] 72 | word_map = {k: v + 1 for v, k in enumerate(words)} 73 | word_map[''] = len(word_map) + 1 74 | word_map[''] = len(word_map) + 1 75 | word_map[''] = len(word_map) + 1 76 | word_map[''] = 0 77 | 78 | # Create a base/root name for all output files 79 | base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq' 80 | 81 | # Save word map to a JSON 82 | with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j: 83 | json.dump(word_map, j) 84 | 85 | # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files 86 | seed(123) 87 | for impaths, imcaps, split in [#(train_image_paths, train_image_captions, 'TRAIN'), 88 | #(val_image_paths, val_image_captions, 'VAL'), 89 | (test_image_paths, test_image_captions, 'TEST')]: 90 | 91 | with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h: 92 | # Make a note of the number of captions we are sampling per image 93 | h.attrs['captions_per_image'] = captions_per_image 94 | 95 | # Create dataset inside HDF5 file to store images 96 | images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8') 97 | 98 | print("\nReading %s images and captions, storing to file...\n" % split) 99 | 100 | enc_captions = [] 101 | caplens = [] 102 | paths = [] 103 | 104 | for i, path in enumerate(tqdm(impaths)): 105 | 106 | # Sample captions 107 | if len(imcaps[i]) < captions_per_image: 108 | captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))] 109 | else: 110 | captions = sample(imcaps[i], k=captions_per_image) 111 | 112 | # Sanity check 113 | assert len(captions) == captions_per_image 114 | 115 | # Read images 116 | img = imread(impaths[i]) 117 | if len(img.shape) == 2: 118 | img = img[:, :, np.newaxis] 119 | img = np.concatenate([img, img, img], axis=2) 120 | img = imresize(img, (256, 256)) 121 | img = img.transpose(2, 0, 1) 122 | assert img.shape == (3, 256, 256) 123 | assert np.max(img) <= 255 124 | 125 | # Save image to HDF5 file 126 | images[i] = img 127 | paths.append(path) 128 | 129 | for j, c in enumerate(captions): 130 | # Encode captions 131 | enc_c = [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [ 132 | word_map['']] + [word_map['']] * (max_len - len(c)) 133 | 134 | # Find caption lengths 135 | c_len = len(c) + 2 136 | 137 | enc_captions.append(enc_c) 138 | caplens.append(c_len) 139 | 140 | # Sanity check 141 | assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens) 142 | 143 | # Save encoded captions and their lengths to JSON files 144 | with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j: 145 | json.dump(enc_captions, j) 146 | 147 | with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j: 148 | json.dump(caplens, j) 149 | 150 | with open(os.path.join(output_folder, split + '_PATHS_' + base_filename + '.json'), 'w') as j: 151 | json.dump(paths, j) 152 | 153 | 154 | def init_embedding(embeddings): 155 | """ 156 | Fills embedding tensor with values from the uniform distribution. 157 | 158 | :param embeddings: embedding tensor 159 | """ 160 | bias = np.sqrt(3.0 / embeddings.size(1)) 161 | torch.nn.init.uniform_(embeddings, -bias, bias) 162 | 163 | 164 | def load_embeddings(emb_file, word_map): 165 | """ 166 | Creates an embedding tensor for the specified word map, for loading into the model. 167 | 168 | :param emb_file: file containing embeddings (stored in GloVe format) 169 | :param word_map: word map 170 | :return: embeddings in the same order as the words in the word map, dimension of embeddings 171 | """ 172 | 173 | # Find embedding dimension 174 | with open(emb_file, 'r') as f: 175 | emb_dim = len(f.readline().split(' ')) - 1 176 | 177 | vocab = set(word_map.keys()) 178 | 179 | # Create tensor to hold embeddings, initialize 180 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 181 | init_embedding(embeddings) 182 | 183 | # Read embedding file 184 | print("\nLoading embeddings...") 185 | for line in open(emb_file, 'r'): 186 | line = line.split(' ') 187 | 188 | emb_word = line[0] 189 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:]))) 190 | 191 | # Ignore word if not in train_vocab 192 | if emb_word not in vocab: 193 | continue 194 | 195 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 196 | 197 | return embeddings, emb_dim 198 | 199 | 200 | def clip_gradient(optimizer, grad_clip): 201 | """ 202 | Clips gradients computed during backpropagation to avoid explosion of gradients. 203 | 204 | :param optimizer: optimizer with the gradients to be clipped 205 | :param grad_clip: clip value 206 | """ 207 | for group in optimizer.param_groups: 208 | for param in group['params']: 209 | if param.grad is not None: 210 | param.grad.data.clamp_(-grad_clip, grad_clip) 211 | 212 | 213 | def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, 214 | bleu4, is_best): 215 | """ 216 | Saves model checkpoint. 217 | 218 | :param data_name: base name of processed dataset 219 | :param epoch: epoch number 220 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score 221 | :param encoder: encoder model 222 | :param decoder: decoder model 223 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 224 | :param decoder_optimizer: optimizer to update decoder's weights 225 | :param bleu4: validation BLEU-4 score for this epoch 226 | :param is_best: is this checkpoint the best so far? 227 | """ 228 | state = {'epoch': epoch, 229 | 'epochs_since_improvement': epochs_since_improvement, 230 | 'bleu-4': bleu4, 231 | 'encoder': encoder, 232 | 'decoder': decoder, 233 | 'encoder_optimizer': encoder_optimizer, 234 | 'decoder_optimizer': decoder_optimizer} 235 | filename = 'checkpoint_' + data_name + '.pth.tar' 236 | torch.save(state, filename) 237 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 238 | if is_best: 239 | torch.save(state, 'BEST_' + filename) 240 | 241 | 242 | class AverageMeter(object): 243 | """ 244 | Keeps track of most recent, average, sum, and count of a metric. 245 | """ 246 | 247 | def __init__(self): 248 | self.reset() 249 | 250 | def reset(self): 251 | self.val = 0 252 | self.avg = 0 253 | self.sum = 0 254 | self.count = 0 255 | 256 | def update(self, val, n=1): 257 | self.val = val 258 | self.sum += val * n 259 | self.count += n 260 | self.avg = self.sum / self.count 261 | 262 | 263 | def adjust_learning_rate(optimizer, shrink_factor): 264 | """ 265 | Shrinks learning rate by a specified factor. 266 | 267 | :param optimizer: optimizer whose learning rate must be shrunk. 268 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 269 | """ 270 | 271 | print("\nDECAYING learning rate.") 272 | for param_group in optimizer.param_groups: 273 | param_group['lr'] = param_group['lr'] * shrink_factor 274 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 275 | 276 | 277 | def accuracy(scores, targets, k): 278 | """ 279 | Computes top-k accuracy, from predicted and true labels. 280 | 281 | :param scores: scores from the model 282 | :param targets: true labels 283 | :param k: k in top-k accuracy 284 | :return: top-k accuracy 285 | """ 286 | 287 | batch_size = targets.size(0) 288 | _, ind = scores.topk(k, 1, True, True) 289 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 290 | correct_total = correct.view(-1).float().sum() # 0D tensor 291 | return correct_total.item() * (100.0 / batch_size) 292 | -------------------------------------------------------------------------------- /web/app.py: -------------------------------------------------------------------------------- 1 | from os.path import join, dirname, realpath 2 | 3 | from flask import Flask, request, redirect, render_template, flash, url_for 4 | from werkzeug.utils import secure_filename 5 | from transformers import ( 6 | VisionEncoderDecoderModel, 7 | AutoModelForCausalLM, 8 | AutoTokenizer, 9 | PreTrainedTokenizerFast, 10 | ViTFeatureExtractor, 11 | ) 12 | import torch 13 | 14 | from utils import * 15 | 16 | 17 | ALLOWED_EXTENSIONS = set(["png", "jpg", "jpeg"]) 18 | UPLOAD_FOLDER = "web/assets/uploads" 19 | 20 | app = Flask( 21 | __name__, static_url_path="", static_folder="", template_folder="web/templates" 22 | ) 23 | app.secret_key = "secret key" 24 | app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER 25 | app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024 26 | 27 | 28 | def allowed_file(filename): 29 | return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS 30 | 31 | 32 | @app.route("/display/") 33 | def display_image(filename): 34 | return redirect( 35 | url_for("static", filename=join(UPLOAD_FOLDER, filename), code=301) 36 | ) 37 | 38 | 39 | @app.route("/about") 40 | def about(): 41 | return render_template("about.html") 42 | 43 | 44 | @app.route("/") 45 | def index(): 46 | # filename = request.args.get("filename") 47 | if request.args.get("filename"): 48 | filename = request.args.get("filename").split("/")[-1] 49 | generated_poems = generate_poem_from_image( 50 | vision_encoder_decoder_model=vision_encoder_decoder_model, 51 | vision_encoder_decoder_tokenizer=vision_encoder_decoder_tokenizer, 52 | feature_extractor=feature_extractor, 53 | poem_generator=poem_generator, 54 | poem_tokenizer=poem_tokenizer, 55 | hk_poem_generator=hk_poem_generator, 56 | hk_poem_tokenizer=hk_poem_tokenizer, 57 | file_folder=app.config["UPLOAD_FOLDER"], 58 | filename=filename, 59 | ) 60 | return render_template( 61 | "responsive.html", 62 | filename=filename, 63 | generated_poems=generated_poems, 64 | ) 65 | else: 66 | return render_template("responsive.html") 67 | 68 | 69 | @app.route("/", methods=["GET", "POST"]) 70 | def upload_image(): 71 | if request.method == "POST": 72 | if "file" not in request.files: 73 | flash("No file part") 74 | return redirect(request.url) 75 | file = request.files["file"] 76 | if file.filename == "": 77 | flash("No image selected for uploading") 78 | return redirect(request.url) 79 | if file and allowed_file(file.filename): 80 | filename = secure_filename(file.filename) 81 | file.save(join(dirname(realpath(__file__)),UPLOAD_FOLDER, filename)) 82 | flash("Image successfully uploaded and displayed below") 83 | return render_template("responsive.html", filename=filename) 84 | else: 85 | flash("Allowed image types are -> png, jpg, jpeg, gif") 86 | return redirect(request.url) 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | # device setting 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | 94 | # load model 95 | encoder_model_name_or_path = "ddobokki/vision-encoder-decoder-vit-gpt2-coco-ko" 96 | feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_model_name_or_path) 97 | vision_encoder_decoder_tokenizer = PreTrainedTokenizerFast.from_pretrained( 98 | encoder_model_name_or_path 99 | ) 100 | vision_encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained( 101 | encoder_model_name_or_path 102 | ) 103 | vision_encoder_decoder_model.to(device) 104 | print("captioning model load") 105 | 106 | poem_generator_model_path = "CheonggyeMountain-Sherpa/kogpt-trinity-poem" 107 | poem_generator = AutoModelForCausalLM.from_pretrained( 108 | poem_generator_model_path, use_auth_token=True 109 | ) 110 | poem_tokenizer = AutoTokenizer.from_pretrained( 111 | poem_generator_model_path, use_auth_token=True 112 | ) 113 | poem_generator.to(device) 114 | poem_generator.eval() 115 | 116 | hk_poem_generator_model_path = "ddobokki/gpt2_poem" 117 | hk_poem_generator = AutoModelForCausalLM.from_pretrained( 118 | hk_poem_generator_model_path 119 | ) 120 | hk_poem_tokenizer = AutoTokenizer.from_pretrained(hk_poem_generator_model_path) 121 | hk_poem_generator.to(device) 122 | hk_poem_generator.eval() 123 | 124 | print("generator model load") 125 | 126 | app.run(host="0.0.0.0", port=6006, debug=True, use_reloader=False) 127 | -------------------------------------------------------------------------------- /web/db_utils.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | 3 | 4 | def create_table(): 5 | conn = sqlite3.connect("db.db") 6 | c = conn.cursor() 7 | 8 | c.execute( 9 | """ 10 | CREATE TABLE CLIENT ( 11 | ID varchar(20) PRIMARY KEY 12 | ); 13 | """ 14 | ) 15 | 16 | c.execute( 17 | """ 18 | CREATE TABLE POEM ( 19 | POEM_ID integer PRIMARY KEY AUTOINCREMENT, 20 | CLIENT_ID varchar(20), 21 | img BLOB NOT NULL, 22 | POEM varchar(200), 23 | FEEDBACK int, 24 | LastCreated DEFAULT CURRENT_TIMESTAMP, 25 | FOREIGN KEY(CLIENT_ID) 26 | REFERENCES CLINET(ID) 27 | ); 28 | """ 29 | ) 30 | conn.commit() 31 | 32 | 33 | if __name__ == "__main__": 34 | create_table() 35 | -------------------------------------------------------------------------------- /web/utils.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, redirect, render_template, flash, url_for 2 | from urllib.parse import urlparse 3 | 4 | from werkzeug.utils import secure_filename 5 | import os 6 | 7 | from io import BytesIO 8 | from PIL import Image 9 | from time import perf_counter 10 | 11 | import requests 12 | import logging 13 | import base64 14 | 15 | from transformers import ( 16 | VisionEncoderDecoderModel, 17 | AutoModelForCausalLM, 18 | AutoTokenizer, 19 | PreTrainedTokenizerFast, 20 | ViTFeatureExtractor, 21 | ) 22 | import torch 23 | import numpy as np 24 | 25 | import sqlite3 26 | 27 | from konlpy.tag import Okt 28 | from collections import Counter 29 | 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | 32 | 33 | def generate_poem_from_image( 34 | vision_encoder_decoder_model, 35 | vision_encoder_decoder_tokenizer, 36 | feature_extractor, 37 | poem_generator, 38 | poem_tokenizer, 39 | hk_poem_generator, 40 | hk_poem_tokenizer, 41 | file_folder, 42 | filename, 43 | ): 44 | try: 45 | img = Image.open(os.path.join(file_folder, filename)).convert("RGB") 46 | except: 47 | return "" 48 | try: 49 | pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values 50 | description = generate_caption( 51 | vision_encoder_decoder_model, vision_encoder_decoder_tokenizer, pixel_values 52 | ) 53 | 54 | hk_description = "" + description[0] + "" 55 | description = "@" + description[0] + "@" 56 | generated_texts = [] 57 | 58 | temp_generated_texts = generate_poem( 59 | poem_generator, poem_tokenizer, description 60 | ) 61 | temp_generated_texts = map( 62 | lambda x: "\n".join(x.split("\n")[:-1]), temp_generated_texts 63 | ) 64 | temp_generated_texts = list(temp_generated_texts) 65 | generated_texts.extend(temp_generated_texts) 66 | print(temp_generated_texts) 67 | 68 | temp_generated_texts = hk_generate_poem( 69 | hk_poem_generator, hk_poem_tokenizer, hk_description 70 | ) 71 | temp_generated_texts = map( 72 | lambda x: "\n".join(x.split("\n")[:-1]), temp_generated_texts 73 | ) 74 | temp_generated_texts = list(temp_generated_texts) 75 | generated_texts.extend(list(temp_generated_texts)) 76 | print(temp_generated_texts) 77 | 78 | except: 79 | return "실패" 80 | return generated_texts 81 | 82 | 83 | def generate_caption( 84 | vision_encoder_decoder_model, vision_encoder_decoder_tokenizer, pixel_values 85 | ): 86 | generated_ids = vision_encoder_decoder_model.generate( 87 | pixel_values.to(device), num_beams=5 88 | ) 89 | generated_text = vision_encoder_decoder_tokenizer.batch_decode( 90 | generated_ids, skip_special_tokens=True 91 | ) 92 | return generated_text 93 | 94 | 95 | def generate_poem(poem_generator, poem_tokenizer, input_text): 96 | input_ids = poem_tokenizer.encode(input_text, return_tensors="pt").to(device) 97 | 98 | with torch.no_grad(): 99 | outputs = poem_generator.generate( 100 | input_ids, 101 | max_length=100, 102 | repetition_penalty=2.0, 103 | pad_token_id=poem_tokenizer.pad_token_id, 104 | eos_token_id=poem_tokenizer.eos_token_id, 105 | bos_token_id=poem_tokenizer.bos_token_id, 106 | bad_word_ids=[[38573], [408]], 107 | do_sample=True, 108 | top_k=15, 109 | top_p=0.75, 110 | num_return_sequences=3, 111 | ) 112 | generated_texts = list( 113 | map(lambda x: poem_tokenizer.decode(x, skip_special_tokens=True), outputs) 114 | ) 115 | generated_texts = map(lambda x: "\n".join(x.split("\n")[1:]), generated_texts) 116 | 117 | return generated_texts 118 | 119 | 120 | def hk_generate_poem(poem_generator, poem_tokenizer, input_text): 121 | input_ids = poem_tokenizer.encode(input_text, return_tensors="pt").to(device) 122 | 123 | with torch.no_grad(): 124 | outputs = poem_generator.generate( 125 | input_ids, max_length=100, num_beams=10, no_repeat_ngram_size=2 126 | ) 127 | generated_texts = list( 128 | map(lambda x: poem_tokenizer.decode(x, skip_special_tokens=True), outputs) 129 | ) 130 | generated_texts = map(lambda x: "\n".join(x.split("\n")[1:]), generated_texts) 131 | 132 | return generated_texts 133 | -------------------------------------------------------------------------------- /web/web/assets/images/profile.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boostcampaitech2/final-project-level3-nlp-08/cce85d868879e9ad23901125fb7ca0d7d129f450/web/web/assets/images/profile.jpeg -------------------------------------------------------------------------------- /web/web/templates/about.html: -------------------------------------------------------------------------------- 1 | {% extends 'layout.html' %} 2 | {% block content %} 3 |
4 | 5 |

6 |
7 | 청계산 셰르파들은 산의 정상 부근까지 기어올라가 삽과 송곳을 들고 나무들을 베어내었다.
8 | 세상은 무너지고 사람들은 흩어졌고
9 | 나는 내 안의 무언가를 끄집어내어
10 | 아무도 눈치채지 못하게 조용히 묻어버렸다.
11 | 나는 여전히 울고, 너는 또 다른 울음.
12 |

13 |
14 | {% endblock %} -------------------------------------------------------------------------------- /web/web/templates/js.js: -------------------------------------------------------------------------------- 1 | function getParam(sname) { 2 | var params = location.search.substr(location.search.indexOf("?") + 1); 3 | var sval = ""; 4 | 5 | params = params.split("&"); 6 | 7 | for (var i = 0; i < params.length; i++) { 8 | temp = params[i].split("="); 9 | if ([temp[0]] == sname) { sval = temp[1]; } 10 | } 11 | return sval; 12 | } 13 | 14 | 15 | function removePoemCards() { 16 | document.getElementById('poem container').remove() 17 | } 18 | 19 | function createCardWithImage(poem) { 20 | var cardDiv = document.getElementById("poem-card-div"); 21 | 22 | var newDIV = document.createElement("div"); 23 | newDIV.setAttribute("class", "ui centered card"); 24 | newDIV.setAttribute("id", "poem-card"); 25 | newDIV.setAttribute("style", "width: 400px"); 26 | 27 | var imgDIV = document.createElement("div"); 28 | imgDIV.setAttribute("class", "image"); 29 | 30 | var img = document.createElement("img"); 31 | img.setAttribute("src", "web/assets/uploads/" + getParam('filename')); 32 | img.setAttribute("width", 400); 33 | img.setAttribute("height", 400); 34 | 35 | var contentDIV = document.createElement("div"); 36 | contentDIV.setAttribute("class", "content"); 37 | 38 | var contentDiscriptionDIV = document.createElement("div"); 39 | contentDiscriptionDIV.setAttribute("class", "div"); 40 | 41 | let innerPoem = '' 42 | for (const row of poem.split('\n')) { 43 | innerPoem += row + "
"; 44 | } 45 | contentDiscriptionDIV.innerHTML = innerPoem; 46 | 47 | removePoemCards(); 48 | contentDIV.appendChild(contentDiscriptionDIV); 49 | imgDIV.appendChild(img); 50 | 51 | newDIV.appendChild(imgDIV); 52 | newDIV.appendChild(contentDIV); 53 | 54 | let cardMsg = document.createElement('h2'); 55 | cardMsg.innerHTML = 'Your Poem Card'; 56 | cardMsg.setAttribute("class", "ui center aligned header"); 57 | 58 | cardDiv.appendChild(cardMsg); 59 | cardDiv.appendChild(newDIV); 60 | 61 | let downloadButton = document.createElement('button'); 62 | downloadButton.setAttribute('align', 'center'); 63 | downloadButton.setAttribute('class', 'positive ui button') 64 | downloadButton.setAttribute('id', 'download'); 65 | downloadButton.innerHTML = 'Download Card'; 66 | downloadButton.setAttribute('onclick', 'download();'); 67 | cardDiv.appendChild(downloadButton) 68 | 69 | } 70 | 71 | function download() { 72 | // 캡처 라이브러리를 통해 canvas 오브젝트 받고 이미지 파일로 리턴함 73 | html2canvas(document.querySelector("#poem-card")).then(canvas => { 74 | saveAs(canvas.toDataURL('image/jpg'), "poem_card.jpg"); //다운로드 되는 이미지 파일 이름 지정 75 | }); 76 | }; 77 | function saveAs(uri, filename) { 78 | // 캡처된 파일을 이미지 파일로 내보냄 79 | var link = document.createElement('a'); 80 | if (typeof link.download === 'string') { 81 | link.href = uri; 82 | link.download = filename; 83 | document.body.appendChild(link); 84 | link.click(); 85 | document.body.removeChild(link); 86 | } else { 87 | window.open(uri); 88 | } 89 | }; 90 | 91 | {/* */ } 92 | {/*

Image2Poem

93 |

Generate poem from image with AI

*/} 94 | 95 | 96 | {/*
97 |
98 | 99 |
100 |
101 |
102 |
103 | {% for row in generated_poem.split('\n') %} 104 | {{row}}
105 | {% endfor %} 106 |
107 |
*/} -------------------------------------------------------------------------------- /web/web/templates/layout.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 청계산셰르파의 감성시인 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 33 | {% block content %} 34 | {% endblock %} 35 | 36 | 37 | 38 | 39 | 57 | 58 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /web/web/templates/responsive.html: -------------------------------------------------------------------------------- 1 | {% extends 'layout.html' %} 2 | {% block content %} 3 | 4 |

Image2Poem

5 |

Generate poem from image with AI

6 | 7 |
8 | 9 |
10 | 11 |

12 | {% with messages = get_flashed_messages() %} 13 | {% if messages %} 14 |

    15 | {% for message in messages %} 16 |
  • {{ message }}
  • 17 | {% endfor %} 18 |
19 | {% endif %} 20 | {% endwith %} 21 |

22 | {% if filename %} 23 |
24 | 26 |
27 | {% endif %} 28 |
29 |
30 |

31 | 33 | 34 | 38 |

39 |
40 |

41 | 43 | 47 |

48 |
49 | 50 | 64 |
65 | 66 |
67 | 68 | 69 |
70 | 72 | 74 | 75 |
76 | 77 |
78 | {% if generated_poems %} 79 |

Select Poem You Want

80 |
81 | 101 |
102 | 103 | {% endif %} 104 |
105 |
106 | 107 | 108 |
109 |
110 | 111 | 112 | 113 | 114 | {% endblock %} --------------------------------------------------------------------------------