├── teaspn_server ├── __init__.py ├── teaspn_handler.py ├── gpt2completion_handler.py ├── paraphrase_handler.py ├── protocol.py ├── teaspn_server.py └── handler_impl_demo.py ├── images ├── hover.gif ├── jump.gif ├── ged_gec.gif ├── search.gif ├── completion.gif └── paraphrase.gif ├── requirements.in ├── run-teaspn-server ├── docker-compose.yml ├── Dockerfile ├── requirements.txt ├── README.md └── .gitignore /teaspn_server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/hover.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teaspn/teaspn-server/HEAD/images/hover.gif -------------------------------------------------------------------------------- /images/jump.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teaspn/teaspn-server/HEAD/images/jump.gif -------------------------------------------------------------------------------- /images/ged_gec.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teaspn/teaspn-server/HEAD/images/ged_gec.gif -------------------------------------------------------------------------------- /images/search.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teaspn/teaspn-server/HEAD/images/search.gif -------------------------------------------------------------------------------- /images/completion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teaspn/teaspn-server/HEAD/images/completion.gif -------------------------------------------------------------------------------- /images/paraphrase.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/teaspn/teaspn-server/HEAD/images/paraphrase.gif -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | language-check==1.1 2 | neuralcoref==4.0 3 | python-jsonrpc-server==0.2.0 4 | transformers==2.1.1 5 | spacy==2.1.0 6 | torch==1.2.0 7 | fairseq==0.8.0 8 | nltk==3.4.5 9 | overrides==1.9 10 | retrying==1.3.3 11 | sentencepiece==0.1.83 12 | -------------------------------------------------------------------------------- /run-teaspn-server: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | 5 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | cd $SCRIPT_DIR 7 | docker-compose restart elasticsearch 8 | docker-compose run teaspn-server python -m teaspn_server.teaspn_server -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | elasticsearch: 4 | image: elasticsearch:7.1.1 5 | ports: 6 | - "9200:9200" 7 | volumes: 8 | - es-data:/usr/share/elasticsearch/data 9 | environment: 10 | - discovery.type=single-node 11 | - cluster.name=docker-cluster 12 | - bootstrap.memory_lock=true 13 | - "ES_JAVA_OPTS=-Xms512m -Xmx512m" 14 | ulimits: 15 | memlock: 16 | soft: -1 17 | hard: -1 18 | 19 | elasticsearch-dump: 20 | image: taskrabbit/elasticsearch-dump 21 | tty: true 22 | stdin_open: true 23 | 24 | teaspn-server: 25 | build: 26 | context: . 27 | dockerfile: Dockerfile 28 | tty: true 29 | stdin_open: true 30 | command: python -m teaspn_server.teaspn_server 31 | 32 | 33 | volumes: 34 | es-data: 35 | driver: local -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | 3 | RUN apt update 4 | RUN apt install -y apt-transport-https ca-certificates wget dirmngr gnupg software-properties-common 5 | RUN wget -qO - https://adoptopenjdk.jfrog.io/adoptopenjdk/api/gpg/key/public | apt-key add - 6 | RUN add-apt-repository -y https://adoptopenjdk.jfrog.io/adoptopenjdk/deb/ 7 | RUN apt update 8 | RUN apt install -y adoptopenjdk-8-hotspot 9 | 10 | RUN java -version 11 | 12 | 13 | ARG project_dir=/teaspn-server 14 | 15 | WORKDIR $project_dir 16 | 17 | COPY requirements.txt $project_dir 18 | 19 | RUN pip install -U pip 20 | RUN pip install -r ./requirements.txt 21 | 22 | COPY teaspn_server $project_dir/teaspn_server 23 | 24 | RUN python -m spacy download en 25 | RUN python -c "import nltk; nltk.download('wordnet')" 26 | RUN python -c "import logging; logging.basicConfig(level=logging.INFO); import neuralcoref" 27 | RUN python -c "from transformers import GPT2Tokenizer, GPT2LMHeadModel; GPT2LMHeadModel.from_pretrained('distilgpt2'); GPT2Tokenizer.from_pretrained('distilgpt2')" 28 | 29 | RUN mkdir -p $project_dir/model/paraphrase 30 | RUN mkdir -p $project_dir/model/paraphrase/spm 31 | 32 | RUN curl -sLJ --output $project_dir/model/paraphrase/dict.target.spm.txt "https://teaspn.s3.amazonaws.com/server/0.0.1/assets/dict.target.spm.txt" 33 | RUN curl -sLJ --output $project_dir/model/paraphrase/dict.source.spm.txt "https://teaspn.s3.amazonaws.com/server/0.0.1/assets/dict.source.spm.txt" 34 | RUN curl -sLJ --output $project_dir/model/paraphrase/checkpoint_best.pt "https://teaspn.s3.amazonaws.com/server/0.0.1/assets/checkpoint_best.pt" 35 | RUN curl -sLJ --output $project_dir/model/paraphrase/spm/para_nmt.model "https://teaspn.s3.amazonaws.com/server/0.0.1/assets/para_nmt.model" 36 | 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile --output-file=requirements.txt requirements.in 6 | # 7 | blis==0.2.4 # via spacy, thinc 8 | boto3==1.9.232 # via neuralcoref, transformers 9 | botocore==1.12.232 # via boto3, s3transfer 10 | certifi==2019.9.11 # via requests 11 | cffi==1.12.3 # via fairseq 12 | chardet==3.0.4 # via requests 13 | click==7.0 # via sacremoses 14 | cymem==2.0.2 # via preshed, spacy, thinc 15 | docutils==0.15.2 # via botocore 16 | fairseq==0.8.0 17 | fastbpe==0.1.0 # via fairseq 18 | future==0.17.1 # via python-jsonrpc-server 19 | idna==2.8 # via requests 20 | jmespath==0.9.4 # via boto3, botocore 21 | joblib==0.13.2 # via sacremoses 22 | jsonschema==2.6.0 # via spacy 23 | language-check==1.1 24 | murmurhash==1.0.2 # via spacy, thinc 25 | neuralcoref==4.0 26 | nltk==3.4.5 27 | numpy==1.17.2 # via blis, fairseq, neuralcoref, spacy, thinc, torch, transformers 28 | overrides==1.9 29 | plac==0.9.6 # via spacy, thinc 30 | portalocker==1.5.1 # via sacrebleu 31 | preshed==2.0.1 # via spacy, thinc 32 | pycparser==2.19 # via cffi 33 | python-dateutil==2.8.0 # via botocore 34 | python-jsonrpc-server==0.2.0 35 | regex==2019.8.19 # via fairseq, transformers 36 | requests==2.22.0 # via neuralcoref, spacy, transformers 37 | retrying==1.3.3 38 | s3transfer==0.2.1 # via boto3 39 | sacrebleu==1.4.1 # via fairseq 40 | sacremoses==0.0.34 # via transformers 41 | sentencepiece==0.1.83 42 | six==1.12.0 # via nltk, python-dateutil, retrying, sacremoses 43 | spacy==2.1.0 44 | srsly==0.1.0 # via spacy, thinc 45 | thinc==7.0.8 # via spacy 46 | torch==1.2.0 47 | tqdm==4.36.1 # via fairseq, sacremoses, thinc, transformers 48 | transformers==2.1.1 49 | typing==3.7.4.1 # via sacrebleu 50 | urllib3==1.25.5 # via botocore, requests 51 | wasabi==0.2.2 # via spacy, thinc 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TEASPN Server 2 | 3 | 4 | ## Overview 5 | This is a sample implementation of the [TEASPN](https://www.teaspn.org/) server supporting neural auto-completion and paraphrasing. 6 | 7 | 8 | ## Prerequisites 9 | 10 | + [TEASPN SDK cilent](https://github.com/teaspn/teaspn-sdk) 11 | + [Docker](https://docs.docker.com/install/) 12 | + [Docker Compose](https://docs.docker.com/compose/install/) 13 | 14 | ## Install 15 | 16 | 1. Build an image from the Dockerfile and pull images from the registry. Note that this may take a while (10-30 mins) depending on your environment. 17 | ``` 18 | docker-compose build 19 | docker-compose pull 20 | ``` 21 | 22 | 2. Download and build the Elasticsearch index. This also takes a while. 23 | ``` 24 | mkdir elasticsearch_indexes 25 | curl -sLJ --output elasticsearch_indexes/tatoeba_ja.index 'https://teaspn.s3.amazonaws.com/server/0.0.1/assets/tatoeba_ja.index' 26 | 27 | docker-compose up -d elasticsearch 28 | docker run --net=host --rm -ti -v $PWD/elasticsearch_indexes:/tmp taskrabbit/elasticsearch-dump --input=/tmp/tatoeba_ja.index --output=http://localhost:9200/tatoeba_ja 29 | 30 | docker-compose stop 31 | ``` 32 | 33 | 3. Set the PATH environment variable. 34 | ``` 35 | echo "export PATH='$PWD:$PATH'" >> ~/.bash_profile 36 | ``` 37 | + If you are using zsh, modify your `~/.zshenv` file instead of `~/.bash_profile`. 38 | 39 | 40 | ## Features 41 | 42 | ### Syntax highlighting 43 | 44 | Powered by the spaCy dependency parser. 45 | 46 | 47 | Head tokens with specific dependency (ROOT, nsubj, nsubjpass, and dobj in the CLEAR style tag set) relation are highlighted in different colors. 48 | 49 | ### Grammatical Error Detection & Grammatical Error Correction 50 | 51 | Powered by [LanguageTool](https://languagetool.org/) and [its python wrapper](https://github.com/myint/language-check) 52 | 53 | ![GED GEC Demo](images/ged_gec.gif) 54 | 55 | ### Completion 56 | 57 | This implementation provides two types of completion: 58 | 59 | 1. Suggesting the likely next phrases given the context using [DistilGPT2](https://github.com/huggingface/transformers) developed by HuggingFace. 60 | 2. Suggesting a set of words consistent with the characters being typed. 61 | 62 | ![Completion Demo](images/completion.gif) 63 | 64 | ### Text Rewriting 65 | 66 | This provides paraphrase suggestions for the selected text. 67 | 68 | We built a paraphrase model trained on [PARANMT-50M](https://github.com/jwieting/para-nmt-50m) using [fairseq](https://github.com/pytorch/fairseq). 69 | 70 | ![Paraphrasing Demo](images/paraphrase.gif) 71 | 72 | ### Example Search 73 | 74 | Provide full-text search feature using [Tatoeba](https://tatoeba.org) and [Elasticsearch](https://www.elastic.co/products/elasticsearch). Currently, this only supports Japanese-to-English search. 75 | 76 | ![Search Demo](images/search.gif) 77 | 78 | ### Reference Jump 79 | 80 | This lets you jupm from a selected expression to its antecedent. Powered by [NeuralCoref](https://github.com/huggingface/neuralcoref). 81 | 82 | ![Jump Demo](images/jump.gif) 83 | 84 | ### Mouse Hover 85 | 86 | Show the definition of a hovered word using [WordNet](https://wordnet.princeton.edu). 87 | 88 | ![Hover Demo](images/hover.gif) 89 | 90 | -------------------------------------------------------------------------------- /teaspn_server/teaspn_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import List, Optional 4 | 5 | from teaspn_server.protocol import ( 6 | CodeAction, Command, CompletionList, Diagnostic, Hover, Location, 7 | Position, Range, SyntaxHighlight 8 | ) 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TeaspnHandler(object): 15 | """ 16 | This is the abstract base class for handling requests (method calls) from ``TeaspnServer``. 17 | Language smartness developers are expected to inherit this class and create their own 18 | implementation of this class. 19 | 20 | This class also provides some low-level utility methods for keeping the document in sync with 21 | the client side. 22 | """ 23 | 24 | def __init__(self): 25 | self._line_offsets = [0] 26 | self._uri = None 27 | self._text = "" 28 | 29 | def _position_to_offset(self, position: Position) -> int: 30 | return self._line_offsets[position.line] + position.character 31 | 32 | def _offset_to_position(self, offset: int) -> Position: 33 | line = 0 34 | for i, n in enumerate(self._line_offsets): 35 | if n > offset: 36 | line = i - 1 37 | break 38 | 39 | character = offset - line 40 | return Position(line=line, character=character) 41 | 42 | def _recompute_line_offsets(self): 43 | self._line_offsets = [0] + [m.start() + 1 for m in re.finditer('\n', self._text)] 44 | # NOTE: \r\n may also need to be considered. 45 | 46 | def _get_text(self, range: Range) -> str: 47 | start_offset = self._position_to_offset(range.start) 48 | end_offset = self._position_to_offset(range.end) 49 | return self._text[start_offset:end_offset] 50 | 51 | def _get_line(self, line: int) -> str: 52 | """ 53 | Returns the text on the specified line. 54 | """ 55 | 56 | line_offsets_with_sentinel = self._line_offsets + [len(self._text)] 57 | return self._text[line_offsets_with_sentinel[line]:line_offsets_with_sentinel[line+1]] 58 | 59 | def _get_word_at(self, position: Position) -> Optional[str]: 60 | """ 61 | Returns the word at position. 62 | 63 | This method gets the text for the cursor line, and then 64 | finds the word that encompasses the cursor. This is a bit inefficient, but simple. 65 | (otherwise you'd need to scan the string in both directions and find punctuations etc., 66 | which is a lot more complicated). 67 | """ 68 | 69 | line = self._get_line(position.line) 70 | 71 | for match in re.finditer(r'\w+', line): 72 | if match.start() <= position.character <= match.end(): 73 | return match.group(0) 74 | 75 | return None 76 | 77 | def initialize_document(self, uri: str, text: str): 78 | logger.debug('Initialized document: uri=%s, text=%s', uri, text) 79 | self._uri = uri 80 | self._text = text 81 | self._recompute_line_offsets() 82 | 83 | def update_document(self, range: Range, text: str): 84 | start_offset = self._position_to_offset(range.start) 85 | end_offset = self._position_to_offset(range.end) 86 | self._text = self._text[:start_offset] + text + self._text[end_offset:] 87 | # self._text = text 88 | self._recompute_line_offsets() 89 | logger.debug('Updated document: text=%s', self._text) 90 | 91 | def highlight_syntax(self) -> List[SyntaxHighlight]: 92 | raise NotImplementedError 93 | 94 | def get_diagnostics(self) -> List[Diagnostic]: 95 | raise NotImplementedError 96 | 97 | def run_quick_fix(self, range: Range, diagnostics: List[Diagnostic]) -> List[CodeAction]: 98 | raise NotImplementedError 99 | 100 | def run_code_action(self, range: Range) -> List[Command]: 101 | raise NotImplementedError 102 | 103 | def search_example(self, query: str) -> List: 104 | # TODO: fix the protocol definition of search example 105 | raise NotImplementedError 106 | 107 | def search_definition(self, position: Position, uri: str) -> List[Location]: 108 | raise NotImplementedError 109 | 110 | def get_completion_list(self, position: Position) -> CompletionList: 111 | raise NotImplementedError 112 | 113 | def hover(self, position: Position) -> Optional[Hover]: 114 | raise NotImplementedError 115 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | log.txt 2 | model 3 | elasticsearch_indexes 4 | .pyenv 5 | 6 | ### macOS 7 | # General 8 | .DS_Store 9 | .AppleDouble 10 | .LSOverride 11 | 12 | # Icon must end with two \r 13 | Icon 14 | 15 | 16 | # Thumbnails 17 | ._* 18 | 19 | # Files that might appear in the root of a volume 20 | .DocumentRevisions-V100 21 | .fseventsd 22 | .Spotlight-V100 23 | .TemporaryItems 24 | .Trashes 25 | .VolumeIcon.icns 26 | .com.apple.timemachine.donotpresent 27 | 28 | # Directories potentially created on remote AFP share 29 | .AppleDB 30 | .AppleDesktop 31 | Network Trash Folder 32 | Temporary Items 33 | .apdisk 34 | 35 | ### Python 36 | # Byte-compiled / optimized / DLL files 37 | __pycache__/ 38 | *.py[cod] 39 | *$py.class 40 | 41 | # C extensions 42 | *.so 43 | 44 | # Distribution / packaging 45 | .Python 46 | build/ 47 | develop-eggs/ 48 | dist/ 49 | downloads/ 50 | eggs/ 51 | .eggs/ 52 | lib/ 53 | lib64/ 54 | parts/ 55 | sdist/ 56 | var/ 57 | wheels/ 58 | pip-wheel-metadata/ 59 | share/python-wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .nox/ 79 | .coverage 80 | .coverage.* 81 | .cache 82 | nosetests.xml 83 | coverage.xml 84 | *.cover 85 | *.py,cover 86 | .hypothesis/ 87 | .pytest_cache/ 88 | 89 | # Translations 90 | *.mo 91 | *.pot 92 | 93 | # Django stuff: 94 | *.log 95 | local_settings.py 96 | db.sqlite3 97 | db.sqlite3-journal 98 | 99 | # Flask stuff: 100 | instance/ 101 | .webassets-cache 102 | 103 | # Scrapy stuff: 104 | .scrapy 105 | 106 | # Sphinx documentation 107 | docs/_build/ 108 | 109 | # PyBuilder 110 | target/ 111 | 112 | # Jupyter Notebook 113 | .ipynb_checkpoints 114 | 115 | # IPython 116 | profile_default/ 117 | ipython_config.py 118 | 119 | # pyenv 120 | .python-version 121 | 122 | # pipenv 123 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 124 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 125 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 126 | # install all needed dependencies. 127 | #Pipfile.lock 128 | 129 | # celery beat schedule file 130 | celerybeat-schedule 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | ### JetBrains 163 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 164 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 165 | 166 | # User-specific stuff 167 | .idea/**/workspace.xml 168 | .idea/**/tasks.xml 169 | .idea/**/usage.statistics.xml 170 | .idea/**/dictionaries 171 | .idea/**/shelf 172 | 173 | # Generated files 174 | .idea/**/contentModel.xml 175 | 176 | # Sensitive or high-churn files 177 | .idea/**/dataSources/ 178 | .idea/**/dataSources.ids 179 | .idea/**/dataSources.local.xml 180 | .idea/**/sqlDataSources.xml 181 | .idea/**/dynamic.xml 182 | .idea/**/uiDesigner.xml 183 | .idea/**/dbnavigator.xml 184 | 185 | # Gradle 186 | .idea/**/gradle.xml 187 | .idea/**/libraries 188 | 189 | # Gradle and Maven with auto-import 190 | # When using Gradle or Maven with auto-import, you should exclude module files, 191 | # since they will be recreated, and may cause churn. Uncomment if using 192 | # auto-import. 193 | # .idea/modules.xml 194 | # .idea/*.iml 195 | # .idea/modules 196 | # *.iml 197 | # *.ipr 198 | 199 | # CMake 200 | cmake-build-*/ 201 | 202 | .idea 203 | 204 | # File-based project format 205 | *.iws 206 | 207 | # IntelliJ 208 | out/ 209 | 210 | # mpeltonen/sbt-idea plugin 211 | .idea_modules/ 212 | 213 | # JIRA plugin 214 | atlassian-ide-plugin.xml 215 | 216 | # Cursive Clojure plugin 217 | .idea/replstate.xml 218 | 219 | # Crashlytics plugin (for Android Studio and IntelliJ) 220 | com_crashlytics_export_strings.xml 221 | crashlytics.properties 222 | crashlytics-build.properties 223 | fabric.properties 224 | 225 | 226 | 227 | ### VisualStudioCode 228 | .vscode/* 229 | !.vscode/settings.json 230 | !.vscode/tasks.json 231 | !.vscode/launch.json 232 | !.vscode/extensions.json 233 | *.code-workspace 234 | -------------------------------------------------------------------------------- /teaspn_server/gpt2completion_handler.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from multiprocessing import Value, JoinableQueue, Queue 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class GPT2CompletionHandler(object): 13 | def __init__(self, model_path='distilgpt2', tokenizer_path='distilgpt2', length=10, device='cpu', num_samples=1, 14 | temperature=1): 15 | self.model = GPT2LMHeadModel.from_pretrained(model_path) 16 | self.tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) 17 | self.num_samples = num_samples 18 | 19 | self.temperature = temperature 20 | self.device = device 21 | 22 | self.model.to(self.device) 23 | self.model.eval() 24 | self.length = length 25 | self.past = None 26 | self.past_context = "" 27 | 28 | if self.length == -1: 29 | self.length = self.model.config.n_ctx // 2 30 | elif self.length > self.model.config.n_ctx: 31 | raise ValueError("Can't get samples longer than window size: %s" % self.model.config.n_ctx) 32 | 33 | def generate(self, context): 34 | context = '<|endoftext|> ' + context 35 | context_tokens = self.tokenizer.encode(context) 36 | out = self.sample_sequence(context=context_tokens, ) 37 | out = out[:, len(context_tokens):].tolist() 38 | return [self.tokenizer.decode(out[i], clean_up_tokenization_spaces=True).strip() for i in 39 | range(self.num_samples)] 40 | 41 | def top_k_top_p_filtering(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 42 | """ 43 | Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 44 | Args: 45 | logits: logits distribution shape (vocabulary size) 46 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 47 | top_p > 0.0: keep the top tokens with cumulative probability > top_p (nucleus filtering). 48 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 49 | From: 50 | https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 51 | and https://github.com/huggingface/transformers/pull/1333/files#diff-69e141e24d872d0cad3270be7db159e5L104 52 | """ 53 | 54 | top_k = min(top_k, logits.size(-1)) # Safety check 55 | if top_k > 0: 56 | # Remove all tokens with a probability less than the last token of the top-k 57 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 58 | logits[indices_to_remove] = filter_value 59 | 60 | if top_p > 0.0: 61 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 62 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 63 | 64 | # Remove tokens with cumulative probability above the threshold 65 | sorted_indices_to_remove = cumulative_probs > top_p 66 | # Shift the indices to the right to keep also the first token above the threshold 67 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 68 | sorted_indices_to_remove[..., 0] = 0 69 | 70 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, 71 | src=sorted_indices_to_remove) 72 | logits[indices_to_remove] = filter_value 73 | return logits 74 | 75 | def sample_sequence(self, context, top_k=0, top_p=0.0, sample=True): 76 | """ 77 | Args: 78 | context: 79 | top_k: 80 | top_p: 81 | sample: 82 | 83 | Note: 84 | We created based on https://github.com/huggingface/transformers/blob/v2.0.0/examples/run_generation.py 85 | """ 86 | 87 | if context[:len(self.past_context)] == context: 88 | self.past_context = context 89 | context = context[len(self.past_context):] 90 | assert self.past_context != context 91 | local_past = self.past 92 | else: 93 | local_past = None 94 | 95 | context = torch.tensor(context, device=self.device, dtype=torch.long).unsqueeze(0).repeat(self.num_samples, 1) 96 | 97 | logger.info('IN sample_sequence: context=%s', context) 98 | prev = context 99 | output = context 100 | 101 | with torch.no_grad(): 102 | for i in range(self.length): 103 | next_token_logits, local_past = self.model(prev, past=local_past) 104 | if i == 0: 105 | self.past = local_past 106 | next_token_logits = next_token_logits[:, -1, :] / self.temperature 107 | filtered_logits = self.top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) 108 | 109 | next_token_log_probs = F.softmax(filtered_logits, dim=-1) 110 | if sample: 111 | prev = torch.multinomial(next_token_log_probs, num_samples=1) 112 | else: 113 | _, prev = torch.topk(next_token_log_probs, k=1, dim=-1) 114 | output = torch.cat((output, prev), dim=1) 115 | return output 116 | -------------------------------------------------------------------------------- /teaspn_server/paraphrase_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from collections import namedtuple 4 | from contextlib import redirect_stdout 5 | 6 | import torch 7 | import sentencepiece as spm 8 | from fairseq import options, tasks, utils, checkpoint_utils 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | Batch = namedtuple('Batch', 'ids src_tokens src_lengths') 14 | 15 | 16 | class ParaphraseHandler(object): 17 | def __init__(self, model_path, data_path, tokenizer_path, beam=3, n_best=3, diverse_beam_group=3, 18 | diverse_beam_strength=0.5, source_lang='source.spm', target_lang='target.spm'): 19 | self.tokenizer = spm.SentencePieceProcessor() 20 | self.tokenizer.Load(tokenizer_path) 21 | self.parser = options.get_generation_parser() 22 | self.args = options.parse_args_and_arch(self.parser, 23 | input_args=['--cpu', 24 | f'--path={model_path}', 25 | f'--beam={beam}', 26 | f'--nbest={n_best}', 27 | f'--diverse-beam-groups={diverse_beam_group}', 28 | f'--diverse-beam-strength={diverse_beam_strength}', 29 | f'--source-lang={source_lang}', 30 | f'--target-lang={target_lang}', 31 | f'--task=translation', 32 | '--remove-bpe=sentencepiece', 33 | data_path]) 34 | 35 | # self.task = TranslationTask.setup_task(self.args) 36 | with redirect_stdout(open(os.devnull, 'w')): 37 | self.task = tasks.setup_task(self.args) 38 | self.models, _ = checkpoint_utils.load_model_ensemble(self.args.path.split(':'), task=self.task, 39 | arg_overrides=eval(self.args.model_overrides)) 40 | self.src_dict = self.task.source_dictionary 41 | self.tgt_dict = self.task.target_dictionary 42 | 43 | # Optimize ensemble for generation 44 | for model in self.models: 45 | model.make_generation_fast_(beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam, 46 | need_attn=self.args.print_alignment) 47 | 48 | self.generator = self.task.build_generator(self.args) 49 | self.align_dict = utils.load_align_dict(self.args.replace_unk) 50 | 51 | self._max_positions = utils.resolve_max_positions( 52 | self.task.max_positions(), 53 | *[model.max_positions() for model in self.models] 54 | ) 55 | self.use_cuda = False 56 | 57 | def generate(self, text_input: str) -> list: 58 | 59 | text_input = ' '.join(self.tokenizer.EncodeAsPieces(text_input)) 60 | outputs = [] 61 | start_id = 0 62 | 63 | results = [] 64 | for batch in self._make_batches([text_input]): 65 | src_tokens = batch.src_tokens 66 | src_lengths = batch.src_lengths 67 | if self.use_cuda: 68 | src_tokens = src_tokens.cuda() 69 | src_lengths = src_lengths.cuda() 70 | 71 | sample = { 72 | 'net_input': { 73 | 'src_tokens': src_tokens, 74 | 'src_lengths': src_lengths, 75 | }, 76 | } 77 | translations = self.task.inference_step(self.generator, self.models, sample) 78 | for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): 79 | src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) 80 | results.append((start_id + id, src_tokens_i, hypos)) 81 | 82 | # sort output to match input order 83 | for _, src_tokens, hypos in sorted(results, key=lambda x: x[0]): 84 | if self.src_dict is not None: 85 | src_str = self.src_dict.string(src_tokens, self.args.remove_bpe) 86 | 87 | # Process top predictions 88 | for hypo in hypos[:min(len(hypos), self.args.nbest)]: 89 | hypo_tokens, hypo_str, alignment = utils.post_process_prediction( 90 | hypo_tokens=hypo['tokens'].int().cpu(), 91 | src_str=src_str, 92 | alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, 93 | align_dict=self.align_dict, 94 | tgt_dict=self.tgt_dict, 95 | remove_bpe=self.args.remove_bpe, 96 | ) 97 | outputs.append(hypo_str) 98 | # update running id counter 99 | return list(set(outputs)) 100 | 101 | def _make_batches(self, lines): 102 | tokens = [ 103 | self.task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long() 104 | for src_str in lines 105 | ] 106 | lengths = torch.LongTensor([t.numel() for t in tokens]) 107 | itr = self.task.get_batch_iterator( 108 | dataset=self.task.build_dataset_for_inference(tokens, lengths), 109 | max_tokens=self.args.max_tokens, 110 | max_sentences=self.args.max_sentences, 111 | max_positions=self._max_positions, 112 | ).next_epoch_itr(shuffle=False) 113 | for batch in itr: 114 | yield Batch( 115 | ids=batch['id'], 116 | src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], 117 | ) 118 | -------------------------------------------------------------------------------- /teaspn_server/protocol.py: -------------------------------------------------------------------------------- 1 | """Defines constants and types for TEASPN.""" 2 | 3 | from typing import Dict, List, NamedTuple, Optional, Union 4 | 5 | 6 | class TextDocumentSyncKind: 7 | NONE = 0 8 | Full = 1 9 | Incremental = 2 10 | 11 | 12 | class Position(NamedTuple): 13 | line: int 14 | character: int 15 | 16 | @classmethod 17 | def from_dict(cls, data: Dict) -> 'Position': 18 | return cls(line=data['line'], character=data['character']) 19 | 20 | def to_dict(self): 21 | return {'line': self.line, 'character': self.character} 22 | 23 | 24 | class Range(NamedTuple): 25 | start: Position 26 | end: Position 27 | 28 | @classmethod 29 | def from_dict(cls, data: Dict) -> 'Range': 30 | start = Position.from_dict(data['start']) 31 | end = Position.from_dict(data['end']) 32 | return cls(start=start, end=end) 33 | 34 | def to_dict(self): 35 | return {'start': self.start.to_dict(), 'end': self.end.to_dict()} 36 | 37 | 38 | class Location(NamedTuple): 39 | uri: str 40 | range: Range 41 | 42 | def to_dict(self): 43 | return {'uri': self.uri, 'range': self.range.to_dict()} 44 | 45 | 46 | class SyntaxHighlight(NamedTuple): 47 | range: Range 48 | type: str 49 | hoverMessage: Optional[str] = None 50 | 51 | def to_dict(self): 52 | result = { 53 | 'range': self.range.to_dict(), 54 | 'type': self.type 55 | } 56 | 57 | if self.hoverMessage: 58 | result['hoverMessage'] = self.hoverMessage 59 | 60 | return result 61 | 62 | 63 | class DiagnosticSeverity: 64 | Error = 1 65 | Warning = 2 66 | Information = 3 67 | Hint = 4 68 | 69 | 70 | class Diagnostic(NamedTuple): 71 | range: Range 72 | message: str 73 | severity: Optional[int] = None 74 | code: Optional[Union[int, str]] = None 75 | source: Optional[str] = None 76 | relatedInformation: None = None # TODO: support relatedInformation 77 | 78 | @classmethod 79 | def from_dict(cls, data: Dict) -> 'Diagnostic': 80 | rng = Range.from_dict(data['range']) 81 | return cls(range=rng, 82 | message=data['message'], 83 | severity=data.get('severity'), 84 | code=data.get('code'), 85 | source=data.get('source')) 86 | 87 | def to_dict(self): 88 | result = { 89 | 'range': self.range.to_dict(), 90 | 'message': self.message 91 | } 92 | 93 | if self.severity is not None: 94 | result['severity'] = self.severity 95 | 96 | if self.code is not None: 97 | result['code'] = self.code 98 | 99 | if self.source is not None: 100 | result['source'] = self.source 101 | 102 | return result 103 | 104 | 105 | class TextEdit(NamedTuple): 106 | range: Range 107 | newText: str 108 | 109 | def to_dict(self): 110 | return { 111 | 'range': self.range.to_dict(), 112 | 'newText': self.newText 113 | } 114 | 115 | 116 | class CompletionItem(NamedTuple): 117 | label: str 118 | kind: Optional[int] = None 119 | detail: Optional[str] = None 120 | documentation: Optional[str] = None # TODO: support MarkupContent 121 | deprecated: Optional[bool] = None 122 | preselect: Optional[bool] = None 123 | sortText: Optional[str] = None 124 | filterText: Optional[str] = None 125 | insertText: Optional[str] = None 126 | insertTextFormat = None # TODO: support insertTextFormat 127 | textEdit: TextEdit = None 128 | additionalTextEdits = None # TODO: support additionalTextEdits 129 | commitCharacters: Optional[List[str]] = None 130 | command = None # TODO: support command 131 | data = None # TODO: support data 132 | 133 | def to_dict(self): 134 | result = {'label': self.label} 135 | 136 | if self.detail is not None: 137 | result['detail'] = self.detail 138 | 139 | if self.textEdit is not None: 140 | result['textEdit'] = self.textEdit.to_dict() 141 | 142 | return result 143 | 144 | 145 | class CompletionList(NamedTuple): 146 | isIncomplete: bool 147 | items: List[CompletionItem] 148 | 149 | def to_dict(self): 150 | return { 151 | 'isIncomplete': self.isIncomplete, 152 | 'items': [item.to_dict() for item in self.items] 153 | } 154 | 155 | 156 | class Example(NamedTuple): 157 | label: str 158 | description: str 159 | 160 | def to_dict(self): 161 | return { 162 | 'label': self.label, 163 | 'description': self.description 164 | } 165 | 166 | 167 | class WorkspaceEdit(NamedTuple): 168 | changes: Dict[str, List[TextEdit]] 169 | documentChanges = None # NOTE: not supported 170 | 171 | def to_dict(self): 172 | return {'changes': {uri: [edit.to_dict() for edit in edits] 173 | for uri, edits in self.changes.items()}} 174 | 175 | 176 | class Command(NamedTuple): 177 | title: str 178 | command: str 179 | arguments: Optional[List] = None 180 | 181 | def to_dict(self): 182 | result = { 183 | 'title': self.title, 184 | 'command': self.command 185 | } 186 | if self.arguments is not None: 187 | result['arguments'] = [arg.to_dict() for arg in self.arguments] 188 | 189 | return result 190 | 191 | 192 | class CodeAction(NamedTuple): 193 | title: str 194 | kind: Optional[str] = None 195 | edit: Optional[WorkspaceEdit] = None # NOTE: current ver. of VSCode doesn't support this 196 | command: Optional[Command] = None 197 | 198 | def to_dict(self): 199 | result = { 200 | 'title': self.title 201 | } 202 | 203 | if self.kind: 204 | result['kind'] = self.kind 205 | 206 | if self.edit: 207 | result['edit'] = self.edit.to_dict() 208 | 209 | if self.command: 210 | result['command'] = self.command.to_dict() 211 | 212 | return result 213 | 214 | 215 | class Hover(NamedTuple): 216 | contents: str # TODO: support MarkedString and list of MarkedString 217 | range: Optional[Range] = None 218 | 219 | def to_dict(self): 220 | result = { 221 | 'contents': self.contents 222 | } 223 | 224 | if self.range is not None: 225 | result['range'] = self.range.to_dict() 226 | 227 | return result 228 | -------------------------------------------------------------------------------- /teaspn_server/teaspn_server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import argparse 4 | import socketserver 5 | 6 | from pyls_jsonrpc.dispatchers import MethodDispatcher 7 | from pyls_jsonrpc.endpoint import Endpoint 8 | from pyls_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter 9 | 10 | from teaspn_server.handler_impl_demo import TeaspnHandlerImplDemo 11 | from teaspn_server.protocol import Diagnostic, Position, Range, TextDocumentSyncKind 12 | from teaspn_server.teaspn_handler import TeaspnHandler 13 | 14 | 15 | MAX_WORKERS = 64 16 | 17 | logging.basicConfig(filename='log.txt', 18 | filemode='w', 19 | format='%(name)s - %(levelname)s - %(message)s', 20 | level=logging.DEBUG) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def _binary_stdio(): 25 | """ 26 | Construct binary stdio streams (not text mode). 27 | This seems to be different for Window/Unix Python2/3, so going by: 28 | https://stackoverflow.com/questions/2850893/reading-binary-data-from-stdin 29 | 30 | NOTE: this method is borrowed from python-language-server: 31 | https://github.com/palantir/python-language-server/blob/develop/pyls/__main__.py 32 | """ 33 | 34 | PY3K = sys.version_info >= (3, 0) 35 | 36 | if PY3K: 37 | # pylint: disable=no-member 38 | stdin, stdout = sys.stdin.buffer, sys.stdout.buffer 39 | else: 40 | # Python 2 on Windows opens sys.stdin in text mode, and 41 | # binary data that read from it becomes corrupted on \r\n 42 | if sys.platform == "win32": 43 | # set sys.stdin to binary mode 44 | # pylint: disable=no-member,import-error 45 | import os 46 | import msvcrt 47 | msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY) 48 | msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) 49 | stdin, stdout = sys.stdin, sys.stdout 50 | 51 | return stdin, stdout 52 | 53 | 54 | class _StreamHandlerWrapper(socketserver.StreamRequestHandler, object): 55 | """ 56 | A wrapper class that is used to construct a custom handler class. 57 | """ 58 | 59 | delegate = None 60 | 61 | def setup(self): 62 | super(_StreamHandlerWrapper, self).setup() 63 | # pylint: disable=no-member 64 | self.delegate = self.DELEGATE_CLASS(self.rfile, self.wfile) 65 | 66 | def handle(self): 67 | self.delegate.start() 68 | 69 | 70 | class TeaspnServer(MethodDispatcher): 71 | """ 72 | This class handles JSON-RPC requests to/from a Teaspn client, working as a middle layer 73 | that passes requests to a TeaspnHandler. Also does serialization and deseriaization of 74 | Teaspn objects. 75 | """ 76 | 77 | def __init__(self, rx, tx, handler: TeaspnHandler, check_parent_process=False): 78 | self.workspace = None 79 | self.config = None 80 | 81 | self._jsonrpc_stream_reader = JsonRpcStreamReader(rx) 82 | self._jsonrpc_stream_writer = JsonRpcStreamWriter(tx) 83 | self._handler = handler 84 | self._check_parent_process = check_parent_process 85 | self._endpoint = Endpoint(self, self._jsonrpc_stream_writer.write, max_workers=MAX_WORKERS) 86 | self._dispatchers = [] 87 | self._shutdown = False 88 | 89 | def start(self): 90 | """ 91 | Entry point for the server. 92 | """ 93 | 94 | self._jsonrpc_stream_reader.listen(self._endpoint.consume) 95 | 96 | def m_initialize(self, processId=None, rootUri=None, rootPath=None, initializationOptions=None, 97 | **_kwargs): 98 | return {"capabilities": { 99 | "textDocumentSync": { 100 | "openClose": True, 101 | "change": TextDocumentSyncKind.Incremental 102 | }, 103 | "completionProvider": { 104 | "resolveProvider": False, 105 | "triggerCharacters":[' '] + list(__import__('string').ascii_lowercase) 106 | }, 107 | "codeActionProvider": True, 108 | "executeCommandProvider": { 109 | "commands": ['refactor.rewrite'] 110 | }, 111 | "definitionProvider": True, 112 | "hoverProvider": True 113 | }} 114 | 115 | def m_shutdown(self, **_kwargs): 116 | return None 117 | 118 | def m_text_document__did_open(self, textDocument=None, **_kwargs): 119 | self._handler.initialize_document(textDocument['uri'], textDocument['text']) 120 | 121 | diagnostics = self._handler.get_diagnostics() 122 | self._endpoint.notify('textDocument/publishDiagnostics', { 123 | 'uri': textDocument['uri'], 124 | 'diagnostics': [diagnostic.to_dict() for diagnostic in diagnostics] 125 | }) 126 | 127 | def m_text_document__did_change(self, textDocument=None, contentChanges=None, **_kwargs): 128 | for change in contentChanges: 129 | rng = Range.from_dict(change['range']) 130 | self._handler.update_document(range=rng, text=change['text']) 131 | # self._handler.update_document(range='', text=change['text']) 132 | diagnostics = self._handler.get_diagnostics() 133 | 134 | self._endpoint.notify('textDocument/publishDiagnostics', { 135 | 'uri': textDocument['uri'], 136 | 'diagnostics': [diagnostic.to_dict() for diagnostic in diagnostics] 137 | }) 138 | 139 | def m_text_document__syntax_highlight(self, textDocument=None, **_kwargs): 140 | highlights = self._handler.highlight_syntax() 141 | return [highlight.to_dict() for highlight in highlights] 142 | 143 | def m_text_document__completion(self, textDocument=None, position=None, **_kwargs): 144 | position = Position.from_dict(position) 145 | 146 | completion_list = self._handler.get_completion_list(position=position) 147 | 148 | return completion_list.to_dict() 149 | 150 | def m_workspace__search_example(self, query=None, **_kwargs): 151 | examples = self._handler.search_example(query) 152 | return [example.to_dict() for example in examples] 153 | 154 | def m_text_document__code_action(self, textDocument=None, range=None, context=None, **_kwargs): 155 | rng = Range.from_dict(range) 156 | actions = [] 157 | if context is not None and context.get('diagnostics', []): 158 | # code action for resolving a diagnostics -> invoke quick fix 159 | diagnostics = [Diagnostic.from_dict(diag) for diag in context['diagnostics']] 160 | actions = self._handler.run_quick_fix(rng, diagnostics) 161 | 162 | # obtain paraphrases 163 | commands = self._handler.run_code_action(rng) 164 | 165 | return [action_or_command.to_dict() for action_or_command in actions + commands] 166 | 167 | def m_workspace__execute_command(self, command=None, arguments=None, **_kwargs): 168 | if command == 'refactor.rewrite': 169 | self._endpoint.request('workspace/applyEdit', { 170 | 'edit': arguments[0] 171 | }) 172 | 173 | def m_text_document__definition(self, textDocument=None, position=None, **_kwargs): 174 | position = Position.from_dict(position) 175 | locations = self._handler.search_definition(position, uri=textDocument['uri']) 176 | return [location.to_dict() for location in locations] 177 | 178 | def m_text_document__hover(self, textDocument=None, position=None, **_kwargs): 179 | position = Position.from_dict(position) 180 | hover = self._handler.hover(position) 181 | if hover: 182 | return hover.to_dict() 183 | else: 184 | return None 185 | 186 | 187 | def main(): 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--tcp', action='store_true') 190 | 191 | args = parser.parse_args() 192 | 193 | if args.tcp: 194 | bind_addr = '127.0.0.1' 195 | port = '9520' 196 | server = socketserver.TCPServer((bind_addr, port), TeaspnServer) 197 | server.allow_reuse_address = True 198 | 199 | else: 200 | stdin, stdout = _binary_stdio() 201 | handler = TeaspnHandlerImplDemo() 202 | server = TeaspnServer(stdin, stdout, handler, check_parent_process=False) 203 | server.start() 204 | 205 | 206 | if __name__ == '__main__': 207 | main() 208 | -------------------------------------------------------------------------------- /teaspn_server/handler_impl_demo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | import time 5 | from multiprocessing import Process, Value, JoinableQueue, Queue 6 | from typing import Dict, List, Optional 7 | 8 | import language_check 9 | import neuralcoref 10 | import requests 11 | import spacy 12 | from nltk.corpus import wordnet 13 | from overrides import overrides 14 | from retrying import retry 15 | 16 | from teaspn_server.protocol import ( 17 | CodeAction, Command, CompletionItem, CompletionList, Diagnostic, DiagnosticSeverity, Example, 18 | Hover, Location, Position, Range, SyntaxHighlight, TextEdit, WorkspaceEdit 19 | ) 20 | from teaspn_server.teaspn_handler import TeaspnHandler 21 | from teaspn_server.gpt2completion_handler import GPT2CompletionHandler 22 | from teaspn_server.paraphrase_handler import ParaphraseHandler 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class ColorSetting: 28 | 29 | dep2color = {'nsubj': 'salmon', 30 | 'ROOT': 'green', 31 | 'dobj': 'skyblue'} 32 | 33 | 34 | class ExampleHandler(object): 35 | 36 | def __init__(self): 37 | self._requests = JoinableQueue() 38 | self._results = Queue() 39 | self._busy = Value('i', 0) 40 | 41 | @retry(stop_max_attempt_number=5) 42 | def search_tatoeba_examples_several_attempts(self, query): 43 | response = requests.post( 44 | url=('http://elasticsearch:9200/tatoeba_ja/_search'), 45 | data=json.dumps({'size': 30, 46 | 'query': {'match_phrase': {'key': query}}}), 47 | headers={'Content-Type': 'application/json'}, 48 | ) 49 | if response.status_code == 200: 50 | response_data = response.json() 51 | logger.info('Received response: %r', response_data) 52 | for text in response_data['hits']['hits']: 53 | self._results.put( 54 | (text['_source']['value'], text['_source']['key'])) 55 | return True 56 | else: 57 | raise IOError("connection fails") 58 | time.sleep(1.) 59 | 60 | def loop(self): 61 | while True: 62 | if not self._requests.empty(): 63 | self._busy.value = 1 64 | query = self._requests.get() 65 | # process 66 | 67 | logger.info('IN Example Searching Thread: target_text=%s', query) 68 | 69 | self.search_tatoeba_examples_several_attempts(query) 70 | self._requests.task_done() 71 | self._busy.value = 0 72 | 73 | time.sleep(1.) 74 | 75 | def make_request(self, text): 76 | self._requests.put(text) 77 | self._requests.join() 78 | 79 | results = [] 80 | while not self._results.empty(): 81 | text = self._results.get() 82 | results.append(text) 83 | 84 | return results 85 | 86 | 87 | class TeaspnHandlerImplDemo(TeaspnHandler): 88 | """ 89 | This is an implementation of ``TeaspnHandler`` for demo purposes. 90 | """ 91 | 92 | def __init__(self): 93 | super(TeaspnHandler, self).__init__() 94 | 95 | self._detect_ge = language_check.LanguageTool('en-US') 96 | # Manages a mapping from Diagnostic to replacements obtained from LanguageTool. 97 | # TODO: maybe move this to the base class? 98 | self._diag_to_replacements: Dict[Diagnostic, List[str]] = {} 99 | 100 | self._paraphrase_handler = ParaphraseHandler(model_path='model/paraphrase/checkpoint_best.pt', 101 | data_path='model/paraphrase', 102 | tokenizer_path='model/paraphrase/spm/para_nmt.model') 103 | 104 | self._completion_handler = GPT2CompletionHandler(num_samples=4, length=5) 105 | 106 | self._example_handler = ExampleHandler() 107 | self._example_process = Process(target=self._example_handler.loop) 108 | self._example_process.start() 109 | 110 | self._nlp = spacy.load('en') 111 | 112 | neuralcoref.add_to_pipe(self._nlp) 113 | 114 | self._freq_words = [word.strip() for word in open('teaspn_server/word.txt')] 115 | 116 | @overrides 117 | def highlight_syntax(self) -> List[SyntaxHighlight]: 118 | highlights = [] 119 | for line_id, line_text in enumerate(self._text.splitlines()): 120 | 121 | for tok_id, tok in enumerate(self._nlp(line_text)): 122 | if tok.dep_ in ColorSetting.dep2color: 123 | rng = Range(start=Position(line=line_id, character=tok.idx), 124 | end=Position(line=line_id, character=tok.idx + len(tok.text))) 125 | highlights.append(SyntaxHighlight(range=rng, 126 | type=ColorSetting.dep2color[tok.dep_], 127 | hoverMessage='dep: {}'.format(tok.dep_))) 128 | 129 | return highlights 130 | 131 | @overrides 132 | def get_diagnostics(self) -> List[Diagnostic]: 133 | diagnostics = [] 134 | for line_id, line_text in enumerate(self._text.splitlines()): 135 | for m in self._detect_ge.check(line_text): 136 | rng = Range(start=Position(line=line_id, character=m.fromx), 137 | end=Position(line=line_id, character=m.tox)) 138 | diagnostic = Diagnostic(range=rng, 139 | severity=DiagnosticSeverity.Error, 140 | message=m.msg) 141 | diagnostics.append(diagnostic) 142 | self._diag_to_replacements[diagnostic] = m.replacements 143 | return diagnostics 144 | 145 | @overrides 146 | def run_quick_fix(self, range: Range, diagnostics: List[Diagnostic]) -> List[CodeAction]: 147 | actions = [] 148 | for diag in diagnostics: 149 | for repl in self._diag_to_replacements[diag]: 150 | edit = WorkspaceEdit({self._uri: [TextEdit(range=diag.range, newText=repl)]}) 151 | command = Command(title='Quick fix: {}'.format(repl), 152 | command='refactor.rewrite', 153 | arguments=[edit]) 154 | actions.append(CodeAction(title='Quick fix: {}'.format(repl), 155 | kind='quickfix', 156 | command=command)) 157 | return actions 158 | 159 | @overrides 160 | def run_code_action(self, range: Range) -> List[Command]: 161 | target_text = self._get_text(range) 162 | if not target_text: 163 | return [] 164 | texts = [] 165 | for text in self._paraphrase_handler.generate(target_text): 166 | texts.append(text) 167 | commands = [] 168 | for text in texts: 169 | edit = WorkspaceEdit({self._uri: [TextEdit(range=range, newText=text)]}) 170 | command = Command(title='Suggestion: {}'.format(text), 171 | command='refactor.rewrite', 172 | arguments=[edit]) 173 | commands.append(command) 174 | return commands 175 | 176 | @overrides 177 | def search_example(self, query: str) -> List[Example]: 178 | examples = self._example_handler.make_request(query) 179 | example_list = [] 180 | for label, description in examples: 181 | example = Example(label=label, description=description) 182 | example_list.append(example) 183 | return example_list 184 | 185 | @overrides 186 | def search_definition(self, position: Position, uri: str) -> List[Location]: 187 | doc = self._nlp(self._text) 188 | offset = self._position_to_offset(position) 189 | locations = [] 190 | for coreference in doc._.coref_clusters: 191 | for mention in coreference.mentions: 192 | if mention.start_char <= offset <= mention.end_char: 193 | rng = Range(start=self._offset_to_position(coreference.main.start_char), 194 | end=self._offset_to_position(coreference.main.end_char)) 195 | locations.append(Location(uri=uri, range=rng)) 196 | 197 | return locations 198 | 199 | @overrides 200 | def get_completion_list(self, position: Position) -> CompletionList: 201 | offset = self._position_to_offset(position) 202 | context = self._text[:offset] 203 | logger.info('Completion: context=%r', context) 204 | 205 | if not context: 206 | return CompletionList(isIncomplete=False, items=[]) 207 | 208 | items = [] 209 | if context.endswith(' '): 210 | for text in self._completion_handler.generate(context): 211 | if re.match(r'[.,:;]', text): 212 | position_dict = position.to_dict() 213 | position_dict['character'] += -1 214 | position = Position.from_dict(position_dict) 215 | 216 | rng = Range(start=position, end=position) 217 | items.append(CompletionItem(label=text, 218 | textEdit=TextEdit(range=rng, newText=text))) 219 | 220 | else: 221 | query = context.split()[-1] 222 | texts = [word for word in self._freq_words if word.startswith(query)] 223 | for text in texts: 224 | rng = Range(start=position, end=position) 225 | new_text = text.replace(context.split()[-1], '') 226 | text_edit = TextEdit(range=rng, newText=new_text) 227 | items.append(CompletionItem(label=text, textEdit=text_edit)) 228 | 229 | return CompletionList(isIncomplete=False, items=items) 230 | 231 | @overrides 232 | def hover(self, position: Position) -> Optional[Hover]: 233 | # NOTE: currently this implements a very simple PoC where the word at position 234 | # is matched against WordNet synsets and the first result is returned, regardless of POS 235 | word = self._get_word_at(position) 236 | if not word: 237 | return None 238 | 239 | synsets = wordnet.synsets(word) 240 | if not synsets: 241 | return None 242 | 243 | pos, definition = synsets[0].pos(), synsets[0].definition() 244 | return Hover(contents=f'{pos}: {definition}') 245 | --------------------------------------------------------------------------------