├── .gitignore ├── README.md ├── amt ├── __init__.py ├── admin.py ├── apps.py ├── constants.py ├── consumers.py ├── migrations │ ├── 0001_initial.py │ ├── 0002_auto_20170501_1309.py │ ├── 0003_feedback.py │ ├── 0004_auto_20170502_0422.py │ ├── 0005_imageranking_score.py │ ├── 0006_auto_20170502_1849.py │ ├── 0007_auto_20170502_1933.py │ ├── 0008_auto_20170707_1013.py │ └── __init__.py ├── models.py ├── routing.py ├── sender.py ├── static │ ├── css │ │ ├── scrollbar.css │ │ └── style.css │ ├── images │ │ ├── Preloader_2.gif │ │ ├── abot.png │ │ ├── bot.jpg │ │ └── new_bot.png │ ├── js │ │ ├── mixitup.min.js │ │ └── queryparam_helper.js │ └── slack_sound.mp3 ├── templates │ └── amt │ │ ├── base.html │ │ ├── feedback.html │ │ ├── index.html │ │ ├── intro.html │ │ ├── loader.html │ │ ├── modal.html │ │ ├── mturk.html │ │ └── plot.html ├── templatetags │ ├── __init__.py │ └── range.py ├── urls.py ├── utils.py └── views.py ├── chatbot ├── __init__.py ├── dataloader.lua ├── im-hist-enc-dec-answerer │ ├── lstm.lua │ └── specificModel.lua ├── modelAnswerer.lua ├── optim_updates.lua ├── opts.lua ├── prepro_ques.py ├── rl_evaluate.lua ├── rl_worker.py ├── sl_evaluate.lua ├── sl_worker.py ├── testAnswerer.lua └── utils.lua ├── data └── pools.json ├── demo ├── __init__.py ├── asgi.py ├── settings.py ├── urls.py └── wsgi.py ├── download_models.sh ├── manage.py ├── ques_feat.json ├── requirements.txt └── static └── img └── guesswhich.png /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | venv/ 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GuessWhich 2 | 3 | ## Introduction 4 | 5 | **Evaluating Visual Conversational Agents via Cooperative Human-AI Games** 6 | Prithvijit Chattopadhyay*, Deshraj Yadav*, Viraj Prabhu, Arjun Chandrashekharan, Abhishek Das, Stefan Lee, Dhruv Batra, Devi Parikh 7 | [HCOMP 2017][4] 8 | 9 | This repository contains code for setting up the **GuessWhich Game** along with Amazon Mechinical Turk (AMT) integration for real time data collection. The data collection settings can be changed easily by modifying certain configurations defined [here](https://github.com/VT-vision-lab/GuessWhich/blob/master/amt/constants.py). 10 | 11 | ## Abstract 12 | 13 | As AI continues to advance, human-AI teams are inevitable. However, progress in AI is routinely measured in isolation, without a human in the loop. It is important to measure how progress in AI translates to humans being able to accomplish tasks better; i.e., the performance of human-AI teams. In this work, we design a cooperative game – GuessWhich to measure human-AI team performance in the specific context of the AI being a visual conversational agent. The AI, which we call ALICE, is provided an image which is unseen by the human. The human then asks ALICE questions aboutthis secret image to identify it from a fixed pool of images. 14 | 15 | We measure performance of the human-ALICE team by the number of guesses it takes the human to correctly identify the secret image after a fixed number of dialog rounds with ALICE. We compare performance of the human-ALICE teams for two versions of ALICE. While AI literature shows that one version outperforms the other when paired with another AI, we find that this improvement in AI-AI performance does not translate to improved human-AI performance. 16 | 17 | 18 | ## Installation Instructions 19 | 20 | ### Installing the essential requirements 21 | 22 | ```shell 23 | sudo apt-get install -y git python-pip python-dev 24 | sudo apt-get install -y python-dev 25 | sudo apt-get install -y autoconf automake libtool curl make g++ unzip 26 | sudo apt-get install -y libgflags-dev libgoogle-glog-dev liblmdb-dev 27 | sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libhdf5-serial-dev protobuf-compiler 28 | ``` 29 | 30 | ### Install Torch 31 | 32 | ```shell 33 | git clone https://github.com/torch/distro.git ~/torch --recursive 34 | cd ~/torch; bash install-deps; 35 | ./install.sh 36 | source ~/.bashrc 37 | ``` 38 | 39 | ### Install PyTorch(Python Lua Wrapper) 40 | 41 | ```shell 42 | git clone https://github.com/hughperkins/pytorch.git 43 | cd pytorch 44 | source ~/torch/install/bin/torch-activate 45 | ./build.sh 46 | ``` 47 | 48 | ### Install RabbitMQ and Redis Server 49 | 50 | ```shell 51 | sudo apt-get install -y redis-server rabbitmq-server 52 | sudo rabbitmq-plugins enable rabbitmq_management 53 | sudo service rabbitmq-server restart 54 | sudo service redis-server restart 55 | ``` 56 | 57 | ### Lua dependencies 58 | 59 | ```shell 60 | luarocks install loadcaffe 61 | ``` 62 | 63 | The below two dependencies are only required if you are going to use GPU 64 | 65 | ```shell 66 | luarocks install cudnn 67 | luarocks install cunn 68 | ``` 69 | 70 | ### Cuda Installation 71 | 72 | Note: CUDA and cuDNN is only required if you are going to use GPU 73 | 74 | Download and install CUDA and cuDNN from [nvidia website](https://developer.nvidia.com/cuda-downloads) 75 | 76 | ### Install dependencies 77 | 78 | ```shell 79 | git clone https://github.com/Cloud-CV/GuessWhich.git 80 | cd GuessWhich 81 | sh download_models.sh 82 | pip install -r requirements.txt 83 | ``` 84 | 85 | ### Create the database 86 | 87 | ```shell 88 | python manage.py makemigrations amt 89 | python manage.py migrate 90 | ``` 91 | 92 | ### Running the RabbitMQ workers and Development Server 93 | 94 | Open 3 different terminal sessions and run the following commands: 95 | 96 | ```shell 97 | cd chatbot && python sl_worker.py 98 | cd chatbot && python rl_worker.py 99 | python manage.py runserver 100 | ``` 101 | 102 | You are all set now. Visit http://127.0.0.1:8000 and you will have your demo running successfully. 103 | 104 | 105 | ## Cite this work 106 | 107 | If you find this code useful, consider citing our work: 108 | 109 | ``` 110 | @inproceedings{visdial_eval, 111 | title={Evaluating Visual Conversational Agents via Cooperative Human-AI Games}, 112 | author={Prithvijit Chattopadhyay and Deshraj Yadav and Viraj Prabhu and Arjun Chandrasekaran and Abhishek Das and Stefan Lee and Dhruv Batra and Devi Parikh}, 113 | booktitle={Proceedings of the Fifth AAAI Conference on Human Computation and Crowdsourcing (HCOMP)}, 114 | year={2017} 115 | } 116 | ``` 117 | 118 | ## Contributors 119 | 120 | * [Deshraj Yadav][2] (deshraj@gatech.edu) 121 | 122 | ## License 123 | 124 | BSD 125 | 126 | ## Credits 127 | 128 | - Vicki Image: "[Robot-clip-art-book-covers-feJCV3-clipart](https://commons.wikimedia.org/wiki/File:Robot-clip-art-book-covers-feJCV3-clipart.png)" by [Wikimedia Commons](https://commons.wikimedia.org) is licensed under [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.en) 129 | 130 | [1]: https://arxiv.org/abs/1611.08669 131 | [2]: http://deshraj.github.io 132 | [4]: http://www.humancomputation.com/2017/ 133 | -------------------------------------------------------------------------------- /amt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/__init__.py -------------------------------------------------------------------------------- /amt/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import GameRound, ImageRanking, Feedback 4 | from import_export import resources 5 | from import_export.admin import ImportExportModelAdmin 6 | 7 | 8 | class GameRoundResource(resources.ModelResource): 9 | 10 | class Meta: 11 | model = GameRound 12 | 13 | 14 | class GameRoundAdmin(ImportExportModelAdmin): 15 | list_display = ('socket_id', 'user_picked_image', 'worker_id', 'assignment_id', 'level', 'task', 16 | 'hit_id', 'game_id', 'round_id', 'question', 'answer', 'target_image', 'created_at', 'bot',) 17 | list_filter = ('bot', 'worker_id', 'task', ) 18 | search_fields = ['socket_id', 'user_picked_image', 'worker_id', 'assignment_id', 'level', 19 | 'hit_id', 'game_id', 'round_id', 'question', 'answer', 'target_image', 'created_at', 'bot', ] 20 | resource_class = GameRoundResource 21 | 22 | 23 | class ImageRankingResource(resources.ModelResource): 24 | 25 | class Meta: 26 | model = ImageRanking 27 | 28 | 29 | class ImageRankingAdmin(ImportExportModelAdmin): 30 | list_display = ('socket_id', 'final_image_list', 'worker_id', 'assignment_id', 'level', 31 | 'task', 'hit_id', 'game_id', 'target_image', 'created_at', 'bot', 'score', ) 32 | list_filter = ('bot', 'worker_id', 'task', ) 33 | search_fields = ['socket_id', 'final_image_list', 'worker_id', 'assignment_id', 34 | 'level', 'hit_id', 'game_id', 'target_image', 'created_at', 'bot', 'score', ] 35 | resource_class = ImageRankingResource 36 | 37 | 38 | class FeedbackResource(resources.ModelResource): 39 | 40 | class Meta: 41 | model = Feedback 42 | 43 | 44 | class FeedbackAdmin(ImportExportModelAdmin): 45 | list_display = ('hit_id', 'assignment_id', 'worker_id', 'understand_question', 'task', 'understand_image', 46 | 'fluency', 'detail', 'accurate', 'consistent', 'comments', 'level', 'game_id', 'bot',) 47 | list_filter = ('bot', 'worker_id', 'assignment_id', 'task', ) 48 | resource_class = FeedbackResource 49 | 50 | admin.site.register(GameRound, GameRoundAdmin) 51 | admin.site.register(ImageRanking, ImageRankingAdmin) 52 | admin.site.register(Feedback, FeedbackAdmin) 53 | -------------------------------------------------------------------------------- /amt/apps.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.apps import AppConfig 4 | 5 | 6 | class VickiConfig(AppConfig): 7 | name = 'vicki' 8 | -------------------------------------------------------------------------------- /amt/constants.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(settings.BASE_DIR, 'chatbot')) 6 | 7 | POOL_IMAGES_URL = os.path.join(settings.MEDIA_URL, 'val2014/') 8 | 9 | BOT_INTORDUCTION_MESSAGE = [ 10 | "Hi, my name is Abot. I am an Artificial Intelligence." \ 11 | "I have been assigned one of these images as the target image." \ 12 | "I am not allowed to show you the image, but as a start," \ 13 | "I will describe the image to you in a sentence." \ 14 | "You can then ask me follow up questions about it. " \ 15 | "When ready, submit one of the images on the left as your best guess. " \ 16 | "I will try to describe the image and answer your questions, but I am not perfect." \ 17 | "I make quite a few mistakes. I hope we can work together to find the image! " \ 18 | "Let's do this! Note: My knowledge of English is limited." \ 19 | "Sometimes if I don't know the right word, I say UNK. " \ 20 | "You will win points based on how accurately you are able to guess.", 21 | ] 22 | 23 | SL_VISDIAL_CONFIG = { 24 | 'inputJson': os.path.join(settings.BASE_DIR, 'chatbot/data/chat_processed_params.json'), 25 | 'qBotpath': os.path.join(settings.BASE_DIR, 'chatbot/data/qbot_hre_qih_sl.t7'), 26 | 'aBotpath': os.path.join(settings.BASE_DIR, 'chatbot/data/abot_hre_qih_sl.t7'), 27 | 'gpuid': 0, 28 | 'backend': 'cudnn', 29 | 'imfeatpath': os.path.join(settings.BASE_DIR, 'chatbot/data/all_pools_vgg16_features.t7'), 30 | } 31 | 32 | SL_VISDIAL_LUA_PATH = "sl_evaluate.lua" 33 | 34 | 35 | RL_VISDIAL_CONFIG = { 36 | 'inputJson': os.path.join(settings.BASE_DIR, 'chatbot/data/chat_processed_params.json'), 37 | 'qBotpath': os.path.join(settings.BASE_DIR, 'chatbot/data/qbot_rl.t7'), 38 | 'aBotpath': os.path.join(settings.BASE_DIR, 'chatbot/data/abot_rl.t7'), 39 | 'gpuid': 0, 40 | 'backend': 'cudnn', 41 | 'imfeatpath': os.path.join(settings.BASE_DIR, 'chatbot/data/all_pools_vgg16_features.t7'), 42 | } 43 | 44 | RL_VISDIAL_LUA_PATH = "rl_evaluate.lua" 45 | 46 | NUMBER_OF_ROUNDS_IN_A_GAME = 9 47 | 48 | NUMBER_OF_GAMES_IN_A_HIT = 1 49 | 50 | AWS_ACCESS_KEY_ID = "" 51 | 52 | AWS_SECRET_ACCESS_KEY = "" 53 | 54 | QUALIFICATION_TYPE_ID = "" 55 | 56 | AMT_HOSTNAME = 'mechanicalturk.amazonaws.com' 57 | 58 | MAX_BONUS_IN_A_GAME = 200 59 | 60 | BONUS_DEDUCTION_FOR_EACH_CLICK = 10 61 | 62 | BONUS_FOR_CORRECT_IMAGE_AFTER_EACH_ROUND = 10 63 | -------------------------------------------------------------------------------- /amt/consumers.py: -------------------------------------------------------------------------------- 1 | from django.utils import timezone 2 | from django.conf import settings 3 | 4 | from .utils import log_to_terminal, fc7_sort 5 | from .sender import chatbot 6 | import constants as constants 7 | from .models import GameRound, ImageRanking 8 | 9 | from channels import Group 10 | 11 | import json 12 | import redis 13 | import datetime 14 | import os 15 | import shutil 16 | import pdb 17 | 18 | 19 | r = redis.StrictRedis(host='localhost', port=6379, db=0) 20 | 21 | 22 | def ws_connect(message): 23 | "Method called when a user is connected through SocketIO" 24 | pass 25 | 26 | 27 | def ws_message(message): 28 | "Method called when there is message from the SocketIO client" 29 | 30 | body = json.loads(message.content['text']) 31 | 32 | if body["event"] == "ConnectionEstablished": 33 | # Event when the user is connected to the socketio client 34 | Group(body["socketid"]).add(message.reply_channel) 35 | log_to_terminal(body["socketid"], { 36 | "info": "User added to the Channel Group"}) 37 | 38 | elif body["event"] == "start": 39 | # Event when the user starts to play the game 40 | current_datetime = timezone.now() 41 | r.set("start_time_{}".format( 42 | body["socketid"]), 43 | current_datetime.strftime("%I:%M%p on %B %d, %Y")) 44 | 45 | elif body["event"] == "questionSubmitted": 46 | # Event when the user submits a question to the backend 47 | body['question'] = body['question'].lower() 48 | bot = body['bot'] 49 | chatbot(body['question'], 50 | body['prev_history'], 51 | os.path.join(settings.BASE_DIR, body['target_image'][1:]), 52 | body["socketid"], 53 | bot) 54 | 55 | elif body['event'] == "imageSubmitted": 56 | # Event when the user selects an image after each round of a game 57 | GameRound.objects.create( 58 | socket_id=body['socketid'], 59 | user_picked_image=body['user_picked_image'], 60 | worker_id=body['worker_id'], 61 | assignment_id=body['assignment_id'], 62 | level=body['level'], 63 | hit_id=body['hit_id'], 64 | game_id=body['game_id'], 65 | round_id=body['round_id'], 66 | question=body['question'], 67 | answer=body['answer'].replace("", "").replace("", ""), 68 | history=body['history'], 69 | target_image=body['target_image'], 70 | bot=body['bot'], 71 | task=body['task'], 72 | ) 73 | log_to_terminal(body["socketid"], {"image_selection_result": True}) 74 | 75 | elif body['event'] == 'finalImagesSelected': 76 | # Event when the user submit the ranking of after completing all rounds 77 | ImageRanking.objects.create( 78 | socket_id=body['socketid'], 79 | final_image_list=body['final_image_list'], 80 | worker_id=body['worker_id'], 81 | assignment_id=body['assignment_id'], 82 | level=body['level'], 83 | hit_id=body['hit_id'], 84 | game_id=body['game_id'], 85 | bot=body['bot'], 86 | target_image=body['target_image'], 87 | score=body['bonus'], 88 | task=body['task'], 89 | ) 90 | 91 | 92 | def ws_disconnect(message): 93 | "Method invoked when the client disconnects the socket connection" 94 | pass 95 | -------------------------------------------------------------------------------- /amt/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-03 08:27 3 | from __future__ import unicode_literals 4 | 5 | import django.contrib.postgres.fields 6 | from django.db import migrations, models 7 | 8 | 9 | class Migration(migrations.Migration): 10 | 11 | initial = True 12 | 13 | dependencies = [ 14 | ] 15 | 16 | operations = [ 17 | migrations.CreateModel( 18 | name='Feedback', 19 | fields=[ 20 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 21 | ('understand_question', models.CharField(blank=True, max_length=200, null=True)), 22 | ('understand_image', models.CharField(blank=True, max_length=200, null=True)), 23 | ('fluency', models.CharField(blank=True, max_length=200, null=True)), 24 | ('detail', models.CharField(blank=True, max_length=200, null=True)), 25 | ('accurate', models.CharField(blank=True, max_length=200, null=True)), 26 | ('consistent', models.CharField(blank=True, max_length=200, null=True)), 27 | ('comments', models.CharField(blank=True, max_length=200, null=True)), 28 | ('worker_id', models.CharField(blank=True, max_length=100, null=True)), 29 | ('assignment_id', models.CharField(blank=True, max_length=100, null=True)), 30 | ('level', models.CharField(blank=True, max_length=100, null=True)), 31 | ('hit_id', models.CharField(blank=True, max_length=100, null=True)), 32 | ('game_id', models.CharField(blank=True, max_length=100, null=True)), 33 | ('created_at', models.DateTimeField(auto_now_add=True)), 34 | ('updated_at', models.DateTimeField(auto_now=True)), 35 | ('bot', models.CharField(blank=True, max_length=100, null=True)), 36 | ('is_active', models.NullBooleanField(default=True)), 37 | ], 38 | ), 39 | migrations.CreateModel( 40 | name='GameRound', 41 | fields=[ 42 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 43 | ('socket_id', models.CharField(blank=True, max_length=100, null=True)), 44 | ('worker_id', models.CharField(blank=True, max_length=100, null=True)), 45 | ('assignment_id', models.CharField(blank=True, max_length=100, null=True)), 46 | ('level', models.CharField(blank=True, max_length=100, null=True)), 47 | ('hit_id', models.CharField(blank=True, max_length=100, null=True)), 48 | ('game_id', models.CharField(blank=True, max_length=100, null=True)), 49 | ('round_id', models.CharField(blank=True, max_length=100, null=True)), 50 | ('question', models.CharField(blank=True, max_length=100, null=True)), 51 | ('answer', models.CharField(blank=True, max_length=100, null=True)), 52 | ('history', models.CharField(blank=True, max_length=10000, null=True)), 53 | ('target_image', models.CharField(blank=True, max_length=100, null=True)), 54 | ('bot', models.CharField(blank=True, max_length=100, null=True)), 55 | ('user_picked_image', models.CharField(blank=True, max_length=100, null=True)), 56 | ('created_at', models.DateTimeField(auto_now_add=True)), 57 | ('updated_at', models.DateTimeField(auto_now=True)), 58 | ('is_active', models.NullBooleanField(default=True)), 59 | ], 60 | ), 61 | migrations.CreateModel( 62 | name='ImagePool', 63 | fields=[ 64 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 65 | ('pool_id', models.CharField(blank=True, max_length=200, null=True)), 66 | ('caption', models.CharField(blank=True, max_length=1000, null=True)), 67 | ('easy_pool', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 68 | ('medium_pool', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 69 | ('hard_pool', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 70 | ('obj', models.CharField(blank=True, max_length=200, null=True)), 71 | ('target_image', models.CharField(blank=True, max_length=200, null=True)), 72 | ('is_active', models.NullBooleanField(default=False)), 73 | ], 74 | ), 75 | migrations.CreateModel( 76 | name='ImageRanking', 77 | fields=[ 78 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 79 | ('socket_id', models.CharField(blank=True, max_length=100, null=True)), 80 | ('worker_id', models.CharField(blank=True, max_length=100, null=True)), 81 | ('assignment_id', models.CharField(blank=True, max_length=100, null=True)), 82 | ('level', models.CharField(blank=True, max_length=100, null=True)), 83 | ('hit_id', models.CharField(blank=True, max_length=100, null=True)), 84 | ('game_id', models.CharField(blank=True, max_length=100, null=True)), 85 | ('target_image', models.CharField(blank=True, max_length=100, null=True)), 86 | ('final_image_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 87 | ('created_at', models.DateTimeField(auto_now_add=True)), 88 | ('updated_at', models.DateTimeField(auto_now=True)), 89 | ('bot', models.CharField(blank=True, max_length=100, null=True)), 90 | ('score', models.FloatField(default=0)), 91 | ('is_active', models.NullBooleanField(default=True)), 92 | ], 93 | ), 94 | ] 95 | -------------------------------------------------------------------------------- /amt/migrations/0002_auto_20170501_1309.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-01 13:09 3 | from __future__ import unicode_literals 4 | 5 | import django.contrib.postgres.fields 6 | from django.db import migrations, models 7 | import django.utils.timezone 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | dependencies = [ 13 | ('amt', '0001_initial'), 14 | ] 15 | 16 | operations = [ 17 | migrations.CreateModel( 18 | name='ImagePool', 19 | fields=[ 20 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 21 | ('pool_id', models.CharField(blank=True, max_length=200, null=True)), 22 | ('caption', models.CharField(blank=True, max_length=1000, null=True)), 23 | ('easy_pool', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 24 | ('medium_pool', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 25 | ('hard_pool', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 26 | ('obj', models.CharField(blank=True, max_length=200, null=True)), 27 | ('target_image', models.CharField(blank=True, max_length=200, null=True)), 28 | ('is_active', models.BooleanField(default=False)), 29 | ], 30 | ), 31 | migrations.CreateModel( 32 | name='ImageRanking', 33 | fields=[ 34 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 35 | ('socket_id', models.CharField(blank=True, max_length=100, null=True)), 36 | ('worker_id', models.CharField(blank=True, max_length=100, null=True)), 37 | ('assignment_id', models.CharField(blank=True, max_length=100, null=True)), 38 | ('level', models.CharField(blank=True, max_length=100, null=True)), 39 | ('hit_id', models.CharField(blank=True, max_length=100, null=True)), 40 | ('game_id', models.CharField(blank=True, max_length=100, null=True)), 41 | ('target_image', models.CharField(blank=True, max_length=100, null=True)), 42 | ('final_image_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=200), blank=True, size=None)), 43 | ('created_at', models.DateTimeField(auto_now_add=True)), 44 | ('updated_at', models.DateTimeField(auto_now=True)), 45 | ], 46 | ), 47 | migrations.RemoveField( 48 | model_name='gameround', 49 | name='fc7_sorted', 50 | ), 51 | migrations.RemoveField( 52 | model_name='gameround', 53 | name='human_sorted', 54 | ), 55 | migrations.AddField( 56 | model_name='gameround', 57 | name='created_at', 58 | field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now), 59 | preserve_default=False, 60 | ), 61 | migrations.AddField( 62 | model_name='gameround', 63 | name='history', 64 | field=models.CharField(blank=True, max_length=10000, null=True), 65 | ), 66 | migrations.AddField( 67 | model_name='gameround', 68 | name='level', 69 | field=models.CharField(blank=True, max_length=100, null=True), 70 | ), 71 | migrations.AddField( 72 | model_name='gameround', 73 | name='socket_id', 74 | field=models.CharField(blank=True, max_length=100, null=True), 75 | ), 76 | migrations.AddField( 77 | model_name='gameround', 78 | name='updated_at', 79 | field=models.DateTimeField(auto_now=True), 80 | ), 81 | ] 82 | -------------------------------------------------------------------------------- /amt/migrations/0003_feedback.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-02 03:27 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ('amt', '0002_auto_20170501_1309'), 12 | ] 13 | 14 | operations = [ 15 | migrations.CreateModel( 16 | name='Feedback', 17 | fields=[ 18 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 19 | ('understand_question', models.CharField(blank=True, max_length=200, null=True)), 20 | ('understand_image', models.CharField(blank=True, max_length=200, null=True)), 21 | ('fluency', models.CharField(blank=True, max_length=200, null=True)), 22 | ('detail', models.CharField(blank=True, max_length=200, null=True)), 23 | ('accurate', models.CharField(blank=True, max_length=200, null=True)), 24 | ('consistent', models.CharField(blank=True, max_length=200, null=True)), 25 | ('comments', models.CharField(blank=True, max_length=200, null=True)), 26 | ('worker_id', models.CharField(blank=True, max_length=100, null=True)), 27 | ('assignment_id', models.CharField(blank=True, max_length=100, null=True)), 28 | ('level', models.CharField(blank=True, max_length=100, null=True)), 29 | ('hit_id', models.CharField(blank=True, max_length=100, null=True)), 30 | ('game_id', models.CharField(blank=True, max_length=100, null=True)), 31 | ('created_at', models.DateTimeField(auto_now_add=True)), 32 | ('updated_at', models.DateTimeField(auto_now=True)), 33 | ], 34 | ), 35 | ] 36 | -------------------------------------------------------------------------------- /amt/migrations/0004_auto_20170502_0422.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-02 04:22 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ('amt', '0003_feedback'), 12 | ] 13 | 14 | operations = [ 15 | migrations.AddField( 16 | model_name='feedback', 17 | name='bot', 18 | field=models.CharField(blank=True, max_length=100, null=True), 19 | ), 20 | migrations.AddField( 21 | model_name='gameround', 22 | name='bot', 23 | field=models.CharField(blank=True, max_length=100, null=True), 24 | ), 25 | migrations.AddField( 26 | model_name='imageranking', 27 | name='bot', 28 | field=models.CharField(blank=True, max_length=100, null=True), 29 | ), 30 | ] 31 | -------------------------------------------------------------------------------- /amt/migrations/0005_imageranking_score.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-02 05:13 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ('amt', '0004_auto_20170502_0422'), 12 | ] 13 | 14 | operations = [ 15 | migrations.AddField( 16 | model_name='imageranking', 17 | name='score', 18 | field=models.FloatField(default=0), 19 | ), 20 | ] 21 | -------------------------------------------------------------------------------- /amt/migrations/0006_auto_20170502_1849.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-02 18:49 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ('amt', '0005_imageranking_score'), 12 | ] 13 | 14 | operations = [ 15 | migrations.AddField( 16 | model_name='feedback', 17 | name='is_active', 18 | field=models.BooleanField(default=True), 19 | ), 20 | migrations.AddField( 21 | model_name='gameround', 22 | name='is_active', 23 | field=models.BooleanField(default=True), 24 | ), 25 | migrations.AddField( 26 | model_name='imageranking', 27 | name='is_active', 28 | field=models.BooleanField(default=True), 29 | ), 30 | ] 31 | -------------------------------------------------------------------------------- /amt/migrations/0007_auto_20170502_1933.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-05-02 19:33 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ('amt', '0006_auto_20170502_1849'), 12 | ] 13 | 14 | operations = [ 15 | migrations.AlterField( 16 | model_name='feedback', 17 | name='is_active', 18 | field=models.NullBooleanField(default=True), 19 | ), 20 | migrations.AlterField( 21 | model_name='gameround', 22 | name='is_active', 23 | field=models.NullBooleanField(default=True), 24 | ), 25 | migrations.AlterField( 26 | model_name='imagepool', 27 | name='is_active', 28 | field=models.NullBooleanField(default=False), 29 | ), 30 | migrations.AlterField( 31 | model_name='imageranking', 32 | name='is_active', 33 | field=models.NullBooleanField(default=True), 34 | ), 35 | ] 36 | -------------------------------------------------------------------------------- /amt/migrations/0008_auto_20170707_1013.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.10.1 on 2017-07-07 10:13 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ('amt', '0007_auto_20170502_1933'), 12 | ] 13 | 14 | operations = [ 15 | migrations.AddField( 16 | model_name='feedback', 17 | name='task', 18 | field=models.CharField(blank=True, max_length=100, null=True), 19 | ), 20 | migrations.AddField( 21 | model_name='gameround', 22 | name='task', 23 | field=models.CharField(blank=True, max_length=100, null=True), 24 | ), 25 | migrations.AddField( 26 | model_name='imageranking', 27 | name='task', 28 | field=models.CharField(blank=True, max_length=100, null=True), 29 | ), 30 | ] 31 | -------------------------------------------------------------------------------- /amt/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/migrations/__init__.py -------------------------------------------------------------------------------- /amt/models.py: -------------------------------------------------------------------------------- 1 | # from __future__ import unicode_literals 2 | 3 | from django.db import models 4 | from django.core.urlresolvers import reverse 5 | from django.contrib.postgres.fields import ArrayField 6 | 7 | 8 | class GameRound(models.Model): 9 | """ 10 | Model depicts the game and each round of the Game 11 | """ 12 | socket_id = models.CharField(max_length=100, blank=True, null=True) 13 | worker_id = models.CharField(max_length=100, blank=True, null=True) 14 | assignment_id = models.CharField(max_length=100, blank=True, null=True) 15 | level = models.CharField(max_length=100, blank=True, null=True) 16 | task = models.CharField(max_length=100, blank=True, null=True) 17 | hit_id = models.CharField(max_length=100, blank=True, null=True) 18 | game_id = models.CharField(max_length=100, blank=True, null=True) 19 | round_id = models.CharField(max_length=100, blank=True, null=True) 20 | question = models.CharField(max_length=100, blank=True, null=True) 21 | answer = models.CharField(max_length=100, blank=True, null=True) 22 | history = models.CharField(max_length=10000, blank=True, null=True) 23 | target_image = models.CharField(max_length=100, blank=True, null=True) 24 | bot = models.CharField(max_length=100, blank=True, null=True) 25 | user_picked_image = models.CharField(max_length=100, blank=True, null=True) 26 | created_at = models.DateTimeField(auto_now_add=True) 27 | updated_at = models.DateTimeField(auto_now=True) 28 | is_active = models.NullBooleanField(default=True, blank=True, null=True) 29 | 30 | def __unicode__(self): 31 | return "%s : %s : %s" % (self.assignment_id, self.game_id, self.round_id) 32 | 33 | 34 | class ImageRanking(models.Model): 35 | socket_id = models.CharField(max_length=100, blank=True, null=True) 36 | worker_id = models.CharField(max_length=100, blank=True, null=True) 37 | assignment_id = models.CharField(max_length=100, blank=True, null=True) 38 | level = models.CharField(max_length=100, blank=True, null=True) 39 | task = models.CharField(max_length=100, blank=True, null=True) 40 | hit_id = models.CharField(max_length=100, blank=True, null=True) 41 | game_id = models.CharField(max_length=100, blank=True, null=True) 42 | target_image = models.CharField(max_length=100, blank=True, null=True) 43 | final_image_list = ArrayField(models.CharField(max_length=200), blank=True) 44 | created_at = models.DateTimeField(auto_now_add=True) 45 | updated_at = models.DateTimeField(auto_now=True) 46 | bot = models.CharField(max_length=100, blank=True, null=True) 47 | score = models.FloatField(default=0) 48 | is_active = models.NullBooleanField(default=True, blank=True, null=True) 49 | 50 | def __unicode__(self): 51 | return "%s : %s : %s" % (self.assignment_id, self.game_id, self.level) 52 | 53 | 54 | class ImagePool(models.Model): 55 | pool_id = models.CharField(max_length=200, blank=True, null=True) 56 | caption = models.CharField(max_length=1000, blank=True, null=True) 57 | easy_pool = ArrayField(models.CharField(max_length=200), blank=True) 58 | medium_pool = ArrayField(models.CharField(max_length=200), blank=True) 59 | hard_pool = ArrayField(models.CharField(max_length=200), blank=True) 60 | obj = models.CharField(max_length=200, blank=True, null=True) 61 | target_image = models.CharField(max_length=200, blank=True, null=True) 62 | is_active = models.NullBooleanField(default=False, blank=True, null=True) 63 | 64 | 65 | class Feedback(models.Model): 66 | understand_question = models.CharField( 67 | max_length=200, blank=True, null=True) 68 | understand_image = models.CharField(max_length=200, blank=True, null=True) 69 | fluency = models.CharField(max_length=200, blank=True, null=True) 70 | detail = models.CharField(max_length=200, blank=True, null=True) 71 | accurate = models.CharField(max_length=200, blank=True, null=True) 72 | consistent = models.CharField(max_length=200, blank=True, null=True) 73 | comments = models.CharField(max_length=200, blank=True, null=True) 74 | worker_id = models.CharField(max_length=100, blank=True, null=True) 75 | assignment_id = models.CharField(max_length=100, blank=True, null=True) 76 | level = models.CharField(max_length=100, blank=True, null=True) 77 | task = models.CharField(max_length=100, blank=True, null=True) 78 | hit_id = models.CharField(max_length=100, blank=True, null=True) 79 | game_id = models.CharField(max_length=100, blank=True, null=True) 80 | created_at = models.DateTimeField(auto_now_add=True) 81 | updated_at = models.DateTimeField(auto_now=True) 82 | bot = models.CharField(max_length=100, blank=True, null=True) 83 | is_active = models.NullBooleanField(default=True, blank=True, null=True) 84 | -------------------------------------------------------------------------------- /amt/routing.py: -------------------------------------------------------------------------------- 1 | from channels.routing import route, include 2 | from .consumers import ws_message, ws_connect, ws_disconnect 3 | 4 | ws_routing = [ 5 | route("websocket.receive", ws_message), 6 | route("websocket.connect", ws_connect), 7 | ] 8 | 9 | channel_routing = [ 10 | include(ws_routing, path=r"^/chat"), 11 | ] 12 | -------------------------------------------------------------------------------- /amt/sender.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from .utils import log_to_terminal 3 | 4 | import os 5 | import pika 6 | import sys 7 | import json 8 | 9 | 10 | def chatbot(input_question, history, image_path, socketid, bot): 11 | connection = pika.BlockingConnection(pika.ConnectionParameters( 12 | host='localhost')) 13 | channel = connection.channel() 14 | 15 | queue_name = 'sl_chatbot_queue' 16 | if bot == "sl": 17 | queue_name = "sl_chatbot_queue" 18 | elif bot == "rl": 19 | queue_name = "rl_chatbot_queue" 20 | 21 | channel.queue_declare(queue=queue_name, durable=True) 22 | message = { 23 | 'image_path': image_path, 24 | 'input_question': input_question, 25 | 'history': history, 26 | 'socketid': socketid, 27 | 'bot': bot, 28 | } 29 | 30 | log_to_terminal( 31 | socketid, {"terminal": "Publishing job to %s" % (queue_name.upper())}) 32 | channel.basic_publish(exchange='', 33 | routing_key=queue_name, 34 | body=json.dumps(message), 35 | properties=pika.BasicProperties( 36 | delivery_mode=2, # make message persistent 37 | )) 38 | 39 | print(" [x] Sent %r" % message) 40 | log_to_terminal(socketid, {"terminal": "Job published successfully"}) 41 | connection.close() 42 | -------------------------------------------------------------------------------- /amt/static/css/scrollbar.css: -------------------------------------------------------------------------------- 1 | /*************** SCROLLBAR BASE CSS ***************/ 2 | 3 | .scroll-wrapper { 4 | overflow: hidden !important; 5 | padding: 0 !important; 6 | position: relative; 7 | } 8 | 9 | .scroll-wrapper > .scroll-content { 10 | border: none !important; 11 | box-sizing: content-box !important; 12 | height: auto; 13 | left: 0; 14 | margin: 0; 15 | max-height: none; 16 | max-width: none !important; 17 | overflow: scroll !important; 18 | padding: 0; 19 | position: relative !important; 20 | top: 0; 21 | width: auto !important; 22 | } 23 | 24 | .scroll-wrapper > .scroll-content::-webkit-scrollbar { 25 | height: 0; 26 | width: 0; 27 | } 28 | 29 | .scroll-element { 30 | display: none; 31 | } 32 | .scroll-element, .scroll-element div { 33 | box-sizing: content-box; 34 | } 35 | 36 | .scroll-element.scroll-x.scroll-scrollx_visible, 37 | .scroll-element.scroll-y.scroll-scrolly_visible { 38 | display: block; 39 | } 40 | 41 | .scroll-element .scroll-bar, 42 | .scroll-element .scroll-arrow { 43 | cursor: default; 44 | } 45 | 46 | .scroll-textarea { 47 | border: 1px solid #cccccc; 48 | border-top-color: #999999; 49 | } 50 | .scroll-textarea > .scroll-content { 51 | overflow: hidden !important; 52 | } 53 | .scroll-textarea > .scroll-content > textarea { 54 | border: none !important; 55 | box-sizing: border-box; 56 | height: 100% !important; 57 | margin: 0; 58 | max-height: none !important; 59 | max-width: none !important; 60 | overflow: scroll !important; 61 | outline: none; 62 | padding: 2px; 63 | position: relative !important; 64 | top: 0; 65 | width: 100% !important; 66 | } 67 | .scroll-textarea > .scroll-content > textarea::-webkit-scrollbar { 68 | height: 0; 69 | width: 0; 70 | } 71 | 72 | 73 | 74 | 75 | /*************** SIMPLE INNER SCROLLBAR ***************/ 76 | 77 | .scrollbar-inner > .scroll-element, 78 | .scrollbar-inner > .scroll-element div 79 | { 80 | border: none; 81 | margin: 0; 82 | padding: 0; 83 | position: absolute; 84 | z-index: 10; 85 | } 86 | 87 | .scrollbar-inner > .scroll-element div { 88 | display: block; 89 | height: 100%; 90 | left: 0; 91 | top: 0; 92 | width: 100%; 93 | } 94 | 95 | .scrollbar-inner > .scroll-element.scroll-x { 96 | bottom: 2px; 97 | height: 8px; 98 | left: 0; 99 | width: 100%; 100 | } 101 | 102 | .scrollbar-inner > .scroll-element.scroll-y { 103 | height: 100%; 104 | right: 2px; 105 | top: 0; 106 | width: 8px; 107 | } 108 | 109 | .scrollbar-inner > .scroll-element .scroll-element_outer { 110 | overflow: hidden; 111 | } 112 | 113 | .scrollbar-inner > .scroll-element .scroll-element_outer, 114 | .scrollbar-inner > .scroll-element .scroll-element_track, 115 | .scrollbar-inner > .scroll-element .scroll-bar { 116 | -webkit-border-radius: 8px; 117 | -moz-border-radius: 8px; 118 | border-radius: 8px; 119 | } 120 | 121 | .scrollbar-inner > .scroll-element .scroll-element_track, 122 | .scrollbar-inner > .scroll-element .scroll-bar { 123 | -ms-filter:"progid:DXImageTransform.Microsoft.Alpha(Opacity=40)"; 124 | filter: alpha(opacity=40); 125 | opacity: 0.4; 126 | } 127 | 128 | .scrollbar-inner > .scroll-element .scroll-element_track { background-color: #e0e0e0; } 129 | .scrollbar-inner > .scroll-element .scroll-bar { background-color: #c2c2c2; } 130 | .scrollbar-inner > .scroll-element:hover .scroll-bar { background-color: #919191; } 131 | .scrollbar-inner > .scroll-element.scroll-draggable .scroll-bar { background-color: #919191; } 132 | 133 | 134 | /* update scrollbar offset if both scrolls are visible */ 135 | 136 | .scrollbar-inner > .scroll-element.scroll-x.scroll-scrolly_visible .scroll-element_track { left: -12px; } 137 | .scrollbar-inner > .scroll-element.scroll-y.scroll-scrollx_visible .scroll-element_track { top: -12px; } 138 | 139 | 140 | .scrollbar-inner > .scroll-element.scroll-x.scroll-scrolly_visible .scroll-element_size { left: -12px; } 141 | .scrollbar-inner > .scroll-element.scroll-y.scroll-scrollx_visible .scroll-element_size { top: -12px; } 142 | -------------------------------------------------------------------------------- /amt/static/css/style.css: -------------------------------------------------------------------------------- 1 | body{ 2 | font-family: 'Roboto', sans-serif !important; 3 | font-weight: 300 !important; 4 | } 5 | 6 | .message_typing{ 7 | position: fixed; 8 | z-index:100; 9 | bottom: 0 !important; 10 | left: 0px; 11 | width: 100%; 12 | } 13 | 14 | .side-nav a{ 15 | height: 72px; 16 | vertical-align: middle; 17 | text-align: left; 18 | padding-top: 10px; 19 | } 20 | 21 | @media only screen and (min-width: 993px) { 22 | 23 | .messages{ 24 | /*margin-left: 300px !important;*/ 25 | margin-bottom: 64px; 26 | } 27 | } 28 | 29 | .sidebar_li{ 30 | margin-bottom: 10px; 31 | align-items: left; 32 | vertical-align: middle !important; 33 | } 34 | 35 | #chat-list{ 36 | margin-top: 60px; 37 | } 38 | 39 | .circle{ 40 | vertical-align: middle; 41 | } 42 | 43 | .chat_bot_row, .chat_bot{ 44 | padding: 0 0.5rem; 45 | } 46 | 47 | 48 | @media only screen and (max-width: 993px) { 49 | 50 | .gcam_image{ 51 | /*margin-left: 300px !important;*/ 52 | margin-bottom: 64px; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /amt/static/images/Preloader_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/static/images/Preloader_2.gif -------------------------------------------------------------------------------- /amt/static/images/abot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/static/images/abot.png -------------------------------------------------------------------------------- /amt/static/images/bot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/static/images/bot.jpg -------------------------------------------------------------------------------- /amt/static/images/new_bot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/static/images/new_bot.png -------------------------------------------------------------------------------- /amt/static/js/queryparam_helper.js: -------------------------------------------------------------------------------- 1 | // =========================================================== 2 | // A bunch of helper functions for an AMT interface 3 | // =========================================================== 4 | 5 | // Note: QS is short for query string, i.e., the = stuff after the 6 | // question mark (?) in the URL. 7 | var NUM_QS_ZEROPAD = 2; // Number of digits for QS parameters 8 | 9 | function gup(name) { 10 | var regexS = "[\\?&]" + name + "=([^&#]*)"; 11 | var regex = new RegExp(regexS); 12 | var tmpURL = window.location.href; 13 | var results = regex.exec(tmpURL); 14 | if (results == null) { 15 | return ""; 16 | } else { 17 | return results[1]; 18 | } 19 | } 20 | 21 | function decode(strToDecode) { 22 | return unescape(strToDecode.replace(/\+/g, " ")); 23 | } 24 | 25 | function get_random_int(min, max) { 26 | return Math.floor(Math.random() * (max - min)) + min; 27 | } 28 | 29 | function zero_pad(num, numZeros) { 30 | var n = Math.abs(num); 31 | var zeros = Math.max(0, numZeros - Math.floor(n).toString().length ); 32 | var zeroString = Math.pow(10,zeros).toString().substr(1); 33 | if (num < 0) { 34 | zeroString = '-' + zeroString; 35 | } 36 | return zeroString+n; 37 | } 38 | 39 | function collect_ordered_QS(param_name, pad) { 40 | var array = []; // Store all the data 41 | var done = false; 42 | var i = 1; 43 | var name = ''; 44 | var val = ''; 45 | while (done == false) { 46 | name = param_name + zero_pad(i, pad); 47 | val = decode(gup(name)); 48 | 49 | if (val == "") { 50 | done = true; 51 | } else { 52 | array.push(val); 53 | } 54 | i += 1; 55 | } 56 | return array; 57 | } 58 | -------------------------------------------------------------------------------- /amt/static/slack_sound.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/static/slack_sound.mp3 -------------------------------------------------------------------------------- /amt/templates/amt/base.html: -------------------------------------------------------------------------------- 1 | {% load static %} 2 | 3 | 4 | 5 | Guess Which 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 28 | 29 | 42 | 43 | 153 | 233 | 234 |
235 | 236 |
237 | {% block header%} 238 | 239 | 246 | 247 | 252 | 253 | 268 | 269 | {% endblock %} 270 | 271 | {% block sidebar %} 272 | {% endblock %} 273 |
274 | 275 | 276 | {% block images %} 277 | 278 | {% endblock %} 279 | 280 | {% block messages %} 281 | 282 | {% endblock %} 283 | 284 | {% block message_typing %} 285 | 286 | {% endblock %} 287 | 288 | 289 | {% block intro %} 290 | 291 | {% endblock %} 292 | 336 | 348 | 349 | -------------------------------------------------------------------------------- /amt/templates/amt/feedback.html: -------------------------------------------------------------------------------- 1 |
2 |
3 | 4 |
5 |

Please provide some feedback on Alice.

6 |
7 |
8 |

Alice seemed to understand the questions that were asked.

9 |

10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |

21 |
22 |
23 | 24 |
25 |
26 |

Alice seemed to understand the images it was looking at.

27 |

28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 |

39 |
40 |
41 | 42 |
43 |
44 |

Alice was fluent in its responses (regardless of whether it was accurate)

45 |

46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 |

57 |
58 |
59 | 60 |
61 |
62 |

Alice's responses were detailed and informative (regardless of whether it was accurate)

63 |

64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 |

75 |
76 |
77 | 78 |
79 |
80 |

Alice provided accurate responses to questions.

81 |

82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 |

93 |
94 |
95 | 96 |
97 |
98 |

Alice was trying to be consistent with what it said in the past.

99 |

100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 |

111 |
112 |
113 | 114 |
115 | Do you have comments or feedback? Please let us know. 116 | 117 |
118 |
119 | 120 | 121 |
122 |
123 | 124 | -------------------------------------------------------------------------------- /amt/templates/amt/index.html: -------------------------------------------------------------------------------- 1 | {% extends 'amt/base.html' %} {% load static %} {% load range %} {% block sidebar %} {% endblock %} {% block images %} 2 |
3 |
4 | {% for i in img_list %} 5 |
6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 | {% endfor %} 15 |
16 |
17 |
18 |
19 |
20 |
21 | 22 |
23 |
24 | {{ bot_intro_message }} 25 |
26 |
27 |
28 | 29 |
30 | 31 |
32 |
33 | Target image description: {{ caption }} 34 |
35 |
36 |
37 | 38 |
39 | 40 |
41 |
42 | Based on your understanding of the image description, pick the image that you think is the most relevant. 43 |
44 |
45 |
46 | 47 |
48 |
49 | 50 |
51 |
52 |
53 |
54 |
55 | 56 | {% endblock %} 57 | 58 | {% block message_typing %} 59 | 119 | 190 | 191 | 197 | 198 | 220 | {% include 'amt/modal.html' %} 221 | {% include 'amt/mturk.html' %} 222 | 355 | 356 | 371 | 379 | {% endblock %} 380 | -------------------------------------------------------------------------------- /amt/templates/amt/intro.html: -------------------------------------------------------------------------------- 1 | {% extends 'amt/base.html' %} 2 | {% load static %} 3 | {% load range %} 4 | 5 | {% block sidebar %} 6 | {% endblock %} 7 | 8 | {% block images %} 9 | {% endblock %} 10 | 11 | {% block message_typing %} 12 | {% endblock %} 13 | 14 | {% block intro %} 15 | 24 |
25 |
26 |

Welcome!

27 | 28 |

Hi, My name is Vicki. I am an AI.

29 |

I have been assigned one target image. I am not allowed to show you the image, but you can ask me questions about it. When ready, submit one of the images on the left as your best guess. We can only make 3 guesses. I will try to answer your questions, but I am not perfect. But I hope we can work together to find the image as quickly as possible!

30 |
31 | 32 |
33 |
34 | {% csrf_token %} 35 |
36 |
37 | 42 | 43 |
44 |
45 | 46 |
47 |
48 | 49 |
50 |
51 |
52 |
53 | 54 |
55 | 60 | 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /amt/templates/amt/loader.html: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | 15 | -------------------------------------------------------------------------------- /amt/templates/amt/modal.html: -------------------------------------------------------------------------------- 1 | 14 | 15 | 25 | 26 | -------------------------------------------------------------------------------- /amt/templates/amt/mturk.html: -------------------------------------------------------------------------------- 1 |
2 | {% csrf_token %} 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | -------------------------------------------------------------------------------- /amt/templates/amt/plot.html: -------------------------------------------------------------------------------- 1 | 23 |
24 | 106 | -------------------------------------------------------------------------------- /amt/templatetags/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/amt/templatetags/__init__.py -------------------------------------------------------------------------------- /amt/templatetags/range.py: -------------------------------------------------------------------------------- 1 | from django import template 2 | 3 | register = template.Library() 4 | 5 | 6 | @register.filter(name='range') 7 | def times(number): 8 | return range(number) 9 | -------------------------------------------------------------------------------- /amt/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls import url 2 | from . import views 3 | 4 | urlpatterns = [ 5 | url(r'^feedback$', views.feedback, name='feedback'), 6 | url(r'^$', views.home, name='home'), 7 | ] 8 | -------------------------------------------------------------------------------- /amt/utils.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from channels import Group 3 | 4 | import h5py 5 | import time 6 | import cPickle 7 | import pdb 8 | 9 | from .constants import ( 10 | POOL_IMAGES_URL, 11 | AWS_ACCESS_KEY_ID, 12 | AWS_SECRET_ACCESS_KEY, 13 | QUALIFICATION_TYPE_ID, 14 | AMT_HOSTNAME, 15 | ) 16 | 17 | import json 18 | import os 19 | import random 20 | import traceback 21 | import numpy as np 22 | import boto.mturk.connection 23 | 24 | 25 | def log_to_terminal(socketid, message): 26 | Group(socketid).send({"text": json.dumps(message)}) 27 | 28 | 29 | def get_pool_images(pool_id=1): 30 | with open('data/pools.json', 'r') as f: 31 | pool_data = json.load(f) 32 | return pool_data[pool_id] 33 | 34 | 35 | def get_url_of_image(image_id): 36 | return POOL_IMAGES_URL + str(image_id) + ".jpg" 37 | 38 | 39 | def fc7_sort(imfeats, prev_sort, chosen_imID): 40 | target_f = imfeats[chosen_imID] 41 | dist_vec = np.zeros(len(prev_sort), dtype='float32') 42 | 43 | for i in range(len(prev_sort)): 44 | dist_vec[i] = np.linalg.norm(imfeats[prev_sort[i]] - target_f) 45 | 46 | sort_ind = np.argsort(dist_vec).tolist() 47 | new_sort = [] 48 | for i in range(len(sort_ind)): 49 | new_sort.append(prev_sort[sort_ind[i]]) 50 | return new_sort 51 | 52 | 53 | def create_qualifications(): 54 | mtc = boto.mturk.connection.MTurkConnection( 55 | aws_access_key_id=AWS_ACCESS_KEY_ID, 56 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY, 57 | host=AMT_HOSTNAME, 58 | debug=2 59 | ) 60 | 61 | qualification = mtc.create_qualification_type( 62 | name='Some Qualification Name', 63 | description='Qualification to avoid bias in responses by preventing workers who have already completed a HIT from doing subsequent HITs.', 64 | status='Active', 65 | auto_granted=True, 66 | auto_granted_value=0 67 | ) 68 | 69 | 70 | def set_qualification_to_worker(worker_id=None, qualification_value=0): 71 | mtc = boto.mturk.connection.MTurkConnection( 72 | aws_access_key_id=AWS_ACCESS_KEY_ID, 73 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY, 74 | host=AMT_HOSTNAME, 75 | debug=2 76 | ) 77 | 78 | mtc.assign_qualification(QUALIFICATION_TYPE_ID, worker_id, 79 | value=qualification_value, 80 | send_notification=False) 81 | 82 | 83 | def updated_qualification_to_worker(worker_id=None, qualification_value=1): 84 | mtc = boto.mturk.connection.MTurkConnection( 85 | aws_access_key_id=AWS_ACCESS_KEY_ID, 86 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY, 87 | host=AMT_HOSTNAME, 88 | debug=2 89 | ) 90 | 91 | mtc.update_qualification_score( 92 | QUALIFICATION_TYPE_ID, worker_id, qualification_value) 93 | -------------------------------------------------------------------------------- /amt/views.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.shortcuts import render 3 | from django.http import JsonResponse 4 | from django.db.models import Sum 5 | 6 | from .utils import ( 7 | log_to_terminal, 8 | get_pool_images, 9 | get_url_of_image, 10 | set_qualification_to_worker 11 | ) 12 | 13 | from .models import Feedback, ImageRanking 14 | 15 | import constants as constants 16 | 17 | import sys 18 | import uuid 19 | import os 20 | import traceback 21 | import random 22 | import urllib2 23 | import redis 24 | import json 25 | 26 | 27 | r = redis.StrictRedis(host='localhost', port=6379, db=0) 28 | 29 | 30 | class PoolImage: 31 | """ 32 | Class to store the details related to a particular pool 33 | """ 34 | 35 | def __init__(self, image_path, score, img_id, rank): 36 | self.image_path = image_path 37 | self.score = score 38 | self.img_id = img_id 39 | self.rank = rank 40 | 41 | 42 | def home(request, template_name="amt/index.html"): 43 | """ 44 | Method called when the game starts 45 | """ 46 | worker_id = request.GET.get('workerId', "default") 47 | 48 | if worker_id == "default": 49 | # default is used for the debug mode 50 | disabled = True 51 | if worker_id != "default": 52 | disabled = False 53 | try: 54 | # Set the qualification so that worker cannot do the HIT again 55 | set_qualification_to_worker( 56 | worker_id=worker_id, qualification_value=1) 57 | print "Success: Setting Qualification for worker ", worker_id 58 | except Exception as e: 59 | print "Error: Cannot Set Qualification for worker ", worker_id 60 | 61 | ''' 62 | Possible values of level: 63 | - easy 64 | - medium 65 | - hard 66 | ''' 67 | level = request.GET.get("level", "medium") 68 | hitId = request.GET.get('hitId') 69 | assignmentId = request.GET.get('assignmentId') 70 | turkSubmitTo = request.GET.get('turkSubmitTo') 71 | bot = request.GET.get('bot') 72 | 73 | socketid = uuid.uuid4() 74 | 75 | # Fetch previous games played by this user 76 | prev_games_of_this_hit = ImageRanking.objects.filter( 77 | assignment_id=assignmentId, worker_id=worker_id, hit_id=hitId, bot=bot) 78 | prev_game_ids = prev_games_of_this_hit.values_list('game_id', flat=True) 79 | prev_game_ids = [int(i) for i in prev_game_ids] 80 | 81 | try: 82 | # Compute the next GameID to show the new pool of images to play with 83 | next_game_id = max(prev_game_ids) 84 | except: 85 | # If exception, start from the very beginning i.e game_id=0 86 | next_game_id = 0 87 | 88 | if next_game_id == 10: 89 | next_game_id = 9 90 | # If this is the last game, show the modal to fill feedback after this 91 | # game 92 | show_feedback_modal = True 93 | else: 94 | show_feedback_modal = False 95 | 96 | # Get the pool details for the particular game_id 97 | image_pool = get_pool_images(pool_id=int(next_game_id)) 98 | 99 | # Fetch the images of particular difficulty from the pool json data 100 | image_list = image_pool['pools'][level][:20] 101 | image_list = sorted(image_list) 102 | img_list = [] 103 | for i in xrange(len(image_list)): 104 | img_path = constants.POOL_IMAGES_URL + str(image_list[i]) + ".jpg" 105 | img = PoolImage(img_path, 0, image_list[i], i+1) 106 | img_list.append(img) 107 | image_path_list = [constants.POOL_IMAGES_URL + 108 | str(s) + ".jpg" for s in image_list][:20] 109 | target_image = image_pool['target'] 110 | target_image_url = get_url_of_image(target_image) 111 | # Assign 0 rank to all of the images 112 | scores = [0] * 20 113 | caption = image_pool['gen_caption'] 114 | 115 | r.set("target_{}".format(str(socketid)), target_image) 116 | 117 | intro_message = random.choice(constants.BOT_INTORDUCTION_MESSAGE) 118 | 119 | # Compute the comulative bonus for previous games that he has played before 120 | total_bonus_so_far = ImageRanking.objects.filter( 121 | assignment_id=assignmentId, worker_id=worker_id, hit_id=hitId, bot=bot).aggregate(score=Sum('score')) 122 | 123 | # If this is the first game for the user, set the total bonus to 0 124 | if total_bonus_so_far['score'] is None: 125 | total_bonus_so_far = 0 126 | else: 127 | total_bonus_so_far = total_bonus_so_far['score'] 128 | 129 | return render(request, template_name, { 130 | "socketid": socketid, 131 | "bot_intro_message": intro_message, 132 | "img_list": img_list, 133 | "target_image": target_image_url, 134 | "target_image_id": target_image, 135 | "scores": scores, 136 | "img_id_list": json.dumps(image_list), 137 | "caption": caption, 138 | "max_rounds": constants.NUMBER_OF_ROUNDS_IN_A_GAME, 139 | "num_of_games_in_a_hit": constants.NUMBER_OF_GAMES_IN_A_HIT, 140 | "disabled": disabled, 141 | "total_bonus_so_far": total_bonus_so_far, 142 | "max_game_bonus": constants.MAX_BONUS_IN_A_GAME, 143 | "bonus_deduction_on_each_click": constants.BONUS_DEDUCTION_FOR_EACH_CLICK, 144 | "next_game_id": next_game_id, 145 | "bonus_for_correct_image_after_each_round": constants.BONUS_FOR_CORRECT_IMAGE_AFTER_EACH_ROUND, 146 | "show_feedback_modal": show_feedback_modal, 147 | }) 148 | 149 | 150 | def feedback(request): 151 | """ 152 | View to collect the feedback provided by the Mechanical Turk Workers 153 | """ 154 | hitId = request.POST.get('hitId') 155 | assignmentId = request.POST.get('assignmentId') 156 | workerId = request.POST.get('workerId') 157 | understand_question = request.POST.get('understand_question') 158 | understand_image = request.POST.get('understand_image') 159 | fluency = request.POST.get('fluency') 160 | detail = request.POST.get('detail') 161 | accurate = request.POST.get('accurate') 162 | consistent = request.POST.get('consistent') 163 | comments = request.POST.get('comments') 164 | game_id = request.POST.get('game_id') 165 | level = request.POST.get('level') 166 | bot = request.POST.get('bot') 167 | task = request.POST.get('task') 168 | 169 | Feedback.objects.create( 170 | hit_id=hitId, 171 | assignment_id=assignmentId, 172 | worker_id=workerId, 173 | understand_question=understand_question, 174 | understand_image=understand_image, 175 | fluency=fluency, 176 | detail=detail, 177 | accurate=accurate, 178 | consistent=consistent, 179 | comments=comments, 180 | level=level, 181 | game_id=game_id, 182 | bot=bot, 183 | task=task 184 | ) 185 | return JsonResponse({'success': True}) 186 | -------------------------------------------------------------------------------- /chatbot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/chatbot/__init__.py -------------------------------------------------------------------------------- /chatbot/dataloader.lua: -------------------------------------------------------------------------------- 1 | -- script to load the dataloader for meta-rnn 2 | require 'hdf5' 3 | require 'xlua' 4 | local utils = require 'utils' 5 | 6 | local dataloader = {}; 7 | 8 | -- read the data 9 | -- params: object itself, command line options, 10 | -- subset of data to load (train, test, val) 11 | function dataloader:initialize(opt, subsets) 12 | -- read additional info like dictionary, etc 13 | print('DataLoader loading h5 file: ', opt.input_json) 14 | info = utils.readJSON(opt.input_json); 15 | for key, value in pairs(info) do dataloader[key] = value; end 16 | 17 | -- add and to vocabulary 18 | count = 0; 19 | for _ in pairs(dataloader['word2ind']) do count = count + 1; end 20 | dataloader['word2ind'][''] = count + 1; 21 | dataloader['word2ind'][''] = count + 2; 22 | count = count + 2; 23 | dataloader.vocabSize = count; 24 | print(string.format('Vocabulary size (with ,): %d\n', count)); 25 | 26 | -- construct ind2word 27 | local ind2word = {}; 28 | for word, ind in pairs(dataloader['word2ind']) do 29 | ind2word[ind] = word; 30 | end 31 | dataloader['ind2word'] = ind2word; 32 | 33 | -- read questions, answers and options 34 | print('DataLoader loading h5 file: ', opt.input_ques_h5) 35 | local quesFile = hdf5.open(opt.input_ques_h5, 'r'); 36 | 37 | print('DataLoader loading h5 file: ', opt.input_img_h5) 38 | local imgFile = hdf5.open(opt.input_img_h5, 'r'); 39 | -- number of threads 40 | self.numThreads = {}; 41 | 42 | for _, dtype in pairs(subsets) do 43 | -- read question related information 44 | self[dtype..'_ques'] = quesFile:read('ques_'..dtype):all(); 45 | self[dtype..'_ques_len'] = quesFile:read('ques_length_'..dtype):all(); 46 | self[dtype..'_ques_count'] = quesFile:read('ques_count_'..dtype):all(); 47 | 48 | -- read answer related information 49 | self[dtype..'_ans'] = quesFile:read('ans_'..dtype):all(); 50 | self[dtype..'_ans_len'] = quesFile:read('ans_length_'..dtype):all(); 51 | self[dtype..'_ans_ind'] = quesFile:read('ans_index_'..dtype):all():long(); 52 | 53 | -- read image list, if image features are needed 54 | if opt.useIm then 55 | print('Reading image features ..') 56 | local imgFeats = imgFile:read('/images_'..dtype):all(); 57 | 58 | -- Normalize the image feature(if needed) 59 | if opt.img_norm == 1 then 60 | print('Normalizing image features..') 61 | local nm = torch.sqrt(torch.sum(torch.cmul(imgFeats, imgFeats), 2)); 62 | imgFeats = torch.cdiv(imgFeats, nm:expandAs(imgFeats)):float(); 63 | end 64 | self[dtype..'_img_fv'] = imgFeats; 65 | -- TODO: make it 1 indexed in processing code 66 | -- currently zero indexed, adjust manually 67 | self[dtype..'_img_pos'] = quesFile:read('img_pos_'..dtype):all():long(); 68 | self[dtype..'_img_pos'] = self[dtype..'_img_pos'] + 1; 69 | end 70 | 71 | -- print information for data type 72 | print(string.format('%s:\n\tNo. of threads: %d\n\tNo. of rounds: %d'.. 73 | '\n\tMax ques len: %d'..'\n\tMax ans len: %d\n', 74 | dtype, self[dtype..'_ques']:size(1), 75 | self[dtype..'_ques']:size(2), 76 | self[dtype..'_ques']:size(3), 77 | self[dtype..'_ans']:size(2))); 78 | -- record some stats 79 | if dtype == 'train' then 80 | self.numTrainThreads = self['train_ques']:size(1); 81 | self.numThreads['train'] = self.numTrainThreads; 82 | end 83 | if dtype == 'test' then 84 | self.numTestThreads = self['test_ques']:size(1); 85 | self.numThreads['test'] = self.numTestThreads; 86 | end 87 | if dtype == 'val' then 88 | self.numValThreads = self['val_ques']:size(1); 89 | self.numThreads['val'] = self.numValThreads; 90 | end 91 | 92 | -- record the options only for test and val 93 | if dtype == 'val' or dtype == 'test' then 94 | self[dtype..'_opt'] = quesFile:read('opt_'..dtype):all():long(); 95 | self[dtype..'_opt_len'] = quesFile:read('opt_length_'..dtype):all(); 96 | self[dtype..'_opt_list'] = quesFile:read('opt_list_'..dtype):all(); 97 | self[dtype..'_opt_prob'] = torch.Tensor(self[dtype..'_opt_len']:size()); 98 | end 99 | 100 | -- assume similar stats across multiple data subsets 101 | -- maximum number of questions per image, ideally 10 102 | self.maxQuesCount = self[dtype..'_ques']:size(2); 103 | -- maximum length of question 104 | self.maxQuesLen = self[dtype..'_ques']:size(3); 105 | -- maximum length of answer 106 | self.maxAnsLen = self[dtype..'_ans']:size(2); 107 | -- number of options, if read 108 | if self[dtype..'_opt'] then 109 | self.numOptions = self[dtype..'_opt']:size(3); 110 | end 111 | 112 | -- if history is needed 113 | if opt.useHistory then 114 | self[dtype..'_cap'] = quesFile:read('cap_'..dtype):all():long(); 115 | self[dtype..'_cap_len'] = quesFile:read('cap_length_'..dtype):all(); 116 | end 117 | end 118 | -- done reading, close files 119 | quesFile:close(); 120 | imgFile:close(); 121 | 122 | -- take desired flags/values from opt 123 | self.useBi = opt.useBi; 124 | self.useHistory = opt.useHistory; 125 | self.useIm = opt.useIm; 126 | self.separateCaption = opt.separateCaption; 127 | self.maxHistoryLen = opt.maxHistoryLen or 60; 128 | 129 | -- prepareDataset for training 130 | for _, dtype in pairs(subsets) do self:prepareDataset(dtype); end 131 | 132 | end 133 | 134 | -- method to prepare questions and answers for retrieval 135 | -- questions : right align 136 | -- answers : prefix with and 137 | function dataloader:prepareDataset(dtype) 138 | -- right align the questions 139 | print('Right aligning questions: '..dtype); 140 | self[dtype..'_ques_fwd'] = utils.rightAlign(self[dtype..'_ques'], 141 | self[dtype..'_ques_len']); 142 | 143 | -- if bidirectional model is needed store backward 144 | if self.useBi then 145 | self[dtype..'_ques_bwd'] = nn.SeqReverseSequence(3) 146 | :forward(self[dtype..'_ques']:double()); 147 | -- convert back to LongTensor 148 | self[dtype..'_ques_bwd'] = self[dtype..'_ques_bwd']:long(); 149 | end 150 | 151 | -- if separate captions are needed 152 | if self.separateCaption then self:processHistoryCaption(dtype); 153 | else if self.useHistory then self:processHistory(dtype); end 154 | end 155 | -- prefix options with and , if not train 156 | if dtype ~= 'train' then self:processOptions(dtype); end 157 | -- process answers 158 | self:processAnswers(dtype); 159 | end 160 | 161 | -- process answers 162 | function dataloader:processAnswers(dtype) 163 | --prefix answers with , ; adjust answer lengths 164 | local answers = self[dtype..'_ans']; 165 | local ansLen = self[dtype..'_ans_len']; 166 | 167 | local numConvs = answers:size(1); 168 | local numRounds = answers:size(2); 169 | local maxAnsLen = answers:size(3); 170 | 171 | local decodeIn = torch.LongTensor(numConvs, numRounds, maxAnsLen+1):zero(); 172 | local decodeOut = torch.LongTensor(numConvs, numRounds, maxAnsLen+1):zero(); 173 | 174 | -- decodeIn begins with 175 | decodeIn[{{}, {}, 1}] = self.word2ind['']; 176 | 177 | -- go over each answer and modify 178 | local endTokenId = self.word2ind['']; 179 | for thId = 1, numConvs do 180 | for roundId = 1, numRounds do 181 | local length = ansLen[thId][roundId]; 182 | 183 | -- only if nonzero 184 | if length > 0 then 185 | --decodeIn[thId][roundId][1] = startToken; 186 | decodeIn[thId][roundId][{{2, length + 1}}] 187 | = answers[thId][roundId][{{1, length}}]; 188 | decodeOut[thId][roundId][{{1, length}}] 189 | = answers[thId][roundId][{{1, length}}]; 190 | decodeOut[thId][roundId][length+1] = endTokenId; 191 | else 192 | print(string.format('Warning: empty answer at (%d %d %d)', 193 | thId, roundId, length)) 194 | end 195 | end 196 | end 197 | 198 | self[dtype..'_ans_len'] = self[dtype..'_ans_len'] + 1; 199 | self[dtype..'_ans_in'] = decodeIn; 200 | self[dtype..'_ans_out'] = decodeOut; 201 | end 202 | 203 | -- process caption as history 204 | function dataloader:processHistory(dtype) 205 | local captions = self[dtype..'_cap']; 206 | local questions = self[dtype..'_ques']; 207 | local quesLen = self[dtype..'_ques_len']; 208 | local capLen = self[dtype..'_cap_len']; 209 | local maxQuesLen = questions:size(3); 210 | 211 | local answers = self[dtype..'_ans']; 212 | local ansLen = self[dtype..'_ans_len']; 213 | local numConvs = answers:size(1); 214 | local numRounds = answers:size(2); 215 | local maxAnsLen = answers:size(3); 216 | 217 | -- chop off caption to maxQuesLen 218 | local history = torch.LongTensor(numConvs, numRounds, 219 | maxQuesLen+maxAnsLen):zero(); 220 | local histLen = torch.LongTensor(numConvs, numRounds):zero(); 221 | 222 | -- go over each question and append it with answer 223 | for thId = 1, numConvs do 224 | local lenC = capLen[thId]; 225 | for roundId = 1, numRounds do 226 | local lenH; -- length of history 227 | if roundId == 1 then 228 | -- first round has caption as history 229 | history[thId][roundId] 230 | = captions[thId][{{1, maxQuesLen + maxAnsLen}}]; 231 | lenH = math.min(lenC, maxQuesLen + maxAnsLen); 232 | else 233 | -- other rounds have previous Q + A as history 234 | local lenQ = quesLen[thId][roundId-1]; 235 | local lenA = ansLen[thId][roundId-1]; 236 | 237 | if lenQ > 0 then 238 | history[thId][roundId][{{1, lenQ}}] 239 | = questions[thId][roundId-1][{{1, lenQ}}]; 240 | end 241 | if lenA > 0 then 242 | history[thId][roundId][{{lenQ + 1, lenQ + lenA}}] 243 | = answers[thId][roundId-1][{{1, lenA}}]; 244 | end 245 | lenH = lenA + lenQ; 246 | end 247 | -- save the history length 248 | histLen[thId][roundId] = lenH; 249 | end 250 | end 251 | 252 | -- right align history and then save 253 | print('Right aligning history: '..dtype); 254 | self[dtype..'_hist'] = utils.rightAlign(history, histLen); 255 | self[dtype..'_hist_len'] = histLen; 256 | end 257 | 258 | -- process history and captions separately 259 | function dataloader:processHistoryCaption(dtype) 260 | local questions = self[dtype..'_ques']; 261 | local quesLen = self[dtype..'_ques_len']; 262 | local maxQuesLen = questions:size(3); 263 | 264 | local answers = self[dtype..'_ans']; 265 | local ansLen = self[dtype..'_ans_len']; 266 | local numConvs = answers:size(1); 267 | local numRounds = answers:size(2); 268 | local maxAnsLen = answers:size(3); 269 | 270 | -- chop off caption to maxQuesLen 271 | local history = torch.LongTensor(numConvs, numRounds, 272 | (self.maxQuesCount - 1) * (maxQuesLen + maxAnsLen)):zero(); 273 | local histLen = torch.LongTensor(numConvs, numRounds):zero(); 274 | 275 | -- go over each question and append it with answer 276 | for thId = 1, numConvs do 277 | -- current round as history for next rounds 278 | local runHistLen = 0; 279 | for prevId = 1, numRounds-1 do 280 | -- current Q and A as history 281 | local lenQ = quesLen[thId][prevId]; 282 | local lenA = ansLen[thId][prevId]; 283 | local curHistLen = lenA + lenQ; 284 | local curHistory = torch.LongTensor(curHistLen); 285 | curHistory[{{1, lenQ}}] = questions[thId][prevId][{{1, lenQ}}]; 286 | curHistory[{{lenQ + 1, curHistLen}}] = answers[thId][prevId][{{1, lenA}}]; 287 | 288 | for roundId = prevId + 1, numRounds do 289 | history[thId][roundId][{{runHistLen+1, runHistLen+curHistLen}}] 290 | = curHistory; 291 | end 292 | 293 | -- increase the running count of history length 294 | runHistLen = runHistLen + curHistLen; 295 | histLen[thId][prevId + 1] = runHistLen; 296 | end 297 | end 298 | 299 | -- right align history and then save 300 | print('Right aligning history: '..dtype); 301 | self[dtype..'_hist'] = utils.rightAlign(history, histLen); 302 | self[dtype..'_hist_len'] = histLen; 303 | 304 | -- right align captions and then save 305 | print('Right aligning captions: '..dtype); 306 | self[dtype..'_cap'] = utils.rightAlign(self[dtype..'_cap'], 307 | self[dtype..'_cap_len']); 308 | end 309 | 310 | -- process options 311 | function dataloader:processOptions(dtype) 312 | local lengths = self[dtype..'_opt_len']; 313 | local answers = self[dtype..'_ans']; 314 | local maxAnsLen = answers:size(3); 315 | local answers = self[dtype..'_opt_list']; 316 | local numConvs = answers:size(1); 317 | 318 | local ansListLen = answers:size(1); 319 | local decodeIn = torch.LongTensor(ansListLen, maxAnsLen + 1):zero(); 320 | local decodeOut = torch.LongTensor(ansListLen, maxAnsLen + 1):zero(); 321 | 322 | -- decodeIn begins with 323 | decodeIn[{{}, 1}] = self.word2ind['']; 324 | 325 | -- go over each answer and modify 326 | local endTokenId = self.word2ind['']; 327 | for id = 1, ansListLen do 328 | -- print progress for number of images 329 | if id % 100 == 0 then 330 | xlua.progress(id, numConvs); 331 | end 332 | local length = lengths[id]; 333 | 334 | -- only if nonzero 335 | if length > 0 then 336 | decodeIn[id][{{2, length + 1}}] = answers[id][{{1, length}}]; 337 | 338 | decodeOut[id][{{1, length}}] = answers[id][{{1, length}}]; 339 | decodeOut[id][length + 1] = endTokenId; 340 | else 341 | print(string.format('Warning: empty answer for %s at %d', 342 | dtype, id)) 343 | end 344 | end 345 | 346 | self[dtype..'_opt_len'] = self[dtype..'_opt_len'] + 1; 347 | self[dtype..'_opt_in'] = decodeIn; 348 | self[dtype..'_opt_out'] = decodeOut; 349 | 350 | collectgarbage(); 351 | end 352 | 353 | -- method to grab the next training batch 354 | function dataloader.getTrainBatch(self, params, batchSize) 355 | local size = batchSize or params.batchSize; 356 | local inds = torch.LongTensor(size):random(1, params.numTrainThreads); 357 | 358 | -- Index question, answers, image features for batch 359 | return self:getIndexData(inds, params, 'train'); 360 | --output['batch_ques'], output['answer_in'], output['answer_out']; 361 | end 362 | 363 | -- method to grab the next test/val batch, for evaluation of a given size 364 | function dataloader.getTestBatch(self, startId, params, dtype) 365 | local batchSize = params.batchSize * 4; 366 | -- get the next start id and fill up current indices till then 367 | local nextStartId; 368 | if dtype == 'val' then 369 | nextStartId = math.min(self.numValThreads+1, startId + batchSize); 370 | end 371 | if dtype == 'test' then 372 | nextStartId = math.min(self.numTestThreads+1, startId + batchSize); 373 | end 374 | 375 | -- dumb way to get range (complains if cudatensor is default) 376 | local inds = torch.LongTensor(nextStartId - startId); 377 | for ii = startId, nextStartId - 1 do inds[ii - startId + 1] = ii; end 378 | --local inds = torch.range(startId, nextStartId - 1):long(); 379 | 380 | -- Index question, answers, image features for batch 381 | local batchOutput = self:getIndexData(inds, params, dtype); 382 | local optionOutput = self:getIndexOption(inds, params, dtype); 383 | 384 | -- merge both the tables and return 385 | for key, value in pairs(optionOutput) do batchOutput[key] = value; end 386 | 387 | return batchOutput, nextStartId; 388 | end 389 | 390 | -- get batch from data subset given the indices 391 | function dataloader.getIndexData(self, inds, params, dtype) 392 | -- get the question lengths 393 | local batchQuesLen = self[dtype..'_ques_len']:index(1, inds); 394 | local maxQuesLen = torch.max(batchQuesLen); 395 | -- get questions 396 | local quesFwd = self[dtype..'_ques_fwd']:index(1, inds) 397 | [{{}, {}, {-maxQuesLen, -1}}]; 398 | local quesBwd; 399 | if self.useBi then 400 | quesBwd = self[dtype..'_ques_bwd']:index(1, inds) 401 | [{{}, {}, {-maxQuesLen, -1}}]; 402 | end 403 | 404 | local history; 405 | if self.useHistory then 406 | local batchHistLen = self[dtype..'_hist_len']:index(1, inds); 407 | local maxHistLen = math.min(torch.max(batchHistLen), self.maxHistoryLen); 408 | history = self[dtype..'_hist']:index(1, inds) 409 | [{{}, {}, {-maxHistLen, -1}}]; 410 | end 411 | 412 | local caption; 413 | if self.separateCaption then 414 | local batchCapLen = self[dtype..'_cap_len']:index(1, inds); 415 | local maxCapLen = torch.max(batchCapLen); 416 | caption = self[dtype..'_cap']:index(1, inds)[{{}, {-maxCapLen, -1}}]; 417 | end 418 | 419 | local imgFeats; 420 | if self.useIm then 421 | local imgInds = self[dtype..'_img_pos']:index(1, inds); 422 | imgFeats = self[dtype..'_img_fv']:index(1, imgInds); 423 | end 424 | 425 | -- get the answer lengths 426 | local batchAnsLen = self[dtype..'_ans_len']:index(1, inds); 427 | local maxAnsLen = torch.max(batchAnsLen); 428 | -- answer labels (decode input and output) 429 | local answerIn = self[dtype..'_ans_in'] 430 | :index(1, inds)[{{}, {}, {1, maxAnsLen}}]; 431 | local answerOut = self[dtype..'_ans_out'] 432 | :index(1, inds)[{{}, {}, {1, maxAnsLen}}]; 433 | local answerInd = self[dtype..'_ans_ind']:index(1, inds); 434 | 435 | local output = {}; 436 | if params.gpuid >= 0 then 437 | -- TODO: instead store everything on gpu to save time 438 | output['ques_fwd'] = quesFwd:cuda(); 439 | output['answer_in'] = answerIn:cuda(); 440 | output['answer_out'] = answerOut:cuda(); 441 | output['answer_ind'] = answerInd:cuda(); 442 | if quesBwd then output['ques_bwd'] = quesBwd:cuda(); end 443 | if history then output['hist'] = history:cuda(); end 444 | if caption then output['cap'] = caption:cuda(); end 445 | if imgFeats then output['img_feat'] = imgFeats:cuda(); end 446 | else 447 | output['ques_fwd'] = quesFwd:contiguous(); 448 | output['answer_in'] = answerIn:contiguous(); 449 | output['answer_out'] = answerOut:contiguous(); 450 | output['answer_ind'] = answerInd:contiguous(); 451 | if quesBwd then output['ques_bwd'] = quesBwd:contiguous(); end 452 | if history then output['hist'] = history:contiguous(); end 453 | if caption then output['cap'] = caption:contiguous(); end 454 | if imgFeats then output['img_feat'] = imgFeats:contiguous(); end 455 | end 456 | 457 | return output; 458 | end 459 | 460 | -- get batch from options given the indices 461 | function dataloader.getIndexOption(self, inds, params, dtype) 462 | local optionIn, optionOut, optionProb, answerProb; 463 | 464 | local optInds = self[dtype..'_opt']:index(1, inds); 465 | local indVector = optInds:view(-1); 466 | 467 | local batchOptLen = self[dtype..'_opt_len']:index(1, indVector); 468 | local maxOptLen = torch.max(batchOptLen); 469 | 470 | optionIn = self[dtype..'_opt_in']:index(1, indVector); 471 | optionIn = optionIn:view(optInds:size(1), optInds:size(2), 472 | optInds:size(3), -1); 473 | optionIn = optionIn[{{}, {}, {}, {1, maxOptLen}}]; 474 | 475 | optionOut = self[dtype..'_opt_out']:index(1, indVector); 476 | optionOut = optionOut:view(optInds:size(1), optInds:size(2), 477 | optInds:size(3), -1); 478 | optionOut = optionOut[{{}, {}, {}, {1, maxOptLen}}]; 479 | 480 | if self[dtype..'_opt_prob'] then 481 | optionProb = self[dtype..'_opt_prob']:index(1, indVector); 482 | optionProb = optionProb:viewAs(optInds); 483 | 484 | -- also get the answer probabilities 485 | local answerInds = self[dtype..'_ans_ind']:index(1, inds); 486 | indVector = answerInds:view(-1); 487 | answerProb = self[dtype..'_opt_prob']:index(1, indVector); 488 | answerProb = answerProb:viewAs(answerInds); 489 | end 490 | 491 | local output = {}; 492 | if params.gpuid >= 0 then 493 | output['option_in'] = optionIn:cuda(); 494 | output['option_out'] = optionOut:cuda(); 495 | if optionProb then 496 | output['option_prob'] = optionProb:cuda(); 497 | output['answer_prob'] = answerProb:cuda(); 498 | end 499 | else 500 | output['option_in'] = optionIn:contiguous(); 501 | output['option_out'] = optionOut:contiguous(); 502 | if optionProb then 503 | output['option_prob'] = optionProb:contiguous(); 504 | output['answer_prob'] = answerProb:contiguous(); 505 | end 506 | end 507 | 508 | return output; 509 | end 510 | 511 | return dataloader; 512 | -------------------------------------------------------------------------------- /chatbot/im-hist-enc-dec-answerer/lstm.lua: -------------------------------------------------------------------------------- 1 | -- lstm based models 2 | local lstm = {}; 3 | 4 | function lstm.buildModel(params) 5 | -- return encoder, nil, decoder 6 | return lstm:EncoderNet(params), lstm:DecoderNet(params); 7 | end 8 | 9 | function lstm.EncoderNet(self, params) 10 | local dropout = params.dropout or 0.2; 11 | -- Use `nngraph` 12 | nn.FastLSTM.usenngraph = true; 13 | 14 | -- encoder network 15 | local enc = nn.Sequential(); 16 | 17 | -- create the two branches 18 | local concat = nn.ConcatTable(); 19 | 20 | -- word branch, along with embedding layer 21 | self.wordEmbed = nn.LookupTableMaskZero(params.vocabSize, params.embedSize); 22 | local wordBranch = nn.Sequential():add(nn.SelectTable(1)):add(self.wordEmbed); 23 | 24 | -- language model 25 | enc.rnnLayers = {}; 26 | for layer = 1, params.numLayers do 27 | local inputSize = (layer==1) and (params.embedSize) 28 | or params.rnnHiddenSize; 29 | enc.rnnLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 30 | enc.rnnLayers[layer]:maskZero(); 31 | 32 | wordBranch:add(enc.rnnLayers[layer]); 33 | end 34 | wordBranch:add(nn.Select(1, -1)); 35 | 36 | -- make clones for embed layer 37 | local qEmbedNet = self.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 38 | local hEmbedNet = self.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 39 | 40 | -- create two branches 41 | local histBranch = nn.Sequential() 42 | :add(nn.SelectTable(2)) 43 | :add(hEmbedNet); 44 | enc.histLayers = {}; 45 | -- number of layers to read the history 46 | for layer = 1, params.numLayers do 47 | local inputSize = (layer == 1) and params.embedSize 48 | or params.rnnHiddenSize; 49 | enc.histLayers[layer] = nn.SeqLSTM(inputSize, params.rnnHiddenSize); 50 | enc.histLayers[layer]:maskZero(); 51 | 52 | histBranch:add(enc.histLayers[layer]); 53 | end 54 | histBranch:add(nn.Select(1, -1)); 55 | 56 | -- select words and image only 57 | local imageBranch = nn.Sequential() 58 | :add(nn.SelectTable(3)) 59 | :add(nn.Dropout(0.5)) 60 | :add(nn.Linear(params.imgFeatureSize, params.imgEmbedSize)) 61 | 62 | -- add concatTable and join 63 | concat:add(wordBranch) 64 | concat:add(histBranch) 65 | concat:add(imageBranch) 66 | enc:add(concat); 67 | 68 | -- another concat table 69 | local concat2 = nn.ConcatTable(); 70 | 71 | enc:add(nn.JoinTable(1, 1)) 72 | -- change the view of the data 73 | -- always split it back wrt batch size and then do transpose 74 | enc:add(nn.View(-1, params.maxQuesCount, 2*params.rnnHiddenSize + params.imgEmbedSize)); 75 | enc:add(nn.Transpose({1, 2})); 76 | enc:add(nn.View(params.maxQuesCount, -1, 2*params.rnnHiddenSize + params.imgEmbedSize)) 77 | enc:add(nn.SeqLSTM(2*params.rnnHiddenSize + params.imgEmbedSize, params.rnnHiddenSize)) 78 | enc:add(nn.Transpose({1, 2})); 79 | enc:add(nn.View(-1, params.rnnHiddenSize)) 80 | 81 | return enc; 82 | end 83 | 84 | function lstm.DecoderNet(self, params) 85 | local dropout = params.dropout or 0.2; 86 | -- Use `nngraph` 87 | nn.FastLSTM.usenngraph = true; 88 | 89 | -- decoder network 90 | local dec = nn.Sequential(); 91 | -- use the same embedding for both encoder and decoder lstm 92 | local embedNet = self.wordEmbed:clone('weight', 'bias', 'gradWeight', 'gradBias'); 93 | dec:add(embedNet); 94 | 95 | dec.rnnLayers = {}; 96 | -- check if decoder has different hidden size 97 | local hiddenSize = (params.ansHiddenSize ~= 0) and params.ansHiddenSize 98 | or params.rnnHiddenSize; 99 | for layer = 1, params.numLayers do 100 | local inputSize = (layer == 1) and params.embedSize or hiddenSize; 101 | dec.rnnLayers[layer] = nn.SeqLSTM(inputSize, hiddenSize); 102 | dec.rnnLayers[layer]:maskZero(); 103 | 104 | dec:add(dec.rnnLayers[layer]); 105 | end 106 | dec:add(nn.Sequencer(nn.MaskZero( 107 | nn.Linear(hiddenSize, params.vocabSize), 1))) 108 | dec:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1))) 109 | 110 | return dec; 111 | end 112 | ------------------------------------------------------------------------------- 113 | -- transfer the hidden state from encoder to decoder 114 | function lstm.forwardConnect(encOut, enc, dec, seqLen) 115 | for ii = 1, #enc.rnnLayers do 116 | dec.rnnLayers[ii].userPrevOutput = enc.rnnLayers[ii].output[seqLen]; 117 | dec.rnnLayers[ii].userPrevCell = enc.rnnLayers[ii].cell[seqLen]; 118 | end 119 | 120 | -- last layer gets output gradients 121 | dec.rnnLayers[#enc.rnnLayers].userPrevOutput = encOut; 122 | end 123 | 124 | -- transfer gradients from decoder to encoder 125 | function lstm.backwardConnect(enc, dec) 126 | -- borrow gradients from decoder 127 | for ii = 1, #dec.rnnLayers do 128 | enc.rnnLayers[ii].userNextGradCell = dec.rnnLayers[ii].userGradPrevCell; 129 | enc.rnnLayers[ii].gradPrevOutput = dec.rnnLayers[ii].userGradPrevOutput; 130 | end 131 | 132 | -- return the gradients for the last layer 133 | return dec.rnnLayers[#enc.rnnLayers].userGradPrevOutput; 134 | end 135 | return lstm; 136 | -------------------------------------------------------------------------------- /chatbot/im-hist-enc-dec-answerer/specificModel.lua: -------------------------------------------------------------------------------- 1 | -- Implements methods for specific type of model 2 | -- Need following methods, in a table 3 | -- a. buildSpecificModels(self, params) 4 | -- Used to build the particular model 5 | -- b. forwardBackward(self, batch) 6 | -- Used while performing training 7 | -- c. retrieveBatch(self, batch) 8 | -- Used to perform retrieval on a batch 9 | local specificModel = {}; 10 | local utils = require 'utils'; 11 | 12 | function specificModel:buildSpecificModel(params) 13 | -- build the model - encoder, decoder and answerNet 14 | local lm; 15 | local modelFile = string.format('%s/%s', params.model_name, 16 | params.languageModel); 17 | lm = require(modelFile); 18 | enc, dec = lm.buildModel(params); 19 | self.forwardConnect = lm.forwardConnect; 20 | self.backwardConnect = lm.backwardConnect; 21 | 22 | return enc, dec; 23 | end 24 | 25 | function specificModel:forwardBackward(batch, onlyForward) 26 | local onlyForward = onlyForward or false; 27 | local batchQues = batch['ques_fwd']; 28 | local batchHist = batch['hist']; 29 | local answerIn = batch['answer_in']; 30 | local answerOut = batch['answer_out']; 31 | local imgFeats = batch['img_feat']; 32 | 33 | -- resize to treat all rounds similarly 34 | -- transpose for timestep first 35 | batchQues = batchQues:view(-1, batchQues:size(3)):t(); 36 | batchHist = batchHist:view(-1, batchHist:size(3)):t(); 37 | answerIn = answerIn:view(-1, answerIn:size(3)):t(); 38 | answerOut = answerOut:view(-1, answerOut:size(3)):t(); 39 | 40 | -- process the image features based on the question (replicate features) 41 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize); 42 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1); 43 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize); 44 | 45 | -- forward pass 46 | -- print('batchQues', #batchQues) 47 | -- print('imgFeats', #imgFeats) 48 | -- print('batchHist', #batchHist) 49 | local encOut = self.encoder:forward({batchQues, batchHist, imgFeats}); 50 | -- print('encOut', #encOut) 51 | -- os.exit() 52 | -- forward connect encoder and decoder 53 | self.forwardConnect(encOut, self.encoder, self.decoder, batchQues:size(1)); 54 | local decOut = self.decoder:forward(answerIn); 55 | local curLoss = self.criterion:forward(decOut, answerOut); 56 | 57 | -- return only if forward is needed 58 | 59 | -- backward pass 60 | if onlyForward ~= true then 61 | do 62 | local gradCriterionOut = self.criterion:backward(decOut, answerOut); 63 | self.decoder:backward(answerIn, gradCriterionOut); 64 | --backward connect decoder and encoder 65 | local gradDecOut = self.backwardConnect(self.encoder, self.decoder); 66 | self.encoder:backward({batchQues, batchHist, imgFeats}, gradDecOut) 67 | end 68 | end 69 | 70 | return curLoss; 71 | end 72 | 73 | function specificModel:retrieveBatch(batch) 74 | local batchQues = batch['ques_fwd']; 75 | local batchHist = batch['hist']; 76 | local answerIn = batch['answer_in']; 77 | local answerOut = batch['answer_out']; 78 | local optionIn = batch['option_in']; 79 | local optionOut = batch['option_out']; 80 | local gtPosition = batch['answer_ind']:view(-1, 1); 81 | local imgFeats = batch['img_feat']; 82 | 83 | -- resize to treat all rounds similarly 84 | -- transpose for time step first 85 | batchQues = batchQues:view(-1, batchQues:size(3)):t(); 86 | batchHist = batchHist:view(-1, batchHist:size(3)):t(); 87 | answerIn = answerIn:view(-1, answerIn:size(3)):t(); 88 | answerOut = answerOut:view(-1, answerOut:size(3)):t(); 89 | optionIn = optionIn:view(-1, optionIn:size(3), optionIn:size(4)); 90 | optionOut = optionOut:view(-1, optionOut:size(3), optionOut:size(4)); 91 | optionIn = optionIn:transpose(1, 2):transpose(2, 3); 92 | optionOut = optionOut:transpose(1, 2):transpose(2, 3); 93 | 94 | -- process the image features based on the question (replicate features) 95 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize); 96 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1); 97 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize); 98 | 99 | -- forward pass 100 | local encOut = self.encoder:forward({batchQues, batchHist, imgFeats}); 101 | local batchSize = batchQues:size(2); 102 | -- tensor holds the likelihood for all the options 103 | local optionLhood = torch.Tensor(self.params.numOptions, batchSize); 104 | 105 | -- repeat for each option and get gt rank 106 | for opId = 1, self.params.numOptions do 107 | -- forward connect encoder and decoder 108 | self.forwardConnect(encOut, self.encoder, self.decoder, batchQues:size(1)); 109 | 110 | local curOptIn = optionIn[opId]; 111 | local curOptOut = optionOut[opId]; 112 | 113 | local decOut = self.decoder:forward(curOptIn); 114 | -- compute the probabilities for each answer, based on its tokens 115 | optionLhood[opId] = utils.computeLhood(curOptOut, decOut); 116 | end 117 | 118 | -- return the ranks for this batch 119 | return utils.computeRanks(optionLhood:t(), gtPosition); 120 | end 121 | 122 | -- code for answer generation 123 | function specificModel:encoderPass(batch) 124 | local batchQues = batch['ques_fwd']; 125 | local batchHist = batch['hist']; 126 | local imgFeats = batch['img_feat']; 127 | 128 | -- resize to treat all rounds similarly 129 | -- transpose for timestep first 130 | batchQues = batchQues:view(-1, batchQues:size(3)):t(); 131 | batchHist = batchHist:view(-1, batchHist:size(3)):t(); 132 | 133 | -- process the image features based on the question (replicate features) 134 | imgFeats = imgFeats:view(-1, 1, self.params.imgFeatureSize); 135 | imgFeats = imgFeats:repeatTensor(1, self.params.maxQuesCount, 1); 136 | imgFeats = imgFeats:view(-1, self.params.imgFeatureSize); 137 | 138 | -- forward pass 139 | local encOut = self.encoder:forward({batchQues, batchHist, imgFeats}); 140 | -- forward connect encoder and decoder 141 | self.forwardConnect(encOut, self.encoder, self.decoder, batchQues:size(1)); 142 | return encOut 143 | end 144 | 145 | function specificModel:forwardBackwardReinforce(imgFeats, ques, hist, ansIn, r_t, params) 146 | assert(imgFeats) -- make sure this is not `nil`, this is GT 147 | 148 | -- encoder forward pass, just sanity checking 149 | local encOut = self.encoder:forward({ques, hist, imgFeats}); 150 | 151 | -- forward connect encoder and decoder 152 | self.forwardConnect(encOut, self.encoder, self.decoder, ques:size(1)); 153 | local decOut = self.decoder:forward(ansIn); 154 | 155 | -- compute loss/reward per round 156 | local numRounds = 10; -- TODO increase a round 157 | local gradInput = torch.Tensor(30, numRounds, 7547):zero(); 158 | 159 | -- compute RL gradients 160 | local maxAnsLen = 30; 161 | for i = 1, numRounds do 162 | for j = 1, maxAnsLen-1 do 163 | if ansIn[j][i] ~= 0 and ansIn[j+1][i] == 0 then 164 | gradInput[j][i][7547] = r_t[i] 165 | elseif ansIn[j+1][i] ~= 0 then 166 | gradInput[j][i][ansIn[j+1][i]] = r_t[i] 167 | end 168 | end 169 | end 170 | 171 | -- backprop! 172 | self.decoder:backward(ansIn, gradInput); 173 | 174 | --backward connect decoder and encoder 175 | local gradDecOut = self.backwardConnect(self.encoder, self.decoder); 176 | self.encoder:backward({ques, hist, imgFeats}, gradDecOut) 177 | collectgarbage() 178 | end 179 | 180 | function specificModel:forwardBackwardAnnealedReinforceBatched(batch, r_t, params) 181 | local hist = batch[1] 182 | local imgFeats = batch[2] 183 | local ques = batch[3] 184 | local ansIn = batch[4] 185 | local ansOut = batch[5] 186 | 187 | assert(imgFeats) -- make sure this is not `nil`, this is GT 188 | 189 | local numSLRounds = params.numSLRounds 190 | 191 | -- encoder forward pass, just sanity checking 192 | local encOut = self.encoder:forward({ques, hist, imgFeats}); 193 | 194 | -- forward connect encoder and decoder 195 | self.forwardConnect(encOut, self.encoder, self.decoder, ques:size(1)); 196 | local decOut = self.decoder:forward(ansIn); 197 | 198 | local maxRounds = 10; 199 | local numRounds = params.batchSize * maxRounds; -- TODO increase a round 200 | local gradInput = torch.Tensor(30, numRounds, 7547):zero(); 201 | 202 | if numSLRounds ~= 0 then 203 | local SLRoundInds, SLRoundIndsIdx = torch.LongTensor(numSLRounds * params.batchSize):zero(), 1; 204 | for i = 1, numRounds do 205 | if i % maxRounds ~= 0 and i % maxRounds <= numSLRounds then 206 | SLRoundInds[SLRoundIndsIdx] = i 207 | SLRoundIndsIdx = SLRoundIndsIdx + 1 208 | end 209 | end 210 | 211 | -- compute SL gradients 212 | local seqLoss = self.criterion:forward(decOut:index(2, SLRoundInds), ansOut:index(2, SLRoundInds)) 213 | gradInput:indexCopy(2, SLRoundInds, self.criterion:backward(decOut:index(2, SLRoundInds), ansOut:index(2, SLRoundInds))) 214 | end 215 | 216 | -- compute RL gradients 217 | local maxAnsLen = 30; 218 | for i = 1, numRounds do 219 | if i % maxRounds > numSLRounds then 220 | for j = 1, maxAnsLen-1 do 221 | if ansIn[j][i] ~= 0 and ansIn[j+1][i] == 0 then 222 | gradInput[j][i][7547] = r_t[i] 223 | elseif ansIn[j+1][i] ~= 0 then 224 | gradInput[j][i][ansIn[j+1][i]] = r_t[i] 225 | end 226 | end 227 | end 228 | end 229 | 230 | -- backprop! 231 | self.decoder:backward(ansIn, gradInput); 232 | 233 | --backward connect decoder and encoder 234 | local gradDecOut = self.backwardConnect(self.encoder, self.decoder); 235 | self.encoder:backward({ques, hist, imgFeats}, gradDecOut) 236 | collectgarbage() 237 | return seqLoss 238 | end 239 | 240 | function specificModel:multitaskReinforceForwardBackward(batch, r_t, params) 241 | local hist = batch[1] 242 | local imgFeats = batch[2] 243 | local ques = batch[3] 244 | local ansIn = batch[4] 245 | local ansOut = batch[5] 246 | local ansSample = batch[6]:t() 247 | 248 | assert(imgFeats) -- make sure this is not `nil`, this is GT 249 | 250 | -- encoder forward pass, just sanity checking 251 | local encOut = self.encoder:forward({ques, hist, imgFeats}); 252 | 253 | -- forward connect encoder and decoder 254 | self.forwardConnect(encOut, self.encoder, self.decoder, ques:size(1)); 255 | local decOut = self.decoder:forward(ansIn); 256 | 257 | local maxRounds = 10; 258 | local numRounds = params.batchSize * maxRounds; -- TODO increase a round 259 | local gradInput = torch.Tensor(30, numRounds, 7547):zero(); 260 | 261 | -- compute SL gradients 262 | local seqLoss = self.criterion:forward(decOut, ansOut) 263 | local gradInputSL = self.criterion:backward(decOut, ansOut) 264 | 265 | -- compute RL gradients 266 | local gradInputRL = torch.Tensor(30, numRounds, 7547):zero(); 267 | local maxAnsLen = 30; 268 | for i = 1, numRounds do 269 | for j = 2, maxAnsLen-1 do 270 | if ansSample[j][i] ~= 0 then 271 | gradInputRL[j-1][i][ansSample[j][i]] = r_t[i] 272 | end 273 | end 274 | end 275 | 276 | -- print(params) 277 | gradInput = gradInputSL + params.lambda * gradInputRL; 278 | 279 | -- backprop! 280 | self.decoder:backward(ansIn, gradInput); 281 | 282 | --backward connect decoder and encoder 283 | local gradDecOut = self.backwardConnect(self.encoder, self.decoder); 284 | self.encoder:backward({ques, hist, imgFeats}, gradDecOut) 285 | collectgarbage() 286 | return seqLoss 287 | end 288 | return specificModel; 289 | -------------------------------------------------------------------------------- /chatbot/optim_updates.lua: -------------------------------------------------------------------------------- 1 | --Author: Andrej Karpathy https://github.com/karpathy 2 | --Project: neuraltalk2 https://github.com/karpathy/neuraltalk2 3 | --Slightly modified by Xiao Lin for initial values of rmsprop. 4 | 5 | -- optim, simple as it should be, written from scratch. That's how I roll 6 | 7 | function sgd(x, dx, lr) 8 | x:add(-lr, dx) 9 | end 10 | 11 | function sgdm(x, dx, lr, alpha, state) 12 | -- sgd with momentum, standard update 13 | if not state.v then 14 | state.v = x.new(#x):zero() 15 | end 16 | state.v:mul(alpha) 17 | state.v:add(lr, dx) 18 | x:add(-1, state.v) 19 | end 20 | 21 | function sgdmom(x, dx, lr, alpha, state) 22 | -- sgd momentum, uses nesterov update (reference: http://cs231n.github.io/neural-networks-3/#sgd) 23 | if not state.m then 24 | state.m = x.new(#x):zero() 25 | state.tmp = x.new(#x) 26 | end 27 | state.tmp:copy(state.m) 28 | state.m:mul(alpha):add(-lr, dx) 29 | x:add(-alpha, state.tmp) 30 | x:add(1+alpha, state.m) 31 | end 32 | 33 | function adagrad(x, dx, lr, epsilon, state) 34 | if not state.m then 35 | state.m = x.new(#x):zero() 36 | state.tmp = x.new(#x) 37 | end 38 | -- calculate new mean squared values 39 | state.m:addcmul(1.0, dx, dx) 40 | -- perform update 41 | state.tmp:sqrt(state.m):add(epsilon) 42 | x:addcdiv(-lr, dx, state.tmp) 43 | end 44 | 45 | -- rmsprop implementation, simple as it should be 46 | function rmsprop(x, dx, state) 47 | local alpha = state.alpha or 0.99; 48 | local learningRate = state.learningRate or 1e-2; 49 | local epsilon = state.epsilon or 1e-8; 50 | if not state.m then 51 | state.m = x.new(#x):zero() 52 | state.tmp = x.new(#x) 53 | end 54 | -- calculate new (leaky) mean squared values 55 | state.m:mul(alpha) 56 | state.m:addcmul(1.0-alpha, dx, dx) 57 | -- perform update 58 | state.tmp:sqrt(state.m):add(epsilon) 59 | x:addcdiv(-learningRate, dx, state.tmp) 60 | end 61 | 62 | function adam(x, dx, state) 63 | local beta1 = state.beta1 or 0.9 64 | local beta2 = state.beta2 or 0.999 65 | local epsilon = state.epsilon or 1e-8 66 | local lr = state.learningRate or 1e-2; 67 | 68 | if not state.m then 69 | -- Initialization 70 | state.t = 0 71 | -- Exponential moving average of gradient values 72 | state.m = x.new(#dx):zero() 73 | -- Exponential moving average of squared gradient values 74 | state.v = x.new(#dx):zero() 75 | -- A tmp tensor to hold the sqrt(v) + epsilon 76 | state.tmp = x.new(#dx):zero() 77 | end 78 | 79 | -- Decay the first and second moment running average coefficient 80 | state.m:mul(beta1):add(1-beta1, dx) 81 | state.v:mul(beta2):addcmul(1-beta2, dx, dx) 82 | state.tmp:copy(state.v):sqrt():add(epsilon) 83 | 84 | state.t = state.t + 1 85 | local biasCorrection1 = 1 - beta1^state.t 86 | local biasCorrection2 = 1 - beta2^state.t 87 | local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 88 | 89 | -- perform update 90 | x:addcdiv(-stepSize, state.m, state.tmp) 91 | end 92 | -------------------------------------------------------------------------------- /chatbot/opts.lua: -------------------------------------------------------------------------------- 1 | cmd = torch.CmdLine() 2 | cmd:text() 3 | cmd:text('Options') 4 | -- Data input settings 5 | cmd:option('-input_img_h5','data/visdial_0.5/data_img.h5','h5file path with image feature') 6 | cmd:option('-input_ques_h5','data/visdial_0.5/chat_processed_data.h5','h5file file with preprocessed questions') 7 | cmd:option('-input_json','data/visdial_0.5/chat_processed_params.json','json path with info and vocab') 8 | 9 | cmd:option('-save_path', 'models/', 'path to save the model and checkpoints') 10 | cmd:option('-model_name', 'im-hist-enc-dec-answerer', 'Name of the model to use for answering') 11 | 12 | cmd:option('-img_norm', 1, 'normalize the image feature. 1=yes, 0=no') 13 | cmd:option('-load_path_a', '', 'path to saved answerer model') 14 | 15 | -- model params 16 | cmd:option('-metaHiddenSize', 100, 'Size of the hidden layer for meta-rnn'); 17 | cmd:option('-multiEmbedSize', 1024, 'Size of multimodal embedding for q+i') 18 | cmd:option('-imgEmbedSize', 300, 'Size of the multimodal embeddings'); 19 | cmd:option('-imgFeatureSize', 4096, 'Size of the deep image feature'); 20 | cmd:option('-embedSize', 300, 'Size of input word embeddings') 21 | cmd:option('-rnnHiddenSize', 512, 'Size of the hidden language rnn in each layer') 22 | cmd:option('-ansHiddenSize', 0, 'Size of the hidden language rnn in each layer for answers') 23 | cmd:option('-maxHistoryLen', 60, 'Maximum history to consider when using appended qa pairs'); 24 | cmd:option('-numLayers', 2, 'number of the rnn layer') 25 | cmd:option('-languageModel', 'lstm', 'rnn to use for language model, lstm | gru') 26 | cmd:option('-bidirectional', 0, 'Bi-directional language model') 27 | cmd:option('-metric', 'llh', 'Metric to use for retrieval, llh | mi') 28 | cmd:option('-lambda', '1.0', 'Factor for marginalized probability for mi metric') 29 | 30 | -- optimization params 31 | cmd:option('-batchSize', 30, 'Batch size (number of threads) (Adjust base on GRAM)'); 32 | cmd:option('-probSampleSize', 50, 'Number of samples for computing probability'); 33 | cmd:option('-learningRate', 1e-3, 'Learning rate'); 34 | cmd:option('-dropout', 0, 'Dropout for language model'); 35 | cmd:option('-numEpochs', 400, 'Epochs'); 36 | cmd:option('-LRateDecay', 10, 'After lr_decay epochs lr reduces to 0.1*lr'); 37 | cmd:option('-lrDecayRate', 0.9997592083, 'Decay for learning rate') 38 | cmd:option('-minLRate', 5e-5, 'Minimum learning rate'); 39 | cmd:option('-gpuid', 0, 'GPU id to use') 40 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 41 | 42 | local opts = cmd:parse(arg); 43 | 44 | -- if save path is not given, use default..time 45 | -- get the current time 46 | local curTime = os.date('*t', os.time()); 47 | -- create another folder to avoid clutter 48 | local modelPath = string.format('models/model-%d-%d-%d-%d:%d:%d-%s/', 49 | curTime.month, curTime.day, curTime.year, 50 | curTime.hour, curTime.min, curTime.sec, opts.model_name) 51 | if opts.save_path == 'models/' then opts.save_path = modelPath end; 52 | -- add useMI flag if the metric is mutual information 53 | if opts.metric == 'mi' then opts.useMI = true; end 54 | if opts.bidirectional == 0 then opts.useBi = nil; else opts.useBi = true; end 55 | -- additionally check if its imitation of discriminative model 56 | if string.match(opts.model_name, 'hist') then 57 | opts.useHistory = true; 58 | if string.match(opts.model_name, 'disc') then 59 | opts.separateCaption = true; 60 | end 61 | end 62 | if string.match(opts.model_name, 'im') then opts.useIm = true; end 63 | 64 | return opts; 65 | -------------------------------------------------------------------------------- /chatbot/prepro_ques.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import re 4 | from nltk.tokenize import word_tokenize 5 | 6 | regex = re.compile('[^\sa-zA-Z]') 7 | 8 | def main(): 9 | ques = regex.sub('',args.question) + ' ?' 10 | question = word_tokenize(ques.replace('?', ' ? ').strip().lower())[:args.ques_len] 11 | # question = word_tokenize(args.question.replace('?', ' ? ').strip().lower())[:args.ques_len] 12 | history_facts = args.history.replace('?', ' ? ').split(args.delimiter) 13 | history, questions = [], [] 14 | for i in history_facts: 15 | fact = word_tokenize(i.strip().lower())[:args.fact_len] 16 | if len(fact) != 0: 17 | history.append(fact) 18 | try: 19 | questions.append(fact[:fact.index('?')+1]) 20 | except: 21 | pass 22 | 23 | num_hist = min(len(history), 10) 24 | num_ques = num_hist - 1 if num_hist > 0 else 0 25 | json.dump({'question': question, 'history': history[-num_hist:], 'questions': questions[-num_ques:]}, open('ques_feat.json', 'w')) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('-question', type=str, default='') 31 | parser.add_argument('-ques_len', type=int, default=15) 32 | parser.add_argument('-history', type=str, default='') 33 | parser.add_argument('-fact_len', type=int, default=30) 34 | parser.add_argument('-delimiter', type=str, default='||||') 35 | args = parser.parse_args() 36 | main() 37 | -------------------------------------------------------------------------------- /chatbot/rl_evaluate.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | History will not have , only 3 | ]] 4 | require 'nn' 5 | require 'nngraph' 6 | require 'cjson' 7 | require 'rnn' 8 | require 'modelAnswerer' 9 | utils = dofile('utils.lua') 10 | 11 | local RLTorchModel = torch.class('RLConversationModel') 12 | 13 | function RLTorchModel:__init(inputJson, qBotpath, aBotpath, gpuid, backend, imfeatpath) 14 | 15 | -- Load the image features 16 | self.imfeats = torch.load(imfeatpath) 17 | 18 | -- Model paths 19 | self.qBotpath = qBotpath 20 | self.aBotpath = aBotpath 21 | self.gpuid = gpuid 22 | self.backend = backend 23 | 24 | -- Create options table to initialize dataloader 25 | self.opt = {} 26 | self.opt['input_json'] = inputJson 27 | self.dataloader = dofile('dataloader.lua') 28 | self.dataloader:initialize(self.opt) 29 | 30 | -- Initial seeds 31 | torch.manualSeed(1234) 32 | if self.gpuid >= 0 then 33 | require 'cutorch' 34 | require 'cunn' 35 | if self.backend == 'cudnn' then require 'cudnn' end 36 | cutorch.setDevice(1) 37 | cutorch.manualSeed(1234) 38 | torch.setdefaulttensortype('torch.CudaTensor') 39 | else 40 | torch.setdefaulttensortype('torch.DoubleTensor') 41 | end 42 | 43 | -- Load Questioner and Answerer model 44 | self.questionerModel = torch.load(qBotpath) 45 | self.answererModel = torch.load(aBotpath) 46 | 47 | -- transfer all options to model 48 | self.questionerModelParams = self.questionerModel.modelParams 49 | self.answererModelParams = self.answererModel.modelParams 50 | 51 | -- changing savepath in checkpoints 52 | self.questionerModelParams['model_name'] = 'im-hist-enc-dec-questioner' 53 | self.answererModelParams['model_name'] = 'im-hist-enc-dec-answerer' 54 | 55 | -- Print Questioner and Answerer 56 | -- print('Questioner', self.questionerModelParams.model_name) 57 | -- print('Answerer', self.answererModelParams.model_name) 58 | 59 | -- Add flags for various configurations 60 | if string.match(self.questionerModelParams.model_name, 'hist') then self.questionerModelParams.useHistory = true; end 61 | if string.match(self.answererModelParams.model_name, 'hist') then self.answererModelParams.useHistory = true; end 62 | if string.match(self.answererModelParams.model_name, 'im') then self.answererModelParams.useIm = true; end 63 | 64 | -- Setup both Qbot and Abot 65 | -- print('Using models from'.. self.questionerModelParams.model_name) 66 | -- print('Using models from'.. self.answererModelParams.model_name) 67 | self.qModel = VisDialQModel(self.questionerModelParams) 68 | self.aModel = VisDialAModel(self.answererModelParams) 69 | 70 | -- copy weights from loaded model 71 | self.qModel.wrapperW:copy(self.questionerModel.modelW) 72 | self.aModel.wrapperW:copy(self.answererModel.modelW) 73 | 74 | -- set models to evaluate mode 75 | self.qModel.wrapper:evaluate() 76 | self.aModel.wrapper:evaluate() 77 | end 78 | 79 | --[[ 80 | ABot method implementation is exactly similar to the Visual Dialog Model 81 | Need to clarify the left/right alignment of questions/etc 82 | ]] 83 | 84 | function RLTorchModel:abot(imgId, history, question) 85 | -- Get image-feature 86 | print(imgId) 87 | print(history) 88 | print(question) 89 | print(self.imfeats) 90 | local imgFeat = self.imfeats[imgId] 91 | imgFeat = torch.repeatTensor(imgFeat, 10, 1) 92 | -- Concatenate history 93 | local history_concat = '' 94 | for i=1, #history do 95 | history_concat = history_concat .. history[i] .. ' |||| ' 96 | end 97 | -- -- Remove from history 98 | -- history_concat = history_concat:gsub('','') 99 | -- if history_concat ~= '' then 100 | -- history_concat = history_concat .. ' ' 101 | -- end 102 | -- get pre-processed QA+Hist 103 | local cmd = 'python prepro_ques.py -question "' .. question .. '" -history "' .. history_concat .. '"' 104 | os.execute(cmd) 105 | local file = io.open('ques_feat.json', 'r') 106 | if file then 107 | json_f = file:read('*a') 108 | qh_feats = cjson.decode(json_f) 109 | file:close() 110 | end 111 | -- Get question vector 112 | local ques_vector = utils.wordsToId(qh_feats.question, self.dataloader.word2ind, 20) 113 | -- Get history Tensor and hist_len vector 114 | local hist_tensor = torch.LongTensor(10, 40):zero() 115 | local hist_len = torch.zeros(10) 116 | for i=1, #qh_feats.history do 117 | hist_tensor[i] = utils.wordsToId(qh_feats.history[i], self.dataloader.word2ind, 40) 118 | hist_len[i] = hist_tensor[i][hist_tensor[i]:ne(0)]:size(1) 119 | end 120 | -- Get question Tensor 121 | local ques_tensor = torch.LongTensor(10, 20):zero() 122 | local ques_len = torch.zeros(10) 123 | for i=1, #qh_feats.questions do 124 | ques_tensor[i] = utils.wordsToId(qh_feats.questions[i], self.dataloader.word2ind, 20) 125 | ques_len[i] = ques_tensor[i][ques_tensor[i]:ne(0)]:size(1) 126 | end 127 | -- Parameter for generating answers 128 | local iter = #qh_feats.questions + 1 129 | ques_tensor[iter] = ques_vector 130 | -- Right align the questions 131 | -- ques_tensor = utils.rightAlign(ques_tensor, ques_len) 132 | -- Right align the history 133 | -- hist_tensor = utils.rightAlign(hist_tensor, hist_len) 134 | -- Transpose the question and history 135 | ques_tensor = ques_tensor:t() 136 | hist_tensor = hist_tensor:t() 137 | -- Shift to GPU 138 | if self.gpuid >= 0 then 139 | ques_tensor = ques_tensor:cuda() 140 | hist_tensor = hist_tensor:cuda() 141 | imgFeat = imgFeat:cuda() 142 | end 143 | -- Generate answer; returns a table :-> {ansWords, aLen, ansText} 144 | local ans_struct = self.aModel:generateSingleAnswer(self.dataloader, {hist_tensor, imgFeat, ques_tensor}, {beamSize = 5}, iter) 145 | -- Use answer-text to concatenate things to show at subject's end 146 | local answer = ans_struct[3] 147 | local result = {} 148 | result['answer'] = answer 149 | result['question'] = question 150 | if history_concat == '||||' then 151 | history_concat = '' 152 | end 153 | result['history'] = history_concat .. question .. ' ' .. answer 154 | result['history'] = string.gsub(result['history'], '', '') 155 | result['input_img'] = imgId 156 | return result 157 | end -------------------------------------------------------------------------------- /chatbot/rl_worker.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import sys 5 | sys.path.append('..') 6 | 7 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'demo.settings') 8 | 9 | import django 10 | django.setup() 11 | 12 | from django.conf import settings 13 | from amt.utils import log_to_terminal 14 | 15 | import amt.constants as constants 16 | import PyTorch 17 | import PyTorchHelpers 18 | import pika 19 | import time 20 | import yaml 21 | import json 22 | import traceback 23 | 24 | RLVisDialModel = PyTorchHelpers.load_lua_class( 25 | constants.RL_VISDIAL_LUA_PATH, 'RLConversationModel') 26 | 27 | RLVisDialATorchModel = RLVisDialModel( 28 | constants.RL_VISDIAL_CONFIG['inputJson'], 29 | constants.RL_VISDIAL_CONFIG['qBotpath'], 30 | constants.RL_VISDIAL_CONFIG['aBotpath'], 31 | constants.RL_VISDIAL_CONFIG['gpuid'], 32 | constants.RL_VISDIAL_CONFIG['backend'], 33 | constants.RL_VISDIAL_CONFIG['imfeatpath'], 34 | ) 35 | 36 | connection = pika.BlockingConnection(pika.ConnectionParameters( 37 | host='localhost')) 38 | 39 | channel = connection.channel() 40 | 41 | channel.queue_declare(queue='rl_chatbot_queue', durable=True) 42 | 43 | 44 | def callback(ch, method, properties, body): 45 | try: 46 | body = yaml.safe_load(body) 47 | body['history'] = body['history'].split("||||") 48 | 49 | # Get the imageid here so that use the extracted features in lua script 50 | image_id = body['image_path'].split("/")[-1].replace(".jpg", "") 51 | 52 | result = RLVisDialATorchModel.abot( 53 | image_id, 54 | body['history'], 55 | body['input_question']) 56 | 57 | result['question'] = str(result['question']) 58 | result['answer'] = str(result['answer']) 59 | result['history'] = result['history'] 60 | result['history'] = result['history'].replace("", "") 61 | result['history'] = result['history'].replace("", "") 62 | 63 | log_to_terminal(body['socketid'], {"result": json.dumps(result)}) 64 | ch.basic_ack(delivery_tag=method.delivery_tag) 65 | 66 | except Exception, err: 67 | print str(traceback.print_exc()) 68 | 69 | channel.basic_consume(callback, 70 | queue='rl_chatbot_queue') 71 | 72 | channel.start_consuming() 73 | -------------------------------------------------------------------------------- /chatbot/sl_evaluate.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | History will not have , only 3 | ]] 4 | require 'nn' 5 | require 'nngraph' 6 | require 'cjson' 7 | require 'rnn' 8 | require 'modelAnswerer' 9 | utils = dofile('utils.lua') 10 | 11 | local TorchModel = torch.class('SLConversationModel') 12 | 13 | function TorchModel:__init(inputJson, qBotpath, aBotpath, gpuid, backend, imfeatpath) 14 | -- Load the image features 15 | print(imfeatpath) 16 | self.imfeats = torch.load(imfeatpath) 17 | print(self.imfeats) 18 | print(#self.imfeats) 19 | 20 | -- Model paths 21 | self.qBotpath = qBotpath 22 | self.aBotpath = aBotpath 23 | self.gpuid = gpuid 24 | self.backend = backend 25 | 26 | -- Create options table to initialize dataloader 27 | self.opt = {} 28 | self.opt['input_json'] = inputJson 29 | self.dataloader = dofile('dataloader.lua') 30 | self.dataloader:initialize(self.opt) 31 | 32 | -- Initial seeds 33 | torch.manualSeed(1234) 34 | if self.gpuid >= 0 then 35 | require 'cutorch' 36 | require 'cunn' 37 | if self.backend == 'cudnn' then require 'cudnn' end 38 | cutorch.setDevice(1) 39 | cutorch.manualSeed(1234) 40 | torch.setdefaulttensortype('torch.CudaTensor') 41 | else 42 | torch.setdefaulttensortype('torch.DoubleTensor') 43 | end 44 | 45 | -- Load Questioner and Answerer model 46 | self.questionerModel = torch.load(qBotpath) 47 | self.answererModel = torch.load(aBotpath) 48 | 49 | -- transfer all options to model 50 | self.questionerModelParams = self.questionerModel.modelParams 51 | self.answererModelParams = self.answererModel.modelParams 52 | 53 | -- changing savepath in checkpoints 54 | self.questionerModelParams['model_name'] = 'im-hist-enc-dec-questioner' 55 | self.answererModelParams['model_name'] = 'im-hist-enc-dec-answerer' 56 | 57 | -- Print Questioner and Answerer 58 | print('Questioner', self.questionerModelParams.model_name) 59 | print('Answerer', self.answererModelParams.model_name) 60 | 61 | -- Add flags for various configurations 62 | if string.match(self.questionerModelParams.model_name, 'hist') then self.questionerModelParams.useHistory = true; end 63 | if string.match(self.answererModelParams.model_name, 'hist') then self.answererModelParams.useHistory = true; end 64 | if string.match(self.answererModelParams.model_name, 'im') then self.answererModelParams.useIm = true; end 65 | 66 | -- Setup both Qbot and Abot 67 | print('Using models from'.. self.questionerModelParams.model_name) 68 | print('Using models from'.. self.answererModelParams.model_name) 69 | self.qModel = VisDialQModel(self.questionerModelParams) 70 | self.aModel = VisDialAModel(self.answererModelParams) 71 | 72 | -- copy weights from loaded model 73 | self.qModel.wrapperW:copy(self.questionerModel.modelW) 74 | self.aModel.wrapperW:copy(self.answererModel.modelW) 75 | 76 | -- set models to evaluate mode 77 | self.qModel.wrapper:evaluate() 78 | self.aModel.wrapper:evaluate() 79 | end 80 | 81 | --[[ 82 | ABot method implementation is exactly similar to the Visual Dialog Model 83 | Need to clarify the left/right alignment of questions/etc 84 | ]] 85 | 86 | function TorchModel:abot(imgId, history, question) 87 | -- Get image-feature 88 | local imgFeat = self.imfeats[imgId] 89 | imgFeat = torch.repeatTensor(imgFeat, 10, 1) 90 | -- Concatenate history 91 | local history_concat = '' 92 | for i=1, #history do 93 | history_concat = history_concat .. history[i] .. ' |||| ' 94 | end 95 | -- -- Remove from history 96 | -- history_concat = history_concat:gsub('','') 97 | -- if history_concat ~= '' then 98 | -- history_concat = history_concat .. ' ' 99 | -- end 100 | -- get pre-processed QA+Hist 101 | local cmd = 'python prepro_ques.py -question "' .. question .. '" -history "' .. history_concat .. '"' 102 | os.execute(cmd) 103 | local file = io.open('ques_feat.json', 'r') 104 | if file then 105 | json_f = file:read('*a') 106 | qh_feats = cjson.decode(json_f) 107 | file:close() 108 | end 109 | -- Get question vector 110 | local ques_vector = utils.wordsToId(qh_feats.question, self.dataloader.word2ind, 20) 111 | -- Get history Tensor and hist_len vector 112 | local hist_tensor = torch.LongTensor(10, 40):zero() 113 | local hist_len = torch.zeros(10) 114 | for i=1, #qh_feats.history do 115 | hist_tensor[i] = utils.wordsToId(qh_feats.history[i], self.dataloader.word2ind, 40) 116 | hist_len[i] = hist_tensor[i][hist_tensor[i]:ne(0)]:size(1) 117 | end 118 | -- Get question Tensor 119 | local ques_tensor = torch.LongTensor(10, 20):zero() 120 | local ques_len = torch.zeros(10) 121 | for i=1, #qh_feats.questions do 122 | ques_tensor[i] = utils.wordsToId(qh_feats.questions[i], self.dataloader.word2ind, 20) 123 | ques_len[i] = ques_tensor[i][ques_tensor[i]:ne(0)]:size(1) 124 | end 125 | -- Parameter for generating answers 126 | local iter = #qh_feats.questions + 1 127 | ques_tensor[iter] = ques_vector 128 | -- Right align the questions 129 | -- ques_tensor = utils.rightAlign(ques_tensor, ques_len) 130 | -- Right align the history 131 | -- hist_tensor = utils.rightAlign(hist_tensor, hist_len) 132 | -- Transpose the question and history 133 | ques_tensor = ques_tensor:t() 134 | hist_tensor = hist_tensor:t() 135 | -- Shift to GPU 136 | if self.gpuid >= 0 then 137 | ques_tensor = ques_tensor:cuda() 138 | hist_tensor = hist_tensor:cuda() 139 | imgFeat = imgFeat:cuda() 140 | end 141 | -- Generate answer; returns a table :-> {ansWords, aLen, ansText} 142 | local ans_struct = self.aModel:generateSingleAnswer(self.dataloader, {hist_tensor, imgFeat, ques_tensor}, {beamSize = 5}, iter) 143 | -- Use answer-text to concatenate things to show at subject's end 144 | local answer = ans_struct[3] 145 | local result = {} 146 | result['answer'] = answer 147 | result['question'] = question 148 | if history_concat == '||||' then 149 | history_concat = '' 150 | end 151 | result['history'] = history_concat .. question .. ' ' .. answer 152 | result['history'] = string.gsub(result['history'], '', '') 153 | result['input_img'] = imgId 154 | return result 155 | end -------------------------------------------------------------------------------- /chatbot/sl_worker.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import sys 5 | sys.path.append('..') 6 | 7 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'demo.settings') 8 | 9 | import django 10 | django.setup() 11 | 12 | from django.conf import settings 13 | 14 | from amt.utils import log_to_terminal 15 | 16 | import amt.constants as constants 17 | import PyTorch 18 | import PyTorchHelpers 19 | import pika 20 | import time 21 | import yaml 22 | import json 23 | import traceback 24 | import signal 25 | import requests 26 | import atexit 27 | 28 | VisDialModel = PyTorchHelpers.load_lua_class( 29 | constants.SL_VISDIAL_LUA_PATH, 'SLConversationModel') 30 | 31 | VisDialATorchModel = VisDialModel( 32 | constants.SL_VISDIAL_CONFIG['inputJson'], 33 | constants.SL_VISDIAL_CONFIG['qBotpath'], 34 | constants.SL_VISDIAL_CONFIG['aBotpath'], 35 | constants.SL_VISDIAL_CONFIG['gpuid'], 36 | constants.SL_VISDIAL_CONFIG['backend'], 37 | constants.SL_VISDIAL_CONFIG['imfeatpath'], 38 | ) 39 | 40 | connection = pika.BlockingConnection(pika.ConnectionParameters( 41 | host='localhost')) 42 | 43 | channel = connection.channel() 44 | 45 | channel.queue_declare(queue='sl_chatbot_queue', durable=True) 46 | 47 | 48 | def callback(ch, method, properties, body): 49 | try: 50 | body = yaml.safe_load(body) 51 | body['history'] = body['history'].split("||||") 52 | 53 | # get the imageid here so that use the extracted features in lua script 54 | image_id = body['image_path'].split("/")[-1].replace(".jpg", "") 55 | 56 | print image_id 57 | print type(image_id) 58 | 59 | result = VisDialATorchModel.abot( 60 | image_id, body['history'], body['input_question']) 61 | result['input_image'] = body['image_path'] 62 | result['question'] = str(result['question']) 63 | result['answer'] = str(result['answer']) 64 | result['history'] = result['history'].replace("", "") 65 | result['history'] = result['history'].replace("", "") 66 | # Store the result['predicted_fc7'] in the database after each round 67 | log_to_terminal(body['socketid'], {"result": json.dumps(result)}) 68 | ch.basic_ack(delivery_tag=method.delivery_tag) 69 | 70 | except Exception, err: 71 | print str(traceback.print_exc()) 72 | 73 | channel.basic_consume(callback, 74 | queue='sl_chatbot_queue') 75 | 76 | channel.start_consuming() 77 | -------------------------------------------------------------------------------- /chatbot/testAnswerer.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'io' 4 | require 'rnn' 5 | utils = dofile('utils.lua'); 6 | 7 | ------------------------------------------------------------------------------- 8 | -- Input arguments and options 9 | ------------------------------------------------------------------------------- 10 | cmd = torch.CmdLine() 11 | cmd:text() 12 | cmd:text('Options') 13 | -- Data input settings 14 | cmd:option('-input_img_h5','data/visdial_0.5/data_img.h5','h5file path with image feature') 15 | cmd:option('-input_ques_h5','data/visdial_0.5/chat_processed_data.h5','h5file file with preprocessed questions') 16 | cmd:option('-input_json','data/visdial_0.5/chat_processed_params.json','json path with info and vocab') 17 | 18 | cmd:option('-load_path', 'models/model-2-14-2017-22:43:51-im-hist-enc-dec/model_epoch_20.t7', 'path to saved model') 19 | cmd:option('-result_path', 'results', 'path to save generated results') 20 | 21 | -- optimization params 22 | cmd:option('-batchSize', 200, 'Batch size (number of threads) (Adjust base on GRAM)'); 23 | cmd:option('-gpuid', 0, 'GPU id to use') 24 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 25 | 26 | local opt = cmd:parse(arg); 27 | print(opt) 28 | 29 | -- seed for reproducibility 30 | torch.manualSeed(1234); 31 | 32 | -- set default tensor based on gpu usage 33 | if opt.gpuid >= 0 then 34 | require 'cutorch' 35 | require 'cunn' 36 | --if opt.backend == 'cudnn' then require 'cudnn' end 37 | torch.setdefaulttensortype('torch.CudaTensor'); 38 | else 39 | torch.setdefaulttensortype('torch.DoubleTensor'); 40 | end 41 | 42 | ------------------------------------------------------------------------ 43 | -- Read saved model and parameters 44 | ------------------------------------------------------------------------ 45 | local savedModel = torch.load(opt.load_path); 46 | 47 | -- transfer all options to model 48 | local modelParams = savedModel.modelParams; 49 | opt.img_norm = modelParams.img_norm; 50 | opt.model_name = modelParams.model_name; 51 | print(opt.model_name) 52 | 53 | -- add flags for various configurations 54 | -- additionally check if its imitation of discriminative model 55 | if string.match(opt.model_name, 'hist') then 56 | opt.useHistory = true; 57 | if string.match(opt.model_name, 'disc') then 58 | opt.separateCaption = true; 59 | end 60 | end 61 | if string.match(opt.model_name, 'im') then opt.useIm = true; end 62 | ------------------------------------------------------------------------ 63 | -- Loading dataset 64 | ------------------------------------------------------------------------ 65 | local dataloader = dofile('dataloader.lua') 66 | dataloader:initialize(opt, {'test'}); 67 | collectgarbage(); 68 | 69 | ------------------------------------------------------------------------ 70 | -- Setup the model 71 | ------------------------------------------------------------------------ 72 | require 'modelAnswerer' 73 | print('Using models from '..modelParams.model_name) 74 | local svqaModel = VisDialAModel(modelParams); 75 | 76 | -- copy the weights from loaded model 77 | svqaModel.wrapperW:copy(savedModel.modelW); 78 | 79 | ------------------------------------------------------------------------ 80 | -- Training 81 | ------------------------------------------------------------------------ 82 | -- validation accuracy 83 | print('Evaluation..') 84 | svqaModel:retrieve(dataloader, 'test'); 85 | os.exit() 86 | 87 | ---[[ 88 | print('Generating answers...') 89 | -- local answers = svqaModel:generateAnswers(dataloader, 'test', {sample = false}); 90 | local answers = svqaModel:generateAnswersBeamSearch(dataloader, 'test', {}); 91 | 92 | --save the file to json 93 | local savePath = string.format('%s/%s-results.json', opt.result_path, modelParams.model_name); 94 | utils.writeJSON(savePath, answers); 95 | print('Writing the results to '..savePath); 96 | -- --]] 97 | 98 | --svqaModel:visualizeAttention(dataloader, 'val', genParams); 99 | -------------------------------------------------------------------------------- /chatbot/utils.lua: -------------------------------------------------------------------------------- 1 | -- script containing supporting code/methods 2 | local utils = {}; 3 | cjson = require 'cjson' 4 | 5 | -- right align the question tokens in 3d volume 6 | function utils.rightAlign(sequences, lengths) 7 | -- clone the sequences 8 | local rAligned = sequences:clone():fill(0); 9 | local numDims = sequences:dim(); 10 | 11 | if numDims == 3 then 12 | local M = sequences:size(3); -- maximum length of question 13 | local numImgs = sequences:size(1); -- number of images 14 | local maxCount = sequences:size(2); -- number of questions / image 15 | 16 | for imId = 1, numImgs do 17 | for quesId = 1, maxCount do 18 | -- do only for non zero sequence counts 19 | if lengths[imId][quesId] == 0 then 20 | break; 21 | end 22 | 23 | -- copy based on the sequence length 24 | rAligned[imId][quesId][{{M - lengths[imId][quesId] + 1, M}}] = 25 | sequences[imId][quesId][{{1, lengths[imId][quesId]}}]; 26 | end 27 | end 28 | else if numDims == 2 then 29 | -- handle 2 dimensional matrices as well 30 | local M = sequences:size(2); -- maximum length of question 31 | local numImgs = sequences:size(1); -- number of images 32 | 33 | for imId = 1, numImgs do 34 | -- do only for non zero sequence counts 35 | if lengths[imId] > 0 then 36 | -- copy based on the sequence length 37 | rAligned[imId][{{M - lengths[imId] + 1, M}}] = 38 | sequences[imId][{{1, lengths[imId]}}]; 39 | end 40 | end 41 | end 42 | end 43 | 44 | return rAligned; 45 | end 46 | 47 | -- translate a table of words to index tensor 48 | function utils.wordsToId(words, word2ind, max_len) 49 | local len = max_len or 15 50 | local vector = torch.LongTensor(len):zero() 51 | for i = 1, #words do 52 | if word2ind[words[i]] ~= nil then 53 | vector[len - #words + i] = word2ind[words[i]] 54 | else 55 | vector[len - #words + i] = word2ind['UNK'] 56 | end 57 | end 58 | return vector 59 | end 60 | 61 | -- translate a given tensor/table to sentence 62 | function utils.idToWords(vector, ind2word) 63 | local sentence = ''; 64 | 65 | local nextWord; 66 | for wordId = 1, vector:size(1) do 67 | if vector[wordId] > 0 then 68 | nextWord = ind2word[vector[wordId]]; 69 | sentence = sentence..' '..nextWord; 70 | end 71 | 72 | -- stop if end of token is attained 73 | if nextWord == '' then break; end 74 | end 75 | 76 | return sentence; 77 | end 78 | 79 | -- read a json file and lua table 80 | function utils.readJSON(fileName) 81 | local file = io.open(fileName, 'r'); 82 | local text = file:read(); 83 | file:close(); 84 | 85 | -- convert and save information 86 | return cjson.decode(text); 87 | end 88 | 89 | -- save a lua table to the json 90 | function utils.writeJSON(fileName, luaTable) 91 | -- serialize lua table 92 | local text = cjson.encode(luaTable) 93 | 94 | local file = io.open(fileName, 'w'); 95 | file:write(text); 96 | file:close(); 97 | end 98 | 99 | -- compute the likelihood given the gt words and predicted probabilities 100 | function utils.computeLhood(words, predProbs) 101 | -- compute the probabilities for each answer, based on its tokens 102 | -- convert to 2d matrix 103 | local predVec = predProbs:view(-1, predProbs:size(3)); 104 | local indices = words:contiguous():view(-1, 1); 105 | local mask = indices:eq(0); 106 | -- assign proxy values to avoid 0 index errors 107 | indices[mask] = 1; 108 | local logProbs = predVec:gather(2, indices); 109 | -- neutralize other values 110 | logProbs[mask] = 0; 111 | logProbs = logProbs:viewAs(words); 112 | -- sum up for each sentence 113 | logProbs = logProbs:sum(1):squeeze(); 114 | 115 | return logProbs; 116 | end 117 | 118 | -- process the scores and obtain the ranks 119 | -- input: scores for all options, ground truth positions 120 | function utils.computeRanks(scores, gtPos) 121 | local gtScore = scores:gather(2, gtPos); 122 | local ranks = scores:gt(gtScore:expandAs(scores)); 123 | ranks = ranks:sum(2) + 1; 124 | 125 | -- convert into double 126 | return ranks:double(); 127 | end 128 | 129 | -- process the ranks and print metrics 130 | function utils.processRanks(ranks) 131 | -- print the results 132 | local numQues = ranks:size(1) * ranks:size(2); 133 | 134 | local numOptions = 100; 135 | 136 | -- convert ranks to double, vector and remove zeros 137 | ranks = ranks:double():view(-1); 138 | -- non of the values should be 0, there is gt in options 139 | if torch.sum(ranks:le(0)) > 0 then 140 | numZero = torch.sum(ranks:le(0)); 141 | print(string.format('Warning: some of ranks are zero : %d', numZero)) 142 | ranks = ranks[ranks:gt(0)]; 143 | end 144 | 145 | if torch.sum(ranks:ge(numOptions + 1)) > 0 then 146 | numGreater = torch.sum(ranks:ge(numOptions + 1)); 147 | print(string.format('Warning: some of ranks >100 : %d', numGreater)) 148 | ranks = ranks[ranks:le(numOptions + 1)]; 149 | end 150 | 151 | ------------------------------------------------ 152 | print(string.format('\tNo. questions: %d', numQues)) 153 | print(string.format('\tr@1: %f', torch.sum(torch.le(ranks, 1))/numQues)) 154 | print(string.format('\tr@5: %f', torch.sum(torch.le(ranks, 5))/numQues)) 155 | print(string.format('\tr@10: %f', torch.sum(torch.le(ranks, 10))/numQues)) 156 | print(string.format('\tmedianR: %f', torch.median(ranks:view(-1))[1])) 157 | print(string.format('\tmeanR: %f', torch.mean(ranks))) 158 | print(string.format('\tmeanRR: %f', torch.mean(ranks:cinv()))) 159 | end 160 | 161 | function utils.preprocess(path, width, height) 162 | local width = width or 224 163 | local height = height or 224 164 | 165 | -- load image 166 | local orig_image = image.load(path) 167 | 168 | -- handle greyscale and rgba images 169 | if orig_image:size(1) == 1 then 170 | orig_image = orig_image:repeatTensor(3, 1, 1) 171 | elseif orig_image:size(1) == 4 then 172 | orig_image = orig_image[{{1,3},{},{}}] 173 | end 174 | 175 | -- get the dimensions of the original image 176 | local im_height = orig_image:size(2) 177 | local im_width = orig_image:size(3) 178 | 179 | -- scale and subtract mean 180 | local img = image.scale(orig_image, width, height):double() 181 | local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) 182 | img = img:index(1, torch.LongTensor{3, 2, 1}):mul(255.0) 183 | mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) 184 | img:add(-1, mean_pixel) 185 | return img, im_height, im_width 186 | end 187 | 188 | return utils; 189 | -------------------------------------------------------------------------------- /data/pools.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "gen_caption": "a fire hydrant on the side of the street", 4 | "gt_caption": "A toy horse is sitting on a sidewalk next a fire hydrant.", 5 | "obj": "fire hydrant", 6 | "poolID": "5b808d3f-b670-47b4-c726-f6874f3363eb", 7 | "pools": { 8 | "easy": [ 9 | "COCO_val2014_000000574216", 10 | "COCO_val2014_000000277089", 11 | "COCO_val2014_000000360899", 12 | "COCO_val2014_000000221105", 13 | "COCO_val2014_000000422280", 14 | "COCO_val2014_000000091052", 15 | "COCO_val2014_000000002014", 16 | "COCO_val2014_000000526188", 17 | "COCO_val2014_000000036333", 18 | "COCO_val2014_000000221932", 19 | "COCO_val2014_000000436883", 20 | "COCO_val2014_000000261360", 21 | "COCO_val2014_000000033123", 22 | "COCO_val2014_000000344194", 23 | "COCO_val2014_000000448309", 24 | "COCO_val2014_000000405047", 25 | "COCO_val2014_000000089909", 26 | "COCO_val2014_000000576052", 27 | "COCO_val2014_000000200296", 28 | "COCO_val2014_000000514248" 29 | ], 30 | "hard": [ 31 | "COCO_val2014_000000228541", 32 | "COCO_val2014_000000301634", 33 | "COCO_val2014_000000051025", 34 | "COCO_val2014_000000448078", 35 | "COCO_val2014_000000534734", 36 | "COCO_val2014_000000293474", 37 | "COCO_val2014_000000175244", 38 | "COCO_val2014_000000304012", 39 | "COCO_val2014_000000453870", 40 | "COCO_val2014_000000388818", 41 | "COCO_val2014_000000239157", 42 | "COCO_val2014_000000487698", 43 | "COCO_val2014_000000382512", 44 | "COCO_val2014_000000450500", 45 | "COCO_val2014_000000411047", 46 | "COCO_val2014_000000276063", 47 | "COCO_val2014_000000463203", 48 | "COCO_val2014_000000172315", 49 | "COCO_val2014_000000575823", 50 | "COCO_val2014_000000514248" 51 | ], 52 | "medium": [ 53 | "COCO_val2014_000000324872", 54 | "COCO_val2014_000000313588", 55 | "COCO_val2014_000000493196", 56 | "COCO_val2014_000000249273", 57 | "COCO_val2014_000000382715", 58 | "COCO_val2014_000000182895", 59 | "COCO_val2014_000000100016", 60 | "COCO_val2014_000000513690", 61 | "COCO_val2014_000000473002", 62 | "COCO_val2014_000000155644", 63 | "COCO_val2014_000000090255", 64 | "COCO_val2014_000000465996", 65 | "COCO_val2014_000000050355", 66 | "COCO_val2014_000000268239", 67 | "COCO_val2014_000000196413", 68 | "COCO_val2014_000000326898", 69 | "COCO_val2014_000000133567", 70 | "COCO_val2014_000000472054", 71 | "COCO_val2014_000000316882", 72 | "COCO_val2014_000000514248" 73 | ] 74 | }, 75 | "target": "COCO_val2014_000000514248" 76 | }, 77 | { 78 | "gen_caption": "a young boy holding a nintendo wii game controller", 79 | "gt_caption": "A persona putting tooth past on their toothbrush over a sink.", 80 | "obj": "toothbrush", 81 | "poolID": "0b8d59a2-f973-4ebe-ce08-4b0af5f89d0d", 82 | "pools": { 83 | "easy": [ 84 | "COCO_val2014_000000157726", 85 | "COCO_val2014_000000459786", 86 | "COCO_val2014_000000332625", 87 | "COCO_val2014_000000144863", 88 | "COCO_val2014_000000441862", 89 | "COCO_val2014_000000239347", 90 | "COCO_val2014_000000538497", 91 | "COCO_val2014_000000164780", 92 | "COCO_val2014_000000068418", 93 | "COCO_val2014_000000124210", 94 | "COCO_val2014_000000123071", 95 | "COCO_val2014_000000416851", 96 | "COCO_val2014_000000039726", 97 | "COCO_val2014_000000000590", 98 | "COCO_val2014_000000273728", 99 | "COCO_val2014_000000210144", 100 | "COCO_val2014_000000203350", 101 | "COCO_val2014_000000505977", 102 | "COCO_val2014_000000342971", 103 | "COCO_val2014_000000290515" 104 | ], 105 | "hard": [ 106 | "COCO_val2014_000000525369", 107 | "COCO_val2014_000000053404", 108 | "COCO_val2014_000000360216", 109 | "COCO_val2014_000000038073", 110 | "COCO_val2014_000000203479", 111 | "COCO_val2014_000000337563", 112 | "COCO_val2014_000000390627", 113 | "COCO_val2014_000000066821", 114 | "COCO_val2014_000000448837", 115 | "COCO_val2014_000000326898", 116 | "COCO_val2014_000000055528", 117 | "COCO_val2014_000000244111", 118 | "COCO_val2014_000000223746", 119 | "COCO_val2014_000000537270", 120 | "COCO_val2014_000000411754", 121 | "COCO_val2014_000000036607", 122 | "COCO_val2014_000000005670", 123 | "COCO_val2014_000000137954", 124 | "COCO_val2014_000000061603", 125 | "COCO_val2014_000000290515" 126 | ], 127 | "medium": [ 128 | "COCO_val2014_000000001292", 129 | "COCO_val2014_000000368367", 130 | "COCO_val2014_000000101622", 131 | "COCO_val2014_000000028702", 132 | "COCO_val2014_000000136555", 133 | "COCO_val2014_000000336600", 134 | "COCO_val2014_000000122549", 135 | "COCO_val2014_000000067156", 136 | "COCO_val2014_000000071495", 137 | "COCO_val2014_000000209048", 138 | "COCO_val2014_000000274653", 139 | "COCO_val2014_000000515126", 140 | "COCO_val2014_000000322261", 141 | "COCO_val2014_000000418699", 142 | "COCO_val2014_000000460053", 143 | "COCO_val2014_000000459400", 144 | "COCO_val2014_000000032887", 145 | "COCO_val2014_000000506441", 146 | "COCO_val2014_000000134778", 147 | "COCO_val2014_000000290515" 148 | ] 149 | }, 150 | "target": "COCO_val2014_000000290515" 151 | }, 152 | { 153 | "gen_caption": "a plate of food that is on a table", 154 | "gt_caption": "A person holding a white plate topped with a sandwich.", 155 | "obj": "hot dog", 156 | "poolID": "9ecac117-c51c-43a0-c3c9-5185299f95ed", 157 | "pools": { 158 | "easy": [ 159 | "COCO_val2014_000000180787", 160 | "COCO_val2014_000000075006", 161 | "COCO_val2014_000000560470", 162 | "COCO_val2014_000000384788", 163 | "COCO_val2014_000000536073", 164 | "COCO_val2014_000000412036", 165 | "COCO_val2014_000000366493", 166 | "COCO_val2014_000000118246", 167 | "COCO_val2014_000000452623", 168 | "COCO_val2014_000000170658", 169 | "COCO_val2014_000000268400", 170 | "COCO_val2014_000000116226", 171 | "COCO_val2014_000000134016", 172 | "COCO_val2014_000000231527", 173 | "COCO_val2014_000000128598", 174 | "COCO_val2014_000000384811", 175 | "COCO_val2014_000000029833", 176 | "COCO_val2014_000000128658", 177 | "COCO_val2014_000000410583", 178 | "COCO_val2014_000000210002" 179 | ], 180 | "hard": [ 181 | "COCO_val2014_000000471394", 182 | "COCO_val2014_000000091460", 183 | "COCO_val2014_000000067686", 184 | "COCO_val2014_000000440212", 185 | "COCO_val2014_000000356424", 186 | "COCO_val2014_000000503148", 187 | "COCO_val2014_000000387567", 188 | "COCO_val2014_000000346752", 189 | "COCO_val2014_000000545039", 190 | "COCO_val2014_000000357430", 191 | "COCO_val2014_000000092861", 192 | "COCO_val2014_000000549971", 193 | "COCO_val2014_000000107140", 194 | "COCO_val2014_000000043581", 195 | "COCO_val2014_000000353027", 196 | "COCO_val2014_000000275657", 197 | "COCO_val2014_000000457461", 198 | "COCO_val2014_000000499313", 199 | "COCO_val2014_000000299023", 200 | "COCO_val2014_000000210002" 201 | ], 202 | "medium": [ 203 | "COCO_val2014_000000136624", 204 | "COCO_val2014_000000097974", 205 | "COCO_val2014_000000105335", 206 | "COCO_val2014_000000298600", 207 | "COCO_val2014_000000508006", 208 | "COCO_val2014_000000207898", 209 | "COCO_val2014_000000105689", 210 | "COCO_val2014_000000440123", 211 | "COCO_val2014_000000492378", 212 | "COCO_val2014_000000462386", 213 | "COCO_val2014_000000520154", 214 | "COCO_val2014_000000295668", 215 | "COCO_val2014_000000397636", 216 | "COCO_val2014_000000529151", 217 | "COCO_val2014_000000451123", 218 | "COCO_val2014_000000101919", 219 | "COCO_val2014_000000187054", 220 | "COCO_val2014_000000150117", 221 | "COCO_val2014_000000030548", 222 | "COCO_val2014_000000210002" 223 | ] 224 | }, 225 | "target": "COCO_val2014_000000210002" 226 | }, 227 | { 228 | "gen_caption": "a man laying on a bed with a laptop", 229 | "gt_caption": "a long skinny bed with a floral blanket on it", 230 | "obj": "bed", 231 | "poolID": "3107290d-9346-4802-c88a-d801ddc70f4e", 232 | "pools": { 233 | "easy": [ 234 | "COCO_val2014_000000568710", 235 | "COCO_val2014_000000088092", 236 | "COCO_val2014_000000308645", 237 | "COCO_val2014_000000457449", 238 | "COCO_val2014_000000470313", 239 | "COCO_val2014_000000036661", 240 | "COCO_val2014_000000293133", 241 | "COCO_val2014_000000407149", 242 | "COCO_val2014_000000280972", 243 | "COCO_val2014_000000039956", 244 | "COCO_val2014_000000506515", 245 | "COCO_val2014_000000072813", 246 | "COCO_val2014_000000439472", 247 | "COCO_val2014_000000235672", 248 | "COCO_val2014_000000262161", 249 | "COCO_val2014_000000041247", 250 | "COCO_val2014_000000315219", 251 | "COCO_val2014_000000108898", 252 | "COCO_val2014_000000343820", 253 | "COCO_val2014_000000005154" 254 | ], 255 | "hard": [ 256 | "COCO_val2014_000000153574", 257 | "COCO_val2014_000000161940", 258 | "COCO_val2014_000000469543", 259 | "COCO_val2014_000000097693", 260 | "COCO_val2014_000000259952", 261 | "COCO_val2014_000000463618", 262 | "COCO_val2014_000000412878", 263 | "COCO_val2014_000000281628", 264 | "COCO_val2014_000000491497", 265 | "COCO_val2014_000000217071", 266 | "COCO_val2014_000000269918", 267 | "COCO_val2014_000000188460", 268 | "COCO_val2014_000000424879", 269 | "COCO_val2014_000000512276", 270 | "COCO_val2014_000000232790", 271 | "COCO_val2014_000000006921", 272 | "COCO_val2014_000000371555", 273 | "COCO_val2014_000000096539", 274 | "COCO_val2014_000000159463", 275 | "COCO_val2014_000000005154" 276 | ], 277 | "medium": [ 278 | "COCO_val2014_000000415048", 279 | "COCO_val2014_000000127337", 280 | "COCO_val2014_000000281628", 281 | "COCO_val2014_000000393777", 282 | "COCO_val2014_000000077648", 283 | "COCO_val2014_000000201727", 284 | "COCO_val2014_000000497388", 285 | "COCO_val2014_000000003716", 286 | "COCO_val2014_000000433423", 287 | "COCO_val2014_000000531622", 288 | "COCO_val2014_000000117112", 289 | "COCO_val2014_000000422833", 290 | "COCO_val2014_000000349414", 291 | "COCO_val2014_000000380487", 292 | "COCO_val2014_000000052132", 293 | "COCO_val2014_000000447602", 294 | "COCO_val2014_000000431306", 295 | "COCO_val2014_000000509750", 296 | "COCO_val2014_000000565582", 297 | "COCO_val2014_000000005154" 298 | ] 299 | }, 300 | "target": "COCO_val2014_000000005154" 301 | }, 302 | { 303 | "gen_caption": "a bowl of fruit is sitting on a table", 304 | "gt_caption": "A bunch of mini bananas are hanging above a bowl of apples and oranges.", 305 | "obj": "banana", 306 | "poolID": "37b38345-afe8-4c33-cc50-0ae7c9dce074", 307 | "pools": { 308 | "easy": [ 309 | "COCO_val2014_000000416786", 310 | "COCO_val2014_000000118929", 311 | "COCO_val2014_000000174239", 312 | "COCO_val2014_000000375426", 313 | "COCO_val2014_000000335148", 314 | "COCO_val2014_000000439868", 315 | "COCO_val2014_000000243190", 316 | "COCO_val2014_000000264625", 317 | "COCO_val2014_000000089258", 318 | "COCO_val2014_000000399744", 319 | "COCO_val2014_000000191981", 320 | "COCO_val2014_000000092020", 321 | "COCO_val2014_000000441496", 322 | "COCO_val2014_000000440562", 323 | "COCO_val2014_000000568893", 324 | "COCO_val2014_000000171805", 325 | "COCO_val2014_000000141040", 326 | "COCO_val2014_000000173530", 327 | "COCO_val2014_000000137479", 328 | "COCO_val2014_000000509131" 329 | ], 330 | "hard": [ 331 | "COCO_val2014_000000444694", 332 | "COCO_val2014_000000133161", 333 | "COCO_val2014_000000238866", 334 | "COCO_val2014_000000095062", 335 | "COCO_val2014_000000406570", 336 | "COCO_val2014_000000441323", 337 | "COCO_val2014_000000185901", 338 | "COCO_val2014_000000130171", 339 | "COCO_val2014_000000320972", 340 | "COCO_val2014_000000158708", 341 | "COCO_val2014_000000157270", 342 | "COCO_val2014_000000442301", 343 | "COCO_val2014_000000211243", 344 | "COCO_val2014_000000338581", 345 | "COCO_val2014_000000402346", 346 | "COCO_val2014_000000256973", 347 | "COCO_val2014_000000455746", 348 | "COCO_val2014_000000060855", 349 | "COCO_val2014_000000042215", 350 | "COCO_val2014_000000509131" 351 | ], 352 | "medium": [ 353 | "COCO_val2014_000000391320", 354 | "COCO_val2014_000000392850", 355 | "COCO_val2014_000000369379", 356 | "COCO_val2014_000000455746", 357 | "COCO_val2014_000000507318", 358 | "COCO_val2014_000000048795", 359 | "COCO_val2014_000000200862", 360 | "COCO_val2014_000000177449", 361 | "COCO_val2014_000000513688", 362 | "COCO_val2014_000000354041", 363 | "COCO_val2014_000000350639", 364 | "COCO_val2014_000000484069", 365 | "COCO_val2014_000000405815", 366 | "COCO_val2014_000000259952", 367 | "COCO_val2014_000000003711", 368 | "COCO_val2014_000000020179", 369 | "COCO_val2014_000000028002", 370 | "COCO_val2014_000000466079", 371 | "COCO_val2014_000000249443", 372 | "COCO_val2014_000000509131" 373 | ] 374 | }, 375 | "target": "COCO_val2014_000000509131" 376 | }, 377 | { 378 | "gen_caption": "a giraffe standing in the grass near a tree", 379 | "gt_caption": "Two giraffes that are standing in the dirt.", 380 | "obj": "giraffe", 381 | "poolID": "cdd58b79-079e-4928-c929-a406ce62cc8e", 382 | "pools": { 383 | "easy": [ 384 | "COCO_val2014_000000554892", 385 | "COCO_val2014_000000423723", 386 | "COCO_val2014_000000448448", 387 | "COCO_val2014_000000402559", 388 | "COCO_val2014_000000076454", 389 | "COCO_val2014_000000032703", 390 | "COCO_val2014_000000328504", 391 | "COCO_val2014_000000164095", 392 | "COCO_val2014_000000366115", 393 | "COCO_val2014_000000404262", 394 | "COCO_val2014_000000254965", 395 | "COCO_val2014_000000548296", 396 | "COCO_val2014_000000514118", 397 | "COCO_val2014_000000109888", 398 | "COCO_val2014_000000033707", 399 | "COCO_val2014_000000531607", 400 | "COCO_val2014_000000417876", 401 | "COCO_val2014_000000122007", 402 | "COCO_val2014_000000351787", 403 | "COCO_val2014_000000085005" 404 | ], 405 | "hard": [ 406 | "COCO_val2014_000000086559", 407 | "COCO_val2014_000000064746", 408 | "COCO_val2014_000000040896", 409 | "COCO_val2014_000000442894", 410 | "COCO_val2014_000000262900", 411 | "COCO_val2014_000000067587", 412 | "COCO_val2014_000000384848", 413 | "COCO_val2014_000000091857", 414 | "COCO_val2014_000000448426", 415 | "COCO_val2014_000000274331", 416 | "COCO_val2014_000000312316", 417 | "COCO_val2014_000000476109", 418 | "COCO_val2014_000000176306", 419 | "COCO_val2014_000000505538", 420 | "COCO_val2014_000000480345", 421 | "COCO_val2014_000000142189", 422 | "COCO_val2014_000000549766", 423 | "COCO_val2014_000000014271", 424 | "COCO_val2014_000000491481", 425 | "COCO_val2014_000000085005" 426 | ], 427 | "medium": [ 428 | "COCO_val2014_000000505538", 429 | "COCO_val2014_000000289337", 430 | "COCO_val2014_000000086559", 431 | "COCO_val2014_000000480345", 432 | "COCO_val2014_000000202880", 433 | "COCO_val2014_000000448426", 434 | "COCO_val2014_000000491481", 435 | "COCO_val2014_000000517029", 436 | "COCO_val2014_000000359781", 437 | "COCO_val2014_000000285961", 438 | "COCO_val2014_000000577524", 439 | "COCO_val2014_000000448448", 440 | "COCO_val2014_000000186747", 441 | "COCO_val2014_000000373578", 442 | "COCO_val2014_000000011202", 443 | "COCO_val2014_000000426629", 444 | "COCO_val2014_000000431897", 445 | "COCO_val2014_000000374266", 446 | "COCO_val2014_000000496264", 447 | "COCO_val2014_000000085005" 448 | ] 449 | }, 450 | "target": "COCO_val2014_000000085005" 451 | }, 452 | { 453 | "gen_caption": "a man sitting on a couch with a laptop", 454 | "gt_caption": "A living room area with some couches and a television", 455 | "obj": "book", 456 | "poolID": "3fe1e080-ce8d-43e7-c319-bfcfb974b69e", 457 | "pools": { 458 | "easy": [ 459 | "COCO_val2014_000000494751", 460 | "COCO_val2014_000000175310", 461 | "COCO_val2014_000000351489", 462 | "COCO_val2014_000000434976", 463 | "COCO_val2014_000000526767", 464 | "COCO_val2014_000000258266", 465 | "COCO_val2014_000000407149", 466 | "COCO_val2014_000000368980", 467 | "COCO_val2014_000000133679", 468 | "COCO_val2014_000000510340", 469 | "COCO_val2014_000000058364", 470 | "COCO_val2014_000000363767", 471 | "COCO_val2014_000000162156", 472 | "COCO_val2014_000000519329", 473 | "COCO_val2014_000000333430", 474 | "COCO_val2014_000000228644", 475 | "COCO_val2014_000000319452", 476 | "COCO_val2014_000000132901", 477 | "COCO_val2014_000000070240", 478 | "COCO_val2014_000000578498" 479 | ], 480 | "hard": [ 481 | "COCO_val2014_000000085803", 482 | "COCO_val2014_000000349414", 483 | "COCO_val2014_000000159458", 484 | "COCO_val2014_000000244159", 485 | "COCO_val2014_000000560368", 486 | "COCO_val2014_000000570456", 487 | "COCO_val2014_000000347693", 488 | "COCO_val2014_000000343466", 489 | "COCO_val2014_000000453622", 490 | "COCO_val2014_000000052282", 491 | "COCO_val2014_000000344548", 492 | "COCO_val2014_000000453824", 493 | "COCO_val2014_000000534733", 494 | "COCO_val2014_000000551963", 495 | "COCO_val2014_000000125778", 496 | "COCO_val2014_000000439143", 497 | "COCO_val2014_000000580410", 498 | "COCO_val2014_000000441415", 499 | "COCO_val2014_000000047775", 500 | "COCO_val2014_000000578498" 501 | ], 502 | "medium": [ 503 | "COCO_val2014_000000367386", 504 | "COCO_val2014_000000269678", 505 | "COCO_val2014_000000091267", 506 | "COCO_val2014_000000463134", 507 | "COCO_val2014_000000444304", 508 | "COCO_val2014_000000202797", 509 | "COCO_val2014_000000528318", 510 | "COCO_val2014_000000348083", 511 | "COCO_val2014_000000303069", 512 | "COCO_val2014_000000337264", 513 | "COCO_val2014_000000223671", 514 | "COCO_val2014_000000286981", 515 | "COCO_val2014_000000436273", 516 | "COCO_val2014_000000276239", 517 | "COCO_val2014_000000085735", 518 | "COCO_val2014_000000417430", 519 | "COCO_val2014_000000148620", 520 | "COCO_val2014_000000162035", 521 | "COCO_val2014_000000395550", 522 | "COCO_val2014_000000578498" 523 | ] 524 | }, 525 | "target": "COCO_val2014_000000578498" 526 | }, 527 | { 528 | "gen_caption": "a train on a track near a platform", 529 | "gt_caption": "A black and yellow train traveling down train tracks.", 530 | "obj": "train", 531 | "poolID": "0971e298-f8a6-40df-c084-4a9d3935a9e8", 532 | "pools": { 533 | "easy": [ 534 | "COCO_val2014_000000151857", 535 | "COCO_val2014_000000231163", 536 | "COCO_val2014_000000137822", 537 | "COCO_val2014_000000113570", 538 | "COCO_val2014_000000241948", 539 | "COCO_val2014_000000472569", 540 | "COCO_val2014_000000070619", 541 | "COCO_val2014_000000178072", 542 | "COCO_val2014_000000506196", 543 | "COCO_val2014_000000335089", 544 | "COCO_val2014_000000125513", 545 | "COCO_val2014_000000547744", 546 | "COCO_val2014_000000145562", 547 | "COCO_val2014_000000318550", 548 | "COCO_val2014_000000545363", 549 | "COCO_val2014_000000209901", 550 | "COCO_val2014_000000308785", 551 | "COCO_val2014_000000256838", 552 | "COCO_val2014_000000021971", 553 | "COCO_val2014_000000347747" 554 | ], 555 | "hard": [ 556 | "COCO_val2014_000000370929", 557 | "COCO_val2014_000000005690", 558 | "COCO_val2014_000000193069", 559 | "COCO_val2014_000000298718", 560 | "COCO_val2014_000000224554", 561 | "COCO_val2014_000000384594", 562 | "COCO_val2014_000000509404", 563 | "COCO_val2014_000000203081", 564 | "COCO_val2014_000000382728", 565 | "COCO_val2014_000000076768", 566 | "COCO_val2014_000000005278", 567 | "COCO_val2014_000000335887", 568 | "COCO_val2014_000000241297", 569 | "COCO_val2014_000000183666", 570 | "COCO_val2014_000000128992", 571 | "COCO_val2014_000000309366", 572 | "COCO_val2014_000000146825", 573 | "COCO_val2014_000000565331", 574 | "COCO_val2014_000000208524", 575 | "COCO_val2014_000000347747" 576 | ], 577 | "medium": [ 578 | "COCO_val2014_000000384594", 579 | "COCO_val2014_000000307598", 580 | "COCO_val2014_000000509404", 581 | "COCO_val2014_000000565331", 582 | "COCO_val2014_000000203081", 583 | "COCO_val2014_000000076768", 584 | "COCO_val2014_000000208524", 585 | "COCO_val2014_000000241297", 586 | "COCO_val2014_000000125815", 587 | "COCO_val2014_000000149731", 588 | "COCO_val2014_000000373089", 589 | "COCO_val2014_000000140292", 590 | "COCO_val2014_000000506196", 591 | "COCO_val2014_000000563641", 592 | "COCO_val2014_000000185472", 593 | "COCO_val2014_000000482585", 594 | "COCO_val2014_000000506416", 595 | "COCO_val2014_000000354229", 596 | "COCO_val2014_000000034950", 597 | "COCO_val2014_000000347747" 598 | ] 599 | }, 600 | "target": "COCO_val2014_000000347747" 601 | }, 602 | { 603 | "gen_caption": "a clock tower with a clock on top of it", 604 | "gt_caption": "A large building with a clock on it's face and a bird statue on top.", 605 | "obj": "clock", 606 | "poolID": "e923ec65-fb76-4acc-cac9-ee0a3e7c09cc", 607 | "pools": { 608 | "easy": [ 609 | "COCO_val2014_000000028002", 610 | "COCO_val2014_000000145032", 611 | "COCO_val2014_000000194780", 612 | "COCO_val2014_000000461993", 613 | "COCO_val2014_000000566026", 614 | "COCO_val2014_000000148668", 615 | "COCO_val2014_000000298137", 616 | "COCO_val2014_000000439777", 617 | "COCO_val2014_000000546226", 618 | "COCO_val2014_000000041246", 619 | "COCO_val2014_000000536743", 620 | "COCO_val2014_000000398438", 621 | "COCO_val2014_000000532530", 622 | "COCO_val2014_000000439868", 623 | "COCO_val2014_000000189955", 624 | "COCO_val2014_000000465552", 625 | "COCO_val2014_000000304012", 626 | "COCO_val2014_000000141086", 627 | "COCO_val2014_000000524450", 628 | "COCO_val2014_000000376767" 629 | ], 630 | "hard": [ 631 | "COCO_val2014_000000445638", 632 | "COCO_val2014_000000578786", 633 | "COCO_val2014_000000042471", 634 | "COCO_val2014_000000242644", 635 | "COCO_val2014_000000153283", 636 | "COCO_val2014_000000069873", 637 | "COCO_val2014_000000300086", 638 | "COCO_val2014_000000467249", 639 | "COCO_val2014_000000274864", 640 | "COCO_val2014_000000040037", 641 | "COCO_val2014_000000369506", 642 | "COCO_val2014_000000050965", 643 | "COCO_val2014_000000277046", 644 | "COCO_val2014_000000561412", 645 | "COCO_val2014_000000334244", 646 | "COCO_val2014_000000131431", 647 | "COCO_val2014_000000188132", 648 | "COCO_val2014_000000458918", 649 | "COCO_val2014_000000314607", 650 | "COCO_val2014_000000376767" 651 | ], 652 | "medium": [ 653 | "COCO_val2014_000000131431", 654 | "COCO_val2014_000000578786", 655 | "COCO_val2014_000000301817", 656 | "COCO_val2014_000000153283", 657 | "COCO_val2014_000000280530", 658 | "COCO_val2014_000000050965", 659 | "COCO_val2014_000000369506", 660 | "COCO_val2014_000000458918", 661 | "COCO_val2014_000000111546", 662 | "COCO_val2014_000000573008", 663 | "COCO_val2014_000000266348", 664 | "COCO_val2014_000000161539", 665 | "COCO_val2014_000000456197", 666 | "COCO_val2014_000000022724", 667 | "COCO_val2014_000000013639", 668 | "COCO_val2014_000000474868", 669 | "COCO_val2014_000000297726", 670 | "COCO_val2014_000000513342", 671 | "COCO_val2014_000000423395", 672 | "COCO_val2014_000000376767" 673 | ] 674 | }, 675 | "target": "COCO_val2014_000000376767" 676 | }, 677 | { 678 | "gen_caption": "a desk with a laptop and a desktop computer", 679 | "gt_caption": "A laptop computer on a desk with a wireless keyboard and mouse.", 680 | "obj": "laptop", 681 | "poolID": "eb060aa2-0401-405c-c05f-fe2c0c57f903", 682 | "pools": { 683 | "easy": [ 684 | "COCO_val2014_000000253630", 685 | "COCO_val2014_000000366137", 686 | "COCO_val2014_000000471893", 687 | "COCO_val2014_000000038034", 688 | "COCO_val2014_000000200492", 689 | "COCO_val2014_000000475882", 690 | "COCO_val2014_000000022656", 691 | "COCO_val2014_000000099317", 692 | "COCO_val2014_000000528311", 693 | "COCO_val2014_000000472930", 694 | "COCO_val2014_000000122213", 695 | "COCO_val2014_000000264124", 696 | "COCO_val2014_000000263251", 697 | "COCO_val2014_000000392892", 698 | "COCO_val2014_000000150843", 699 | "COCO_val2014_000000212766", 700 | "COCO_val2014_000000084073", 701 | "COCO_val2014_000000289781", 702 | "COCO_val2014_000000283028", 703 | "COCO_val2014_000000195542" 704 | ], 705 | "hard": [ 706 | "COCO_val2014_000000398158", 707 | "COCO_val2014_000000241187", 708 | "COCO_val2014_000000366225", 709 | "COCO_val2014_000000069266", 710 | "COCO_val2014_000000465007", 711 | "COCO_val2014_000000228450", 712 | "COCO_val2014_000000542509", 713 | "COCO_val2014_000000144305", 714 | "COCO_val2014_000000203705", 715 | "COCO_val2014_000000141828", 716 | "COCO_val2014_000000297085", 717 | "COCO_val2014_000000337290", 718 | "COCO_val2014_000000152751", 719 | "COCO_val2014_000000261453", 720 | "COCO_val2014_000000113556", 721 | "COCO_val2014_000000383209", 722 | "COCO_val2014_000000320658", 723 | "COCO_val2014_000000066001", 724 | "COCO_val2014_000000441253", 725 | "COCO_val2014_000000195542" 726 | ], 727 | "medium": [ 728 | "COCO_val2014_000000337290", 729 | "COCO_val2014_000000254986", 730 | "COCO_val2014_000000295768", 731 | "COCO_val2014_000000182755", 732 | "COCO_val2014_000000141828", 733 | "COCO_val2014_000000131138", 734 | "COCO_val2014_000000149222", 735 | "COCO_val2014_000000243456", 736 | "COCO_val2014_000000090239", 737 | "COCO_val2014_000000129576", 738 | "COCO_val2014_000000003716", 739 | "COCO_val2014_000000451284", 740 | "COCO_val2014_000000125778", 741 | "COCO_val2014_000000526418", 742 | "COCO_val2014_000000049473", 743 | "COCO_val2014_000000393421", 744 | "COCO_val2014_000000449296", 745 | "COCO_val2014_000000560119", 746 | "COCO_val2014_000000043939", 747 | "COCO_val2014_000000195542" 748 | ] 749 | }, 750 | "target": "COCO_val2014_000000195542" 751 | } 752 | ] -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/demo/__init__.py -------------------------------------------------------------------------------- /demo/asgi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from channels.asgi import get_channel_layer 3 | 4 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "demo.settings") 5 | 6 | channel_layer = get_channel_layer() 7 | -------------------------------------------------------------------------------- /demo/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 7 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | # Set the path of the Lua Script for both SL and RL bots 10 | os.environ['LUA_PATH'] = os.environ['LUA_PATH'] + ":" + os.path.join(BASE_DIR, 'chatbot', 'sl_evaluate.lua') 11 | os.environ['LUA_PATH'] = os.environ['LUA_PATH'] + ":" + os.path.join(BASE_DIR, 'chatbot', 'rl_evaluate.lua') 12 | 13 | # SECURITY WARNING: keep the secret key used in production secret! 14 | SECRET_KEY = '3$zc7zn#v==*r2ukiezqv39g2im4zf!2%53f+h0rga&*=&(7l5' 15 | 16 | # SECURITY WARNING: don't run with debug turned on in production! 17 | DEBUG = True 18 | 19 | ALLOWED_HOSTS = [] 20 | 21 | 22 | # Application definition 23 | 24 | INSTALLED_APPS = [ 25 | 'django.contrib.admin', 26 | 'django.contrib.auth', 27 | 'django.contrib.contenttypes', 28 | 'django.contrib.sessions', 29 | 'django.contrib.messages', 30 | 'django.contrib.staticfiles', 31 | 'channels', 32 | 'amt', 33 | 'import_export', 34 | ] 35 | 36 | MIDDLEWARE = [ 37 | 'django.middleware.security.SecurityMiddleware', 38 | 'django.contrib.sessions.middleware.SessionMiddleware', 39 | 'django.middleware.common.CommonMiddleware', 40 | 'django.middleware.csrf.CsrfViewMiddleware', 41 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 42 | 'django.contrib.messages.middleware.MessageMiddleware', 43 | # 'django.middleware.clickjacking.XFrameOptionsMiddleware', 44 | ] 45 | 46 | ROOT_URLCONF = 'demo.urls' 47 | 48 | TEMPLATES = [ 49 | { 50 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 51 | 'DIRS': [], 52 | 'APP_DIRS': True, 53 | 'OPTIONS': { 54 | 'context_processors': [ 55 | 'django.template.context_processors.debug', 56 | 'django.template.context_processors.request', 57 | 'django.contrib.auth.context_processors.auth', 58 | 'django.contrib.messages.context_processors.messages', 59 | ], 60 | }, 61 | }, 62 | ] 63 | 64 | WSGI_APPLICATION = 'demo.wsgi.application' 65 | 66 | 67 | # Database 68 | # https://docs.djangoproject.com/en/1.10/ref/settings/#databases 69 | 70 | DATABASES = { 71 | 'default': { 72 | 'ENGINE': 'django.db.backends.sqlite3', 73 | 'NAME': 'test.db', 74 | } 75 | } 76 | 77 | # Password validation 78 | # https://docs.djangoproject.com/en/1.10/ref/settings/#auth-password-validators 79 | 80 | AUTH_PASSWORD_VALIDATORS = [ 81 | { 82 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 83 | }, 84 | { 85 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 86 | }, 87 | { 88 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 89 | }, 90 | { 91 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 92 | }, 93 | ] 94 | 95 | 96 | # Internationalization 97 | # https://docs.djangoproject.com/en/1.10/topics/i18n/ 98 | 99 | LANGUAGE_CODE = 'en-us' 100 | 101 | TIME_ZONE = 'UTC' 102 | 103 | USE_I18N = True 104 | 105 | USE_L10N = True 106 | 107 | USE_TZ = True 108 | 109 | 110 | # Static files (CSS, JavaScript, Images) 111 | # https://docs.djangoproject.com/en/1.10/howto/static-files/ 112 | 113 | STATIC_URL = '/static/' 114 | 115 | STATIC_ROOT = os.path.join(BASE_DIR, 'static') 116 | 117 | MEDIA_ROOT = os.path.join(BASE_DIR, 'media') 118 | 119 | MEDIA_URL= "https://vision.ece.vt.edu/mscoco/images/" 120 | 121 | CHANNEL_LAYERS = { 122 | "default": { 123 | "BACKEND": "asgi_redis.RedisChannelLayer", 124 | "CONFIG": { 125 | "hosts": [("localhost", 6379)], 126 | "prefix": u'vicki_redis:', 127 | }, 128 | "ROUTING": "amt.routing.channel_routing", 129 | }, 130 | } 131 | -------------------------------------------------------------------------------- /demo/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls import url, include 2 | from django.contrib import admin 3 | from django.conf import settings 4 | from django.conf.urls.static import static 5 | 6 | urlpatterns = [ 7 | url(r'^admin/', admin.site.urls), 8 | url(r'^', include('amt.urls')), 9 | ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) 10 | -------------------------------------------------------------------------------- /demo/wsgi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from django.core.wsgi import get_wsgi_application 4 | 5 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "demo.settings") 6 | 7 | application = get_wsgi_application() 8 | -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | cd chatbot 2 | mkdir data && cd data 3 | 4 | base_url="https://filebox.ece.vt.edu/~deshraj/guesswhich_github/" 5 | 6 | wget "${base_url}chat_processed_params.json" 7 | wget "${base_url}qbot_hre_qih_sl.t7" 8 | wget "${base_url}abot_hre_qih_sl.t7" 9 | wget "${base_url}all_pools_vgg16_features.t7" 10 | wget "${base_url}final_vgg16_pool_features.t7" 11 | wget "${base_url}qbot_rl.t7" 12 | wget "${base_url}abot_rl.t7" 13 | 14 | cd ../.. 15 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "demo.settings") 7 | try: 8 | from django.core.management import execute_from_command_line 9 | except ImportError: 10 | # The above import may fail for some other reason. Ensure that the 11 | # issue is really that Django is missing to avoid masking other 12 | # exceptions on Python 2. 13 | try: 14 | import django 15 | except ImportError: 16 | raise ImportError( 17 | "Couldn't import Django. Are you sure it's installed and " 18 | "available on your PYTHONPATH environment variable? Did you " 19 | "forget to activate a virtual environment?" 20 | ) 21 | raise 22 | execute_from_command_line(sys.argv) 23 | -------------------------------------------------------------------------------- /ques_feat.json: -------------------------------------------------------------------------------- 1 | {"ques": [11106, 6749, 2528, 6069, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "ques_length": 4} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asgi-redis==0.14.1 2 | asgiref==0.14.0 3 | autobahn==0.16.0 4 | backports.ssl-match-hostname==3.5.0.1 5 | channels==0.17.2 6 | daphne==0.15.0 7 | Django==1.10.1 8 | h5py==2.6.0 9 | msgpack-python==0.4.8 10 | MySQL-python==1.2.5 11 | nltk==3.2.1 12 | numpy==1.11.1 13 | pika==0.10.0 14 | Pillow==3.3.0 15 | psycopg2==2.7.1 16 | PyTorch===4.1.1-SNAPSHOT 17 | PyYAML==3.12 18 | redis==2.10.5 19 | requests==2.13.0 20 | scipy==0.17.1 21 | six==1.10.0 22 | tqdm==4.8.4 23 | Twisted==16.4.1 24 | txaio==2.5.1 25 | uWSGI==2.0.13.1 26 | websocket-client==0.37.0 27 | zope.interface==4.3.2 -------------------------------------------------------------------------------- /static/img/guesswhich.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-Vision-Lab/GuessWhich/4d883db722a6347c16cb6ddd7b3329b5a5fd439f/static/img/guesswhich.png --------------------------------------------------------------------------------