├── 0_Intro_OT.ipynb
├── 0_Intro_OT.py
├── 1_DomainAdaptation.ipynb
├── 1_DomainAdaptation.py
├── 2_ColorGrading.ipynb
├── 2_ColorGrading.py
├── 3_WMD.ipynb
├── 3_WMD.py
├── LICENSE
├── README.md
├── data
├── data_text.npz
├── klimt.jpg
├── manhattan.npz
├── mnist_usps.npz
├── model.npz
└── schiele.jpg
└── slides
├── Part1_intro_OT_2022.pdf
├── Part2_UOT_GW_Rennes_2022.pdf
└── Part3_OTML_Rennes_2022.pdf
/0_Intro_OT.py:
--------------------------------------------------------------------------------
1 |
2 | # coding: utf-8
3 |
4 | # # Introduction to Optimal Transport with Python
5 | #
6 | # #### *Rémi Flamary, Nicolas Courty*
7 |
8 | # ## POT installation
9 |
10 | # + Install with pip:
11 | # ```bash
12 | # pip install pot
13 | # ```
14 | # + Install with conda
15 | # ```bash
16 | # conda install -c conda-forge pot
17 | # ```
18 |
19 | # ## POT Python Optimal Transport Toolbox
20 | #
21 | # #### Import the toolbox
22 |
23 | # In[1]:
24 |
25 |
26 | import numpy as np # always need it
27 | import scipy as sp # often use it
28 | import pylab as pl # do the plots
29 |
30 | import ot # ot
31 |
32 |
33 | #%% #### Getting help
34 | #
35 | # Online documentation : [http://pot.readthedocs.io](http://pot.readthedocs.io)
36 | #
37 | # Or inline help:
38 | #
39 |
40 | # In[2]:
41 |
42 |
43 | help(ot.dist)
44 |
45 |
46 | #%% ## First OT Problem
47 | #
48 | # We will solve the Bakery/Cafés problem of transporting croissants from a number of Bakeries to Cafés in a City (In this case Manhattan). We did a quick google map search in Manhattan for bakeries and Cafés:
49 | #
50 | # 
51 | #
52 | # We extracted from this search their positions and generated fictional production and sale number (that both sum to the same value).
53 | #
54 | # We have acess to the position of Bakeries ```bakery_pos``` and their respective production ```bakery_prod``` which describe the source distribution. The Cafés where the croissants are sold are defiend also by their position ```cafe_pos``` and ```cafe_prod```. For fun we also provide a map ```Imap``` that will illustrate the position of these shops in the city.
55 | #
56 | #
57 | # Now we load the data
58 | #
59 | #
60 |
61 | # In[3]:
62 |
63 |
64 | data=np.load('data/manhattan.npz')
65 |
66 | bakery_pos=data['bakery_pos']
67 | bakery_prod=data['bakery_prod']
68 | cafe_pos=data['cafe_pos']
69 | cafe_prod=data['cafe_prod']
70 | Imap=data['Imap']
71 |
72 | print('Bakery production: {}'.format(bakery_prod))
73 | print('Cafe sale: {}'.format(cafe_prod))
74 | print('Total croissants : {}'.format(cafe_prod.sum()))
75 |
76 |
77 | #%% #### Plotting bakeries in the city
78 | #
79 | # Next we plot the position of the bakeries and cafés on the map. The size of the circle is proportional to their production.
80 | #
81 |
82 | # In[4]:
83 |
84 |
85 |
86 | pl.figure(1,(8,7))
87 | pl.clf()
88 | pl.imshow(Imap,interpolation='bilinear') # plot the map
89 | pl.scatter(bakery_pos[:,0],bakery_pos[:,1],s=bakery_prod,c='r', edgecolors='k',label='Bakeries')
90 | pl.scatter(cafe_pos[:,0],cafe_pos[:,1],s=cafe_prod,c='b', edgecolors='k',label='Cafés')
91 | pl.legend()
92 | pl.title('Manhattan Bakeries and Cafés');
93 |
94 |
95 | #%% #### Cost matrix
96 | #
97 | #
98 | # We compute the cost matrix between the bakeries and the cafés, this will be the transport cost matrix. This can be done using the [ot.dist](http://pot.readthedocs.io/en/stable/all.html#ot.dist) that defaults to squared euclidean distance but can return other things such as cityblock (or manhattan distance).
99 | #
100 | #
101 |
102 | #%% #### Solving the OT problem with [ot.emd](http://pot.readthedocs.io/en/stable/all.html#ot.emd)
103 |
104 | # #### Transportation plan vizualization
105 | #
106 | # A good vizualization of the OT matrix in the 2D plane is to denote the transportation of mass between a Bakery and a Café by a line. This can easily be done with a double ```for``` loop.
107 | #
108 | # In order to make it more interpretable one can also use the ```alpha``` parameter of plot and set it to ```alpha=G[i,j]/G[i,j].max()```.
109 |
110 | #%% #### OT loss and dual variables
111 | #
112 | # The resulting wasserstein loss loss is of the form:
113 | #
114 | # $W=\sum_{i,j}\gamma_{i,j}C_{i,j}$
115 | #
116 | # where $\gamma$ is the optimal transport matrix.
117 | #
118 |
119 | #%% #### Regularized OT with SInkhorn
120 | #
121 | # The Sinkhorn algorithm is very simple to code. You can implement it directly using the following pseudo-code:
122 | #
123 | # 
124 | #
125 | # An alternative is to use the POT toolbox with [ot.sinkhorn](http://pot.readthedocs.io/en/stable/all.html#ot.sinkhorn)
126 | #
127 | # Be carefull to numerical problems. A good pre-provcessing for Sinkhorn is to divide the cost matrix ```C```
128 | # by its maximum value.
129 |
--------------------------------------------------------------------------------
/1_DomainAdaptation.py:
--------------------------------------------------------------------------------
1 |
2 | # coding: utf-8
3 |
4 | # # Domain Adaptation between digits
5 | #
6 | # #### *Rémi Flamary, Nicolas Courty*
7 | #
8 | # In this practical session we will apply on digit classification the OT based domain adaptation method proposed in
9 | #
10 | # N. Courty, R. Flamary, D. Tuia, A. Rakotomamonjy, "[Optimal transport for domain adaptation](http://remi.flamary.com/biblio/courty2016optimal.pdf)", Pattern Analysis and Machine Intelligence, IEEE Transactions on , 2016.
11 | #
12 | # 
13 | #
14 | # To this end we will try and adapt between the MNIST and USPS datasets. Since those datasets do not have the same resolution (28x28 and 16x16 for MNSIT and USPS) we perform a zeros padding of the USPS digits
15 | #
16 | #
17 | #%% #### Import modules
18 | #
19 | # First we import the relevant modules. Note that you will need ```sklearn``` to learn the Support Vector Machine cleassifier and to projet the data with TSNE.
20 | #
21 |
22 | # In[1]:
23 |
24 |
25 | import numpy as np # always need it
26 | import pylab as pl # do the plots
27 |
28 | from sklearn.svm import SVC
29 | from sklearn.manifold import TSNE
30 | import ot
31 |
32 |
33 | #%% ### Loading data and normalization
34 | #
35 | # We load the data in memory and perform a normalization of the images so that they all sum to 1.
36 | #
37 | # Note that every line in the ```xs``` and ```xt``` is a 28x28 image.
38 |
39 | # In[2]:
40 |
41 |
42 | data=np.load('data/mnist_usps.npz')
43 |
44 | xs,ys=data['xs'],data['ys']
45 | xt,yt=data['xt'],data['yt']
46 |
47 |
48 | # normalization
49 | xs=xs/xs.sum(1,keepdims=True) # every l
50 | xt=xt/xt.sum(1,keepdims=True)
51 |
52 | ns=xs.shape[0]
53 | nt=xt.shape[0]
54 |
55 |
56 | #%% ### Vizualizing Source (MNIST) and Target (USPS) datasets
57 | #
58 | #
59 | #
60 | #
61 |
62 | # In[3]:
63 |
64 |
65 |
66 | # function for plotting images
67 | def plot_image(x):
68 | pl.imshow(x.reshape((28,28)),cmap='gray')
69 | pl.xticks(())
70 | pl.yticks(())
71 |
72 |
73 | nb=10
74 |
75 | # Fisrt we plot MNIST
76 | pl.figure(1,(nb,nb))
77 | for i in range(nb*nb):
78 | pl.subplot(nb,nb,1+i)
79 | c=i%nb
80 | plot_image(xs[np.where(ys==c)[0][i//nb],:])
81 | pl.gcf().suptitle("MNIST", fontsize=20);
82 | pl.gcf().subplots_adjust(top=0.95)
83 |
84 | # Then we plot USPS
85 | pl.figure(2,(nb,nb))
86 | for i in range(nb*nb):
87 | pl.subplot(nb,nb,1+i)
88 | c=i%nb
89 | plot_image(xt[np.where(yt==c)[0][i//nb],:])
90 | pl.gcf().suptitle("USPS", fontsize=20);
91 | pl.gcf().subplots_adjust(top=0.95)
92 |
93 |
94 | # Note that there is a large discrepancy especially between the 1,2 and 5 that have differnt shapes in both datasets.
95 | #
96 | # Also since we have performe zero padding on the USPS digits theyr are in average slightly smaller than NMSIT that can take the whole image.
97 | #
98 | #
99 | #%% ### Classification without domain adaptation
100 | #
101 | # We learn a classifier on the MNIST dataset (we will not be state of the art on 1000 samples). We evaluate this claddifier on MNIST and on the USPS dataset.
102 |
103 | # In[4]:
104 |
105 |
106 |
107 | # Train SVM with reg parameter C=1 and RBF kernel parameter gamma=1e1
108 | clf=SVC(C=1,gamma=1e2) # might take time
109 | clf.fit(xs,ys)
110 |
111 | # Compute accuracy
112 | ACC_MNIST=clf.score(xs,ys) # beware of overfitting !
113 | ACC_USPS=clf.score(xt,yt)
114 |
115 | print('ACC_MNIST={:1.3f}'.format(ACC_MNIST))
116 | print('ACC_USPS={:1.3f}'.format(ACC_USPS))
117 |
118 |
119 | #%% There is a very large loss in performances. This can be better explained by performning a TSNE embedding on the data.
120 | #
121 | # ### TSNE of the Source/Target domains
122 | #
123 | # [t-distributed stochastic neighbor embedding (TSNE)](http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) is a well knwn approch that allow projection of complex high dimensionnal data in a lower dimensionnal space while keeping its structure.
124 | #
125 | #
126 |
127 | # In[5]:
128 |
129 |
130 |
131 | xtot=np.concatenate((xs,xt),axis=0) # all data
132 |
133 | xp=TSNE().fit_transform(xtot) # this maigh take a while (30 sec on my laptop)
134 |
135 | # separate again but now in 2D
136 | xps=xp[:ns,:]
137 | xpt=xp[ns:,:]
138 |
139 |
140 | # In[6]:
141 |
142 |
143 |
144 | pl.figure(3,(12,10))
145 |
146 | pl.scatter(xps[:,0],xps[:,1],c=ys,marker='o',cmap='tab10',label='Source data')
147 | pl.scatter(xpt[:,0],xpt[:,1],c=yt,marker='+',cmap='tab10',label='Target data')
148 | pl.legend()
149 | pl.colorbar()
150 | pl.title('TSNE Embedding of the Source/Target data');
151 |
152 |
153 | # We can see that while the classes are relatively well clustured, the clusters from source and target dataset rarely overlapp. This is the main reason for the important loss in performance between Source and target.
154 | #
155 | #%% ### Optimal Transport Domain Adaptation (OTDA)
156 | #
157 | # Now we perform domain adaptation with the following 3 steps illustrated at the top of the notebook:
158 | #
159 | # 1. Compute the OT matrix betxeen source and target datasets
160 | # 1. Perform OT mapping with barycentric mapping (```np.dot```).
161 | # 1. Estimate classifier on the mapped source samples
162 | #
163 | # #### 1. OT between domain
164 | #
165 | # First we compute the Cost matrix and vizualize it. Note that the sampels are sorted by class in both source and target domains in order to better see the class based structure in the cost matrix and OT matrix.
166 | #
167 | #
168 | #
169 |
170 | # We can clearly see the (noisy) structure in the matrix. It is also interesting to note that the class 1 in usps (second column) is particularly different fromm all the other classes in MNIST data (even class 1).
171 | #
172 | #
173 | #%% Next we compute the OT matrix using exact LP OT [ot.emd](http://pot.readthedocs.io/en/stable/all.html#ot.emd) or regularized OT with [ot.sinkhorn](http://pot.readthedocs.io/en/stable/all.html#ot.sinkhorn).
174 |
175 | # We can see that most of the trasportation is done in the block-diagonal which means that in average samples from one class are affected to the proper classs in the target.
176 | #
177 | #%% #### 2/3 Mapping + Classification
178 | #
179 | # Now we perform the barycentric mapping of the samples and traing the classifier on the mapped samples. We recomend to use a smaller ```gamma=1e1``` here because some samples will be mislabeled and a smooth classifier will work better.
180 |
181 | # We can see that the adaptation with EMD leads to a performance gain of nearly 10%. You can get even better performances using entropic regularized OT or group lasso regularization.
182 | #
183 | #%% #### TNSE vizualization for OTDA
184 | #
185 | # In order to see the effect of the adaptation we can perform a new TSNE embedding to see if the classes are betetr aligned.
186 | #
187 | #
188 |
189 | # In[ ]:
190 |
191 |
192 |
193 |
194 |
195 | # We can see that when using emd solver the OT matrix is a permutation wo the samples are exactly superimposed. In average the classes are also well transported but there exist a number of badly transported samples that have a class permutation.
196 | #
197 | #
198 | #%% #### Transported sampels vizualization
199 | #
200 | # We can now also plot the transported samples.
201 |
202 | # Those are the same MNIST samples that have been plotted above but after trasnportation. There are several samples that are transported on the wrong class but again in average the class information is preserved which explain the accuracy gain.
203 | #
204 | #%% ### OTDA with regularization
205 | #
206 | # We now recomend to try regularized OT and to redo classification/TSNE/Vizu to see the impact of the regularization in term of performances, TNSE and transported samples.
207 |
--------------------------------------------------------------------------------
/2_ColorGrading.py:
--------------------------------------------------------------------------------
1 |
2 | # coding: utf-8
3 |
4 | # # Color grading with optimal transport
5 | #
6 | # #### *Nicolas Courty, Rémi Flamary*
7 |
8 | #%% In this tutorial we will learn how to perform color grading of images with optimal transport. This is somehow a very direct usage of optimal transport. You will learn how to treat an image as an empirical distribution, and apply optimal transport to find a matching between two different images seens as distributions.
9 |
10 | # First we need to load two images. To this end we need some packages
11 | #
12 |
13 | # In[1]:
14 |
15 |
16 | import numpy as np
17 | import matplotlib.pylab as pl
18 | from matplotlib.pyplot import imread
19 | from mpl_toolkits.mplot3d import Axes3D
20 |
21 | I1 = imread('./data/klimt.jpg').astype(np.float64) / 256
22 | I2 = imread('./data/schiele.jpg').astype(np.float64) / 256
23 |
24 |
25 | #%% We need some code to visualize them
26 |
27 | # In[2]:
28 |
29 |
30 | def showImage(I,myPreferredFigsize=(8,8)):
31 | pl.figure(figsize=myPreferredFigsize)
32 | pl.imshow(I)
33 | pl.axis('off')
34 | pl.tight_layout()
35 | pl.show()
36 |
37 |
38 | # In[3]:
39 |
40 |
41 | showImage(I1)
42 | showImage(I2)
43 |
44 |
45 | # Those are two beautiful paintings of respectively Gustav Klimt and Egon Schiele. Now we will treat them as empirical distributions.
46 |
47 | #%% Write two functions that will be used to convert 2D images as arrays of 3D points (in the color space), and back.
48 |
49 | # In[4]:
50 |
51 |
52 | def im2mat(I):
53 | """Converts and image to matrix (one pixel per line)"""
54 | pass # use reshape
55 |
56 |
57 | def mat2im(X, shape):
58 | """Converts back a matrix to an image"""
59 | pass # use reshape
60 |
61 | X1 = im2mat(I1)
62 | X2 = im2mat(I2)
63 |
64 |
65 | #%% It is unlikely that our solver, as efficient it can be, can handle so large distributions (1Mx1M for the coupling). We will use the Mini batch k-means procedure from sklearn to subsample those distributions. Write the code that performs this subsampling (you can choose a size of 1000 clusters to have a good approximation of the image)
66 |
67 | # In[5]:
68 |
69 |
70 | import sklearn.cluster as skcluster
71 | nbsamples=1000
72 |
73 |
74 | #%% You can use the following procedure to display them as point clouds
75 |
76 | # In[6]:
77 |
78 |
79 | def showImageAsPointCloud(X,myPreferredFigsize=(8,8)):
80 | fig = pl.figure(figsize=myPreferredFigsize)
81 | ax = fig.add_subplot(111, projection='3d')
82 | ax.set_xlim(0,1)
83 | ax.scatter(X[:,0], X[:,1], X[:,2], c=X, marker='o', alpha=1.0)
84 | ax.set_xlabel('R',fontsize=22)
85 | ax.set_xticklabels([])
86 | ax.set_ylim(0,1)
87 | ax.set_ylabel('G',fontsize=22)
88 | ax.set_yticklabels([])
89 | ax.set_zlim(0,1)
90 | ax.set_zlabel('B',fontsize=22)
91 | ax.set_zticklabels([])
92 | ax.grid('off')
93 | pl.show()
94 |
95 |
96 | #%% You can now compute the coupling between those two distributions using the exact LP solver (EMD)
97 |
98 | #%% using the barycentric mapping method, express the tansformation of both images into the other one
99 |
100 | #%% Since only the centroid of clusters have changed, we need to figure out a simple way of transporting all the pixels in the original image. At first, we will apply a simple strategy where the new value of the pixel corresponds simply to the new position of its corresponding centroid
101 |
102 | #%% Express this transformation in your code, and display the corresponding adapted image.
103 |
104 | #%% You can use also the entropy regularized version of Optimal Transport (a.k.a. the Sinkhorn algorithm) to explore the impact of regularization on the final result
105 | #
106 |
--------------------------------------------------------------------------------
/3_WMD.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Word Mover's distance\n",
8 | "\n",
9 | "In this note notebook we will see an application of Optimal Transport to the problem of computing similarities between sentences and texts. The method under the lens is called 'Word Mover's Distance' in reference to 'Earth Mover's Distance', another name of the Wasserstein $1$ distance, mostly used in computer vision. \n",
10 | "\n",
11 | "Traditionnally, portions of texts are compared by Cosine similarity on bag-of-words vectors, i.e. histograms of occurences of words in a text. It captures the exact similarity in terms of words, but two very related sentences can be orthogonal if the words that are used have the same semantic but are different. Such a semantic distance can be obtained by using *word embeddings*, that are embeddings of words in a Euclidean space (of potentially large dimension) where the Euclidean distance have a semantic meaning: two related words will be close in such embeddings. A popular embedding is the *word2vec* embedding, obtained with neural networks. A study of those mechanisms is not in the scope of this notebook, but the interested reader can find more information on [the corresponding Wikipedia page](https://en.wikipedia.org/wiki/Word2vec). Throughout the rest of this tutorial, we will use a subset of the [GloVe](https://nlp.stanford.edu/projects/glove/) embedding.\n",
12 | "\n",
13 | "The key observation made by Kusner and colleagues [1] is that when confronted to a sentence/document, the optimal transport distance can be used between histograms of occuring words using a ground metric obtained through word embeddings. In such a way, related words will be matched together, and the resulting distance will somehow express semantic relatedness between the content.\n",
14 | "\n",
15 | "[1] Kusner, M., Sun, Y., Kolkin, N., & Weinberger, K. (2015, June). From word embeddings to document distances. In International Conference on Machine Learning (pp. 957-966). http://proceedings.mlr.press/v37/kusnerb15.pdf\n"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "## A basic example \n",
23 | "\n",
24 | "We will start by reproducing the Figure $1$ in the original paper\n",
25 | "\n",
26 | ""
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "Two sentences are considered: 'Obama speaks to the media in Illinois' and 'The president greets the press in Chicago'. It is clear from this example that the Cosine similarity between the two sentences indicates that the two sentences are totally not related, since there is no word in common. We will start by some imports and creating a list of the two sentences as words without stopwords that are not relevant for our analysis."
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 1,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "import os\n",
43 | "\n",
44 | "import numpy as np\n",
45 | "import matplotlib.pylab as pl\n",
46 | "import ot\n",
47 | "\n",
48 | "\n",
49 | "s1 = ['Obama','speaks','media','Illinois']\n",
50 | "s2 = ['President','greets','press','Chicago']\n"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "We will use a subset of the GloVe word embedding, expressed as a dictionnary (word,embedding) that you can load this way"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 2,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | " \n",
67 | "model=dict(np.load('data/model.npz'))\n",
68 | " "
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "Then the embedded representation of the sentences can be obtained by"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 3,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "s1_embed = np.array([model[w] for w in s1])\n",
85 | "s2_embed = np.array([model[w] for w in s2])"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "From the multidimensional scaling method in Scikitlearn, try to visualize the corresponding embedding of words in 2D."
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "data": {
102 | "image/png": "\n",
103 | "text/plain": [
104 | ""
105 | ]
106 | },
107 | "metadata": {},
108 | "output_type": "display_data"
109 | }
110 | ],
111 | "source": [
112 | "from sklearn import manifold\n",
113 | "\n",
114 | "C = ot.dist(np.vstack((s1_embed,s2_embed)))\n",
115 | "\n",
116 | "nmds = manifold.MDS(\n",
117 | " 2,\n",
118 | " eps=1e-9,\n",
119 | " dissimilarity=\"precomputed\",\n",
120 | " n_init=1)\n",
121 | "npos = nmds.fit_transform(C)\n",
122 | "\n",
123 | "pl.figure(figsize=(6,6))\n",
124 | "pl.scatter(npos[:4,0],npos[:4,1],c='r',s=50, edgecolor = 'k')\n",
125 | "for i, txt in enumerate(s1):\n",
126 | " pl.annotate(txt, (npos[i,0]-4,npos[i,1]+2),fontsize=20)\n",
127 | "pl.scatter(npos[4:,0],npos[4:,1],c='b',s=50, edgecolor = 'k')\n",
128 | "for i, txt in enumerate(s2):\n",
129 | " pl.annotate(txt, (npos[i+4,0]-4,npos[i+4,1]+2),fontsize=20)\n",
130 | "pl.axis('off')\n",
131 | "pl.tight_layout()\n",
132 | "pl.show()"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "Let's now compute the coupling between those two distributions and visualize the corresponding result \n"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 5,
145 | "metadata": {},
146 | "outputs": [
147 | {
148 | "data": {
149 | "image/png": "\n",
150 | "text/plain": [
151 | ""
152 | ]
153 | },
154 | "metadata": {},
155 | "output_type": "display_data"
156 | }
157 | ],
158 | "source": [
159 | "C2= ot.dist(s1_embed,s2_embed)\n",
160 | "G=ot.emd(ot.unif(4),ot.unif(4),C2)\n",
161 | "\n",
162 | "pl.figure(figsize=(6,6))\n",
163 | "pl.scatter(npos[:4,0],npos[:4,1],c='r',s=50, edgecolor = 'k')\n",
164 | "for i, txt in enumerate(s1):\n",
165 | " pl.annotate(txt, (npos[i,0]-4,npos[i,1]+2),fontsize=20)\n",
166 | "pl.scatter(npos[4:,0],npos[4:,1],c='b',s=50, edgecolor = 'k')\n",
167 | "for i, txt in enumerate(s2):\n",
168 | " pl.annotate(txt, (npos[i+4,0]-4,npos[i+4,1]+2),fontsize=20)\n",
169 | "for i in range(G.shape[0]):\n",
170 | " for j in range(G.shape[1]):\n",
171 | " if G[i,j]>1e-5:\n",
172 | " pl.plot([npos[i,0],npos[j+4,0]],[npos[i,1],npos[j+4,1]],'k',alpha=G[i,j]/np.max(G))\n",
173 | "pl.title('Word embedding and coupling with OT')\n",
174 | "pl.axis('off')\n",
175 | "pl.tight_layout()\n",
176 | "pl.show()"
177 | ]
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "metadata": {},
182 | "source": [
183 | "## Sentence similarity\n",
184 | "We will now explore the superiority of this Word mover distance (WMD) in a regression context, where our goal is to estimate the similarity (or relatedness) of two sentences on a scale of 0 to 5 (5 being the most similar). Given a set of pairs of sentences with a human annotated relatedness, our goal is predict the relatedness from a new pair of sentences.\n",
185 | "\n",
186 | "We will use the [SICK (Sentences Involving Compositional Knowledge) dataset](http://clic.cimec.unitn.it/composes/sick.html) for this purpose.\n",
187 | "\n",
188 | "We first load it."
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 6,
194 | "metadata": {},
195 | "outputs": [
196 | {
197 | "name": "stdout",
198 | "output_type": "stream",
199 | "text": [
200 | "A group of kids is playing in a yard and an old man is standing in the background\n",
201 | "A group of boys in a yard is playing and a man is standing in the background\n",
202 | "4.5\n"
203 | ]
204 | }
205 | ],
206 | "source": [
207 | " \n",
208 | "data=np.load('data/data_text.npz') \n",
209 | "setA=data['setA']\n",
210 | "setB=data['setB']\n",
211 | "scores=data['scores']\n",
212 | "\n",
213 | "print (setA[0])\n",
214 | "print (setB[0])\n",
215 | "print(scores[0])\n",
216 | "\n",
217 | "np.savez('data/data_text.npz',setA=setA,setB=setB,scores=scores)\n"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {},
223 | "source": [
224 | "We will only keep 200 sentences for learning our regression model and the rest for testing"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 7,
230 | "metadata": {},
231 | "outputs": [],
232 | "source": [
233 | "n=200\n",
234 | "testA=setA[n:]\n",
235 | "trainA=setA[:n]\n",
236 | "testB=setB[n:]\n",
237 | "trainB=setB[:n]\n",
238 | "\n",
239 | "scores_train=scores[:n]\n",
240 | "scores_test=scores[n:]"
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {},
246 | "source": [
247 | "Using the countVectorizer model from ScikitLearn, compute all the bag-of-words representations of the sentences"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 8,
253 | "metadata": {},
254 | "outputs": [],
255 | "source": [
256 | "from sklearn.feature_extraction.text import CountVectorizer\n",
257 | "vect = # TO BE FILLED"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "Build a big data matrix of all the words present in the dataset embeddings\n"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 9,
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "all_feat = # TO BE FILLED"
274 | ]
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "metadata": {},
279 | "source": [
280 | "Compute a big matrix of all pairwise feature distances using the dist() method of POT"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 10,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "D = ot.dist(all_feat)"
290 | ]
291 | },
292 | {
293 | "cell_type": "markdown",
294 | "metadata": {},
295 | "source": [
296 | "now you can write a code that will compute the Cosine and WMD dissimilarities from all the pairs of the training set "
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 11,
302 | "metadata": {},
303 | "outputs": [],
304 | "source": [
305 | "X_cos=[]\n",
306 | "X_wmd=[]\n",
307 | "Y=[]\n",
308 | "\n",
309 | "\n",
310 | "\n",
311 | "for i in range(len(trainA)):\n",
312 | " s1 = vect.transform([trainA[i]]).toarray().ravel()\n",
313 | " s2 = vect.transform([trainB[i]]).toarray().ravel()\n",
314 | " # Cosine similarity between bag of words\n",
315 | " d_cos=# TO BE FILLED\n",
316 | " X_cos.append(d_cos)\n",
317 | " # WMD\n",
318 | " d_wmd=# TO BE FILLED\n",
319 | " X_wmd.append(d_wmd)\n",
320 | " Y.append(scores_train[i])\n",
321 | "\n",
322 | "\n"
323 | ]
324 | },
325 | {
326 | "cell_type": "markdown",
327 | "metadata": {},
328 | "source": [
329 | "Visualize the corresponding golden similarities / distance from the learning set. Hence you have a first appreciation of how much WMD better captures this similarity."
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 12,
335 | "metadata": {},
336 | "outputs": [
337 | {
338 | "data": {
339 | "image/png": "\n",
340 | "text/plain": [
341 | ""
342 | ]
343 | },
344 | "metadata": {},
345 | "output_type": "display_data"
346 | },
347 | {
348 | "data": {
349 | "image/png": "\n",
350 | "text/plain": [
351 | ""
352 | ]
353 | },
354 | "metadata": {},
355 | "output_type": "display_data"
356 | }
357 | ],
358 | "source": [
359 | "pl.figure()\n",
360 | "pl.scatter(X_cos,Y)\n",
361 | "pl.title('Cosine Similarity VS golden score')\n",
362 | "pl.show()\n",
363 | "pl.figure()\n",
364 | "pl.scatter(X_wmd,Y)\n",
365 | "pl.title('WMD Similarity VS golden score')\n",
366 | "pl.show()\n"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {},
372 | "source": [
373 | "You can learn a simple regression model between those 2 quantities. Use a polynomial of degree 2 to learn the regression model."
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 13,
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "import numpy.polynomial.polynomial as poly\n",
383 | "k_cos = # TO BE FILLED\n",
384 | "k_wmd = # TO BE FILLED"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "metadata": {},
390 | "source": [
391 | "Now compute from your regression model the estimated relatedness for all the pairs in the test set."
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": 14,
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "X_cos=[]\n",
401 | "X_wmd=[]\n",
402 | "Y_test=[]\n",
403 | "for i in range(len(testA)):\n",
404 | " s1 = vect.transform([testA[i]]).toarray().ravel()\n",
405 | " s2 = vect.transform([testB[i]]).toarray().ravel()\n",
406 | " # cosine similarity between bag of words\n",
407 | " d_cos=# TO BE FILLED\n",
408 | " X_cos.append(d_cos)\n",
409 | " # WMD\n",
410 | " d_wmd=# TO BE FILLED\n",
411 | " X_wmd.append(d_wmd)\n",
412 | " Y_test.append(scores_test[i])\n",
413 | "\n",
414 | "# Final regression scores\n",
415 | "Y_cos = # TO BE FILLED\n",
416 | "Y_wmd = # TO BE FILLED"
417 | ]
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "metadata": {},
422 | "source": [
423 | "We will use MSE, Spearman's rho and Pearson coefficients to measure the quality of our regression model"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 15,
429 | "metadata": {},
430 | "outputs": [],
431 | "source": [
432 | "from sklearn.metrics import mean_squared_error as mse\n",
433 | "from scipy.stats import pearsonr\n",
434 | "from scipy.stats import spearmanr\n"
435 | ]
436 | },
437 | {
438 | "cell_type": "markdown",
439 | "metadata": {},
440 | "source": [
441 | "Estimate the quality of your regression model for both Cosine and WMD dissimilarities"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": 16,
447 | "metadata": {},
448 | "outputs": [
449 | {
450 | "name": "stdout",
451 | "output_type": "stream",
452 | "text": [
453 | "-------- Cosine\n",
454 | "Test Pearson (test): 0.4391501309646927\n",
455 | "Test Spearman (test): 0.36660337782004127\n",
456 | "Test MSE (test): 0.8784692440428779\n",
457 | "-------- WMD\n",
458 | "Test Pearson (test): 0.7062364858939475\n",
459 | "Test Spearman (test): 0.5872407398045967\n",
460 | "Test MSE (test): 0.5858116558928291\n"
461 | ]
462 | }
463 | ],
464 | "source": [
465 | "print('-------- Cosine')\n",
466 | "\n",
467 | "pr = pearsonr(Y_cos, Y_test)[0]\n",
468 | "sr = spearmanr(Y_cos,Y_test)[0]\n",
469 | "se = mse(Y_cos,Y_test)\n",
470 | "\n",
471 | "print('Test Pearson (test): ' + str(pr))\n",
472 | "print('Test Spearman (test): ' + str(sr))\n",
473 | "print('Test MSE (test): ' + str(se))\n",
474 | "\n",
475 | "print('-------- WMD')\n",
476 | "\n",
477 | "pr = pearsonr(Y_wmd, Y_test)[0]\n",
478 | "sr = spearmanr(Y_wmd,Y_test)[0]\n",
479 | "se = mse(Y_wmd,Y_test)\n",
480 | "\n",
481 | "print('Test Pearson (test): ' + str(pr))\n",
482 | "print('Test Spearman (test): ' + str(sr))\n",
483 | "print('Test MSE (test): ' + str(se))"
484 | ]
485 | },
486 | {
487 | "cell_type": "markdown",
488 | "metadata": {},
489 | "source": [
490 | "Not bad isn't it ?"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "metadata": {},
497 | "outputs": [],
498 | "source": []
499 | }
500 | ],
501 | "metadata": {
502 | "kernelspec": {
503 | "display_name": "Python 2",
504 | "language": "python",
505 | "name": "python2"
506 | },
507 | "language_info": {
508 | "codemirror_mode": {
509 | "name": "ipython",
510 | "version": 2
511 | },
512 | "file_extension": ".py",
513 | "mimetype": "text/x-python",
514 | "name": "python",
515 | "nbconvert_exporter": "python",
516 | "pygments_lexer": "ipython2",
517 | "version": "2.7.15"
518 | }
519 | },
520 | "nbformat": 4,
521 | "nbformat_minor": 2
522 | }
523 |
--------------------------------------------------------------------------------
/3_WMD.py:
--------------------------------------------------------------------------------
1 |
2 | # coding: utf-8
3 |
4 | # # Word Mover's distance
5 | #
6 | # In this note notebook we will see an application of Optimal Transport to the problem of computing similarities between sentences and texts. The method under the lens is called 'Word Mover's Distance' in reference to 'Earth Mover's Distance', another name of the Wasserstein $1$ distance, mostly used in computer vision.
7 | #
8 | # Traditionnally, portions of texts are compared by Cosine similarity on bag-of-words vectors, i.e. histograms of occurences of words in a text. It captures the exact similarity in terms of words, but two very related sentences can be orthogonal if the words that are used have the same semantic but are different. Such a semantic distance can be obtained by using *word embeddings*, that are embeddings of words in a Euclidean space (of potentially large dimension) where the Euclidean distance have a semantic meaning: two related words will be close in such embeddings. A popular embedding is the *word2vec* embedding, obtained with neural networks. A study of those mechanisms is not in the scope of this notebook, but the interested reader can find more information on [the corresponding Wikipedia page](https://en.wikipedia.org/wiki/Word2vec). Throughout the rest of this tutorial, we will use a subset of the [GloVe](https://nlp.stanford.edu/projects/glove/) embedding.
9 | #
10 | # The key observation made by Kusner and colleagues [1] is that when confronted to a sentence/document, the optimal transport distance can be used between histograms of occuring words using a ground metric obtained through word embeddings. In such a way, related words will be matched together, and the resulting distance will somehow express semantic relatedness between the content.
11 | #
12 | # [1] Kusner, M., Sun, Y., Kolkin, N., & Weinberger, K. (2015, June). From word embeddings to document distances. In International Conference on Machine Learning (pp. 957-966). http://proceedings.mlr.press/v37/kusnerb15.pdf
13 | #
14 |
15 | # ## A basic example
16 | #
17 | # We will start by reproducing the Figure $1$ in the original paper
18 | #
19 | #
20 |
21 | # Two sentences are considered: 'Obama speaks to the media in Illinois' and 'The president greets the press in Chicago'. It is clear from this example that the Cosine similarity between the two sentences indicates that the two sentences are totally not related, since there is no word in common. We will start by some imports and creating a list of the two sentences as words without stopwords that are not relevant for our analysis.
22 |
23 | # In[1]:
24 |
25 |
26 | import os
27 |
28 | import numpy as np
29 | import matplotlib.pylab as pl
30 | import ot
31 |
32 |
33 | s1 = ['Obama','speaks','media','Illinois']
34 | s2 = ['President','greets','press','Chicago']
35 |
36 |
37 | # We will use a subset of the GloVe word embedding, expressed as a dictionnary (word,embedding) that you can load this way
38 |
39 | # In[2]:
40 |
41 |
42 |
43 | model=dict(np.load('data/model.npz'))
44 |
45 |
46 | # Then the embedded representation of the sentences can be obtained by
47 |
48 | # In[3]:
49 |
50 |
51 | s1_embed = np.array([model[w] for w in s1])
52 | s2_embed = np.array([model[w] for w in s2])
53 |
54 |
55 | # From the multidimensional scaling method in Scikitlearn, try to visualize the corresponding embedding of words in 2D.
56 |
57 | # In[4]:
58 |
59 |
60 | from sklearn import manifold
61 |
62 | C = ot.dist(np.vstack((s1_embed,s2_embed)))
63 |
64 | nmds = manifold.MDS(
65 | 2,
66 | eps=1e-9,
67 | dissimilarity="precomputed",
68 | n_init=1)
69 | npos = nmds.fit_transform(C)
70 |
71 | pl.figure(figsize=(6,6))
72 | pl.scatter(npos[:4,0],npos[:4,1],c='r',s=50, edgecolor = 'k')
73 | for i, txt in enumerate(s1):
74 | pl.annotate(txt, (npos[i,0]-4,npos[i,1]+2),fontsize=20)
75 | pl.scatter(npos[4:,0],npos[4:,1],c='b',s=50, edgecolor = 'k')
76 | for i, txt in enumerate(s2):
77 | pl.annotate(txt, (npos[i+4,0]-4,npos[i+4,1]+2),fontsize=20)
78 | pl.axis('off')
79 | pl.tight_layout()
80 | pl.show()
81 |
82 |
83 | # Let's now compute the coupling between those two distributions and visualize the corresponding result
84 | #
85 |
86 | # In[5]:
87 |
88 |
89 | C2= ot.dist(s1_embed,s2_embed)
90 | G=ot.emd(ot.unif(4),ot.unif(4),C2)
91 |
92 | pl.figure(figsize=(6,6))
93 | pl.scatter(npos[:4,0],npos[:4,1],c='r',s=50, edgecolor = 'k')
94 | for i, txt in enumerate(s1):
95 | pl.annotate(txt, (npos[i,0]-4,npos[i,1]+2),fontsize=20)
96 | pl.scatter(npos[4:,0],npos[4:,1],c='b',s=50, edgecolor = 'k')
97 | for i, txt in enumerate(s2):
98 | pl.annotate(txt, (npos[i+4,0]-4,npos[i+4,1]+2),fontsize=20)
99 | for i in range(G.shape[0]):
100 | for j in range(G.shape[1]):
101 | if G[i,j]>1e-5:
102 | pl.plot([npos[i,0],npos[j+4,0]],[npos[i,1],npos[j+4,1]],'k',alpha=G[i,j]/np.max(G))
103 | pl.title('Word embedding and coupling with OT')
104 | pl.axis('off')
105 | pl.tight_layout()
106 | pl.show()
107 |
108 |
109 | # ## Sentence similarity
110 | # We will now explore the superiority of this Word mover distance (WMD) in a regression context, where our goal is to estimate the similarity (or relatedness) of two sentences on a scale of 0 to 5 (5 being the most similar). Given a set of pairs of sentences with a human annotated relatedness, our goal is predict the relatedness from a new pair of sentences.
111 | #
112 | # We will use the [SICK (Sentences Involving Compositional Knowledge) dataset](http://clic.cimec.unitn.it/composes/sick.html) for this purpose.
113 | #
114 | # We first load it.
115 |
116 | # In[6]:
117 |
118 |
119 |
120 | data=np.load('data/data_text.npz')
121 | setA=data['setA']
122 | setB=data['setB']
123 | scores=data['scores']
124 |
125 | print (setA[0])
126 | print (setB[0])
127 | print(scores[0])
128 |
129 | np.savez('data/data_text.npz',setA=setA,setB=setB,scores=scores)
130 |
131 |
132 | # We will only keep 200 sentences for learning our regression model and the rest for testing
133 |
134 | # In[7]:
135 |
136 |
137 | n=200
138 | testA=setA[n:]
139 | trainA=setA[:n]
140 | testB=setB[n:]
141 | trainB=setB[:n]
142 |
143 | scores_train=scores[:n]
144 | scores_test=scores[n:]
145 |
146 |
147 | # Using the countVectorizer model from ScikitLearn, compute all the bag-of-words representations of the sentences
148 |
149 | # In[8]:
150 |
151 |
152 | from sklearn.feature_extraction.text import CountVectorizer
153 | vect = # TO BE FILLED
154 |
155 |
156 | # Build a big data matrix of all the words present in the dataset embeddings
157 | #
158 |
159 | # In[9]:
160 |
161 |
162 | all_feat = # TO BE FILLED
163 |
164 |
165 | # Compute a big matrix of all pairwise feature distances using the dist() method of POT
166 |
167 | # In[10]:
168 |
169 |
170 | D = ot.dist(all_feat)
171 |
172 |
173 | # now you can write a code that will compute the Cosine and WMD dissimilarities from all the pairs of the training set
174 |
175 | # In[11]:
176 |
177 |
178 | X_cos=[]
179 | X_wmd=[]
180 | Y=[]
181 |
182 |
183 |
184 | for i in range(len(trainA)):
185 | s1 = vect.transform([trainA[i]]).toarray().ravel()
186 | s2 = vect.transform([trainB[i]]).toarray().ravel()
187 | # Cosine similarity between bag of words
188 | d_cos=# TO BE FILLED
189 | X_cos.append(d_cos)
190 | # WMD
191 | d_wmd=# TO BE FILLED
192 | X_wmd.append(d_wmd)
193 | Y.append(scores_train[i])
194 |
195 |
196 |
197 | # Visualize the corresponding golden similarities / distance from the learning set. Hence you have a first appreciation of how much WMD better captures this similarity.
198 |
199 | # In[12]:
200 |
201 |
202 | pl.figure()
203 | pl.scatter(X_cos,Y)
204 | pl.title('Cosine Similarity VS golden score')
205 | pl.show()
206 | pl.figure()
207 | pl.scatter(X_wmd,Y)
208 | pl.title('WMD Similarity VS golden score')
209 | pl.show()
210 |
211 |
212 | # You can learn a simple regression model between those 2 quantities. Use a polynomial of degree 2 to learn the regression model.
213 |
214 | # In[13]:
215 |
216 |
217 | import numpy.polynomial.polynomial as poly
218 | k_cos = # TO BE FILLED
219 | k_wmd = # TO BE FILLED
220 |
221 |
222 | # Now compute from your regression model the estimated relatedness for all the pairs in the test set.
223 |
224 | # In[14]:
225 |
226 |
227 | X_cos=[]
228 | X_wmd=[]
229 | Y_test=[]
230 | for i in range(len(testA)):
231 | s1 = vect.transform([testA[i]]).toarray().ravel()
232 | s2 = vect.transform([testB[i]]).toarray().ravel()
233 | # cosine similarity between bag of words
234 | d_cos=# TO BE FILLED
235 | X_cos.append(d_cos)
236 | # WMD
237 | d_wmd=# TO BE FILLED
238 | X_wmd.append(d_wmd)
239 | Y_test.append(scores_test[i])
240 |
241 | # Final regression scores
242 | Y_cos = # TO BE FILLED
243 | Y_wmd = # TO BE FILLED
244 |
245 |
246 | # We will use MSE, Spearman's rho and Pearson coefficients to measure the quality of our regression model
247 |
248 | # In[15]:
249 |
250 |
251 | from sklearn.metrics import mean_squared_error as mse
252 | from scipy.stats import pearsonr
253 | from scipy.stats import spearmanr
254 |
255 |
256 | # Estimate the quality of your regression model for both Cosine and WMD dissimilarities
257 |
258 | # In[16]:
259 |
260 |
261 | print('-------- Cosine')
262 |
263 | pr = pearsonr(Y_cos, Y_test)[0]
264 | sr = spearmanr(Y_cos,Y_test)[0]
265 | se = mse(Y_cos,Y_test)
266 |
267 | print('Test Pearson (test): ' + str(pr))
268 | print('Test Spearman (test): ' + str(sr))
269 | print('Test MSE (test): ' + str(se))
270 |
271 | print('-------- WMD')
272 |
273 | pr = pearsonr(Y_wmd, Y_test)[0]
274 | sr = spearmanr(Y_wmd,Y_test)[0]
275 | se = mse(Y_wmd,Y_test)
276 |
277 | print('Test Pearson (test): ' + str(pr))
278 | print('Test Spearman (test): ' + str(sr))
279 | print('Test MSE (test): ' + str(se))
280 |
281 |
282 | # Not bad isn't it ?
283 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Rémi Flamary
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Optimal Transport and Machine Learning Course 2022
2 |
3 | Courses and practical sessions for the Optimal Transport and Machine learning
4 | course.
5 |
6 |
7 |
8 | ### Course
9 |
10 |
11 | The slides from the course can be downloaded here:
12 |
13 | * [Part 1](slides/Part1_intro_OT_2022.pdf) Intro to numerical Optimal Transport
14 | (N. Courty)
15 | * [Part 2](slides/Part2_UOT_GW_Rennes_2022.pdf) Unbalanced OT and OT across
16 | spaces (L. Chapel)
17 | * [Part 3](slides/Part3_OTML_Rennes_2022.pdf) Optimal Transport for Machine
18 | Learning (R. Flamary)
19 |
20 | ### Practical Sessions
21 |
22 | You can download the introductory slides to the practical session [here](https://remi.flamary.com/cours/otml/OTML_TPDS3_2018.pdf).
23 |
24 |
25 | #### Install Python and POT Toolbox
26 |
27 | In order to do the practical sessions you need to have a working Python installation.
28 | The simplest way on any OS is to install the [Anaconda](https://www.anaconda.com/download/) distribution that can be freely downloaded from [here](https://www.anaconda.com/download/).
29 |
30 | When anaconda is installed the simplest way to install pot is to launch the anaconda terminal and execute:
31 |
32 | ```
33 | conda install -c conda-forge pot
34 | ```
35 |
36 | which will install the POT OT Toolbox automatically. Note that in Window you need to launch the anaconda terminal with admnistrator mode to install with conda.
37 |
38 |
39 |
40 | #### Download the Notebooks for the session
41 |
42 | You can download all the necessary files here: [OTML_DS3_2018.zip](https://github.com/rflamary/OTML_DS3_2018/archive/master.zip)
43 |
44 | The zip file contains the following session:
45 |
46 | 0. [Introduction to OT with POT](0_Intro_OT.ipynb)
47 | 1. [Domain adaptation on digits with OT](1_DomainAdaptation.ipynb)
48 | 2. [Color Grading with OT](2_ColorGrading.ipynb)
49 | 3. [Word Mover's Distance on text](3_WMD.ipynb)
50 |
51 | You can choose to do the practical session using the notebooks included or the python script. We recommend Notebooks for beginners.
52 |
53 | The solutions for the practical sessions can be obtained at the following URL:
54 |
55 | ```
56 | https://remi.flamary.com/cours/otml/solution_[NUMBER].zip
57 | ```
58 |
59 | Where [NUMBER] has to be replaced by the integer part of the value of the
60 | Wasserstein distance obtained in Practical [Session 0](0_Intro_OT.ipynb) using
61 | the Manhattan/Cityblock ground metric (without normalization of the marginals).
62 |
--------------------------------------------------------------------------------
/data/data_text.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/data/data_text.npz
--------------------------------------------------------------------------------
/data/klimt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/data/klimt.jpg
--------------------------------------------------------------------------------
/data/manhattan.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/data/manhattan.npz
--------------------------------------------------------------------------------
/data/mnist_usps.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/data/mnist_usps.npz
--------------------------------------------------------------------------------
/data/model.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/data/model.npz
--------------------------------------------------------------------------------
/data/schiele.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/data/schiele.jpg
--------------------------------------------------------------------------------
/slides/Part1_intro_OT_2022.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/slides/Part1_intro_OT_2022.pdf
--------------------------------------------------------------------------------
/slides/Part2_UOT_GW_Rennes_2022.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/slides/Part2_UOT_GW_Rennes_2022.pdf
--------------------------------------------------------------------------------
/slides/Part3_OTML_Rennes_2022.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PythonOT/OTML_course_2022/4f1189e73c61937dbdcfc3a458cf5463f639d337/slides/Part3_OTML_Rennes_2022.pdf
--------------------------------------------------------------------------------