├── .github └── FUNDING.yml ├── README.md ├── bridges.csv ├── cache └── README.md ├── companies.sqlite ├── iris.csv ├── irnet ├── .dockerignore ├── Dockerfile ├── Makefile ├── README.md ├── install.sh ├── requirements.txt ├── serve.sh ├── server │ ├── add_csv.py │ ├── add_question.py │ ├── download.sh │ ├── prediction_server.py │ ├── requirements.txt │ └── setup_nltk.py └── setup.sh ├── players.csv ├── sqlova ├── .dockerignore ├── Dockerfile ├── Makefile ├── README.md ├── fetch_models.sh ├── run_services.sh └── support │ ├── bert_config_uncased_L-12_H-768_A-12.json │ ├── bert_config_uncased_L-24_H-1024_A-16.json │ ├── vocab_uncased_L-12_H-768_A-12.txt │ └── vocab_uncased_L-24_H-1024_A-16.txt ├── test.sh └── valuenet ├── Dockerfile ├── Makefile ├── example_data └── original │ ├── database │ └── data │ │ └── data.sqlite │ └── tables.json ├── install.sh ├── requirements.txt ├── serve.sh ├── server ├── add_csv.py ├── add_question.py └── prediction_server.py └── setup.sh /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [paulfitz] 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Infer SQL queries from plain-text questions and table headers. 2 | 3 | Requirements: 4 | * install `docker` 5 | * install `curl` (or, if you're feeling brave, [asql](https://github.com/paulfitz/asql)) 6 | * Make sure docker allows at least 3GB of RAM (see `Docker`>`Preferences`>`Advanced` 7 | or equivalent) for SQLova, or 5GB for IRNet or ValueNet. 8 | 9 | I take pretrained models published along with academic papers, and do whatever it takes 10 | to make them testable on fresh data (academic work often omits that, with code tied 11 | to a particular benchmark dataset). I spend days tracking down and patching obscure 12 | data preprocessing steps so you don't have to. 13 | 14 | ![ValueNet example](https://user-images.githubusercontent.com/118367/89111827-75e5db80-d428-11ea-912a-e36a176bb56c.png) 15 | 16 | So far I've packaged three models: 17 | * [SQLova](#sqlova). Works on single tables. 18 | * [ValueNet](#valuenet). Works on multiple tables, and 19 | makes an effort to predict parameters. 20 | * [IRNet](#irnet). Works on multiple tables, but doesn't 21 | predict parameters. 22 | 23 | In each case, I've mangled the original network somewhat, so if they interest you do follow up 24 | with the original sources. 25 | 26 | ## SQLova 27 | 28 | This wraps up a published pretrained model for SQLova (https://github.com/naver/sqlova/). 29 | 30 | Fetch and start SQLova running as an api server on port 5050: 31 | 32 | ``` 33 | docker run --name sqlova -d -p 5050:5050 paulfitz/sqlova 34 | ``` 35 | 36 | Be patient, the image is about 4.2GB. Once it is running, it'll take a few seconds 37 | to load models and then you can start asking questions about CSV tables. For example: 38 | 39 | ``` 40 | curl -F "csv=@bridges.csv" -F "q=how long is throgs neck" localhost:5050 41 | # {"answer":[1800],"params":["throgs neck"],"sql":"SELECT (length) FROM bridges WHERE bridge = ?"} 42 | ``` 43 | 44 | This is using the sample `bridges.csv` included in this repo. 45 | 46 | | bridge | designer | length | 47 | |---|---|---| 48 | | Brooklyn | J. A. Roebling | 1595 | 49 | | Manhattan | G. Lindenthal | 1470 | 50 | | Williamsburg | L. L. Buck | 1600 | 51 | | Queensborough | Palmer & Hornbostel | 1182 | 52 | | Triborough | O. H. Ammann | 1380,383 | 53 | | Bronx Whitestone | O. H. Ammann | 2300 | 54 | | Throgs Neck | O. H. Ammann | 1800 | 55 | | George Washington | O. H. Ammann | 3500 | 56 | 57 | Here are some examples of the answers and sql inferred for plain-text questions about 58 | this table: 59 | 60 | | question | answer | sql | 61 | |---|---|---| 62 | | how long is throgs neck | 1800 | `SELECT (length) FROM bridges WHERE bridge = ? ['throgs neck']` | 63 | | who designed the george washington | O. H. Ammann | `SELECT (designer) FROM bridges WHERE bridge = ? ['george washington']` | 64 | | how many bridges are there | 8 | `SELECT count(bridge) FROM bridges` | 65 | | how many bridges are designed by O. H. Ammann | 4 | `SELECT count(bridge) FROM bridges WHERE designer = ? ['O. H. Ammann']` | 66 | | which bridge are longer than 2000 | Bronx Whitestone, George Washington | `SELECT (bridge) FROM bridges WHERE length > ? ['2000']` | 67 | | how many bridges are longer than 2000 | 2 | `SELECT count(bridge) FROM bridges WHERE length > ? ['2000']` | 68 | | what is the shortest length | 1182 | `SELECT min(length) FROM bridges` | 69 | 70 | With the `players.csv` sample from WikiSQL: 71 | 72 | | Player | No. | Nationality | Position | Years in Toronto | School/Club Team | 73 | |---|---|---|---|---|---| 74 | | Antonio Lang | 21 | United States | Guard-Forward | 1999-2000 | Duke | 75 | | Voshon Lenard | 2 | United States | Guard | 2002-03 | Minnesota | 76 | | Martin Lewis | 32, 44 | United States | Guard-Forward | 1996-97 | Butler CC (KS) | 77 | | Brad Lohaus | 33 | United States | Forward-Center | 1996 | Iowa | 78 | | Art Long | 42 | United States | Forward-Center | 2002-03 | Cincinnati | 79 | | John Long | 25 | United States | Guard | 1996-97 | Detroit | 80 | | Kyle Lowry | 3 | United States | Guard | 2012-present | Villanova | 81 | 82 | | question | answer | sql | 83 | |---|---|---| 84 | | What number did the person playing for Duke wear? | 21 | `SELECT (No.) FROM players WHERE School/Club Team = ? ['duke']` | 85 | | Who is the player that wears number 42? | Art Long | `SELECT (Player) FROM players WHERE No. = ? ['42']` | 86 | | What year did Brad Lohaus play? | 1996 | `SELECT (Years in Toronto) FROM players WHERE Player = ? ['brad lohaus']` | 87 | | What country is Voshon Lenard from? | United States | `SELECT (Nationality) FROM players WHERE Player = ? ['voshon lenard']` | 88 | 89 | Some questions about [iris.csv](https://en.wikipedia.org/wiki/Iris_flower_data_set): 90 | 91 | | question | answer | sql | 92 | |---|---|---| 93 | | what is the average petal width for virginica | 2.026 | `SELECT avg(Petal.Width) FROM iris WHERE Species = ? ['virginica']` | 94 | | what is the longest sepal for versicolor | 7.0 | `SELECT max(Sepal.Length) FROM iris WHERE Species = ? ['versicolor']` | 95 | | how many setosa rows are there | 50 | `SELECT count(col0) FROM iris WHERE Species = ? ['setosa']` | 96 | 97 | There are plenty of types of questions this model cannot answer (and that aren't covered 98 | in the dataset it is trained on, or in the sql it is permitted to generate). 99 | 100 | ## ValueNet 101 | 102 | This wraps up a published pretrained model for ValueNet (https://github.com/brunnurs/valuenet). 103 | 104 | Fetch and start ValueNet running as an api server on port 5050: 105 | 106 | ``` 107 | docker run --name valuenet -d -p 5050:5050 paulfitz/valuenet 108 | ``` 109 | 110 | You can then ask questions of individual csv files as before, or several csv files 111 | (just repeat `-F "csv=@fileN.csv"`) or a simple sqlite db with tables related by foreign keys. 112 | In this last case, the model can answer using joins. 113 | 114 | ``` 115 | curl -F "sqlite=@companies.sqlite" -F "q=who is the CEO of Omni Cooperative" localhost:5050 116 | # {"answer":[["Dracula"]], "sql":"SELECT T1.name FROM people AS T1 JOIN organizations AS T2 \ 117 | # ON T1.id = T2.ceo_id WHERE T2.company = 'Omni Cooperative'"} 118 | curl -F "csv=@bridges.csv" -F "q=how many designers are there?" localhost:5050 119 | # {"answer":[[5]],"sql":"SELECT DISTINCT count(DISTINCT T1.designer) FROM bridges AS T1"} 120 | curl -F "csv=@bridges.csv" -F "csv=@airports.csv" -F "q=how many designers are there?" localhost:5050 121 | # same answer 122 | curl -F "csv=@bridges.csv" -F "csv=@airports.csv" -F "q=what is the name of the airport with the highest latitude?" localhost:5050 123 | # {"answer":[["Disraeli Inlet Water Aerodrome"]], 124 | # "sql":"SELECT T1.name FROM airports AS T1 ORDER BY T1.latitude_deg DESC LIMIT 1"} 125 | ``` 126 | 127 | I've includes material to convert user tables into the form needed to query them. Don't 128 | judge the network by its quality here, go do a deep dive with the original - I've deviated 129 | from the original in important respects, including how named entity recognition is done. 130 | 131 | I've written up [some experiments with ValueNet](https://paulfitz.github.io/2020/08/01/translate-english-to-sql-progress-updates.html). 132 | 133 | ## IRNet 134 | 135 | This wraps up a published pretrained model for IRNet (https://github.com/microsoft/IRNet). 136 | Upstream released a better model after I packaged this, so don't judge the model by playing 137 | with it here. 138 | 139 | Fetch and start IRNet running as an api server on port 5050: 140 | 141 | ``` 142 | docker run --name irnet -d -p 5050:5050 -v $PWD/cache:/cache paulfitz/irnet 143 | ``` 144 | 145 | Be super patient! Especially on the first run, when a few large models need to 146 | be downloaded and unpacked. 147 | 148 | You can then ask questions of individual csv files as before, or several csv files 149 | (just repeat `-F "csv=@fileN.csv"`) or a simple sqlite db with tables related by foreign keys. 150 | In this last case, the model can answer using joins. 151 | 152 | ``` 153 | curl -F "sqlite=@companies.sqlite" -F "q=what city is The Firm headquartered in?" localhost:5050 154 | # Answer: SELECT T1.city FROM locations AS T1 JOIN organizations AS T2 WHERE T2.company = 1 155 | curl -F "sqlite=@companies.sqlite" -F "q=who is the CEO of Omni Cooperative" localhost:5050 156 | # Answer: SELECT T1.name FROM people AS T1 JOIN organizations AS T2 WHERE T2.company = 1 157 | curl -F "sqlite=@companies.sqlite" -F "q=what company has Dracula as CEO" localhost:5050 158 | # Answer: SELECT T1.company FROM organizations AS T1 JOIN people AS T2 WHERE T2.name = 1 159 | ``` 160 | 161 | (Note there's no value prediction, so e.g. the where clauses are `= 1` rather than something 162 | more useful). 163 | 164 | ## Postman users 165 | 166 | Curl can be replaced by Postman for those who like that. Here's a working set-up: 167 | ![Postman version](https://user-images.githubusercontent.com/118367/73127529-b05d5000-3f8f-11ea-8499-b58273ca1961.png) 168 | 169 | ## Other models 170 | 171 | I hope to track research in the area and substitute in models as they become available: 172 | 173 | * [Spider leaderboard](https://yale-lily.github.io/spider) 174 | * [WikiSQL leaderboard](https://github.com/salesforce/WikiSQL#leaderboard) 175 | * [SparC leaderboard](https://yale-lily.github.io/sparc) 176 | 177 | * [RAT-SQL](https://github.com/Microsoft/rat-sql) 178 | 179 | ## Live demoes 180 | 181 | * [Photon](https://naturalsql.com/) by the SalesForce group is a good live demo of 182 | text-to-SQL. 183 | -------------------------------------------------------------------------------- /bridges.csv: -------------------------------------------------------------------------------- 1 | bridge,designer,length 2 | Brooklyn,J. A. Roebling,1595 3 | Manhattan,G. Lindenthal,1470 4 | Williamsburg,L. L. Buck,1600 5 | Queensborough,Palmer & Hornbostel,1182 6 | Triborough,O. H. Ammann,"1380,383" 7 | Bronx Whitestone,O. H. Ammann,2300 8 | Throgs Neck,O. H. Ammann,1800 9 | George Washington,O. H. Ammann,3500 10 | -------------------------------------------------------------------------------- /cache/README.md: -------------------------------------------------------------------------------- 1 | This is a cache for material that is too bulky to place in a 2 | docker image, and so gets downloaded on first use. 3 | -------------------------------------------------------------------------------- /companies.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulfitz/mlsql/2f2f9cff35dce24580b06085072b44a8ce17fb54/companies.sqlite -------------------------------------------------------------------------------- /iris.csv: -------------------------------------------------------------------------------- 1 | "","Sepal.Length","Sepal.Width","Petal.Length","Petal.Width","Species" 2 | "1",5.1,3.5,1.4,0.2,"setosa" 3 | "2",4.9,3,1.4,0.2,"setosa" 4 | "3",4.7,3.2,1.3,0.2,"setosa" 5 | "4",4.6,3.1,1.5,0.2,"setosa" 6 | "5",5,3.6,1.4,0.2,"setosa" 7 | "6",5.4,3.9,1.7,0.4,"setosa" 8 | "7",4.6,3.4,1.4,0.3,"setosa" 9 | "8",5,3.4,1.5,0.2,"setosa" 10 | "9",4.4,2.9,1.4,0.2,"setosa" 11 | "10",4.9,3.1,1.5,0.1,"setosa" 12 | "11",5.4,3.7,1.5,0.2,"setosa" 13 | "12",4.8,3.4,1.6,0.2,"setosa" 14 | "13",4.8,3,1.4,0.1,"setosa" 15 | "14",4.3,3,1.1,0.1,"setosa" 16 | "15",5.8,4,1.2,0.2,"setosa" 17 | "16",5.7,4.4,1.5,0.4,"setosa" 18 | "17",5.4,3.9,1.3,0.4,"setosa" 19 | "18",5.1,3.5,1.4,0.3,"setosa" 20 | "19",5.7,3.8,1.7,0.3,"setosa" 21 | "20",5.1,3.8,1.5,0.3,"setosa" 22 | "21",5.4,3.4,1.7,0.2,"setosa" 23 | "22",5.1,3.7,1.5,0.4,"setosa" 24 | "23",4.6,3.6,1,0.2,"setosa" 25 | "24",5.1,3.3,1.7,0.5,"setosa" 26 | "25",4.8,3.4,1.9,0.2,"setosa" 27 | "26",5,3,1.6,0.2,"setosa" 28 | "27",5,3.4,1.6,0.4,"setosa" 29 | "28",5.2,3.5,1.5,0.2,"setosa" 30 | "29",5.2,3.4,1.4,0.2,"setosa" 31 | "30",4.7,3.2,1.6,0.2,"setosa" 32 | "31",4.8,3.1,1.6,0.2,"setosa" 33 | "32",5.4,3.4,1.5,0.4,"setosa" 34 | "33",5.2,4.1,1.5,0.1,"setosa" 35 | "34",5.5,4.2,1.4,0.2,"setosa" 36 | "35",4.9,3.1,1.5,0.2,"setosa" 37 | "36",5,3.2,1.2,0.2,"setosa" 38 | "37",5.5,3.5,1.3,0.2,"setosa" 39 | "38",4.9,3.6,1.4,0.1,"setosa" 40 | "39",4.4,3,1.3,0.2,"setosa" 41 | "40",5.1,3.4,1.5,0.2,"setosa" 42 | "41",5,3.5,1.3,0.3,"setosa" 43 | "42",4.5,2.3,1.3,0.3,"setosa" 44 | "43",4.4,3.2,1.3,0.2,"setosa" 45 | "44",5,3.5,1.6,0.6,"setosa" 46 | "45",5.1,3.8,1.9,0.4,"setosa" 47 | "46",4.8,3,1.4,0.3,"setosa" 48 | "47",5.1,3.8,1.6,0.2,"setosa" 49 | "48",4.6,3.2,1.4,0.2,"setosa" 50 | "49",5.3,3.7,1.5,0.2,"setosa" 51 | "50",5,3.3,1.4,0.2,"setosa" 52 | "51",7,3.2,4.7,1.4,"versicolor" 53 | "52",6.4,3.2,4.5,1.5,"versicolor" 54 | "53",6.9,3.1,4.9,1.5,"versicolor" 55 | "54",5.5,2.3,4,1.3,"versicolor" 56 | "55",6.5,2.8,4.6,1.5,"versicolor" 57 | "56",5.7,2.8,4.5,1.3,"versicolor" 58 | "57",6.3,3.3,4.7,1.6,"versicolor" 59 | "58",4.9,2.4,3.3,1,"versicolor" 60 | "59",6.6,2.9,4.6,1.3,"versicolor" 61 | "60",5.2,2.7,3.9,1.4,"versicolor" 62 | "61",5,2,3.5,1,"versicolor" 63 | "62",5.9,3,4.2,1.5,"versicolor" 64 | "63",6,2.2,4,1,"versicolor" 65 | "64",6.1,2.9,4.7,1.4,"versicolor" 66 | "65",5.6,2.9,3.6,1.3,"versicolor" 67 | "66",6.7,3.1,4.4,1.4,"versicolor" 68 | "67",5.6,3,4.5,1.5,"versicolor" 69 | "68",5.8,2.7,4.1,1,"versicolor" 70 | "69",6.2,2.2,4.5,1.5,"versicolor" 71 | "70",5.6,2.5,3.9,1.1,"versicolor" 72 | "71",5.9,3.2,4.8,1.8,"versicolor" 73 | "72",6.1,2.8,4,1.3,"versicolor" 74 | "73",6.3,2.5,4.9,1.5,"versicolor" 75 | "74",6.1,2.8,4.7,1.2,"versicolor" 76 | "75",6.4,2.9,4.3,1.3,"versicolor" 77 | "76",6.6,3,4.4,1.4,"versicolor" 78 | "77",6.8,2.8,4.8,1.4,"versicolor" 79 | "78",6.7,3,5,1.7,"versicolor" 80 | "79",6,2.9,4.5,1.5,"versicolor" 81 | "80",5.7,2.6,3.5,1,"versicolor" 82 | "81",5.5,2.4,3.8,1.1,"versicolor" 83 | "82",5.5,2.4,3.7,1,"versicolor" 84 | "83",5.8,2.7,3.9,1.2,"versicolor" 85 | "84",6,2.7,5.1,1.6,"versicolor" 86 | "85",5.4,3,4.5,1.5,"versicolor" 87 | "86",6,3.4,4.5,1.6,"versicolor" 88 | "87",6.7,3.1,4.7,1.5,"versicolor" 89 | "88",6.3,2.3,4.4,1.3,"versicolor" 90 | "89",5.6,3,4.1,1.3,"versicolor" 91 | "90",5.5,2.5,4,1.3,"versicolor" 92 | "91",5.5,2.6,4.4,1.2,"versicolor" 93 | "92",6.1,3,4.6,1.4,"versicolor" 94 | "93",5.8,2.6,4,1.2,"versicolor" 95 | "94",5,2.3,3.3,1,"versicolor" 96 | "95",5.6,2.7,4.2,1.3,"versicolor" 97 | "96",5.7,3,4.2,1.2,"versicolor" 98 | "97",5.7,2.9,4.2,1.3,"versicolor" 99 | "98",6.2,2.9,4.3,1.3,"versicolor" 100 | "99",5.1,2.5,3,1.1,"versicolor" 101 | "100",5.7,2.8,4.1,1.3,"versicolor" 102 | "101",6.3,3.3,6,2.5,"virginica" 103 | "102",5.8,2.7,5.1,1.9,"virginica" 104 | "103",7.1,3,5.9,2.1,"virginica" 105 | "104",6.3,2.9,5.6,1.8,"virginica" 106 | "105",6.5,3,5.8,2.2,"virginica" 107 | "106",7.6,3,6.6,2.1,"virginica" 108 | "107",4.9,2.5,4.5,1.7,"virginica" 109 | "108",7.3,2.9,6.3,1.8,"virginica" 110 | "109",6.7,2.5,5.8,1.8,"virginica" 111 | "110",7.2,3.6,6.1,2.5,"virginica" 112 | "111",6.5,3.2,5.1,2,"virginica" 113 | "112",6.4,2.7,5.3,1.9,"virginica" 114 | "113",6.8,3,5.5,2.1,"virginica" 115 | "114",5.7,2.5,5,2,"virginica" 116 | "115",5.8,2.8,5.1,2.4,"virginica" 117 | "116",6.4,3.2,5.3,2.3,"virginica" 118 | "117",6.5,3,5.5,1.8,"virginica" 119 | "118",7.7,3.8,6.7,2.2,"virginica" 120 | "119",7.7,2.6,6.9,2.3,"virginica" 121 | "120",6,2.2,5,1.5,"virginica" 122 | "121",6.9,3.2,5.7,2.3,"virginica" 123 | "122",5.6,2.8,4.9,2,"virginica" 124 | "123",7.7,2.8,6.7,2,"virginica" 125 | "124",6.3,2.7,4.9,1.8,"virginica" 126 | "125",6.7,3.3,5.7,2.1,"virginica" 127 | "126",7.2,3.2,6,1.8,"virginica" 128 | "127",6.2,2.8,4.8,1.8,"virginica" 129 | "128",6.1,3,4.9,1.8,"virginica" 130 | "129",6.4,2.8,5.6,2.1,"virginica" 131 | "130",7.2,3,5.8,1.6,"virginica" 132 | "131",7.4,2.8,6.1,1.9,"virginica" 133 | "132",7.9,3.8,6.4,2,"virginica" 134 | "133",6.4,2.8,5.6,2.2,"virginica" 135 | "134",6.3,2.8,5.1,1.5,"virginica" 136 | "135",6.1,2.6,5.6,1.4,"virginica" 137 | "136",7.7,3,6.1,2.3,"virginica" 138 | "137",6.3,3.4,5.6,2.4,"virginica" 139 | "138",6.4,3.1,5.5,1.8,"virginica" 140 | "139",6,3,4.8,1.8,"virginica" 141 | "140",6.9,3.1,5.4,2.1,"virginica" 142 | "141",6.7,3.1,5.6,2.4,"virginica" 143 | "142",6.9,3.1,5.1,2.3,"virginica" 144 | "143",5.8,2.7,5.1,1.9,"virginica" 145 | "144",6.8,3.2,5.9,2.3,"virginica" 146 | "145",6.7,3.3,5.7,2.5,"virginica" 147 | "146",6.7,3,5.2,2.3,"virginica" 148 | "147",6.3,2.5,5,1.9,"virginica" 149 | "148",6.5,3,5.2,2,"virginica" 150 | "149",6.2,3.4,5.4,2.3,"virginica" 151 | "150",5.9,3,5.1,1.8,"virginica" 152 | -------------------------------------------------------------------------------- /irnet/.dockerignore: -------------------------------------------------------------------------------- 1 | cache 2 | venv 3 | spider 4 | IRNet 5 | WikiSQL 6 | -------------------------------------------------------------------------------- /irnet/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN \ 4 | apt update && \ 5 | apt install -y git wget && \ 6 | rm -rf /var/lib/apt/lists/* 7 | 8 | ADD install.sh install.sh 9 | 10 | RUN \ 11 | apt update && \ 12 | apt install -y virtualenv 13 | 14 | RUN \ 15 | bash ./install.sh 16 | 17 | # Oy, there's a stray mysqlclient dependency :( 18 | RUN \ 19 | apt update && \ 20 | apt-get install -y libmysqlclient-dev build-essential python3-dev libssl-dev 21 | 22 | RUN \ 23 | . venv/bin/activate && \ 24 | pip install -r IRNet/requirements.txt 25 | 26 | RUN \ 27 | apt update && \ 28 | apt install -y unzip 29 | 30 | ADD requirements.txt requirements.txt 31 | 32 | RUN \ 33 | . venv/bin/activate && \ 34 | pip install -r /requirements.txt 35 | 36 | ADD setup.sh setup.sh 37 | 38 | RUN \ 39 | bash ./setup.sh 40 | 41 | ADD serve.sh serve.sh 42 | 43 | ADD server server 44 | 45 | CMD ["./serve.sh"] 46 | 47 | EXPOSE 5050 48 | -------------------------------------------------------------------------------- /irnet/Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | docker build -t irnet . 3 | 4 | run: 5 | docker stop /irnet || echo ok 6 | docker rm /irnet || echo ok 7 | mkdir -p cache 8 | docker run --name /irnet -v "$(PWD)/cache:/cache:delegated" -p 5050:5050 -dit irnet # /bin/bash 9 | 10 | exec: 11 | docker exec -it /irnet /bin/bash 12 | -------------------------------------------------------------------------------- /irnet/README.md: -------------------------------------------------------------------------------- 1 | This wraps up https://github.com/microsoft/IRNet for making predictions. 2 | Don't judge IRNet based on this, I haven't got a BERT-based model for it, 3 | haven't even set up a way to test multi-table queries, and have broken 4 | things in various ways. Watch this space though. 5 | 6 | Steps: 7 | 8 | ``` 9 | # Make sure sqlova container isn't running since this will use same port. 10 | # Make sure docker has access to >= 5GB of memory, and that you're not 11 | # short on disk space. 12 | 13 | make && make run 14 | 15 | curl -F "csv=@bridges.csv" -F "q=how long is throgs neck" localhost:5050 16 | 17 | # You'll have to be very patient! It takes a long time to load the model. 18 | # The first time you run this it will be especially long, since various 19 | # downloads and unpackings need to happen. It wasn't practical to stick 20 | # all that in the docker container, at least not just yet. 21 | ``` 22 | -------------------------------------------------------------------------------- /irnet/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | if [ ! -e venv ]; then 6 | virtualenv -ppython3 venv 7 | fi 8 | 9 | source venv/bin/activate 10 | 11 | if [ ! -e WikiSQL ]; then 12 | git clone https://github.com/salesforce/WikiSQL 13 | cd WikiSQL 14 | git checkout 7080c898e13d82395c85e2c2c1de3c914801f4d8 15 | cd .. 16 | fi 17 | 18 | if [ ! -e spider ]; then 19 | git clone https://github.com/taoyds/spider 20 | cd spider 21 | git checkout 0b0c9cad97e4deeef1bc37c8435950f4bdefc141 22 | cd .. 23 | fi 24 | 25 | if [ ! -e IRNet ]; then 26 | git clone https://github.com/microsoft/IRNet 27 | cd IRNet 28 | git checkout 72df5c876f368ae4a1b594e7a740ff966dbbd3ba 29 | cd .. 30 | fi 31 | 32 | if [ ! -e irnet_conceptNet.zip ]; then 33 | wget https://github.com/paulfitz/mlsql/releases/download/v0.1/irnet_conceptNet.zip 34 | fi 35 | -------------------------------------------------------------------------------- /irnet/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==1.0.3 2 | records==0.5.2 3 | requests==2.22.0 4 | -------------------------------------------------------------------------------- /irnet/serve.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | PYTHONPATH=$PWD/IRNet:$PWD/server:$PYTHONPATH python server/prediction_server.py 5 | -------------------------------------------------------------------------------- /irnet/server/add_csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Add a CSV file as a table into .db and .tables.jsonl 4 | # Call as: 5 | # python add_csv.py 6 | # For a CSV file called data.csv, the table will be called table_data in the .db 7 | # file, and will be assigned the id 'data'. 8 | # All columns are treated as text - no attempt is made to sniff the type of value 9 | # stored in the column. 10 | 11 | import argparse, csv, json, os 12 | import records 13 | 14 | def get_table_name(table_id): 15 | return '{}'.format(table_id) 16 | 17 | def csv_stream_to_sqlite(table_id, f, sqlite_file_name): 18 | db = records.Database('sqlite:///{}'.format(sqlite_file_name)) 19 | cf = csv.DictReader(f, delimiter=',') 20 | # columns = [f'col{i}' for i in range(len(cf.fieldnames))] 21 | columns = cf.fieldnames 22 | simple_name = dict(zip(cf.fieldnames, columns)) 23 | rows = [dict((simple_name[name], val) for name, val in row.items()) 24 | for row in cf] 25 | types = {} 26 | for name in columns: 27 | good_float = 0 28 | bad_float = 0 29 | good_int = 0 30 | bad_int = 0 31 | for row in rows: 32 | val = row[name] 33 | try: 34 | float(val) 35 | good_float += 1 36 | except: 37 | bad_float += 1 38 | try: 39 | int(val) 40 | good_int += 1 41 | except: 42 | bad_int += 1 43 | if good_int >= 2 * bad_int and good_int >= good_float: 44 | types[name] = 'integer' 45 | elif good_float >= 2 * bad_float and good_float > 0: 46 | types[name] = 'real' 47 | else: 48 | types[name] = 'text' 49 | schema = ', '.join([f'{name} {types[name]}' for name in columns]) 50 | tname = get_table_name(table_id) 51 | db.query(f'DROP TABLE IF EXISTS {tname}') 52 | db.query(f'CREATE TABLE {tname} ({schema})') 53 | ccolumns = [f':{name}' for name in columns] 54 | print(f'INSERT INTO {tname}({",".join(columns)}) VALUES({",".join(ccolumns)})') 55 | db.bulk_query(f'INSERT INTO {tname}({",".join(columns)}) VALUES({",".join(ccolumns)})', 56 | rows) 57 | return True 58 | 59 | def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name): 60 | with open(csv_file_name) as f: 61 | return csv_stream_to_sqlite(table_id, f, sqlite_file_name) 62 | 63 | def csv_stream_to_json(table_id, f, json_file_name): 64 | cf = csv.DictReader(f, delimiter=',') 65 | record = {} 66 | record['header'] = [(name or 'col{}'.format(i)) for i, name in enumerate(cf.fieldnames)] 67 | record['page_title'] = None 68 | record['types'] = ['text'] * len(cf.fieldnames) 69 | record['id'] = table_id 70 | record['caption'] = None 71 | record['rows'] = [list(row.values()) for row in cf] 72 | record['name'] = get_table_name(table_id) 73 | with open(json_file_name, 'a+') as fout: 74 | json.dump(record, fout) 75 | fout.write('\n') 76 | return record 77 | 78 | def csv_to_json(table_id, csv_file_name, json_file_name): 79 | with open(csv_file_name) as f: 80 | csv_stream_to_json(table_id, f, json_file_name) 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('split') 85 | parser.add_argument('file', metavar='file.csv') 86 | args = parser.parse_args() 87 | table_id = os.path.splitext(os.path.basename(args.file))[0] 88 | csv_to_sqlite(table_id, args.file, '{}.db'.format(args.split)) 89 | csv_to_json(table_id, args.file, '{}.tables.jsonl'.format(args.split)) 90 | print("Added table with id '{id}' (name '{name}') to {split}.db and {split}.tables.jsonl".format( 91 | id=table_id, name=get_table_name(table_id), split=args.split)) 92 | 93 | -------------------------------------------------------------------------------- /irnet/server/add_question.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Add a line of json representing a question into .jsonl 4 | # Call as: 5 | # python add_question.py 6 | # 7 | # This utility is not intended for use during training. A dummy label is added to the 8 | # question to make it loadable by existing code. 9 | # 10 | # For example, suppose we downloaded this list of us state abbreviations: 11 | # https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/USstateAbbreviations.csv 12 | # Let's rename it as something short, say "abbrev.csv" 13 | # Now we can add it to a split called say "playground": 14 | # python add_csv.py playground abbrev.csv 15 | # And now we can add a question about it to the same split: 16 | # python add_question.py playground abbrev "what state has ansi digits of 11" 17 | # The next step would be to annotate the split: 18 | # python annotate_ws.py --din $PWD --dout $PWD --split playground 19 | # Then we're ready to run prediction on the split with predict.py 20 | 21 | import argparse, csv, json 22 | 23 | import json 24 | import nltk 25 | import sys 26 | 27 | nltk.download('punkt') 28 | 29 | def encode_question(db_id, question): 30 | question_toks = nltk.word_tokenize(question) 31 | result = [{ 32 | "db_id": db_id, 33 | "query": "SELECT count(*) FROM something", 34 | "question": question, 35 | "question_toks": question_toks, 36 | "sql": { 37 | "except": None, 38 | "from": { 39 | "conds": [], 40 | "table_units": [ 41 | [ 42 | "table_unit", 43 | 1 44 | ] 45 | ] 46 | }, 47 | "groupBy": [], 48 | "having": [], 49 | "intersect": None, 50 | "limit": None, 51 | "orderBy": [], 52 | "select": [ 53 | False, 54 | [ 55 | [ 56 | 1, 57 | [ 58 | 0, 59 | [ 60 | 0, 61 | 0, 62 | False 63 | ], 64 | None 65 | ] 66 | ] 67 | ] 68 | ], 69 | "union": None, 70 | "where": [ 71 | [ 72 | False, 73 | 1, 74 | [ 75 | 0, 76 | [ 77 | 0, 78 | 1, 79 | False 80 | ], 81 | None 82 | ], 83 | 1.0, 84 | None 85 | ] 86 | ] 87 | } 88 | }] 89 | return result 90 | 91 | def question_to_json(table_id, question, json_file_name): 92 | record = encode_question(table_id, question) 93 | with open(json_file_name, 'w') as fout: 94 | json.dump(record, fout) 95 | fout.write('\n') 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('split') 100 | parser.add_argument('table_id') 101 | parser.add_argument('question', type=str, nargs='+') 102 | args = parser.parse_args() 103 | json_file_name = '{}.jsonl'.format(args.split) 104 | question_to_json(args.table_id, " ".join(args.question), json_file_name) 105 | print("Added question (with dummy label) to {}".format(json_file_name)) 106 | -------------------------------------------------------------------------------- /irnet/server/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd / 4 | glove="glove.42B.300d" 5 | mkdir -p IRNet/data 6 | if [ ! -e IRNet/data/$glove.txt ]; then 7 | if [ ! -e cache/$glove.zip ]; then 8 | cd cache 9 | wget https://nlp.stanford.edu/data/wordvecs/$glove.zip 10 | cd .. 11 | fi 12 | if [ ! -e cache/$glove.txt ]; then 13 | cd cache 14 | unzip $glove.zip 15 | cd .. 16 | fi 17 | ln -s $PWD/cache/$glove.txt IRNet/data/$glove.txt 18 | fi 19 | 20 | mkdir -p IRNet/saved_model 21 | if [ ! -e IRNet/saved_model/IRNet_pretrained.model ]; then 22 | if [ ! -e cache/IRNet_pretrained.model ]; then 23 | cd cache 24 | wget https://github.com/paulfitz/mlsql/releases/download/v0.1/IRNet_pretrained.model 25 | cd .. 26 | fi 27 | ln -s $PWD/cache/IRNet_pretrained.model IRNet/saved_model/IRNet_pretrained.model 28 | fi 29 | -------------------------------------------------------------------------------- /irnet/server/prediction_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Turn this flag on to test just the server part of all this. 4 | TRIAL_RUN = False 5 | 6 | 7 | import sys 8 | sys.path.insert(0, '/IRNet') 9 | sys.path.insert(0, '/server') 10 | 11 | import argparse 12 | import json 13 | import os 14 | 15 | from flask import Flask, request 16 | from flask import jsonify 17 | import io 18 | import uuid 19 | import re 20 | 21 | import add_csv 22 | import add_question 23 | 24 | import torch 25 | from src import args as arg 26 | from src import utils 27 | from src.models.model import IRNet 28 | from src.rule import semQL 29 | 30 | import shutil 31 | import subprocess 32 | 33 | 34 | handle_request = None 35 | 36 | import threading 37 | thread = None 38 | status = "Loading irnet model, please wait" 39 | 40 | app = Flask(__name__) 41 | @app.route('/', methods=['POST']) 42 | def run(): 43 | if handle_request: 44 | return handle_request(request) 45 | else: 46 | return jsonify({"error": status}, 503) 47 | def start(): 48 | app.run(host='0.0.0.0', port=5050) 49 | thread = threading.Thread(target=start, args=()) 50 | thread.daemon = True 51 | thread.start() 52 | 53 | model = None 54 | args = None 55 | 56 | if not TRIAL_RUN: 57 | subprocess.run(["bash", "/server/download.sh"]) 58 | subprocess.run(["python", "/server/setup_nltk.py"]) 59 | sys.argv = ['zing', 60 | '--dataset', 'fake', 61 | '--glove_embed_path', '/cache/glove.42B.300d.txt', 62 | '--epoch', '50', 63 | '--beam_size', '5', 64 | '--seed', '90', 65 | '--save', '/tmp/save_name', 66 | '--embed_size', '300', 67 | '--sentence_features', 68 | '--column_pointer', 69 | '--hidden_size', '300', 70 | '--lr_scheduler', 71 | '--lr_scheduler_gammar', '0.5', 72 | '--att_vec_size', '300', 73 | '--batch_size', '1', 74 | '--load_model', '/cache/IRNet_pretrained.model'] 75 | arg_parser = arg.init_arg_parser() 76 | args = arg.init_config(arg_parser) 77 | print(args) 78 | grammar = semQL.Grammar() 79 | model = IRNet(args, grammar) 80 | if args.cuda: model.cuda() 81 | print('load pretrained model from %s'% (args.load_model)) 82 | pretrained_model = torch.load(args.load_model, 83 | map_location=lambda storage, loc: storage) 84 | import copy 85 | pretrained_modeled = copy.deepcopy(pretrained_model) 86 | for k in pretrained_model.keys(): 87 | if k not in model.state_dict().keys(): 88 | del pretrained_modeled[k] 89 | 90 | model.load_state_dict(pretrained_modeled) 91 | 92 | model.word_emb = utils.load_word_emb(args.glove_embed_path) 93 | print('loaded all models') 94 | 95 | 96 | def run_split(split): 97 | print(split) 98 | if not TRIAL_RUN: 99 | args.dataset = split 100 | sql_data, table_data, val_sql_data,\ 101 | val_table_data = utils.load_dataset(args.dataset, use_small=args.toy) 102 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, 103 | beam_size=args.beam_size) 104 | print('Sketch Acc: %f, Acc: %f' % (sketch_acc, acc)) 105 | with open(os.path.join(split, 'predict_lf.json'), 'w') as f: 106 | json.dump(json_datas, f) 107 | subprocess.run([ 108 | "python", 109 | "./sem2SQL.py", 110 | "--data_path", 111 | split, 112 | "--input_path", 113 | os.path.join(split, 'predict_lf.json'), 114 | "--output_path", 115 | os.path.join(split, 'output.txt') 116 | ], cwd="/IRNet") 117 | else: 118 | print("Trial run") 119 | with open(os.path.join(split, 'output.txt'), 'w') as f: 120 | f.write('trial run\n') 121 | with open(os.path.join(split, 'predict_lf.json'), 'w') as f: 122 | json.dump({'trial': 'run'}, f) 123 | results = {} 124 | with open(os.path.join(split, 'output.txt'), 'r') as f: 125 | results["sql"] = f.read().strip() 126 | with open(os.path.join(split, 'predict_lf.json'), 'r') as f: 127 | results["interpretation"] = json.load(f) 128 | message = { 129 | "split": split, 130 | "result": results 131 | } 132 | return message 133 | 134 | def serialize(o): 135 | if isinstance(o, int64): 136 | return int(o) 137 | 138 | def handle_request0(request): 139 | debug = 'debug' in request.form 140 | base = "" 141 | try: 142 | csv_key = 'csv' 143 | if csv_key not in request.files: 144 | csv_key = 'csv[]' 145 | print(request.files) 146 | if csv_key not in request.files and not 'sqlite' in request.files: 147 | raise Exception('please include a csv file or sqlite file') 148 | if not 'q' in request.form: 149 | raise Exception('please include a q parameter with a question in it') 150 | csvs = request.files.getlist(csv_key) 151 | sqlite_file = request.files.get('sqlite') 152 | q = request.form['q'] 153 | 154 | # brute force removal of any old requests 155 | if not TRIAL_RUN: 156 | subprocess.run([ 157 | "bash", 158 | "-c", 159 | "rm -rf /cache/case_*" 160 | ]) 161 | key = "case_" + str(uuid.uuid4()) 162 | data_dir = os.path.join('/cache', key) 163 | os.makedirs(os.path.join(data_dir, 'data'), exist_ok=True) 164 | print("Key", key) 165 | for csv in csvs: 166 | print("Working on", csv) 167 | table_id = os.path.splitext(csv.filename)[0] 168 | table_id = re.sub(r'\W+', '_', table_id) 169 | stream = io.StringIO(csv.stream.read().decode("UTF8"), newline=None) 170 | add_csv.csv_stream_to_sqlite(table_id, stream, os.path.join(data_dir, 'data', 171 | 'data.sqlite')) 172 | stream.seek(0) 173 | if sqlite_file: 174 | print("Working on", sqlite_file) 175 | sqlite_file.save(os.path.join(data_dir, 'data', 'data.sqlite')) 176 | question_file = os.path.join(data_dir, 'question.json') 177 | tables_file = os.path.join(data_dir, 'tables.json') 178 | dummy_file = os.path.join(data_dir, 'dummy.json') 179 | add_question.question_to_json('data', q, question_file) 180 | with open(dummy_file, 'w') as fout: 181 | fout.write('[]\n') 182 | 183 | if not TRIAL_RUN: 184 | subprocess.run([ 185 | "python", 186 | "/spider/preprocess/get_tables.py", 187 | data_dir, 188 | tables_file, 189 | dummy_file 190 | ]) 191 | subprocess.run([ 192 | "bash", 193 | "./run_me.sh", 194 | question_file, 195 | tables_file, 196 | os.path.join(data_dir, 'dummy2.json') 197 | ], cwd="/IRNet/preprocess") 198 | shutil.copyfile(question_file, os.path.join(data_dir, 'dev.json')) 199 | shutil.copyfile(question_file, os.path.join(data_dir, 'train.json')) 200 | message = run_split(data_dir) 201 | code = 200 202 | except Exception as e: 203 | message = { "error": str(e) } 204 | code = 500 205 | if debug: 206 | message['base'] = base 207 | return jsonify(message), code 208 | 209 | handle_request = handle_request0 210 | thread.join() 211 | -------------------------------------------------------------------------------- /irnet/server/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==1.0.3 2 | requests==2.22.0 3 | -------------------------------------------------------------------------------- /irnet/server/setup_nltk.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import nltk 3 | nltk.download('averaged_perceptron_tagger') 4 | nltk.download('punkt') 5 | nltk.download('wordnet') 6 | -------------------------------------------------------------------------------- /irnet/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source venv/bin/activate 4 | pip install nltk 5 | 6 | sed -i "s/ / /g" spider/preprocess/get_tables.py 7 | sed -i "s/print \(.*\)/print(\1)/" spider/preprocess/get_tables.py 8 | sed -i "s/prev_/#prev_/" spider/preprocess/get_tables.py 9 | sed -i "s/cur_/#cur_/" spider/preprocess/get_tables.py 10 | sed -i "s/if df in ex_tabs.*/if False:/" spider/preprocess/get_tables.py 11 | 12 | sed -i "s/open(file_name)/open(file_name, 'r', encoding='utf=8')/" IRNet/src/utils.py 13 | sed -i "s/(total)/(max(total,0.00001))/g" IRNet/src/utils.py 14 | sed -i "s/^ continue/ print('continue')/" IRNet/src/utils.py 15 | 16 | if [ ! -e IRNet/preprocess/conceptNet ]; then 17 | unzip irnet_conceptNet.zip 18 | ln -s $PWD/conceptNet IRNet/preprocess/conceptNet 19 | cd IRNet/preprocess 20 | sed -i 's/\r//g' run_me.sh 21 | cd ../.. 22 | fi 23 | 24 | mkdir -p cache 25 | -------------------------------------------------------------------------------- /players.csv: -------------------------------------------------------------------------------- 1 | Player,No.,Nationality,Position,Years in Toronto,School/Club Team 2 | Antonio Lang,21,United States,Guard-Forward,1999-2000,Duke 3 | Voshon Lenard,2,United States,Guard,2002-03,Minnesota 4 | Martin Lewis,"32, 44",United States,Guard-Forward,1996-97,Butler CC (KS) 5 | Brad Lohaus,33,United States,Forward-Center,1996,Iowa 6 | Art Long,42,United States,Forward-Center,2002-03,Cincinnati 7 | John Long,25,United States,Guard,1996-97,Detroit 8 | Kyle Lowry,3,United States,Guard,2012-present,Villanova 9 | -------------------------------------------------------------------------------- /sqlova/.dockerignore: -------------------------------------------------------------------------------- 1 | sqlova 2 | venv 3 | -------------------------------------------------------------------------------- /sqlova/Dockerfile: -------------------------------------------------------------------------------- 1 | # latest is not actually the most recent, currently 2 | 3 | FROM ubuntu:18.04 4 | 5 | # Install Java 6 | RUN \ 7 | apt update && \ 8 | apt install -y unzip openjdk-8-jre-headless git wget && \ 9 | apt install -y python3 python3-pip python3-dev && \ 10 | rm -rf /var/lib/apt/lists/* 11 | 12 | ENV VERSION stanford-corenlp-full-2016-10-31 13 | 14 | RUN \ 15 | mkdir -p /opt/corenlp && \ 16 | cd /opt/corenlp && \ 17 | wget --quiet http://nlp.stanford.edu/software/$VERSION.zip -O corenlp.zip && \ 18 | unzip corenlp.zip && \ 19 | mv $VERSION src && \ 20 | rm -r corenlp.zip && \ 21 | rm -rf /var/lib/apt/lists/* 22 | 23 | 24 | RUN \ 25 | cd /opt && \ 26 | echo v2 && \ 27 | git clone https://github.com/paulfitz/sqlova/ -b prediction_api 28 | 29 | WORKDIR /opt/sqlova 30 | 31 | add support support 32 | add pretrained pretrained 33 | 34 | add run_services.sh run_services.sh 35 | 36 | RUN \ 37 | pip3 install -r requirements.txt 38 | 39 | CMD ["./run_services.sh"] 40 | 41 | EXPOSE 5050 42 | -------------------------------------------------------------------------------- /sqlova/Makefile: -------------------------------------------------------------------------------- 1 | start: 2 | docker run --name sqlova -d -p "5050:5050" paulfitz/sqlova 3 | 4 | build: 5 | ./fetch_models.sh 6 | docker build -t paulfitz/sqlova . 7 | 8 | stop: 9 | docker stop sqlova || echo ok 10 | docker kill sqlova || echo ok 11 | docker rm sqlova 12 | -------------------------------------------------------------------------------- /sqlova/README.md: -------------------------------------------------------------------------------- 1 | This wraps up https://github.com/naver/sqlova/ for making predictions, 2 | just to get a sense of when it works and when it fails. 3 | 4 | Steps: 5 | 6 | ``` 7 | make 8 | # Get a csv file, e.g. bridges.csv 9 | curl -F "csv=@bridges.csv" -F "q=how long is throgs neck" localhost:5050 10 | ``` 11 | 12 | If model crashes without appearing to give output, check docker memory limit is high enough. 13 | (Docker->Preferences->Advanced or equivalent). 3GB works! 14 | -------------------------------------------------------------------------------- /sqlova/fetch_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | function fetch_model() { 6 | target="$1" 7 | url="$2" 8 | if [ ! -e "$target" ]; then 9 | which wget && { 10 | wget -O "$target" "$url" 11 | } || { 12 | which curl && { 13 | curl -L "$url" > "$target" 14 | } || { 15 | echo "Please fetch $url" 16 | echo "And place it in $target" 17 | exit 1 18 | } 19 | } 20 | echo "Downloaded $1" 21 | else 22 | echo "Have $1 already, not downloading" 23 | fi 24 | } 25 | 26 | mkdir -p pretrained 27 | 28 | fetch_model pretrained/model_bert_best.pt \ 29 | https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt 30 | 31 | fetch_model pretrained/model_best.pt \ 32 | https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_best.pt 33 | -------------------------------------------------------------------------------- /sqlova/run_services.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -o errexit 4 | set -o errtrace 5 | set -o pipefail 6 | set -e 7 | 8 | trap 'cleanup' EXIT 9 | trap 'echo "Exiting on SIGINT"; exit 1' INT 10 | trap 'echo "Exiting on SIGTERM"; exit 1' TERM 11 | 12 | PIDS=() 13 | 14 | cleanup() { 15 | kill "${PIDS[@]}" 16 | wait "${PIDS[@]}" 17 | exit 0 18 | } 19 | 20 | cd /opt/corenlp/src 21 | java -mx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer & 22 | PIDS+=($!) 23 | 24 | # corenlp should be up before we start serving sqlova, or we could 25 | # fail to annotate a request. 26 | while ! wget http://localhost:9000 < /dev/null > /dev/null 2>&1; do 27 | echo Waiting for nlp server 28 | sleep 1 29 | done 30 | 31 | cd /opt/sqlova 32 | python3 ./predict.py \ 33 | --bert_type_abb uL \ 34 | --model_file pretrained/model_best.pt \ 35 | --bert_model_file pretrained/model_bert_best.pt \ 36 | --bert_path support \ 37 | --data_path /opt/sqlova \ 38 | --result_path /opt/sqlova & 39 | PIDS+=($!) 40 | 41 | wait "${PIDS[@]}" 42 | -------------------------------------------------------------------------------- /sqlova/support/bert_config_uncased_L-12_H-768_A-12.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /sqlova/support/bert_config_uncased_L-24_H-1024_A-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 4096, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 16, 10 | "num_hidden_layers": 24, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function run() { 4 | echo " " 5 | echo "==================================================" 6 | echo "$2" 7 | curl -F "csv=@$1.csv" -F "q=$2" localhost:5050 8 | } 9 | 10 | run bridges "how long is throgs neck" 11 | run bridges "who designed the george washington" 12 | run bridges "how many bridges are there" 13 | run bridges "how many bridges are designed by O. H. Ammann" 14 | run bridges "which bridges are longer than 2000" 15 | run bridges "how many bridges are longer than 2000" 16 | run bridges "what is the shortest length" 17 | 18 | run players "Who is the player that wears number 42?" 19 | run players "What number did the person playing for Duke wear?" 20 | run players "What year did Brad Lohaus play?" 21 | run players "What country is Voshon Lenard from?" 22 | 23 | run iris "how many setosa rows are there" 24 | run iris "what is the average petal width for virginica" 25 | run iris "what is the longest sepal for versicolor" 26 | -------------------------------------------------------------------------------- /valuenet/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN \ 4 | apt update && \ 5 | apt install -y git wget && \ 6 | rm -rf /var/lib/apt/lists/* 7 | 8 | ADD install.sh install.sh 9 | 10 | RUN \ 11 | apt update && \ 12 | apt install -y virtualenv 13 | 14 | RUN \ 15 | bash ./install.sh 16 | 17 | # Oy, there's a stray mysqlclient dependency :( 18 | RUN \ 19 | apt update && \ 20 | apt-get install -y libmysqlclient-dev build-essential python3-dev libssl-dev 21 | 22 | RUN \ 23 | . venv/bin/activate && \ 24 | pip install -r IRNet/requirements.txt 25 | 26 | RUN \ 27 | apt update && \ 28 | apt install -y unzip 29 | 30 | ADD requirements.txt requirements.txt 31 | 32 | RUN \ 33 | . venv/bin/activate && \ 34 | pip install -r /requirements.txt 35 | 36 | ADD setup.sh setup.sh 37 | 38 | RUN \ 39 | bash ./setup.sh 40 | 41 | ADD serve.sh serve.sh 42 | 43 | ADD server server 44 | 45 | ADD example_data /valuenet/data/paulfitz 46 | 47 | RUN \ 48 | AUTOEXIT=1 bash ./serve.sh 49 | 50 | CMD ["./serve.sh"] 51 | 52 | EXPOSE 5050 53 | -------------------------------------------------------------------------------- /valuenet/Makefile: -------------------------------------------------------------------------------- 1 | build: 2 | docker build -t valuenet . 3 | 4 | run: 5 | docker stop /valuenet || echo ok 6 | docker rm /valuenet || echo ok 7 | docker run --name /valuenet -v "$(PWD)/../cache:/cache:delegated" -p 5050:5050 -dit valuenet 8 | 9 | exec: 10 | docker exec -it /valuenet /bin/bash 11 | -------------------------------------------------------------------------------- /valuenet/example_data/original/database/data/data.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulfitz/mlsql/2f2f9cff35dce24580b06085072b44a8ce17fb54/valuenet/example_data/original/database/data/data.sqlite -------------------------------------------------------------------------------- /valuenet/example_data/original/tables.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "column_names": [ 4 | [ 5 | -1, 6 | "*" 7 | ], 8 | [ 9 | 0, 10 | "id" 11 | ], 12 | [ 13 | 0, 14 | "city" 15 | ], 16 | [ 17 | 0, 18 | "country" 19 | ], 20 | [ 21 | 1, 22 | "id" 23 | ], 24 | [ 25 | 1, 26 | "company" 27 | ], 28 | [ 29 | 1, 30 | "location id" 31 | ], 32 | [ 33 | 1, 34 | "ceo id" 35 | ], 36 | [ 37 | 2, 38 | "id" 39 | ], 40 | [ 41 | 2, 42 | "name" 43 | ], 44 | [ 45 | 2, 46 | "birth year" 47 | ] 48 | ], 49 | "column_names_original": [ 50 | [ 51 | -1, 52 | "*" 53 | ], 54 | [ 55 | 0, 56 | "id" 57 | ], 58 | [ 59 | 0, 60 | "city" 61 | ], 62 | [ 63 | 0, 64 | "country" 65 | ], 66 | [ 67 | 1, 68 | "id" 69 | ], 70 | [ 71 | 1, 72 | "company" 73 | ], 74 | [ 75 | 1, 76 | "location_id" 77 | ], 78 | [ 79 | 1, 80 | "ceo_id" 81 | ], 82 | [ 83 | 2, 84 | "id" 85 | ], 86 | [ 87 | 2, 88 | "name" 89 | ], 90 | [ 91 | 2, 92 | "birth_year" 93 | ] 94 | ], 95 | "column_types": [ 96 | "text", 97 | "number", 98 | "text", 99 | "text", 100 | "number", 101 | "text", 102 | "text", 103 | "text", 104 | "number", 105 | "text", 106 | "text" 107 | ], 108 | "db_id": "data", 109 | "foreign_keys": [ 110 | [ 111 | 7, 112 | 8 113 | ], 114 | [ 115 | 6, 116 | 1 117 | ] 118 | ], 119 | "primary_keys": [ 120 | 1, 121 | 4, 122 | 8 123 | ], 124 | "table_names": [ 125 | "locations", 126 | "organizations", 127 | "people" 128 | ], 129 | "table_names_original": [ 130 | "locations", 131 | "organizations", 132 | "people" 133 | ] 134 | } 135 | ] -------------------------------------------------------------------------------- /valuenet/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | if [ ! -e venv ]; then 6 | virtualenv -ppython3 venv 7 | fi 8 | 9 | source venv/bin/activate 10 | 11 | if [ ! -e WikiSQL ]; then 12 | git clone https://github.com/salesforce/WikiSQL 13 | cd WikiSQL 14 | git checkout 7080c898e13d82395c85e2c2c1de3c914801f4d8 15 | cd .. 16 | fi 17 | 18 | if [ ! -e spider ]; then 19 | git clone https://github.com/taoyds/spider 20 | cd spider 21 | git checkout 0b0c9cad97e4deeef1bc37c8435950f4bdefc141 22 | cd .. 23 | fi 24 | 25 | if [ ! -e IRNet ]; then 26 | git clone https://github.com/microsoft/IRNet 27 | cd IRNet 28 | git checkout 72df5c876f368ae4a1b594e7a740ff966dbbd3ba 29 | cd .. 30 | fi 31 | 32 | if [ ! -e irnet_conceptNet.zip ]; then 33 | wget https://github.com/paulfitz/mlsql/releases/download/v0.1/irnet_conceptNet.zip 34 | fi 35 | -------------------------------------------------------------------------------- /valuenet/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.8.0 2 | pytictoc==1.5.0 3 | textdistance[extras]==4.2.0 4 | wandb==0.8.35 5 | spacy==2.2.4 6 | termcolor==1.1.0 7 | 8 | Flask==1.0.3 9 | records==0.5.2 10 | requests==2.22.0 11 | -------------------------------------------------------------------------------- /valuenet/serve.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cp -r /maintain/example_data/ /valuenet/data/paulfitz 4 | cd / 5 | source venv/bin/activate 6 | 7 | if [ ! -e valuenet_pretrained.pt ]; then 8 | wget https://github.com/paulfitz/mlsql/releases/download/v0.1/valuenet_pretrained.pt 9 | fi 10 | 11 | ( 12 | echo "import nltk" 13 | echo "nltk.download('averaged_perceptron_tagger')" 14 | ) | python 15 | 16 | cd /valuenet 17 | PYTHONPATH=/valuenet/src python /server/prediction_server.py 18 | -------------------------------------------------------------------------------- /valuenet/server/add_csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Add a CSV file as a table into .db and .tables.jsonl 4 | # Call as: 5 | # python add_csv.py 6 | # For a CSV file called data.csv, the table will be called table_data in the .db 7 | # file, and will be assigned the id 'data'. 8 | # All columns are treated as text - no attempt is made to sniff the type of value 9 | # stored in the column. 10 | 11 | import argparse, csv, json, os 12 | import records 13 | 14 | def get_table_name(table_id): 15 | return '{}'.format(table_id) 16 | 17 | def csv_stream_to_sqlite(table_id, f, sqlite_file_name): 18 | db = records.Database('sqlite:///{}'.format(sqlite_file_name)) 19 | cf = csv.DictReader(f, delimiter=',') 20 | # columns = [f'col{i}' for i in range(len(cf.fieldnames))] 21 | columns = cf.fieldnames 22 | simple_name = dict(zip(cf.fieldnames, columns)) 23 | rows = [dict((simple_name[name], val) for name, val in row.items()) 24 | for row in cf] 25 | types = {} 26 | for name in columns: 27 | good_float = 0 28 | bad_float = 0 29 | good_int = 0 30 | bad_int = 0 31 | for row in rows: 32 | val = row[name] 33 | try: 34 | float(val) 35 | good_float += 1 36 | except: 37 | bad_float += 1 38 | try: 39 | int(val) 40 | good_int += 1 41 | except: 42 | bad_int += 1 43 | if good_int >= 2 * bad_int and good_int >= good_float: 44 | types[name] = 'integer' 45 | elif good_float >= 2 * bad_float and good_float > 0: 46 | types[name] = 'real' 47 | else: 48 | types[name] = 'text' 49 | schema = ', '.join([f'{name} {types[name]}' for name in columns]) 50 | tname = get_table_name(table_id) 51 | db.query(f'DROP TABLE IF EXISTS {tname}') 52 | db.query(f'CREATE TABLE {tname} ({schema})') 53 | ccolumns = [f':{name}' for name in columns] 54 | print(f'INSERT INTO {tname}({",".join(columns)}) VALUES({",".join(ccolumns)})') 55 | db.bulk_query(f'INSERT INTO {tname}({",".join(columns)}) VALUES({",".join(ccolumns)})', 56 | rows) 57 | return True 58 | 59 | def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name): 60 | with open(csv_file_name) as f: 61 | return csv_stream_to_sqlite(table_id, f, sqlite_file_name) 62 | 63 | def csv_stream_to_json(table_id, f, json_file_name): 64 | cf = csv.DictReader(f, delimiter=',') 65 | record = {} 66 | record['header'] = [(name or 'col{}'.format(i)) for i, name in enumerate(cf.fieldnames)] 67 | record['page_title'] = None 68 | record['types'] = ['text'] * len(cf.fieldnames) 69 | record['id'] = table_id 70 | record['caption'] = None 71 | record['rows'] = [list(row.values()) for row in cf] 72 | record['name'] = get_table_name(table_id) 73 | with open(json_file_name, 'a+') as fout: 74 | json.dump(record, fout) 75 | fout.write('\n') 76 | return record 77 | 78 | def csv_to_json(table_id, csv_file_name, json_file_name): 79 | with open(csv_file_name) as f: 80 | csv_stream_to_json(table_id, f, json_file_name) 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('split') 85 | parser.add_argument('file', metavar='file.csv') 86 | args = parser.parse_args() 87 | table_id = os.path.splitext(os.path.basename(args.file))[0] 88 | csv_to_sqlite(table_id, args.file, '{}.db'.format(args.split)) 89 | csv_to_json(table_id, args.file, '{}.tables.jsonl'.format(args.split)) 90 | print("Added table with id '{id}' (name '{name}') to {split}.db and {split}.tables.jsonl".format( 91 | id=table_id, name=get_table_name(table_id), split=args.split)) 92 | 93 | -------------------------------------------------------------------------------- /valuenet/server/add_question.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Add a line of json representing a question into .jsonl 4 | # Call as: 5 | # python add_question.py
6 | # 7 | # This utility is not intended for use during training. A dummy label is added to the 8 | # question to make it loadable by existing code. 9 | # 10 | # For example, suppose we downloaded this list of us state abbreviations: 11 | # https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/USstateAbbreviations.csv 12 | # Let's rename it as something short, say "abbrev.csv" 13 | # Now we can add it to a split called say "playground": 14 | # python add_csv.py playground abbrev.csv 15 | # And now we can add a question about it to the same split: 16 | # python add_question.py playground abbrev "what state has ansi digits of 11" 17 | # The next step would be to annotate the split: 18 | # python annotate_ws.py --din $PWD --dout $PWD --split playground 19 | # Then we're ready to run prediction on the split with predict.py 20 | 21 | import argparse, csv, json 22 | 23 | import json 24 | import nltk 25 | import sys 26 | 27 | nltk.download('punkt') 28 | 29 | def encode_question(db_id, question): 30 | question_toks = nltk.word_tokenize(question) 31 | result = [{ 32 | "db_id": db_id, 33 | "query": "SELECT count(*) FROM something", 34 | "question": question, 35 | "question_toks": question_toks, 36 | "sql": { 37 | "except": None, 38 | "from": { 39 | "conds": [], 40 | "table_units": [ 41 | [ 42 | "table_unit", 43 | 1 44 | ] 45 | ] 46 | }, 47 | "groupBy": [], 48 | "having": [], 49 | "intersect": None, 50 | "limit": None, 51 | "orderBy": [], 52 | "select": [ 53 | False, 54 | [ 55 | [ 56 | 1, 57 | [ 58 | 0, 59 | [ 60 | 0, 61 | 0, 62 | False 63 | ], 64 | None 65 | ] 66 | ] 67 | ] 68 | ], 69 | "union": None, 70 | "where": [ 71 | [ 72 | False, 73 | 1, 74 | [ 75 | 0, 76 | [ 77 | 0, 78 | 1, 79 | False 80 | ], 81 | None 82 | ], 83 | 1.0, 84 | None 85 | ] 86 | ] 87 | } 88 | }] 89 | return result 90 | 91 | def question_to_json(table_id, question, json_file_name): 92 | record = encode_question(table_id, question) 93 | with open(json_file_name, 'w') as fout: 94 | json.dump(record, fout) 95 | fout.write('\n') 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('split') 100 | parser.add_argument('table_id') 101 | parser.add_argument('question', type=str, nargs='+') 102 | args = parser.parse_args() 103 | json_file_name = '{}.jsonl'.format(args.split) 104 | question_to_json(args.table_id, " ".join(args.question), json_file_name) 105 | print("Added question (with dummy label) to {}".format(json_file_name)) 106 | -------------------------------------------------------------------------------- /valuenet/server/prediction_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Turn this flag on to test just the server part of all this. 4 | TRIAL_RUN = False 5 | 6 | 7 | import sys 8 | sys.path.insert(0, '/IRNet') 9 | sys.path.insert(0, '/server') 10 | 11 | import argparse 12 | import json 13 | import os 14 | 15 | from flask import Flask, request 16 | from flask import jsonify 17 | import io 18 | import uuid 19 | import re 20 | 21 | import add_csv 22 | import add_question 23 | 24 | import torch 25 | from src import args as arg 26 | from src import utils 27 | from src.rule import semQL 28 | 29 | import shutil 30 | import subprocess 31 | 32 | 33 | 34 | 35 | ################################################################################### 36 | # Manual inference material 37 | ################################################################################### 38 | 39 | import os 40 | import pickle 41 | import sqlite3 42 | from pprint import pprint 43 | 44 | import torch 45 | 46 | from config import read_arguments_manual_inference 47 | from intermediate_representation import semQL 48 | from intermediate_representation.sem2sql.sem2SQL import transform 49 | from intermediate_representation.sem_utils import alter_column0 50 | from model.model import IRNet 51 | from named_entity_recognition.api_ner.google_api_repository import remote_named_entity_recognition 52 | from named_entity_recognition.pre_process_ner_values import pre_process, match_values_in_database 53 | from preprocessing.process_data import process_datas 54 | from preprocessing.utils import merge_data_with_schema 55 | from spider import spider_utils 56 | from spider.example_builder import build_example 57 | from utils import setup_device, set_seed_everywhere 58 | 59 | from spacy.lang.en import English 60 | 61 | from termcolor import colored 62 | 63 | import json 64 | import spacy 65 | from spacy import displacy 66 | from collections import Counter 67 | import en_core_web_sm 68 | nlp = en_core_web_sm.load() 69 | 70 | def _inference_semql(data_row, schemas, model): 71 | example = build_example(data_row, schemas) 72 | 73 | with torch.no_grad(): 74 | results_all = model.parse(example, beam_size=1) 75 | results = results_all[0] 76 | # here we set assemble the predicted actions (including leaf-nodes) as string 77 | full_prediction = " ".join([str(x) for x in results[0].actions]) 78 | 79 | prediction = example.sql_json['pre_sql'] 80 | prediction['model_result'] = full_prediction 81 | 82 | return prediction, example 83 | 84 | 85 | def _tokenize_question(tokenizer, question): 86 | # Create a Tokenizer with the default settings for English 87 | # including punctuation rules and exceptions 88 | 89 | question_tokenized = tokenizer(question) 90 | 91 | return [str(token) for token in question_tokenized] 92 | 93 | 94 | def _get_entities(question): 95 | ner_results = remote_named_entity_recognition(row['question']) 96 | row['ner_extracted_values'] = ner_results['entities'] 97 | 98 | def _get_entities_local(question): 99 | doc = nlp(question) 100 | return [{'type': 'spacy_' + ent.label_, 'name': ent.text} for ent in doc.ents] 101 | 102 | def _pre_process_values(row): 103 | row['ner_extracted_values'] = _get_entities_local(row['question']) 104 | 105 | extracted_values = pre_process(row) 106 | 107 | row['values'] = match_values_in_database(row['db_id'], extracted_values) 108 | 109 | return row 110 | 111 | 112 | def _semql_to_sql(prediction, schemas): 113 | alter_column0([prediction]) 114 | result = transform(prediction, schemas[prediction['db_id']]) 115 | return result[0] 116 | 117 | 118 | def _execute_query(sql, db): 119 | conn = sqlite3.connect(db) 120 | cursor = conn.cursor() 121 | 122 | cursor.execute(sql) 123 | result = cursor.fetchall() 124 | 125 | conn.close() 126 | 127 | return result 128 | 129 | ################################################################################### 130 | # Manual inference material ends 131 | ################################################################################### 132 | 133 | 134 | 135 | 136 | 137 | 138 | handle_request = None 139 | 140 | import threading 141 | thread = None 142 | status = "Loading valuenet models, please wait" 143 | 144 | app = Flask(__name__) 145 | @app.route('/', methods=['POST']) 146 | def run(): 147 | if handle_request: 148 | return handle_request(request) 149 | else: 150 | return jsonify({"error": status}, 503) 151 | def start(): 152 | app.run(host='0.0.0.0', port=5050) 153 | 154 | thread = None 155 | if not(os.environ.get('AUTOEXIT')): 156 | thread = threading.Thread(target=start, args=()) 157 | thread.daemon = True 158 | thread.start() 159 | 160 | model = None 161 | args = None 162 | 163 | if not TRIAL_RUN: 164 | 165 | sys.argv = ['zing', 166 | '--model_to_load', '/valuenet_pretrained.pt', 167 | '--database', 'data', 168 | '--data_set', 'paulfitz', 169 | '--conceptNet', '/conceptNet'] 170 | args = read_arguments_manual_inference() 171 | print(args) 172 | 173 | device, n_gpu = setup_device() 174 | set_seed_everywhere(args.seed, n_gpu) 175 | 176 | grammar = semQL.Grammar() 177 | model = IRNet(args, device, grammar) 178 | model.to(device) 179 | 180 | # load the pre-trained parameters 181 | model.load_state_dict(torch.load(args.model_to_load,map_location=torch.device('cpu'))) 182 | model.eval() 183 | print("Load pre-trained model from '{}'".format(args.model_to_load)) 184 | 185 | nlp = English() 186 | tokenizer = nlp.Defaults.create_tokenizer(nlp) 187 | 188 | with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f: 189 | related_to_concept = pickle.load(f) 190 | 191 | with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f: 192 | is_a_concept = pickle.load(f) 193 | print('loaded all models') 194 | 195 | 196 | def run_split(split): 197 | print(split) 198 | if not TRIAL_RUN: 199 | args.dataset = split 200 | sql_data, table_data, val_sql_data,\ 201 | val_table_data = utils.load_dataset(args.dataset, use_small=args.toy) 202 | json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, 203 | beam_size=args.beam_size) 204 | print('Sketch Acc: %f, Acc: %f' % (sketch_acc, acc)) 205 | with open(os.path.join(split, 'predict_lf.json'), 'w') as f: 206 | json.dump(json_datas, f) 207 | subprocess.run([ 208 | "python", 209 | "./sem2SQL.py", 210 | "--data_path", 211 | split, 212 | "--input_path", 213 | os.path.join(split, 'predict_lf.json'), 214 | "--output_path", 215 | os.path.join(split, 'output.txt') 216 | ], cwd="/IRNet") 217 | else: 218 | print("Trial run") 219 | with open(os.path.join(split, 'output.txt'), 'w') as f: 220 | f.write('trial run\n') 221 | with open(os.path.join(split, 'predict_lf.json'), 'w') as f: 222 | json.dump({'trial': 'run'}, f) 223 | results = {} 224 | with open(os.path.join(split, 'output.txt'), 'r') as f: 225 | results["sql"] = f.read().strip() 226 | with open(os.path.join(split, 'predict_lf.json'), 'r') as f: 227 | results["interpretation"] = json.load(f) 228 | message = { 229 | "split": split, 230 | "result": results 231 | } 232 | return message 233 | 234 | def serialize(o): 235 | if isinstance(o, int64): 236 | return int(o) 237 | 238 | def handle_request0(request): 239 | debug = 'debug' in request.form 240 | base = "" 241 | try: 242 | csv_key = 'csv' 243 | if csv_key not in request.files: 244 | csv_key = 'csv[]' 245 | print(request.files) 246 | if csv_key not in request.files and not 'sqlite' in request.files: 247 | raise Exception('please include a csv file or sqlite file') 248 | if not 'q' in request.form: 249 | raise Exception('please include a q parameter with a question in it') 250 | csvs = request.files.getlist(csv_key) 251 | sqlite_file = request.files.get('sqlite') 252 | q = request.form['q'] 253 | 254 | # brute force removal of any old requests 255 | if not TRIAL_RUN: 256 | subprocess.run([ 257 | "bash", 258 | "-c", 259 | "rm -rf /cache/case_*" 260 | ]) 261 | key = "case_" + str(uuid.uuid4()) 262 | data_dir = os.path.join('/cache', key) 263 | os.makedirs(os.path.join(data_dir, 'data'), exist_ok=True) 264 | os.makedirs(os.path.join(data_dir, 'original', 'database', 'data'), exist_ok=True) 265 | print("Key", key) 266 | for csv in csvs: 267 | print("Working on", csv) 268 | table_id = os.path.splitext(csv.filename)[0] 269 | table_id = re.sub(r'\W+', '_', table_id) 270 | stream = io.StringIO(csv.stream.read().decode("UTF8"), newline=None) 271 | add_csv.csv_stream_to_sqlite(table_id, stream, os.path.join(data_dir, 'data', 272 | 'data.sqlite')) 273 | stream.seek(0) 274 | if sqlite_file: 275 | print("Working on", sqlite_file) 276 | sqlite_file.save(os.path.join(data_dir, 'data', 'data.sqlite')) 277 | question_file = os.path.join(data_dir, 'question.json') 278 | tables_file = os.path.join(data_dir, 'tables.json') 279 | dummy_file = os.path.join(data_dir, 'dummy.json') 280 | add_question.question_to_json('data', q, question_file) 281 | 282 | row = { 283 | 'question': q, 284 | 'query': 'DUMMY', 285 | 'db_id': args.database, 286 | 'question_toks': _tokenize_question(tokenizer, q) 287 | } 288 | 289 | print(colored(f"question has been tokenized to : { row['question_toks'] }", 'cyan', attrs=['bold'])) 290 | 291 | with open(dummy_file, 'w') as fout: 292 | fout.write('[]\n') 293 | 294 | subprocess.run([ 295 | "python", 296 | "/spider/preprocess/get_tables.py", 297 | data_dir, 298 | tables_file, 299 | dummy_file 300 | ]) 301 | 302 | # valuenet expects different setup to irnet 303 | shutil.copyfile(tables_file, os.path.join(data_dir, 'original', 'tables.json')) 304 | database_path = os.path.join(data_dir, 'original', 'database', 'data', 305 | 'data.sqlite') 306 | shutil.copyfile(os.path.join(data_dir, 'data', 'data.sqlite'), 307 | database_path) 308 | 309 | schemas_raw, schemas_dict = spider_utils.load_schema(data_dir) 310 | 311 | data, table = merge_data_with_schema(schemas_raw, [row]) 312 | 313 | pre_processed_data = process_datas(data, related_to_concept, is_a_concept) 314 | 315 | pre_processed_with_values = _pre_process_values(pre_processed_data[0]) 316 | 317 | print(f"we found the following potential values in the question: {row['values']}") 318 | 319 | prediction, example = _inference_semql(pre_processed_with_values, schemas_dict, model) 320 | 321 | print(f"Results from schema linking (question token types): {example.src_sent}") 322 | print(f"Results from schema linking (column types): {example.col_hot_type}") 323 | 324 | print(colored(f"Predicted SemQL-Tree: {prediction['model_result']}", 'magenta', attrs=['bold'])) 325 | print() 326 | sql = _semql_to_sql(prediction, schemas_dict) 327 | 328 | print(colored(f"Transformed to SQL: {sql}", 'cyan', attrs=['bold'])) 329 | print() 330 | result = _execute_query(sql, database_path) 331 | 332 | print(f"Executed on the database '{args.database}'. Results: ") 333 | for row in result: 334 | print(colored(row, 'green')) 335 | 336 | message = { 337 | "split": key, 338 | "result": { 339 | "sql": sql.strip(), 340 | "answer": result 341 | } 342 | } 343 | code = 200 344 | except Exception as e: 345 | message = { "error": str(e) } 346 | code = 500 347 | if debug: 348 | message['base'] = base 349 | return jsonify(message), code 350 | 351 | handle_request = handle_request0 352 | 353 | if thread: 354 | thread.join() 355 | -------------------------------------------------------------------------------- /valuenet/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | if [ ! -e valuenet ]; then 6 | git clone https://github.com/paulfitz/valuenet -b mlsql_tweaks 7 | fi 8 | 9 | source venv/bin/activate 10 | pip install -r /requirements.txt 11 | python -m spacy download en_core_web_sm 12 | 13 | sed -i "s/ / /g" spider/preprocess/get_tables.py 14 | sed -i "s/print \(.*\)/print(\1)/" spider/preprocess/get_tables.py 15 | sed -i "s/prev_/#prev_/" spider/preprocess/get_tables.py 16 | sed -i "s/cur_/#cur_/" spider/preprocess/get_tables.py 17 | sed -i "s/if df in ex_tabs.*/if False:/" spider/preprocess/get_tables.py 18 | 19 | sed -i "s/open(file_name)/open(file_name, 'r', encoding='utf=8')/" IRNet/src/utils.py 20 | sed -i "s/(total)/(max(total,0.00001))/g" IRNet/src/utils.py 21 | sed -i "s/^ continue/ print('continue')/" IRNet/src/utils.py 22 | 23 | if [ ! -e IRNet/preprocess/conceptNet ]; then 24 | unzip irnet_conceptNet.zip 25 | ln -s $PWD/conceptNet IRNet/preprocess/conceptNet 26 | cd IRNet/preprocess 27 | sed -i 's/\r//g' run_me.sh 28 | cd ../.. 29 | fi 30 | 31 | mkdir -p cache 32 | --------------------------------------------------------------------------------