├── requirements.txt ├── README.md ├── LICENSE └── predict_tag.py /requirements.txt: -------------------------------------------------------------------------------- 1 | git+http://github.com/nico-opendata/niconico_chainer_models.git#egg=niconico_chainer_models 2 | 3 | six 4 | pillow 5 | chainer>=1.12 6 | argparse 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | nico-illust tag prediction 2 | ================================== 3 | 4 | 必要なモデルファイル、タグ一覧は [nico-opendata.jp](http://nico-opendata.jp) からダウンロードしてください。 5 | 解凍して出来たディレクトリ下の `v1/` 下にある `model.npz` 及び `tags.txt` をカレントディレクトリにコピーして使います。 6 | 7 | USAGE 8 | -------------- 9 | 10 | 依存ライブラリのインストール 11 | 12 | ```sh 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | CPUで実行 17 | 18 | ```sh 19 | python predict_tag.py \ 20 | --gpu=-1 \ 21 | --tags=tags.txt 22 | --model=model.npz 23 | http://lohas.nicoseiga.jp/thumb/4313120i 24 | # tag: 川内改二 / score: 0.9832866787910461 25 | # tag: 艦これ / score: 0.9811543226242065 26 | # tag: 夜戦忍者 / score: 0.934027910232544 27 | # : 28 | # と出力 29 | ``` 30 | 31 | GPUで実行 32 | 33 | ```sh 34 | python predict_tag.py \ 35 | --gpu=0 \ 36 | --tags=tags.txt 37 | --model=model.npz 38 | http://lohas.nicoseiga.jp/thumb/4313120i 39 | ``` 40 | 41 | ## License 42 | 43 | MIT License (see `LICENSE` file). 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Dwango Co., Ltd. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /predict_tag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import sys 4 | 5 | import chainer 6 | import numpy 7 | import six 8 | from niconico_chainer_models.google_net import GoogLeNet 9 | from PIL import Image, ImageFile 10 | 11 | 12 | def resize(img, size): 13 | h, w = img.size 14 | ratio = size / float(min(h, w)) 15 | h_ = int(math.ceil(h * ratio)) 16 | w_ = int(math.ceil(w * ratio)) 17 | img = img.resize((h_, w_)) 18 | return img 19 | 20 | 21 | def fetch_image(url): 22 | 23 | response = six.moves.urllib.request.urlopen(url) 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | img = Image.open(response) 26 | 27 | if img.mode != 'RGB': # not RGB 28 | img = img.convert('RGB') 29 | 30 | img = resize(img, 224) 31 | 32 | x = numpy.asarray(img).astype('f') 33 | x = x[:224, :224, :3] # crop 34 | 35 | x /= 255.0 # normalize 36 | x = x.transpose((2, 0, 1)) 37 | return x 38 | 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--gpu', type=int, default=-1) 42 | parser.add_argument('--model', default='model.npz') 43 | parser.add_argument('--tags', default='tags.txt') 44 | parser.add_argument('image_url') 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | args = parser.parse_args() 50 | 51 | if args.gpu >= 0: 52 | chainer.cuda.get_device(args.gpu).use() 53 | xp = chainer.cuda.cupy 54 | else: 55 | xp = numpy 56 | 57 | # load model 58 | sys.stderr.write("\r model loading...") 59 | model = GoogLeNet() 60 | chainer.serializers.load_npz(args.model, model) 61 | if args.gpu >= 0: 62 | model.to_gpu() 63 | 64 | # load tags 65 | tags = [line.rstrip() for line in open(args.tags)] 66 | tag_dict = dict((i, tag) for i, tag in enumerate(tags)) 67 | 68 | # load image 69 | sys.stderr.write("\r image fetching...") 70 | x = xp.array([fetch_image(args.image_url)]) 71 | z = xp.zeros((1, 8)).astype('f') 72 | 73 | sys.stderr.write("\r tag predicting...") 74 | predicted = model.tag(x, z).data[0] 75 | 76 | sys.stderr.write("\r") 77 | top_10 = sorted(enumerate(predicted), key=lambda index_value: -index_value[1])[:10] 78 | 79 | for tag, score in top_10: 80 | if tag in tag_dict: 81 | tag_name = tag_dict[tag] 82 | print("tag: {} / score: {}".format(tag_name, score)) 83 | --------------------------------------------------------------------------------