├── .gitignore
├── LICENSE
├── README.md
├── images
├── NLRLongerNoBanner.gif
├── NLRgifActuallyGif.gif
├── NaturalLanguageRecommendationsUpdated.png
├── NaturalLanguageRecommendationsUpdated2.png
├── Our Paper Comparison.jpg
├── Relevant Paper Comparison.jpg
├── SampleNLRresults.JPG
├── TensorBoard.JPG
├── architecturePart1.JPG
├── architecturePart1.PNG
├── architecturePart2.JPG
├── architecturePart2.PNG
├── architecturePart3.JPG
├── architecturePart3.PNG
├── architecturePart4.JPG
├── architecturePart4.PNG
└── gif4Github1-1.gif
├── notebooks
├── data
│ ├── CreateCSBertTFrecords.ipynb
│ ├── CreateCS_tfrecordsDataSet4Bert_github.ipynb
│ ├── DataGuideForGithub.ipynb
│ ├── PruningCreateEmbeddingDataGithub.ipynb
│ ├── medical_preprocessing.ipynb
│ └── pruning_first_pass.ipynb
├── export_saved_model.ipynb
├── inference
│ ├── DemoNaturalLanguageRecommendationsCPU_Autofeedback.ipynb
│ ├── DemoNaturalLanguageRecommendationsCPU_Manualfeedback.ipynb
│ ├── DemoNaturalLanguageRecommendationsSimpleDemoCPU.ipynb
│ ├── TPUIndexPublic1SecondDemo19p5Million.ipynb
│ ├── TpuIndex_build_index_and_search.ipynb
│ ├── build_index_and_search.ipynb
│ ├── create_abstract_vectors.ipynb
│ └── tpu_index_search_million_embeddings.ipynb
├── text2cite_preprocessing.ipynb
├── tfrecords_debug.ipynb
├── tpu_index_debug.ipynb
└── training
│ ├── PaperVectorTrainingWord2vec_Github.ipynb
│ ├── TF2.0 Word2Vec CBOW.ipynb
│ └── model.ipynb
└── src
├── TFrecordWriter.py
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **.ipynb_checkpoints
2 | **tfrecords
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Santosh Gupta
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](http://hits.dwyl.io/Santosh-Gupta/NaturalLanguageRecommendations)
2 |
3 | # Natural Language Recommendations: A novel research paper search engine developed entirely with embedding and transformer models.
4 |
5 |
6 |
7 |
8 |
9 | ## Try it out, NOW
10 |
11 | https://colab.research.google.com/github/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/DemoNaturalLanguageRecommendationsCPU_Autofeedback.ipynb
12 |
13 | Run the first cell of the colab notebook to download and load the models and data. There there's about 8 gigs total to download and load in the first cell, so this cell will take several minutes to run. After it has finished running, it'll be ready to take your queries.
14 |
15 | The model was trained on abstracts for input, so it does the best on inputs of ~100 words, but does pretty well on short 1-sentence queries as well.
16 |
17 | Note: The Colab notebook above automatically and anonymously records queries, which will be used to improve future versions of our model. If you do not wish to send queries automatically, use this version, which will only send feedback manually:
18 |
19 | https://colab.research.google.com/github/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/DemoNaturalLanguageRecommendationsCPU_Manualfeedback.ipynb
20 |
21 |
22 |
23 |
24 |
25 | Results include title, abstract, and Semantic Scholar link to the paper.
26 |
27 | ## Architecture
28 |
29 |
30 |
31 |
32 |
33 | The architecture is one part word2vec, one part Bert as a text encoder. I previously explored Bert medical text encodings in a previous project [https://github.com/re-search/DocProduct] and was impressed by the effectiveness of correlating medical questions with answers. In this project, we use the abstract of each paper as the input, but instead of using another Bert encoding as a label, we use a vector that was trained using word2vec. The Semantic Scholar Corpus [https://api.semanticscholar.org/corpus/] contains 179 million papers, and for each paper, it has the paper IDs of papers that it either cited or papers that referenced that paper.
34 |
35 | This network of citations can be trained in using the word2vec algorithm. Each embedding represents a paper. For each paper, it's citations and embeddings act as the 'context'.
36 |
37 |
38 |
39 |
40 |
41 |
42 | Our word2vec training notebooks can be found here https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/tree/master/notebooks/training
43 |
44 | Next, the abstracts are fed into Bert. The embeddings for the last hidden layer and mean pooled into a single 768-dimensional vector. This vector and then fed into a fully connected layer, whose output is a 512-dimensional vector. At the same time, each paper's paper vector is fed into a separate fully connected layer, whose output is 512 dimensions. We picked 512 as the embedding size in word2vec because in the literature on word embeddings, sometimes the embedding quality decreases after 512 dimensions, so we picked the highest dimension possible (to closer to Bert's 768 hidden layer dimensions) without risk of decreasing the quality of the embeddings. There isn't too much confidence in this choice, as the distributions in the paper data are quite different from words in text. Regular word2vec training contains 5-6 figures of labels, a lot of which frequently occur throughout the data. The paper data has 7-8 figures of labels, which each label occurring much less frequently.
45 |
46 |
47 |
48 |
49 |
50 |
51 | The notebook that we used to convert the abstracts to bert input ids, and make a dataset with the input ids and paper vectors to tfrecords files can be found here:
52 |
53 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/data/CreateCS_tfrecordsDataSet4Bert_github.ipynb
54 |
55 | We wanted to use negative sampling in our training, so in each batch, all of the labels can act as negative labels for training examples that they do not belong to. This is tricky to do because we wanted the samples to be chosen at random, but our data was split up into multiple files, with only a few at a time being loaded into memory due to our dataset being too large to fit the whole thing into ram. Luckily, the tf.data API made this easy to do.
56 |
57 | ```
58 | with strategy.scope():
59 | train_files = tf.data.Dataset.list_files(tfrecords_pattern_train)
60 | train_dataset = train_files.interleave(tf.data.TFRecordDataset,
61 | cycle_length=32,
62 | block_length=4,
63 | num_parallel_calls=autotune)
64 | train_dataset = train_dataset.map(parse_example, num_parallel_calls=autotune)
65 | train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
66 | train_dataset = train_dataset.repeat()
67 | train_dataset = train_dataset.prefetch(autotune)
68 |
69 | val_files = tf.data.Dataset.list_files(tfrecords_pattern_val)
70 | val_dataset = val_files.interleave(tf.data.TFRecordDataset,
71 | cycle_length=32,
72 | block_length=4,
73 | num_parallel_calls=autotune)
74 | val_dataset = val_dataset.map(parse_example, num_parallel_calls=autotune)
75 | val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
76 | val_dataset = val_dataset.repeat()
77 | val_dataset = val_dataset.prefetch(autotune)
78 |
79 | ```
80 |
81 |
82 |
83 |
84 |
85 | Another challenge we ran into is the training time for the data. We were developing this project for the TFWorld hackathon [https://tfworld.devpost.com/] whose deadline was Dec 31st, but we had only finished processing the data a few days before. We had 1.26 million training example, and our architecture contained a whole Bert model, which is *not super fast to train on*. Luckily, we had access to TPUs, which were ultrafast; **1 epoch taking 20-30 minutes each!** Not only were we able to complete training on the data, but we were also able to run several hyperparameter experiments on the data before the deadline.
86 |
87 | ```
88 | try:
89 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver('srihari-1-tpu') # TPU detection
90 | print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
91 | except ValueError:
92 | tpu = None
93 |
94 | if tpu:
95 | tf.config.experimental_connect_to_cluster(tpu)
96 | tf.tpu.experimental.initialize_tpu_system(tpu)
97 | strategy = tf.distribute.experimental.TPUStrategy(tpu)
98 | ```
99 |
100 | ```
101 | with strategy.scope():
102 | train_files = tf.data.Dataset.list_files(tfrecords_pattern_train)
103 | train_dataset = train_files.interleave(tf.data.TFRecordDataset,
104 | cycle_length=32,
105 | block_length=4,
106 | num_parallel_calls=autotune)
107 | train_dataset = train_dataset.map(parse_example, num_parallel_calls=autotune)
108 | train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
109 | train_dataset = train_dataset.repeat()
110 | train_dataset = train_dataset.prefetch(autotune)
111 |
112 | val_files = tf.data.Dataset.list_files(tfrecords_pattern_val)
113 | val_dataset = val_files.interleave(tf.data.TFRecordDataset,
114 | cycle_length=32,
115 | block_length=4,
116 | num_parallel_calls=autotune)
117 | val_dataset = val_dataset.map(parse_example, num_parallel_calls=autotune)
118 | val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
119 | val_dataset = val_dataset.repeat()
120 | val_dataset = val_dataset.prefetch(autotune)```
121 | ```
122 |
123 | ```
124 | with strategy.scope():
125 | model = create_model(drop_out=0.20)
126 | model.compile(loss=loss_fn,
127 | optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5))
128 | ```
129 |
130 |
131 |
132 | The really fun part was using Tensorboard, which allows users to see training and results in real-time.
133 |
134 | https://tensorboard.dev/experiment/rPYkizsLTWOpua3cyePkIg/#scalars
135 |
136 | https://tensorboard.dev/experiment/dE1MpRHvSd2XMltMrwqbeA/#scalars
137 |
138 |
139 |
140 |
141 |
142 | A link to the model training notebook can be found here
143 |
144 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/model.ipynb
145 |
146 | Watching the first Tensorboard training was like watching a NASA launch. At the time of the first training, we spent nearly 2 months on the project. There was some worry that the data may not train well. There may have been something wrong with the data (which occurred the first time we trained word2vec). Maybe we picked the wrong hyperparameters, etc. We all sat around, nervously waiting for each 20-minute epoch increment, hoping the validation loss would go down. **And then it did.** And then it did again, and again. And again.
147 |
148 | After the embeddings pass through the fully connected layers, the resulting embeddings are all dot product'd with each other. For each paper, a softmax was taken for each of its dot products. Finally, the cross-entropy loss was calculated on these logits, with a label of 1 for each original input/output pair for that training example, and 0 for all other combinations.
149 |
150 |
151 |
152 |
153 |
154 | Putting it all together
155 |
156 |
157 |
158 |
159 |
160 | ## Paper Data
161 | The papers used for this project were cleaned from Semantic Scholar's Open Corpus.
162 | Link to the cleaned data used: https://drive.google.com/open?id=1PcdLDJUXoVXorlCTcGlM98GllArk5Z9s
163 |
164 | ## Motivation
165 |
166 | Scientific information retrieval has been my biggest fascination for several years now (and now some our members share the same interest!), and it started with my research positions in biomedical research, where one of the greatest areas of friction was the difficulty in finding all the research that was relevant to my projects. This is a very common issue with researchers, especially in chem/bio/medical research due to the huge variations in terms and phrasing.
167 |
168 | To CS people, I use this example to describe what it’s like searching for information in chem/bio: imagine that StackOverflow doesn’t exist, and there’s no unified documentation for any platform, framework, or library; and all the available documentation has variation in terminology and phrasing. Imagine how slow development is in these circumstances. Imagine what the state of the internet, software, the hardware would be under these circumstances. That’s the type of friction that research in chemistry and biology is dealing with right now; the world is missing out on a ton of amazing scientific progress because of this friction.
169 |
170 | There were many times where I would stumble upon a very relevant paper, months after I had completed a project it was relevant to. Not only does this type of friction slow down research, but it also stifles creativity and the imagination towards the goals these researchers have.
171 |
172 | The latest advancements in NLP has the potential to significantly reduce this sort of friction. Vector representation of queries and documents reduces the dependency of a particular keyword or phrase for robust information retrieval. The vector representation is already being implemented into information retrieval systems at the highest levels; Earlier this year, Google announced that it is incorporating Bert into its main search engine, affecting up 10% of all search results.
173 |
174 | I think the potential for significant acceleration of scientific research makes this field an area very much worth pursuing. I have seen directly what the world is missing out on, and I suggest to anyone who looking for a particular focus in NLP, to consider scientific information retrieval. But you don't have to take my word for it, in 2017 IBM Watson found 96 cases of relevant treatment options in patients that doctors had overlooked [https://bigthink.com/stephen-johnson/ibms-watson-supercomputer-found-treatments-for-323-cancer-patients-that-human-experts-overlooked]
175 |
176 | I feel that its important to pursue as many varied information retrieval techniques/models as possible. Although many of these models will overlap, the most import aspect is if a model can find papers that the other models left behind. This becomes increasingly important for very difficult topics to search for. And often, 1 paper can have a huge impact on the direction of a project.
177 |
178 | For the Semantic Scholar Corpus, we found a very unique way of modeling information retrieval. The corpus has citation network data, and abstracts. We were able to correlate text encodings to the citation networks.
179 |
180 | ## Our Amazing Chaotic Journey (How we did it)
181 |
182 | #### Step 1: Filter the Semantic Scholar Corpus
183 |
184 | The Semantic Scholar Corpus contains about 178 million papers in a variety of subjects. We don't have the computer power to process the whole dataset (yet, if you know anything about model parallelism, please contact us), so we're focusing on subsets of the corpus.
185 |
186 | We developed filters to distill CS/Math/Physics papers from the corpus here (warning, huge notebook)
187 |
188 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/text2cite_preprocessing.ipynb
189 |
190 | And we are currently working on a subset that contains only Medline/Pubmed papers
191 |
192 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/data/medical_preprocessing.ipynb
193 |
194 | Our filtering isn't perfect, there are papers that shouldn't be in our subsets.
195 |
196 | #### Step 2: Pruning, and creating embeddings dataset
197 |
198 | Each paper has a list of references and citations. We only want papers that have citations or references to one of the other papers in our dataset (otherwise its embedding will never get a chance to be trained in word2vec), so we prune out those papers. Next, we map a unique embedding ID for each paper, save the citation data, and create an HDF5 dataset to be used for word2vec training.
199 |
200 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/data/PruningCreateEmbeddingDataGithub.ipynb
201 |
202 | #### Step 3: Word2vec
203 |
204 | We apply word2vec training to the citation/reference network data. The 'context' for each paper will be 4 of its reference/citation papers chosen at random. The issue with training an embedding for each paper is that we have a lot of papers. Our CS dataset contains 1.26 million papers (whereas word embedding training is usually in 5-6 figures only), and our Medline/Pubmed dataset contains 15 million papers.
205 |
206 | We were looking into model parallelism at the time, but with the deadline coming up, we decided to use SpeedTorch. Although we still have a TF2.0 version of Word2Vec implemented in Keras, [here](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/TF2.0%20Word2Vec%20CBOW.ipynb).
207 |
208 | https://github.com/Santosh-Gupta/SpeedTorch
209 |
210 | A library to increase transfer between CPU<->GPU. We used this to host some of the embeddings on the CPU whenever they weren't being trained.
211 |
212 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/PaperVectorTrainingWord2vec_Github.ipynb
213 |
214 | #### Step 4: Create the Bert Dataset.
215 |
216 | After word2vec training, we have a citation embedding which represents each paper. We can then use this vector as a label for the mean-pooled output of the last hidden states of Bert, with the input being each paper's abstract. We used the SciBert vocab for the tokenizer since SciBert was trained on many of the papers in the Semantic Scholar Corpus.
217 |
218 | https://github.com/allenai/scibert
219 |
220 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/data/CreateCS_tfrecordsDataSet4Bert_github.ipynb
221 |
222 | We saved these files as the tfrecords, which works great with the tf.data API and TPU training.
223 |
224 | https://www.tensorflow.org/tutorials/load_data/tfrecord
225 |
226 | #### Step 5: Training Bert
227 |
228 | Using the dataset created in Step 4, we can train our Bert model and our similarity fully connected layers. Please see the architecture section for more details. We used the TF2.0 Keras version of HuggingFace's transformer library for Bert.
229 |
230 | https://github.com/huggingface/transformers
231 |
232 | And we used Keras for the overall architecture as well. The initial weights we used were SciBert weights. https://github.com/allenai/scibert.
233 |
234 | We used tf.data to handle our data pipeline, and we used TPUv3-8 provided by the TensorFlow Research Cloud to train over our data.
235 |
236 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/model.ipynb
237 |
238 | #### Step 6: Inference
239 |
240 | At inference, a user inputs text that will be converted by our model into a test similarity vector (through Bert and its fully connected layer), and a similarity search will be performed against all of our papers' citation similarity vectors. While testing the embeddings, we found out that the abstract similarity vectors also give great results, so we decided to search against both and return the results.
241 |
242 | Our simple inference notebook can be found here
243 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/DemoNaturalLanguageRecommendationsSimpleDemoCPU.ipynb
244 |
245 | Or, to test directly in colab:
246 |
247 | https://colab.research.google.com/github/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/DemoNaturalLanguageRecommendationsCPU_Autofeedback.ipynb
248 |
249 | The notebook Above uses colab forms to hide most of the code, you can double click on any of the cell boxes to see the code. The inference runs on a CPU.
250 |
251 | For those who would like to test out inference on a GPU or even a TPU, the notebook Below automatically detects which type of instance is running at initialization, and sets the workers accordingly.
252 |
253 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/build_index_and_search.ipynb
254 |
255 | Colab verison:
256 |
257 | https://colab.research.google.com/github/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/build_index_and_search.ipynb
258 |
259 | If you like using TPUs to perform similarity search, we created a library just for this (we predict are group is going to need to do this, alot). Here is a notebook which incorporated our library TPU-Index, for similarity search.
260 |
261 | https://colab.research.google.com/drive/1wkrilS34nC4kBNEPA1bT0GOJJ8-NRzfJ
262 |
263 | ## Side Quest
264 |
265 |
266 |
267 |
268 |
269 |
270 | Ultrafast indexing, powered by TPUs
271 |
272 |
273 | We plan to eventually run inference on all 179 million papers on the Semantic Scholar Corpus, each which will have a 512-dimensional vector, which is a ton of papers to run similarity search on. This can be a very computational resource and time-consuming. There are libraries for this, like Faiss, but as we were getting to know how to utilize TPUs, Srihari came up with an idea of running cos similarity indexing over TPUs; and he created a new library for this!
274 |
275 | ```
276 | !pip install tpu-index
277 |
278 | ```
279 |
280 | ```
281 | from tpu_index import TPUIndex
282 |
283 | index = TPUIndex(num_tpu_cores=8)
284 | index.create_index(vectors) # vectors = numpy array, shape == [None, None]
285 |
286 | ...
287 | D, I = index.search(xq, distance_metric='cosine', top_k=5)
288 | ```
289 |
290 | We chose to do this on TPUs for their speed and memory capacity. Currently, the package supports search using cosine similarity, but we plan to extend this to multiple distance metrics.
291 |
292 | Currently, Google Colab has v2-8 TPUs, which have 8 gigs per core (64 gigs total). This instance can handle about 19 to 22 million float32 embeddings of size 512 (this seems to vary among depending on what chunk size we use to append the vectors, we can't pin point why). **For 19.5 million embeddings, it takes 1.017 seconds for a single cos similarity search.**
293 |
294 | We recommend adding embeddings of this size in chunks of 750,000, otherwise a memory error could occur. While appending the vectors. We find that smaller chunk sizes may result in a larger number of vectors that the index can hold.
295 |
296 | The package is quite simple to use, check it out:
297 |
298 | https://github.com/srihari-humbarwadi/tpu_index
299 |
300 | https://pypi.org/project/tpu-index/
301 |
302 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/TpuIndex_build_index_and_search.ipynb
303 |
304 | https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/tpu_index_search_million_embeddings.ipynb
305 |
306 | Test it out with our Colab Notebooks
307 |
308 | Test our model
309 | https://colab.research.google.com/drive/1wkrilS34nC4kBNEPA1bT0GOJJ8-NRzfJ
310 |
311 | Demo of a cos similarity search on 19.5 million float32 embeddings of size 512; average search time 1.017 seconds.
312 | https://colab.research.google.com/drive/1ULxK5esPJVvy7BmQx8j_6koGLMzVEDLy
313 |
314 | ## Case Studies (More coming soon)
315 |
316 | One of the main motivations of this project is to find papers that are highly relevant to a search query. We'll be testing the model out in the next couple weeks, and will post interesting case studies here.
317 |
318 | ### Case 1
319 |
320 | It was recently found by one of our members a perfect example of how our solution compares to an established one such the Semantic Sholar Corpus. For this [paper](https://www.semanticscholar.org/paper/Job-Prediction%3A-From-Deep-Neural-Network-Models-to-Huynh-Nguyen/f96cae24d992d7bcd44a99baa2ecd80e713271cc#related-papers). Titled: Job Prediction: From Deep Neural Network Models to Applications
321 | Which the Sematic Scholar gives these three papers as 'Relevant Papers':
322 |
323 |
324 |
325 |
326 |
327 | And our model was able to find:
328 |
329 |
330 |
331 |
332 | Although our model also shows non-relevant results to using machine learning in job matching and modeling this does show the capabilities of our model in being able to find rarer or more obscure papers that have fewer academic papers written on the subject.
333 |
334 | ### Case 2
335 |
336 | query = 'job prediction with machine learning'
337 |
338 | ---Top 5 results for Semantic Scholar---
339 |
340 | https://www.semanticscholar.org/search?q=job%20prediction%20with%20machine%20learning&sort=relevance
341 |
342 | Introduction to machine learning https://www.semanticscholar.org/paper/Introduction-to-machine-learning-Alpaydin/0359bba5112d472206d82ddb29947f2d634bb0cc
343 |
344 | Large-Scale Machine Learning with Stochastic Gradient Descent https://www.semanticscholar.org/paper/Large-Scale-Machine-Learning-with-Stochastic-Bottou/fbc6562814e08e416e28a268ce7beeaa3d0708c8
345 |
346 | Link prediction using supervised learning https://www.semanticscholar.org/paper/Link-prediction-using-supervised-learning-Hasan-Chaoji/413240adfbcb801b5eb186b8a9e67fe77588733c
347 |
348 | Gaussian processes for machine learning https://www.semanticscholar.org/paper/Gaussian-processes-for-machine-learning-Rasmussen-Williams/82266f6103bade9005ec555ed06ba20b5210ff22
349 |
350 | Applications of Machine Learning in Cancer Prediction and Prognosis https://www.semanticscholar.org/paper/Applications-of-Machine-Learning-in-Cancer-and-Cruz-Wishart/7e7b9f37ce280787075046727efbaf9b5a390729
351 |
352 | ---Top results for Natural Language Recommendations---
353 |
354 | Using abstract similarity:
355 |
356 | Bejo: Behavior Based Job Classification for Resource Consumption Prediction in the Cloud https://www.semanticscholar.org/paper/f6913c1d255f236f7c4e2a810425d33256cf3d84
357 |
358 | Random Forest Forecast (RFF): One hour ahead jobs in volunteer grid https://www.semanticscholar.org/paper/c770ccd5ae0809139bc13cc82356f0b132c24433
359 |
360 | Analysis of XDMoD/SUPReMM Data Using Machine Learning Techniques https://www.semanticscholar.org/paper/09af1a0185955c3aea1692972296c697f0c5b7ee
361 |
362 | Job Recommendation System based on Machine Learning and Data Mining Techniques using RESTful API and Android IDE https://www.semanticscholar.org/paper/fe661340e332779f8c40dca713011f0fad938688
363 |
364 | Machine Learning Based Prediction and Classification of Computational Jobs in Cloud Computing Centers https://www.semanticscholar.org/paper/e1b11d29b7bba8b6048439ebbb8ee26700d702a1
365 |
366 | Using citation similarity (these results aren't as good; citation similarity only seems to do well with longer inputs, 100+ words):
367 |
368 | A signal processing method to eliminate grating lobes https://www.semanticscholar.org/paper/0bc3f599347ae37b530b79e7e7458dca8208aef1
369 |
370 | Multi- and Single-output Support Vector Regression for Spectral Reflectance Recovery https://www.semanticscholar.org/paper/c6e0cbbee2745650823407d2237e511fea6578c7
371 |
372 | Space-vector PWM voltage control with optimized switching strategy https://www.semanticscholar.org/paper/795e517f8951daefc920fbec291261374dc9ee14
373 |
374 | Pole position problem for Meixner filters
375 | https://www.semanticscholar.org/paper/179a4bf74953c5111abd4de1f31e0f163d48fd22
376 |
377 | ### Case 3
378 |
379 | query = 'Optimal negative sampling for embedding models. What is the ratio of negative samples to positive examples results in the best quality vectors in noise contrastive estimation.'
380 |
381 | ---Top 5 results for Semantic Scholar (CS results only)---
382 |
383 | https://www.semanticscholar.org/search?q=Optimal%20negative%20sampling%20for%20embedding%20models.%20What%20is%20the%20ratio%20of%20negative%20samples%20to%20positive%20examples%20results%20in%20the%20best%20quality%20vectors%20in%20noise%20contrastive%20estimation.&sort=relevance
384 |
385 | Toward Optimal Active Learning through Sampling Estimation of Error Reduction https://www.semanticscholar.org/paper/Toward-Optimal-Active-Learning-through-Sampling-of-Roy-McCallum/0a20a309deda54fe14580007759c9c7623c58694
386 |
387 | Sampling-based algorithms for optimal motion planning https://www.semanticscholar.org/paper/Sampling-based-algorithms-for-optimal-motion-Karaman-Frazzoli/4326d7e9933c77ff9dc53056c62ef6712d90c633
388 |
389 | Large sample estimation and hypothesis testing https://www.semanticscholar.org/paper/Large-sample-estimation-and-hypothesis-testing-Newey-Mcfadden/3ff91f28967e0702667a644f8f9c53d964d63e4c
390 |
391 | Negative Binomial Regression https://www.semanticscholar.org/paper/Negative-Binomial-Regression-Hilbe/e54fdd22ca9d6c1094db3c0de18b3f184734dd23
392 |
393 | A transformation for ordering multispectral data in terms of image quality with implications for noise removal https://www.semanticscholar.org/paper/A-transformation-for-ordering-multispectral-data-in-Green-Berman/6ae00ebd3a91c0667c79c39035b5163025bcfcad
394 |
395 | ---Top results for Natural Language Recommendations---
396 |
397 | Using abstract similarity:
398 |
399 | Biparti Majority Learning with Tensors https://www.semanticscholar.org/paper/0985d86afbfcd53462f59bd26dd03505c9c09395
400 |
401 | Linear discriminant analysis with an information divergence criterion https://www.semanticscholar.org/paper/1f73769d98a1c661d4ce3877a25d558ef93f66bf
402 |
403 | One-class label propagation using local cone based similarity https://www.semanticscholar.org/paper/7e0c82b3225a12752dd1062292297b6201ca8d6e
404 |
405 | Concave Region Partitioning with a Greedy Strategy on Imbalanced Points https://www.semanticscholar.org/paper/d5bfdac67aec2940c93327bcf5d6e7ee86a70b64
406 |
407 | Noise-Contrastive Estimation Based on Relative Neighbour Sampling for Unsupervised Image Embedding Learning
408 | https://www.semanticscholar.org/paper/9b87f58b620d9de5f360f6dccdcedfffd99c1408
409 |
410 | Using citation similarity:
411 |
412 | Learning from Imbalanced Data https://www.semanticscholar.org/paper/6a97303b92477d95d1e6acf7b443ebe19a6beb60
413 |
414 | Bregman Divergence-Based Regularization for Transfer Subspace Learning https://www.semanticscholar.org/paper/4118b4fc7d61068b9b448fd499876d139baeec81
415 |
416 | The pyramid match kernel: discriminative classification with sets of image features https://www.semanticscholar.org/paper/625bce34ec80d29242340400d916e799d2975430
417 |
418 | Linear Discriminative Sparsity Preserving Projections for Dimensionality Reduction https://www.semanticscholar.org/paper/13e677e2041e688a2b33391f21c163e042e097d9
419 |
420 | Transfer Sparse Coding for Robust Image Representation https://www.semanticscholar.org/paper/afe14b9034f71c7078cd03626853170ef51b8060
421 |
422 | ### Case 4
423 |
424 | query = 'Copula Density Estimation'
425 |
426 | -Top 5 results for Semantic Scholar.
427 |
428 | https://www.semanticscholar.org/search?q=Copula%20Density%20Estimation&sort=relevance
429 |
430 | Copula Methods in Finance https://www.semanticscholar.org/paper/Copula-Methods-in-Finance-Cherubini-Luciano/f5a07d110482abf5bb537b37d414737d114afa09
431 |
432 | Autoregressive Conditional Density Estimation https://www.semanticscholar.org/paper/Autoregressive-Conditional-Density-Estimation-Hansen/c474cc43d8294ef7340f615a429f5085df624051
433 |
434 | Kernel density estimation via diffusion https://www.semanticscholar.org/paper/Bayesian-Density-Estimation-and-Inference-Using-Escobar-West/df25adb36860c1ad9edaac04b8855a2f19e79c5b
435 |
436 | Bayesian Density Estimation and Inference Using Mixtures https://www.semanticscholar.org/paper/Bayesian-Density-Estimation-and-Inference-Using-Escobar-West/df25adb36860c1ad9edaac04b8855a2f19e79c5b
437 |
438 | Pair-copula constructions of multiple dependence https://www.semanticscholar.org/paper/Pair-copula-constructions-of-multiple-dependence-Aas-Czado/817b6512d3d07ae231d525c366f9a95aa9bdc75a
439 |
440 | -Top results for Natural Language Recommendations
441 |
442 | Using abstract similarity
443 |
444 | On Necessary Conditions for Dependence Parameters of Minimum and Maximum Value Distributions Based on n-Variate FGM Copula https://www.semanticscholar.org/paper/ac2a2521904ca20d1135370581fdc84fbb79e46d
445 |
446 | Conditional Mean and Conditional Variance for Ali-Mikhail-Hap Copula https://www.semanticscholar.org/paper/ed09d9d721a63ca2d2fa5fac945f1e8e96b7b429
447 |
448 | Efficient estimation of high-dimensional multivariate normal copula models with discrete spatial responses https://www.semanticscholar.org/paper/f09557729a65cd87b8bbfd0950125063e06b97da
449 |
450 | Nonparametric estimation of simplified vine copula models: comparison of methods https://www.semanticscholar.org/paper/3e41b0e69342f71ff33791b88eb741c265c1eabf
451 |
452 | On tests of radial symmetry for bivariate copulas https://www.semanticscholar.org/paper/f22539174a7915b68092f27c6b6bc3c91f1fa1b0
453 |
454 | Using citation similarity
455 |
456 | None of the results using citation similarity were relevant. Again, it doesn't do well unless the queries are over 100 words.
457 |
458 |
459 |
460 | ## Unfinished Business (future work)
461 |
462 | #### Metrics
463 |
464 | Judging the results just qualitatively. . . they're really really *Really* Good. (But don't take our word for it, try it out. We have [colab notebooks](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/tree/master/notebooks/inference) that downloads the model and the data within a few clicks, and you can use it to search papers in CS). We are looking for ways to give our qualitative experiences of quantitative metrics. If you have any ideas, please contact us at Research2vec@gmail.com.
465 |
466 | #### Model Variations
467 |
468 | We have gained quite a bit of insight during this project, and have ideas of what may further improve the quality of the results. We have quite a few ideas on variations on our model which we are curious to test out.
469 |
470 | #### Bigger/Better Subsets
471 |
472 | Since the Semantic Scholar corpus is so large, we can only test subsets of subjects at a time, but there's not a way currently to filter out a certain subset directly, so we have to get creative on how we create our subsets. We are hoping to improve upon our filtering methods to get more specific/accurate subsets from the corpus.
473 |
474 | We are also hoping to figure out ways to increase the number of parameters we can train word2vec on. Currently, our capacity is around 15 million. We are aiming to get up to 179 million, which would take up a ton of memory (200 gb?) to have them all loaded into memory at the same time. If you have any ideas for this, please get in touch.
475 |
476 | #### Paper
477 |
478 | We are also looking to perform experiments and write up our work in a high enough level of quality that would make a significant contribution to the field of NLP, and thus qualify for getting accepted into a prestigious venue/journal. We are also looking for mentors who have accomplished this. If interested, please contact us at the email posted above.
479 |
480 | ## File Descriptions
481 |
482 | ### Notebooks
483 |
484 | #### [build_index_and_search.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/build_index_and_search.ipynb)
485 | Description: This notebook loads the trained bert model, builds the index with 1.3 million papers on TPUs and runs a demo search.
486 |
487 | #### [tpu_index_search_million_embeddings.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/tpu_index_search_million_embeddings.ipynb)
488 | Description: A demo notebook showcasing our tpu_index package running search on a million abstract embeddings from BERT model.
489 |
490 | #### [create_abstract_vectors.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/create_abstract_vectors.ipynb)
491 | Description: This notebook extracts embeddings for paper abstracts by passing them through the BERT model.
492 |
493 | #### [inference_model.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/inference/build_index_and_search.ipynb)
494 | Description: This notebook builds the models for inference phase.
495 |
496 | #### [medical_preprocessing.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/data/medical_preprocessing.ipynb)
497 | Description: This notebook was used to clean the original Open Corpus dataset to retain all papers that either had a PubMed id or were part of MedLine and had at least 1 citation. Cleaned medical data in the folder linked above.
498 |
499 | #### [model.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/model.ipynb)
500 | Description: This notebook has the training code for BERT, which is designed to run on Google Cloud TPU v3-8.
501 |
502 | #### [pruning_first_pass.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/data/pruning_first_pass.ipynb)
503 | Description: This notebook pruned our filtered data, meaning that it only kept papers in the cleaned dataset that either had a citation to or were cited by another paper in the cleaned data. Pruned data in the folder linked above.
504 |
505 | #### [text2cite_preprocessing.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/text2cite_preprocessing.ipynb)
506 | Description: This notebook was used to clean the original Open Corpus data to only keep papers related to fields such as engineering, math, physics, and CS. Medical/humanities papers were filtered out. Cleaned CS data in the folder linked above.
507 |
508 | #### [tfrecords_debug.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/tfrecords_debug.ipynb)
509 | Description: Testing the tfrecord writer class.
510 |
511 | #### [TF2.0 Word2Vec CBOW.ipynb](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/TF2.0%20Word2Vec%20CBOW.ipynb)
512 | Description: The original Word2Vec model implemented in Keras.
513 |
514 | ### Python files
515 |
516 | #### [tfrecordwriter.py](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/src/TFrecordWriter.py)
517 | Description: This file is a TFrecord writer class that has utility functions for sharing the dataset.
518 |
519 | #### [model.py](https://github.com/Santosh-Gupta/NaturalLanguageRecommendations/blob/master/notebooks/training/model.ipynb)
520 | Description: This is the training code for the BERT model, which is designed to run on Google Cloud TPU v3-8.
521 |
522 | ### Packages
523 |
524 | #### [tpu_index](https://github.com/srihari-humbarwadi/tpu_index)
525 | Description: TPU Index is a package that we came up with for the community to use for fast similarity search over large collections of high dimension vectors on TPUs.
526 |
527 | ## Authors
528 |
529 | ### Santosh Gupta
530 |
531 | Santosh is a former Biomedical Engineer, current Machine Learning Engineer. His favorite area in machine learning is using the latest advancements in NLP for better scientific information retrieval. You can follow him on twitter here https://twitter.com/SantoshStyles
532 |
533 | ### Akul Vohra
534 |
535 | Akul is a junior in high school and is interested in NLP research. He would like to pursue cognitive science or computer science in the future and is happy to be a contributor in Natural Language Recommendations. Here is his portfolio: https://akul.org/
536 |
537 | ### Liam Croteau
538 |
539 | Liam is a Nanotechnology engineering undergraduate student, interested in NLP and machine learning for better scientific information retrieval. You can follow him on twitter here https://twitter.com/LiamCroteau.
540 |
541 |
542 | ### Srihari Humbarwadi
543 |
544 | Srihari is a Computer Vision Engineer, interested in Computer Vision, NLP and machine learning, He is currently working to improve self supervision and data efficient training methods. You can follow him on twitter here https://twitter.com/srihari_rh.
545 |
546 | ## Questions, Comments, Collaborations, Feedback?
547 |
548 | Research2vec@gmail.com
549 |
--------------------------------------------------------------------------------
/images/NLRLongerNoBanner.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/NLRLongerNoBanner.gif
--------------------------------------------------------------------------------
/images/NLRgifActuallyGif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/NLRgifActuallyGif.gif
--------------------------------------------------------------------------------
/images/NaturalLanguageRecommendationsUpdated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/NaturalLanguageRecommendationsUpdated.png
--------------------------------------------------------------------------------
/images/NaturalLanguageRecommendationsUpdated2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/NaturalLanguageRecommendationsUpdated2.png
--------------------------------------------------------------------------------
/images/Our Paper Comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/Our Paper Comparison.jpg
--------------------------------------------------------------------------------
/images/Relevant Paper Comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/Relevant Paper Comparison.jpg
--------------------------------------------------------------------------------
/images/SampleNLRresults.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/SampleNLRresults.JPG
--------------------------------------------------------------------------------
/images/TensorBoard.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/TensorBoard.JPG
--------------------------------------------------------------------------------
/images/architecturePart1.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart1.JPG
--------------------------------------------------------------------------------
/images/architecturePart1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart1.PNG
--------------------------------------------------------------------------------
/images/architecturePart2.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart2.JPG
--------------------------------------------------------------------------------
/images/architecturePart2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart2.PNG
--------------------------------------------------------------------------------
/images/architecturePart3.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart3.JPG
--------------------------------------------------------------------------------
/images/architecturePart3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart3.PNG
--------------------------------------------------------------------------------
/images/architecturePart4.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart4.JPG
--------------------------------------------------------------------------------
/images/architecturePart4.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/architecturePart4.PNG
--------------------------------------------------------------------------------
/images/gif4Github1-1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Santosh-Gupta/NaturalLanguageRecommendations/a60e961145d274942aee2537c3072097d0405966/images/gif4Github1-1.gif
--------------------------------------------------------------------------------
/notebooks/data/medical_preprocessing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "medical_preprocessing.ipynb",
7 | "provenance": [],
8 | "machine_shape": "hm",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "GPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "metadata": {
31 | "id": "9O1eOO0FjQYR",
32 | "colab_type": "code",
33 | "outputId": "90190b07-57cf-48af-95b8-9c00f502213f",
34 | "colab": {
35 | "base_uri": "https://localhost:8080/",
36 | "height": 126
37 | }
38 | },
39 | "source": [
40 | "from google.colab import drive\n",
41 | "drive.mount('/content/gdrive')"
42 | ],
43 | "execution_count": 1,
44 | "outputs": [
45 | {
46 | "output_type": "stream",
47 | "text": [
48 | "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
49 | "\n",
50 | "Enter your authorization code:\n",
51 | "··········\n",
52 | "Mounted at /content/gdrive\n"
53 | ],
54 | "name": "stdout"
55 | }
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "SWD-VJLXjeKB",
62 | "colab_type": "code",
63 | "outputId": "84abc25f-1eb5-4614-8e0a-faa53cea2009",
64 | "colab": {
65 | "base_uri": "https://localhost:8080/",
66 | "height": 231
67 | }
68 | },
69 | "source": [
70 | "!pip install langid\n",
71 | "!pip install tqdm"
72 | ],
73 | "execution_count": 2,
74 | "outputs": [
75 | {
76 | "output_type": "stream",
77 | "text": [
78 | "Collecting langid\n",
79 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ea/4c/0fb7d900d3b0b9c8703be316fbddffecdab23c64e1b46c7a83561d78bd43/langid-1.1.6.tar.gz (1.9MB)\n",
80 | "\u001b[K |████████████████████████████████| 1.9MB 4.6MB/s \n",
81 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from langid) (1.17.4)\n",
82 | "Building wheels for collected packages: langid\n",
83 | " Building wheel for langid (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
84 | " Created wheel for langid: filename=langid-1.1.6-cp36-none-any.whl size=1941190 sha256=699bf2d6b835147f8fff54168cb6831f3786472bf5c66de0e144b1ea52b4847f\n",
85 | " Stored in directory: /root/.cache/pip/wheels/29/bc/61/50a93be85d1afe9436c3dc61f38da8ad7b637a38af4824e86e\n",
86 | "Successfully built langid\n",
87 | "Installing collected packages: langid\n",
88 | "Successfully installed langid-1.1.6\n",
89 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.28.1)\n"
90 | ],
91 | "name": "stdout"
92 | }
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "Uv-KNIA_jf_Z",
99 | "colab_type": "code",
100 | "colab": {}
101 | },
102 | "source": [
103 | "import zipfile\n",
104 | "import os\n",
105 | "import sys\n",
106 | "import pandas as pd\n",
107 | "import numpy as np\n",
108 | "import gc\n",
109 | "from urllib import request\n",
110 | "import json\n",
111 | "import itertools\n",
112 | "import gzip\n",
113 | "import shutil\n",
114 | "import ast\n",
115 | "import pickle \n",
116 | "import nltk\n",
117 | "from langid.langid import LanguageIdentifier, model\n",
118 | "from tqdm import tqdm"
119 | ],
120 | "execution_count": 0,
121 | "outputs": []
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "I6iS_D19jlGx",
127 | "colab_type": "code",
128 | "colab": {}
129 | },
130 | "source": [
131 | "# Function to unip a file from semantic scholar to a json file, then open that json file\n",
132 | "def dlandopen(url, file, json ):\n",
133 | " request.urlretrieve(url, file )\n",
134 | " with gzip.open(file, 'rb') as f_in:\n",
135 | " with open(json, 'wb') as f_out:\n",
136 | " shutil.copyfileobj(f_in, f_out)\n",
137 | "\n",
138 | " os.remove(file)\n",
139 | " f = open(json)\n",
140 | " mylistH = f.readlines()\n",
141 | " del f\n",
142 | " return mylistH"
143 | ],
144 | "execution_count": 0,
145 | "outputs": []
146 | },
147 | {
148 | "cell_type": "code",
149 | "metadata": {
150 | "id": "35eLVFI8jpMZ",
151 | "colab_type": "code",
152 | "colab": {}
153 | },
154 | "source": [
155 | "# This is probably not the best way to go about it, but this is the code to convert the opened file into a pandas dataframe\n",
156 | "#Akul, it may be worth figuring out if you can make a dataframe directly, without going line-by-line \n",
157 | "def firstDataframe(mylist, start, finish):\n",
158 | " data = []\n",
159 | "\n",
160 | " for line in itertools.islice(mylist , start , finish):\n",
161 | " data.append(json.loads(line))\n",
162 | "\n",
163 | " df = pd.DataFrame(data)\n",
164 | " del data\n",
165 | " gc.collect()\n",
166 | " return df"
167 | ],
168 | "execution_count": 0,
169 | "outputs": []
170 | },
171 | {
172 | "cell_type": "code",
173 | "metadata": {
174 | "id": "V0YBFdwGjrPX",
175 | "colab_type": "code",
176 | "colab": {}
177 | },
178 | "source": [
179 | "def medicalRefinement(dftbr):\n",
180 | " dftbr['strsources'] = dftbr['sources'].astype('str')\n",
181 | " dftbr['totalCites2'] = dftbr['outCitations'] + dftbr['inCitations'] \n",
182 | " dftbr = dftbr[dftbr['totalCites2'].str.len() >= 1] #outCitations - all papers where papers cited is greater than 1\n",
183 | " dftbr = dftbr[ (dftbr.pmid != '' ) | (dftbr.strsources.str.contains(\"Medline\"))]\n",
184 | " dftbr.drop(['journalVolume','journalPages','year','authors','sources','doiUrl', 'strsources'], axis = 1)\n",
185 | " return dftbr"
186 | ],
187 | "execution_count": 0,
188 | "outputs": []
189 | },
190 | {
191 | "cell_type": "code",
192 | "metadata": {
193 | "id": "jBlLOuCNju3q",
194 | "colab_type": "code",
195 | "colab": {}
196 | },
197 | "source": [
198 | "def remove_nonenglish(df, jSet, vSet):\n",
199 | " # Gets list of English Journals\n",
200 | " listOfEnglishJournals = []\n",
201 | " identifier = LanguageIdentifier.from_modelstring(model, norm_probs=True)\n",
202 | " count = 0\n",
203 | " for i in jSet:\n",
204 | " lang = identifier.classify(str(i))\n",
205 | " langstr = str(lang)\n",
206 | " if langstr.find(\"en\") != -1 and lang[1] > 0.5:\n",
207 | " listOfEnglishJournals.append(i)\n",
208 | " \n",
209 | " # Gets List of English Venues\n",
210 | " listOfEnglishVenues = []\n",
211 | " count = 0\n",
212 | " for i in vSet:\n",
213 | " lang = identifier.classify(str(i))\n",
214 | " langstr = str(lang)\n",
215 | " if langstr.find(\"en\") != -1 and lang[1] > 0.5:\n",
216 | " listOfEnglishVenues.append(i)\n",
217 | "\n",
218 | " filter1 = df_refined['journalName'].isin(listOfEnglishJournals)\n",
219 | " filter2 = df_refined['venue'].isin(listOfEnglishVenues)\n",
220 | "\n",
221 | " dfFinal = df_refined[df_refined['journalName'].isin(listOfEnglishJournals) | df_refined['venue'].isin(listOfEnglishVenues)]\n",
222 | "\n",
223 | " return dfFinal"
224 | ],
225 | "execution_count": 0,
226 | "outputs": []
227 | },
228 | {
229 | "cell_type": "code",
230 | "metadata": {
231 | "id": "r5OmTsAFj0C0",
232 | "colab_type": "code",
233 | "outputId": "7f71cdc8-fab7-483c-aa6f-64682fc43813",
234 | "colab": {
235 | "base_uri": "https://localhost:8080/",
236 | "height": 390
237 | }
238 | },
239 | "source": [
240 | "url_template = 'https://s3-us-west-2.amazonaws.com/ai2-s2-research-public/open-corpus/2019-11-01/s2-corpus-{:03}.gz'\n",
241 | "file_template = 'json{:03}.gz'\n",
242 | "json_template = 'json{:03}'\n",
243 | "\n",
244 | "i=55\n",
245 | "mylist = dlandopen(url_template.format(i), file_template.format(i), json_template.format(i) )\n",
246 | "df = firstDataframe(mylist, 0, len(mylist))\n",
247 | "df_refined = medicalRefinement(df)\n",
248 | "print(\"REFINED\")"
249 | ],
250 | "execution_count": 8,
251 | "outputs": [
252 | {
253 | "output_type": "error",
254 | "ename": "KeyboardInterrupt",
255 | "evalue": "ignored",
256 | "traceback": [
257 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
258 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
259 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m55\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mmylist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdlandopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl_template\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfile_template\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjson_template\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfirstDataframe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmylist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmylist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mdf_refined\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmedicalRefinement\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"REFINED\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
260 | "\u001b[0;32m\u001b[0m in \u001b[0;36mfirstDataframe\u001b[0;34m(mylist, start, finish)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mitertools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mislice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmylist\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mfinish\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
261 | "\u001b[0;32m/usr/lib/python3.6/json/__init__.py\u001b[0m in \u001b[0;36mloads\u001b[0;34m(s, encoding, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0mparse_int\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mparse_float\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m parse_constant is None and object_pairs_hook is None and not kw):\n\u001b[0;32m--> 354\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_default_decoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 355\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mJSONDecoder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
262 | "\u001b[0;32m/usr/lib/python3.6/json/decoder.py\u001b[0m in \u001b[0;36mdecode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 338\u001b[0m \"\"\"\n\u001b[0;32m--> 339\u001b[0;31m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_w\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 340\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_w\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
263 | "\u001b[0;32m/usr/lib/python3.6/json/decoder.py\u001b[0m in \u001b[0;36mraw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 353\u001b[0m \"\"\"\n\u001b[1;32m 354\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscan_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 356\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mJSONDecodeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Expecting value\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
264 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
265 | ]
266 | }
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "metadata": {
272 | "id": "v95sJN4xk0TO",
273 | "colab_type": "code",
274 | "outputId": "136528cf-15bc-4b7a-bfbb-5773e0d472f3",
275 | "colab": {
276 | "base_uri": "https://localhost:8080/",
277 | "height": 35
278 | }
279 | },
280 | "source": [
281 | "df_refined.shape"
282 | ],
283 | "execution_count": 0,
284 | "outputs": [
285 | {
286 | "output_type": "execute_result",
287 | "data": {
288 | "text/plain": [
289 | "(112419, 21)"
290 | ]
291 | },
292 | "metadata": {
293 | "tags": []
294 | },
295 | "execution_count": 10
296 | }
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "metadata": {
302 | "id": "7eivzkoikg56",
303 | "colab_type": "code",
304 | "outputId": "7d82aac8-041c-4430-f2c5-1b7641e9e834",
305 | "colab": {
306 | "base_uri": "https://localhost:8080/",
307 | "height": 35
308 | }
309 | },
310 | "source": [
311 | "journalSet = set()\n",
312 | "venueSet = set()\n",
313 | "journalSet = journalSet|set(df_refined['journalName'])\n",
314 | "venueSet = journalSet|set(df_refined['venue'])\n",
315 | "english_only_df = remove_nonenglish(df_refined, journalSet, venueSet)\n",
316 | "print(\"NONENGLISH REMOVED\")"
317 | ],
318 | "execution_count": 0,
319 | "outputs": [
320 | {
321 | "output_type": "stream",
322 | "text": [
323 | "NONENGLISH REMOVED\n"
324 | ],
325 | "name": "stdout"
326 | }
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "metadata": {
332 | "id": "p_XwJd4-k2M5",
333 | "colab_type": "code",
334 | "outputId": "17105f07-1959-4224-8871-8095afb5249e",
335 | "colab": {
336 | "base_uri": "https://localhost:8080/",
337 | "height": 35
338 | }
339 | },
340 | "source": [
341 | "english_only_df.shape"
342 | ],
343 | "execution_count": 0,
344 | "outputs": [
345 | {
346 | "output_type": "execute_result",
347 | "data": {
348 | "text/plain": [
349 | "(76791, 21)"
350 | ]
351 | },
352 | "metadata": {
353 | "tags": []
354 | },
355 | "execution_count": 13
356 | }
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "metadata": {
362 | "id": "nWcJto9skrZJ",
363 | "colab_type": "code",
364 | "colab": {}
365 | },
366 | "source": [
367 | "english_only_df.to_json(\"/content/gdrive/My Drive/text2doc_data/med_data/med-{:03}.json\".format(i), orient='index')"
368 | ],
369 | "execution_count": 0,
370 | "outputs": []
371 | },
372 | {
373 | "cell_type": "code",
374 | "metadata": {
375 | "id": "DVIgq5XwzbhX",
376 | "colab_type": "code",
377 | "outputId": "e438671f-485f-408c-8faf-c681eac155c7",
378 | "colab": {
379 | "base_uri": "https://localhost:8080/",
380 | "height": 1000
381 | }
382 | },
383 | "source": [
384 | "url_template = 'https://s3-us-west-2.amazonaws.com/ai2-s2-research-public/open-corpus/2019-11-01/s2-corpus-{:03}.gz'\n",
385 | "file_template = 'json{:03}.gz'\n",
386 | "json_template = 'json{:03}'\n",
387 | "\n",
388 | "all_length = 0\n",
389 | "english_length = 0\n",
390 | "\n",
391 | "i = 175\n",
392 | "\n",
393 | "while i <= 178:\n",
394 | " mylist = dlandopen(url_template.format(i), file_template.format(i), json_template.format(i) )\n",
395 | " df = firstDataframe(mylist, 0, len(mylist))\n",
396 | " df_refined = medicalRefinement(df)\n",
397 | " df_refined.to_json(\"/content/gdrive/My Drive/text2doc_data/med_data_all/med-{:03}.json\".format(i), orient='index')\n",
398 | " all_length += df_refined.shape[0]\n",
399 | " print(\"REFINED\")\n",
400 | " journalSet = set()\n",
401 | " venueSet = set()\n",
402 | " journalSet = journalSet|set(df_refined['journalName'])\n",
403 | " venueSet = journalSet|set(df_refined['venue'])\n",
404 | " english_only_df = remove_nonenglish(df_refined, journalSet, venueSet)\n",
405 | " print(\"NONENGLISH REMOVED\")\n",
406 | " english_length += english_only_df.shape[0]\n",
407 | " english_only_df.to_json(\"/content/gdrive/My Drive/text2doc_data/med_data_english_only/med-{:03}.json\".format(i), orient='index')\n",
408 | " i+=1\n",
409 | "\n",
410 | "print(\"All length: \" + all_length)\n",
411 | "print(\"English length: \" + english_length)"
412 | ],
413 | "execution_count": 8,
414 | "outputs": [
415 | {
416 | "output_type": "stream",
417 | "text": [
418 | "REFINED\n",
419 | "NONENGLISH REMOVED\n",
420 | "REFINED\n",
421 | "NONENGLISH REMOVED\n",
422 | "REFINED\n",
423 | "NONENGLISH REMOVED\n",
424 | "REFINED\n",
425 | "NONENGLISH REMOVED\n",
426 | "REFINED\n",
427 | "NONENGLISH REMOVED\n",
428 | "REFINED\n",
429 | "NONENGLISH REMOVED\n",
430 | "REFINED\n",
431 | "NONENGLISH REMOVED\n",
432 | "REFINED\n",
433 | "NONENGLISH REMOVED\n",
434 | "REFINED\n",
435 | "NONENGLISH REMOVED\n",
436 | "REFINED\n",
437 | "NONENGLISH REMOVED\n",
438 | "REFINED\n",
439 | "NONENGLISH REMOVED\n",
440 | "REFINED\n",
441 | "NONENGLISH REMOVED\n",
442 | "REFINED\n",
443 | "NONENGLISH REMOVED\n",
444 | "REFINED\n",
445 | "NONENGLISH REMOVED\n",
446 | "REFINED\n",
447 | "NONENGLISH REMOVED\n",
448 | "REFINED\n",
449 | "NONENGLISH REMOVED\n",
450 | "REFINED\n",
451 | "NONENGLISH REMOVED\n",
452 | "REFINED\n",
453 | "NONENGLISH REMOVED\n",
454 | "REFINED\n",
455 | "NONENGLISH REMOVED\n",
456 | "REFINED\n",
457 | "NONENGLISH REMOVED\n",
458 | "REFINED\n",
459 | "NONENGLISH REMOVED\n",
460 | "REFINED\n",
461 | "NONENGLISH REMOVED\n"
462 | ],
463 | "name": "stdout"
464 | },
465 | {
466 | "output_type": "error",
467 | "ename": "OSError",
468 | "evalue": "ignored",
469 | "traceback": [
470 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
471 | "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
472 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfirstDataframe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmylist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmylist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mdf_refined\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmedicalRefinement\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mdf_refined\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_json\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/gdrive/My Drive/text2doc_data/med_data_all/med-{:03}.json\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morient\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'index'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0mall_length\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mdf_refined\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"REFINED\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
473 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36mto_json\u001b[0;34m(self, path_or_buf, orient, date_format, double_precision, force_ascii, date_unit, default_handler, lines, compression, index)\u001b[0m\n\u001b[1;32m 2422\u001b[0m \u001b[0mlines\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlines\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2423\u001b[0m \u001b[0mcompression\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2424\u001b[0;31m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2425\u001b[0m )\n\u001b[1;32m 2426\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
474 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/io/json/_json.py\u001b[0m in \u001b[0;36mto_json\u001b[0;34m(path_or_buf, obj, orient, date_format, double_precision, force_ascii, date_unit, default_handler, lines, compression, index)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandles\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_handle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_or_buf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"w\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompression\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0mfh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
475 | "\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device"
476 | ]
477 | }
478 | ]
479 | }
480 | ]
481 | }
482 |
--------------------------------------------------------------------------------
/notebooks/export_saved_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "model_debug.ipynb",
8 | "provenance": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.6.7"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "view-in-github",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "
"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "colab_type": "code",
44 | "id": "QK0_IOe0LkEY",
45 | "colab": {}
46 | },
47 | "source": [
48 | "!pip install transformers --quiet"
49 | ],
50 | "execution_count": 0,
51 | "outputs": []
52 | },
53 | {
54 | "cell_type": "code",
55 | "metadata": {
56 | "colab_type": "code",
57 | "id": "5jPcffepWMGw",
58 | "outputId": "391439f6-f66f-45a3-f1c1-b0deb5829cb3",
59 | "colab": {
60 | "base_uri": "https://localhost:8080/",
61 | "height": 272
62 | }
63 | },
64 | "source": [
65 | "!wget 'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar'\n",
66 | "!tar -xvf 'scibert_scivocab_uncased.tar'"
67 | ],
68 | "execution_count": 28,
69 | "outputs": [
70 | {
71 | "output_type": "stream",
72 | "text": [
73 | "--2019-12-28 08:15:15-- https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar\n",
74 | "Resolving s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)... 52.218.232.232\n",
75 | "Connecting to s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)|52.218.232.232|:443... connected.\n",
76 | "HTTP request sent, awaiting response... 200 OK\n",
77 | "Length: 442460160 (422M) [application/x-tar]\n",
78 | "Saving to: ‘scibert_scivocab_uncased.tar.1’\n",
79 | "\n",
80 | "scibert_scivocab_un 100%[===================>] 421.96M 19.4MB/s in 23s \n",
81 | "\n",
82 | "2019-12-28 08:15:39 (18.2 MB/s) - ‘scibert_scivocab_uncased.tar.1’ saved [442460160/442460160]\n",
83 | "\n",
84 | "scibert_scivocab_uncased/\n",
85 | "scibert_scivocab_uncased/vocab.txt\n",
86 | "scibert_scivocab_uncased/pytorch_model.bin\n",
87 | "scibert_scivocab_uncased/config.json\n"
88 | ],
89 | "name": "stdout"
90 | }
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "metadata": {
96 | "id": "gJwAz12AkBsb",
97 | "colab_type": "code",
98 | "colab": {}
99 | },
100 | "source": [
101 | "from google.colab import drive\n",
102 | "drive.mount('/gdrive')\n",
103 | "drive_base_path = '/gdrive/My Drive/'"
104 | ],
105 | "execution_count": 0,
106 | "outputs": []
107 | },
108 | {
109 | "cell_type": "code",
110 | "metadata": {
111 | "colab_type": "code",
112 | "id": "S_hHnz5fLiZ5",
113 | "outputId": "8eeb1daa-39bc-4884-e261-f6ab019d19f2",
114 | "colab": {
115 | "base_uri": "https://localhost:8080/",
116 | "height": 34
117 | }
118 | },
119 | "source": [
120 | "%tensorflow_version 2.x\n",
121 | "import os\n",
122 | "import tensorflow as tf\n",
123 | "from tensorflow.keras import backend as K\n",
124 | "from tensorflow.keras.layers import Lambda, Dense, Activation, Concatenate, Dropout\n",
125 | "from transformers import TFBertModel\n",
126 | "from time import time\n",
127 | "print('TensorFlow:', tf.__version__)"
128 | ],
129 | "execution_count": 30,
130 | "outputs": [
131 | {
132 | "output_type": "stream",
133 | "text": [
134 | "TensorFlow: 2.1.0-rc1\n"
135 | ],
136 | "name": "stdout"
137 | }
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "metadata": {
143 | "colab_type": "code",
144 | "id": "dYscLlW_7wIR",
145 | "colab": {}
146 | },
147 | "source": [
148 | "embedding_dim = 512\n",
149 | "model_files_pattern = 'gs://tfworld/model_files_2/*'"
150 | ],
151 | "execution_count": 0,
152 | "outputs": []
153 | },
154 | {
155 | "cell_type": "code",
156 | "metadata": {
157 | "colab_type": "code",
158 | "id": "TidDJ55-LiZ_",
159 | "colab": {}
160 | },
161 | "source": [
162 | "def create_model(drop_out):\n",
163 | " textIds = tf.keras.Input(shape=(512,), dtype=tf.int32) # from bert tokenizer\n",
164 | " citation = tf.keras.Input(shape=(512,)) # normalized word2vec outputs\n",
165 | " \n",
166 | " bert_model = TFBertModel.from_pretrained('scibert_scivocab_uncased', from_pt=True)\n",
167 | " \n",
168 | " textOut = bert_model(textIds)\n",
169 | " textOutMean = tf.reduce_mean(textOut[0], axis=1)\n",
170 | " textOutSim = Dense(units=embedding_dim, activation='tanh', name='DenseTitle')(textOutMean)\n",
171 | " textOutSim = Dropout(drop_out)(textOutSim)\n",
172 | " \n",
173 | " citationSim = Dense(units=embedding_dim, activation='tanh', name='DenseCitation')(citation)\n",
174 | " citationSim = Dropout(drop_out)(citationSim)\n",
175 | "\n",
176 | " # Get dot product of each of title x citation combinations\n",
177 | " dotProduct = tf.reduce_sum(tf.multiply(textOutSim[:, None, :], citationSim), axis=-1)\n",
178 | " \n",
179 | " # Softmax to make sure each row has sum == 1.0\n",
180 | " probs = tf.nn.softmax(dotProduct, axis=-1)\n",
181 | "\n",
182 | " model = tf.keras.Model(inputs=[textIds, citation], outputs=[probs])\n",
183 | " return model"
184 | ],
185 | "execution_count": 0,
186 | "outputs": []
187 | },
188 | {
189 | "cell_type": "code",
190 | "metadata": {
191 | "colab_type": "code",
192 | "id": "o8MXgYFSLiaB",
193 | "colab": {}
194 | },
195 | "source": [
196 | "model = create_model(drop_out=.2)\n",
197 | "model.load_weights('gs://tfworld/model_files_2/epoch_06_1.96')"
198 | ],
199 | "execution_count": 0,
200 | "outputs": []
201 | },
202 | {
203 | "cell_type": "code",
204 | "metadata": {
205 | "colab_type": "code",
206 | "id": "B8DkqlzfDmL_",
207 | "outputId": "ea6fbd13-8886-4978-bfdc-830e9d88673e",
208 | "colab": {
209 | "base_uri": "https://localhost:8080/",
210 | "height": 51
211 | }
212 | },
213 | "source": [
214 | "inference_model = tf.keras.Model(inputs=[model.inputs[0]],\n",
215 | " outputs=[model.get_layer('DenseTitle').output])\n"
216 | ],
217 | "execution_count": 36,
218 | "outputs": [
219 | {
220 | "output_type": "execute_result",
221 | "data": {
222 | "text/plain": [
223 | "([],\n",
224 | " [])"
225 | ]
226 | },
227 | "metadata": {
228 | "tags": []
229 | },
230 | "execution_count": 36
231 | }
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "metadata": {
237 | "colab_type": "code",
238 | "id": "tbJ3BoQFXYve",
239 | "colab": {}
240 | },
241 | "source": [
242 | "citation_projection_model = tf.keras.Sequential([tf.keras.Input(shape=(512,), dtype=tf.float32),\n",
243 | " model.get_layer('DenseCitation')])"
244 | ],
245 | "execution_count": 0,
246 | "outputs": []
247 | },
248 | {
249 | "cell_type": "code",
250 | "metadata": {
251 | "colab_type": "code",
252 | "id": "da9UjmnWXwPm",
253 | "outputId": "8d944d45-b2c3-4e1c-98cd-3f24bccede14",
254 | "colab": {
255 | "base_uri": "https://localhost:8080/",
256 | "height": 476
257 | }
258 | },
259 | "source": [
260 | "inference_model.summary(), citation_projection_model.summary()"
261 | ],
262 | "execution_count": 39,
263 | "outputs": [
264 | {
265 | "output_type": "stream",
266 | "text": [
267 | "Model: \"model_3\"\n",
268 | "_________________________________________________________________\n",
269 | "Layer (type) Output Shape Param # \n",
270 | "=================================================================\n",
271 | "input_4 (InputLayer) [(None, 512)] 0 \n",
272 | "_________________________________________________________________\n",
273 | "tf_bert_model_1 (TFBertModel ((None, 512, 768), (None, 109918464 \n",
274 | "_________________________________________________________________\n",
275 | "tf_op_layer_Mean_1 (TensorFl [(None, 768)] 0 \n",
276 | "_________________________________________________________________\n",
277 | "DenseTitle (Dense) (None, 512) 393728 \n",
278 | "=================================================================\n",
279 | "Total params: 110,312,192\n",
280 | "Trainable params: 110,312,192\n",
281 | "Non-trainable params: 0\n",
282 | "_________________________________________________________________\n",
283 | "Model: \"sequential_1\"\n",
284 | "_________________________________________________________________\n",
285 | "Layer (type) Output Shape Param # \n",
286 | "=================================================================\n",
287 | "DenseCitation (Dense) (None, 512) 262656 \n",
288 | "=================================================================\n",
289 | "Total params: 262,656\n",
290 | "Trainable params: 262,656\n",
291 | "Non-trainable params: 0\n",
292 | "_________________________________________________________________\n"
293 | ],
294 | "name": "stdout"
295 | },
296 | {
297 | "output_type": "execute_result",
298 | "data": {
299 | "text/plain": [
300 | "(None, None)"
301 | ]
302 | },
303 | "metadata": {
304 | "tags": []
305 | },
306 | "execution_count": 39
307 | }
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "metadata": {
313 | "id": "1pCOb2TdlDcY",
314 | "colab_type": "code",
315 | "colab": {}
316 | },
317 | "source": [
318 | "model_output_dir = drive_base_path+'tfworld/'\n",
319 | "os.mkdir(model_output_dir)"
320 | ],
321 | "execution_count": 0,
322 | "outputs": []
323 | },
324 | {
325 | "cell_type": "code",
326 | "metadata": {
327 | "colab_type": "code",
328 | "id": "VGbGLvBReBh4",
329 | "outputId": "9b494425-c79d-4f99-f15e-063e18931519",
330 | "colab": {
331 | "base_uri": "https://localhost:8080/",
332 | "height": 51
333 | }
334 | },
335 | "source": [
336 | "inference_model.save(model_output_dir + 'inference_model', save_format='tf')\n",
337 | "citation_projection_model.save(model_output_dir + 'citations_projection_model', save_format='tf')"
338 | ],
339 | "execution_count": 46,
340 | "outputs": [
341 | {
342 | "output_type": "stream",
343 | "text": [
344 | "INFO:tensorflow:Assets written to: /gdrive/My Drive/tfworld/inference_model/assets\n",
345 | "INFO:tensorflow:Assets written to: /gdrive/My Drive/tfworld/citations_projection_model/assets\n"
346 | ],
347 | "name": "stdout"
348 | }
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "metadata": {
354 | "id": "kX1NAS2mlkrB",
355 | "colab_type": "code",
356 | "colab": {
357 | "base_uri": "https://localhost:8080/",
358 | "height": 272
359 | },
360 | "outputId": "8ae2e6d2-0b9a-4950-ba4f-e0dacf462630"
361 | },
362 | "source": [
363 | "!zip -r \"/gdrive/My Drive/tfworld.zip\" \"/gdrive/My Drive/tfworld\""
364 | ],
365 | "execution_count": 48,
366 | "outputs": [
367 | {
368 | "output_type": "stream",
369 | "text": [
370 | " adding: gdrive/My Drive/tfworld/ (stored 0%)\n",
371 | " adding: gdrive/My Drive/tfworld/inference_model/ (stored 0%)\n",
372 | " adding: gdrive/My Drive/tfworld/inference_model/variables/ (stored 0%)\n",
373 | " adding: gdrive/My Drive/tfworld/inference_model/variables/variables.data-00000-of-00002 (deflated 13%)\n",
374 | " adding: gdrive/My Drive/tfworld/inference_model/variables/variables.data-00001-of-00002 (deflated 7%)\n",
375 | " adding: gdrive/My Drive/tfworld/inference_model/variables/variables.index (deflated 78%)\n",
376 | " adding: gdrive/My Drive/tfworld/inference_model/assets/ (stored 0%)\n",
377 | " adding: gdrive/My Drive/tfworld/inference_model/saved_model.pb (deflated 92%)\n",
378 | " adding: gdrive/My Drive/tfworld/citations_projection_model/ (stored 0%)\n",
379 | " adding: gdrive/My Drive/tfworld/citations_projection_model/variables/ (stored 0%)\n",
380 | " adding: gdrive/My Drive/tfworld/citations_projection_model/variables/variables.data-00000-of-00002 (deflated 61%)\n",
381 | " adding: gdrive/My Drive/tfworld/citations_projection_model/variables/variables.data-00001-of-00002 (deflated 8%)\n",
382 | " adding: gdrive/My Drive/tfworld/citations_projection_model/variables/variables.index (deflated 30%)\n",
383 | " adding: gdrive/My Drive/tfworld/citations_projection_model/assets/ (stored 0%)\n",
384 | " adding: gdrive/My Drive/tfworld/citations_projection_model/saved_model.pb (deflated 86%)\n"
385 | ],
386 | "name": "stdout"
387 | }
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "metadata": {
393 | "colab_type": "code",
394 | "id": "zE_TZrYPYoF9",
395 | "colab": {}
396 | },
397 | "source": [
398 | "abstract_model = tf.saved_model.load(model_output_dir + 'inference_model')\n",
399 | "citations_model = tf.saved_model.load(model_output_dir + 'citations_projection_model')"
400 | ],
401 | "execution_count": 0,
402 | "outputs": []
403 | },
404 | {
405 | "cell_type": "code",
406 | "metadata": {
407 | "colab_type": "code",
408 | "id": "GDCKXvolfcOC",
409 | "colab": {}
410 | },
411 | "source": [
412 | "abstractIds = tf.random.uniform(shape=(1, 512), maxval=500, dtype=tf.int32).numpy()\n",
413 | "citation_vector = tf.random.uniform(shape=(1, 512), minval=-1, maxval=1, dtype=tf.float32).numpy()"
414 | ],
415 | "execution_count": 0,
416 | "outputs": []
417 | },
418 | {
419 | "cell_type": "code",
420 | "metadata": {
421 | "colab_type": "code",
422 | "id": "3FYcfSTkYr-l",
423 | "outputId": "d00e0a52-089c-4b60-8987-d284a95b19bb",
424 | "colab": {
425 | "base_uri": "https://localhost:8080/",
426 | "height": 1000
427 | }
428 | },
429 | "source": [
430 | "abstract_model(abstractIds), citations_model(citation_vector)"
431 | ],
432 | "execution_count": 51,
433 | "outputs": [
434 | {
435 | "output_type": "execute_result",
436 | "data": {
437 | "text/plain": [
438 | "(,\n",
542 | " )"
646 | ]
647 | },
648 | "metadata": {
649 | "tags": []
650 | },
651 | "execution_count": 51
652 | }
653 | ]
654 | },
655 | {
656 | "cell_type": "code",
657 | "metadata": {
658 | "colab_type": "code",
659 | "id": "nYO-rH1tYyAs",
660 | "colab": {}
661 | },
662 | "source": [
663 | ""
664 | ],
665 | "execution_count": 0,
666 | "outputs": []
667 | }
668 | ]
669 | }
--------------------------------------------------------------------------------
/notebooks/inference/create_abstract_vectors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "!gdown --id 1-1HED-B-HtuZR9kNtybz9sypVpk2A_fX # TitlesAbstractsEmbedIds\n",
10 | "!wget 'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar'\n",
11 | "!tar -xvf 'scibert_scivocab_uncased.tar'\n",
12 | "!pip install transformers --quiet"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "%tensorflow_version 2.x\n",
22 | "import numpy as np\n",
23 | "import tensorflow as tf\n",
24 | "from time import time\n",
25 | "from tqdm import tqdm_notebook as tqdm\n",
26 | "from transformers import BertTokenizer\n",
27 | "import pandas as pd\n",
28 | "from pprint import pprint\n",
29 | "\n",
30 | "print('TensorFlow:', tf.__version__)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "try:\n",
40 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
41 | " print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
42 | "except ValueError:\n",
43 | " tpu = None\n",
44 | "\n",
45 | "if tpu:\n",
46 | " tf.config.experimental_connect_to_cluster(tpu)\n",
47 | " tf.tpu.experimental.initialize_tpu_system(tpu)\n",
48 | " strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
49 | "else:\n",
50 | " strategy = tf.distribute.MirroredStrategy()\n",
51 | "\n",
52 | "print(\"REPLICAS: \", strategy.num_replicas_in_sync)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "model = tf.saved_model.load('gs://tfworld/saved_models')\n",
62 | "tokenizer = BertTokenizer(vocab_file='scibert_scivocab_uncased/vocab.txt')\n",
63 | "\n",
64 | "df = pd.read_json('TitlesAbstractsEmbedIds.json.gzip', compression = 'gzip')\n",
65 | "embed2Title = pd.Series(df['title'].values,index=df['EmbeddingID']).to_dict()\n",
66 | "embed2Abstract = pd.Series(df['paperAbstract'].values,index=df['EmbeddingID']).to_dict()"
67 | ]
68 | }
69 | ],
70 | "metadata": {
71 | "kernelspec": {
72 | "display_name": "Python 3",
73 | "language": "python",
74 | "name": "python3"
75 | },
76 | "language_info": {
77 | "codemirror_mode": {
78 | "name": "ipython",
79 | "version": 3
80 | },
81 | "file_extension": ".py",
82 | "mimetype": "text/x-python",
83 | "name": "python",
84 | "nbconvert_exporter": "python",
85 | "pygments_lexer": "ipython3",
86 | "version": "3.6.7"
87 | }
88 | },
89 | "nbformat": 4,
90 | "nbformat_minor": 4
91 | }
92 |
--------------------------------------------------------------------------------
/notebooks/tfrecords_debug.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "colab": {
18 | "base_uri": "https://localhost:8080/",
19 | "height": 34
20 | },
21 | "colab_type": "code",
22 | "id": "QK0_IOe0LkEY",
23 | "outputId": "9c211742-51cf-4fd8-e9e8-452a812737f4"
24 | },
25 | "outputs": [
26 | {
27 | "name": "stderr",
28 | "output_type": "stream",
29 | "text": [
30 | "UsageError: Line magic function `%tensorflow_version` not found.\n"
31 | ]
32 | }
33 | ],
34 | "source": [
35 | "%tensorflow_version 2.x"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {
42 | "colab": {
43 | "base_uri": "https://localhost:8080/",
44 | "height": 34
45 | },
46 | "colab_type": "code",
47 | "id": "S_hHnz5fLiZ5",
48 | "outputId": "105d9ee2-0c9c-4700-e3b6-47b8d669336e"
49 | },
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "TensorFlow: 2.0.0\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "import tensorflow as tf\n",
61 | "from tqdm import tqdm_notebook as tqdm\n",
62 | "import os\n",
63 | "\n",
64 | "print('TensorFlow:', tf.__version__)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {
71 | "colab": {},
72 | "colab_type": "code",
73 | "id": "1d9orC2vS5Tr"
74 | },
75 | "outputs": [],
76 | "source": [
77 | "batch_size = 8\n",
78 | "embedding_dim = 512\n",
79 | "autotune = tf.data.experimental.AUTOTUNE"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {
86 | "colab": {},
87 | "colab_type": "code",
88 | "id": "iYONEmj6LiZ9"
89 | },
90 | "outputs": [],
91 | "source": [
92 | "def get_random_title():\n",
93 | " return tf.random.uniform(shape=[512], maxval=200, dtype=tf.int32)\n",
94 | "\n",
95 | "def get_random_citation():\n",
96 | " vector = tf.random.uniform(shape=[embedding_dim], minval=-1, maxval=1, dtype=tf.float32)\n",
97 | " normed_vector = tf.math.l2_normalize(vector)\n",
98 | " return normed_vector\n",
99 | "\n",
100 | "def generate_sample():\n",
101 | " title = get_random_title()\n",
102 | " posCitations = get_random_citation()\n",
103 | " return title, posCitations"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 6,
109 | "metadata": {
110 | "colab": {},
111 | "colab_type": "code",
112 | "id": "6hDwsTD5S5Tv"
113 | },
114 | "outputs": [],
115 | "source": [
116 | "class TFrecordWriter:\n",
117 | " def __init__(self,\n",
118 | " n_samples,\n",
119 | " n_shards,\n",
120 | " output_dir='',\n",
121 | " prefix=''):\n",
122 | " self.n_samples = n_samples\n",
123 | " self.n_shards = n_shards\n",
124 | " self.step_size = self.n_samples//self.n_shards + 1\n",
125 | " self.prefix = prefix\n",
126 | " self.output_dir = output_dir\n",
127 | " self.buffer = []\n",
128 | " self.file_count = 1\n",
129 | " \n",
130 | " def make_example(self, title, vector):\n",
131 | " feature = {\n",
132 | " 'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title)),\n",
133 | " 'citation': tf.train.Feature(float_list=tf.train.FloatList(value=vector))\n",
134 | " }\n",
135 | " return tf.train.Example(features=tf.train.Features(feature=feature))\n",
136 | " \n",
137 | " def write_tfrecord(self, tfrecord_path):\n",
138 | " print('writing {} samples in {}'.format(len(self.buffer), tfrecord_path))\n",
139 | " with tf.io.TFRecordWriter(tfrecord_path) as writer:\n",
140 | " for (title, vector) in tqdm(self.buffer):\n",
141 | " example = self.make_example(title, vector)\n",
142 | " writer.write(example.SerializeToString())\n",
143 | " \n",
144 | " def push(self, title, vector):\n",
145 | " self.buffer.append([title, vector])\n",
146 | " if len(self.buffer) == self.step_size:\n",
147 | " fname = self.prefix + '_000' + str(self.file_count) + '.tfrecord'\n",
148 | " tfrecord_path = os.path.join(self.output_dir, fname)\n",
149 | " self.write_tfrecord(tfrecord_path)\n",
150 | " self.clear_buffer()\n",
151 | " self.file_count += 1\n",
152 | " \n",
153 | " def flush_last(self):\n",
154 | " if len(self.buffer):\n",
155 | " fname = self.prefix + '_000' + str(self.file_count) + '.tfrecord'\n",
156 | " tfrecord_path = os.path.join(self.output_dir, fname)\n",
157 | " self.write_tfrecord(tfrecord_path)\n",
158 | " \n",
159 | " def clear_buffer(self):\n",
160 | " self.buffer = []"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 7,
166 | "metadata": {
167 | "colab": {},
168 | "colab_type": "code",
169 | "id": "d0-G4yukS5Tx"
170 | },
171 | "outputs": [],
172 | "source": [
173 | "!mkdir 'tfrecords'\n",
174 | "tfrecord_writer = TFrecordWriter(1000, 4, 'tfrecords', 'train')"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 8,
180 | "metadata": {
181 | "colab": {
182 | "base_uri": "https://localhost:8080/",
183 | "height": 1000,
184 | "referenced_widgets": [
185 | "571fdcfd73f140898380a7615f4b891f",
186 | "c459b83831354e1eb06bd702d7706f34",
187 | "186100bec2e9404a855672114e2b9237",
188 | "ae77968ebdc34c88a7c74050c5698fb6",
189 | "a125d1cf5a3e45fd96bf1eb31047104d",
190 | "4c58c1ef413c4477a523d91951c91979",
191 | "ab50e4fc295f4c2ba714dbb4d70d568e",
192 | "6bdb266c226a44a59264abd8d6ac4b30",
193 | "b472b582bbe840a081ff2dd8fd3f2ae5",
194 | "43a93f68343640eab24838c7c96a7436",
195 | "e407c28cea89485bb41721789dc952df",
196 | "211f913e65ba4dfe99a233ca14f95e78",
197 | "c4ef17c618be4d5aaf3612385abd56eb",
198 | "8948b8dfc18d4c449d6baeb44ff5d350",
199 | "cc6c42edc27d44e1af90519324b4918b",
200 | "29c713a3b75b428891024d24307b35d8",
201 | "b0d62a5157e94754b21704d9b6e24e39",
202 | "5877525f104b4e7890de53886088a8e5",
203 | "59846ac3b21a4f62a96d05d27fb99286",
204 | "fa84376fd28f4f45ab21f43307f0aa86",
205 | "7d8221302dd84e03b11f46acddd3c552",
206 | "dbe29203dad545a2a578c05a5b1203fa",
207 | "6ad0bc711c8f4373be4e8980a2ed267a",
208 | "fcec2fa19ddd48dc9e563ba2e003f41e",
209 | "648b766e7596437db3b42821b571ac86",
210 | "2a2b83751d154af7a9c34b7e06469492",
211 | "b4d968b6c06047b7b9389845804d5cad",
212 | "1e56c59524504700ba42bdd32aed76d2",
213 | "a552c6bc81e64df4b5615f8c6cb5710f",
214 | "8f00ba59569b4c3fbd3f293555ec486a",
215 | "49e03462281440ad8203670e5927a517",
216 | "0a0ef7de6b3140f48768b57208071adc",
217 | "3f4f02e12c704d3ca920b610d2f72b18",
218 | "32f4f4519f864205a2ebef469373e392",
219 | "842d88cc6d69424fb23840e5e3c52a0a",
220 | "18e685f7865b475fac2216279108de0c",
221 | "2b0e9294f39c489c9d7ef706d727e7af",
222 | "7aecb4b66fb0460c9a5f49080ff5639e",
223 | "039256722aed4aaf92fbc23412f272b2",
224 | "53728d2634ba46e0882149085946215f",
225 | "a22b500f733a4ecaa05e759efc665791",
226 | "c1926335b7254e70bb9e0d9d63f15cf1",
227 | "b02e734bd35f4c1cb9d11d6b5f75cae0",
228 | "891f99b858ff4d2baa944bbc63741bb1",
229 | "bd2f6b45affd4be4bcf96e2cb3fb3f73",
230 | "7d1e9e6b5d2241f4a844f94fc8fdde8a",
231 | "859924f71226469ea98a13324251b1af",
232 | "b427ac6a762d472cbaa9a406126629d5",
233 | "1ad203c086514394a1edb87f0aae790e",
234 | "71432b5d85ed4c3db401c9093bdfb33a",
235 | "2fedbd6add144ef0ba9755064acabd54",
236 | "89c9b245c1f546e5baae500878428ec8",
237 | "e12fbb49e7eb449bb02c298f197d1d7d",
238 | "545309513e5740878c9e3c18d3ff139b",
239 | "bf29f8f29b1e48b0b0ee0d91be3f9322",
240 | "74068b6c2ae14d69919f51b54398deb6",
241 | "85145d0f111b47c89f66d34051819578",
242 | "82c7973c8cac40678528302925e98ed0",
243 | "d7e0e9e08ceb43458665f3f0a3df9a8f",
244 | "2b4f4fa593984b28b65885a46fd57f53",
245 | "ea98cf874ede42c2afe5252b5b5162cc",
246 | "44a56d3218ca4e6da401bf6cb853362e",
247 | "29d10197e340479e8fdaf2f9b4a2b4ff",
248 | "8d421553642c4c6a845afdad16a77fad",
249 | "7cbc303b0d834aab9d50d627502f74b9",
250 | "fdf8fd6e2f0f4794bd7bfa04b04bfdee",
251 | "4df9870358ce4f759563ed4762c50390",
252 | "8ded66f3746143eb87c59ad58f782829",
253 | "3df4f6b79bd94e3295761994abcaf6ff",
254 | "d389bd68ed7d4102ba9ef4eb9cd8cab7",
255 | "350d6612e4ff469884e74b32bc45906a",
256 | "de4c8c1f7ca245bc8328f07ab5c7ab4e",
257 | "9d56cdfc98f74a6dad8b9585aaa8c146",
258 | "f08c432f546e440c92a9a3c6c0b7725c",
259 | "bcaa95853b3a4a3caaee8a2c4b97a2ba",
260 | "575f14c8f31247889f149b16f0613ed3",
261 | "3515a5812c0b4281a09d500261b69d7f",
262 | "f1b12589109f4cd09d20c86e5a9a6247",
263 | "801d6bfda9c84f9f804d239627ec4739",
264 | "094031bb9e004e32b48764a4b3ce5593",
265 | "d20cec67aad44d0aac012c4d6df4849d",
266 | "c26df9e7f606425f94da1ebb96283296",
267 | "1b32f738278b410a9183af804091d810",
268 | "05693f0785f4446ba3fea3621c0abd90",
269 | "e6e54dc3d19a4a92a4b405da9512ce45",
270 | "5488dad2be4f416b96f4b143294222d6",
271 | "63beb9e643b64dbbb76a400f4dbe49cc",
272 | "c8e8ed5230b14ae6bb88ab6dc8d878c2",
273 | "86dca666d05f40b882cf76314430ef4b",
274 | "9af432dfa1d24ff4935e91ff3704b232",
275 | "775e13c4f72e4700867d40d031d82e65",
276 | "25b4131c04c34f90b6d6f403e52145f6",
277 | "914c3bd36bfd4c0d8cbfb46af411e16d",
278 | "a5b347586a444c3baf93979351db91bf",
279 | "bc5a7773fd024f57b0c93ca603a85db2",
280 | "3220373463444d2b820df4eb5f0dca5e",
281 | "f402ce76096e45bb89db640ef8410848",
282 | "aeae3d83f1ec45ccb62ad43b99c7c94d",
283 | "369d6a31be0b454a9d05d44c138c131c",
284 | "49d84c841e184ff487f756c9a0e18b97",
285 | "b38386e3ff7d446e88c4ee5716413002",
286 | "a9808e675c934f0ca93cea336ef330d0",
287 | "412199da8a294f2f8036ce1bb2756505",
288 | "e61bbd3e4bdb499dba9c9ac4d7497120",
289 | "9cc080ab6e734b8b9db5c684cc5303ce",
290 | "6678878a1a2143738d58bf2528936d6a",
291 | "e7e92fd4d2b94a1cb9576a495a34eb56",
292 | "87f1a32b5bf9424baaab349226a6dff2",
293 | "d36f741cb4e448b6915b63f3ff1340ff",
294 | "e0d4f7adabec492092a78c1038bbddaa",
295 | "033e44519d07457d86cfc4fe7757636f",
296 | "97d4cea967d84dd087bd3a1c9efdbf68",
297 | "3c79a302616c497bba4afd1ada4cb0b6",
298 | "712ecaa3c9304ab9aa1c414aa86c0689",
299 | "01bb16f673bf43119f56c764baadac53",
300 | "5e28295ea30f41539fadd24480330115",
301 | "0abf9d419b934be3a1b3c61079207019",
302 | "54f14a7163bd4ca7b023e31f387b0a6a",
303 | "8619dc98ccd043e88cd5c289a7ae99c7",
304 | "23c5f001806e4e82a96fc9f11b18a903",
305 | "f206a791c1e74586a63d891d04883eb4",
306 | "1daadb88ce8b41d4ad3960659df9f42b",
307 | "e84113d409804318b38a27b5ad24c9b3",
308 | "9d5b52e80eb940218a92049f66a966bf",
309 | "20af83e68c2e473590adf7de4a522949",
310 | "d2df6adb047b415d83cbba90bb9020b3",
311 | "a1c8fd5b4c634e78aa6ca21c94fe3efd",
312 | "b0c8a2422c9446849b16ea7cd2c2b530"
313 | ]
314 | },
315 | "colab_type": "code",
316 | "id": "-2D3-6KFS5T0",
317 | "outputId": "edd11383-6325-4740-d28a-23398e655b17"
318 | },
319 | "outputs": [
320 | {
321 | "name": "stdout",
322 | "output_type": "stream",
323 | "text": [
324 | "writing 251 samples in tfrecords/train_0001.tfrecord\n"
325 | ]
326 | },
327 | {
328 | "name": "stderr",
329 | "output_type": "stream",
330 | "text": [
331 | "/Users/srihari/tf2.0/lib/python3.6/site-packages/ipykernel_launcher.py:25: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
332 | "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
333 | ]
334 | },
335 | {
336 | "data": {
337 | "application/vnd.jupyter.widget-view+json": {
338 | "model_id": "7f366b6e4bc44869acf1f4916a3c3c20",
339 | "version_major": 2,
340 | "version_minor": 0
341 | },
342 | "text/plain": [
343 | "HBox(children=(IntProgress(value=0, max=251), HTML(value='')))"
344 | ]
345 | },
346 | "metadata": {},
347 | "output_type": "display_data"
348 | },
349 | {
350 | "name": "stdout",
351 | "output_type": "stream",
352 | "text": [
353 | "\n",
354 | "writing 251 samples in tfrecords/train_0002.tfrecord\n"
355 | ]
356 | },
357 | {
358 | "data": {
359 | "application/vnd.jupyter.widget-view+json": {
360 | "model_id": "b1dab0c4cc00499db6dc6804339e49f3",
361 | "version_major": 2,
362 | "version_minor": 0
363 | },
364 | "text/plain": [
365 | "HBox(children=(IntProgress(value=0, max=251), HTML(value='')))"
366 | ]
367 | },
368 | "metadata": {},
369 | "output_type": "display_data"
370 | },
371 | {
372 | "name": "stdout",
373 | "output_type": "stream",
374 | "text": [
375 | "\n",
376 | "writing 251 samples in tfrecords/train_0003.tfrecord\n"
377 | ]
378 | },
379 | {
380 | "data": {
381 | "application/vnd.jupyter.widget-view+json": {
382 | "model_id": "cbcb837dc4114a3c9057c8e4994d12a6",
383 | "version_major": 2,
384 | "version_minor": 0
385 | },
386 | "text/plain": [
387 | "HBox(children=(IntProgress(value=0, max=251), HTML(value='')))"
388 | ]
389 | },
390 | "metadata": {},
391 | "output_type": "display_data"
392 | },
393 | {
394 | "name": "stdout",
395 | "output_type": "stream",
396 | "text": [
397 | "\n",
398 | "writing 247 samples in tfrecords/train_0004.tfrecord\n"
399 | ]
400 | },
401 | {
402 | "data": {
403 | "application/vnd.jupyter.widget-view+json": {
404 | "model_id": "e69a6d3ebe734cebb3c5de1526545bd7",
405 | "version_major": 2,
406 | "version_minor": 0
407 | },
408 | "text/plain": [
409 | "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
410 | ]
411 | },
412 | "metadata": {},
413 | "output_type": "display_data"
414 | },
415 | {
416 | "name": "stdout",
417 | "output_type": "stream",
418 | "text": [
419 | "\n"
420 | ]
421 | }
422 | ],
423 | "source": [
424 | "for i in range(1000):\n",
425 | " title, vector = generate_sample()\n",
426 | " tfrecord_writer.push(title, vector)\n",
427 | "tfrecord_writer.flush_last()"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": null,
433 | "metadata": {
434 | "colab": {},
435 | "colab_type": "code",
436 | "id": "NwA4RriLS5T2"
437 | },
438 | "outputs": [],
439 | "source": []
440 | }
441 | ],
442 | "metadata": {
443 | "accelerator": "TPU",
444 | "colab": {
445 | "include_colab_link": true,
446 | "name": "model_debug.ipynb",
447 | "provenance": []
448 | },
449 | "kernelspec": {
450 | "display_name": "Python 3",
451 | "language": "python",
452 | "name": "python3"
453 | },
454 | "language_info": {
455 | "codemirror_mode": {
456 | "name": "ipython",
457 | "version": 3
458 | },
459 | "file_extension": ".py",
460 | "mimetype": "text/x-python",
461 | "name": "python",
462 | "nbconvert_exporter": "python",
463 | "pygments_lexer": "ipython3",
464 | "version": "3.6.8"
465 | },
466 | "widgets": {
467 | "application/vnd.jupyter.widget-state+json": {
468 | "state": {},
469 | "version_major": 2,
470 | "version_minor": 0
471 | }
472 | }
473 | },
474 | "nbformat": 4,
475 | "nbformat_minor": 4
476 | }
477 |
--------------------------------------------------------------------------------
/notebooks/tpu_index_debug.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "colab": {
18 | "base_uri": "https://localhost:8080/",
19 | "height": 153
20 | },
21 | "colab_type": "code",
22 | "id": "8wAw2JV6F4Am",
23 | "outputId": "6ca9a703-0699-425b-926a-eeb6604e8464"
24 | },
25 | "outputs": [
26 | {
27 | "name": "stdout",
28 | "output_type": "stream",
29 | "text": [
30 | "Downloading...\n",
31 | "From: https://drive.google.com/uc?id=1-8nsWLseynVj6Z9-E12w1ywnffnqJftm\n",
32 | "To: /content/Uembeds306Epochs.npy\n",
33 | "2.59GB [00:21, 119MB/s]\n",
34 | "Downloading...\n",
35 | "From: https://drive.google.com/uc?id=1UszbNYQnlNrAcPQkBwvb1wKX21oRPiqb\n",
36 | "To: /content/Vembeds306Epochs.npy\n",
37 | "2.59GB [00:38, 67.5MB/s]\n"
38 | ]
39 | }
40 | ],
41 | "source": [
42 | "!gdown --id 1-8nsWLseynVj6Z9-E12w1ywnffnqJftm\n",
43 | "!gdown --id 1UszbNYQnlNrAcPQkBwvb1wKX21oRPiqb"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 1,
49 | "metadata": {
50 | "colab": {
51 | "base_uri": "https://localhost:8080/",
52 | "height": 51
53 | },
54 | "colab_type": "code",
55 | "id": "SGov6e8uGO3B",
56 | "outputId": "a3eeff15-ef9e-4066-9f14-fe349ff14d59"
57 | },
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "TensorFlow 2.x selected.\n",
64 | "TensorFlow: 2.1.0-rc1\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "%tensorflow_version 2.x\n",
70 | "from concurrent.futures import ProcessPoolExecutor\n",
71 | "import numpy as np\n",
72 | "import tensorflow as tf\n",
73 | "from time import time\n",
74 | "from tqdm import tqdm_notebook as tqdm\n",
75 | "print('TensorFlow:', tf.__version__)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {
82 | "colab": {},
83 | "colab_type": "code",
84 | "id": "iht3WQa3H_g6"
85 | },
86 | "outputs": [],
87 | "source": [
88 | "try:\n",
89 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
90 | " print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
91 | "except ValueError:\n",
92 | " tpu = None\n",
93 | "\n",
94 | "if tpu:\n",
95 | " tf.config.experimental_connect_to_cluster(tpu)\n",
96 | " tf.tpu.experimental.initialize_tpu_system(tpu)\n",
97 | " strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
98 | "else:\n",
99 | " strategy = tf.distribute.MirroredStrategy()\n",
100 | "\n",
101 | "print(\"REPLICAS: \", strategy.num_replicas_in_sync)"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 3,
107 | "metadata": {
108 | "colab": {
109 | "base_uri": "https://localhost:8080/",
110 | "height": 153
111 | },
112 | "colab_type": "code",
113 | "id": "3ks95sxuIKbR",
114 | "outputId": "c0ccc370-ac93-4147-b48b-0dd1b015f614"
115 | },
116 | "outputs": [
117 | {
118 | "data": {
119 | "text/plain": [
120 | "['/job:worker/replica:0/task:0/device:TPU:0',\n",
121 | " '/job:worker/replica:0/task:0/device:TPU:1',\n",
122 | " '/job:worker/replica:0/task:0/device:TPU:2',\n",
123 | " '/job:worker/replica:0/task:0/device:TPU:3',\n",
124 | " '/job:worker/replica:0/task:0/device:TPU:4',\n",
125 | " '/job:worker/replica:0/task:0/device:TPU:5',\n",
126 | " '/job:worker/replica:0/task:0/device:TPU:6',\n",
127 | " '/job:worker/replica:0/task:0/device:TPU:7']"
128 | ]
129 | },
130 | "execution_count": 3,
131 | "metadata": {
132 | "tags": []
133 | },
134 | "output_type": "execute_result"
135 | }
136 | ],
137 | "source": [
138 | "workers = ['/job:worker/replica:0/task:0/device:TPU:'+str(i) for i in range(8)]\n",
139 | "workers"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {
146 | "colab": {},
147 | "colab_type": "code",
148 | "id": "uByeXRrVO78_"
149 | },
150 | "outputs": [],
151 | "source": [
152 | "class Index:\n",
153 | " def __init__(self, u, v, worker):\n",
154 | " self.embeddings = tf.math.l2_normalize(u, axis=1) + tf.math.l2_normalize(v, axis=1)\n",
155 | " self.squared_norms_embeddings = tf.expand_dims(tf.square(tf.norm(self.embeddings, axis=1)), axis=0)\n",
156 | " self.worker = worker\n",
157 | "\n",
158 | " @tf.function\n",
159 | " def search(self, query_vector, top_k=None):\n",
160 | " with tf.device(worker):\n",
161 | " squared_norms_query_vector = tf.expand_dims(tf.square(tf.norm(query_vector, axis=1)), axis=0)\n",
162 | " dot_product = tf.reduce_sum(tf.multiply(self.embeddings, query_vector), axis=1)\n",
163 | " distances = tf.maximum(self.squared_norms_embeddings + squared_norms_query_vector - 2 * dot_product, 0)\n",
164 | " sorted_indices = tf.argsort(distances)\n",
165 | " if top_k:\n",
166 | " sorted_indices = sorted_indices[..., :top_k]\n",
167 | " nearest_distances = tf.reshape(tf.gather(distances[0], sorted_indices), shape=[-1, 1])\n",
168 | " return nearest_distances[..., 0], sorted_indices[0]"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 5,
174 | "metadata": {
175 | "colab": {
176 | "base_uri": "https://localhost:8080/",
177 | "height": 34
178 | },
179 | "colab_type": "code",
180 | "id": "DFloyWt7G8ts",
181 | "outputId": "ddaf8718-6c4d-46d4-e754-2cf166e99a83"
182 | },
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/plain": [
187 | "((1262996, 512), (1262996, 512))"
188 | ]
189 | },
190 | "execution_count": 5,
191 | "metadata": {
192 | "tags": []
193 | },
194 | "output_type": "execute_result"
195 | }
196 | ],
197 | "source": [
198 | "u_embeddings = np.load('Uembeds306Epochs.npy')\n",
199 | "v_embeddings = np.load('Vembeds306Epochs.npy')\n",
200 | "u_embeddings.shape, v_embeddings.shape"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {
207 | "colab": {},
208 | "colab_type": "code",
209 | "id": "AB8L4sA-H2nf"
210 | },
211 | "outputs": [],
212 | "source": [
213 | "# Discarding last 4 vectors to make number of vectors divisible by 8\n",
214 | "u_embeddings = np.split(u_embeddings[:-4], 8, axis=0)\n",
215 | "v_embeddings = np.split(v_embeddings[:-4], 8, axis=0)\n",
216 | "vecs_per_index = 157874"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 7,
222 | "metadata": {
223 | "colab": {
224 | "base_uri": "https://localhost:8080/",
225 | "height": 153
226 | },
227 | "colab_type": "code",
228 | "id": "HoEA0SqeJW4n",
229 | "outputId": "ddad8271-169b-4377-e94b-68e20bccd08d"
230 | },
231 | "outputs": [
232 | {
233 | "name": "stdout",
234 | "output_type": "stream",
235 | "text": [
236 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:0\n",
237 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:1\n",
238 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:2\n",
239 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:3\n",
240 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:4\n",
241 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:5\n",
242 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:6\n",
243 | "Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:7\n"
244 | ]
245 | }
246 | ],
247 | "source": [
248 | "## Place 1/8 of total embeddings on each TPU core\n",
249 | "indices = []\n",
250 | "for i, worker in enumerate(workers):\n",
251 | " with tf.device(worker):\n",
252 | " print('Building index with {} vectors on {}'.format(u_embeddings[i].shape[0],worker))\n",
253 | " indices.append(Index(u_embeddings[i], v_embeddings[i], worker))"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "metadata": {
260 | "colab": {},
261 | "colab_type": "code",
262 | "id": "W5r4LmcVqGAZ"
263 | },
264 | "outputs": [],
265 | "source": [
266 | "def search(xq):\n",
267 | " D, I = [], []\n",
268 | " for i in range(8):\n",
269 | " print('Search running in index: {}'.format(indices[i].worker))\n",
270 | " d, idx = indices[i].search(xq, 1)\n",
271 | " D.append(d.numpy()[0])\n",
272 | " I.append(i*vecs_per_index + idx.numpy()[0])\n",
273 | "\n",
274 | " id_sorted = np.argsort(D)\n",
275 | " D = np.array(D)[id_sorted]\n",
276 | " I = np.array(I)[id_sorted]\n",
277 | " return D, I"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 28,
283 | "metadata": {
284 | "colab": {
285 | "base_uri": "https://localhost:8080/",
286 | "height": 289
287 | },
288 | "colab_type": "code",
289 | "id": "_7JMlz-oisfp",
290 | "outputId": "21bcdb0a-a975-40c5-eb7f-1737c37f43af"
291 | },
292 | "outputs": [
293 | {
294 | "name": "stdout",
295 | "output_type": "stream",
296 | "text": [
297 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:0\n",
298 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:1\n",
299 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:2\n",
300 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:3\n",
301 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:4\n",
302 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:5\n",
303 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:6\n",
304 | "Search running in index: /job:worker/replica:0/task:0/device:TPU:7\n",
305 | "\n",
306 | "Actual ID : 1115204\n",
307 | "Result ID : 1115204 \n",
308 | "\n",
309 | "Neighbours : [1115204 84123 604532 881265 190994 1046794 466390 683556]\n",
310 | "Distances : [0. 1.9193 1.9222 1.9401 1.9465 1.9475 1.9555 1.9613]\n",
311 | "\n",
312 | "Time taken : 0.24 secs\n"
313 | ]
314 | }
315 | ],
316 | "source": [
317 | "n = 10086\n",
318 | "split = 7 # [0, 7] . # Pick nth vector from given split\n",
319 | "actual_n = vecs_per_index*split + n\n",
320 | "\n",
321 | "xq = tf.nn.l2_normalize(u_embeddings[split][n]) + tf.nn.l2_normalize(v_embeddings[split][n])\n",
322 | "xq = tf.reshape(xq, [1, -1])\n",
323 | "\n",
324 | "s = time()\n",
325 | "D, I = search(xq)\n",
326 | "e = time()\n",
327 | "\n",
328 | "print('\\nActual ID :', actual_n)\n",
329 | "print('Result ID :', I[0], '\\n')\n",
330 | "\n",
331 | "print('Neighbours :', I )\n",
332 | "print('Distances :', np.round(D, 4))\n",
333 | "print('\\nTime taken :', np.round(e-s, 2), 'secs')\n",
334 | "# First search runs slow, because tf.function traces the function\n",
335 | "# only for the first invocation, following invocations should run fine"
336 | ]
337 | },
338 | {
339 | "cell_type": "markdown",
340 | "metadata": {
341 | "colab_type": "text",
342 | "id": "T517kc8Q0c0-"
343 | },
344 | "source": [
345 | "#### Checking accuracy"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "metadata": {
352 | "colab": {},
353 | "colab_type": "code",
354 | "id": "FpUfykJkxlKN"
355 | },
356 | "outputs": [],
357 | "source": [
358 | "n_test = 5000\n",
359 | "random_n = np.random.randint(0, vecs_per_index, n_test)\n",
360 | "random_split = np.random.randint(0, 8, n_test)"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": 25,
366 | "metadata": {
367 | "colab": {
368 | "base_uri": "https://localhost:8080/",
369 | "height": 117,
370 | "referenced_widgets": [
371 | "77c45ec9d93c43bbb523da2ec45cd90f",
372 | "3aa9843e4b5241e1808bc55692259cf7",
373 | "e970f15766aa42c281df14b31b902f92",
374 | "d8bf419f5cf1491fa021fdabb16b9a53",
375 | "66be0be7823f4cf399f302f0a61b48d5",
376 | "0de3e0e635c44282acc831a417d0e309",
377 | "b4ef6cef8bfc4adbb0fc1e523b6ed0d7",
378 | "f76a09689f0848909194dd5c31ebf82d"
379 | ]
380 | },
381 | "colab_type": "code",
382 | "id": "4poe0mEdy454",
383 | "outputId": "187e5373-9ebe-4970-f1c5-378bb8ae4b40"
384 | },
385 | "outputs": [
386 | {
387 | "data": {
388 | "application/vnd.jupyter.widget-view+json": {
389 | "model_id": "77c45ec9d93c43bbb523da2ec45cd90f",
390 | "version_major": 2,
391 | "version_minor": 0
392 | },
393 | "text/plain": [
394 | "HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))"
395 | ]
396 | },
397 | "metadata": {
398 | "tags": []
399 | },
400 | "output_type": "display_data"
401 | },
402 | {
403 | "name": "stdout",
404 | "output_type": "stream",
405 | "text": [
406 | "\n",
407 | "\n",
408 | "Time taken per search : 0.24776199999999998 secs\n",
409 | "Accuracy : 1.0\n"
410 | ]
411 | }
412 | ],
413 | "source": [
414 | "y_true= []\n",
415 | "y_pred = []\n",
416 | "s = time()\n",
417 | "for n, split in tqdm(zip(random_n, random_split), total=n_test):\n",
418 | " xq = tf.nn.l2_normalize(u_embeddings[split][n]) + tf.nn.l2_normalize(v_embeddings[split][n])\n",
419 | " xq = tf.reshape(xq, [1, -1])\n",
420 | " actual_n = vecs_per_index*split + n\n",
421 | " D, I = [], []\n",
422 | " for i in range(8):\n",
423 | " d, idx = indices[i].search(xq, 1)\n",
424 | " D.append(d.numpy()[0])\n",
425 | " I.append(i*vecs_per_index + idx.numpy()[0])\n",
426 | " id_sorted = np.argsort(D)\n",
427 | " y_pred.append(np.array(I)[id_sorted])\n",
428 | " y_true.append(actual_n)\n",
429 | "e = time()\n",
430 | "y_true = np.array(y_true)\n",
431 | "y_pred = np.array(y_pred)\n",
432 | "print('\\nTime taken per search :', np.round(e-s, 2) / n_test, 'secs')\n",
433 | "print('Accuracy :', np.sum(y_true == y_pred[:, 0]) / n_test)"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": null,
439 | "metadata": {
440 | "colab": {},
441 | "colab_type": "code",
442 | "id": "ghMRw0jJj2i3"
443 | },
444 | "outputs": [],
445 | "source": []
446 | }
447 | ],
448 | "metadata": {
449 | "accelerator": "TPU",
450 | "colab": {
451 | "collapsed_sections": [],
452 | "include_colab_link": true,
453 | "name": "tpu index.ipynb",
454 | "provenance": []
455 | },
456 | "kernelspec": {
457 | "display_name": "Python 3",
458 | "language": "python",
459 | "name": "python3"
460 | },
461 | "language_info": {
462 | "codemirror_mode": {
463 | "name": "ipython",
464 | "version": 3
465 | },
466 | "file_extension": ".py",
467 | "mimetype": "text/x-python",
468 | "name": "python",
469 | "nbconvert_exporter": "python",
470 | "pygments_lexer": "ipython3",
471 | "version": "3.6.7"
472 | },
473 | "widgets": {
474 | "application/vnd.jupyter.widget-state+json": {
475 | "0de3e0e635c44282acc831a417d0e309": {
476 | "model_module": "@jupyter-widgets/base",
477 | "model_name": "LayoutModel",
478 | "state": {
479 | "_model_module": "@jupyter-widgets/base",
480 | "_model_module_version": "1.2.0",
481 | "_model_name": "LayoutModel",
482 | "_view_count": null,
483 | "_view_module": "@jupyter-widgets/base",
484 | "_view_module_version": "1.2.0",
485 | "_view_name": "LayoutView",
486 | "align_content": null,
487 | "align_items": null,
488 | "align_self": null,
489 | "border": null,
490 | "bottom": null,
491 | "display": null,
492 | "flex": null,
493 | "flex_flow": null,
494 | "grid_area": null,
495 | "grid_auto_columns": null,
496 | "grid_auto_flow": null,
497 | "grid_auto_rows": null,
498 | "grid_column": null,
499 | "grid_gap": null,
500 | "grid_row": null,
501 | "grid_template_areas": null,
502 | "grid_template_columns": null,
503 | "grid_template_rows": null,
504 | "height": null,
505 | "justify_content": null,
506 | "justify_items": null,
507 | "left": null,
508 | "margin": null,
509 | "max_height": null,
510 | "max_width": null,
511 | "min_height": null,
512 | "min_width": null,
513 | "object_fit": null,
514 | "object_position": null,
515 | "order": null,
516 | "overflow": null,
517 | "overflow_x": null,
518 | "overflow_y": null,
519 | "padding": null,
520 | "right": null,
521 | "top": null,
522 | "visibility": null,
523 | "width": null
524 | }
525 | },
526 | "3aa9843e4b5241e1808bc55692259cf7": {
527 | "model_module": "@jupyter-widgets/base",
528 | "model_name": "LayoutModel",
529 | "state": {
530 | "_model_module": "@jupyter-widgets/base",
531 | "_model_module_version": "1.2.0",
532 | "_model_name": "LayoutModel",
533 | "_view_count": null,
534 | "_view_module": "@jupyter-widgets/base",
535 | "_view_module_version": "1.2.0",
536 | "_view_name": "LayoutView",
537 | "align_content": null,
538 | "align_items": null,
539 | "align_self": null,
540 | "border": null,
541 | "bottom": null,
542 | "display": null,
543 | "flex": null,
544 | "flex_flow": null,
545 | "grid_area": null,
546 | "grid_auto_columns": null,
547 | "grid_auto_flow": null,
548 | "grid_auto_rows": null,
549 | "grid_column": null,
550 | "grid_gap": null,
551 | "grid_row": null,
552 | "grid_template_areas": null,
553 | "grid_template_columns": null,
554 | "grid_template_rows": null,
555 | "height": null,
556 | "justify_content": null,
557 | "justify_items": null,
558 | "left": null,
559 | "margin": null,
560 | "max_height": null,
561 | "max_width": null,
562 | "min_height": null,
563 | "min_width": null,
564 | "object_fit": null,
565 | "object_position": null,
566 | "order": null,
567 | "overflow": null,
568 | "overflow_x": null,
569 | "overflow_y": null,
570 | "padding": null,
571 | "right": null,
572 | "top": null,
573 | "visibility": null,
574 | "width": null
575 | }
576 | },
577 | "66be0be7823f4cf399f302f0a61b48d5": {
578 | "model_module": "@jupyter-widgets/controls",
579 | "model_name": "ProgressStyleModel",
580 | "state": {
581 | "_model_module": "@jupyter-widgets/controls",
582 | "_model_module_version": "1.5.0",
583 | "_model_name": "ProgressStyleModel",
584 | "_view_count": null,
585 | "_view_module": "@jupyter-widgets/base",
586 | "_view_module_version": "1.2.0",
587 | "_view_name": "StyleView",
588 | "bar_color": null,
589 | "description_width": ""
590 | }
591 | },
592 | "77c45ec9d93c43bbb523da2ec45cd90f": {
593 | "model_module": "@jupyter-widgets/controls",
594 | "model_name": "HBoxModel",
595 | "state": {
596 | "_dom_classes": [],
597 | "_model_module": "@jupyter-widgets/controls",
598 | "_model_module_version": "1.5.0",
599 | "_model_name": "HBoxModel",
600 | "_view_count": null,
601 | "_view_module": "@jupyter-widgets/controls",
602 | "_view_module_version": "1.5.0",
603 | "_view_name": "HBoxView",
604 | "box_style": "",
605 | "children": [
606 | "IPY_MODEL_e970f15766aa42c281df14b31b902f92",
607 | "IPY_MODEL_d8bf419f5cf1491fa021fdabb16b9a53"
608 | ],
609 | "layout": "IPY_MODEL_3aa9843e4b5241e1808bc55692259cf7"
610 | }
611 | },
612 | "b4ef6cef8bfc4adbb0fc1e523b6ed0d7": {
613 | "model_module": "@jupyter-widgets/controls",
614 | "model_name": "DescriptionStyleModel",
615 | "state": {
616 | "_model_module": "@jupyter-widgets/controls",
617 | "_model_module_version": "1.5.0",
618 | "_model_name": "DescriptionStyleModel",
619 | "_view_count": null,
620 | "_view_module": "@jupyter-widgets/base",
621 | "_view_module_version": "1.2.0",
622 | "_view_name": "StyleView",
623 | "description_width": ""
624 | }
625 | },
626 | "d8bf419f5cf1491fa021fdabb16b9a53": {
627 | "model_module": "@jupyter-widgets/controls",
628 | "model_name": "HTMLModel",
629 | "state": {
630 | "_dom_classes": [],
631 | "_model_module": "@jupyter-widgets/controls",
632 | "_model_module_version": "1.5.0",
633 | "_model_name": "HTMLModel",
634 | "_view_count": null,
635 | "_view_module": "@jupyter-widgets/controls",
636 | "_view_module_version": "1.5.0",
637 | "_view_name": "HTMLView",
638 | "description": "",
639 | "description_tooltip": null,
640 | "layout": "IPY_MODEL_f76a09689f0848909194dd5c31ebf82d",
641 | "placeholder": "",
642 | "style": "IPY_MODEL_b4ef6cef8bfc4adbb0fc1e523b6ed0d7",
643 | "value": "100% 5000/5000 [20:38<00:00, 4.11it/s]"
644 | }
645 | },
646 | "e970f15766aa42c281df14b31b902f92": {
647 | "model_module": "@jupyter-widgets/controls",
648 | "model_name": "IntProgressModel",
649 | "state": {
650 | "_dom_classes": [],
651 | "_model_module": "@jupyter-widgets/controls",
652 | "_model_module_version": "1.5.0",
653 | "_model_name": "IntProgressModel",
654 | "_view_count": null,
655 | "_view_module": "@jupyter-widgets/controls",
656 | "_view_module_version": "1.5.0",
657 | "_view_name": "ProgressView",
658 | "bar_style": "success",
659 | "description": "",
660 | "description_tooltip": null,
661 | "layout": "IPY_MODEL_0de3e0e635c44282acc831a417d0e309",
662 | "max": 5000,
663 | "min": 0,
664 | "orientation": "horizontal",
665 | "style": "IPY_MODEL_66be0be7823f4cf399f302f0a61b48d5",
666 | "value": 5000
667 | }
668 | },
669 | "f76a09689f0848909194dd5c31ebf82d": {
670 | "model_module": "@jupyter-widgets/base",
671 | "model_name": "LayoutModel",
672 | "state": {
673 | "_model_module": "@jupyter-widgets/base",
674 | "_model_module_version": "1.2.0",
675 | "_model_name": "LayoutModel",
676 | "_view_count": null,
677 | "_view_module": "@jupyter-widgets/base",
678 | "_view_module_version": "1.2.0",
679 | "_view_name": "LayoutView",
680 | "align_content": null,
681 | "align_items": null,
682 | "align_self": null,
683 | "border": null,
684 | "bottom": null,
685 | "display": null,
686 | "flex": null,
687 | "flex_flow": null,
688 | "grid_area": null,
689 | "grid_auto_columns": null,
690 | "grid_auto_flow": null,
691 | "grid_auto_rows": null,
692 | "grid_column": null,
693 | "grid_gap": null,
694 | "grid_row": null,
695 | "grid_template_areas": null,
696 | "grid_template_columns": null,
697 | "grid_template_rows": null,
698 | "height": null,
699 | "justify_content": null,
700 | "justify_items": null,
701 | "left": null,
702 | "margin": null,
703 | "max_height": null,
704 | "max_width": null,
705 | "min_height": null,
706 | "min_width": null,
707 | "object_fit": null,
708 | "object_position": null,
709 | "order": null,
710 | "overflow": null,
711 | "overflow_x": null,
712 | "overflow_y": null,
713 | "padding": null,
714 | "right": null,
715 | "top": null,
716 | "visibility": null,
717 | "width": null
718 | }
719 | }
720 | }
721 | }
722 | },
723 | "nbformat": 4,
724 | "nbformat_minor": 4
725 | }
726 |
--------------------------------------------------------------------------------
/notebooks/training/TF2.0 Word2Vec CBOW.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "TF2.0 Word2Vec CBOW",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "accelerator": "GPU"
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "9weW9X7-vLsg",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "Although in the end this model was not used, our group felt it'd still be appropriate to add to our github for the completeness of our submission to TFWorld competition and to allow the community to use this model if the need arose."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "metadata": {
30 | "id": "4gFQvB6MatHC",
31 | "colab_type": "code",
32 | "colab": {}
33 | },
34 | "source": [
35 | "try:\n",
36 | " %tensorflow_version 2.x\n",
37 | "except Exception:\n",
38 | " pass\n",
39 | "import tensorflow as tf\n",
40 | "\n",
41 | "import tensorflow_datasets as tfds\n",
42 | "import os\n",
43 | "from tensorflow import keras \n",
44 | "from tensorflow.keras.layers import Input, Lambda, Dense, dot, Reshape, Embedding\n",
45 | "# from keras.layers.embeddings import Embedding\n",
46 | "from tensorflow.keras import backend as K\n",
47 | "from keras.preprocessing.sequence import pad_sequences\n",
48 | "import numpy as np\n"
49 | ],
50 | "execution_count": 0,
51 | "outputs": []
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {
56 | "id": "fDVgocMrv0lm",
57 | "colab_type": "text"
58 | },
59 | "source": [
60 | "A note for the following cell. Althouhg our model is the continuous bag of words version of Word2Vec we used Keras's skipgrams preprocessing for nengative sampling due to how our data wasn't formated as sequences. Read more on these functions [here](https://keras.io/preprocessing/sequence/)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "metadata": {
66 | "id": "XP91bFU8Qq6P",
67 | "colab_type": "code",
68 | "colab": {}
69 | },
70 | "source": [
71 | "from keras.preprocessing.sequence import skipgrams, make_sampling_table\n",
72 | "def sampling(inCitation:list, outCitation:list, window_size=3):\n",
73 | " ''' InCitation the original paper, OutCitation: the papers cited in the original paper'''\n",
74 | " global vocab_size\n",
75 | " vocab = list(set(inCitation))\n",
76 | " for out in outCitation:\n",
77 | " for paper in out:\n",
78 | " vocab.append(paper)\n",
79 | " vocab = list(set(vocab))\n",
80 | " vocab_size = len(vocab)\n",
81 | "\n",
82 | " sampling_table = make_sampling_table(vocab_size)\n",
83 | " labels = []\n",
84 | " data = []\n",
85 | " target = []\n",
86 | " for i in range(len(inCitation)):\n",
87 | " out = outCitation[i] \n",
88 | " contexes, label = skipgrams(out, vocab_size, window_size=window_size)\n",
89 | " data.append(contexes)\n",
90 | " labels.append(label)\n",
91 | " target.append([inCitation[i]] * len(label))\n",
92 | " \n",
93 | " return target, data, labels\n"
94 | ],
95 | "execution_count": 0,
96 | "outputs": []
97 | },
98 | {
99 | "cell_type": "code",
100 | "metadata": {
101 | "id": "p7EOtljpdQES",
102 | "colab_type": "code",
103 | "colab": {}
104 | },
105 | "source": [
106 | "target, data, labels = sampling([1,2,3,5,6], [[3,6,4], [4,2,1], [1,2,5], [2,5,3], [1,2,3], [5]]) #Dummy data"
107 | ],
108 | "execution_count": 0,
109 | "outputs": []
110 | },
111 | {
112 | "cell_type": "code",
113 | "metadata": {
114 | "id": "CZuv1sHSjWKo",
115 | "colab_type": "code",
116 | "outputId": "d7b77197-f7a0-47dc-b31d-d249fc148df8",
117 | "colab": {
118 | "base_uri": "https://localhost:8080/",
119 | "height": 1000
120 | }
121 | },
122 | "source": [
123 | "class Word2CBOW(keras.Model): #I should stop naming things\n",
124 | " def __init__(self, window_size=3, **kwargs):\n",
125 | " super().__init__(**kwargs) #handles standard args (e.g., name)\n",
126 | " #super() is to use the keras.Model class\n",
127 | " #To add: Argument for window size and arg for\n",
128 | " self.embedding_layer = Embedding(7, 768, input_length=2) #Only working with a vocab of 6 to test\n",
129 | " self.window_size = 3\n",
130 | " self.context_window = self.window_size * 2\n",
131 | " self.outvec = Dense(1, activation='sigmoid')\n",
132 | " self.similarity = 0\n",
133 | " self.cbow = Lambda(lambda x: K.mean(x, axis=[-1]))\n",
134 | " self.batched_dot = Lambda(self.bdotFunction)\n",
135 | "\n",
136 | " def bdotFunction(self, x):\n",
137 | " first = x[0]\n",
138 | " second = x[1]\n",
139 | " return K.batch_dot(first, second, axes=-1)\n",
140 | "\n",
141 | " def call(self, inputs):\n",
142 | " target_input, context_inputs = inputs\n",
143 | "\n",
144 | " target_input = keras.layers.InputLayer(input_shape=[1,])(target_input)\n",
145 | " context_inputs = keras.layers.InputLayer(input_shape=[2,])(context_inputs)\n",
146 | " \n",
147 | " target1 = self.embedding_layer(target_input)\n",
148 | " context = self.embedding_layer(context_inputs)\n",
149 | " \n",
150 | " context = self.cbow(context) #Averaging the context vectors\n",
151 | " \n",
152 | "\n",
153 | " dotted = self.batched_dot([tf.squeeze(target1, axis=0), tf.squeeze(context, axis=0)])\n",
154 | "\n",
155 | " binary_output = self.outvec(dotted)\n",
156 | " binary_output = tf.squeeze(binary_output)\n",
157 | "\n",
158 | " return binary_output\n",
159 | "model = Word2CBOW()\n",
160 | "model.compile(loss='binary_crossentropy', optimizer='rmsprop')\n",
161 | "# labels = np.array(np.transpose(labels))\n",
162 | "model.fit((np.array(target, dtype=np.int32), np.array(data, dtype=np.int32)), y=np.array(labels), epochs=100, batch_size=1)\n",
163 | "# model((np.array(target[0], dtype=np.int32), np.array(data[0], dtype=np.int32)))\n",
164 | "#Might have to try train_on_batch"
165 | ],
166 | "execution_count": 0,
167 | "outputs": [
168 | {
169 | "output_type": "stream",
170 | "text": [
171 | "(12,)\n",
172 | "(12,)\n",
173 | "Train on 5 samples\n",
174 | "Epoch 1/100\n",
175 | "(12,)\n",
176 | "(12,)\n",
177 | "(12,)\n",
178 | "(12,)\n",
179 | "5/5 [==============================] - 1s 113ms/sample - loss: 0.6924\n",
180 | "Epoch 2/100\n",
181 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6896\n",
182 | "Epoch 3/100\n",
183 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6876\n",
184 | "Epoch 4/100\n",
185 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6856\n",
186 | "Epoch 5/100\n",
187 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6835\n",
188 | "Epoch 6/100\n",
189 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6815\n",
190 | "Epoch 7/100\n",
191 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6793\n",
192 | "Epoch 8/100\n",
193 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6771\n",
194 | "Epoch 9/100\n",
195 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6747\n",
196 | "Epoch 10/100\n",
197 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6723\n",
198 | "Epoch 11/100\n",
199 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6699\n",
200 | "Epoch 12/100\n",
201 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6676\n",
202 | "Epoch 13/100\n",
203 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6651\n",
204 | "Epoch 14/100\n",
205 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6627\n",
206 | "Epoch 15/100\n",
207 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6604\n",
208 | "Epoch 16/100\n",
209 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6582\n",
210 | "Epoch 17/100\n",
211 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6560\n",
212 | "Epoch 18/100\n",
213 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6538\n",
214 | "Epoch 19/100\n",
215 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6515\n",
216 | "Epoch 20/100\n",
217 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6492\n",
218 | "Epoch 21/100\n",
219 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6471\n",
220 | "Epoch 22/100\n",
221 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6450\n",
222 | "Epoch 23/100\n",
223 | "5/5 [==============================] - 0s 6ms/sample - loss: 0.6428\n",
224 | "Epoch 24/100\n",
225 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6406\n",
226 | "Epoch 25/100\n",
227 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6384\n",
228 | "Epoch 26/100\n",
229 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6364\n",
230 | "Epoch 27/100\n",
231 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6343\n",
232 | "Epoch 28/100\n",
233 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6321\n",
234 | "Epoch 29/100\n",
235 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6301\n",
236 | "Epoch 30/100\n",
237 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6282\n",
238 | "Epoch 31/100\n",
239 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6263\n",
240 | "Epoch 32/100\n",
241 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6244\n",
242 | "Epoch 33/100\n",
243 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6226\n",
244 | "Epoch 34/100\n",
245 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6207\n",
246 | "Epoch 35/100\n",
247 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6190\n",
248 | "Epoch 36/100\n",
249 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6174\n",
250 | "Epoch 37/100\n",
251 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6158\n",
252 | "Epoch 38/100\n",
253 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6143\n",
254 | "Epoch 39/100\n",
255 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6129\n",
256 | "Epoch 40/100\n",
257 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6113\n",
258 | "Epoch 41/100\n",
259 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6100\n",
260 | "Epoch 42/100\n",
261 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6086\n",
262 | "Epoch 43/100\n",
263 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6073\n",
264 | "Epoch 44/100\n",
265 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6060\n",
266 | "Epoch 45/100\n",
267 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6048\n",
268 | "Epoch 46/100\n",
269 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6037\n",
270 | "Epoch 47/100\n",
271 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6025\n",
272 | "Epoch 48/100\n",
273 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.6013\n",
274 | "Epoch 49/100\n",
275 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.6002\n",
276 | "Epoch 50/100\n",
277 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5990\n",
278 | "Epoch 51/100\n",
279 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5980\n",
280 | "Epoch 52/100\n",
281 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5971\n",
282 | "Epoch 53/100\n",
283 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5961\n",
284 | "Epoch 54/100\n",
285 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5951\n",
286 | "Epoch 55/100\n",
287 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5943\n",
288 | "Epoch 56/100\n",
289 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5934\n",
290 | "Epoch 57/100\n",
291 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5925\n",
292 | "Epoch 58/100\n",
293 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5917\n",
294 | "Epoch 59/100\n",
295 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5910\n",
296 | "Epoch 60/100\n",
297 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5901\n",
298 | "Epoch 61/100\n",
299 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5895\n",
300 | "Epoch 62/100\n",
301 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5888\n",
302 | "Epoch 63/100\n",
303 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5882\n",
304 | "Epoch 64/100\n",
305 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5876\n",
306 | "Epoch 65/100\n",
307 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5870\n",
308 | "Epoch 66/100\n",
309 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5864\n",
310 | "Epoch 67/100\n",
311 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5858\n",
312 | "Epoch 68/100\n",
313 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5854\n",
314 | "Epoch 69/100\n",
315 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5849\n",
316 | "Epoch 70/100\n",
317 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5845\n",
318 | "Epoch 71/100\n",
319 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5840\n",
320 | "Epoch 72/100\n",
321 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5836\n",
322 | "Epoch 73/100\n",
323 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5831\n",
324 | "Epoch 74/100\n",
325 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5827\n",
326 | "Epoch 75/100\n",
327 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5824\n",
328 | "Epoch 76/100\n",
329 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5820\n",
330 | "Epoch 77/100\n",
331 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5816\n",
332 | "Epoch 78/100\n",
333 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5813\n",
334 | "Epoch 79/100\n",
335 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5809\n",
336 | "Epoch 80/100\n",
337 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5807\n",
338 | "Epoch 81/100\n",
339 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5803\n",
340 | "Epoch 82/100\n",
341 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5801\n",
342 | "Epoch 83/100\n",
343 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5797\n",
344 | "Epoch 84/100\n",
345 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5794\n",
346 | "Epoch 85/100\n",
347 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5792\n",
348 | "Epoch 86/100\n",
349 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5789\n",
350 | "Epoch 87/100\n",
351 | "5/5 [==============================] - 0s 3ms/sample - loss: 0.5787\n",
352 | "Epoch 88/100\n",
353 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5784\n",
354 | "Epoch 89/100\n",
355 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5781\n",
356 | "Epoch 90/100\n",
357 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5778\n",
358 | "Epoch 91/100\n",
359 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5776\n",
360 | "Epoch 92/100\n",
361 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5774\n",
362 | "Epoch 93/100\n",
363 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5772\n",
364 | "Epoch 94/100\n",
365 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5769\n",
366 | "Epoch 95/100\n",
367 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5767\n",
368 | "Epoch 96/100\n",
369 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5765\n",
370 | "Epoch 97/100\n",
371 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5763\n",
372 | "Epoch 98/100\n",
373 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5761\n",
374 | "Epoch 99/100\n",
375 | "5/5 [==============================] - 0s 5ms/sample - loss: 0.5758\n",
376 | "Epoch 100/100\n",
377 | "5/5 [==============================] - 0s 4ms/sample - loss: 0.5755\n"
378 | ],
379 | "name": "stdout"
380 | },
381 | {
382 | "output_type": "execute_result",
383 | "data": {
384 | "text/plain": [
385 | ""
386 | ]
387 | },
388 | "metadata": {
389 | "tags": []
390 | },
391 | "execution_count": 444
392 | }
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "metadata": {
398 | "id": "faQEWPbSxnQ7",
399 | "colab_type": "code",
400 | "outputId": "9f2848c0-396f-41c2-c643-6f4393f2471e",
401 | "colab": {
402 | "base_uri": "https://localhost:8080/",
403 | "height": 34
404 | }
405 | },
406 | "source": [
407 | "np.array(np.transpose(labels)).shape"
408 | ],
409 | "execution_count": 0,
410 | "outputs": [
411 | {
412 | "output_type": "execute_result",
413 | "data": {
414 | "text/plain": [
415 | "(12, 5)"
416 | ]
417 | },
418 | "metadata": {
419 | "tags": []
420 | },
421 | "execution_count": 409
422 | }
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "metadata": {
428 | "id": "f4a0gK6MOpu3",
429 | "colab_type": "code",
430 | "outputId": "90f3451d-3ef0-4059-a592-a3f823ee6b1b",
431 | "colab": {
432 | "base_uri": "https://localhost:8080/",
433 | "height": 34
434 | }
435 | },
436 | "source": [
437 | "np.array(labels).shape"
438 | ],
439 | "execution_count": 0,
440 | "outputs": [
441 | {
442 | "output_type": "execute_result",
443 | "data": {
444 | "text/plain": [
445 | "(5, 12)"
446 | ]
447 | },
448 | "metadata": {
449 | "tags": []
450 | },
451 | "execution_count": 400
452 | }
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "metadata": {
458 | "id": "zVrmlR4wdNa2",
459 | "colab_type": "code",
460 | "colab": {}
461 | },
462 | "source": [
463 | "# target = [1,2,3,5,6]\n",
464 | "inputs = (np.array(target[0], dtype=np.int32), np.array(data[0], dtype=np.int32))"
465 | ],
466 | "execution_count": 0,
467 | "outputs": []
468 | },
469 | {
470 | "cell_type": "code",
471 | "metadata": {
472 | "id": "aqEk4_HadQ9k",
473 | "colab_type": "code",
474 | "outputId": "853fb5ef-d3d7-4373-ee45-342815832b5a",
475 | "colab": {
476 | "base_uri": "https://localhost:8080/",
477 | "height": 34
478 | }
479 | },
480 | "source": [
481 | "target_input, context_inputs = inputs\n",
482 | "target_input = keras.layers.InputLayer(input_shape=[1,])(target_input)\n",
483 | "context_inputs = keras.layers.InputLayer(input_shape=[2,])(context_inputs)\n",
484 | "context_inputs.shape, target_input.shape"
485 | ],
486 | "execution_count": 0,
487 | "outputs": [
488 | {
489 | "output_type": "execute_result",
490 | "data": {
491 | "text/plain": [
492 | "(TensorShape([12, 2]), TensorShape([12]))"
493 | ]
494 | },
495 | "metadata": {
496 | "tags": []
497 | },
498 | "execution_count": 424
499 | }
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "metadata": {
505 | "id": "QXjAF0AadaQU",
506 | "colab_type": "code",
507 | "outputId": "26b68b54-bc75-41d4-9c3f-ae765abb72f5",
508 | "colab": {
509 | "base_uri": "https://localhost:8080/",
510 | "height": 34
511 | }
512 | },
513 | "source": [
514 | "embedding_layer = Embedding(7, 768, input_length=2)\n",
515 | "target = embedding_layer(target_input)\n",
516 | "context = embedding_layer(context_inputs)\n",
517 | "context.shape, target.shape"
518 | ],
519 | "execution_count": 0,
520 | "outputs": [
521 | {
522 | "output_type": "execute_result",
523 | "data": {
524 | "text/plain": [
525 | "(TensorShape([12, 2, 768]), TensorShape([12, 768]))"
526 | ]
527 | },
528 | "metadata": {
529 | "tags": []
530 | },
531 | "execution_count": 432
532 | }
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "metadata": {
538 | "id": "t71lMbdGd-xC",
539 | "colab_type": "code",
540 | "outputId": "30b05b8b-aa74-4efc-9da0-b124bc58f8e4",
541 | "colab": {
542 | "base_uri": "https://localhost:8080/",
543 | "height": 34
544 | }
545 | },
546 | "source": [
547 | "cbow = Lambda(lambda x: K.mean(x, axis=[1]))\n",
548 | "context = cbow(context)\n",
549 | "context.shape"
550 | ],
551 | "execution_count": 0,
552 | "outputs": [
553 | {
554 | "output_type": "execute_result",
555 | "data": {
556 | "text/plain": [
557 | "TensorShape([12, 768])"
558 | ]
559 | },
560 | "metadata": {
561 | "tags": []
562 | },
563 | "execution_count": 433
564 | }
565 | ]
566 | },
567 | {
568 | "cell_type": "code",
569 | "metadata": {
570 | "id": "5S1CZHoJghvx",
571 | "colab_type": "code",
572 | "outputId": "e7119180-6791-4934-8585-1b77e5f3c048",
573 | "colab": {
574 | "base_uri": "https://localhost:8080/",
575 | "height": 34
576 | }
577 | },
578 | "source": [
579 | "dotted = dot([target, context], axes=1)\n",
580 | "dotted.shape"
581 | ],
582 | "execution_count": 0,
583 | "outputs": [
584 | {
585 | "output_type": "execute_result",
586 | "data": {
587 | "text/plain": [
588 | "TensorShape([12, 1])"
589 | ]
590 | },
591 | "metadata": {
592 | "tags": []
593 | },
594 | "execution_count": 435
595 | }
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "metadata": {
601 | "id": "EZF62FmYiUFb",
602 | "colab_type": "code",
603 | "outputId": "a55d42a3-f680-4b77-ad42-33b6ad1c9f20",
604 | "colab": {
605 | "base_uri": "https://localhost:8080/",
606 | "height": 34
607 | }
608 | },
609 | "source": [
610 | "outvec = Dense(1, activation='sigmoid')\n",
611 | "done = outvec(dotted)\n",
612 | "# done.shape\n",
613 | "tf.convert_to_tensor(labels[0]).shape, tf.squeeze(done).shape"
614 | ],
615 | "execution_count": 0,
616 | "outputs": [
617 | {
618 | "output_type": "execute_result",
619 | "data": {
620 | "text/plain": [
621 | "(TensorShape([12]), TensorShape([12]))"
622 | ]
623 | },
624 | "metadata": {
625 | "tags": []
626 | },
627 | "execution_count": 437
628 | }
629 | ]
630 | },
631 | {
632 | "cell_type": "code",
633 | "metadata": {
634 | "id": "7MVubitzliSs",
635 | "colab_type": "code",
636 | "colab": {}
637 | },
638 | "source": [
639 | ""
640 | ],
641 | "execution_count": 0,
642 | "outputs": []
643 | }
644 | ]
645 | }
646 |
--------------------------------------------------------------------------------
/notebooks/training/model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "colab": {},
18 | "colab_type": "code",
19 | "id": "QK0_IOe0LkEY"
20 | },
21 | "outputs": [],
22 | "source": [
23 | "# !pip install transformers cloud-tpu-client torch\n",
24 | "# !wget 'https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/huggingface_pytorch/scibert_scivocab_uncased.tar'\n",
25 | "# !tar -xvf 'scibert_scivocab_uncased.tar'"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {
32 | "colab": {
33 | "base_uri": "https://localhost:8080/",
34 | "height": 34
35 | },
36 | "colab_type": "code",
37 | "id": "S_hHnz5fLiZ5",
38 | "outputId": "35daa83e-555f-48db-871a-c1b0098c3516"
39 | },
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "TensorFlow: 2.1.0-dev20191226\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "import os\n",
51 | "import tensorflow as tf\n",
52 | "from tensorflow.keras import backend as K\n",
53 | "from tensorflow.keras.layers import Lambda, Dense, Activation, Concatenate, Dropout\n",
54 | "from transformers import TFBertModel\n",
55 | "from time import time\n",
56 | "print('TensorFlow:', tf.__version__)"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 3,
62 | "metadata": {
63 | "colab": {
64 | "base_uri": "https://localhost:8080/",
65 | "height": 51
66 | },
67 | "colab_type": "code",
68 | "id": "-1Mu1rsDMHCF",
69 | "outputId": "d604ad71-9e00-433f-b6db-bb2e2119bfb3"
70 | },
71 | "outputs": [
72 | {
73 | "name": "stdout",
74 | "output_type": "stream",
75 | "text": [
76 | "Running on TPU ['192.168.19.2:8470']\n"
77 | ]
78 | },
79 | {
80 | "name": "stderr",
81 | "output_type": "stream",
82 | "text": [
83 | "INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0\n"
84 | ]
85 | },
86 | {
87 | "name": "stdout",
88 | "output_type": "stream",
89 | "text": [
90 | "INFO:tensorflow:Initializing the TPU system: srihari-1-tpu\n"
91 | ]
92 | },
93 | {
94 | "name": "stderr",
95 | "output_type": "stream",
96 | "text": [
97 | "INFO:tensorflow:Initializing the TPU system: srihari-1-tpu\n"
98 | ]
99 | },
100 | {
101 | "name": "stdout",
102 | "output_type": "stream",
103 | "text": [
104 | "INFO:tensorflow:Clearing out eager caches\n"
105 | ]
106 | },
107 | {
108 | "name": "stderr",
109 | "output_type": "stream",
110 | "text": [
111 | "INFO:tensorflow:Clearing out eager caches\n"
112 | ]
113 | },
114 | {
115 | "name": "stdout",
116 | "output_type": "stream",
117 | "text": [
118 | "INFO:tensorflow:Finished initializing TPU system.\n"
119 | ]
120 | },
121 | {
122 | "name": "stderr",
123 | "output_type": "stream",
124 | "text": [
125 | "INFO:tensorflow:Finished initializing TPU system.\n"
126 | ]
127 | },
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "INFO:tensorflow:Found TPU system:\n"
133 | ]
134 | },
135 | {
136 | "name": "stderr",
137 | "output_type": "stream",
138 | "text": [
139 | "INFO:tensorflow:Found TPU system:\n"
140 | ]
141 | },
142 | {
143 | "name": "stdout",
144 | "output_type": "stream",
145 | "text": [
146 | "INFO:tensorflow:*** Num TPU Cores: 8\n"
147 | ]
148 | },
149 | {
150 | "name": "stderr",
151 | "output_type": "stream",
152 | "text": [
153 | "INFO:tensorflow:*** Num TPU Cores: 8\n"
154 | ]
155 | },
156 | {
157 | "name": "stdout",
158 | "output_type": "stream",
159 | "text": [
160 | "INFO:tensorflow:*** Num TPU Workers: 1\n"
161 | ]
162 | },
163 | {
164 | "name": "stderr",
165 | "output_type": "stream",
166 | "text": [
167 | "INFO:tensorflow:*** Num TPU Workers: 1\n"
168 | ]
169 | },
170 | {
171 | "name": "stdout",
172 | "output_type": "stream",
173 | "text": [
174 | "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
175 | ]
176 | },
177 | {
178 | "name": "stderr",
179 | "output_type": "stream",
180 | "text": [
181 | "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
182 | ]
183 | },
184 | {
185 | "name": "stdout",
186 | "output_type": "stream",
187 | "text": [
188 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
189 | ]
190 | },
191 | {
192 | "name": "stderr",
193 | "output_type": "stream",
194 | "text": [
195 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
196 | ]
197 | },
198 | {
199 | "name": "stdout",
200 | "output_type": "stream",
201 | "text": [
202 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
203 | ]
204 | },
205 | {
206 | "name": "stderr",
207 | "output_type": "stream",
208 | "text": [
209 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
210 | ]
211 | },
212 | {
213 | "name": "stdout",
214 | "output_type": "stream",
215 | "text": [
216 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
217 | ]
218 | },
219 | {
220 | "name": "stderr",
221 | "output_type": "stream",
222 | "text": [
223 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
224 | ]
225 | },
226 | {
227 | "name": "stdout",
228 | "output_type": "stream",
229 | "text": [
230 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
231 | ]
232 | },
233 | {
234 | "name": "stderr",
235 | "output_type": "stream",
236 | "text": [
237 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
238 | ]
239 | },
240 | {
241 | "name": "stdout",
242 | "output_type": "stream",
243 | "text": [
244 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
245 | ]
246 | },
247 | {
248 | "name": "stderr",
249 | "output_type": "stream",
250 | "text": [
251 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
252 | ]
253 | },
254 | {
255 | "name": "stdout",
256 | "output_type": "stream",
257 | "text": [
258 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
259 | ]
260 | },
261 | {
262 | "name": "stderr",
263 | "output_type": "stream",
264 | "text": [
265 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
266 | ]
267 | },
268 | {
269 | "name": "stdout",
270 | "output_type": "stream",
271 | "text": [
272 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
273 | ]
274 | },
275 | {
276 | "name": "stderr",
277 | "output_type": "stream",
278 | "text": [
279 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
280 | ]
281 | },
282 | {
283 | "name": "stdout",
284 | "output_type": "stream",
285 | "text": [
286 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
287 | ]
288 | },
289 | {
290 | "name": "stderr",
291 | "output_type": "stream",
292 | "text": [
293 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
294 | ]
295 | },
296 | {
297 | "name": "stdout",
298 | "output_type": "stream",
299 | "text": [
300 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
301 | ]
302 | },
303 | {
304 | "name": "stderr",
305 | "output_type": "stream",
306 | "text": [
307 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
308 | ]
309 | },
310 | {
311 | "name": "stdout",
312 | "output_type": "stream",
313 | "text": [
314 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
315 | ]
316 | },
317 | {
318 | "name": "stderr",
319 | "output_type": "stream",
320 | "text": [
321 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
322 | ]
323 | },
324 | {
325 | "name": "stdout",
326 | "output_type": "stream",
327 | "text": [
328 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
329 | ]
330 | },
331 | {
332 | "name": "stderr",
333 | "output_type": "stream",
334 | "text": [
335 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
336 | ]
337 | },
338 | {
339 | "name": "stdout",
340 | "output_type": "stream",
341 | "text": [
342 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
343 | ]
344 | },
345 | {
346 | "name": "stderr",
347 | "output_type": "stream",
348 | "text": [
349 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
350 | ]
351 | },
352 | {
353 | "name": "stdout",
354 | "output_type": "stream",
355 | "text": [
356 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
357 | ]
358 | },
359 | {
360 | "name": "stderr",
361 | "output_type": "stream",
362 | "text": [
363 | "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
364 | ]
365 | },
366 | {
367 | "name": "stdout",
368 | "output_type": "stream",
369 | "text": [
370 | "REPLICAS: 8\n"
371 | ]
372 | }
373 | ],
374 | "source": [
375 | "try:\n",
376 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver('srihari-1-tpu') # TPU detection\n",
377 | " print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
378 | "except ValueError:\n",
379 | " tpu = None\n",
380 | "\n",
381 | "if tpu:\n",
382 | " tf.config.experimental_connect_to_cluster(tpu)\n",
383 | " tf.tpu.experimental.initialize_tpu_system(tpu)\n",
384 | " strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
385 | "else:\n",
386 | " strategy = tf.distribute.MirroredStrategy()\n",
387 | "\n",
388 | "print(\"REPLICAS: \", strategy.num_replicas_in_sync)"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 4,
394 | "metadata": {
395 | "colab": {},
396 | "colab_type": "code",
397 | "id": "dYscLlW_7wIR"
398 | },
399 | "outputs": [
400 | {
401 | "name": "stdout",
402 | "output_type": "stream",
403 | "text": [
404 | "Batch Size: 256\n"
405 | ]
406 | }
407 | ],
408 | "source": [
409 | "batch_size = 32 * strategy.num_replicas_in_sync\n",
410 | "embedding_dim = 512\n",
411 | "autotune = tf.data.experimental.AUTOTUNE\n",
412 | "\n",
413 | "train_steps = 1262996 * 0.8 // batch_size\n",
414 | "val_steps = train_steps // 10\n",
415 | "epochs = 100\n",
416 | "print('Batch Size:', batch_size)"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 5,
422 | "metadata": {
423 | "colab": {},
424 | "colab_type": "code",
425 | "id": "lEohYPP7QAK1"
426 | },
427 | "outputs": [
428 | {
429 | "name": "stdout",
430 | "output_type": "stream",
431 | "text": [
432 | "Logging in: gs://tfworld/model_files_2/logs_1577426847.0237095\n"
433 | ]
434 | }
435 | ],
436 | "source": [
437 | "base_dir = 'gs://tfworld'\n",
438 | "model_dir = os.path.join(base_dir, 'model_files_2')\n",
439 | "tensorboard_dir = os.path.join(model_dir, 'logs_'+str(time()))\n",
440 | "tfrecords_pattern_train = os.path.join(base_dir, 'tfrecords', 'train*')\n",
441 | "tfrecords_pattern_val = os.path.join(base_dir, 'tfrecords', 'eval*')\n",
442 | "\n",
443 | "print('Logging in: ', tensorboard_dir)"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": 6,
449 | "metadata": {
450 | "colab": {},
451 | "colab_type": "code",
452 | "id": "iYONEmj6LiZ9"
453 | },
454 | "outputs": [],
455 | "source": [
456 | "features = {\n",
457 | " 'title':tf.io.FixedLenFeature([512], dtype=tf.int64),\n",
458 | " 'citation':tf.io.FixedLenFeature([512], dtype=tf.float32),\n",
459 | " }\n",
460 | "\n",
461 | "def parse_example(example_proto):\n",
462 | " parsed_example = tf.io.parse_single_example(example_proto, features)\n",
463 | " title = parsed_example['title']\n",
464 | " citation = parsed_example['citation']\n",
465 | " \n",
466 | " title = tf.cast(title, dtype=tf.int32)\n",
467 | " citation = tf.cast(citation, dtype=tf.float32)\n",
468 | " return (title, citation), tf.constant([1.0], dtype=tf.float32)"
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": 7,
474 | "metadata": {
475 | "colab": {
476 | "base_uri": "https://localhost:8080/",
477 | "height": 68
478 | },
479 | "colab_type": "code",
480 | "id": "Sg8AtlNpQAK7",
481 | "outputId": "49122387-74be-4cbc-8d1b-244a6f8bfff0"
482 | },
483 | "outputs": [
484 | {
485 | "data": {
486 | "text/plain": [
487 | "(((TensorSpec(shape=(256, 512), dtype=tf.int32, name=None),\n",
488 | " TensorSpec(shape=(256, 512), dtype=tf.float32, name=None)),\n",
489 | " TensorSpec(shape=(256, 1), dtype=tf.float32, name=None)),\n",
490 | " ((TensorSpec(shape=(256, 512), dtype=tf.int32, name=None),\n",
491 | " TensorSpec(shape=(256, 512), dtype=tf.float32, name=None)),\n",
492 | " TensorSpec(shape=(256, 1), dtype=tf.float32, name=None)))"
493 | ]
494 | },
495 | "execution_count": 7,
496 | "metadata": {},
497 | "output_type": "execute_result"
498 | }
499 | ],
500 | "source": [
501 | "with strategy.scope():\n",
502 | " train_files = tf.data.Dataset.list_files(tfrecords_pattern_train)\n",
503 | " train_dataset = train_files.interleave(tf.data.TFRecordDataset,\n",
504 | " cycle_length=32,\n",
505 | " block_length=4,\n",
506 | " num_parallel_calls=autotune)\n",
507 | " train_dataset = train_dataset.map(parse_example, num_parallel_calls=autotune)\n",
508 | " train_dataset = train_dataset.batch(batch_size, drop_remainder=True)\n",
509 | " train_dataset = train_dataset.repeat()\n",
510 | " train_dataset = train_dataset.prefetch(autotune)\n",
511 | "\n",
512 | " val_files = tf.data.Dataset.list_files(tfrecords_pattern_val)\n",
513 | " val_dataset = val_files.interleave(tf.data.TFRecordDataset,\n",
514 | " cycle_length=32,\n",
515 | " block_length=4,\n",
516 | " num_parallel_calls=autotune)\n",
517 | " val_dataset = val_dataset.map(parse_example, num_parallel_calls=autotune)\n",
518 | " val_dataset = val_dataset.batch(batch_size, drop_remainder=True)\n",
519 | " val_dataset = val_dataset.repeat()\n",
520 | " val_dataset = val_dataset.prefetch(autotune)\n",
521 | "\n",
522 | "tf.data.experimental.get_structure(train_dataset), tf.data.experimental.get_structure(val_dataset)"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": 8,
528 | "metadata": {
529 | "colab": {},
530 | "colab_type": "code",
531 | "id": "TidDJ55-LiZ_"
532 | },
533 | "outputs": [],
534 | "source": [
535 | "@tf.function\n",
536 | "def loss_fn(_, probs):\n",
537 | " '''\n",
538 | " 1. Every sample is its own positive, and the rest of the\n",
539 | " elements in the batch are its negative.\n",
540 | " 2. Each TPU core gets 1/8 * global_batch_size elements, hence\n",
541 | " compute shape dynamically.\n",
542 | " 3. Dataset produces dummy labels to make sure the loss_fn matches\n",
543 | " the loss signature of keras, actual labels are computed inside this\n",
544 | " function.\n",
545 | " 4. `probs` lie in [0, 1] and are to be treated as probabilities.\n",
546 | " '''\n",
547 | " bs = tf.shape(probs)[0] \n",
548 | " labels = tf.eye(bs, bs)\n",
549 | " return tf.losses.categorical_crossentropy(labels, probs, from_logits=False)\n",
550 | " \n",
551 | "def create_model(drop_out):\n",
552 | " textIds = tf.keras.Input(shape=(512,), dtype=tf.int32) # from bert tokenizer\n",
553 | " citation = tf.keras.Input(shape=(512,)) # normalized word2vec outputs\n",
554 | " \n",
555 | " bert_model = TFBertModel.from_pretrained('scibert_scivocab_uncased', from_pt=True)\n",
556 | " \n",
557 | " textOut = bert_model(textIds)\n",
558 | " textOutMean = tf.reduce_mean(textOut[0], axis=1)\n",
559 | " textOutSim = Dense(units=embedding_dim, activation='tanh', name='DenseTitle')(textOutMean)\n",
560 | " textOutSim = Dropout(drop_out)(textOutSim)\n",
561 | " \n",
562 | " citationSim = Dense(units=embedding_dim, activation='tanh', name='DenseCitation')(citation)\n",
563 | " citationSim = Dropout(drop_out)(citationSim)\n",
564 | "\n",
565 | " # Get dot product of each of title x citation combinations\n",
566 | " dotProduct = tf.reduce_sum(tf.multiply(textOutSim[:, None, :], citationSim), axis=-1)\n",
567 | " \n",
568 | " # Softmax to make sure each row has sum == 1.0\n",
569 | " probs = tf.nn.softmax(dotProduct, axis=-1)\n",
570 | "\n",
571 | " model = tf.keras.Model(inputs=[textIds, citation], outputs=[probs])\n",
572 | " return model"
573 | ]
574 | },
575 | {
576 | "cell_type": "code",
577 | "execution_count": 9,
578 | "metadata": {
579 | "colab": {},
580 | "colab_type": "code",
581 | "id": "o8MXgYFSLiaB"
582 | },
583 | "outputs": [],
584 | "source": [
585 | "with strategy.scope():\n",
586 | " model = create_model(drop_out=0.20)\n",
587 | " model.compile(loss=loss_fn,\n",
588 | " optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5))"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": 10,
594 | "metadata": {
595 | "colab": {},
596 | "colab_type": "code",
597 | "id": "D3f-F4QXDmL7"
598 | },
599 | "outputs": [],
600 | "source": [
601 | "callbacks = [tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, update_freq='epoch'), \n",
602 | " tf.keras.callbacks.ModelCheckpoint(filepath=model_dir + '/epoch_{epoch:02d}_{loss:.2f}',\n",
603 | " monitor='loss',\n",
604 | " verbose=1,\n",
605 | " save_weights_only=True,\n",
606 | " save_freq='epoch')\n",
607 | " ]"
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": null,
613 | "metadata": {
614 | "colab": {
615 | "base_uri": "https://localhost:8080/",
616 | "height": 156
617 | },
618 | "colab_type": "code",
619 | "id": "p6S3kWXPLiaF",
620 | "outputId": "576ab892-4002-4424-b9d1-d13055c40b2d"
621 | },
622 | "outputs": [
623 | {
624 | "name": "stdout",
625 | "output_type": "stream",
626 | "text": [
627 | "Train for 3946.0 steps, validate for 394.0 steps\n",
628 | "WARNING:tensorflow:Model failed to serialize as JSON. Ignoring... \n"
629 | ]
630 | },
631 | {
632 | "name": "stderr",
633 | "output_type": "stream",
634 | "text": [
635 | "WARNING:tensorflow:Model failed to serialize as JSON. Ignoring... \n"
636 | ]
637 | },
638 | {
639 | "name": "stdout",
640 | "output_type": "stream",
641 | "text": [
642 | "Epoch 1/100\n",
643 | "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n"
644 | ]
645 | },
646 | {
647 | "name": "stderr",
648 | "output_type": "stream",
649 | "text": [
650 | "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n"
651 | ]
652 | },
653 | {
654 | "name": "stdout",
655 | "output_type": "stream",
656 | "text": [
657 | "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n"
658 | ]
659 | },
660 | {
661 | "name": "stderr",
662 | "output_type": "stream",
663 | "text": [
664 | "WARNING:tensorflow:Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.\n"
665 | ]
666 | },
667 | {
668 | "name": "stdout",
669 | "output_type": "stream",
670 | "text": [
671 | " 1/3946 [..............................] - ETA: 74:42:48 - loss: 3.7574WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (1.095701). Check your callbacks.\n"
672 | ]
673 | },
674 | {
675 | "name": "stderr",
676 | "output_type": "stream",
677 | "text": [
678 | "WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (1.095701). Check your callbacks.\n"
679 | ]
680 | },
681 | {
682 | "name": "stdout",
683 | "output_type": "stream",
684 | "text": [
685 | "3945/3946 [============================>.] - ETA: 0s - loss: 3.0831\n",
686 | "Epoch 00001: saving model to gs://tfworld/model_files_2/epoch_01_3.08\n",
687 | "3946/3946 [==============================] - 2403s 609ms/step - loss: 3.0830 - val_loss: 2.7980\n",
688 | "Epoch 2/100\n",
689 | "3945/3946 [============================>.] - ETA: 0s - loss: 2.7259\n",
690 | "Epoch 00002: saving model to gs://tfworld/model_files_2/epoch_02_2.73\n",
691 | "3946/3946 [==============================] - 2362s 599ms/step - loss: 2.7259 - val_loss: 2.5801\n",
692 | "Epoch 3/100\n",
693 | "3945/3946 [============================>.] - ETA: 0s - loss: 2.4969\n",
694 | "Epoch 00003: saving model to gs://tfworld/model_files_2/epoch_03_2.50\n",
695 | "3946/3946 [==============================] - 2363s 599ms/step - loss: 2.4969 - val_loss: 2.4444\n",
696 | "Epoch 4/100\n",
697 | "3945/3946 [============================>.] - ETA: 0s - loss: 2.3048\n",
698 | "Epoch 00004: saving model to gs://tfworld/model_files_2/epoch_04_2.30\n",
699 | "3946/3946 [==============================] - 2359s 598ms/step - loss: 2.3047 - val_loss: 2.3618\n",
700 | "Epoch 5/100\n",
701 | "3945/3946 [============================>.] - ETA: 0s - loss: 2.1282\n",
702 | "Epoch 00005: saving model to gs://tfworld/model_files_2/epoch_05_2.13\n",
703 | "3946/3946 [==============================] - 2362s 599ms/step - loss: 2.1281 - val_loss: 2.3036\n",
704 | "Epoch 6/100\n",
705 | "3945/3946 [============================>.] - ETA: 0s - loss: 1.9581\n",
706 | "Epoch 00006: saving model to gs://tfworld/model_files_2/epoch_06_1.96\n",
707 | "3946/3946 [==============================] - 2382s 604ms/step - loss: 1.9580 - val_loss: 2.2698\n",
708 | "Epoch 7/100\n",
709 | "3945/3946 [============================>.] - ETA: 0s - loss: 1.7960"
710 | ]
711 | }
712 | ],
713 | "source": [
714 | "model.fit(train_dataset,\n",
715 | " epochs=epochs,\n",
716 | " steps_per_epoch=train_steps,\n",
717 | " validation_data=val_dataset,\n",
718 | " validation_steps=val_steps,\n",
719 | " validation_freq=1,\n",
720 | " callbacks=callbacks)"
721 | ]
722 | },
723 | {
724 | "cell_type": "code",
725 | "execution_count": null,
726 | "metadata": {
727 | "colab": {},
728 | "colab_type": "code",
729 | "id": "B8DkqlzfDmL_"
730 | },
731 | "outputs": [],
732 | "source": []
733 | }
734 | ],
735 | "metadata": {
736 | "accelerator": "GPU",
737 | "colab": {
738 | "include_colab_link": true,
739 | "name": "model_debug.ipynb",
740 | "provenance": []
741 | },
742 | "kernelspec": {
743 | "display_name": "Python 3",
744 | "language": "python",
745 | "name": "python3"
746 | },
747 | "language_info": {
748 | "codemirror_mode": {
749 | "name": "ipython",
750 | "version": 3
751 | },
752 | "file_extension": ".py",
753 | "mimetype": "text/x-python",
754 | "name": "python",
755 | "nbconvert_exporter": "python",
756 | "pygments_lexer": "ipython3",
757 | "version": "3.5.3"
758 | }
759 | },
760 | "nbformat": 4,
761 | "nbformat_minor": 4
762 | }
763 |
--------------------------------------------------------------------------------
/src/TFrecordWriter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | from tqdm import tqdm
4 |
5 |
6 | class TFrecordWriter:
7 | def __init__(self,
8 | n_samples,
9 | n_shards,
10 | output_dir='',
11 | prefix=''):
12 | self.n_samples = n_samples
13 | self.n_shards = n_shards
14 | self.step_size = self.n_samples // self.n_shards + 1
15 | self.prefix = prefix
16 | self.output_dir = output_dir
17 | self.buffer = []
18 | self.file_count = 1
19 |
20 | def make_example(self, title, vector):
21 | feature = {
22 | 'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title)),
23 | 'citation': tf.train.Feature(float_list=tf.train.FloatList(value=vector))
24 | }
25 | return tf.train.Example(features=tf.train.Features(feature=feature))
26 |
27 | def write_tfrecord(self, tfrecord_path):
28 | print('writing {} samples in {}'.format(
29 | len(self.buffer), tfrecord_path))
30 | with tf.io.TFRecordWriter(tfrecord_path) as writer:
31 | for (title, vector) in tqdm(self.buffer):
32 | example = self.make_example(title, vector)
33 | writer.write(example.SerializeToString())
34 |
35 | def push(self, title, vector):
36 | self.buffer.append([title, vector])
37 | if len(self.buffer) == self.step_size:
38 | fname = self.prefix + '_000' + str(self.file_count) + '.tfrecord'
39 | tfrecord_path = os.path.join(self.output_dir, fname)
40 | self.write_tfrecord(tfrecord_path)
41 | self.clear_buffer()
42 | self.file_count += 1
43 |
44 | def flush_last(self):
45 | if len(self.buffer):
46 | fname = self.prefix + '_000' + str(self.file_count) + '.tfrecord'
47 | tfrecord_path = os.path.join(self.output_dir, fname)
48 | self.write_tfrecord(tfrecord_path)
49 |
50 | def clear_buffer(self):
51 | self.buffer = []
52 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | from tensorflow.keras.layers import (Dense,
4 | Dropout)
5 | from transformers import TFBertModel
6 | from time import time
7 | print('TensorFlow:', tf.__version__)
8 |
9 | try:
10 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
11 | 'srihari-1-tpu')
12 | print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
13 | except ValueError:
14 | tpu = None
15 |
16 | if tpu:
17 | tf.config.experimental_connect_to_cluster(tpu)
18 | tf.tpu.experimental.initialize_tpu_system(tpu)
19 | strategy = tf.distribute.experimental.TPUStrategy(tpu)
20 | else:
21 | strategy = tf.distribute.MirroredStrategy()
22 |
23 | print("REPLICAS: ", strategy.num_replicas_in_sync)
24 |
25 | batch_size = 32 * strategy.num_replicas_in_sync
26 | embedding_dim = 512
27 | autotune = tf.data.experimental.AUTOTUNE
28 |
29 | train_steps = 1262996 * 0.8 // batch_size
30 | val_steps = train_steps // 10
31 | epochs = 100
32 | print('Batch Size:', batch_size)
33 |
34 | config_name = 'model_a'
35 | base_dir = 'gs://tfworld/hparams_search'
36 | model_dir = os.path.join(base_dir, config_name)
37 | tensorboard_dir = os.path.join(model_dir, 'logs_' + str(time()))
38 | tfrecords_pattern_train = os.path.join(base_dir, 'tfrecords', 'train*')
39 | tfrecords_pattern_val = os.path.join(base_dir, 'tfrecords', 'eval*')
40 |
41 | print('Logging in: ', tensorboard_dir)
42 |
43 | features = {
44 | 'title': tf.io.FixedLenFeature([512], dtype=tf.int64),
45 | 'citation': tf.io.FixedLenFeature([512], dtype=tf.float32),
46 | }
47 |
48 |
49 | def parse_example(example_proto):
50 | parsed_example = tf.io.parse_single_example(example_proto, features)
51 | title = parsed_example['title']
52 | citation = parsed_example['citation']
53 |
54 | title = tf.cast(title, dtype=tf.int32)
55 | citation = tf.cast(citation, dtype=tf.float32)
56 | return (title, citation), tf.constant([1.0], dtype=tf.float32)
57 |
58 |
59 | with strategy.scope():
60 | train_files = tf.data.Dataset.list_files(tfrecords_pattern_train)
61 | train_dataset = train_files.interleave(tf.data.TFRecordDataset,
62 | cycle_length=32,
63 | block_length=4,
64 | num_parallel_calls=autotune)
65 | train_dataset = train_dataset.map(
66 | parse_example, num_parallel_calls=autotune)
67 | train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
68 | train_dataset = train_dataset.repeat()
69 | train_dataset = train_dataset.prefetch(autotune)
70 |
71 | val_files = tf.data.Dataset.list_files(tfrecords_pattern_val)
72 | val_dataset = val_files.interleave(tf.data.TFRecordDataset,
73 | cycle_length=32,
74 | block_length=4,
75 | num_parallel_calls=autotune)
76 | val_dataset = val_dataset.map(parse_example, num_parallel_calls=autotune)
77 | val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
78 | val_dataset = val_dataset.repeat()
79 | val_dataset = val_dataset.prefetch(autotune)
80 |
81 |
82 | @tf.function
83 | def loss_fn(_, probs):
84 | '''
85 | 1. Every sample is its own positive, and the rest of the
86 | elements in the batch are its negative.
87 | 2. Each TPU core gets 1/8 * global_batch_size elements, hence
88 | compute shape dynamically.
89 | 3. Dataset produces dummy labels to make sure the loss_fn matches
90 | the loss signature of keras, actual labels are computed inside this
91 | function.
92 | 4. `probs` lie in [0, 1] and are to be treated as probabilities.
93 | '''
94 | bs = tf.shape(probs)[0]
95 | labels = tf.eye(bs, bs)
96 | return tf.losses.categorical_crossentropy(labels,
97 | probs,
98 | from_logits=False)
99 |
100 |
101 | def create_model(drop_out, dense_units, activation):
102 | textIds = tf.keras.Input(
103 | shape=(512,), dtype=tf.int32) # from bert tokenizer
104 | # normalized word2vec outputs
105 | citation = tf.keras.Input(shape=(512,))
106 |
107 | bert_model = TFBertModel.from_pretrained(
108 | 'scibert_scivocab_uncased', from_pt=True)
109 |
110 | textOut = bert_model(textIds)
111 | textOutMean = tf.reduce_mean(textOut[0], axis=1)
112 | textOutSim = Dense(units=embedding_dim, activation=activation,
113 | name='DenseTitle')(textOutMean)
114 | textOutSim = Dropout(drop_out)(textOutSim)
115 |
116 | citationSim = citation
117 | for units in dense_units:
118 | citationSim = Dense(units=units, activation=activation,
119 | name='DenseCitation')(citationSim)
120 | citationSim = Dropout(drop_out)(citationSim)
121 |
122 | # Get dot product of each of title x citation combinations
123 | dotProduct = tf.reduce_sum(tf.multiply(
124 | textOutSim[:, None, :], citationSim), axis=-1)
125 |
126 | # Softmax to make sure each row has sum == 1.0
127 | probs = tf.nn.softmax(dotProduct, axis=-1)
128 |
129 | model = tf.keras.Model(inputs=[textIds, citation], outputs=[probs])
130 | return model
131 |
132 | config = {
133 | 'drop_out':0.2,
134 | 'dense_units':[512, 512],
135 | 'activation':'tanh'
136 | }
137 | with strategy.scope():
138 | model = create_model(**config)
139 | model.compile(loss=loss_fn,
140 | optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5))
141 |
142 | callbacks = [tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir,
143 | update_freq='epoch'),
144 | tf.keras.callbacks.ModelCheckpoint(filepath=model_dir + '/epoch_{epoch:02d}_{loss:.2f}',
145 | monitor='loss',
146 | verbose=1,
147 | save_weights_only=True,
148 | save_freq='epoch')
149 | ]
150 |
151 | model.fit(train_dataset,
152 | epochs=epochs,
153 | steps_per_epoch=train_steps,
154 | validation_data=val_dataset,
155 | validation_steps=val_steps,
156 | validation_freq=1,
157 | callbacks=callbacks)
158 |
--------------------------------------------------------------------------------