├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── examples ├── 01_basics.py ├── 02_video.py ├── 03_pytorch_integration.py ├── README.md └── tutorials │ ├── Basics.ipynb │ ├── L_dev_video.npy │ ├── L_train_video.npy │ ├── Video.ipynb │ ├── Y_dev_video.npy │ ├── __init__.py │ └── tutorial_helpers.py ├── figs ├── System Diagram.png ├── graphical_structure_simple.png ├── graphical_structure_video.png ├── graphical_structure_video_lambda_dep.png ├── logo.png ├── tennis_nonrally.png └── tennis_rally.png ├── flyingsquid ├── __init__.py ├── _graphs.py ├── _lm_parameters.py ├── _observables.py ├── _triplets.py ├── helpers.py ├── label_model.py └── pytorch_loss.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **.swp 2 | **.pyc 3 | **.ipynb_checkpoints 4 | **.egg-info 5 | build 6 | dist 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Dan Fu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # More Interactive Weak Supervision with FlyingSquid 6 | 7 | **UPDATE 06/17/20**: Code re-factored, with two new features: 8 | * Compute label model parameters by looking at all possible triplets and taking 9 | the mean or median; we find this to be more stable than just looking at a single 10 | triplet (use `label_model.fit(..., solve_method='triplet_mean')`). 11 | By default, the code now uses `triplet_mean`. 12 | * Get the estimated accuracies of each labeling function `P(lambda_i == Y)` with 13 | `label_model.estimated_accuracies()`. 14 | 15 | FlyingSquid is a new framework for automatically building models from multiple 16 | noisy label sources. 17 | Users write functions that generate noisy labels for data, and FlyingSquid uses 18 | the agreements and disagreements between them to learn a _label model_ of how 19 | accurate the _labeling functions_ are. 20 | The label model can be used directly for downstream applications, or it can be 21 | used to train a powerful end model: 22 | 23 |
24 | 25 |
26 | 27 | FlyingSquid can be used to build models for all sorts of tasks, including text 28 | applications, video analysis, and online learning. 29 | Check out our [blog post](http://hazyresearch.stanford.edu/flyingsquid) and paper on 30 | [arXiv](https://arxiv.org/abs/2002.11955) 31 | for more details! 32 | 33 | ## Getting Started 34 | * Quickly [install](#installation) FlyingSquid 35 | * Check out the [examples](examples/) folder for tutorials and some simple code 36 | examples 37 | 38 | ## Sample Usage 39 | ```Python 40 | from flyingsquid.label_model import LabelModel 41 | import numpy as np 42 | 43 | L_train = np.load('...') 44 | 45 | m = L_train.shape[1] 46 | label_model = LabelModel(m) 47 | label_model.fit(L_train) 48 | 49 | preds = label_model.predict(L_train) 50 | ``` 51 | 52 | ## Installation 53 | 54 | We recommend using `conda` to install FlyingSquid: 55 | 56 | ``` 57 | git clone https://github.com/HazyResearch/flyingsquid.git 58 | 59 | cd flyingsquid 60 | 61 | conda env create -f environment.yml 62 | conda activate flyingsquid 63 | ``` 64 | 65 | Alternatively, you can install the dependencies yourself: 66 | * [Pgmpy](http://pgmpy.org/) 67 | * [PyTorch](https://pytorch.org/) (only necessary for the PyTorch integration) 68 | 69 | And then install the actual package: 70 | ``` 71 | pip install flyingsquid 72 | ``` 73 | 74 | To install from source: 75 | ``` 76 | git clone https://github.com/HazyResearch/flyingsquid.git 77 | 78 | cd flyingsquid 79 | 80 | conda env create -f environment.yml 81 | conda activate flyingsquid 82 | 83 | pip install -e . 84 | ``` 85 | 86 | ## Citation 87 | 88 | If you use our work or found it useful, please cite our [paper](https://arxiv.org/abs/2002.11955) at ICML 2020: 89 | ``` 90 | @inproceedings{fu2020fast, 91 | author = {Daniel Y. Fu and Mayee F. Chen and Frederic Sala and Sarah M. Hooper and Kayvon Fatahalian and Christopher R\'e}, 92 | title = {Fast and Three-rious: Speeding Up Weak Supervision with Triplet Methods}, 93 | booktitle = {Proceedings of the 37th International Conference on Machine Learning (ICML 2020)}, 94 | year = {2020}, 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: flyingsquid 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: # New core deps should also be added to setup.py for pip 6 | - python=3.7 7 | - dill 8 | - networkx>=2.2 9 | - numpy 10 | - pandas 11 | - pytorch>=1 12 | - scipy 13 | - tornado<6 # tornado v6 introduced a breaking depedency 14 | - pip 15 | - pip: 16 | - tensorboardX==1.4 17 | - scipy 18 | - pyparsing 19 | - statsmodels 20 | - tqdm 21 | - joblib 22 | - pgmpy 23 | - flyingsquid 24 | - matplotlib=3 25 | - nose 26 | - jupyter 27 | - nb_conda_kernels 28 | - runipy 29 | - scikit-learn 30 | - torchvision 31 | - tensorboard 32 | -------------------------------------------------------------------------------- /examples/01_basics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This example code shows a bare-minimum example of how to get FlyingSquid up and 3 | running. 4 | 5 | It generates synthetic data from the tutorials folder, and trains up a label 6 | model. 7 | 8 | You can run this file from the examples folder. 9 | ''' 10 | 11 | from flyingsquid.label_model import LabelModel 12 | from tutorials.tutorial_helpers import * 13 | 14 | L_train, L_dev, Y_dev = synthetic_data_basics() 15 | 16 | m = L_train.shape[1] 17 | label_model = LabelModel(m) 18 | 19 | label_model.fit(L_train) 20 | 21 | preds = label_model.predict(L_dev).reshape(Y_dev.shape) 22 | accuracy = np.sum(preds == Y_dev) / Y_dev.shape[0] 23 | 24 | print('Label model accuracy: {}%'.format(int(100 * accuracy))) 25 | -------------------------------------------------------------------------------- /examples/02_video.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This example code shows how to train a FlyingSquid label model for video data. 3 | 4 | It loads some labeling functions to detect Tennis Rallies from the tutorials 5 | folder, and trains up a label model. 6 | 7 | You can run this file from the examples folder. 8 | ''' 9 | 10 | from flyingsquid.label_model import LabelModel 11 | import numpy as np 12 | 13 | L_train = np.load('tutorials/L_train_video.npy') 14 | L_dev = np.load('tutorials/L_dev_video.npy') 15 | Y_dev = np.load('tutorials/Y_dev_video.npy') 16 | 17 | # Model three frames at a time 18 | v = 3 19 | 20 | # Six labeling functions per frame 21 | m_per_frame = 6 22 | 23 | # Total number of labeling functions is m_per_frame * v 24 | m = m_per_frame * v 25 | 26 | # Figure out how many sequences we're going to have 27 | n_frames_train = L_train.shape[0] 28 | n_frames_dev = L_dev.shape[0] 29 | 30 | n_seqs_train = n_frames_train // v 31 | n_seqs_dev = n_frames_dev // v 32 | 33 | # Resize and reshape matrices 34 | L_train_seqs = L_train[:n_seqs_train * v].reshape((n_seqs_train, m)) 35 | L_dev_seqs = L_dev[:n_seqs_dev * v].reshape((n_seqs_dev, m)) 36 | Y_dev_seqs = Y_dev[:n_seqs_dev * v].reshape((n_seqs_dev, v)) 37 | 38 | # Create the label model with temporal dependencies 39 | label_model = LabelModel( 40 | m, 41 | v = v, 42 | y_edges = [ (i, i + 1) for i in range(v - 1) ], 43 | lambda_y_edges = [ (i, i // m_per_frame) for i in range(m) ] 44 | ) 45 | 46 | label_model.fit(L_train_seqs) 47 | 48 | probabilistic_labels = label_model.predict_proba_marginalized(L_dev_seqs) 49 | preds = [ 1. if prob > 0.5 else -1. for prob in probabilistic_labels ] 50 | accuracy = np.sum(preds == Y_dev[:n_seqs_dev * v]) / (n_seqs_dev * v) 51 | 52 | print('Label model accuracy: {}%'.format(int(100 * accuracy))) 53 | -------------------------------------------------------------------------------- /examples/03_pytorch_integration.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This example code shows how to use the PyTorch integration for online training 3 | (example data loaders and training loop). 4 | 5 | This code is only provided as a reference. In a real application, you would 6 | need to load in actual image paths to train over. 7 | ''' 8 | 9 | from flyingsquid.label_model import LabelModel 10 | from flyingsquid.pytorch_loss import FSLoss 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.optim import lr_scheduler 16 | import numpy as np 17 | import torchvision 18 | from torchvision import datasets, models, transforms 19 | from torch.utils.data import Dataset, DataLoader 20 | 21 | # Load in the L matrices 22 | L_train = np.load('tutorials/L_train_video.npy') 23 | L_dev = np.load('tutorials/L_dev_video.npy') 24 | Y_dev = np.load('tutorials/Y_dev_video.npy') 25 | 26 | # This is where you would load in the images corresponding to the rows 27 | X_paths_train = np.load('....') 28 | X_paths_dev = np.load('....') 29 | 30 | # Example dataloader for FSLoss 31 | class ImageFSDataset(Dataset): 32 | def __init__(self, paths, weak_labels, T, gt_labels=None, 33 | transform=None): 34 | self.T = T 35 | self.paths = self.paths[:len(self.paths) - (len(self.paths) % T)] 36 | self.weak_labels = self.weak_labels[:self.weak_labels.shape[0] - 37 | (self.weak_labels.shape[0] % T)] 38 | m_per_task = self.weak_labels.shape[1] 39 | 40 | self.transform = transform 41 | 42 | n_frames = self.weak_labels.shape[0] 43 | n_seqs = n_frames // T 44 | 45 | v = T 46 | m = m_per_task * T 47 | 48 | self.data_temporal = { 49 | 'paths': np.reshape(self.paths, (n_seqs, v)), 50 | 'weak_labels': np.reshape(self.weak_labels, (n_seqs, m)) 51 | } 52 | 53 | self.gt_labels = gt_labels 54 | if gt_labels is not None: 55 | self.gt_labels = self.gt_labels[:len(self.gt_labels) - 56 | (len(self.gt_labels) % T)] 57 | self.data_temporal['gt_labels'] = np.reshape(self.gt_labels, (n_seqs, v)) 58 | 59 | def __len__(self): 60 | return self.data_temporal['paths'].shape[0] 61 | 62 | def __getitem__(self, idx): 63 | paths_seq = self.data_temporal['paths'][idx] 64 | 65 | img_tensors = [ 66 | torch.unsqueeze( 67 | self.transform(Image.open(path).convert('RGB')), 68 | dim = 0) 69 | for path in paths_seq 70 | ] 71 | 72 | weak_labels = self.data_temporal['weak_labels'][idx] 73 | 74 | if self.gt_labels is not None: 75 | return (torch.cat(img_tensors), 76 | torch.unsqueeze(torch.tensor(weak_labels), dim=0), 77 | torch.unsqueeze(torch.tensor(self.data_temporal['gt_labels'][idx]), dim = 0)) 78 | else: 79 | return torch.cat(img_tensors), torch.unsqueeze(torch.tensor(weak_labels), dim = 0) 80 | 81 | # Example training loop 82 | def train_model_online(model, T, criterion, optimizer, dataset): 83 | model.train() 84 | dataset_size = len(dataset) * T 85 | 86 | for item in dataset: 87 | image_tensor = item[0] 88 | weak_labels = item[1] 89 | labels = None if dataset.gt_labels is None else item[2] 90 | 91 | # zero the parameter gradients 92 | optimizer.zero_grad() 93 | 94 | # forward 95 | with torch.set_grad_enabled(True): 96 | outputs = model(inputs) 97 | 98 | loss = criterion(torch.unsqueeze(outputs, dim = 0), weak_labels) 99 | 100 | # backward + optimize 101 | loss.backward() 102 | optimizer.step() 103 | 104 | return model 105 | 106 | # Model three frames at a time 107 | v = 3 108 | 109 | # Set up the dataset 110 | train_dataset = ImageFSDataset(X_paths_train, L_train, v) 111 | 112 | # Set up the loss function 113 | fs_criterion = FSLoss( 114 | m, 115 | v = v, 116 | y_edges = [ (i, i + 1) for i in range(v - 1) ], 117 | lambda_y_edges = [ (i, i // m_per_frame) for i in range(m) ] 118 | ) 119 | 120 | # Train up a model online 121 | model = models.resnet50(pretrained=True) 122 | num_ftrs = model_online.fc.in_features 123 | model.fc = nn.Linear(num_ftrs, 1) 124 | 125 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 126 | 127 | model = train_model_online(model, v, fs_criterion, optimizer, dataset) 128 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # FlyingSquid Examples 2 | 3 | This directory contains a few simple examples and tutorials to show how to use 4 | FlyingSquid. 5 | 6 | ## Tutorials 7 | 8 | We have two tutorials to help you get started: 9 | * [Basics.ipynb](tutorials/Basics.ipynb): a Jupyter notebook that uses 10 | some simple synthetic data to introduce FlyingSquid's API 11 | * [Video.ipynb](tutorials/Video.ipynb): a Jupyter notebook that shows how 12 | to use FlyingSquid to model sequential dependencies for applications like video 13 | 14 | ## Examples 15 | 16 | * [01_basics.py](01_basics.py): example code to show how to model labeling 17 | functions 18 | * [02_video.py](02_video.py): example code to show how to model labeling 19 | functions for video 20 | * [03_pytorch_integration.py](03_pytorch_integration.py): example code to show 21 | how to integrate FlyingSquid into PyTorch for online learning 22 | -------------------------------------------------------------------------------- /examples/tutorials/Basics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# FlyingSquid Basics\n", 8 | "\n", 9 | "In this notebook, we'll use some synthetic data to introduce you to FlyingSquid's API. In this notebook, we'll cover the three steps of the FlyingSquid pipeline using synthetic data:\n", 10 | "\n", 11 | "
\n", 12 | " \n", 13 | "
\n", 14 | "\n", 15 | "First, we'll generate some synthetic labeling function outputs. Next, we'll use FlyingSquid to model the accuracies of these labeling functions (without any ground truth data). Finally, we'll generate probabilistic training labels for downstream model training." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Step 1: Generate Synthetic Labeling Function Outputs\n", 23 | "\n", 24 | "Let's generate some synthetic labeling function outputs.\n", 25 | "\n", 26 | "For a real application, we would write `m` labeling functions that would generate any of the three following labels for each data point:\n", 27 | "\n", 28 | "* Positive: return +1\n", 29 | "* Negative: return -1\n", 30 | "* Abstain: return 0\n", 31 | "\n", 32 | "We would run the `m` labeling functions over `n` data points to get an `(n, m)`-sized matrix. For this tutorial, the `synthetic_data_basics` function will do that for us:" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "(10000, 5)\n", 45 | "(500, 5)\n", 46 | "(500,)\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "from tutorial_helpers import *\n", 52 | "L_train, L_dev, Y_dev = synthetic_data_basics()\n", 53 | "print(L_train.shape)\n", 54 | "print(L_dev.shape)\n", 55 | "print(Y_dev.shape)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "As you can see, we have five synthetic labeling functions that have generated labels for an unlabeled training set with 10,000 data points, and a labeled dev set with 500 labeled data points." 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "We can use the dev set to see how accurate our labeling functions are:" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "LF 0: Accuracy 93%, Abstain rate 78%\n", 82 | "LF 1: Accuracy 63%, Abstain rate 87%\n", 83 | "LF 2: Accuracy 62%, Abstain rate 30%\n", 84 | "LF 3: Accuracy 59%, Abstain rate 37%\n", 85 | "LF 4: Accuracy 46%, Abstain rate 48%\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "print_statistics(L_dev, Y_dev) " 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "As you can see, we have two labeling functions that have high accuracies but also high abstain rates (LF 0 and LF 1), and three labeling functions with lower abstain rates but also lower accuracies. We can inspect the `L_dev` and `Y_dev` matrices to see the data formats:" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 3, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "[[ 0. 0. -1. 1. 1.]\n", 110 | " [ 0. 0. -1. 1. 1.]\n", 111 | " [ 0. 0. -1. -1. 0.]\n", 112 | " [ 0. 0. -1. -1. -1.]\n", 113 | " [ 0. 0. -1. 0. 0.]\n", 114 | " [ 0. 0. 1. 1. 0.]\n", 115 | " [ 0. 0. 1. 1. 0.]\n", 116 | " [ 0. 0. -1. 0. 0.]\n", 117 | " [-1. 0. 0. 1. 1.]\n", 118 | " [ 0. 0. 1. -1. 0.]]\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "print(L_dev[:10])" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "[ 1. -1. -1. -1. -1. 1. -1. -1. -1. 1.]\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "print(Y_dev[:10])" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "## Step 2: Model the labeling functions with FlyingSquid\n", 148 | "\n", 149 | "Next, we're going to use FlyingSquid to model the five labeling functions. We'll use this dependency graph:\n", 150 | "\n", 151 | "
\n", 152 | " \n", 153 | "
\n", 154 | "\n", 155 | "As you can see, we have one (hidden) node for the latent ground truth variable Y, and five (observable) nodes for each labeling function.\n", 156 | "\n", 157 | "To model that in FlyingSquid, we just need to specify that we have `m = 5` labeling functions. Since we only have a single task, the dependencies are automatically inferred (see the video tutorial for more complex dependencies)." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "from flyingsquid.label_model import LabelModel" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 6, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "m = 5\n", 176 | "\n", 177 | "label_model = LabelModel(m)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "To train the label model, all we need to do is pass `L_train` to the fit function:" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 7, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "label_model.fit(L_train)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "### Evaluating the label model\n", 201 | "\n", 202 | "Now, let's use the dev set to evaluate the label model:" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "Label model accuracy: 70%\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "preds = label_model.predict(L_dev).reshape(Y_dev.shape)\n", 220 | "accuracy = np.sum(preds == Y_dev) / Y_dev.shape[0]\n", 221 | "\n", 222 | "print('Label model accuracy: {}%'.format(int(100 * accuracy)))" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "We can see that this performs better than majority vote:" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 9, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "Majority vote accuracy: 65%\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "majority_vote_preds = np.array([1 if pred > 0 else -1 for pred in np.sum(L_dev, axis=1)])\n", 247 | "majority_vote_accuracy = np.sum(majority_vote_preds == Y_dev) / Y_dev.shape[0]\n", 248 | "\n", 249 | "print('Majority vote accuracy: {}%'.format(int(100 * majority_vote_accuracy)))" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "## Step 3: Training an End Model\n", 257 | "\n", 258 | "If necessary, we can also use FlyingSquid to generate probabilistic labels to train up an end model. Instead of calling the `predict` function, we can call `predict_proba_marginalized` over `L_train`:" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 10, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "(10000,)\n", 271 | "[0.46439535 0.89805256 0.72736331 0.48237588 0.2962007 0.2633458\n", 272 | " 0.66693893 0.53600092 0.72736331 0.3213108 ]\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "probabilistic_labels = label_model.predict_proba_marginalized(L_train)\n", 278 | "print(probabilistic_labels.shape)\n", 279 | "print(probabilistic_labels[:10])" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "These labels can be used as training labels for a powerful downstream end model (pick your favorite deep network). We can generate labels for every single data point, even in the presence of multiple abstains from the labeling functions." 287 | ] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "Python [conda env:flyingsquid]", 293 | "language": "python", 294 | "name": "conda-env-flyingsquid-py" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.7.6" 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 2 311 | } 312 | -------------------------------------------------------------------------------- /examples/tutorials/L_dev_video.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/examples/tutorials/L_dev_video.npy -------------------------------------------------------------------------------- /examples/tutorials/L_train_video.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/examples/tutorials/L_train_video.npy -------------------------------------------------------------------------------- /examples/tutorials/Video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# FlyingSquid for Video\n", 8 | "\n", 9 | "In this notebook, we'll use FlyingSquid to train a label model for sequential video data. In this application, we'll be using broadcast tennis footage to detect tennis rallies (when the two players are continuously hitting the ball back and forth).\n", 10 | "\n", 11 | "\n", 12 | "\n", 13 | " \n", 14 | " \n", 15 | " \n", 16 | " \n", 17 | " \n", 18 | " \n", 19 | " \n", 20 | " \n", 21 | "
Rally SegmentNon-Rally Segment
" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Step 1: Load Labeling Function Outputs\n", 29 | "Again, we'll start by loading labeling function outputs. This time, we'll load in the outputs from actual labeling functions." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "(6959, 6)\n", 42 | "(746, 6)\n", 43 | "(746,)\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "import numpy as np\n", 49 | "from tutorial_helpers import *\n", 50 | "\n", 51 | "L_train = np.load('L_train_video.npy')\n", 52 | "L_dev = np.load('L_dev_video.npy')\n", 53 | "Y_dev = np.load('Y_dev_video.npy')\n", 54 | "\n", 55 | "print(L_train.shape)\n", 56 | "print(L_dev.shape)\n", 57 | "print(Y_dev.shape)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "In our L and Y matrices, each row represents a single frame in the video. Since rallies occur over contiguous frames, notice that the ground truth annotations have contiguous segments of the same label:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 76 | " 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", 77 | " -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 78 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1,\n", 79 | " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", 80 | " -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])" 81 | ] 82 | }, 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "Y_dev[:100]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "These labeling functions were written using [Rekall queries](https://github.com/scanner-research/rekall) and express heuristics like the number of people on court, the size of the people, and the number of near-white pixels. Notice that abstain rates tend to be much lower when using Rekall queries, since many queries automatically label the whole video:" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 3, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "LF 0: Accuracy 87%, Abstain rate 0%\n", 109 | "LF 1: Accuracy 89%, Abstain rate 0%\n", 110 | "LF 2: Accuracy 57%, Abstain rate 85%\n", 111 | "LF 3: Accuracy 84%, Abstain rate 39%\n", 112 | "LF 4: Accuracy 86%, Abstain rate 0%\n", 113 | "LF 5: Accuracy 62%, Abstain rate 60%\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "print_statistics(L_dev, Y_dev)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Step 2: Model the labeling functions (and temporal dependencies) with FlyingSquid\n", 126 | "\n", 127 | "Next, we're going to use FlyingSquid to model the labeling functions. But we're going to end up doing something slightly more complicated in order to model temporal dependencies.\n", 128 | "\n", 129 | "We'll model three frames at a time with three hidden variables, and model each labeling function labeling an individual frame in the sequence:\n", 130 | "\n", 131 | "
\n", 132 | " \n", 133 | "
\n", 134 | "\n", 135 | "In a given sequence of three frames, `lambda_0`, `lambda_6`, and `lambda_12` model LF 0's outputs on the first, second, and third frames, respectively. Similarly, `lambda_1`, `lambda_7`, and `lambda_13` model LF 1's outputs on the first, second, and third frames.\n", 136 | "\n", 137 | "Our first step is resizing and reshaping our matrices:" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# We use v to denote the length of sequence we're modeling\n", 147 | "v = 3\n", 148 | "\n", 149 | "# Six labeling functions per frame\n", 150 | "m_per_frame = 6\n", 151 | "\n", 152 | "# Total number of labeling functions is m_per_frame * v\n", 153 | "m = m_per_frame * v\n", 154 | "\n", 155 | "# Figure out how many sequences we're going to have\n", 156 | "n_frames_train = L_train.shape[0]\n", 157 | "n_frames_dev = L_dev.shape[0]\n", 158 | "\n", 159 | "n_seqs_train = n_frames_train // v\n", 160 | "n_seqs_dev = n_frames_dev // v\n", 161 | "\n", 162 | "# Resize and reshape matrices\n", 163 | "L_train_seqs = L_train[:n_seqs_train * v].reshape((n_seqs_train, m))\n", 164 | "L_dev_seqs = L_dev[:n_seqs_dev * v].reshape((n_seqs_dev, m))\n", 165 | "Y_dev_seqs = Y_dev[:n_seqs_dev * v].reshape((n_seqs_dev, v))" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "Next, we'll use FlyingSquid to model this dependency structure:" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 5, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "from flyingsquid.label_model import LabelModel" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 6, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "label_model = LabelModel(\n", 191 | " m,\n", 192 | " v = v,\n", 193 | " y_edges = [\n", 194 | " (0, 1), (1, 2)\n", 195 | " ],\n", 196 | " lambda_y_edges = [\n", 197 | " (0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0),\n", 198 | " (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1),\n", 199 | " (12, 2), (13, 2), (14, 2), (15, 2), (16, 2), (17, 2),\n", 200 | " ]\n", 201 | ")" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "Let's explain each argument that we just passed to `LabelModel`:\n", 209 | "* Our first argument is still `m`, the total number of labeling functions.\n", 210 | "* The second argument, `v`, specifies how many sequences we're modeling.\n", 211 | "* The third argument, `y_edges`, specifies (non-directional) edges between the hidden variables. Each pair in the array specifies an edge; in this case, we are specifying edges between `Y_0` and `Y_1`, and between `Y_1` and `Y_2`.\n", 212 | "* The fourth argument, `lambda_y_edges`, specifies (non-directional) edges between observable variables and hidden variables. In this case, each pair in the array specifies an edge by using the first item to index into the observable varialbes, and using the second item to index into the hidden variables.\n", 213 | "\n", 214 | "Now that we understand what's going on, we can actually express this in fewer lines of code:" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# A simpler way to build the label model\n", 224 | "label_model = LabelModel(\n", 225 | " m,\n", 226 | " v = v,\n", 227 | " y_edges = [ (i, i + 1) for i in range(v - 1) ],\n", 228 | " lambda_y_edges = [ (i, i // m_per_frame) for i in range(m) ]\n", 229 | ")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "The above code will work for most video tasks!" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "To train the label model, all we need to do is pass `L_train_seqs` to the fit function:" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 8, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "label_model.fit(L_train_seqs)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "### Evaluating the label model\n", 260 | "\n", 261 | "Now, let's use the dev set to evaluate the label model.\n", 262 | "\n", 263 | "Since we are now modeling sequences, we want to use the function `predict_proba_marginalized` to get predictions for individual frames:" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 9, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "Label model accuracy: 88%\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "probabilistic_labels = label_model.predict_proba_marginalized(L_dev_seqs)\n", 281 | "preds = [ 1. if prob > 0.5 else -1. for prob in probabilistic_labels ]\n", 282 | "accuracy = np.sum(preds == Y_dev[:n_seqs_dev * v]) / (n_seqs_dev * v)\n", 283 | "\n", 284 | "print('Label model accuracy: {}%'.format(int(100 * accuracy)))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "## Step 3: Training an End Model\n", 292 | "If necessary, we can use the probabilistic labels to train up an end model. All we need to do is call `predict_proba_marginalized` over `L_train_seqs`:" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 10, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "(6957,)\n", 305 | "[9.95970565e-01 9.45002920e-01 9.61349604e-01 9.95970565e-01\n", 306 | " 9.45002920e-01 9.98849719e-01 9.99250640e-01 9.98525418e-01\n", 307 | " 9.96477675e-01 9.99953165e-01 9.99213851e-01 7.91527313e-01\n", 308 | " 9.99991330e-01 9.82809821e-01 9.98691913e-01 9.99944174e-01\n", 309 | " 9.98525418e-01 9.99930068e-01 9.99998119e-01 9.99658311e-01\n", 310 | " 9.99930068e-01 9.99998119e-01 9.99658311e-01 9.99930068e-01\n", 311 | " 9.99998119e-01 9.99658311e-01 9.23526306e-01 1.15346375e-03\n", 312 | " 1.28974667e-03 2.06006806e-04 1.15346375e-03 1.74129763e-03\n", 313 | " 2.34298998e-04 1.15346375e-03 8.54307705e-03 7.48697405e-03\n", 314 | " 5.71032053e-03 8.54307705e-03 7.48697405e-03 9.99250640e-01\n", 315 | " 9.98525418e-01 2.18199807e-01 9.99991330e-01 9.87216252e-01\n", 316 | " 9.96477675e-01 9.99991330e-01 9.98525418e-01 9.95995757e-01\n", 317 | " 9.99106918e-01 9.82809821e-01 9.98691913e-01 9.99944174e-01\n", 318 | " 6.33889805e-03 6.58867471e-03 4.79508961e-03 9.86709989e-01\n", 319 | " 9.96477675e-01 9.99992726e-01 9.99213851e-01 9.98849719e-01\n", 320 | " 9.99250640e-01 1.74129763e-03 2.34298998e-04 1.15346375e-03\n", 321 | " 1.74129763e-03 2.34298998e-04 1.15346375e-03 9.98938427e-01\n", 322 | " 9.13925457e-01 9.99953165e-01 9.82809821e-01 9.96477675e-01\n", 323 | " 9.99992726e-01 9.99213851e-01 9.98849719e-01 9.99992726e-01\n", 324 | " 9.98009283e-01 9.96477675e-01 9.99106918e-01 9.98938427e-01\n", 325 | " 9.95995757e-01 9.99944174e-01 9.98009283e-01 2.33811138e-03\n", 326 | " 9.67879509e-04 1.28974667e-03 2.34298998e-04 1.15346375e-03\n", 327 | " 1.74129763e-03 2.34298998e-04 1.15346375e-03 1.74129763e-03\n", 328 | " 9.23526306e-01 9.99250640e-01 9.99213851e-01 9.98849719e-01\n", 329 | " 9.99992726e-01 9.99213851e-01 6.67422449e-01 3.18205893e-01]\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "probabilistic_labels = label_model.predict_proba_marginalized(L_train_seqs)\n", 335 | "print(probabilistic_labels.shape)\n", 336 | "print(probabilistic_labels[:100])" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "## Bonus: Dependencies between labeling functions\n", 344 | "\n", 345 | "Now that we know how to specify dependencies manually, it's a simple step to specify dependencies between labeling functions:\n", 346 | "\n", 347 | "
\n", 348 | " \n", 349 | "
\n", 350 | "\n", 351 | "All you have to do is pass in an extra argument, `lambda_edges`, that specifies edges between observable variables. For example, you can specify the above graph like this:" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 11, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "# Video Label model with dependencies between lambda_0 and lambda_1\n", 361 | "label_model = LabelModel(\n", 362 | " m,\n", 363 | " v = v,\n", 364 | " y_edges = [ (i, i + 1) for i in range(v - 1) ],\n", 365 | " lambda_y_edges = [ (i, i // m_per_frame) for i in range(m) ],\n", 366 | " lambda_edges = [ (0, 1) ]\n", 367 | ")" 368 | ] 369 | } 370 | ], 371 | "metadata": { 372 | "kernelspec": { 373 | "display_name": "Python [conda env:flyingsquid]", 374 | "language": "python", 375 | "name": "conda-env-flyingsquid-py" 376 | }, 377 | "language_info": { 378 | "codemirror_mode": { 379 | "name": "ipython", 380 | "version": 3 381 | }, 382 | "file_extension": ".py", 383 | "mimetype": "text/x-python", 384 | "name": "python", 385 | "nbconvert_exporter": "python", 386 | "pygments_lexer": "ipython3", 387 | "version": "3.7.6" 388 | } 389 | }, 390 | "nbformat": 4, 391 | "nbformat_minor": 2 392 | } 393 | -------------------------------------------------------------------------------- /examples/tutorials/Y_dev_video.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/examples/tutorials/Y_dev_video.npy -------------------------------------------------------------------------------- /examples/tutorials/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/examples/tutorials/__init__.py -------------------------------------------------------------------------------- /examples/tutorials/tutorial_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import seed, rand 3 | import itertools 4 | 5 | def exponential_family (lam, y, theta, theta_y): 6 | # without normalization 7 | return np.exp(theta_y * y + y * np.dot(theta, lam)) 8 | 9 | # create vector describing cumulative distribution of lambda_1, ... lambda_m, Y 10 | def make_pdf(m, v, theta, theta_y, lst): 11 | p = np.zeros(len(lst)) 12 | for i in range(len(lst)): 13 | labels = lst[i] 14 | p[i] = exponential_family(labels[0:m], labels[v-1], theta, theta_y) 15 | 16 | return p/sum(p) 17 | 18 | def make_cdf(pdf): 19 | return np.cumsum(pdf) 20 | 21 | # draw a set of lambda_1, ... lambda_m, Y based on the distribution 22 | def sample(lst, cdf): 23 | r = np.random.random_sample() 24 | smaller = np.where(cdf < r)[0] 25 | if len(smaller) == 0: 26 | i = 0 27 | else: 28 | i = smaller.max() + 1 29 | return lst[i] 30 | 31 | def generate_data(n, theta, m, theta_y=0): 32 | v = m+1 33 | 34 | lst = list(map(list, itertools.product([-1, 1], repeat=v))) 35 | pdf = make_pdf(m, v, theta, theta_y, lst) 36 | cdf = make_cdf(pdf) 37 | 38 | sample_matrix = np.zeros((n,v)) 39 | for i in range(n): 40 | sample_matrix[i,:] = sample(lst,cdf) 41 | 42 | return sample_matrix 43 | 44 | def synthetic_data_basics(): 45 | seed(0) 46 | 47 | n_train = 10000 48 | n_dev = 500 49 | 50 | m = 5 51 | theta = [1.5,.5,.2,.2,.05] 52 | abstain_rate = [.8, .88, .28, .38, .45] 53 | 54 | train_data = generate_data(n_train, theta, m) 55 | dev_data = generate_data(n_dev, theta, m) 56 | 57 | L_train = train_data[:,:-1] 58 | L_dev = dev_data[:,:-1] 59 | Y_dev = dev_data[:,-1] 60 | 61 | train_values = rand(n_train * m).reshape(L_train.shape) 62 | dev_values = rand(n_dev * m).reshape(L_dev.shape) 63 | 64 | L_train[train_values < (abstain_rate,) * n_train] = 0 65 | L_dev[dev_values < (abstain_rate,) * n_dev] = 0 66 | 67 | return L_train, L_dev, Y_dev 68 | 69 | def print_statistics(L_dev, Y_dev): 70 | m = L_dev.shape[1] 71 | 72 | for i in range(m): 73 | acc = np.sum(L_dev[:,i] == Y_dev)/np.sum(L_dev[:,i] != 0) 74 | abstains = np.sum(L_dev[:,i] == 0)/Y_dev.shape[0] 75 | 76 | print('LF {}: Accuracy {}%, Abstain rate {}%'.format( 77 | i, int(acc * 100), int((abstains) * 100))) -------------------------------------------------------------------------------- /figs/System Diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/System Diagram.png -------------------------------------------------------------------------------- /figs/graphical_structure_simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/graphical_structure_simple.png -------------------------------------------------------------------------------- /figs/graphical_structure_video.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/graphical_structure_video.png -------------------------------------------------------------------------------- /figs/graphical_structure_video_lambda_dep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/graphical_structure_video_lambda_dep.png -------------------------------------------------------------------------------- /figs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/logo.png -------------------------------------------------------------------------------- /figs/tennis_nonrally.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/tennis_nonrally.png -------------------------------------------------------------------------------- /figs/tennis_rally.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/figs/tennis_rally.png -------------------------------------------------------------------------------- /flyingsquid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/flyingsquid/28a713a9ac501b7597c2489468ae189943d00685/flyingsquid/__init__.py -------------------------------------------------------------------------------- /flyingsquid/_graphs.py: -------------------------------------------------------------------------------- 1 | from pgmpy.models import MarkovModel 2 | from pgmpy.factors.discrete import JointProbabilityDistribution, DiscreteFactor 3 | from itertools import combinations 4 | from flyingsquid.helpers import * 5 | import numpy as np 6 | import math 7 | from tqdm import tqdm 8 | import sys 9 | import random 10 | 11 | class Mixin: 12 | ''' 13 | Functions to check whether we can solve this graph structure. 14 | ''' 15 | 16 | def _is_separator(self, srcSet, dstSet, separatorSet): 17 | '''Check if separatorSet separates srcSet from dstSet. 18 | 19 | Tries to find a path from some node in srcSet to some node in dstSet that doesn't 20 | pass through separatorSet. If successful, return False. Otherwise, return True. 21 | ''' 22 | def neighbors(node): 23 | neighbor_set = set() 24 | for edge in self.G.edges: 25 | if edge[0] == node: 26 | neighbor_set.add(edge[1]) 27 | if edge[1] == node: 28 | neighbor_set.add(edge[0]) 29 | return list(neighbor_set) 30 | 31 | visited = set() 32 | for srcNode in srcSet: 33 | if srcNode in dstSet: 34 | return False 35 | queue = [srcNode] 36 | 37 | curNode = srcNode 38 | 39 | while len(queue) > 0: 40 | curNode = queue.pop() 41 | if curNode not in visited: 42 | visited.add(curNode) 43 | else: 44 | continue 45 | 46 | for neighbor in neighbors(curNode): 47 | if neighbor == srcNode: 48 | continue 49 | if neighbor in dstSet: 50 | return False 51 | if neighbor in separatorSet: 52 | continue 53 | if neighbor not in visited: 54 | queue.append(neighbor) 55 | 56 | return True 57 | 58 | 59 | def _check(self): 60 | '''Check to make sure we can solve this. 61 | 62 | Checks: 63 | * For each node or separator set in the junction tree: 64 | There is either only one Y node in the clique, or the clique is made up entirely of Y nodes, since 65 | we can only estimate marginals where there is at most one Y, unless the entire marginal is 66 | made up of Y's) 67 | * For each node or separator set in the junction tree that contains at least one 68 | lambda node and exactly one Y node: 69 | The Y node separates the lambda's from at least two other lambda nodes, that are themselves 70 | separated by Y. To estimate the marginal mu(lambda_i, ..., lambda_j, Y_k), we need to find 71 | lambda_a, lambda_b such that lambda_a, lambda_b, and the joint (lambda_i, ..., lambda_j) are 72 | independent conditioned on Y_k. This amounts to Y_k separating lambda_a, lambda_b, and 73 | (lambda_i, ..., lambda_j). Note that lambda_i, ..., lambda_j do not have to be separated by Y_k. 74 | 75 | Outputs: True if we can solve this, False otherwise. 76 | ''' 77 | def num_Ys(nodes): 78 | return len([ 79 | node for node in nodes if 'Y' in node 80 | ]) 81 | 82 | def num_lambdas(nodes): 83 | return len([ 84 | node for node in nodes if 'lambda' in node 85 | ]) 86 | 87 | def estimatable_clique(clique): 88 | y_count = num_Ys(clique) 89 | lambda_count = num_lambdas(clique) 90 | 91 | return y_count <= 1 or lambda_count == 0 92 | 93 | for clique in self.junction_tree.nodes: 94 | if not estimatable_clique(clique): 95 | return False, "We can't estimate {}!".format(clique) 96 | 97 | for separator_set in self.separator_sets: 98 | if not estimatable_clique(clique): 99 | return False, "We can't estimate {}!".format(separator_set) 100 | 101 | # for each marginal we need to estimate, check if there is a valid triplet 102 | marginals = sorted(list(self.junction_tree.nodes) + list(self.separator_sets)) 103 | for marginal in marginals: 104 | y_count = num_Ys(marginal) 105 | lambda_count = num_lambdas(marginal) 106 | 107 | if y_count != 1: 108 | continue 109 | if lambda_count == 0: 110 | continue 111 | 112 | separator_y = [node for node in marginal if 'Y' in node] 113 | lambdas = [node for node in marginal if 'lambda' in node] 114 | 115 | found = False 116 | for first_node in self.nodes: 117 | if 'Y' in first_node or first_node in lambdas: 118 | continue 119 | for second_node in self.nodes: 120 | if 'Y' in second_node or second_node in lambdas: 121 | continue 122 | 123 | if (self._is_separator(lambdas, [first_node], separator_y) and 124 | self._is_separator(lambdas, [second_node], separator_y) and 125 | self._is_separator([first_node], [second_node], separator_y)): 126 | found = True 127 | break 128 | if found: 129 | break 130 | 131 | if not found: 132 | print('Could not find triplet for {}!'.format(marginal)) 133 | return False 134 | 135 | return True -------------------------------------------------------------------------------- /flyingsquid/_lm_parameters.py: -------------------------------------------------------------------------------- 1 | from pgmpy.models import MarkovModel 2 | from pgmpy.factors.discrete import JointProbabilityDistribution, DiscreteFactor 3 | from itertools import combinations 4 | from flyingsquid.helpers import * 5 | import numpy as np 6 | import math 7 | from tqdm import tqdm 8 | import sys 9 | import random 10 | 11 | class Mixin: 12 | ''' 13 | Functions to compute label model parameters from mean parameters. 14 | ''' 15 | 16 | def _generate_e_vector(self, clique): 17 | ''' 18 | The e vector is a vector of assignments for a particular marginal. 19 | 20 | For example, in a marginal with one LF and one Y variable, and no 21 | abstentions, the e vector entries are: 22 | [ 23 | (1, 1), 24 | (1, -1), 25 | (-1, 1), 26 | (-1, -1) 27 | ] 28 | The first entry of each tuple is the value of the LF, the second 29 | entry is the value of the Y variagble. 30 | 31 | In a marginal with two LFs and one Y variable and no abstentions, 32 | the entries are: 33 | [ 34 | (1, 1, 1), 35 | (1, 1, -1), 36 | (1, -1, 1), 37 | (1, -1, -1), 38 | (-1, 1, 1), 39 | (-1, 1, -1), 40 | (-1, -1, 1), 41 | (-1, -1, -1) 42 | ] 43 | 44 | In a marginal with one Lf, one Y variable, and abstentions: 45 | [ 46 | (1, 1), 47 | (0, 1), 48 | (-1, 1), 49 | (1, -1), 50 | (0, -1), 51 | (-1, -1) 52 | ] 53 | 54 | Two LFs, one Y variable, and abstentions: 55 | [ 56 | (1, 1, 1), 57 | (0, 1, 1), 58 | (-1, 1, 1), 59 | (1, 0, 1), 60 | (0, 0, 1), 61 | (-1, 0, 1), 62 | (1, -1, 1), 63 | (0, -1, 1), 64 | (-1, -1, 1), 65 | (1, 1, -1), 66 | (0, 1, -1), 67 | (-1, 1, -1), 68 | (1, 0, -1), 69 | (0, 0, -1), 70 | (-1, 0, -1), 71 | (1, -1, -1), 72 | (0, -1, -1), 73 | (-1, -1, -1) 74 | ] 75 | ''' 76 | lambda_values = [1, 0, -1] if self.allow_abstentions else [1, -1] 77 | e_vec = [[1], [-1]] 78 | for i in range(len(clique) - 1): 79 | new_e_vec = [] 80 | if not self.allow_abstentions: 81 | for new_val in lambda_values: 82 | for e_val in e_vec: 83 | new_e_vec.append(e_val + [new_val]) 84 | else: 85 | for e_val in e_vec: 86 | for new_val in lambda_values: 87 | new_e_vec.append([new_val] + e_val) 88 | e_vec = new_e_vec 89 | e_vec = [ tuple(e_val) for e_val in e_vec ] 90 | 91 | return e_vec 92 | 93 | def _generate_r_vector(self, clique): 94 | ''' 95 | The r vector is the vector of probability values that needs to be on the RHS 96 | of the B_matrix * e_vector = r_vector to make e_vector have the right values. 97 | 98 | When there are abstentions, the mapping works as follows: 99 | * Each probability is some combination of 100 | P(A * B * ... * C = 1, D = 0, E = 0, ..., F = 0) 101 | * The A, B, ..., C can include any LF, and the Y variable. 102 | * The D, E, ..., F can include any LF 103 | * Let the A, B, ..., C set be called the "equals one set" 104 | * Let the D, E, ..., F set be called the "equals zero set" 105 | * Then, for each entry in the e vector: 106 | * If there is a -1 in an LF spot, add the LF to the "equals zero set" 107 | * If there is a 0 in the LF spot, add the LF to the "equals one set" 108 | * If there is a -1 in the Y variable spot, add it to the "equals one set" 109 | 110 | When there are no abstentions, each probability is just defined by the 111 | "equals one set" (i.e., P(A * B * ... * C = 1)). 112 | * For each entry in the e vector: 113 | * If there is a -1 in any spot (LF spot or Y variable), add it to the 114 | "equals one set" 115 | ''' 116 | indices = [ int(node.split('_')[1]) for node in clique ] 117 | lf_indices = sorted(indices[:-1]) 118 | Y_idx = indices[-1] 119 | Y_val = 'Y_{}'.format(Y_idx) 120 | 121 | e_vec = self._generate_e_vector(clique) 122 | 123 | r_vec = [] 124 | for e_vec_tup in e_vec: 125 | # P(a * b * ... * c = 1) for everything in this array 126 | r_vec_entry_equal_one = [] 127 | # P(a = 0, b = 0, ..., c = 0) for everything in this array 128 | r_vec_entry_equal_zero = [] 129 | for e_vec_entry, lf_idx in zip(e_vec_tup, lf_indices): 130 | # if you have abstentions, -1 means add to equal zero, 0 means add to equal one 131 | if self.allow_abstentions: 132 | if e_vec_entry == -1: 133 | r_vec_entry_equal_zero.append('lambda_{}'.format(lf_idx)) 134 | if e_vec_entry == 0: 135 | r_vec_entry_equal_one.append('lambda_{}'.format(lf_idx)) 136 | # otherwise, -1 means add to equal one 137 | else: 138 | if e_vec_entry == -1: 139 | r_vec_entry_equal_one.append('lambda_{}'.format(lf_idx)) 140 | if e_vec_tup[-1] == -1: 141 | r_vec_entry_equal_one.append(Y_val) 142 | 143 | entries_equal_one = ( 144 | tuple(['1']) if len(r_vec_entry_equal_one) == 0 else 145 | tuple(r_vec_entry_equal_one)) 146 | entries_equal_zero = ( 147 | tuple(['0']) if len(r_vec_entry_equal_zero) == 0 else 148 | tuple(r_vec_entry_equal_zero)) 149 | if self.allow_abstentions: 150 | r_vec.append(( 151 | entries_equal_one, 152 | entries_equal_zero 153 | )) 154 | else: 155 | if len(r_vec_entry_equal_zero) > 0: 156 | print('No abstentions allowed!') 157 | exit(1) 158 | r_vec.append(entries_equal_one) 159 | 160 | return r_vec 161 | 162 | def _generate_b_matrix(self, clique): 163 | if not self.allow_abstentions: 164 | b_matrix_orig = np.array([[1, 1], [1, -1]]) 165 | b_matrix = b_matrix_orig 166 | for i in range(len(clique) - 1): 167 | b_matrix = np.kron(b_matrix, b_matrix_orig) 168 | b_matrix[b_matrix < 0] = 0 169 | 170 | return b_matrix 171 | else: 172 | a_zero = np.array([ 173 | [1, 1], 174 | [1, 0] 175 | ]) 176 | b_zero = np.array([ 177 | [0, 0], 178 | [0, 1] 179 | ]) 180 | 181 | c_matrix = np.array([ 182 | [1, 1, 1], 183 | [1, 0, 0], 184 | [0, 1, 0] 185 | ]) 186 | d_matrix = np.array([ 187 | [0, 0, 0], 188 | [0, 0, 1], 189 | [0, 0, 0] 190 | ]) 191 | 192 | a_i = a_zero 193 | b_i = b_zero 194 | for i in range(len(clique) - 1): 195 | a_prev = a_i 196 | b_prev = b_i 197 | a_i = np.kron(a_prev, c_matrix) + np.kron(b_prev, d_matrix) 198 | b_i = np.kron(a_prev, d_matrix) + np.kron(b_prev, c_matrix) 199 | 200 | return a_i -------------------------------------------------------------------------------- /flyingsquid/_observables.py: -------------------------------------------------------------------------------- 1 | from pgmpy.models import MarkovModel 2 | from pgmpy.factors.discrete import JointProbabilityDistribution, DiscreteFactor 3 | from itertools import combinations 4 | from flyingsquid.helpers import * 5 | import numpy as np 6 | import math 7 | from tqdm import tqdm 8 | import sys 9 | import random 10 | 11 | class Mixin: 12 | ''' 13 | Functions to compute observable properties. 14 | ''' 15 | 16 | def _compute_class_balance(self, class_balance=None, Y_dev=None): 17 | # generate class balance of Ys 18 | Ys_ordered = [ 'Y_{}'.format(i) for i in range(self.v) ] 19 | cardinalities = [ 2 for i in range(self.v) ] 20 | if class_balance is not None: 21 | class_balance = class_balance / sum(class_balance) 22 | cb = JointProbabilityDistribution( 23 | Ys_ordered, cardinalities, class_balance 24 | ) 25 | elif Y_dev is not None: 26 | Ys_ordered = [ 'Y_{}'.format(i) for i in range(self.v) ] 27 | vals = { Y: (-1, 1) for Y in Ys_ordered } 28 | Y_vecs = sorted([ 29 | [ vec_dict[Y] for Y in Ys_ordered ] 30 | for vec_dict in dict_product(vals) 31 | ]) 32 | counts = { 33 | tuple(Y_vec): 0 34 | for Y_vec in Y_vecs 35 | } 36 | for data_point in Y_dev: 37 | counts[tuple(data_point)] += 1 38 | cb = JointProbabilityDistribution( 39 | Ys_ordered, cardinalities, 40 | [ 41 | float(counts[tuple(Y_vec)]) / len(Y_dev) 42 | for Y_vec in Y_vecs 43 | ]) 44 | else: 45 | num_combinations = 2 ** self.v 46 | cb = JointProbabilityDistribution( 47 | Ys_ordered, cardinalities, [ 48 | 1. / num_combinations for i in range(num_combinations) 49 | ]) 50 | 51 | return cb 52 | 53 | def _compute_Y_marginals(self, Y_marginals): 54 | for marginal in Y_marginals: 55 | nodes = [ 'Y_{}'.format(idx) for idx in marginal ] 56 | Y_marginals[marginal] = self.cb.marginal_distribution( 57 | nodes, 58 | inplace=False 59 | ) 60 | 61 | return Y_marginals 62 | 63 | def _compute_Y_equals_one(self, Y_equals_one): 64 | # compute from class balance 65 | for factor in Y_equals_one: 66 | nodes = [ 'Y_{}'.format(idx) for idx in factor ] 67 | 68 | Y_marginal = self.cb.marginal_distribution( 69 | nodes, 70 | inplace=False 71 | ) 72 | vals = { Y: (-1, 1) for Y in nodes } 73 | Y_vecs = sorted([ 74 | [ vec_dict[Y] for Y in nodes ] 75 | for vec_dict in dict_product(vals) 76 | ]) 77 | 78 | # add up the probabilities of all the vectors whose values multiply to +1 79 | total_prob = 0 80 | for Y_vec in Y_vecs: 81 | if np.prod(Y_vec) == 1: 82 | vector_prob = Y_marginal.reduce( 83 | [ 84 | (Y_i, Y_val if Y_val == 1 else 0) 85 | for Y_i, Y_val in zip(nodes, Y_vec) 86 | ], 87 | inplace=False 88 | ).values 89 | total_prob += vector_prob 90 | 91 | Y_equals_one[factor] = total_prob 92 | 93 | return Y_equals_one -------------------------------------------------------------------------------- /flyingsquid/_triplets.py: -------------------------------------------------------------------------------- 1 | from pgmpy.models import MarkovModel 2 | from pgmpy.factors.discrete import JointProbabilityDistribution, DiscreteFactor 3 | from itertools import combinations 4 | from flyingsquid.helpers import * 5 | import numpy as np 6 | import math 7 | from tqdm import tqdm 8 | import sys 9 | import random 10 | 11 | class Mixin: 12 | ''' 13 | Triplet algorithms as a Mixin. These algorithms recover the mean parameters 14 | of the graphical model. 15 | ''' 16 | 17 | def _triplet_method_single_seed(self, expectations_to_estimate): 18 | # create triplets for what we need, and return which moments we'll need to compute 19 | 20 | exp_to_estimate_list = sorted(list(expectations_to_estimate)) 21 | if self.triplet_seed is not None: 22 | random.shuffle(exp_to_estimate_list) 23 | 24 | if self.triplets is None: 25 | expectations_in_triplets = set() 26 | triplets = [] 27 | for expectation in exp_to_estimate_list: 28 | # if we're already computing it, don't need to add to a new triplet 29 | if expectation in expectations_in_triplets: 30 | continue 31 | 32 | if not self.allow_abstentions: 33 | Y_node = expectation[-1] 34 | else: 35 | Y_node = expectation[0][-1] 36 | 37 | def check_triplet(triplet): 38 | return (self._is_separator(triplet[0][:-1], triplet[1][:-1], Y_node) and 39 | self._is_separator(triplet[0][:-1], triplet[2][:-1], Y_node) and 40 | self._is_separator(triplet[1][:-1], triplet[2][:-1], Y_node)) 41 | 42 | triplet = [expectation] 43 | found = False 44 | 45 | # first try looking at the other expectations that we need to estimate 46 | for first_node in exp_to_estimate_list: 47 | if self.allow_abstentions: 48 | # need to check if conditionals are the same 49 | if (first_node in triplet or # skip if it's already in the triplet 50 | first_node[0][-1] != Y_node or # skip if the Y values aren't the same 51 | first_node[1] != expectation[1] or # skip if conditions are different 52 | (len(first_node[0]) > 2 and len(expectation[0]) > 2) or # at most one item in the triplet can have length > 2 53 | first_node in expectations_in_triplets or # we're already computing this 54 | not self._is_separator(expectation[0][:-1], first_node[0][:-1], Y_node)): # not separated 55 | continue 56 | else: 57 | if (first_node in triplet or # skip if it's already in the triplet 58 | first_node[-1] != Y_node or # skip if the Y values aren't the same 59 | (len(first_node) > 2 and len(expectation) > 2) or # at most one item in the triplet can have length > 2 60 | first_node in expectations_in_triplets or # we're already computing this 61 | not self._is_separator(expectation[:-1], first_node[:-1], Y_node)): # not separated 62 | continue 63 | triplet = [expectation, first_node] 64 | # first try looking at the other expectations that we need to estimate 65 | for second_node in exp_to_estimate_list: 66 | if self.allow_abstentions: 67 | if (second_node in triplet or # skip if it's already in the triplet 68 | second_node[0][-1] != Y_node or # skip if the Y values aren't the same 69 | second_node[1] != expectation[1] or # skip if conditions are different 70 | (len(second_node[0]) > 2 and 71 | any(len(exp[0]) > 2 for exp in triplet)) or # at most one item in the triplet can have length > 2 72 | second_node in expectations_in_triplets or # we're already computing this 73 | not all(self._is_separator(exp[0][:-1], second_node[0][:-1], Y_node) for exp in triplet)): # not separated 74 | continue 75 | else: 76 | if (second_node in triplet or # skip if it's already in the triplet 77 | second_node[-1] != Y_node or # skip if the Y values aren't the same 78 | (len(second_node) > 2 and 79 | any(len(exp) > 2 for exp in triplet)) or # at most one item in the triplet can have length > 2 80 | second_node in expectations_in_triplets or # we're already computing this 81 | not all(self._is_separator(exp[:-1], second_node[:-1], Y_node) for exp in triplet)): # not separated 82 | continue 83 | 84 | # we found a triplet! 85 | triplet = [expectation, first_node, second_node] 86 | found = True 87 | break 88 | 89 | if found: 90 | break 91 | 92 | # otherwise, try everything 93 | for second_node in [ 94 | ((node, Y_node), expectation[1]) if self.allow_abstentions else (node, Y_node) 95 | for node in self.nodes 96 | ]: 97 | if self.allow_abstentions: 98 | if (second_node in triplet or # skip if it's already in the triplet 99 | second_node[1] != expectation[1] or # skip if conditions are different 100 | not all(self._is_separator(exp[0][:-1], second_node[0][:-1], Y_node) for exp in triplet)): # not separated 101 | continue 102 | else: 103 | if (second_node in triplet or # skip if it's already in the triplet 104 | not all(self._is_separator(exp[:-1], second_node[:-1], Y_node) for exp in triplet)): # not separated 105 | continue 106 | 107 | # we found a triplet! 108 | triplet = [expectation, first_node, second_node] 109 | found = True 110 | break 111 | 112 | if found: 113 | break 114 | 115 | if not found: 116 | # try everything 117 | for first_node in [ 118 | ((node, Y_node), expectation[1]) if self.allow_abstentions else (node, Y_node) 119 | for node in self.nodes if 'Y' not in node 120 | ]: 121 | if self.allow_abstentions: 122 | if (first_node in triplet or # skip if it's already in the triplet 123 | first_node[0][0] in expectation[1] or # skip if the node is part of the condition 124 | not self._is_separator(expectation[0][:-1], first_node[0][:-1], Y_node)): # not separated 125 | continue 126 | else: 127 | if (first_node in triplet or # skip if it's already in the triplet 128 | not self._is_separator(expectation[:-1], first_node[:-1], Y_node)): # not separated 129 | continue 130 | 131 | triplet = [expectation, first_node] 132 | 133 | if found: 134 | break 135 | 136 | for second_node in [ 137 | ((node, Y_node), expectation[1]) if self.allow_abstentions else (node, Y_node) 138 | for node in self.nodes if 'Y' not in node 139 | ]: 140 | if self.allow_abstentions: 141 | if (second_node in triplet or # skip if it's already in the triplet 142 | second_node[0][0] in expectation[1] or # skip if the node is part of the condition 143 | not all(self._is_separator(exp[0][:-1], second_node[0][:-1], Y_node) for exp in triplet)): # not separated 144 | continue 145 | else: 146 | if (second_node in triplet or # skip if it's already in the triplet 147 | not all(self._is_separator(exp[:-1], second_node[:-1], Y_node) for exp in triplet)): # not separated 148 | continue 149 | # we found a triplet! 150 | triplet = [expectation, first_node, second_node] 151 | found = True 152 | break 153 | 154 | if found: 155 | break 156 | 157 | if found: 158 | triplets.append(triplet) 159 | for expectation in triplet: 160 | expectations_in_triplets.add(expectation) 161 | else: 162 | triplets = self.triplets 163 | 164 | all_moments = set() 165 | abstention_probabilities = {} 166 | 167 | for exp1, exp2, exp3 in triplets: 168 | if self.allow_abstentions: 169 | condition = exp1[1] 170 | 171 | moments = [ 172 | tuple(sorted(exp1[0][:-1] + exp2[0][:-1])), 173 | tuple(sorted(exp1[0][:-1] + exp3[0][:-1])), 174 | tuple(sorted(exp2[0][:-1] + exp3[0][:-1])) 175 | ] 176 | 177 | indices1 = tuple(sorted([ int(node.split('_')[1]) for node in exp1[0][:-1] ])) 178 | indices2 = tuple(sorted([ int(node.split('_')[1]) for node in exp2[0][:-1] ])) 179 | indices3 = tuple(sorted([ int(node.split('_')[1]) for node in exp3[0][:-1] ])) 180 | 181 | if indices1 not in abstention_probabilities: 182 | abstention_probabilities[indices1] = 0 183 | if indices2 not in abstention_probabilities: 184 | abstention_probabilities[indices2] = 0 185 | if indices3 not in abstention_probabilities: 186 | abstention_probabilities[indices3] = 0 187 | else: 188 | # first, figure out which moments we need to compute 189 | moments = [ 190 | tuple(sorted(exp1[:-1] + exp2[:-1])), 191 | tuple(sorted(exp1[:-1] + exp3[:-1])), 192 | tuple(sorted(exp2[:-1] + exp3[:-1])) 193 | ] 194 | for moment in moments: 195 | indices = tuple(sorted([ int(node.split('_')[1]) for node in moment ])) 196 | 197 | if indices not in all_moments: 198 | all_moments.add(indices) 199 | 200 | return triplets, all_moments, abstention_probabilities 201 | 202 | def _triplet_method_mean_median(self, expectations_to_estimate, solve_method): 203 | exp_to_estimate_list = sorted(list(expectations_to_estimate)) 204 | triplets = [] 205 | 206 | if self.triplets is None: 207 | if self.fully_independent_case: 208 | Y_node = 'Y' 209 | all_nodes = [ 210 | ((node, Y_node), '0') if self.allow_abstentions else (node, Y_node) 211 | for node in self.nodes if 'Y' not in node 212 | ] 213 | triplets = [ 214 | [i, j, k] 215 | for i in all_nodes 216 | for j in all_nodes if i != j 217 | for k in all_nodes if i != k and k != j 218 | ] + [ 219 | [expectation, -1, -1] for expectation in exp_to_estimate_list 220 | ] 221 | else: 222 | for expectation in exp_to_estimate_list: 223 | if not self.allow_abstentions: 224 | Y_node = expectation[-1] 225 | else: 226 | Y_node = expectation[0][-1] 227 | 228 | triplet = [expectation] 229 | 230 | # try everything 231 | for first_node in [ 232 | ((node, Y_node), expectation[1]) if self.allow_abstentions else (node, Y_node) 233 | for node in self.nodes if 'Y' not in node 234 | ]: 235 | if self.allow_abstentions: 236 | if (first_node in triplet or # skip if it's already in the triplet 237 | first_node[0][0] in expectation[1] or # skip if the node is part of the condition 238 | not self._is_separator(expectation[0][:-1], first_node[0][:-1], Y_node)): # not separated 239 | continue 240 | else: 241 | if (first_node in triplet or # skip if it's already in the triplet 242 | not self._is_separator(expectation[:-1], first_node[:-1], Y_node)): # not separated 243 | continue 244 | 245 | triplet = [expectation, first_node] 246 | 247 | for second_node in [ 248 | ((node, Y_node), expectation[1]) if self.allow_abstentions else (node, Y_node) 249 | for node in self.nodes if 'Y' not in node 250 | ]: 251 | if self.allow_abstentions: 252 | if (second_node in triplet or # skip if it's already in the triplet 253 | second_node[0][0] in expectation[1] or # skip if the node is part of the condition 254 | not all(self._is_separator(exp[0][:-1], second_node[0][:-1], Y_node) for exp in triplet)): # not separated 255 | continue 256 | else: 257 | if (second_node in triplet or # skip if it's already in the triplet 258 | not all(self._is_separator(exp[:-1], second_node[:-1], Y_node) for exp in triplet)): # not separated 259 | continue 260 | if tuple([expectation, second_node, first_node]) in triplets: 261 | continue 262 | # we found a triplet! 263 | triplet = [expectation, first_node, second_node] 264 | triplets.append(tuple(triplet)) 265 | triplet = [expectation, first_node] 266 | triplet = [expectation] 267 | else: 268 | triplets = self.triplets 269 | 270 | all_moments = set() 271 | abstention_probabilities = {} 272 | 273 | if self.fully_independent_case: 274 | all_nodes = list(range(self.m)) 275 | all_moments = set([ 276 | (i, j) 277 | for i in all_nodes 278 | for j in all_nodes if i != j 279 | ]) 280 | if self.allow_abstentions: 281 | for node in all_nodes: 282 | abstention_probabilities[tuple([node])] = 0 283 | else: 284 | for exp1, exp2, exp3 in triplets: 285 | if self.allow_abstentions: 286 | condition = exp1[1] 287 | 288 | moments = [ 289 | tuple(sorted(exp1[0][:-1] + exp2[0][:-1])), 290 | tuple(sorted(exp1[0][:-1] + exp3[0][:-1])), 291 | tuple(sorted(exp2[0][:-1] + exp3[0][:-1])) 292 | ] 293 | 294 | indices1 = tuple(sorted([ int(node.split('_')[1]) for node in exp1[0][:-1] ])) 295 | indices2 = tuple(sorted([ int(node.split('_')[1]) for node in exp2[0][:-1] ])) 296 | indices3 = tuple(sorted([ int(node.split('_')[1]) for node in exp3[0][:-1] ])) 297 | 298 | if indices1 not in abstention_probabilities: 299 | abstention_probabilities[indices1] = 0 300 | if indices2 not in abstention_probabilities: 301 | abstention_probabilities[indices2] = 0 302 | if indices3 not in abstention_probabilities: 303 | abstention_probabilities[indices3] = 0 304 | else: 305 | # first, figure out which moments we need to compute 306 | moments = [ 307 | tuple(sorted(exp1[:-1] + exp2[:-1])), 308 | tuple(sorted(exp1[:-1] + exp3[:-1])), 309 | tuple(sorted(exp2[:-1] + exp3[:-1])) 310 | ] 311 | for moment in moments: 312 | indices = tuple(sorted([ int(node.split('_')[1]) for node in moment ])) 313 | 314 | if indices not in all_moments: 315 | all_moments.add(indices) 316 | 317 | return triplets, all_moments, abstention_probabilities 318 | 319 | def _triplet_method_preprocess(self, expectations_to_estimate, solve_method): 320 | if solve_method == 'triplet': 321 | return self._triplet_method_single_seed(expectations_to_estimate) 322 | elif solve_method in [ 'triplet_mean', 'triplet_median' ]: 323 | return self._triplet_method_mean_median(expectations_to_estimate, solve_method) 324 | else: 325 | raise NotImplemented('Unknown solve method {}'.format(solve_method)) 326 | 327 | def _triplet_method_probabilities(self, triplets, lambda_moment_vals, lambda_zeros, 328 | abstention_probabilities, sign_recovery, solve_method): 329 | expectation_values = {} 330 | 331 | if solve_method == 'triplet': 332 | pass 333 | else: 334 | # each triplet is constructed for the first value in the expectation 335 | # get all the triplets with the same first value, and take the mean or median 336 | expectation_value_candidates = {} 337 | 338 | if self.fully_independent_case and solve_method in ['triplet_mean', 'triplet_median']: 339 | second_moment = np.zeros((self.m, self.m)) 340 | for key in lambda_moment_vals: 341 | i, j = key 342 | second_moment[i][j] = lambda_moment_vals[(i, j)] 343 | 344 | def all_triplet_vals(idx): 345 | triplet_vals = [] 346 | for i in range(self.m): 347 | if i == idx: 348 | continue 349 | for j in range(i): 350 | if j == idx: 351 | continue 352 | val = math.sqrt(abs( 353 | (second_moment[idx, i] * second_moment[idx, j] / second_moment[i, j]) 354 | if second_moment[i, j] != 0 else 0 355 | )) 356 | triplet_vals.append(val) 357 | return triplet_vals 358 | all_vals = [all_triplet_vals(idx) for idx in range(self.m)] 359 | expectations_to_estimate = [ 360 | expectation 361 | for expectation, a, b in triplets if a == -1 and b == -1 362 | ] 363 | for expectation in expectations_to_estimate: 364 | if self.allow_abstentions: 365 | idx = int(expectation[0][0].split('_')[1]) 366 | else: 367 | idx = int(expectation[0].split('_')[1]) 368 | expectation_value_candidates[expectation] = all_vals[idx] 369 | else: 370 | for exp1, exp2, exp3 in triplets: 371 | if self.allow_abstentions: 372 | moments = [ 373 | tuple(sorted(exp1[0][:-1] + exp2[0][:-1])), 374 | tuple(sorted(exp1[0][:-1] + exp3[0][:-1])), 375 | tuple(sorted(exp2[0][:-1] + exp3[0][:-1])) 376 | ] 377 | else: 378 | # first, figure out which moments we need to compute 379 | moments = [ 380 | tuple(sorted(exp1[:-1] + exp2[:-1])), 381 | tuple(sorted(exp1[:-1] + exp3[:-1])), 382 | tuple(sorted(exp2[:-1] + exp3[:-1])) 383 | ] 384 | 385 | moment_vals = [ 386 | lambda_moment_vals[ 387 | tuple(sorted([ int(node.split('_')[1]) for node in moment ])) 388 | ] 389 | for moment in moments 390 | ] 391 | 392 | if solve_method == 'triplet': 393 | expectation_values[exp1] = ( 394 | math.sqrt(abs(moment_vals[0] * moment_vals[1] / moment_vals[2])) if moment_vals[2] != 0 else 0) 395 | expectation_values[exp2] = ( 396 | math.sqrt(abs(moment_vals[0] * moment_vals[2] / moment_vals[1])) if moment_vals[1] != 0 else 0) 397 | expectation_values[exp3] = ( 398 | math.sqrt(abs(moment_vals[1] * moment_vals[2] / moment_vals[0])) if moment_vals[0] != 0 else 0) 399 | else: 400 | if exp1 not in expectation_value_candidates: 401 | expectation_value_candidates[exp1] = [] 402 | exp_value = ( 403 | math.sqrt(abs(moment_vals[0] * moment_vals[1] / moment_vals[2])) if moment_vals[2] != 0 else 0) 404 | expectation_value_candidates[exp1].append(exp_value) 405 | 406 | if solve_method in ['triplet_mean', 'triplet_median']: 407 | for exp in expectation_value_candidates: 408 | if solve_method == 'triplet_mean': 409 | agg_function = np.mean 410 | if solve_method == 'triplet_median': 411 | agg_function = np.median 412 | expectation_values[exp] = agg_function(expectation_value_candidates[exp]) 413 | 414 | self.expectation_value_candidates = expectation_value_candidates 415 | 416 | if sign_recovery == 'all_positive': 417 | # all signs are already positive 418 | pass 419 | else: 420 | print('{} sign recovery not implemented'.format(sign_recovery)) 421 | return 422 | 423 | if self.allow_abstentions: 424 | # probability is 0.5 * (1 + expectation - P(lambda part of factor is zero)) * P(conditional) 425 | # P(conditional) is 1 if there is no conditional 426 | probabilities = {} 427 | for expectation in sorted(list(expectation_values.keys())): 428 | exp_value = expectation_values[expectation] 429 | if expectation[1][0] == '0': 430 | condition_prob = 1 431 | else: 432 | zero_condition = tuple(sorted([ int(node.split('_')[1]) for node in expectation[1] ])) 433 | condition_prob = lambda_zeros[zero_condition] 434 | 435 | lambda_factor = tuple(sorted([ int(node.split('_')[1]) for node in expectation[0][:-1] ])) 436 | abstention_prob = abstention_probabilities[lambda_factor] 437 | 438 | probabilities[expectation] = 0.5 * (1 + exp_value - abstention_prob) * condition_prob 439 | else: 440 | probabilities = { 441 | expectation: 0.5 * (1 + expectation_values[expectation]) 442 | for expectation in sorted(list(expectation_values.keys())) 443 | } 444 | 445 | 446 | return probabilities, expectation_values 447 | -------------------------------------------------------------------------------- /flyingsquid/helpers.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | def dict_product(d): 4 | keys = d.keys() 5 | for element in product(*d.values()): 6 | yield dict(zip(keys, element)) -------------------------------------------------------------------------------- /flyingsquid/label_model.py: -------------------------------------------------------------------------------- 1 | from pgmpy.models import MarkovModel 2 | from pgmpy.factors.discrete import JointProbabilityDistribution, DiscreteFactor 3 | from itertools import combinations 4 | from flyingsquid.helpers import * 5 | from flyingsquid import _triplets 6 | from flyingsquid import _graphs 7 | from flyingsquid import _observables 8 | from flyingsquid import _lm_parameters 9 | import numpy as np 10 | import math 11 | from tqdm import tqdm 12 | import sys 13 | import random 14 | 15 | class LabelModel(_triplets.Mixin, _graphs.Mixin, _observables.Mixin, 16 | _lm_parameters.Mixin): 17 | 18 | def __init__(self, m, v=1, y_edges=[], lambda_y_edges=[], lambda_edges=[], 19 | allow_abstentions=True, triplets=None, triplet_seed=0): 20 | '''Initialize the LabelModel with a graph G. 21 | 22 | m: number of LF's 23 | v: number of Y tasks 24 | y_edges: edges between the tasks. (i, j) in y_edges means that 25 | there is an edge between y_i and y_j. 26 | lambda_y_edges: edges between LF's and tasks. (i, j) in lambda_y_edges 27 | means that there is an edge between lambda_i and y_j. If this list 28 | is empty, assume that all labeling functions are connected to Y_0. 29 | lambda_edges: edges between LF's. (i, j) in lambda_edges means that 30 | there is an edge between lambda_i and lambda_j. 31 | allow_abstentions: if True, allow abstentions in L_train. 32 | triplets: if specified, use these triplets 33 | triplet_seed: if triplets not specified, randomly shuffle the nodes 34 | with this seed when generating triplets 35 | ''' 36 | if lambda_y_edges == []: 37 | lambda_y_edges = [(i, 0) for i in range(m)] 38 | 39 | G = MarkovModel() 40 | # Add LF nodes 41 | G.add_nodes_from([ 42 | 'lambda_{}'.format(i) 43 | for i in range(m) 44 | ]) 45 | G.add_nodes_from([ 46 | 'Y_{}'.format(i) 47 | for i in range(v) 48 | ]) 49 | 50 | # Add edges 51 | G.add_edges_from([ 52 | ('Y_{}'.format(start), 'Y_{}'.format(end)) 53 | for start, end in y_edges 54 | ]) 55 | G.add_edges_from([ 56 | ('lambda_{}'.format(start), 'Y_{}'.format(end)) 57 | for start, end in lambda_y_edges 58 | ]) 59 | G.add_edges_from([ 60 | ('lambda_{}'.format(start), 'lambda_{}'.format(end)) 61 | for start, end in lambda_edges 62 | ]) 63 | 64 | self.fully_independent_case = lambda_edges == [] 65 | 66 | self.m = m 67 | if m < 3: 68 | raise NotImplementedError("Triplet method needs at least three LF's to run.") 69 | self.v = v 70 | self.G = G 71 | self.junction_tree = self.G.to_junction_tree() 72 | 73 | self.nodes = sorted(list(self.G.nodes)) 74 | self.triplet_seed = triplet_seed 75 | if triplet_seed is not None: 76 | random.seed(triplet_seed) 77 | random.shuffle(self.nodes) 78 | 79 | self.separator_sets = set([ 80 | tuple(sorted(list((set(clique1).intersection(set(clique2)))))) 81 | for clique1, clique2 in self.junction_tree.edges 82 | ]) 83 | 84 | self.allow_abstentions = allow_abstentions 85 | self.triplets = triplets 86 | 87 | if not self._check(): 88 | raise NotImplementedError('Cannot run triplet method for specified graph.') 89 | 90 | # Make this Picklable 91 | def save(obj): 92 | return (obj.__class__, obj.__dict__) 93 | 94 | def load(cls, attributes): 95 | obj = cls.__new__(cls) 96 | obj.__dict__.update(attributes) 97 | return obj 98 | 99 | def enumerate_ys(self): 100 | # order to output probabilities 101 | vals = { Y: (-1, 1) for Y in range(self.v) } 102 | Y_vecs = sorted([ 103 | [ vec_dict[Y] for Y in range(self.v) ] 104 | for vec_dict in dict_product(vals) 105 | ]) 106 | 107 | return Y_vecs 108 | 109 | def _lambda_pass(self, L_train, lambda_marginals, lambda_moment_vals, lambda_equals_one, 110 | lambda_zeros, abstention_probabilities, verbose = False): 111 | ''' 112 | Make the pass over L_train. 113 | 114 | In this pass, we need to: 115 | * Compute all the joint marginal distributions over multiple lambda's (lambda_marginals) 116 | * Compute the probabilities that some set of lambda's are all equal to zero (lambda_zeros) 117 | * Compute all the lambda moments, including conditional moments (lambda_moment_vals) 118 | * Compute the probability that the product of some lambdas is zero (abstention_probabilities) 119 | ''' 120 | 121 | # do the fast cases first 122 | easy_marginals = { 123 | marginal: None 124 | for marginal in lambda_marginals 125 | if len(marginal) == 1 126 | } 127 | easy_moments = { 128 | moment: None 129 | for moment in lambda_moment_vals 130 | if type(moment[0]) != type(()) and len(moment) <= 2 131 | } 132 | easy_equals_one = { 133 | factor: None 134 | for factor in lambda_equals_one 135 | if type(factor[0]) != type(()) and len(factor) == 1 136 | } 137 | easy_zeros = { 138 | condition: None 139 | for condition in lambda_zeros if len(condition) == 1 140 | } 141 | easy_abstention_probs = { 142 | factor: None 143 | for factor in abstention_probabilities if len(factor) == 1 144 | } 145 | 146 | means = np.einsum('ij->j', L_train)/L_train.shape[0] 147 | covariance = np.einsum('ij,ik->jk', L_train, L_train)/L_train.shape[0] 148 | 149 | lf_cardinality = 3 if self.allow_abstentions else 2 150 | lf_values = (-1, 0, 1) if self.allow_abstentions else (-1, 1) 151 | for marginal in easy_marginals: 152 | idx = marginal[0] 153 | counts = [ np.sum(L_train[:,idx] == val) / L_train.shape[0] for val in lf_values ] 154 | easy_marginals[marginal] = JointProbabilityDistribution( 155 | [ 'lambda_{}'.format(idx) ], [ lf_cardinality ], counts 156 | ) 157 | 158 | if marginal in easy_equals_one: 159 | easy_equals_one[marginal] = counts[-1] 160 | if marginal in easy_zeros: 161 | easy_zeros[marginal] = counts[1] 162 | if marginal in easy_abstention_probs: 163 | easy_abstention_probs[marginal] = counts[1] 164 | for moment in easy_moments: 165 | if len(moment) == 1: 166 | easy_moments[moment] = means[moment[0]] 167 | else: 168 | easy_moments[moment] = covariance[moment[0]][moment[1]] 169 | for factor in easy_equals_one: 170 | if easy_equals_one[factor] is None: 171 | easy_equals_one[factor] = np.sum(L_train[:,factor[0]] == 1) / L_train.shape[0] 172 | for condition in easy_zeros: 173 | if easy_zeros[condition] is None: 174 | idx = condition[0] 175 | easy_zeros[condition] = np.sum(L_train[:,idx] == 0) / L_train.shape[0] 176 | for factor in easy_abstention_probs: 177 | if easy_abstention_probs[factor] is None: 178 | idx = factor[0] 179 | easy_abstention_probs[factor] = np.sum(L_train[:,idx] == 0) / L_train.shape[0] 180 | 181 | # time for the remaining cases 182 | lambda_marginals = { 183 | key: lambda_marginals[key] 184 | for key in lambda_marginals 185 | if key not in easy_marginals 186 | } 187 | lambda_moment_vals = { 188 | key: lambda_moment_vals[key] 189 | for key in lambda_moment_vals 190 | if key not in easy_moments 191 | } 192 | lambda_equals_one = { 193 | key: lambda_equals_one[key] 194 | for key in lambda_equals_one 195 | if key not in easy_equals_one 196 | } 197 | lambda_zeros = { 198 | key: lambda_zeros[key] 199 | for key in lambda_zeros 200 | if key not in easy_zeros 201 | } 202 | abstention_probabilities = { 203 | key: abstention_probabilities[key] 204 | for key in abstention_probabilities 205 | if key not in easy_abstention_probs 206 | } 207 | 208 | # for the rest, loop through L_train 209 | if (len(lambda_marginals) > 0 or len(lambda_moment_vals) > 0 or 210 | len(lambda_equals_one) > 0 or len(lambda_zeros) > 0 or 211 | len(abstention_probabilities) > 0): 212 | 213 | # figure out which lambda states we need to keep track of 214 | lambda_marginal_counts = {} 215 | lambda_marginal_vecs = {} 216 | lf_values = (-1, 0, 1) if self.allow_abstentions else (-1, 1) 217 | for lambda_marginal in lambda_marginals: 218 | nodes = [ 'lambda_{}'.format(idx) for idx in lambda_marginal ] 219 | vals = { lf: lf_values for lf in nodes } 220 | lf_vecs = sorted([ 221 | [ vec_dict[lf] for lf in nodes ] 222 | for vec_dict in dict_product(vals) 223 | ]) 224 | counts = { 225 | tuple(lf_vec): 0 226 | for lf_vec in lf_vecs 227 | } 228 | lambda_marginal_vecs[lambda_marginal] = lf_vecs 229 | lambda_marginal_counts[lambda_marginal] = counts 230 | 231 | lambda_moment_counts = { moment: 0 for moment in lambda_moment_vals } 232 | lambda_moment_basis = { moment: 0 for moment in lambda_moment_vals } 233 | lambda_equals_one_counts = { factor: 0 for factor in lambda_equals_one } 234 | lambda_equals_one_basis = { factor: 0 for factor in lambda_equals_one } 235 | lambda_zero_counts = { condition: 0 for condition in lambda_zeros } 236 | abstention_probability_counts = { factor: 0 for factor in abstention_probabilities } 237 | 238 | for data_point in tqdm(L_train) if verbose else L_train: 239 | for marginal in lambda_marginals: 240 | mask = [ data_point[idx] for idx in marginal ] 241 | lambda_marginal_counts[marginal][tuple(mask)] += 1 242 | for moment in lambda_moment_vals: 243 | if type(moment[0]) == type(()): 244 | pos_mask = [ data_point[idx] for idx in moment[0] ] 245 | zero_mask = [ data_point[idx] for idx in moment[1] ] 246 | 247 | if np.count_nonzero(zero_mask) == 0: 248 | lambda_moment_basis[moment] += 1 249 | lambda_moment_counts[moment] += np.prod(pos_mask) 250 | else: 251 | mask = [ data_point[idx] for idx in moment ] 252 | lambda_moment_counts[moment] += np.prod(mask) 253 | lambda_moment_basis[moment] += 1 254 | for factor in lambda_equals_one: 255 | if type(factor[0]) == type(()): 256 | pos_mask = [ data_point[idx] for idx in factor[0] ] 257 | zero_mask = [ data_point[idx] for idx in factor[1] ] 258 | 259 | if np.count_nonzero(zero_mask) == 0: 260 | lambda_equals_one_basis[factor] += 1 261 | if np.prod(pos_mask) == 1: 262 | lambda_equals_one_counts[factor] += 1 263 | else: 264 | mask = [ data_point[idx] for idx in factor ] 265 | if np.prod(mask) == 1: 266 | lambda_equals_one_counts[factor] += 1 267 | lambda_equals_one_basis[factor] += 1 268 | for zero_condition in lambda_zeros: 269 | zero_mask = [ data_point[idx] for idx in zero_condition ] 270 | if np.count_nonzero(zero_mask) == 0: 271 | lambda_zero_counts[zero_condition] += 1 272 | for factor in abstention_probability_counts: 273 | zero_mask = [ data_point[idx] for idx in factor ] 274 | if np.prod(zero_mask) == 0: 275 | abstention_probability_counts[factor] += 1 276 | 277 | lf_cardinality = 3 if self.allow_abstentions else 2 278 | for marginal in lambda_marginals: 279 | nodes = [ 'lambda_{}'.format(idx) for idx in marginal ] 280 | lf_vecs = lambda_marginal_vecs[marginal] 281 | counts = lambda_marginal_counts[marginal] 282 | 283 | lambda_marginals[marginal] = JointProbabilityDistribution( 284 | nodes, [ lf_cardinality for node in nodes ], 285 | [ 286 | float(counts[tuple(lf_vec)]) / len(L_train) 287 | for lf_vec in lf_vecs 288 | ] 289 | ) 290 | 291 | for moment in lambda_moment_vals: 292 | if lambda_moment_basis[moment] == 0: 293 | moment_val = 0 294 | else: 295 | moment_val = lambda_moment_counts[moment] / lambda_moment_basis[moment] 296 | lambda_moment_vals[moment] = moment_val 297 | 298 | for factor in lambda_equals_one: 299 | if lambda_equals_one_basis[factor] == 0: 300 | prob = 0 301 | else: 302 | prob = lambda_equals_one_counts[factor] / lambda_equals_one_basis[factor] 303 | lambda_equals_one[factor] = prob 304 | 305 | for zero_condition in lambda_zeros: 306 | lambda_zeros[zero_condition] = lambda_zero_counts[zero_condition] / len(L_train) 307 | 308 | for factor in abstention_probabilities: 309 | abstention_probabilities[factor] = abstention_probability_counts[factor] / len(L_train) 310 | 311 | # update with the easy values 312 | lambda_marginals.update(easy_marginals) 313 | lambda_moment_vals.update(easy_moments) 314 | lambda_equals_one.update(easy_equals_one) 315 | lambda_zeros.update(easy_zeros) 316 | abstention_probabilities.update(easy_abstention_probs) 317 | 318 | return lambda_marginals, lambda_moment_vals, lambda_equals_one, lambda_zeros, abstention_probabilities 319 | 320 | def fit(self, L_train, class_balance=None, Y_dev=None, flip_negative=True, clamp=True, 321 | solve_method='triplet_mean', 322 | sign_recovery='all_positive', 323 | verbose = False): 324 | '''Compute the marginal probabilities of each clique and separator set in the junction tree. 325 | 326 | L_train: an m x n matrix of LF outputs. L_train[k][i] is the value of \lambda_i on item k. 327 | 1 means positive, -1 means negative, 0 means abstain. 328 | class_balance: a 2^v vector of the probabilities of each combination of Y values. Sorted in 329 | lexicographical order (entry zero is for Y_0 = -1, ..., Y_{v-1} = -1, entry one is for 330 | Y_0 = -1, ..., Y_{v-1} = 1, last entry is for Y_0 = 1, ..., Y_{v-1} = 1). 331 | Y_dev: a v x |Y_dev| matrix of ground truth examples. If class_balance is not specified, this 332 | is used to find out the class balance. Otherwise not used. 333 | If this is not specified, and class_balance is not specified, then class balance is uniform. 334 | 1 means positive, -1 means negative. 335 | flip_negative: if True, flip sign of negative probabilities 336 | clamp: if True and flip_negative is not True, set negative probabilities to 0 337 | solve_method: one of ['triplet_mean', 'triplet_median', 'triplet', 'independencies'] 338 | If triplet, use the method below and the independencies we write down there. 339 | If independencies, use the following facts: 340 | * For any lambda_i: lambda_i * Y and Y are independent for any i, so 341 | E[lambda_i Y] = E[lambda_i] / E[Y] 342 | * For any lambda_i, lambda_j: E[lambda_i * lambda_j * Y] = E[lambda_i * lambda_j] * E[Y] 343 | * For an odd number of lambda's, the first property holds; for an even number, the second 344 | property holds 345 | Only triplet implemented right now. 346 | sign_recovery: one of ['all_positive', 'fully_independent'] 347 | If all_positive, assume that all accuracies that we compute are positive. 348 | If fully_independent, assume that the accuracy of lambda_0 on Y_0 is positive, and that for 349 | any lambda_i and lambda_{i+1}, sign(lambda_i lambda_{i+1}) = sign(M_{i,i+1}) where M_{i, i+1} 350 | is the second moment between lambda_0 and lambda_i. 351 | If solve_method is independencies, we don't need to do this. 352 | Only all_positive implemented right now. 353 | verbose: if True, print out messages to stderr as we make progress 354 | 355 | How we go about solving these probabilities (for Triplet method): 356 | * We assume that we have the joint distribution/class balance of our Y's (or can infer it 357 | from the dev set). 358 | * We observe agreements and disagreements between LF's, so we can compute values like 359 | P(\lambda_i \lambda_j = 1). 360 | * The only thing we need to estimate now are correlations between LF's and (unseen) Y's - 361 | values like P(\lambda_i Y_j = 1). 362 | * Luckily, we have P(\lambda_i Y_j = 1) = 1/2(1 + E[\lambda_i Y_j]). We refer to E[\lambda_i Y_j] 363 | as the accuracy of \lambda_i on Y_j. 364 | * And because of the format of our exponential model, we have: 365 | E[\lambda_i Y_j]E[\lambda_k Y_j] = E[\lambda_i Y_j \lambda_k Y_j] = E[\lambda_i \lambda_k] 366 | For any \lambda_i, \lambda_k that are conditionally independent given Y_j. This translates to 367 | Y_j being a separator of \lambda_i and \lambda_k in our graphical model. 368 | And we can observe E[\lambda_i \lambda_k] (the second moment) from L_train! 369 | * The algorithm proceeds to estimate the marginal probabilities by picking out triplets of 370 | conditionally-independent subsets of LF's, and estimating the accuracies of LF's on Y's. 371 | * Then, to recover the joint probabilities, we can solve a linear system B e = r (written out in latex): 372 | 373 | $$\begin{align*} 374 | \begin{bmatrix} 375 | 1 & 1 & 1 & 1 \\ 376 | 1 & 0 & 1 & 0 \\ 377 | 1 & 1 & 0 & 0 \\ 378 | 1 & 0 & 0 &1 379 | \end{bmatrix} 380 | \begin{bmatrix} 381 | p_{\lambda_i, Y_j}(+1, +1)\\ 382 | p_{\lambda_i, Y_j}(-1, +1) \\ 383 | p_{\lambda_i, Y_j}(+1, -1) \\ 384 | p_{\lambda_i, Y_j}(-1, -1) \end{bmatrix} = 385 | \begin{bmatrix} 1 \\ 386 | P(\lambda_{i} = 1) \\ 387 | P(Y_j = 1) \\ 388 | \rho_{i, j} \end{bmatrix} . 389 | \end{align*}$$ 390 | 391 | The values on the left of the equality are an invertible matrix, and values like 392 | P(\lambda_i = 1, Y_j = 1), P(\lambda_i = -1, Y_j = 1), etc for the full marginal probability. 393 | The values on the left of the equality are [1, P(\lambda_i = 1), P(Y_j = 1), P(\lambda_i = Y_j)]^T. 394 | We can observe or solve for all the values on the right, to solve for the values in the marginal 395 | probability! 396 | This can also be extended to multiple dimensions. 397 | 398 | Outputs: None. 399 | ''' 400 | # if abstentions not allowed, check for zero's 401 | if not self.allow_abstentions: 402 | if np.count_nonzero(L_train) < L_train.shape[0] * L_train.shape[1]: 403 | print('Abstentions not allowed!') 404 | return 405 | 406 | # Y marginals to compute 407 | Y_marginals = {} 408 | 409 | # lambda marginals to compute 410 | lambda_marginals = {} 411 | 412 | # marginals will eventually be returned here 413 | marginals = [ 414 | (clique, None) 415 | for clique in sorted(list(self.junction_tree.nodes)) + sorted(list(self.separator_sets)) 416 | ] 417 | 418 | def num_Ys(nodes): 419 | if nodes == tuple([1]) or nodes == tuple([0]): 420 | return 0 421 | return len([ 422 | node for node in nodes if 'Y' in node 423 | ]) 424 | 425 | def num_lambdas(nodes): 426 | if nodes == tuple([1]) or nodes == tuple([0]): 427 | return 0 428 | return len([ 429 | node for node in nodes if 'lambda' in node 430 | ]) 431 | 432 | observable_cliques = [] 433 | non_observable_cliques = [] 434 | 435 | for i, (clique, _) in enumerate(marginals): 436 | if num_Ys(clique) == 0 or num_lambdas(clique) == 0: 437 | observable_cliques.append(i) 438 | else: 439 | non_observable_cliques.append(i) 440 | 441 | # write down everything we need for the observable cliques 442 | for idx in observable_cliques: 443 | clique = marginals[idx][0] 444 | indices = tuple(sorted([ int(node.split('_')[1]) for node in clique ])) 445 | 446 | if 'Y' in clique[0]: 447 | if indices not in Y_marginals: 448 | Y_marginals[indices] = None 449 | else: 450 | if indices not in lambda_marginals: 451 | lambda_marginals[indices] = None 452 | 453 | if verbose: 454 | print('Marginals written down', file=sys.stderr) 455 | 456 | # for each marginal we need to estimate, write down the r vector that we need 457 | r_vecs = {} # mapping from clique index to the r vector 458 | r_vals = {} # mapping from a value name (like Y_1 or tuple(lambda_1, Y_1)) to its value 459 | for idx in non_observable_cliques: 460 | clique = list(reversed(sorted(marginals[idx][0]))) 461 | r_vec = self._generate_r_vector(clique) 462 | r_vecs[idx] = r_vec 463 | for r_val in r_vec: 464 | if r_val not in r_vals: 465 | r_vals[r_val] = None 466 | 467 | if verbose: 468 | print('R vector written down', file=sys.stderr) 469 | 470 | # write down all the sets of zero conditions 471 | lambda_zeros = {} 472 | 473 | # write down the moment values that we need to keep track of when we walk through the L matrix 474 | Y_equals_one = {} 475 | lambda_equals_one = {} 476 | 477 | # write down which expectations we need to solve using the triplet method 478 | expectations_to_estimate = set() 479 | for r_val in r_vals: 480 | if not self.allow_abstentions or r_val[1] == tuple(['0']): 481 | equals_one_tup = r_val if not self.allow_abstentions else r_val[0] 482 | 483 | if equals_one_tup[0] == '1': 484 | # If the value is 1, the probability is just 1 485 | r_vals[r_val] = 1 486 | elif num_Ys(equals_one_tup) != 0 and num_lambdas(equals_one_tup) != 0: 487 | # If this contains lambdas and Y's, we can't observe it 488 | expectations_to_estimate.add(r_val) 489 | elif num_Ys(equals_one_tup) != 0: 490 | # We need to cache this moment 491 | indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 492 | if indices not in Y_equals_one: 493 | Y_equals_one[indices] = None 494 | elif num_lambdas(equals_one_tup) != 0: 495 | # If it contains just lambdas, go through L_train 496 | indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 497 | if indices not in lambda_equals_one: 498 | lambda_equals_one[indices] = None 499 | else: 500 | # we allow abstentions, and there are clauses that are equal to zero 501 | equals_one_tup = r_val[0] 502 | equals_zero_tup = r_val[1] 503 | if num_lambdas(equals_one_tup) > 0 and num_Ys(equals_one_tup) > 0: 504 | # we can't observe this 505 | expectations_to_estimate.add(r_val) 506 | elif num_lambdas(equals_one_tup) > 0: 507 | # compute probability some lambda's multiply to one, subject to some zeros 508 | pos_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 509 | zero_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_zero_tup ])) 510 | 511 | tup = (pos_indices, zero_indices) 512 | if tup not in lambda_equals_one: 513 | lambda_equals_one[tup] = None 514 | if zero_indices not in lambda_zeros: 515 | lambda_zeros[zero_indices] = None 516 | else: 517 | # compute a Y equals one probability, and multiply it by probability of zeros 518 | if equals_one_tup[0] != '1': 519 | pos_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 520 | if pos_indices not in Y_equals_one: 521 | Y_equals_one[pos_indices] = None 522 | zero_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_zero_tup ])) 523 | if zero_indices not in lambda_zeros: 524 | lambda_zeros[zero_indices] = None 525 | 526 | if verbose: 527 | print('Expectations to estimate written down', file=sys.stderr) 528 | 529 | if solve_method[:len('triplet')] == 'triplet': 530 | triplets, new_moment_vals, abstention_probabilities = self._triplet_method_preprocess( 531 | expectations_to_estimate, solve_method) 532 | self.triplets = triplets 533 | elif solve_method == 'independencies': 534 | print('Independencies not implemented yet!') 535 | return 536 | 537 | if verbose: 538 | print('Triplets constructed', file=sys.stderr) 539 | 540 | lambda_moment_vals = {} 541 | for moment in new_moment_vals: 542 | if moment not in lambda_moment_vals: 543 | lambda_moment_vals[moment] = None 544 | 545 | # now time to compute all the Y marginals 546 | self.cb = self._compute_class_balance(class_balance, Y_dev) 547 | Y_marginals = self._compute_Y_marginals(Y_marginals) 548 | 549 | if verbose: 550 | print('Y marginals computed', file=sys.stderr) 551 | 552 | Y_equals_one = self._compute_Y_equals_one(Y_equals_one) 553 | 554 | if verbose: 555 | print('Y equals one computed', file=sys.stderr) 556 | 557 | self.Y_marginals = Y_marginals 558 | self.Y_equals_one = Y_equals_one 559 | 560 | # now time to compute the lambda moments, marginals, zero conditions, and abstention probs 561 | lambda_marginals, lambda_moment_vals, lambda_equals_one, lambda_zeros, abstention_probabilities = self._lambda_pass( 562 | L_train, lambda_marginals, lambda_moment_vals, lambda_equals_one, 563 | lambda_zeros, abstention_probabilities, verbose = verbose) 564 | 565 | if verbose: 566 | print('lambda marginals, moments, conditions computed', file=sys.stderr) 567 | 568 | self.lambda_marginals = lambda_marginals 569 | self.lambda_moment_vals = lambda_moment_vals 570 | self.lambda_equals_one = lambda_equals_one 571 | self.lambda_zeros = lambda_zeros 572 | self.abstention_probabilities = abstention_probabilities 573 | 574 | # put observable cliques in the right place 575 | for idx in observable_cliques: 576 | clique = marginals[idx][0] 577 | indices = tuple(sorted([ int(node.split('_')[1]) for node in clique ])) 578 | 579 | if 'Y' in clique[0]: 580 | marginal = Y_marginals[indices] 581 | else: 582 | marginal = lambda_marginals[indices] 583 | 584 | marginals[idx] = (clique, marginal) 585 | 586 | # get unobserved probabilities 587 | if solve_method[:len('triplet')] == 'triplet': 588 | probability_values, expectation_values = self._triplet_method_probabilities( 589 | triplets, lambda_moment_vals, lambda_zeros, 590 | abstention_probabilities, sign_recovery, solve_method) 591 | elif solve_method == 'independencies': 592 | print('Independencies not implemented yet!') 593 | return 594 | 595 | self.probability_values = probability_values 596 | self.expectation_values = expectation_values 597 | 598 | if verbose: 599 | print('Unobserved probabilities computed', file=sys.stderr) 600 | 601 | # put values into the R vectors 602 | for r_val in r_vals: 603 | if not self.allow_abstentions or r_val[1] == tuple(['0']): 604 | equals_one_tup = r_val if not self.allow_abstentions else r_val[0] 605 | 606 | if equals_one_tup[0] == '1': 607 | # If the value is 1, the probability is just 1 608 | pass 609 | elif num_Ys(equals_one_tup) != 0 and num_lambdas(equals_one_tup) != 0: 610 | # If this contains lambdas and Y's, we can't observe it 611 | r_vals[r_val] = probability_values[r_val] 612 | elif num_Ys(equals_one_tup) != 0: 613 | # We need to cache this moment 614 | indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 615 | r_vals[r_val] = Y_equals_one[indices] 616 | elif num_lambdas(equals_one_tup) != 0: 617 | indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 618 | r_vals[r_val] = lambda_equals_one[indices] 619 | else: 620 | # we allow abstentions, and there are clauses that are equal to zero 621 | equals_one_tup = r_val[0] 622 | equals_zero_tup = r_val[1] 623 | if num_lambdas(equals_one_tup) > 0 and num_Ys(equals_one_tup) > 0: 624 | # we can't observe this 625 | r_vals[r_val] = probability_values[r_val] 626 | elif num_lambdas(equals_one_tup) > 0: 627 | # compute lambda moment, subject to some zeros 628 | pos_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 629 | zero_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_zero_tup ])) 630 | 631 | tup = (pos_indices, zero_indices) 632 | r_vals[r_val] = lambda_equals_one[tup] 633 | else: 634 | # compute a Y moment, and multiply it by probability of zeros 635 | if equals_one_tup[0] != '1': 636 | pos_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_one_tup ])) 637 | 638 | pos_prob = Y_equals_one[pos_indices] 639 | else: 640 | pos_prob = 1. 641 | zero_indices = tuple(sorted([ int(node.split('_')[1]) for node in equals_zero_tup ])) 642 | zero_probs = lambda_zeros[zero_indices] 643 | 644 | r_vals[r_val] = pos_prob * zero_probs 645 | 646 | self.r_vals = r_vals 647 | 648 | if verbose: 649 | print('R values computed', file=sys.stderr) 650 | 651 | # solve for marginal values 652 | for idx in non_observable_cliques: 653 | clique = list(reversed(sorted(marginals[idx][0]))) 654 | r_vec = r_vecs[idx] 655 | 656 | r_vec_vals = np.array([ r_vals[exp] for exp in r_vec ]) 657 | 658 | # e_vec is the vector of marginal values 659 | e_vec = self._generate_e_vector(clique) 660 | 661 | b_matrix = self._generate_b_matrix(clique) 662 | 663 | e_vec_vals = np.linalg.inv(b_matrix) @ r_vec_vals 664 | 665 | e_vec_val_index = { tup: i for i, tup in enumerate(e_vec) } 666 | marginal_vals = np.array([ 667 | e_vec_vals[e_vec_val_index[tup]] 668 | for tup in sorted(e_vec) 669 | ]) 670 | 671 | if flip_negative: 672 | marginal_vals[marginal_vals < 0] = marginal_vals[marginal_vals < 0] * -1 673 | marginal_vals /= sum(marginal_vals) 674 | elif clamp: 675 | marginal_vals[marginal_vals < 0] = 1e-8 676 | marginal_vals /= sum(marginal_vals) 677 | 678 | indices = [ int(node.split('_')[1]) for node in clique ] 679 | lf_indices = sorted(indices[:-1]) 680 | Y_idx = indices[-1] 681 | 682 | variables = [ 'lambda_{}'.format(i) for i in lf_indices ] + [ 'Y_{}'.format(Y_idx) ] 683 | 684 | # cardinality 3 for lambda variables if you allow abstentions, 2 for Y's 685 | cardinalities = [ 686 | 3 if self.allow_abstentions else 2 687 | for i in range(len(lf_indices)) 688 | ] + [2] 689 | 690 | marginal = DiscreteFactor(variables, cardinalities, marginal_vals).normalize(inplace = False) 691 | 692 | marginals[idx] = (clique, marginal) 693 | 694 | self.clique_marginals = marginals[:len(self.junction_tree.nodes)] 695 | self.separator_marginals = marginals[len(self.junction_tree.nodes):] 696 | separator_degrees = { 697 | sep: 0 698 | for sep in self.separator_sets 699 | } 700 | for clique1, clique2 in self.junction_tree.edges: 701 | separator_degrees[tuple(sorted(list((set(clique1).intersection(set(clique2))))))] += 1 702 | self.separator_degrees = separator_degrees 703 | 704 | def reduce_marginal(self, marginal, data_point): 705 | lf_vals = [-1, 0, 1] if self.allow_abstentions else [-1, 1] 706 | params = [ 707 | (var, lf_vals.index(data_point[int(var.split('_')[1])])) 708 | for var in marginal.variables if 'lambda' in var 709 | ] 710 | return marginal.reduce(params, inplace=False) if len(params) > 0 else marginal 711 | 712 | def predict_proba(self, L_matrix, verbose=True): 713 | '''Predict the probabilities of the Y's given the outputs of the LF's. 714 | 715 | L_matrix: a m x |Y| matrix of of LF outputs. L_matrix[k][i] is the value of \lambda_i on item k. 716 | 1 means positive, -1 means negative, 0 means abstain. 717 | 718 | Let C be the set of all cliques in the graphical model, and S the set of all separator sets. 719 | Let d(s) for s \in S be the number of maximal cliques that s separates. 720 | 721 | Then, we have the following formula for the joint probability: 722 | 723 | P(\lambda_1, ..., \lambda_m, Y_1, ..., Y_v) = 724 | \prod_{c \in C} \mu_c(c) / \prod_{s \in S} [\mu_s(s)]^(d(s) - 1) 725 | 726 | Where \mu_c and \mu_s are the marginal probabilities of a clique c or a separator s, respectively. 727 | We solved for these marginals during the fit function, so now we use them for inference! 728 | 729 | Outputs: a 2^v x |Y| matrix of probabilities. The probabilities for the combinations are 730 | sorted lexicographically. 731 | ''' 732 | def num_lambdas(nodes): 733 | return len([ 734 | node for node in nodes if 'lambda' in node 735 | ]) 736 | 737 | L_matrix = np.array(L_matrix) 738 | 739 | Y_vecs = self.enumerate_ys() 740 | numerator_vals_by_lambda_count = [] 741 | max_lambda_count = max([ num_lambdas(clique) for clique, marginal in self.clique_marginals ]) 742 | 743 | # Compute all marginals that have lambda_count lambdas 744 | for lambda_count in range(1, max_lambda_count + 1): 745 | correct_lambda_cliques = [ 746 | (clique, marginal) 747 | for clique, marginal in self.clique_marginals if num_lambdas(clique) == lambda_count 748 | ] 749 | if len(correct_lambda_cliques) == 0: 750 | continue 751 | lambda_options = (-1, 0, 1) if self.allow_abstentions else (-1, 1) 752 | lambda_vals = { 753 | i: lambda_options 754 | for i in range(lambda_count) 755 | } 756 | lambda_vecs = sorted([ 757 | [ vec_dict[i] for i in range(lambda_count) ] 758 | for vec_dict in dict_product(lambda_vals) 759 | ]) 760 | 761 | # index by Y_vec, clique, and lambda value 762 | A_lambda = np.zeros((len(Y_vecs), len(correct_lambda_cliques), len(lambda_vecs))) 763 | 764 | for i, Y_vec in enumerate(Y_vecs): 765 | for j, (clique, marginal) in enumerate(correct_lambda_cliques): 766 | lambda_marginal = marginal.reduce( 767 | [ 768 | ('Y_{}'.format(Y_idx), y_val if y_val == 1 else 0) 769 | for Y_idx, y_val in enumerate(Y_vec) 770 | if 'Y_{}'.format(Y_idx) in clique 771 | ], 772 | inplace = False 773 | ) 774 | for k, lambda_vec in enumerate(lambda_vecs): 775 | A_lambda[i, j, k] = lambda_marginal.reduce( 776 | [ 777 | ( 778 | clique_node, 779 | lambda_options.index(lambda_val) 780 | ) 781 | for clique_node, lambda_val in zip(clique, lambda_vec) 782 | ], 783 | inplace=False).values 784 | 785 | indexes = np.array([ 786 | [ 787 | np.sum([ 788 | ((lambda_options.index(data_point[int(node.split('_')[1])])) * 789 | ((len(lambda_options)) ** (lambda_count - i - 1))) 790 | for i, node in enumerate(clique[:-1]) 791 | ]) 792 | for clique, marginal in correct_lambda_cliques 793 | ] 794 | for data_point in L_matrix 795 | ]).astype('int') 796 | 797 | clique_values = A_lambda[:, np.arange(indexes.shape[1]), indexes] 798 | 799 | numerator_values = np.prod(clique_values, axis=2) 800 | numerator_vals_by_lambda_count.append(numerator_values) 801 | 802 | # Compute all marginals that have zero lambdas 803 | zero_lambda_cliques = [ 804 | (clique, marginal) 805 | for clique, marginal in self.clique_marginals if num_lambdas(clique) == 0 806 | ] 807 | if len(zero_lambda_cliques) > 0: 808 | A_y = np.zeros((len(Y_vecs), len(zero_lambda_cliques))) 809 | for i, Y_vec in enumerate(Y_vecs): 810 | for j, (clique, marginal) in enumerate(zero_lambda_cliques): 811 | Y_marginal = marginal.reduce( 812 | [ 813 | ('Y_{}'.format(Y_idx), y_val if y_val == 1 else 0) 814 | for Y_idx, y_val in enumerate(Y_vec) 815 | if 'Y_{}'.format(Y_idx) in clique 816 | ], 817 | inplace = False 818 | ) 819 | A_y[i, j] = Y_marginal.values 820 | 821 | y_probs = np.prod(A_y, axis=1) 822 | 823 | numerator_ys = np.array([y_probs,] * L_matrix.shape[0]).T 824 | 825 | # Compute all separator marginals 826 | zero_lambda_separators = [ 827 | (clique, marginal) 828 | for clique, marginal in self.separator_marginals if num_lambdas(clique) == 0 829 | ] 830 | 831 | A_y_sep = np.zeros((len(Y_vecs), len(zero_lambda_separators))) 832 | for i, Y_vec in enumerate(Y_vecs): 833 | for j, (clique, marginal) in enumerate(zero_lambda_separators): 834 | Y_marginal = marginal.reduce( 835 | [ 836 | ('Y_{}'.format(Y_idx), y_val if y_val == 1 else 0) 837 | for Y_idx, y_val in enumerate(Y_vec) 838 | if 'Y_{}'.format(Y_idx) in clique 839 | ], 840 | inplace = False 841 | ) 842 | A_y_sep[i, j] = Y_marginal.values ** (self.separator_degrees[clique] - 1) 843 | 844 | y_probs_sep = np.prod(A_y_sep, axis=1) 845 | 846 | denominator_ys = np.array([y_probs_sep,] * L_matrix.shape[0]).T 847 | 848 | predictions = numerator_vals_by_lambda_count[0] 849 | for lambda_numerator in numerator_vals_by_lambda_count[1:]: 850 | predictions = predictions * lambda_numerator 851 | if len(zero_lambda_cliques) > 0: 852 | predictions = predictions * numerator_ys 853 | predictions = (predictions / denominator_ys).T 854 | 855 | # in the case of zero-sum predictions 856 | predictions[predictions.sum(axis = 1) == 0] += .001 857 | 858 | normalized_preds = predictions / np.array(([predictions.sum(axis = 1),] * len(Y_vecs))).T 859 | 860 | return normalized_preds 861 | 862 | def predict(self, L_matrix, verbose=True): 863 | '''Predict the value of the Y's that best fits the outputs of the LF's. 864 | 865 | L_matrix: a m x |Y| matrix of LF outputs. L_matrix[k][i] is the value of \lambda_i on item k. 866 | 1 means positive, -1 means negative, 0 means abstain. 867 | 868 | Let C be the set of all cliques in the graphical model, and S the set of all separator sets. 869 | Let d(s) for s \in S be the number of maximal cliques that s separates. 870 | 871 | Then, we have the following formula for the joint probability: 872 | 873 | P(\lambda_1, ..., \lambda_m, Y_1, ..., Y_v) = 874 | \prod_{c \in C} \mu_c(c) / \prod_{s \in S} [\mu_s(s)]^(d(s) - 1) 875 | 876 | Where \mu_c and \mu_s are the marginal probabilities of a clique c or a separator s, respectively. 877 | We solved for these marginals during the fit function, so now we use them for inference! 878 | 879 | Outputs: a v x |Y| matrix of predicted outputs. 880 | ''' 881 | 882 | Y_vecs = self.enumerate_ys() 883 | combination_probs = self.predict_proba(L_matrix, verbose=verbose) 884 | most_likely = np.argmax(combination_probs, axis=1) 885 | preds = np.array(Y_vecs)[most_likely] 886 | 887 | return preds 888 | 889 | def predict_proba_marginalized(self, L_matrix, verbose=False): 890 | '''Predict the probabilities of the Y's given the outputs of the LF's, marginalizing out all the 891 | Y values every time (return a separate probability for +1/-1 for each Y). 892 | 893 | L_matrix: a m x |Y| matrix of of LF outputs. L_matrix[k][i] is the value of \lambda_i on item k. 894 | 1 means positive, -1 means negative, 0 means abstain. 895 | 896 | Let C be the set of all cliques in the graphical model, and S the set of all separator sets. 897 | Let d(s) for s \in S be the number of maximal cliques that s separates. 898 | 899 | Then, we have the following formula for the joint probability: 900 | 901 | P(\lambda_1, ..., \lambda_m, Y_1, ..., Y_v) = 902 | \prod_{c \in C} \mu_c(c) / \prod_{s \in S} [\mu_s(s)]^(d(s) - 1) 903 | 904 | Where \mu_c and \mu_s are the marginal probabilities of a clique c or a separator s, respectively. 905 | We solved for these marginals during the fit function, so now we use them for inference! 906 | 907 | Outputs: a v x |Y| matrix of marginalized probabilities (one probability for each task, for each 908 | data point). 909 | ''' 910 | combination_probs = self.predict_proba(L_matrix, verbose=verbose) 911 | # construct indices for each task 912 | Y_vecs = self.enumerate_ys() 913 | task_indices = [ 914 | [ idx for idx, y_vec in enumerate(Y_vecs) if y_vec[i] == 1 ] 915 | for i in range(self.v) 916 | ] 917 | 918 | return np.sum(combination_probs[:, task_indices], axis=2).reshape(len(combination_probs) * self.v) 919 | 920 | def estimated_accuracies(self): 921 | '''Get the estimated accuracies of each LF. 922 | 923 | Assumes that each LF is connected to exactly one Y node. 924 | Let Y(i) denote the node that LF i is connected to. 925 | This function returns an array of values P(lambda_i = Y(i)), for each LF i. 926 | 927 | Outputs: a m-sized array of estimated LF accuracies. 928 | ''' 929 | if not self.probability_values: 930 | print('You need to train the label model first!') 931 | return 932 | 933 | accuracies = [] 934 | for i in range(self.m): 935 | lambda_node = 'lambda_{}'.format(i) 936 | Y_node = [ 937 | e2 938 | for e1, e2 in self.G.edges 939 | if e1 == lambda_node and 'Y' in e2 940 | ][0] 941 | if self.allow_abstentions: 942 | prob_key = ( 943 | (lambda_node, Y_node), ('0', ) 944 | ) if self.allow_abstentions else (lambda_node, Y_node) 945 | 946 | accuracies.append(self.probability_values[prob_key]) 947 | 948 | return accuracies 949 | -------------------------------------------------------------------------------- /flyingsquid/pytorch_loss.py: -------------------------------------------------------------------------------- 1 | from flyingsquid.label_model import LabelModel 2 | import torch 3 | import torch.nn as nn 4 | 5 | class FSLoss(nn.Module): 6 | ''' 7 | Expose FlyingSquid as a loss function. 8 | The loss function takes sequences: one sequence of outputs of your end model, 9 | and another sequence of LF votes. 10 | 11 | Let `v` be the length of the sequence. 12 | We will compute BCEWithLogitsLoss, averaged over every element of the 13 | sequence (and over every sequence in the batch). 14 | 15 | Let `m` be the number of labeling functions. 16 | 17 | Let `batch_size` be the size of your batch during training. 18 | 19 | The loss function will take two arguments: `outputs` and `weak_labels`. 20 | * The shape of `outputs` will be `batch_size * v` 21 | * The shape of `weak_labels` will be `batch_size * v * m` 22 | 23 | ``` 24 | # outputs will be batch_size * v 25 | # weak_labels will be batch_size * v * m 26 | loss(outputs, weak_labels) 27 | ``` 28 | 29 | The loss function will keep a buffer of N sequences of previous weak labels 30 | that it's seen. 31 | Each step, the loss function does the following: 32 | * For each sequence in the batch (zip over everything in outputs and weak_labels): 33 | * Add the sequence from `weak_labels` to the buffer (kicking out the oldest 34 | items in the buffer) 35 | * Use the triplet method over everything in the buffer (buffer needs to be on 36 | the CPU) to get probabilistic labels, a tensor of shape `T` (put the tensor 37 | onto device) 38 | * For each element in the sequence, compute `BCEWithLogitsLoss` between the 39 | output and the probabilistic label 40 | * Return the average over losses in the sequence 41 | 42 | When the dataloader isn't shuffling data, this amounts to "streaming" 43 | 44 | Args: 45 | m: number of LF's 46 | v: number of Y tasks 47 | task_deps: edges between the tasks. (i, j) in y_edges means that 48 | there is an edge between y_i and y_j. 49 | lf_task_deps: edges between LF's and tasks. (i, j) in lambda_y_edges 50 | means that there is an edge between lambda_i and y_j. 51 | lf_deps: edges between LF's. (i, j) in lambda_edges means that 52 | there is an edge between lambda_i and lambda_j. 53 | cb: the class balance 54 | allow_abstentions: if True, all abstentions in LF votes 55 | device: which device to store the loss/gradients 56 | buffer_capacity: how many sequences of LF's to cache 57 | update_frequency: how often to retrain the label model 58 | clamp_vals: if True, clamp the probabilities out of FlyingSquid to 0. 59 | or 1. 60 | triplets: if specified, use this set of triplets for the triplet method 61 | pos_weight: if specified, set the weight of the positive class to this 62 | in BCEWithLogitsLoss 63 | 64 | Example:: 65 | 66 | T = ... # length of a sequence 67 | m = m_per_task * T # m_per_task LF's per frame 68 | 69 | # this creates a new triplet label model under the hood 70 | criterion = FSLoss( 71 | m, T, 72 | [(i, i + 1) for i in range(T - 1)], # chain dependencies for tasks 73 | [(i + m_per_task * j, j) # LF's have dependencies to the frames they vote on 74 | for i in range(m_per_task) for j in range(v)], 75 | [], # no dependencies between LF's 76 | cb = ... # pass in class balance if you need to 77 | ) 78 | 79 | model = ... # end model 80 | 81 | frame_sequence = [...] # sequence of T frames 82 | lf_votes = [...] # (T, m) vector of LF votes 83 | 84 | model_outputs = [ # run the model on each frame 85 | model(frame) 86 | for frame in frame_sequence 87 | ] 88 | 89 | # This caches the votes in lf_votes, retrains the label model if necessary, and 90 | # generates probabilistic labels for each frame from the LF votes. 91 | # Then, `BCEWithLogitsLoss` on the model outputs and probabilistic labels is used 92 | # to generate the loss value that can be backpropped. 93 | loss = criterion( 94 | torch.tensor([model_outputs]), 95 | torch.tensor([lf_votes]) 96 | ) 97 | loss.backward() 98 | ''' 99 | 100 | def __init__(self, m, v=1, task_deps=[], lf_task_deps=[], lf_deps=[], 101 | Y_dev=None, cb=None, allow_abstentions = True, device='cpu', 102 | buffer_capacity=100, update_frequency=10, clamp_vals=False, 103 | triplets=None, pos_weight=None): 104 | super(WSLoss, self).__init__() 105 | self.m = m 106 | self.v = v 107 | self.task_deps = task_deps 108 | self.lf_task_deps = lf_task_deps 109 | if self.lf_task_deps == []: 110 | self.lf_task_deps = [(i, 0) for i in range(m)] 111 | self.lf_deps = lf_deps 112 | self.Y_dev = Y_dev 113 | self.cb = cb 114 | self.device = device 115 | self.clamp_vals = clamp_vals 116 | 117 | self.lm = LabelModel(m, v=v, y_edges=task_deps, lambda_y_edges=lf_task_deps, 118 | lambda_edges=lf_deps, allow_abstentions = allow_abstentions, 119 | triplets=triplets) 120 | 121 | self.criterion = nn.BCEWithLogitsLoss() if pos_weight is None else nn.BCEWithLogitsLoss(pos_weight = pos_weight) 122 | self.buffer_capacity = buffer_capacity 123 | self.update_frequency = update_frequency 124 | 125 | # register buffer for LF outputs 126 | self.register_buffer('lf_buffer', torch.zeros((buffer_capacity, m), dtype=torch.long)) 127 | 128 | # register buffer to keep track of how many items 129 | self.register_buffer('buffer_size', torch.zeros(1, dtype=torch.long)) 130 | 131 | # reigster buffer to keep track of where you are 132 | self.register_buffer('buffer_index', torch.zeros(1, dtype=torch.long)) 133 | 134 | def forward(self, predictions, weak_labels, update_frequency = None): 135 | ''' 136 | Generate probabilistic labels from the weak labels, and use `BCEWithLogitsLoss` to 137 | get the actual loss value for end model training. 138 | Also caches the LF votes, and re-trains the label model if necessary (depending on 139 | update_frequency). 140 | 141 | Args: 142 | predictions: A (batch_size, v)-sized tensor of model outputs. For sequences, 143 | v is usually the length of the sequence. 144 | weak_labels: A (batch_size, m)-sized tensor of weak labels. 145 | 146 | Returns: 147 | Computes BCEWithLogitsLoss on every item in the batch (for each item, computes it 148 | between the v model outputs and the v probabilistic labels), and returns the 149 | average. 150 | ''' 151 | update_frequency = update_frequency if update_frequency else self.update_frequency 152 | 153 | output = torch.tensor(0., requires_grad=True, device=self.device) 154 | 155 | for i, (prediction, label_vector) in enumerate(zip(predictions, weak_labels)): 156 | self.lf_buffer[self.buffer_index] = label_vector 157 | if self.buffer_size < self.buffer_capacity: 158 | self.buffer_size += 1 159 | 160 | if (self.buffer_index % update_frequency) == 0: 161 | L_train = self.lf_buffer.cpu().numpy()[:self.buffer_size] 162 | 163 | self.lm.fit( 164 | L_train, 165 | Y_dev = self.Y_dev, 166 | class_balance = self.cb 167 | ) 168 | 169 | self.buffer_index += 1 170 | if self.buffer_index == self.buffer_capacity: 171 | self.buffer_index = torch.tensor(0) 172 | 173 | labels = self.lm.predict_proba_marginalized( 174 | [label_vector.cpu().numpy()], verbose=False) 175 | if self.clamp_vals: 176 | labels[0] = [1. if pred >= 0.5 else 0. for pred in labels[0]] 177 | 178 | label_tensor = torch.tensor(labels[0], requires_grad=True, device=self.device).view(prediction.shape) 179 | 180 | output = output + self.criterion( 181 | prediction, 182 | label_tensor) 183 | 184 | return output / predictions.shape[0] 185 | 186 | 187 | class MajorityVoteLoss(nn.Module): 188 | ''' 189 | Expose majority vote as a loss function (for baselines). 190 | 191 | Let `m` be the number of labeling functions. 192 | 193 | Let `batch_size` be the size of your batch during training. 194 | 195 | The loss function will take two arguments: `outputs` and `weak_labels`. 196 | * The shape of `outputs` will be `batch_size` 197 | * The shape of `weak_labels` will be `batch_size * m` 198 | 199 | ``` 200 | # outputs will be batch_size 201 | # weak_labels will be batch_size * m 202 | loss(outputs, weak_labels) 203 | ``` 204 | 205 | The loss function will keep a buffer of N sequences of previous weak labels 206 | that it's seen. 207 | Each step, the loss function does the following: 208 | * For each sequence in the batch (zip over everything in outputs and weak_labels): 209 | * Add the sequence from `weak_labels` to the buffer (kicking out the oldest 210 | items in the buffer) 211 | * Use the triplet method over everything in the buffer (buffer needs to be on 212 | the CPU) to get probabilistic labels, a tensor of shape `T` (put the tensor 213 | onto device) 214 | * For each element in the sequence, compute `BCEWithLogitsLoss` between the 215 | output and the probabilistic label 216 | * Return the average over losses in the sequence 217 | 218 | When the dataloader isn't shuffling data, this amounts to "streaming" 219 | 220 | Args: 221 | device: which device to store the loss/gradients 222 | 223 | Example:: 224 | 225 | m = ... # number of LF's 226 | 227 | # this creates a new triplet label model under the hood 228 | criterion = MajorityVoteLoss( 229 | device = ... 230 | ) 231 | 232 | model = ... # end model 233 | 234 | frame_sequence = [...] # sequence of T frames 235 | lf_votes = [...] # (T, m) vector of LF votes 236 | 237 | model_outputs = [ # run the model on each frame 238 | model(frame) 239 | for frame in frame_sequence 240 | ] 241 | 242 | # This caches the votes in lf_votes, retrains the label model if necessary, and 243 | # generates probabilistic labels for each frame from the LF votes. 244 | # Then, `BCEWithLogitsLoss` on the model outputs and probabilistic labels is used 245 | # to generate the loss value that can be backpropped. 246 | loss = criterion( 247 | torch.tensor(model_outputs), 248 | torch.tensor(lf_votes) 249 | ) 250 | loss.backward() 251 | ''' 252 | 253 | def __init__(self, device='cpu', pos_weight=None): 254 | super(MajorityVoteLoss, self).__init__() 255 | self.criterion = nn.BCEWithLogitsLoss() if pos_weight is None else nn.BCEWithLogitsLoss(pos_weight = pos_weight) 256 | self.device = device 257 | 258 | def forward(self, predictions, weak_labels, update_frequency = None): 259 | ''' 260 | Generate probabilistic labels from the weak labels, and use `BCEWithLogitsLoss` to 261 | get the actual loss value for end model training. 262 | Also caches the LF votes, and re-trains the label model if necessary (depending on 263 | update_frequency). 264 | 265 | Args: 266 | predictions: A (batch_size)-sized tensor of model outputs. 267 | weak_labels: A (batch_size, m)-sized tensor of weak labels. 268 | 269 | Returns: 270 | Computes BCEWithLogitsLoss on every item in the batch (for each item, computes it 271 | between the v model outputs and the v probabilistic labels), and returns the 272 | average. 273 | ''' 274 | 275 | output = torch.tensor(0., requires_grad=True, device=self.device) 276 | 277 | for i, (prediction, label_vector) in enumerate(zip(predictions, weak_labels)): 278 | label = (np.sum(label_vector.cpu().numpy()) > 0).astype(float) 279 | 280 | label_tensor = torch.tensor(label, requires_grad=True, device=self.device).view(prediction.shape) 281 | 282 | output = output + self.criterion( 283 | prediction, 284 | label_tensor) 285 | 286 | return output / predictions.shape[0] 287 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup(name='flyingsquid', 5 | version='0.0.0a0', 6 | description='Weak supervision with triplet methods', 7 | url='https://github.com/HazyResearch/flyingsquid', 8 | author='Dan Fu', 9 | author_email='danfu@cs.stanford.edu', 10 | license='Apache 2.0', 11 | packages=['flyingsquid']) 12 | --------------------------------------------------------------------------------