├── index ├── __init__.py ├── faiss │ └── _swigfaiss.so ├── annoyVectorIndex.py └── faissVectorIndex.py ├── model_api ├── __init__.py ├── processing │ └── _swigfaiss.so ├── onmt_lua_model_api.py └── abstract_model_api.py ├── s2s ├── __init__.py ├── lru.py └── project.py ├── client ├── css │ ├── global.scss │ ├── modalDialog.scss │ ├── main.scss │ └── vis.scss ├── assets │ └── s2s_logo.png ├── fonts │ ├── Source_Sans_Pro │ │ ├── SourceSansPro-Black.ttf │ │ ├── SourceSansPro-Bold.ttf │ │ ├── SourceSansPro-Light.ttf │ │ ├── SourceSansPro-Italic.ttf │ │ ├── SourceSansPro-Regular.ttf │ │ ├── SourceSansPro-SemiBold.ttf │ │ ├── SourceSansPro-BoldItalic.ttf │ │ ├── SourceSansPro-ExtraLight.ttf │ │ ├── SourceSansPro-BlackItalic.ttf │ │ ├── SourceSansPro-LightItalic.ttf │ │ ├── SourceSansPro-SemiBoldItalic.ttf │ │ ├── SourceSansPro-ExtraLightItalic.ttf │ │ └── OFL.txt │ └── ssp.css ├── ts │ ├── etc │ │ ├── LocalTypes.ts │ │ ├── Util.ts │ │ ├── SimpleEventHandler.ts │ │ ├── SVGplus.ts │ │ ├── ModalDialog.ts │ │ ├── Networking.ts │ │ └── URLHandler.ts │ ├── vis │ │ ├── BarList.ts │ │ ├── StateVis.ts │ │ ├── BeamTree.ts │ │ ├── InfoPanel.ts │ │ ├── AttentionVis.ts │ │ ├── CloseWordList.ts │ │ ├── VisualComponent.ts │ │ ├── WordProjector.ts │ │ ├── StatePictograms.ts │ │ └── WordLine.ts │ ├── main.ts │ └── api │ │ ├── S2SApi.ts │ │ └── Translation.ts ├── tsconfig.json ├── README.md ├── package.json ├── webpack.config.js └── index.html ├── scripts ├── _swigfaiss.so ├── h5_to_faiss.py └── faiss.py ├── client_dist ├── s2s_logo.png ├── 674f50d287a8c48dc19ba404d20fe713.eot ├── af7ae505a9eed503f8b8e6982036873e.woff2 ├── b06871f281fee6b241d60582ae9369b9.ttf ├── fee66e712a8a08eef5805a46892932ad.woff └── index.html ├── docs ├── pics │ ├── s2s_teaser.png │ └── s2s_dates_01.png └── Faiss.md ├── setup_client.sh ├── .gitignore ├── .dockerignore ├── docker-compose.yml ├── setup_onmt_custom.sh ├── .editorconfig ├── setup_cpu.sh ├── test_sentences.txt ├── Dockerfile_base ├── Dockerfile ├── .circleci └── config.yml ├── .github └── pull_request_template.md ├── environment.yml ├── .gitattributes ├── swagger.yaml ├── README.md └── LICENSE /index/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model_api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /s2s/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Hendrik Strobelt, Sebastian Gehrmann' -------------------------------------------------------------------------------- /client/css/global.scss: -------------------------------------------------------------------------------- 1 | $main-color: #fffdfa; 2 | //$main-color: #fafafa; 3 | $abc: 12; -------------------------------------------------------------------------------- /scripts/_swigfaiss.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/scripts/_swigfaiss.so -------------------------------------------------------------------------------- /client_dist/s2s_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client_dist/s2s_logo.png -------------------------------------------------------------------------------- /docs/pics/s2s_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/docs/pics/s2s_teaser.png -------------------------------------------------------------------------------- /index/faiss/_swigfaiss.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/index/faiss/_swigfaiss.so -------------------------------------------------------------------------------- /client/assets/s2s_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/assets/s2s_logo.png -------------------------------------------------------------------------------- /docs/pics/s2s_dates_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/docs/pics/s2s_dates_01.png -------------------------------------------------------------------------------- /model_api/processing/_swigfaiss.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/model_api/processing/_swigfaiss.so -------------------------------------------------------------------------------- /client_dist/674f50d287a8c48dc19ba404d20fe713.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client_dist/674f50d287a8c48dc19ba404d20fe713.eot -------------------------------------------------------------------------------- /client_dist/af7ae505a9eed503f8b8e6982036873e.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client_dist/af7ae505a9eed503f8b8e6982036873e.woff2 -------------------------------------------------------------------------------- /client_dist/b06871f281fee6b241d60582ae9369b9.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client_dist/b06871f281fee6b241d60582ae9369b9.ttf -------------------------------------------------------------------------------- /client_dist/fee66e712a8a08eef5805a46892932ad.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client_dist/fee66e712a8a08eef5805a46892932ad.woff -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-Black.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-Black.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-Bold.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-Light.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-Light.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-Italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-Italic.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-Regular.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-SemiBold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-SemiBold.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-BoldItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-BoldItalic.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-ExtraLight.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-ExtraLight.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-BlackItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-BlackItalic.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-LightItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-LightItalic.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-SemiBoldItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-SemiBoldItalic.ttf -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/SourceSansPro-ExtraLightItalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HendrikStrobelt/Seq2Seq-Vis/HEAD/client/fonts/Source_Sans_Pro/SourceSansPro-ExtraLightItalic.ttf -------------------------------------------------------------------------------- /setup_client.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Build client 4 | conda install --name s2sv --yes -c conda-forge nodejs 5 | source activate s2sv 6 | cd client 7 | npm install 8 | npm run build 9 | cd .. 10 | -------------------------------------------------------------------------------- /client/ts/etc/LocalTypes.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | 3 | export type D3Sel = d3.Selection; 4 | // type dObj = { [k: string]: any }; 5 | export interface LooseObject { 6 | [key: string]: any 7 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | /client/ts/**/*.js 3 | *.map 4 | *.pyc 5 | .idea 6 | __pycache__ 7 | /model_api/data 8 | /model_api/processing/.ipynb_checkpoints 9 | tmp.txt 10 | /client_dist 11 | /client/dist 12 | /client/.cache-loader 13 | /data 14 | -------------------------------------------------------------------------------- /client/ts/etc/Util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by hen on 5/15/17. 3 | */ 4 | let the_unique_id_counter = 0; 5 | export class Util { 6 | static simpleUId({prefix = ''}):string { 7 | the_unique_id_counter += 1; 8 | 9 | return prefix + the_unique_id_counter; 10 | } 11 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | /client/node_modules 2 | *.map 3 | *.pyc 4 | .idea 5 | __pycache__ 6 | /model_api/data 7 | /model_api/processing/.ipynb_checkpoints 8 | tmp.txt 9 | /client_dist 10 | /client/dist 11 | /client/.cache-loader 12 | /data 13 | /docs 14 | /.github 15 | /.circleci 16 | /0316-fakedates 17 | -------------------------------------------------------------------------------- /client/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "./dist/", 4 | "module": "esnext", 5 | "target": "es2015", 6 | "moduleResolution": "node", 7 | "sourceMap": true, 8 | "lib": [ 9 | "es2016", 10 | "dom" 11 | ] 12 | }, 13 | "include": [ 14 | "ts/**/*.ts" 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.0' 2 | services: 3 | server: 4 | #image: hendrikstrobelt/seq2seq-vis:base 5 | image: sgratzl/seq2seq-vis-base 6 | working_dir: '/ws' 7 | command: '"source activate s2sv && python3 server.py --dir /data"' 8 | volumes: 9 | - './data/0316-fakedates:/data' 10 | - '.:/ws' 11 | ports: 12 | - '8080:8080' 13 | -------------------------------------------------------------------------------- /setup_onmt_custom.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # just to be sure :) 4 | source activate s2sv 5 | 6 | # clone modified opennmt repo which exposes internals to Seq2Seq-Vis 7 | git clone https://github.com/sebastianGehrmann/OpenNMT-py.git 8 | cd OpenNMT-py/ 9 | git checkout states_in_translation 10 | python setup.py install 11 | pip install torchtext==0.2.3 12 | cd .. 13 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent 2 | # coding styles between different editors and IDEs 3 | # editorconfig.org 4 | 5 | root = true 6 | 7 | 8 | [*] 9 | 10 | # Change these settings to your own preference 11 | indent_style = space 12 | indent_size = 4 13 | 14 | # We recommend you to keep these unchanged 15 | end_of_line = lf 16 | charset = utf-8 17 | trim_trailing_whitespace = true 18 | insert_final_newline = true 19 | 20 | [*.md] 21 | trim_trailing_whitespace = false 22 | -------------------------------------------------------------------------------- /docs/Faiss.md: -------------------------------------------------------------------------------- 1 | ### Install Faiss 2 | 3 | ```bash 4 | brew install --with-clang llvm 5 | PATH=$PATH:/usr/local/opt/llvm/bin 6 | brew install swig 7 | conda install swig 8 | git clone https://github.com/facebookresearch/faiss.git 9 | cd faiss 10 | cp example_makefiles/makefile.inc.Mac.brew ./makefile.inc 11 | make tests/test_blas 12 | ./tests/test_blas 13 | make 14 | make py 15 | python -c "import faiss 16 | python -c "import faiss, numpy 17 | faiss.Kmeans(10, 20).train(numpy.random.rand(1000, 10).astype('float32'))" 18 | ``` -------------------------------------------------------------------------------- /setup_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Install all essential packages 4 | #conda env create -f environment.yml 5 | 6 | # Install all essential packages 7 | conda create --yes --name s2sv python=3.6 h5py numpy scikit-learn=0.19.1 flask 8 | conda install --name s2sv --yes -c conda-forge connexion nodejs python-annoy 9 | conda install --name s2sv --yes pytorch=0.3.1 -c soumith 10 | conda install --name s2sv --yes -c pytorch faiss-cpu 11 | source activate s2sv 12 | 13 | 14 | # Build client 15 | # conda install --name s2sv --yes -c conda-forge nodejs 16 | # cd client 17 | # npm install 18 | # npm run build 19 | # cd .. 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /test_sentences.txt: -------------------------------------------------------------------------------- 1 | es schlug mein herz geschwind zu pferde - es war getan bevor gedacht . 2 | 3 | die längsten reisen fangen an , wenn es auf den straßen dunkel wird . 4 | 5 | the longest journeys begin when it is getting dark in the streets . 6 | 7 | in dem magazin werden mittels modernster internet-technologie die megatrends der zukunft wie beispielsweise eine steigende anzahl von herz-kreislauf-krankheiten, ein wachsender bedarf an hochwertigen nahrungsmitteln und die kohlendioxyd-problematik angesprochen . 8 | 9 | in dem magazin werden mittels modernster technologie die trends der zukunft wie beispielsweise eine steigende anzahl von krankheiten, ein wachsender bedarf an hochwertigen nahrungsmitteln und die kohlendioxyd problematik angesprochen . -------------------------------------------------------------------------------- /Dockerfile_base: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3:4.4.10 2 | 3 | RUN conda update --yes -n base conda &&\ 4 | conda clean --all --yes 5 | 6 | ADD ./environment.yml /tmp/ 7 | RUN conda env create -f /tmp/environment.yml &&\ 8 | conda clean --all --yes &&\ 9 | rm -rf /boot/.cache/pip ~/.cache/pip 10 | 11 | # RUN conda create --yes --name s2sv python=3.6 h5py numpy scikit-learn flask 12 | # RUN conda install --name s2sv --yes -c conda-forge connexion python-annoy 13 | # RUN conda install --name s2sv --yes -c pytorch pytorch==0.3.1 faiss-cpu 14 | 15 | WORKDIR /tmp 16 | ADD ./setup_onmt_custom.sh /tmp/ 17 | RUN /bin/bash /tmp/setup_onmt_custom.sh &&\ 18 | rm -rf /tmp /boot/.cache/pip ~/.cache/pip 19 | 20 | ENTRYPOINT [ "/bin/bash", "-c" ] 21 | -------------------------------------------------------------------------------- /client/README.md: -------------------------------------------------------------------------------- 1 | # Client for Seq2Seq-Vis 2 | 3 | For just using the tool you don't need to do anything here. 4 | 5 | If you want to change/customize the frontend, here are some short hints: 6 | 7 | 1) install `nodejs` 8 | 2) `cd client` 9 | 3) install dependencies and webpack: `node install` 10 | 4) run build: `npm run build` or live watch: `npm run watch` 11 | 12 | Warning: the stack is from 2019 and requires some very specific webpack modules. 13 | 14 | 15 | Folder structure: 16 | ``` 17 | /assets, /css, /fonts ==> assets, styles, local fonts used 18 | /ts/api ==> calls to backend and return types 19 | /ts/controller ==> view controller meta classes 20 | /ts/etc ==> helpers 21 | /ts/vis ==> visualziation components (super-class: VisualComponent.ts)``` 22 | -------------------------------------------------------------------------------- /s2s/lru.py: -------------------------------------------------------------------------------- 1 | class LRU: 2 | 3 | def __init__(self, k=5): 4 | self.k = k 5 | self.cache = [] 6 | self.insert_to = 0 7 | 8 | def preload(self, key, obj, persist=True): 9 | self.add(key, obj) 10 | if persist: 11 | self.insert_to += 1 12 | 13 | def get(self, key): 14 | i = 0 15 | l = len(self.cache) 16 | hit = None 17 | while i < l and not hit: 18 | if self.cache[i]['key'] == key: 19 | hit = self.cache[i] 20 | i += 1 21 | 22 | if hit: 23 | self.cache.remove(hit) 24 | self.cache.insert(self.insert_to, hit) 25 | return hit['object'] 26 | else: 27 | return None 28 | 29 | def add(self, key, obj): 30 | self.cache.insert(self.insert_to, {'key': key, 'object': obj}) 31 | if len(self.cache) > self.k: 32 | self.cache.pop() 33 | -------------------------------------------------------------------------------- /client/ts/etc/SimpleEventHandler.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by hen on 5/15/17. 3 | */ 4 | export class SimpleEventHandler { 5 | 6 | element: Element; 7 | eventListeners: object[]; 8 | 9 | 10 | constructor(element: Element) { 11 | this.element = element; 12 | this.eventListeners = [] 13 | } 14 | 15 | 16 | bind(eventNames: string, eventFunction: Function) { 17 | for (const eventName of eventNames.split(' ')) { 18 | this.eventListeners.push({eventName, eventFunction}); 19 | const eventFunctionWrap = e => eventFunction(e.detail, e); 20 | this.element.addEventListener(eventName, eventFunctionWrap, false); 21 | } 22 | } 23 | 24 | getListeners() { 25 | return this.eventListeners; 26 | } 27 | 28 | trigger(eventName: string, detail: object) { 29 | this.element.dispatchEvent(new CustomEvent(eventName, {detail})); 30 | } 31 | 32 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM hendrikstrobelt/seq2seq-vis:base 2 | FROM sgratzl/seq2seq-vis-base 3 | 4 | # FROM continuumio/miniconda3:4.4.10 5 | 6 | # RUN conda update --yes -n base conda &&\ 7 | # conda clean --all --yes 8 | # 9 | # ADD ./environment.yml /tmp/ 10 | # RUN conda env create -f /tmp/environment.yml &&\ 11 | # conda clean --all --yes &&\ 12 | # rm -rf /boot/.cache/pip ~/.cache/pip 13 | # 14 | # WORKDIR /tmp 15 | # ADD ./setup_onmt_custom.sh /tmp/ 16 | # RUN /bin/bash /tmp/setup_onmt_custom.sh &&\ 17 | # rm -rf /tmp /boot/.cache/pip ~/.cache/pip 18 | # 19 | # ENTRYPOINT [ "/bin/bash", "-c" ] 20 | 21 | WORKDIR /ws 22 | EXPOSE 8080 23 | VOLUME /data 24 | CMD [ "source activate s2sv && python3 server.py --dir /data --cache /data/cache" ] 25 | 26 | ADD . /ws/ 27 | 28 | # build client and clean up afterwards 29 | RUN /bin/bash /ws/setup_client.sh &&\ 30 | conda uninstall --name s2sv --yes nodejs &&\ 31 | conda clean --all --yes &&\ 32 | rm -rf client ~/.npm /boot/.cache/pip ~/.cache/pip 33 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | working_directory: ~/ws 5 | docker: 6 | - image: circleci/node:8 7 | - image: docker:17.05.0-ce-git 8 | steps: 9 | - checkout 10 | - setup_remote_docker 11 | - restore_cache: 12 | key: dependency-cache-{{ checksum "client/package.json" }} 13 | - run: 14 | name: install npm 15 | command: | 16 | cd client 17 | npm install 18 | - save_cache: 19 | key: dependency-cache-{{ checksum "client/package.json" }} 20 | paths: 21 | - ./client/node_modules 22 | - run: 23 | name: build client 24 | command: | 25 | cd client 26 | npm run build 27 | - run: 28 | name: build image 29 | command: | 30 | docker build -t seq2seq-vis . 31 | workflows: 32 | version: 2 33 | build-n-deploy: 34 | jobs: 35 | - build: 36 | filters: 37 | tags: 38 | only: /^v.*/ 39 | -------------------------------------------------------------------------------- /client/ts/etc/SVGplus.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3" 2 | 3 | /** 4 | * Created by hen on 5/15/17. 5 | */ 6 | export class SVG { 7 | static translate({x, y}) { 8 | return "translate(" + x + "," + y + ")" 9 | } 10 | 11 | static group(parent, classes, pos = {x: 0, y: 0}) { 12 | return parent.append('g').attrs({ 13 | class: classes, 14 | "transform": SVG.translate(pos) 15 | }) 16 | } 17 | 18 | } 19 | 20 | export class SVGMeasurements { 21 | 22 | private measureElement: d3.Selection; 23 | 24 | constructor(baseElement, classes = '') { 25 | this.measureElement = baseElement.append('text') 26 | .attrs({x: 0, y: -20, class: classes}) 27 | 28 | } 29 | 30 | textLength(text, style = null) { 31 | this.measureElement.attr('style', style); 32 | this.measureElement.text(text); 33 | const tl = ( this.measureElement.node()).getComputedTextLength(); 34 | this.measureElement.text(''); 35 | 36 | return tl; 37 | } 38 | } -------------------------------------------------------------------------------- /model_api/onmt_lua_model_api.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import logging 4 | 5 | from model_api.abstract_model_api import AbstractModelAPI 6 | 7 | __author__ = 'Hendrik Strobelt, Sebastian Gehrmann' 8 | 9 | 10 | class ONMTLuaModelAPI(AbstractModelAPI): 11 | def __init__(self, url: str = "http://127.0.0.1:7784/translator/translate"): 12 | self.url = url 13 | 14 | def translate(self, in_text: str, partial_decode: str = None, k: int = 1, attn: dict = None): 15 | if k > 1: 16 | logging.warning('This version of the API only supports top 1 prediction. Sorry..') 17 | 18 | if partial_decode: 19 | logging.warning('This version of the API does not support partial decode.. ') 20 | 21 | response = requests.post(self.url, data=json.dumps([{"src": in_text}])) 22 | 23 | # response: [[{'src': 'Hello World', 'tgt': 'Hallo Welt', 'pred_score': -0.1768690943718, 'attn': [[0.62342292070389, 24 | # 0.37657704949379], [0.16017833352089, 0.83982169628143]], 'n_best': 1}]] 25 | 26 | r = response.json()[0][0] 27 | 28 | return { 29 | 'encoder': list(map(lambda x: {'token': x}, r['src'].split())), 30 | 'decoder': [list(map(lambda x: {'token': x}, r['tgt'].split()))], 31 | 'attn': [list(r['attn'])] 32 | } 33 | 34 | def n_closest_tokens(self, token: str, n: int = 10): 35 | pass 36 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | By submitting the Pull request the contributor agrees to the Certificate of Origin: 2 | 3 | ``` 4 | Developer's Certificate of Origin 1.1 5 | 6 | By making a contribution to this project, I certify that: 7 | 8 | (a) The contribution was created in whole or in part by me and I 9 | have the right to submit it under the open source license 10 | indicated in the file; or 11 | 12 | (b) The contribution is based upon previous work that, to the best 13 | of my knowledge, is covered under an appropriate open source 14 | license and I have the right under that license to submit that 15 | work with modifications, whether created in whole or in part 16 | by me, under the same open source license (unless I am 17 | permitted to submit under a different license), as indicated 18 | in the file; or 19 | 20 | (c) The contribution was provided directly to me by some other 21 | person who certified (a), (b) or (c) and I have not modified 22 | it. 23 | 24 | (d) I understand and agree that this project and the contribution 25 | are public and that a record of the contribution (including all 26 | personal information I submit with it, including my sign-off) is 27 | maintained indefinitely and may be redistributed consistent with 28 | this project or the open source license(s) involved. 29 | ``` 30 | -------------------------------------------------------------------------------- /scripts/h5_to_faiss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import faiss 3 | import h5py 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | print("Loaded libraries...") 8 | 9 | parser = argparse.ArgumentParser( 10 | description='''h5_to_faiss.py is used to go 11 | from extracted states to 12 | a faiss index 13 | ''') 14 | parser.add_argument( 15 | '-states', 16 | required=True, 17 | type=str, 18 | help="""Path of the states file""") 19 | parser.add_argument( 20 | '-data', 21 | type=str, 22 | default="decoder_out", 23 | help="""Which set within the states to use""") 24 | 25 | parser.add_argument( 26 | '-output', default="index.faiss", 27 | type=str, 28 | help="""Path of the output file""") 29 | parser.add_argument( 30 | '-stepsize', type=int, default=100, 31 | help="""Add that many sequences at once 32 | (larger = more memory, but faster).""") 33 | 34 | opt = parser.parse_args() 35 | 36 | def main(): 37 | f = h5py.File(opt.states, "r") 38 | data = f[opt.data] 39 | seqs, slens, hid = data.shape 40 | 41 | print("Processing {} Sequences".format(seqs)) 42 | print("with {} tokens each".format(slens)) 43 | print("and {} states".format(hid)) 44 | 45 | # Initialize a new index 46 | index = faiss.IndexFlatIP(hid) 47 | # Fill it 48 | for ix in tqdm(range(0, seqs-opt.stepsize, opt.stepsize)): 49 | cdata = np.array(data[ix:ix+opt.stepsize]\ 50 | .reshape(-1, hid), dtype="float32") 51 | index.add(cdata) 52 | f.close() 53 | 54 | faiss.write_index(index, opt.output) 55 | 56 | if __name__ == "__main__": 57 | main() -------------------------------------------------------------------------------- /model_api/abstract_model_api.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hs, sg' 2 | 3 | 4 | class AbstractModelAPI: 5 | def translate(self, in_text: str, partial_decode: str = None, k: int = 10, attn: dict = None): 6 | """ 7 | translates `in_text` using the associated model and returns 8 | meta information about the encoder side , the top `k` best 9 | translations (incl. decoder states and meta-data), and the 10 | attention values used for translation (k x dec_length x encoder_length) 11 | 12 | Returns an dict: 13 | {encoder: [{<1>},...], decoder: [[{<1>},..]], attn: [[[<2>],...]]}; 14 | 15 | encoder, decoder, attn -- the inner [{},..] lists tokens in their order 16 | decoder, attn -- the outer [[..]] lists the sorted top k predictions; 17 | 18 | <1>: {token:str, state:[], embed:[]} -- for each token encoder/decoder 19 | <2>: [[[a: attention for each encoder word]: for each decoder token]: 20 | for each of top k]; 21 | 22 | 23 | 24 | :param attn: use alternative attention values 25 | :param k: k for top k translations 26 | :param partial_decode: a partial decoder string defines the left side 27 | of a decoding sequence 28 | :param in_text: the input text 29 | :return: see dictionary above 30 | 31 | 32 | 33 | """ 34 | raise NotImplementedError(".. has to be implemented") 35 | 36 | def n_closest_tokens(self, token: str, n: int = 10): 37 | """ 38 | returns the n closest tokens using cosine distance between embeddings 39 | 40 | :param token: the token 41 | :param n: how many closest token to return 42 | :return: list/array of tokens 43 | """ 44 | 45 | raise NotImplementedError(".. has to be implemented") 46 | -------------------------------------------------------------------------------- /client/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "S2SVis", 3 | "version": "0.9.0", 4 | "description": "", 5 | "scripts": { 6 | "test": "echo \"Error: no test specified\" && exit 1", 7 | "wp": "npm run build:dev", 8 | "ww": "npm run watch", 9 | "stats": "webpack --mode production --json --profile > stats.json", 10 | "build": "webpack --mode production", 11 | "build:dev": "webpack --mode development --devtool=inline-source-map", 12 | "watch": "webpack --mode development --devtool=inline-source-map --watch", 13 | "start": "webpack-dev-server --mode development --devtool=inline-source-map", 14 | "preversion": "npm run build", 15 | "release:major": "npm version major && git push --follow-tags", 16 | "release:minor": "npm version minor && git push --follow-tags", 17 | "release:patch": "npm version patch && git push --follow-tags", 18 | "release:pre": "npm version prerelease && git push --follow-tags" 19 | }, 20 | "author": "", 21 | "license": "Apache-2.0", 22 | "dependencies": { 23 | "@types/concaveman": "^1.1.3", 24 | "@types/d3": "^4.12.0", 25 | "@types/d3-selection-multi": "^1.0.1", 26 | "@types/lodash": "^4.14.92", 27 | "@types/node": "^9.6.7", 28 | "concaveman": "^1.1.1", 29 | "d3": "^4.12.0", 30 | "d3-selection-multi": "^1.0.1", 31 | "font-awesome": "~4.7.0", 32 | "lodash": "^4.17.10", 33 | "node-sass": "^5.0.0", 34 | "webcola": "^3.3.8" 35 | }, 36 | "devDependencies": { 37 | "cache-loader": "^1.2.2", 38 | "css-loader": "^0.28.11", 39 | "extract-text-webpack-plugin": "^4.0.0-beta.0", 40 | "file-loader": "^1.1.11", 41 | "fork-ts-checker-webpack-plugin": "^0.4.1", 42 | "sass-loader": "^7.0.1", 43 | "style-loader": "^0.20.3", 44 | "thread-loader": "^1.1.5", 45 | "ts-loader": "^4.2.0", 46 | "typescript": "^3.5.1", 47 | "uglifyjs-webpack-plugin": "^1.2.4", 48 | "url-loader": "^1.0.1", 49 | "webpack": "^4.6.0", 50 | "webpack-cli": "^2.0.14", 51 | "webpack-dev-server": "^3.1.6" 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /client/css/modalDialog.scss: -------------------------------------------------------------------------------- 1 | .modalDialog { 2 | position: fixed; 3 | top: 10px; 4 | left: 100px; 5 | width: 200px; 6 | background-color: white; 7 | padding: 10px; 8 | box-shadow: 5px 5px 5px grey; 9 | 10 | //button { 11 | // font-family: inherit; 12 | // font-size: smaller; 13 | // border-style: none; 14 | // padding: 2px 5px; 15 | //} 16 | // 17 | //button:focus { 18 | // outline: none; 19 | //} 20 | 21 | .uiButton { 22 | border-radius: 3px; 23 | background-color: rgba(179, 179, 179, .3); 24 | } 25 | 26 | .uiButton:hover { 27 | background-color: rgba(179, 179, 179, 1); 28 | } 29 | 30 | .uiRow { 31 | margin-bottom: 10px; 32 | } 33 | 34 | .uiCancel{ 35 | background-color: #9c0012; 36 | color: lightgrey; 37 | } 38 | 39 | } 40 | 41 | .modalDialog .textInput { 42 | box-sizing: border-box; 43 | outline: #b3b3b3 solid 1px; 44 | border: none; 45 | width: 100%; 46 | padding: 5px; 47 | -webkit-transition: 0.3s ease all; 48 | -moz-transition: 0.3s ease all; 49 | transition: 0.3s ease all; 50 | } 51 | 52 | .modalDialog .textInput:hover { 53 | background-color: #eee; 54 | } 55 | 56 | .modalDialog .selectInput { 57 | background: url("data:image/svg+xml;utf8,"); 58 | background-color: #eee; 59 | background-repeat: no-repeat; 60 | background-position: right 5px top 5px; 61 | background-size: 15px 15px; 62 | padding: 5px; 63 | width: 100%; 64 | text-align: center; 65 | border-radius: 3px; 66 | -webkit-border-radius: 3px; 67 | -webkit-appearance: none; 68 | border: 0; 69 | outline: 0; 70 | -webkit-transition: 0.3s ease all; 71 | -moz-transition: 0.3s ease all; 72 | transition: 0.3s ease all; 73 | } 74 | 75 | .modalDialog .select:hover { 76 | background-color: #ddd; 77 | } 78 | 79 | .inactivator { 80 | position: fixed; 81 | top: 0; 82 | left: 0; 83 | width: 100%; 84 | height: 100%; 85 | background: white; 86 | } 87 | 88 | -------------------------------------------------------------------------------- /client/ts/vis/BarList.ts: -------------------------------------------------------------------------------- 1 | import {VComponent} from "./VisualComponent"; 2 | import * as d3 from "d3"; 3 | 4 | export interface BarListData { 5 | extent: [number, number], 6 | values: number[] 7 | } 8 | 9 | 10 | export class BarList extends VComponent { 11 | 12 | css_name = 'barlist'; 13 | 14 | static events = {}; 15 | 16 | options = { 17 | pos: {x: 0, y: 0}, 18 | width: 90, 19 | bar_height: 20, 20 | css_class_main: 'bar_list_vis', 21 | css_bar: 'bar' 22 | }; 23 | 24 | _current= { 25 | hidden: false, 26 | xScale:d3.scaleLinear() 27 | } 28 | 29 | 30 | constructor(d3Parent, eventHandler, options:{} = {}) { 31 | super(d3Parent, eventHandler); 32 | this.superInit(options, false) 33 | } 34 | 35 | _init() { 36 | } 37 | 38 | _wrangle(data: BarListData): BarListData { 39 | const cur = this._current; 40 | const op = this.options; 41 | 42 | 43 | const ex = data.extent; 44 | cur.xScale = 45 | d3.scaleLinear() 46 | .domain(ex) 47 | .range([op.width, 0]); 48 | 49 | const barValues = data.values; 50 | 51 | this.parent.attrs({ 52 | width: op.width, 53 | height: barValues.length * op.bar_height 54 | }); 55 | 56 | return data; 57 | } 58 | 59 | _render(rData: BarListData) { 60 | 61 | const op = this.options; 62 | const cur = this._current; 63 | 64 | const bars = this.base.selectAll(`.${op.css_bar}`).data(rData.values); 65 | bars.exit().remove(); 66 | 67 | const barsEnter = bars.enter().append('rect').attr('class', op.css_bar); 68 | 69 | 70 | barsEnter.merge(bars).attrs({ 71 | x: d => op.width - cur.xScale(d), 72 | y: (_, i) => i * op.bar_height, 73 | height: op.bar_height - 2, 74 | width: d => cur.xScale(d) 75 | }) 76 | 77 | } 78 | 79 | 80 | get xScale() { 81 | return this._current.xScale; 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: s2sv 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - asn1crypto=0.24.0=py36_0 8 | - ca-certificates=2018.4.16=0 9 | - cffi=1.11.5=py36_0 10 | - chardet=3.0.4=py36_0 11 | - clickclick=1.2.2=py36_0 12 | - connexion=1.4=py36_0 13 | - cryptography=2.2.1=py36_0 14 | - idna=2.6=py36_1 15 | - jsonschema=2.6.0=py36_1 16 | - pycparser=2.18=py36_0 17 | - pyopenssl=17.5.0=py36_1 18 | - pysocks=1.6.8=py36_1 19 | - python-annoy=1.11.5=py36_0 20 | - pyyaml=3.12=py36_1 21 | - requests=2.18.4=py36_1 22 | - swagger-spec-validator=2.1.0=py36_0 23 | - typing=3.6.4=py36_0 24 | - urllib3=1.22=py36_0 25 | - yaml=0.1.7=0 26 | - certifi=2018.4.16=py36_0 27 | - click=6.7=py36h5253387_0 28 | - cudatoolkit=8.0=3 29 | - flask=0.12.2=py36hb24657c_0 30 | - h5py=2.7.1=py36h3585f63_0 31 | - hdf5=1.10.1=h9caa474_1 32 | - inflection=0.3.1=py36_0 33 | - intel-openmp=2018.0.0=8 34 | - itsdangerous=0.24=py36h93cc618_1 35 | - jinja2=2.10=py36ha16c418_0 36 | - libedit=3.1=heed3624_0 37 | - libffi=3.2.1=hd88cf55_4 38 | - libgcc-ng=7.2.0=hdf63c60_3 39 | - libgfortran-ng=7.2.0=hdf63c60_3 40 | - libstdcxx-ng=7.2.0=hdf63c60_3 41 | - markupsafe=1.0=py36hd9260cd_1 42 | - mkl=2018.0.2=1 43 | - mkl_fft=1.0.1=py36h3010b51_0 44 | - mkl_random=1.0.1=py36h629b387_0 45 | - ncurses=6.0=h9df7e31_2 46 | - numpy=1.14.2=py36hdbf6ddf_1 47 | - openssl=1.0.2o=h20670df_0 48 | - pip=9.0.3=py36_0 49 | - python=3.6.5=hc3d631a_1 50 | - readline=7.0=ha6073c6_4 51 | - scikit-learn=0.19.1=py36h7aa7ec6_0 52 | - scipy=1.0.1=py36hfc37229_0 53 | - setuptools=39.0.1=py36_0 54 | - six=1.11.0=py36h372c433_1 55 | - sqlite=3.23.1=he433501_0 56 | - tk=8.6.7=hc745277_3 57 | - werkzeug=0.14.1=py36_0 58 | - wheel=0.31.0=py36_0 59 | - xz=5.2.3=h5e939de_4 60 | - zlib=1.2.11=ha838bed_2 61 | - faiss-cpu=1.2.1=py36_cuda0.0_2 62 | - pytorch=0.3.1=py36_cuda8.0.61_cudnn7.1.2_3 63 | - pip: 64 | - annoy==1.11.5 65 | - faiss==0.1 66 | - torch==0.3.1.post3 67 | - torchtext==0.2.3 68 | - tqdm==4.23.1 69 | 70 | -------------------------------------------------------------------------------- /client/ts/etc/ModalDialog.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | 3 | import '../../css/modalDialog.scss' 4 | import {D3Sel} from "./LocalTypes"; 5 | import {SimpleEventHandler} from "./SimpleEventHandler"; 6 | 7 | 8 | export default class ModalDialog { 9 | 10 | static get events() { 11 | return { 12 | modalDialogCanceled: 'modalDialogCanceled', 13 | modalDialogSubmitted: 'modalDialogSubmitted' 14 | } 15 | } 16 | 17 | static open(rootNode: D3Sel, eventHandler: SimpleEventHandler, width = 300) { 18 | 19 | // Bind the buttons 20 | 21 | rootNode.selectAll('.uiCancel') 22 | .on('click', () => { 23 | ModalDialog.close(rootNode); 24 | eventHandler.trigger(ModalDialog.events.modalDialogCanceled, rootNode); 25 | }); 26 | 27 | 28 | rootNode.selectAll('.uiSubmit') 29 | .on('click', () => { 30 | eventHandler.trigger(ModalDialog.events.modalDialogSubmitted, rootNode); 31 | }); 32 | 33 | // Make it appear nicely :) 34 | 35 | d3.select('body').append('div') 36 | .attr('class', 'inactivator') 37 | .styles({opacity: 0}) 38 | .transition().style('opacity', 0.5) 39 | 40 | rootNode.attr('hidden', null) 41 | const dialogHeight = rootNode.node().clientHeight; 42 | rootNode 43 | .raise() 44 | .style('width', `${width}px`) 45 | .style('opacity', 1) 46 | .style('top', `${-dialogHeight}px`) 47 | .style('left', `${(window.innerWidth - width) / 2}px`) 48 | 49 | 50 | rootNode.transition() 51 | .style('top', '5px') 52 | 53 | 54 | } 55 | 56 | static close(rootNode: D3Sel) { 57 | d3.selectAll('.inactivator').remove(); 58 | 59 | const dialogHeight = rootNode.node().clientHeight; 60 | 61 | rootNode.transition() 62 | // .duration(2000) 63 | .style('top', `${-dialogHeight}px`) 64 | .style('opacity', 0) 65 | .on('end', function () { 66 | d3.select(this).attr('hidden', true) 67 | }) 68 | } 69 | 70 | 71 | } -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # These settings are for any web project 2 | 3 | # Handle line endings automatically for files detected as text 4 | # and leave all files detected as binary untouched. 5 | * text=auto 6 | 7 | # 8 | # The above will handle all files NOT found below 9 | # 10 | 11 | # 12 | ## These files are text and should be normalized (Convert crlf => lf) 13 | # 14 | 15 | # source code 16 | *.php text 17 | *.css text 18 | *.sass text 19 | *.scss text 20 | *.less text 21 | *.styl text 22 | *.js text 23 | *.ts text 24 | *.coffee text 25 | *.json text 26 | *.htm text 27 | *.html text 28 | *.xml text 29 | *.txt text 30 | *.ini text 31 | *.inc text 32 | *.pl text 33 | *.rb text 34 | *.py text 35 | *.scm text 36 | *.sql text 37 | *.sh text eof=LF 38 | *.bat text 39 | 40 | # templates 41 | *.hbt text 42 | *.jade text 43 | *.haml text 44 | *.hbs text 45 | *.dot text 46 | *.tmpl text 47 | *.phtml text 48 | 49 | # server config 50 | .htaccess text 51 | 52 | # git config 53 | .gitattributes text 54 | .gitignore text 55 | 56 | # code analysis config 57 | .jshintrc text 58 | .jscsrc text 59 | .jshintignore text 60 | .csslintrc text 61 | 62 | # misc config 63 | *.yaml text 64 | *.yml text 65 | .editorconfig text 66 | 67 | # build config 68 | *.npmignore text 69 | *.bowerrc text 70 | Dockerfile text eof=LF 71 | 72 | # Heroku 73 | Procfile text 74 | .slugignore text 75 | 76 | # Documentation 77 | *.md text 78 | LICENSE text 79 | AUTHORS text 80 | 81 | 82 | # 83 | ## These files are binary and should be left untouched 84 | # 85 | 86 | # (binary is a macro for -text -diff) 87 | *.png binary 88 | *.jpg binary 89 | *.jpeg binary 90 | *.gif binary 91 | *.ico binary 92 | *.mov binary 93 | *.mp4 binary 94 | *.mp3 binary 95 | *.flv binary 96 | *.fla binary 97 | *.swf binary 98 | *.gz binary 99 | *.zip binary 100 | *.7z binary 101 | *.ttf binary 102 | *.pyc binary 103 | *.pdf binary 104 | 105 | # Source files 106 | # ============ 107 | *.pxd text 108 | *.py text 109 | *.py3 text 110 | *.pyw text 111 | *.pyx text 112 | *.sh text eol=lf 113 | *.json text 114 | 115 | # Binary files 116 | # ============ 117 | *.db binary 118 | *.p binary 119 | *.pkl binary 120 | *.pyc binary 121 | *.pyd binary 122 | *.pyo binary 123 | 124 | # Note: .db, .p, and .pkl files are associated 125 | # with the python modules ``pickle``, ``dbm.*``, 126 | # ``shelve``, ``marshal``, ``anydbm``, & ``bsddb`` 127 | # (among others). 128 | -------------------------------------------------------------------------------- /index/annoyVectorIndex.py: -------------------------------------------------------------------------------- 1 | from annoy import AnnoyIndex 2 | 3 | 4 | class AnnoyVectorIndex: 5 | 6 | def __init__(self, file_name, dim_vector=500): 7 | self.u = AnnoyIndex(dim_vector) 8 | self.u.load(file_name) 9 | 10 | def get_closest(self, ix, k=10, ignore_same_tgt=False, 11 | include_distances=False, use_vectors=False): 12 | if ignore_same_tgt: 13 | interval_min = ix // 55 * 55 14 | if use_vectors: 15 | candidates = self.u.get_nns_by_vector(ix, k + 55, 16 | search_k=100000, 17 | include_distances=include_distances) 18 | else: 19 | candidates = self.u.get_nns_by_item(ix, k + 55, search_k=100000, 20 | include_distances=include_distances) 21 | if include_distances: 22 | return [k for k in zip(*candidates) 23 | if not interval_min <= k[0] <= interval_min + 55][:k] 24 | else: 25 | return [k for k in candidates 26 | if not interval_min <= k <= interval_min + 55][:k] 27 | else: 28 | if use_vectors: 29 | return list( 30 | zip(*self.u.get_nns_by_vector(ix, k, search_k=100000, 31 | include_distances=include_distances))) 32 | 33 | else: 34 | return list(zip(*self.u.get_nns_by_item(ix, k, search_k=100000, 35 | include_distances=include_distances))) 36 | 37 | def get_closest_x(self, ixs, k=10, ignore_same_tgt=False, 38 | include_distances=False, use_vectors=False): 39 | res = [] 40 | for ix in ixs: 41 | res.append( 42 | self.get_closest(ix, k, ignore_same_tgt, include_distances, 43 | use_vectors)) 44 | return res 45 | 46 | def get_details(self, ixs): 47 | res = [] 48 | for ix in ixs: 49 | res.append({'index': ix, 50 | 'v': self.u.get_item_vector(ix), 51 | 'pos': self.search_to_sentence_index(ix)}) 52 | 53 | return res 54 | 55 | def get_vectors(self, ixs): 56 | return map(lambda x: self.u.get_item_vector(x), ixs) 57 | 58 | def get_vector(self, ix): 59 | return self.u.get_item_vector(ix) 60 | 61 | def search_to_sentence_index(self, index): 62 | return index // 55, index % 55 63 | 64 | def sentence_to_search_index(self, sentence, pos_in_sent): 65 | return sentence * 55 + pos_in_sent 66 | -------------------------------------------------------------------------------- /client/css/main.scss: -------------------------------------------------------------------------------- 1 | @import "global"; 2 | @import "vis"; 3 | 4 | @import "~font-awesome/css/font-awesome.min.css"; 5 | @import "../fonts/ssp.css"; 6 | 7 | body { 8 | font-family: 'Source Sans Pro', sans-serif; 9 | font-weight: 300; 10 | font-size: 10pt; 11 | margin: 0; 12 | background: $main-color; 13 | } 14 | 15 | input[type="text"] { 16 | font-size: 12pt; 17 | font-family: 'Source Sans Pro', sans-serif; 18 | font-weight: 600; 19 | margin-top: 3px; 20 | display: inline-block; 21 | border: 1px solid #ccc; 22 | //box-shadow: inset 0 1px 3px #ddd; 23 | border-radius: 4px; 24 | -webkit-box-sizing: border-box; 25 | -moz-box-sizing: border-box; 26 | box-sizing: border-box; 27 | padding: 8px 6px; 28 | 29 | 30 | } 31 | 32 | ::placeholder { 33 | color: #b3b3b3; 34 | } 35 | 36 | 37 | 38 | ul.topnav { 39 | list-style-type: none; 40 | margin: 0; 41 | padding: 0; 42 | overflow: hidden; 43 | background-color: #d4d4d4; 44 | 45 | li { 46 | float: left; 47 | 48 | a { 49 | display: block; 50 | color: #4a4a4a; 51 | text-align: center; 52 | padding: 4px 6px; 53 | text-decoration: none; 54 | } 55 | a:hover:not(.active) { 56 | background-color: #b3deff; 57 | } 58 | 59 | a.active { 60 | background-color: #80ccff; 61 | } 62 | 63 | &.right { 64 | float: right; 65 | } 66 | 67 | } 68 | 69 | } 70 | 71 | @media screen and (max-width: 400px) { 72 | ul.topnav li.right, 73 | ul.topnav li { 74 | float: none; 75 | } 76 | } 77 | 78 | .full_width_input { 79 | width: 100%; 80 | //margin: 0 5px; 81 | } 82 | 83 | .header { 84 | padding: 5px; 85 | border-bottom: solid 1px lightgray; 86 | } 87 | 88 | .mainPanel { 89 | display: table; 90 | width: 100%; 91 | } 92 | 93 | .column { 94 | display: table-cell; 95 | vertical-align: top; 96 | padding-top: 10px; 97 | } 98 | 99 | .setup { 100 | text-align: right; 101 | } 102 | 103 | svg { 104 | display: block; 105 | } 106 | 107 | .scrollable { 108 | width: 100%; 109 | overflow: auto; 110 | } 111 | 112 | select { 113 | background: url("data:image/svg+xml;utf8,"); 114 | background-color: lightgrey; 115 | background-repeat: no-repeat; 116 | background-position: right 5px top 9px; 117 | background-size: 7px 7px; 118 | padding: 12px; 119 | width: auto; 120 | //font-size:16px; 121 | //font-weight: bold; 122 | //text-align:center; 123 | //text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25); 124 | border-radius: 5px; 125 | -webkit-border-radius: 5px; 126 | -webkit-appearance: none; 127 | border: 0; 128 | outline: 0; 129 | } -------------------------------------------------------------------------------- /client/ts/main.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | import "d3-selection-multi"; 3 | 4 | import {PanelController} from "./controller/PanelController"; 5 | import {S2SApi} from "./api/S2SApi"; 6 | import '../css/main.scss' 7 | import URLHandler from "./etc/URLHandler"; 8 | import {Translation} from "./api/Translation"; 9 | 10 | import "!file-loader?name=index.html!../index.html"; 11 | import "!file-loader?name=s2s_logo.png!../assets/s2s_logo.png"; 12 | 13 | 14 | declare const __VERSION__:string; 15 | declare const __BUILDID__:string; 16 | 17 | window.onload = () => { 18 | const panelCtrl = new PanelController(); 19 | 20 | 21 | // --- EVENTS --- 22 | const translate = (value) => { 23 | S2SApi.translate({input: value, neighbors: []}) 24 | .then((data: string) => { 25 | const raw_data = JSON.parse(data); 26 | panelCtrl.clearCompare(); 27 | panelCtrl.update(new Translation(raw_data)); 28 | panelCtrl.cleanPanels(); 29 | 30 | 31 | (document.querySelector('#spinner')).style.display = 'none'; 32 | }) 33 | .catch((error: Error) => console.log(error, "--- error")); 34 | }; 35 | 36 | const updateAllVis = () => { 37 | (document.querySelector('#spinner')).style.display = null; 38 | const value = ( d3.select('#query_input').node()) 39 | .value.trim(); 40 | 41 | URLHandler.setURLParam('in', value, false); 42 | translate(value); 43 | }; 44 | // const updateDebounced = _.debounce(updateAllVis, 1000); 45 | 46 | 47 | /* **************** 48 | * URL param 'in' triggers query 49 | * *****************/ 50 | 51 | 52 | d3.select('#query_input') 53 | .on('keypress', () => { 54 | const keycode = d3.event.keyCode; 55 | if (d3.event instanceof KeyboardEvent 56 | && (keycode === 13) //|| keycode === 32 57 | ) { 58 | // updateDebounced(); 59 | updateAllVis(); 60 | } 61 | }); 62 | 63 | 64 | // TODO: needed ? 65 | // function windowResize() { 66 | // const width = window.innerWidth; 67 | // const height = window.innerHeight 68 | // - (document.querySelector("#title")).offsetHeight 69 | // - (document.querySelector("#ui")).offsetHeight - 5; 70 | // globalEvents.trigger('svg-resize', {width, height}) 71 | // } 72 | // window.addEventListener('resize', windowResize); 73 | // windowResize(); 74 | 75 | 76 | S2SApi.project_info(null).then((data) => { 77 | 78 | data = JSON.parse(data); 79 | 80 | panelCtrl.updateProjectInfo(data); 81 | 82 | const input_from_url = URLHandler.parameters['in']; 83 | if (input_from_url) { 84 | ( d3.select('#query_input').node()) 85 | .value = input_from_url; 86 | translate(input_from_url); 87 | } 88 | 89 | 90 | }) 91 | 92 | 93 | }; 94 | -------------------------------------------------------------------------------- /index/faissVectorIndex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | # sys.path.append('faiss') 5 | import faiss 6 | 7 | 8 | class FaissVectorIndex: 9 | 10 | def __init__(self, file_name, dim_vector=500, sentence_max_len=50): 11 | self.u = faiss.read_index(file_name) # type: faiss.Index 12 | self.sentence_max_length = sentence_max_len 13 | 14 | def get_closest(self, ix, k=10, ignore_same_tgt=False, 15 | include_distances=False, use_vectors=False): 16 | """ 17 | :param ix: vector or index ID 18 | :param k: number of nearest neighbors 19 | :param ignore_same_tgt: 20 | :param include_distances: 21 | :param use_vectors: 22 | :return: list [(id, distance),...] 23 | """ 24 | 25 | ix_conv = np.array([ix], dtype='float32') 26 | 27 | if use_vectors: 28 | candidates = self.u.search(ix_conv, k) 29 | print(candidates) 30 | else: 31 | print('not possible') 32 | candidates = [] 33 | 34 | if ignore_same_tgt: 35 | print('not possible') 36 | return [] 37 | # interval_min = ix // 55 * 55 38 | # 39 | # if include_distances: 40 | # return [k for k in zip(candidates[1][0], candidates[0][0]) 41 | # if not interval_min <= k[0] <= interval_min + 55][:k] 42 | # else: 43 | # return [k for k in candidates[1][0] 44 | # if not interval_min <= k <= interval_min + 55][:k] 45 | else: 46 | if include_distances: 47 | return zip(candidates[1][0].tolist(), candidates[0][0].tolist()) 48 | else: 49 | return candidates[1][0].tolist() 50 | 51 | def get_closest_x(self, ixs, k=10, ignore_same_tgt=False, 52 | include_distances=False, use_vectors=False): 53 | res = [] 54 | 55 | ix_conv = np.array(ixs, dtype='float32') 56 | 57 | dists, inds = self.u.search(ix_conv, k) 58 | 59 | for i in range(dists.shape[0]): 60 | res.append(zip(inds[i].tolist(), dists[i].tolist())) 61 | 62 | # for ix in ixs: 63 | # res.append( 64 | # self.get_closest(ix, k, ignore_same_tgt, include_distances, 65 | # use_vectors)) 66 | return res 67 | 68 | def get_details(self, ixs): 69 | res = [] 70 | for ix in ixs: 71 | res.append({'index': ix, 72 | 'v': self.u.get_item_vector(ix), 73 | 'pos': self.search_to_sentence_index(ix)}) 74 | 75 | return res 76 | 77 | def get_vectors(self, ixs): 78 | return map(lambda x: self.u.reconstruct(x), ixs) 79 | 80 | def get_vector(self, ix): 81 | ix_c = int(ix) 82 | return self.u.reconstruct(ix_c) 83 | 84 | def search_to_sentence_index(self, index): 85 | return index // self.sentence_max_length, index % self.sentence_max_length 86 | 87 | def sentence_to_search_index(self, sentence, pos_in_sent): 88 | return sentence * self.sentence_max_length + pos_in_sent 89 | -------------------------------------------------------------------------------- /client/ts/vis/StateVis.ts: -------------------------------------------------------------------------------- 1 | import {VComponent} from "./VisualComponent"; 2 | import * as d3 from "d3"; 3 | import {flattenDeep} from "lodash"; 4 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 5 | import {D3Sel} from "../etc/LocalTypes"; 6 | import {SVG} from "../etc/SVGplus"; 7 | 8 | type StateVisRender = { states: number[][], yDomain: number[] } 9 | 10 | export interface StateVisData { 11 | states: number[][] 12 | } 13 | 14 | 15 | export class StateVis extends VComponent { 16 | 17 | css_name = 'statevis'; 18 | 19 | layers: { main: D3Sel, axis: D3Sel }; 20 | 21 | static events = {}; 22 | 23 | options = { 24 | pos: {x: 0, y: 0}, 25 | cell_width: 100, 26 | height: 50, 27 | css_class_main: 'state_vis', 28 | css_line: 'state_line', 29 | x_offset: 3, 30 | hidden: true, 31 | // data_access: d => d.encoder.map(e => e.state) 32 | }; 33 | 34 | constructor(d3Parent: D3Sel, eventHandler: SimpleEventHandler, options = {}) { 35 | super(d3Parent, eventHandler); 36 | this.superInit(options, false) 37 | } 38 | 39 | _init() { 40 | if (this.options.hidden) this.hideView(); 41 | this.layers.main = SVG.group(this.base, 'main'); 42 | this.layers.axis = SVG.group(this.base, 'axis'); 43 | } 44 | 45 | _wrangle(data: StateVisData): StateVisRender { 46 | 47 | const op = this.options; 48 | 49 | const orig_states = data.states; 50 | const states = d3.transpose(orig_states); 51 | const yDomain = d3.extent(flattenDeep(states)); 52 | 53 | this.parent.attrs({ 54 | width: (orig_states.length * op.cell_width + (op.x_offset + 5 + 20)), 55 | height: op.height 56 | }); 57 | 58 | 59 | return {states, yDomain} 60 | } 61 | 62 | _render(renderData: StateVisRender) { 63 | 64 | 65 | const op = this.options; 66 | 67 | const x = (i) => op.x_offset + Math.round((i + .5) * op.cell_width); 68 | 69 | const y = d3.scalePow().exponent(.5).domain(renderData.yDomain).range([op.height, 0]); 70 | 71 | const line = d3.line() 72 | .x((_, i) => x(i)) 73 | .y(d => y(d)); 74 | 75 | 76 | const stateLine = this.layers.main.selectAll(`.${op.css_line}`).data(renderData.states); 77 | stateLine.exit().remove(); 78 | 79 | const stateLineEnter = stateLine.enter().append('path').attr('class', op.css_line); 80 | 81 | stateLineEnter.merge(stateLine).attrs({ 82 | 'd': line 83 | }); 84 | 85 | 86 | if (renderData.states.length > 0) { 87 | const yAxis = d3.axisLeft(y).ticks(7); 88 | this.layers.axis.classed("axis state_axis", true) 89 | .call(yAxis).selectAll('*'); 90 | this.layers.axis.attrs({ 91 | transform: `translate(${op.x_offset + op.cell_width * .5 - 3},0)` 92 | }) 93 | } else { 94 | this.layers.axis.selectAll("*").remove(); 95 | } 96 | 97 | 98 | } 99 | 100 | } 101 | -------------------------------------------------------------------------------- /client/ts/api/S2SApi.ts: -------------------------------------------------------------------------------- 1 | import {Networking} from "../etc/Networking"; 2 | 3 | export type TrainDataIndexResponse = { 4 | ids: number[], 5 | loc: string, 6 | res: { 7 | attn: number[][], src: string, tgt: string, 8 | src_words: string[], tgt_words: string[], 9 | tokenId: number, sentId: number 10 | }[] 11 | 12 | } 13 | 14 | 15 | export class S2SApi { 16 | 17 | 18 | static project_info(project_id) { 19 | const request = Networking.ajax_request('../api/project_info'); 20 | 21 | 22 | let payload = new Map(); 23 | if (project_id) { 24 | payload = new Map( 25 | [ 26 | ['project_id', project_id] 27 | ]); 28 | } 29 | 30 | return request.get(payload) 31 | 32 | 33 | } 34 | 35 | static translate({ 36 | input, partial = [], force_attn = <{ [key: number]: number }>{}, 37 | neighbors: neighbors = ['decoder', 'encoder'] //, 'context' 38 | }) { 39 | const request = Networking.ajax_request('../api/translate'); 40 | 41 | let force_attn_array = null; 42 | for (const key in force_attn) { 43 | if (!force_attn_array) force_attn_array = []; 44 | force_attn_array.push(key); 45 | force_attn_array.push(force_attn[key]); 46 | } 47 | 48 | const payload = new Map([['in', input], 49 | ['neighbors', neighbors], 50 | ['partial', partial], 51 | ['force_attn', force_attn_array] 52 | ]); 53 | 54 | return request.get(payload) 55 | } 56 | 57 | static translate_compare({ 58 | input, compare, 59 | neighbors = ['decoder', 'encoder'] //, 'context' 60 | }) { 61 | const request = Networking.ajax_request('../api/translate_compare'); 62 | const payload = new Map([ 63 | ['in', input], 64 | ['compare', compare], 65 | ['neighbors', neighbors]]); 66 | 67 | return request.get(payload) 68 | } 69 | 70 | static closeWords({input, limit = 50, loc = 'src'}) { 71 | const request = Networking.ajax_request('../api/close_words'); 72 | const payload = new Map([ 73 | ['in', input], 74 | ['loc', loc], 75 | ['limit', limit]]); 76 | 77 | return request 78 | .get(payload) 79 | } 80 | 81 | // static compareTranslation({pivot, compare}) { 82 | // const request = Networking.ajax_request('/api/compare_translation'); 83 | // const payload = new Map([ 84 | // ['in', pivot], 85 | // ['compare', compare.join('|')]]); 86 | // 87 | // return request 88 | // .get(payload) 89 | // } 90 | 91 | static trainDataIndices(indices: number[], loc: string) { 92 | //http://0.0.0.0:8080/api/train_data_for_index?indices=123%2C333&loc=src 93 | const request = Networking.ajax_request('../api/train_data_for_index'); 94 | const payload = new Map([['indices', indices.join(',')], 95 | ['loc', loc]]); 96 | 97 | return request 98 | .get(payload) 99 | 100 | 101 | } 102 | 103 | } 104 | 105 | 106 | -------------------------------------------------------------------------------- /client/ts/etc/Networking.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by hen on 5/15/17. 3 | */ 4 | export class Networking { 5 | 6 | /** 7 | * Generates a Ajax Request object. 8 | * @param {string} url - the base url 9 | * @returns {{get: (function(*=)), post: (function(*=)), put: (function(*=)), delete: (function(*=))}} 10 | * the ajax object that can call get, post, put, delete on the url 11 | */ 12 | static ajax_request(url): { get, post, put, delete } { 13 | 14 | /* Adapted from: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise 15 | * EXAMPLE: 16 | 17 | var mdnAPI = 'https://developer.mozilla.org/en-US/search.json'; 18 | var payload = { 19 | 'topic' : 'js', 20 | 'q' : 'Promise' 21 | }; 22 | 23 | var callback = { 24 | success: function(data) { 25 | console.log(1, 'success', JSON.parse(data)); 26 | }, 27 | error: function(data) { 28 | console.log(2, 'error', JSON.parse(data)); 29 | } 30 | }; 31 | 32 | // Executes the method call 33 | $http(mdnAPI) 34 | .get(payload) 35 | .then(callback.success) 36 | .catch(callback.error); 37 | 38 | // Executes the method call but an alternative way (1) to handle Promise Reject case 39 | $http(mdnAPI) 40 | .get(payload) 41 | .then(callback.success, callback.error); 42 | 43 | */ 44 | 45 | // Method that performs the ajax request 46 | const ajax = (method, _url, args) => { 47 | 48 | // Creating a promise 49 | return new Promise((resolve, reject) => { 50 | 51 | // Instantiates the XMLHttpRequest 52 | const client = new XMLHttpRequest(); 53 | let uri = _url; 54 | 55 | if (args && (method === 'POST' || method === 'GET' || method === 'PUT')) { 56 | uri += '?'; 57 | args.forEach((value, key) => { 58 | if (value) { 59 | uri += '&'; 60 | uri += encodeURIComponent(key) + '=' + encodeURIComponent(value); 61 | } 62 | 63 | } 64 | ) 65 | } 66 | 67 | // Debug: console.log('URI', uri, args); 68 | client.open(method, uri); 69 | client.send(); 70 | client.onload = function () { 71 | if (this.status >= 200 && this.status < 300) { 72 | // Performs the function "resolve" when this.status is equal to 2xx 73 | resolve(this.response); 74 | } else { 75 | // Performs the function "reject" when this.status is different than 2xx 76 | reject(this.statusText); 77 | } 78 | }; 79 | client.onerror = function () { 80 | reject(this.statusText); 81 | }; 82 | }); 83 | 84 | }; 85 | 86 | // Adapter pattern 87 | return { 88 | 'get': args => ajax('GET', url, args), 89 | 'post': args => ajax('POST', url, args), 90 | 'put': args => ajax('PUT', url, args), 91 | 'delete': args => ajax('DELETE', url, args) 92 | }; 93 | 94 | 95 | } 96 | } -------------------------------------------------------------------------------- /client/ts/etc/URLHandler.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by hen on 5/15/17. 3 | */ 4 | 5 | export default class URLHandler { 6 | 7 | static basicURL() { 8 | const url_path = window.location.pathname.split('/').slice(0, -2).join('/'); 9 | 10 | return window.location.origin + (url_path.length ? url_path : ''); 11 | } 12 | 13 | /** 14 | * Read all URL parameters into a map. 15 | * @returns {Map} the url parameters as a key-value store (ES6 map) 16 | */ 17 | static get parameters(): object { 18 | // Adapted from: http://stackoverflow.com/questions/2090551/parse-query-string-in-javascript 19 | const query = window.location.search.substring(1); 20 | const vars = query.split('&'); 21 | console.log(vars, "--- vars"); 22 | 23 | const urlParameters = {}; 24 | 25 | const isInt = x => (/^[0-9]+$/).test(x); 26 | const isFloat = x => (/^[0-9]+\.[0-9]*$/).test(x); 27 | 28 | const typeCast = val => { 29 | if (isInt(val)) { 30 | return Number.parseInt(val, 10); 31 | } else if (isFloat(val)) { 32 | return Number.parseFloat(val); 33 | } 34 | // else: 35 | return val; 36 | } 37 | 38 | 39 | vars.forEach(v => { 40 | if (v.length > 0) { 41 | const splits = v.split('='); 42 | const key = decodeURIComponent(splits[0]); 43 | let raw_value = decodeURIComponent(splits[1]); 44 | 45 | const isArray = raw_value.startsWith('..'); 46 | if (isArray) { 47 | raw_value = raw_value.slice(2); 48 | } 49 | 50 | if (raw_value.length < 1) { 51 | urlParameters[key] = isArray ? [] : ''; 52 | } else if (isArray) { 53 | urlParameters[key] = raw_value.split(',') 54 | .map(val => typeCast(val)); 55 | } else { 56 | urlParameters[key] = typeCast(raw_value); 57 | } 58 | } 59 | }); 60 | 61 | return urlParameters; 62 | 63 | } 64 | 65 | /** 66 | * Generates a key-value map of all URL params and replaces replaceKeys 67 | * @param updateKeys 68 | */ 69 | static enrichParameters(updateKeys) { 70 | const currentParams = URLHandler.parameters; 71 | Object.keys(updateKeys).forEach((k) => currentParams[k] = updateKeys[k]) 72 | return currentParams; 73 | } 74 | 75 | /** 76 | * Generates an URL string from a map of url parameters 77 | * @param {{}} urlParameters - the map of parameters 78 | * @returns {string} - an URI string 79 | */ 80 | static urlString(urlParameters: object) { 81 | const attr = []; 82 | Object.keys(urlParameters).forEach(k => { 83 | const v = urlParameters[k]; 84 | if (v !== undefined) { 85 | let value = v; 86 | if (Array.isArray(v)) value = '..' + v.join(','); 87 | attr.push(encodeURI(k + '=' + value)) 88 | } 89 | }); 90 | 91 | 92 | const url = window.location.pathname; 93 | let res = url.substring(url.lastIndexOf('/') + 1); 94 | if (attr.length > 0) { 95 | res += '?' + attr.join('&') 96 | } 97 | 98 | return res; 99 | } 100 | 101 | static setURLParam(key: string, value: string, addToBrowserHistory = true) { 102 | const currentParams = URLHandler.parameters; 103 | currentParams[key] = value; 104 | URLHandler.updateUrl(currentParams, addToBrowserHistory); 105 | } 106 | 107 | static updateUrl(urlParameters: object, addToBrowserHistory = true) { 108 | if (addToBrowserHistory) { 109 | window.history.pushState(urlParameters, '', 110 | URLHandler.urlString(urlParameters)) 111 | } else { 112 | window.history.replaceState(urlParameters, '', 113 | URLHandler.urlString(urlParameters)) 114 | } 115 | } 116 | 117 | } -------------------------------------------------------------------------------- /client/ts/vis/BeamTree.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | import {VComponent} from "./VisualComponent"; 3 | 4 | import {HierarchyPointNode} from "d3-hierarchy"; 5 | import {SVGMeasurements} from "../etc/SVGplus"; 6 | 7 | 8 | export interface BeamTreeData { 9 | root: BeamTree, 10 | maxDepth: number 11 | } 12 | 13 | export interface BeamTree { 14 | children: BeamTree[] 15 | name: string 16 | topBeam?: boolean 17 | 18 | } 19 | 20 | type NodeType = HierarchyPointNode; 21 | 22 | export class BeamTreeVis extends VComponent { 23 | 24 | css_name = 'beamtreevis'; 25 | 26 | protected options = { 27 | pos: {x: 0, y: 0}, 28 | width: 600, 29 | height: 600 30 | }; 31 | private textMeasure: SVGMeasurements; 32 | 33 | 34 | constructor(d3Parent, eventHandler?, options: {} = {}) { 35 | super(d3Parent, eventHandler); 36 | this.superInit(options, true) 37 | 38 | } 39 | 40 | protected _init() { 41 | 42 | this.textMeasure = new SVGMeasurements(this.base, 'node') 43 | 44 | } 45 | 46 | protected _render(renderData: BeamTree): void { 47 | const op = this.options; 48 | 49 | const root = d3.hierarchy(renderData, d => d.children); 50 | console.log(root, "--- root"); 51 | 52 | const treeGen = d3.tree().size([op.height - 10, op.width - 100]) 53 | 54 | 55 | const nodes = treeGen(root).descendants(); 56 | const links = treeGen(root).descendants().slice(1); 57 | 58 | console.log(nodes, "--- nodes"); 59 | const wordWidth = <{ [key: string]: number }>{}; 60 | nodes.forEach(node => { 61 | const n = node.data.name; 62 | wordWidth[n] = this.textMeasure.textLength(n); 63 | }) 64 | 65 | 66 | // bad habbit, but todo: 67 | this.layers.main.selectAll('g.node').remove(); 68 | this.layers.bg.selectAll('.link').remove(); 69 | 70 | 71 | let i = 0; 72 | let nodeEls = this.layers.main.selectAll('g.node') 73 | .data(nodes, function (d: any) { 74 | return d.id || (d.id = ++i); 75 | }); 76 | 77 | 78 | // Enter any new modes at the parent's previous position. 79 | const nodeEnter = nodeEls.enter().append('g') 80 | .attr('class', 'node') 81 | .attr("transform", function (d) { 82 | return "translate(" + (d.y + 5) + "," + (d.x + 5) + ")"; 83 | }).classed('topBeam', d => d.data.topBeam); 84 | 85 | nodeEnter.append('rect').style('fill', 'white'); 86 | nodeEnter.append('text') 87 | .attr('class', 'node_text') 88 | .styles({ 89 | // 'text-anchor': 'middle', 90 | 'dominant-baseline': "middle" 91 | }); 92 | nodeEls = nodeEnter.merge(nodeEls); 93 | 94 | nodeEls.select('rect').attrs({ 95 | x: -2,//d => -wordWidth[d.data.name] / 2 - 2, 96 | width: d => wordWidth[d.data.name] + 4, 97 | height: 6, 98 | y: -3 99 | }); 100 | 101 | nodeEls.select('text') 102 | .text((d) => d.data.name); 103 | 104 | 105 | const linkGen = d3.linkHorizontal() 106 | .x(d => d.y + 5) 107 | .y(d => d.x + 5); 108 | 109 | 110 | let link = this.layers.bg.selectAll('.link') 111 | .data(links, (d: any) => d.id) 112 | 113 | const linkEnter = link.enter().append('path').attr('class', 'link') 114 | 115 | linkEnter.merge(link).attr('d', d => linkGen({ 116 | source: d, 117 | target: d.parent 118 | })).classed('topBeam', d => d.data.topBeam) 119 | 120 | 121 | } 122 | 123 | protected _wrangle(data: BeamTreeData) { 124 | 125 | //TODO: maybe find a better heuristic ? 126 | if (data.maxDepth>0){ 127 | this.parent.attr('width', data.maxDepth * 70); 128 | this.options.width = data.maxDepth * 70; 129 | 130 | const h = Math.max(Math.sqrt(data.maxDepth/10) *150, 150); 131 | this.parent.attr('height', h); 132 | this.options.height = h; 133 | } 134 | 135 | return data.root; 136 | } 137 | 138 | } 139 | -------------------------------------------------------------------------------- /client/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const DefinePlugin = require('webpack').DefinePlugin; 3 | const ExtractTextPlugin = require('extract-text-webpack-plugin'); 4 | const ForkTsCheckerWebpackPlugin = require('fork-ts-checker-webpack-plugin'); 5 | const package = require('./package.json'); 6 | 7 | module.exports = (env) => ({ 8 | entry: './ts/main.ts', 9 | module: { 10 | rules: [ 11 | { 12 | test: /\.tsx?$/, 13 | exclude: /node_modules/, 14 | use: [{ 15 | loader: 'cache-loader' 16 | }, 17 | { 18 | loader: 'thread-loader', 19 | options: { 20 | // there should be 1 cpu for the fork-ts-checker-webpack-plugin 21 | workers: require('os').cpus().length - 1, 22 | }, 23 | }, 24 | { 25 | loader: 'ts-loader', 26 | options: { 27 | happyPackMode: true // IMPORTANT! use happyPackMode mode to speed-up compilation and reduce errors reported to webpack 28 | } 29 | } 30 | ].slice(process.env.CI ? 2 : 0) // no optimizations for CIs 31 | }, 32 | { 33 | test: /\.s?css$/, 34 | use: ExtractTextPlugin.extract({ 35 | fallback: 'style-loader', 36 | use: [ 37 | { 38 | loader: 'css-loader', 39 | options: { 40 | minimize: true, 41 | sourceMap: true 42 | } 43 | }, 44 | { 45 | loader: 'sass-loader', 46 | options: { 47 | sourceMap: true 48 | } 49 | } 50 | ] 51 | }) 52 | }, 53 | { 54 | test: /\.(png|jpg)$/, 55 | loader: 'url-loader', 56 | options: { 57 | limit: 20000 //inline <= 10kb 58 | } 59 | }, 60 | { 61 | test: /\.woff(2)?(\?v=[0-9]\.[0-9]\.[0-9])?$/, 62 | loader: 'url-loader', 63 | options: { 64 | limit: 20000, //inline <= 20kb 65 | mimetype: 'application/font-woff' 66 | } 67 | }, 68 | { 69 | test: /\.svg(2)?(\?v=[0-9]\.[0-9]\.[0-9])?$/, 70 | loader: 'url-loader', 71 | options: { 72 | limit: 10000, //inline <= 10kb 73 | mimetype: 'image/svg+xml' 74 | } 75 | }, 76 | { 77 | test: /\.(ttf|eot)(\?v=[0-9]\.[0-9]\.[0-9])?$/, 78 | loader: 'file-loader' 79 | } 80 | ] 81 | }, 82 | resolve: { 83 | extensions: ['.ts', '.js'] 84 | }, 85 | plugins: [ 86 | new DefinePlugin({ 87 | __VERSION__: JSON.stringify(package.version), 88 | __BUILDID__: JSON.stringify(new Date().toISOString()) 89 | }), 90 | new ExtractTextPlugin('style.css'), 91 | new ForkTsCheckerWebpackPlugin({ 92 | checkSyntacticErrors: true 93 | }) 94 | ], 95 | optimization: { 96 | splitChunks: { 97 | cacheGroups: { 98 | vendor: { 99 | test: /node_modules/, 100 | chunks: "initial", 101 | name: "vendor", 102 | priority: 10, 103 | enforce: true 104 | } 105 | } 106 | } 107 | }, 108 | output: { 109 | filename: '[name].js', 110 | path: path.resolve(__dirname, '../client_dist/') 111 | }, 112 | devServer: { 113 | port: 8090, 114 | proxy: { 115 | '/api/*': { 116 | target: 'http://localhost:8080', 117 | secure: false, 118 | ws: true 119 | } 120 | } 121 | } 122 | }); 123 | -------------------------------------------------------------------------------- /client/fonts/Source_Sans_Pro/OFL.txt: -------------------------------------------------------------------------------- 1 | Copyright 2010, 2012, 2014 Adobe Systems Incorporated (http://www.adobe.com/), with Reserved Font Name ‘Source’. 2 | 3 | This Font Software is licensed under the SIL Open Font License, Version 1.1. 4 | This license is copied below, and is also available with a FAQ at: 5 | http://scripts.sil.org/OFL 6 | 7 | 8 | ----------------------------------------------------------- 9 | SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 10 | ----------------------------------------------------------- 11 | 12 | PREAMBLE 13 | The goals of the Open Font License (OFL) are to stimulate worldwide 14 | development of collaborative font projects, to support the font creation 15 | efforts of academic and linguistic communities, and to provide a free and 16 | open framework in which fonts may be shared and improved in partnership 17 | with others. 18 | 19 | The OFL allows the licensed fonts to be used, studied, modified and 20 | redistributed freely as long as they are not sold by themselves. The 21 | fonts, including any derivative works, can be bundled, embedded, 22 | redistributed and/or sold with any software provided that any reserved 23 | names are not used by derivative works. The fonts and derivatives, 24 | however, cannot be released under any other type of license. The 25 | requirement for fonts to remain under this license does not apply 26 | to any document created using the fonts or their derivatives. 27 | 28 | DEFINITIONS 29 | "Font Software" refers to the set of files released by the Copyright 30 | Holder(s) under this license and clearly marked as such. This may 31 | include source files, build scripts and documentation. 32 | 33 | "Reserved Font Name" refers to any names specified as such after the 34 | copyright statement(s). 35 | 36 | "Original Version" refers to the collection of Font Software components as 37 | distributed by the Copyright Holder(s). 38 | 39 | "Modified Version" refers to any derivative made by adding to, deleting, 40 | or substituting -- in part or in whole -- any of the components of the 41 | Original Version, by changing formats or by porting the Font Software to a 42 | new environment. 43 | 44 | "Author" refers to any designer, engineer, programmer, technical 45 | writer or other person who contributed to the Font Software. 46 | 47 | PERMISSION & CONDITIONS 48 | Permission is hereby granted, free of charge, to any person obtaining 49 | a copy of the Font Software, to use, study, copy, merge, embed, modify, 50 | redistribute, and sell modified and unmodified copies of the Font 51 | Software, subject to the following conditions: 52 | 53 | 1) Neither the Font Software nor any of its individual components, 54 | in Original or Modified Versions, may be sold by itself. 55 | 56 | 2) Original or Modified Versions of the Font Software may be bundled, 57 | redistributed and/or sold with any software, provided that each copy 58 | contains the above copyright notice and this license. These can be 59 | included either as stand-alone text files, human-readable headers or 60 | in the appropriate machine-readable metadata fields within text or 61 | binary files as long as those fields can be easily viewed by the user. 62 | 63 | 3) No Modified Version of the Font Software may use the Reserved Font 64 | Name(s) unless explicit written permission is granted by the corresponding 65 | Copyright Holder. This restriction only applies to the primary font name as 66 | presented to the users. 67 | 68 | 4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font 69 | Software shall not be used to promote, endorse or advertise any 70 | Modified Version, except to acknowledge the contribution(s) of the 71 | Copyright Holder(s) and the Author(s) or with their explicit written 72 | permission. 73 | 74 | 5) The Font Software, modified or unmodified, in part or in whole, 75 | must be distributed entirely under this license, and must not be 76 | distributed under any other license. The requirement for fonts to 77 | remain under this license does not apply to any document created 78 | using the Font Software. 79 | 80 | TERMINATION 81 | This license becomes null and void if any of the above conditions are 82 | not met. 83 | 84 | DISCLAIMER 85 | THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 86 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF 87 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT 88 | OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE 89 | COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 90 | INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL 91 | DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 92 | FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM 93 | OTHER DEALINGS IN THE FONT SOFTWARE. 94 | -------------------------------------------------------------------------------- /client/ts/vis/InfoPanel.ts: -------------------------------------------------------------------------------- 1 | import {D3Sel} from "../etc/LocalTypes"; 2 | import * as d3 from "d3"; 3 | 4 | 5 | export interface InfoPanelData { 6 | translations: { src: string[], tgt: string[] }[], 7 | highlights: { 8 | loc: string, //'tgt' | 'src' 9 | indices: number[] 10 | } 11 | } 12 | 13 | export class InfoPanel { 14 | private infoPanel: D3Sel; 15 | 16 | current = { 17 | highlightOffset: 0, 18 | display: { 19 | tgt: true, 20 | src: true 21 | } 22 | }; 23 | private show_btns: D3Sel; 24 | private offset_btns: D3Sel; 25 | private data: InfoPanelData; 26 | 27 | 28 | constructor(private parent: D3Sel) { 29 | parent.html('
' + 30 | '
') 39 | 40 | this.infoPanel = parent.select('.info_panel'); 41 | 42 | this.show_btns = parent.selectAll('.show_btn'); 43 | this.offset_btns = parent.selectAll('.offset_btn'); 44 | 45 | const that = this; 46 | this.show_btns.on('click', function () { 47 | const v = d3.select(this).text(); 48 | that.current.display[v] = !that.current.display[v]; 49 | that.updateMenu(); 50 | that.actionUpdateCurrent(); 51 | }); 52 | 53 | this.offset_btns.on('click', function () { 54 | that.current.highlightOffset = +d3.select(this).text(); 55 | that.updateMenu(); 56 | that.actionUpdateCurrent(); 57 | }); 58 | 59 | 60 | this.updateMenu(); 61 | 62 | // this.src = parent.select('.src'); 63 | // this.tgt = parent.select('.tgt'); 64 | } 65 | 66 | private updateMenu() { 67 | const that = this; 68 | this.show_btns.classed('selected', function () { 69 | const v = d3.select(this).text(); 70 | return that.current.display[v] 71 | }); 72 | 73 | this.offset_btns.classed('selected', function () { 74 | const v = +d3.select(this).text(); 75 | return that.current.highlightOffset === v; 76 | }); 77 | } 78 | 79 | 80 | actionUpdateCurrent() { 81 | const cur = this.current; 82 | const data = this.data; 83 | 84 | const allT = this.infoPanel.selectAll(".translation"); 85 | 86 | allT.select('.src').html((d: any, i) => 87 | this.renderHighlight(d.src, data.highlights.loc === 'src' ? data.highlights.indices[i] + cur.highlightOffset : -1)) 88 | .attr('hidden', cur.display.src ? null : true); 89 | allT.select('.tgt').html((d: any, i) => 90 | this.renderHighlight(d.tgt, data.highlights.loc === 'tgt' ? data.highlights.indices[i] + cur.highlightOffset : -1)) 91 | .attr('hidden', cur.display.tgt ? null : true); 92 | 93 | // allT.select('.starIt').classed('selected', true); 94 | 95 | } 96 | 97 | 98 | private cleanData(s: string) { 99 | return s.replace(/&/g, '&') 100 | .replace(//g, '>') 102 | // .replace(new RegExp('--\\|'), '') 103 | // .replace(new RegExp('\\|--'), '') 104 | } 105 | 106 | private renderHighlight(words: string[], index: number) { 107 | 108 | return words.map((w, i) => { 109 | w = this.cleanData(w); 110 | return i === index ? `${w}` : w 111 | }).join(' ') 112 | 113 | } 114 | 115 | 116 | update(data: InfoPanelData) { 117 | this.data = data; 118 | 119 | let tSel = this.infoPanel.selectAll(".translation").data(data.translations); 120 | tSel.exit().remove(); 121 | 122 | const tEnter = tSel.enter().append('div').attr('class', 'translation').style('display', 'table-row'); 123 | tEnter.html('
' + 124 | // ' ' + 125 | '
' + 126 | '
' + 127 | '
'); 128 | 129 | tEnter.merge(tSel).order(); 130 | 131 | this.actionUpdateCurrent(); 132 | 133 | } 134 | 135 | } -------------------------------------------------------------------------------- /client/ts/api/Translation.ts: -------------------------------------------------------------------------------- 1 | import {LooseObject} from "../etc/LocalTypes"; 2 | import {cloneDeep, range, sum} from "lodash"; 3 | 4 | export class Translation { 5 | 6 | private readonly _result: { 7 | attn: number[][][], 8 | attnFiltered: number[][][], 9 | scores: number[], 10 | decoder: { 11 | neighbors: number[][], 12 | neighbor_context: number[][], 13 | state: number[], 14 | token: string 15 | }[][], 16 | encoder: { 17 | neighbors: number[][], 18 | state: number[], 19 | token: string 20 | }[], 21 | [key: string]: any 22 | } = null; 23 | 24 | public _current: LooseObject; 25 | 26 | private attention_save = null; 27 | 28 | constructor(result) { 29 | this._result = result; 30 | this.attention_save = cloneDeep(this.attn); 31 | } 32 | 33 | 34 | public increaseAttn(decPos, encPos, beam = 0, factor = .1) { 35 | 36 | const curAttn = this._result.attn[beam][decPos]; 37 | const l = curAttn.length; 38 | 39 | const cost = sum(curAttn) - curAttn[encPos]; 40 | 41 | for (const i in range(l)) { 42 | 43 | if (i == encPos) { 44 | curAttn[i] += factor 45 | } else { 46 | curAttn[i] -= curAttn[i] * factor / cost; 47 | curAttn[i] = Math.max(curAttn[i], 0); 48 | } 49 | } 50 | 51 | this.filterAttention(); 52 | 53 | return curAttn; 54 | } 55 | 56 | public setAttn(decPos, encPos, beam = 0) { 57 | const curAttn = this._result.attn[beam][decPos]; 58 | const l = curAttn.length; 59 | 60 | for (const i in range(l)) { 61 | 62 | if (i == encPos) { 63 | console.log("-hen-- encPOS "); 64 | curAttn[i] = 1 65 | } else { 66 | curAttn[i] = 0 67 | } 68 | } 69 | 70 | this.filterAttention(); 71 | 72 | return curAttn; 73 | } 74 | 75 | public resetAttn(decPos, beam = 0) { 76 | this._result.attn[beam][decPos] = this.attention_save[beam][decPos]; 77 | } 78 | 79 | 80 | public filterAttention(threshold = .75) { 81 | 82 | if (this._result.attn.length > 0) { 83 | 84 | const res = []; 85 | 86 | for (const topSentence of this._result.attn) { 87 | const newSentence = []; 88 | for (const row of topSentence) { 89 | const sortedValues = row.map((v, i) => [v, i]) 90 | .sort((a, b) => b[0] - a[0]); 91 | const newRow = new Array(row.length).fill(0); 92 | let acc = 0; 93 | let index = 0; 94 | while (acc < threshold && index < row.length) { 95 | const v = sortedValues[index][0]; 96 | newRow[sortedValues[index][1]] = v; 97 | acc += v; 98 | index++; 99 | } 100 | newSentence.push(newRow) 101 | } 102 | res.push(newSentence) 103 | } 104 | this._result.attnFiltered = res; 105 | return true; 106 | 107 | } else return false; 108 | 109 | } 110 | 111 | 112 | get encoderWords(): string[] { 113 | return this._result.encoder.map(w => w.token); 114 | } 115 | 116 | get inputSentence(): string { 117 | return this.encoderWords.join(' ') 118 | } 119 | 120 | get decoderWords(): string[][] { 121 | return this._result.decoder.map( 122 | deco => deco.map( 123 | w => w.token)) 124 | } 125 | 126 | 127 | get allNeighbors() { 128 | return this._result.allNeighbors; 129 | } 130 | 131 | get beam_trace_words() { 132 | return this._result.beam_trace_words; 133 | } 134 | 135 | get beam() { 136 | return this._result.beam; 137 | } 138 | 139 | get result() { 140 | return this._result; 141 | } 142 | 143 | get attn() { 144 | return this._result.attn; 145 | } 146 | 147 | get attnFiltered() { 148 | return this._result.attnFiltered; 149 | } 150 | 151 | get encoder() { 152 | return this._result.encoder; 153 | } 154 | 155 | get encoderNeighbors() { 156 | return this._result.encoder.map(d => d.neighbors) 157 | } 158 | 159 | get decoder() { 160 | return this._result.decoder; 161 | } 162 | 163 | get decoderNeighbors() { 164 | return this._result.decoder.map(dec => 165 | dec.map(d => d.neighbors)) 166 | } 167 | 168 | get contextNeighbors() { 169 | return this._result.decoder.map(dec => 170 | dec.map(d => d.neighbor_context)) 171 | } 172 | 173 | get scores() { 174 | return this._result.scores; 175 | } 176 | 177 | } 178 | -------------------------------------------------------------------------------- /client/ts/vis/AttentionVis.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | import {unzip, max, sum, flatten} from "lodash"; 3 | 4 | import {VComponent} from "./VisualComponent"; 5 | 6 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 7 | import {D3Sel} from "../etc/LocalTypes"; 8 | 9 | 10 | type Edge = { 11 | classes: string, 12 | inPos: number, 13 | outPos: number, 14 | width: number, 15 | edge: number[] 16 | } 17 | 18 | enum VertexType {src = 0, tgt = 1} 19 | 20 | /** 21 | * Input to render() 22 | */ 23 | type AV_RenderType = { edges: Edge[], maxPos: number } 24 | 25 | export interface AttentionVisData { 26 | inWidths: number[], 27 | outWidths: number[], 28 | inPos: number[], 29 | outPos: number[], 30 | edgeWeights: number[][], 31 | } 32 | 33 | 34 | export class AttentionVis extends VComponent { 35 | 36 | css_name = 'attentionvis'; 37 | 38 | static VERTEX_TYPE = VertexType; 39 | 40 | static events = {}; 41 | 42 | 43 | options = { 44 | pos: {x: 0, y: 0}, 45 | max_bundle_width: 15, 46 | height: 50, 47 | css_class_main: 'attn_graph', 48 | css_edge: 'attn_edge', 49 | x_offset: 3 50 | }; 51 | 52 | constructor(d3Parent: D3Sel, eventHandler?: SimpleEventHandler, options: {} = {}) { 53 | super(d3Parent, eventHandler); 54 | this.superInit(options, false); 55 | } 56 | 57 | _init() { 58 | } 59 | 60 | private _createGraph(attnWeights: number[][], maxBundleWidth: number, 61 | inWidths: number[], outWidths: number[], 62 | inPos: number[], outPos: number[]): AV_RenderType { 63 | 64 | const attnPerInWord = unzip(attnWeights); 65 | const attnPerInWordSum = attnPerInWord.map(a => sum(a)); 66 | const maxAttnPerAllWords = Math.max(1, max(attnPerInWordSum)); 67 | const lineWidthScale = d3.scaleLinear() 68 | .domain([0, maxAttnPerAllWords]).range([0, maxBundleWidth]); 69 | 70 | let maxPos = 0; 71 | 72 | const inPositionGraph = inWidths.map((inWord, inIndex) => { 73 | let inc = inPos[inIndex] + (inWord - lineWidthScale(attnPerInWordSum[inIndex])) * .5; 74 | return outWidths.map((_, outIndex) => { 75 | const lw = lineWidthScale(attnPerInWord[inIndex][outIndex]); 76 | const res = inc + lw * .5; 77 | inc += lineWidthScale(attnPerInWord[inIndex][outIndex]); 78 | maxPos = inc > maxPos ? inc : maxPos; 79 | return { 80 | inPos: res, 81 | width: lw, 82 | edge: [inIndex, outIndex], 83 | classes: `in${inIndex} out${outIndex}`, 84 | outPos: null 85 | } 86 | }); 87 | }); 88 | 89 | outWidths.forEach((outWord, outIndex) => { 90 | let inc = outPos[outIndex] + (outWord - lineWidthScale(1)) * .5; 91 | inWidths.forEach((_, inIndex) => { 92 | const line = inPositionGraph[inIndex][outIndex]; 93 | line['outPos'] = inc + line.width * .5; 94 | inc += line.width; 95 | maxPos = inc > maxPos ? inc : maxPos; 96 | }) 97 | }); 98 | 99 | return {edges: flatten(inPositionGraph).filter(d => d.width > 0), maxPos}; 100 | 101 | } 102 | 103 | protected _wrangle(data: AttentionVisData) { 104 | 105 | const {edges, maxPos} = this._createGraph( 106 | data.edgeWeights, 107 | this.options.max_bundle_width, 108 | data.inWidths, 109 | data.outWidths, 110 | data.inPos, 111 | data.outPos 112 | ); 113 | 114 | this.parent.attrs({ 115 | width: maxPos + 5 + this.options.x_offset, //reserve 116 | height: this.options.height 117 | }); 118 | 119 | return {edges, maxPos} 120 | 121 | } 122 | 123 | protected _render(renderData: AV_RenderType) { 124 | 125 | // console.log(renderData, "--- renderData"); 126 | 127 | const op = this.options; 128 | 129 | const graph = this.base.selectAll(`.${op.css_class_main}`) 130 | .data(renderData.edges); 131 | graph.exit().remove(); 132 | 133 | const linkGen = d3.linkVertical(); 134 | 135 | const graphEnter = graph.enter().append('g').attr('class', op.css_class_main); 136 | graphEnter.append('path'); 137 | graphEnter.merge(graph).select('path').attrs({ 138 | 'd': d => { 139 | return linkGen({ 140 | source: [d.inPos + op.x_offset, 0], 141 | target: [d.outPos + op.x_offset, op.height] 142 | }) 143 | }, 144 | 'class': d => `${this.options.css_edge} ${d.classes}` 145 | }).style('stroke-width', d => d.width); 146 | 147 | } 148 | 149 | actionHighlightEdges(index: number, type: VertexType, highlight: boolean, className = 'highlight') { 150 | 151 | if (highlight) { 152 | this.base.selectAll(`.${this.options.css_class_main}`) 153 | .classed(className, d => { 154 | return (d).edge[type] === index; 155 | }) 156 | } else { 157 | this.base.selectAll(`.${this.options.css_class_main}`) 158 | .classed(className, false) 159 | 160 | } 161 | 162 | } 163 | 164 | 165 | } 166 | -------------------------------------------------------------------------------- /client/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | S2S Attention 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 26 | 27 | 28 |
29 | 30 | 32 | 33 |
34 | 36 |
37 |
38 |
39 | 40 | 41 |
43 |
44 |
45 |
46 |
47 | 48 |
49 |
51 | 52 | 53 | 54 | 55 |
56 | change: 57 |
58 |
59 |
word 62 |
63 |
attn 65 |
66 |
67 |
compare: 69 |
70 | 71 |
sentence 72 |
73 |
swap:
74 |
75 | 77 |
78 |
79 |
81 | 82 |
83 |
84 | 85 |
86 |
87 |
88 |
89 | 90 |
91 | 92 | 93 | 95 | 96 | 98 | 100 | 101 | 102 |
103 |
104 | 105 | 106 | 107 | 108 | 109 |
110 | 111 |
112 |
113 | 114 |
115 |
116 | 117 |
118 | 119 |
120 | 121 |
122 |
123 |
124 | 125 | 126 | 127 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /client_dist/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | S2S Attention 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 26 | 27 | 28 |
29 | 30 | 32 | 33 |
34 | 36 |
37 |
38 |
39 | 40 | 41 |
43 |
44 |
45 |
46 |
47 | 48 |
49 |
51 | 52 | 53 | 54 | 55 |
56 | change: 57 |
58 |
59 |
word 62 |
63 |
attn 65 |
66 |
67 |
compare: 69 |
70 | 71 |
sentence 72 |
73 |
swap:
74 |
75 | 77 |
78 |
79 |
81 | 82 |
83 |
84 | 85 |
86 |
87 |
88 |
89 | 90 |
91 | 92 | 93 | 95 | 96 | 98 | 100 | 101 | 102 |
103 |
104 | 105 | 106 | 107 | 108 | 109 |
110 | 111 |
112 |
113 | 114 |
115 |
116 | 117 |
118 | 119 |
120 | 121 |
122 |
123 |
124 | 125 | 126 | 127 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /s2s/project.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | from sklearn.externals import joblib 6 | 7 | from index.faissVectorIndex import FaissVectorIndex 8 | from model_api.opennmt_model import ONMTmodelAPI 9 | from index.annoyVectorIndex import AnnoyVectorIndex 10 | 11 | __author__ = 'Hendrik Strobelt, Sebastian Gehrmann' 12 | import yaml 13 | 14 | 15 | class S2SProject: 16 | def __init__(self, config_file, directory): 17 | with open(config_file, 'rb') as cff: 18 | self.config = yaml.load(cff) 19 | self.model = ONMTmodelAPI(os.path.join(directory, self.config['model'])) 20 | self.embeddings = h5py.File( 21 | os.path.join(directory, self.config['embeddings']), mode='r') 22 | self.train_data = h5py.File( 23 | os.path.join(directory, self.config['train']), mode='r') 24 | self.dicts = {'i2t': {'src': {}, 'tgt': {}}, 25 | 't2i': {'src': {}, 'tgt': {}}} 26 | 27 | self.cached_norms = {'src': None, 'tgt': None} 28 | self.directory = os.path.abspath(directory) 29 | 30 | self.indexType = self.config.get('indexType', 'annoy') 31 | self.has_neighbors = ('indices' in self.config) 32 | 33 | self.indices = None 34 | self.currentIndexName = None 35 | self.currentIndex = None 36 | 37 | self.project_model = None 38 | if 'project_model' in self.config: 39 | self.project_model = joblib.load( 40 | os.path.join(directory, self.config['project_model'])) 41 | 42 | for h in ['src', 'tgt']: 43 | with open(os.path.join(directory, self.config['dicts'][h])) as f: 44 | raw = f.readline() 45 | while len(raw) > 0: 46 | line = raw.strip() 47 | if len(line) > 0: 48 | iid, token = line.split() 49 | iid = int(iid) 50 | self.dicts['i2t'][h][iid] = token 51 | self.dicts['i2t'][h][0] = '' # todo: hack 52 | self.dicts['t2i'][h][token] = iid 53 | raw = f.readline() 54 | 55 | def info(self): 56 | return { 57 | 'model': self.config['model'], 58 | 'has_neighbors': self.has_neighbors 59 | } 60 | 61 | def cached_norm(self, loc, matrix): 62 | if self.cached_norms[loc] is None: 63 | self.cached_norms[loc] = np.linalg.norm(matrix, axis=1) 64 | 65 | return self.cached_norms[loc] 66 | 67 | # def convert_result_to_correct_index(self, oldix): 68 | # return oldix // 55, oldix % 55 69 | 70 | def ix2text(self, array, vocab, highlight=-1): 71 | tokens = [] 72 | for ix, t in enumerate(array): 73 | if ix == highlight: 74 | tokens.append("--|" + vocab[t] + "|--") 75 | elif t != 1: 76 | tokens.append(vocab[t]) 77 | return " ".join(tokens) 78 | 79 | def get_train_for_index(self, ixs, data_src='tgt'): 80 | 81 | # Compute length of a sentence when ignoring padding 82 | def compute_sent_length(array): 83 | return np.sum([1 for t in array if t != 1]) 84 | 85 | def ix2words(indices, word_dict): 86 | return [word_dict.get(x, '???') for x in indices if x != 1] 87 | 88 | res = [] 89 | for ix in ixs: 90 | sentIx, tokIx = self.get_index('encoder').search_to_sentence_index( 91 | ix) 92 | # Get raw list of tokens 93 | src_in = self.train_data['src'][sentIx] 94 | tgt_in = self.train_data['tgt'][sentIx] 95 | # Convert to text 96 | if data_src == 'tgt': 97 | src = self.ix2text(src_in, self.dicts['i2t']['src']) 98 | tgt = self.ix2text(tgt_in, self.dicts['i2t']['tgt'], tokIx) 99 | else: 100 | src = self.ix2text(src_in, self.dicts['i2t']['src'], tokIx) 101 | tgt = self.ix2text(tgt_in, self.dicts['i2t']['tgt']) 102 | 103 | # attn = self.train_data['attn'][sentIx] 104 | # src_len = compute_sent_length(src_in) 105 | # tgt_len = compute_sent_length(tgt_in) 106 | # attn = attn[:tgt_len, :src_len] 107 | 108 | res.append({'src': src, 'tgt': tgt, 109 | 'src_words': ix2words(src_in, self.dicts['i2t']['src']), 110 | 'tgt_words': ix2words(tgt_in, self.dicts['i2t']['tgt']), 111 | # 'attn': attn.tolist(), 112 | 'sentId': sentIx, 'tokenId': tokIx}) 113 | # print(src) 114 | # print(tgt) 115 | return res 116 | 117 | def _load_index(self, name): 118 | path = None 119 | if 'indices' in self.config: 120 | if name in self.config['indices']: 121 | path = os.path.join(self.directory, 122 | self.config['indices'][name]) 123 | if not path: 124 | extension = ".ann" 125 | if self.indexType == 'faiss': 126 | extension = ".faiss" 127 | path = os.path.join(self.directory, name + extension) 128 | 129 | if os.path.exists(path): 130 | if self.indexType == 'faiss': 131 | return FaissVectorIndex(path) 132 | else: 133 | return AnnoyVectorIndex(path) 134 | 135 | def preload_indices(self, names=[]): 136 | self.indices = {} 137 | for name in names: 138 | self.indices[name] = self._load_index(name) 139 | 140 | def get_index(self, name): 141 | 142 | if self.indices is not None: # if pre-loaded 143 | if name not in self.indices: 144 | print('loading ', name) 145 | self.indices[name] = self._load_index(name) 146 | return self.indices[name] 147 | 148 | else: # if NOT pre-loaded 149 | if name != self.currentIndexName: 150 | print('loading ', name) 151 | self.currentIndexName = name 152 | self.currentIndex = self._load_index(name) 153 | return self.currentIndex 154 | 155 | # if name not in self.indices: 156 | # 157 | # path = None 158 | # if 'indices' in self.config: 159 | # if name in self.config['indices']: 160 | # path = os.path.join(self.directory, 161 | # self.config['indices'][name]) 162 | # if not path: 163 | # extension = ".ann" 164 | # if self.indexType == 'faiss': 165 | # extension = ".faiss" 166 | # path = os.path.join(self.directory, name + extension) 167 | # 168 | # print('index:', str(path)) 169 | # 170 | # if os.path.exists(path): 171 | # if self.indexType == 'faiss': 172 | # self.indices[name] = FaissVectorIndex(path) 173 | # else: 174 | # self.indices[name] = AnnoyVectorIndex(path) 175 | # 176 | # return self.indices[name] 177 | -------------------------------------------------------------------------------- /swagger.yaml: -------------------------------------------------------------------------------- 1 | # see als: https://github.com/hjacobs/connexion-example 2 | 3 | swagger: '2.0' 4 | info: 5 | title: S2S API 6 | version: "0.1" 7 | consumes: 8 | - application/json 9 | produces: 10 | - application/json 11 | basePath: /api 12 | #security: 13 | # enable OAuth protection for all REST endpoints 14 | # (only active if the TOKENINFO_URL environment variable is set) 15 | # - oauth2: [uid] 16 | 17 | # a simple data path: 18 | paths: 19 | /translate: 20 | get: 21 | tags: [Translate, All] 22 | operationId: server.get_translation 23 | summary: Get Translation and Data for sentence 24 | parameters: 25 | - $ref: '#/parameters/inSentence' 26 | - $ref: '#/parameters/neighbors' 27 | - $ref: '#/parameters/partial' 28 | - $ref: '#/parameters/force_attn' 29 | responses: 30 | 200: 31 | description: Return Translation and meta data 32 | schema: 33 | $ref: '#/definitions/Translation' 34 | /translate_compare: 35 | get: 36 | tags: [Translate, Neighbors, All] 37 | operationId: server.get_translation_compare 38 | summary: Two-way comparison between two sentences with merged vector-spaces 39 | parameters: 40 | - $ref: '#/parameters/inSentence' 41 | - $ref: '#/parameters/compareSentence' 42 | - $ref: '#/parameters/neighbors' 43 | responses: 44 | 200: 45 | description: fun 46 | /project_info: 47 | get: 48 | tags: [All] 49 | operationId: server.get_info 50 | summary: get general project informations 51 | parameters: 52 | - $ref: '#/parameters/project_id' 53 | responses: 54 | 200: 55 | description: fun 56 | # /compare_translation: 57 | # get: 58 | # tags: [Translate, All] 59 | # operationId: server.compare_translation 60 | # summary: Compare Translation and Data for sentences 61 | # parameters: 62 | # - $ref: '#/parameters/inSentence' 63 | # - $ref: '#/parameters/compareSentences' 64 | # - $ref: '#/parameters/partial' 65 | # responses: 66 | # 200: 67 | # description: Return Translations and distances 68 | # schema: 69 | # $ref: '#/definitions/Translation' 70 | /close_words: 71 | get: 72 | tags: [Embedding, All] 73 | operationId: server.get_close_words 74 | summary: Get the closest words w.r.t word embedding 75 | parameters: 76 | - $ref: '#/parameters/inWord' 77 | - $ref: '#/parameters/loc' 78 | - $ref: '#/parameters/p_method' 79 | - $ref: '#/parameters/limit' 80 | responses: 81 | 200: 82 | description: Return list of closest words 83 | 84 | /close_vectors: 85 | get: 86 | tags: [Embedding, All] 87 | operationId: server.get_close_vectors 88 | summary: Find closesest vector to vector at position 89 | parameters: 90 | - $ref: '#/parameters/vector_name' 91 | - $ref: '#/parameters/indices' 92 | responses: 93 | 200: 94 | description: return list of indices 95 | 96 | /neighbor_details: 97 | get: 98 | tags: [Embedding, All] 99 | operationId: server.get_neighbor_details 100 | summary: Get details about close neighbors 101 | parameters: 102 | - $ref: '#/parameters/vector_name' 103 | - $ref: '#/parameters/indices' 104 | - $ref: '#/parameters/p_method' 105 | responses: 106 | 200: 107 | description: return list details 108 | 109 | /train_data_for_index: 110 | get: 111 | tags: [Index, All] 112 | operationId: server.train_data_for_index 113 | summary: Get data from training set for specific index id 114 | parameters: 115 | - $ref: '#/parameters/indices' 116 | - $ref: '#/parameters/loc' 117 | responses: 118 | 200: 119 | description: return list details 120 | 121 | parameters: 122 | inSentence: 123 | name: in 124 | description: the input sentence 125 | in: query 126 | type: string 127 | default: "world" 128 | compareSentence: 129 | name: compare 130 | description: compare sentence 131 | in: query 132 | type: string 133 | compareSentences: 134 | name: compare 135 | description: list of sentences to compare against (tab seperated !!) 136 | in: query 137 | required: true 138 | type: array 139 | collectionFormat: pipes 140 | items: 141 | type: string 142 | inWord: 143 | name: in 144 | description: the input word 145 | in: query 146 | type: string 147 | default: "hello" 148 | loc: 149 | name: loc 150 | description: location - src or tgt 151 | in: query 152 | type: string 153 | enum: 154 | - "src" 155 | - "tgt" 156 | default: "src" 157 | limit: 158 | name: limit 159 | description: limit of results 160 | in: query 161 | type: integer 162 | default: 100 163 | p_method: 164 | name: p_method 165 | description: projection method to use 166 | in: query 167 | type: string 168 | enum: 169 | - "mds" 170 | - "pca" 171 | - "tsne" 172 | - "none" 173 | default: "pca" 174 | vector_name: 175 | name: vector_name 176 | description: Name of the vector -- encoder, embedding, etc... 177 | in: query 178 | type: string 179 | default: "context" 180 | index: 181 | name: index 182 | in: query 183 | description: position in corpus 184 | type: integer 185 | default: 100 186 | indices: 187 | name: indices 188 | in: query 189 | description: positions in corpus 190 | type: array 191 | items: 192 | type: integer 193 | required: true 194 | neighbors: 195 | name: neighbors 196 | description: list of dimensions to add neighbors to 197 | in: query 198 | type: array 199 | items: 200 | type: string 201 | required: false 202 | partial: 203 | name: partial 204 | description: partial translate for decoder 205 | in: query 206 | type: array 207 | collectionFormat: pipes 208 | items: 209 | type: string 210 | required: false 211 | force_attn: 212 | name: force_attn 213 | description: force attention 214 | in: query 215 | type: array 216 | items: 217 | type: integer 218 | required: false 219 | project_id: 220 | name: project_id 221 | description: Project ID 222 | in: query 223 | type: string 224 | required: false 225 | 226 | 227 | # These definitions are only needed for proper documentation 228 | # no functional purpose 229 | definitions: 230 | Translation: 231 | type: object 232 | required: 233 | - in 234 | - out 235 | - attn 236 | properties: 237 | in: 238 | type: string 239 | description: input sentence 240 | example: "The sun is yellow ." 241 | out: 242 | type: string 243 | description: translated sentence 244 | example: "Die Sonne ist gelb ." 245 | attn: 246 | type: array 247 | description: Attention for each output position 248 | example: [[.1,.2,.13],[.3,.4,.2]] 249 | items: 250 | type: integer 251 | 252 | 253 | 254 | 255 | #securityDefinitions: 256 | # oauth2: 257 | # type: oauth2 258 | # flow: implicit 259 | # authorizationUrl: https://example.com/oauth2/dialog 260 | # scopes: 261 | # uid: Unique identifier of the user accessing the service. 262 | -------------------------------------------------------------------------------- /client/ts/vis/CloseWordList.ts: -------------------------------------------------------------------------------- 1 | import {VComponent} from "./VisualComponent"; 2 | import * as d3 from "d3"; 3 | import {min, minBy, maxBy, sortBy, zipWith, max} from "lodash"; 4 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 5 | import {SVGMeasurements} from "../etc/SVGplus"; 6 | import {D3Sel, LooseObject} from "../etc/LocalTypes"; 7 | 8 | type WordListType = LooseObject[]; 9 | 10 | //TODO: Remove if not needed !! 11 | export class CloseWordList extends VComponent { 12 | 13 | css_name = 'closewordlist'; 14 | 15 | options = { 16 | pos: {x: 0, y: 0}, 17 | height: 400, 18 | width: 1000, 19 | lineSpacing: 20, 20 | scoreWidth: 100, 21 | css_class_main: 'close_words', 22 | hidden: false, 23 | data_access: { 24 | pos: d => d.pos, 25 | scores: d => d.score, 26 | words: d => d.word, 27 | compare: d => d.compare 28 | }, 29 | text_measurer:null 30 | }; 31 | 32 | 33 | constructor(d3Parent: D3Sel, eventHandler?: SimpleEventHandler, options: {} = {}) { 34 | super(d3Parent, eventHandler); 35 | this.superInit(options); 36 | } 37 | 38 | 39 | _init() { 40 | const op = this.options; 41 | this.options.text_measurer = this.options.text_measurer 42 | || new SVGMeasurements(this.parent, 'close_word_list'); 43 | 44 | this.parent.attrs({ 45 | width: op.width, 46 | height: op.height 47 | }); 48 | if (this.options.hidden) this.hideView(); 49 | } 50 | 51 | protected _wrangle(data) { 52 | 53 | console.log("wrnagle--- "); 54 | 55 | const op = this.options; 56 | 57 | // const raw_pos = op.data_access.pos(data); 58 | // const x_values = raw_pos.map(d => d[0]); 59 | // const y_values = raw_pos.map(d => d[1]); 60 | // 61 | // const p0_min = minBy(x_values); 62 | // const p1_min = minBy(y_values); 63 | // 64 | // const diff0 = maxBy(x_values) - p0_min; 65 | // const diff1 = maxBy(y_values) - p1_min; 66 | // 67 | // 68 | // let norm_pos = []; 69 | // 70 | // if (diff0 > diff1) { 71 | // norm_pos = raw_pos.map(d => [(d[0] - p0_min) / diff0, (d[1] - p1_min) / diff0]); 72 | // } else { 73 | // norm_pos = raw_pos.map(d => [(d[0] - p0_min) / diff1, (d[1] - p1_min) / diff1]) 74 | // } 75 | 76 | const words = op.data_access.words(data); 77 | const wordWidth = words.map(w => op.text_measurer.textLength(w)); 78 | const scores = op.data_access.scores(data); 79 | const compare = op.data_access.compare(data); 80 | this._current.has_compare = compare !== null; 81 | 82 | // if (this._states.has_compare) { 83 | return sortBy(zipWith(words, scores, wordWidth, compare, 84 | (word, score, width, compare) => ({ 85 | word, 86 | score, 87 | width, 88 | compare 89 | })), d => -d.score); 90 | // } else { 91 | // return sortBy(zipWith(words, scores, wordWidth, 92 | // (word, score, width) => ({word, score, width})), d => -d.score); 93 | // } 94 | 95 | } 96 | 97 | _render(renderData: LooseObject[]) { 98 | 99 | const op = this.options; 100 | const noItems = renderData.length; 101 | const ls = op.lineSpacing; 102 | const f2f = d3.format(".2f"); 103 | 104 | 105 | this.parent.attr('height', noItems * ls); 106 | 107 | 108 | const word = this.layers.main.selectAll(".word").data(renderData); 109 | word.exit().remove(); 110 | 111 | const wordEnter = word.enter().append('text').attr('class', 'word'); 112 | 113 | const yscale = d3.scaleLinear().domain([0, noItems - 1]) 114 | .range([ls / 2, (noItems - .5) * ls]); 115 | 116 | 117 | //TODO: BAD HACK - -should not be using indices 118 | 119 | wordEnter.merge(word).attrs({ 120 | x: () => 10, 121 | y: (d, i) => yscale(i), 122 | }).text(d => d.word); 123 | // .style('font-size', d => wordScale(d.score) + 'px') 124 | 125 | 126 | const wordEnd = maxBy(renderData, 'width').width; 127 | const maxScore = maxBy(renderData, 'score').score; 128 | 129 | const barScale = d3.scaleLinear().domain([0, maxScore]) 130 | .range([0, op.scoreWidth]); 131 | 132 | const scoreBars = this.layers.main.selectAll(".scoreBar").data(renderData); 133 | scoreBars.exit().remove(); 134 | 135 | const scoreBarsEnter = scoreBars.enter().append('g').attr('class', 'scoreBar'); 136 | scoreBarsEnter.append('rect'); 137 | scoreBarsEnter.append('text').attrs({ 138 | x: 2, 139 | y: ls / 2 - 2, 140 | 'class': 'barText' 141 | }); 142 | 143 | const allScoreBars = scoreBarsEnter.merge(scoreBars).attrs({ 144 | transform: (d, i) => `translate(${wordEnd + 10 + 10},${yscale(i) - ls / 2 })` 145 | }); 146 | 147 | allScoreBars.select('rect').attrs({ 148 | width: d => barScale(d.score), 149 | height: ls - 4 150 | }); 151 | allScoreBars.select('text').text(d => f2f(d.score)); 152 | 153 | 154 | if (this._current.has_compare) { 155 | 156 | const bd_max = max(renderData.map(d => d.compare.dist)); 157 | const bd_scale = d3.scaleLinear().domain([0, bd_max]) 158 | .range([1, 100]); 159 | 160 | 161 | const barDist = this.layers.main.selectAll(".distBar").data(renderData); 162 | barDist.exit().remove(); 163 | const barDistEnter = barDist.enter().append('g').attr('class', 'distBar'); 164 | barDistEnter.append('rect'); 165 | barDistEnter.append('text').attrs({ 166 | x: 2, 167 | y: ls / 2 - 2, 168 | 'class': 'barText' 169 | }); 170 | 171 | 172 | const all_barDist = barDistEnter.merge(barDist).attrs({ 173 | transform: (d, i) => `translate(${wordEnd + 10 + 10 + op.scoreWidth + 10},${yscale(i) - ls / 2 })` 174 | }); 175 | all_barDist.select('rect') 176 | .attrs({ 177 | width: d => bd_scale(d.compare.dist), 178 | height: ls - 4 179 | }); 180 | 181 | all_barDist.select('text').text(d => f2f(d.compare.dist)); 182 | 183 | 184 | const wordComp = this.layers.main.selectAll(".wordComp").data(renderData); 185 | wordComp.exit().remove(); 186 | 187 | const wordCompEnter = wordComp.enter().append('text').attr('class', 'wordComp'); 188 | 189 | wordCompEnter.merge(wordComp).attrs({ 190 | transform: (d, i) => `translate(${wordEnd + 10 + 10 + op.scoreWidth + 120},${yscale(i)})` 191 | }).text(d => d.compare.sentence) 192 | } else { 193 | this.layers.main.selectAll(".wordComp").remove() 194 | } 195 | 196 | 197 | // console.log(wordEnd, [wordEnd], "--- wordEnd,[wordEnd]"); 198 | // const dLine = this.layers.bg.selectAll('.dividerLine').data([wordEnd]) 199 | // dLine.enter().append('line').attr('class', 'dividerLine') 200 | // .merge(dLine).attrs({ 201 | // x1: d => d + 10, 202 | // x2: d => d + 10, 203 | // y1: 0, 204 | // y2: noItems * ls 205 | // }) 206 | 207 | 208 | } 209 | 210 | } 211 | -------------------------------------------------------------------------------- /client/ts/vis/VisualComponent.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by Hendrik Strobelt (hendrik.strobelt.com) on 12/3/16. 3 | */ 4 | import {Util} from "../etc/Util"; 5 | import * as d3 from 'd3' 6 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 7 | import {SVG} from "../etc/SVGplus"; 8 | import {D3Sel, LooseObject} from "../etc/LocalTypes"; 9 | 10 | 11 | export abstract class VComponent { 12 | 13 | // STATIC FIELDS ============================================================ 14 | 15 | /** 16 | * The static property that contains all class related events. 17 | * Should be overwritten and event strings have to be unique!! 18 | */ 19 | 20 | static events: {} = {noEvent: 'VComponent_noEvent'}; 21 | 22 | /** 23 | * set of ALL options and their defaults 24 | * Example: 25 | * { 26 | pos: {x: 10, y: 10}, 27 | // List of Events that are ONLY handled globally: 28 | globalExclusiveEvents: [] 29 | }; 30 | * 31 | */ 32 | // abstract readonly defaultOptions; 33 | 34 | 35 | // /** 36 | // * Defines the layers in SVG for bg,main,fg,... 37 | // */ 38 | // protected abstract readonly layout: { name: string, pos: number[] }[] = [{name: 'main', pos: [0, 0]}]; 39 | 40 | 41 | protected id: string; 42 | protected parent: D3Sel; 43 | protected abstract options: { pos: { x: number, y: number }, [key: string]: any }; 44 | protected base: D3Sel; 45 | protected layers: { main?: D3Sel, fg?: D3Sel, bg?: D3Sel, [key: string]: D3Sel }; 46 | protected eventHandler: SimpleEventHandler; 47 | protected _current: { hidden: boolean, hideElement?: D3Sel | null; [key: string]: any }; 48 | protected data: any; 49 | protected renderData: any; 50 | protected abstract css_name:string; 51 | 52 | // CONSTRUCTOR ============================================================ 53 | 54 | 55 | /** 56 | * Simple constructor. Subclasses should call @superInit(options) as well. 57 | * see why here: https://stackoverflow.com/questions/43595943/why-are-derived-class-property-values-not-seen-in-the-base-class-constructor 58 | * 59 | * template: 60 | constructor(d3Parent: D3Sel, eventHandler?: SimpleEventHandler, options: {} = {}) { 61 | super(d3Parent, eventHandler); 62 | // -- access to subclass params: 63 | this.superInit(options); 64 | } 65 | * 66 | * @param {D3Sel} d3parent D3 selection of parent SVG DOM Element 67 | * @param {SimpleEventHandler} eventHandler a global event handler object or 'null' for local event handler 68 | */ 69 | protected constructor(d3parent: D3Sel, eventHandler?: SimpleEventHandler) { 70 | this.id = Util.simpleUId({}); 71 | 72 | this.parent = d3parent; 73 | 74 | // If not further specified - create a local event handler bound to the bas element 75 | this.eventHandler = eventHandler || 76 | new SimpleEventHandler(this.base.node()); 77 | 78 | // Object for storing internal states and variables 79 | this._current = {hidden: false}; 80 | 81 | } 82 | 83 | /** 84 | * Has to be called as last call in subclass constructor. 85 | * @param {{}} options 86 | * @param defaultLayers -- create the default layers: bg -> main -> fg 87 | * @param runInit -- run this._init() or not 88 | */ 89 | protected superInit(options: {} = {}, defaultLayers = true, runInit = true, createSVG = true) { 90 | // Set default options if not specified in constructor call 91 | // const defaults = this.defaultOptions; 92 | // this.options = {}; 93 | // const keys = new Set([...Object.keys(defaults), ...Object.keys(options)]); 94 | // keys.forEach(key => this.options[key] = (key in options) ? options[key] : defaults[key]); 95 | Object.keys(options).forEach(key => this.options[key] = options[key]); 96 | 97 | 98 | this.layers = {}; 99 | // Create the base group element 100 | if (createSVG) { 101 | this.base = SVG.group(this.parent, 102 | this.css_name + ' ID' + this.id, 103 | this.options.pos); 104 | 105 | 106 | // create default layers: background, main, foreground 107 | if (defaultLayers) { 108 | // construction order is important ! 109 | this.layers.bg = SVG.group(this.base, 'bg'); 110 | this.layers.main = SVG.group(this.base, 'main'); 111 | this.layers.fg = SVG.group(this.base, 'fg'); 112 | } 113 | 114 | } 115 | if (runInit) this._init(); 116 | } 117 | 118 | 119 | /** 120 | * Should be overwritten to create the static DOM elements 121 | * @abstract 122 | * @return {*} --- 123 | */ 124 | protected abstract _init(); 125 | 126 | // DATA UPDATE & RENDER ============================================================ 127 | 128 | /** 129 | * Every time data has changed, update is called and 130 | * triggers wrangling and re-rendering 131 | * @param {Object} data data object 132 | * @return {*} --- 133 | */ 134 | update(data: DataInterface) { 135 | this.data = data; 136 | if (this._current.hidden) return; 137 | this.renderData = this._wrangle(data); 138 | this._render(this.renderData); 139 | } 140 | 141 | 142 | /** 143 | * Data wrangling method -- implement in subclass. Returns this.renderData. 144 | * Simplest implementation: `return data;` 145 | * @param {Object} data data 146 | * @returns {*} -- data in render format 147 | * @abstract 148 | */ 149 | protected abstract _wrangle(data); 150 | 151 | 152 | /** 153 | * Is responsible for mapping data to DOM elements 154 | * @param {Object} renderData pre-processed (wrangled) data 155 | * @abstract 156 | * @returns {*} --- 157 | */ 158 | protected abstract _render(renderData): void; 159 | 160 | 161 | // UPDATE OPTIONS ============================================================ 162 | /** 163 | * Updates instance options 164 | * @param {Object} options only the options that should be updated 165 | * @param {Boolean} reRender if option change requires a re-rendering (default:false) 166 | * @returns {*} --- 167 | */ 168 | updateOptions({options, reRender = false}) { 169 | Object.keys(options).forEach(k => this.options[k] = options[k]); 170 | if (reRender) this._render(this.renderData); 171 | } 172 | 173 | 174 | // === CONVENIENCE ==== 175 | 176 | 177 | setHideElement(hE: D3Sel) { 178 | this._current.hideElement = hE; 179 | } 180 | 181 | hideView() { 182 | if (!this._current.hidden) { 183 | const hE = this._current.hideElement || this.parent; 184 | hE.transition().styles({ 185 | 'opacity': 0, 186 | 'pointer-events': 'none' 187 | }).style('display', 'none'); 188 | this._current.hidden = true; 189 | } 190 | } 191 | 192 | unhideView() { 193 | if (this._current.hidden) { 194 | const hE = this._current.hideElement || this.parent; 195 | hE.transition().styles({ 196 | 'opacity': 1, 197 | 'pointer-events': null, 198 | 'display': null 199 | }); 200 | this._current.hidden = false; 201 | // this.update(this.data); 202 | 203 | } 204 | } 205 | 206 | destroy() { 207 | this.base.remove(); 208 | } 209 | 210 | } 211 | 212 | -------------------------------------------------------------------------------- /client/ts/vis/WordProjector.ts: -------------------------------------------------------------------------------- 1 | import {VComponent} from "./VisualComponent"; 2 | import {min, max, sortBy, zipWith} from "lodash"; 3 | import * as d3 from "d3"; 4 | import {removeOverlaps, Rectangle} from "webcola"; 5 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 6 | import {SVGMeasurements} from "../etc/SVGplus"; 7 | import {D3Sel, LooseObject} from "../etc/LocalTypes"; 8 | 9 | 10 | type wordObj = { 11 | word: string; 12 | score: number; 13 | pos: number[]; 14 | compare: { 15 | attn: number[][]; 16 | attn_padding: number[][]; 17 | dist: number; 18 | orig: string; 19 | sentence: string 20 | } 21 | } 22 | 23 | export type WordProjectorClickedEvent = { 24 | caller: WordProjector, 25 | word: string, 26 | sentence: string, 27 | [key: string]: any 28 | } 29 | 30 | export class WordProjector extends VComponent { 31 | 32 | css_name = 'wordprojector'; 33 | 34 | static events = { 35 | wordClicked: "WordProjector_word_clicked" 36 | } 37 | 38 | options = { 39 | pos: {x: 0, y: 0}, 40 | height: 400, 41 | width: 500, 42 | css_class_main: 'wp_vis', 43 | hidden: false, 44 | data_access: { 45 | pos: d => d.pos, 46 | scores: d => d.score, 47 | words: d => d.word, 48 | compare: d => d.compare 49 | }, 50 | text_measurer: null, 51 | loc: null 52 | }; 53 | 54 | 55 | layout = [ 56 | {name: 'bg', pos: [0, 0]}, 57 | {name: 'main', pos: [0, 0]}, 58 | ]; 59 | 60 | //-- default constructor -- 61 | constructor(d3Parent: D3Sel, eventHandler?: SimpleEventHandler, options: {} = {}) { 62 | super(d3Parent, eventHandler); 63 | this.superInit(options); 64 | } 65 | 66 | _init() { 67 | const op = this.options; 68 | this.options.text_measurer = this.options.text_measurer 69 | || new SVGMeasurements(this.parent, 'measureWord'); 70 | 71 | this.parent.attrs({ 72 | width: op.width, 73 | height: op.height 74 | }); 75 | if (this.options.hidden) this.hideView(); 76 | } 77 | 78 | _wrangle(data) { 79 | 80 | const op = this.options; 81 | 82 | const raw_pos = op.data_access.pos(data); 83 | const x_values = raw_pos.map(d => d[0]); 84 | const y_values = raw_pos.map(d => d[1]); 85 | 86 | const p0_min = min(x_values); 87 | const p1_min = min(y_values); 88 | 89 | const diff0 = max(x_values) - p0_min; 90 | const diff1 = max(y_values) - p1_min; 91 | 92 | 93 | let norm_pos = []; 94 | 95 | if (diff0 > diff1) { 96 | norm_pos = raw_pos.map(d => [(d[0] - p0_min) / diff0, (d[1] - p1_min) / diff0]); 97 | } else { 98 | norm_pos = raw_pos.map(d => [(d[0] - p0_min) / diff1, (d[1] - p1_min) / diff1]) 99 | } 100 | 101 | const words = op.data_access.words(data); 102 | const scores = op.data_access.scores(data); 103 | const compare = op.data_access.compare(data); 104 | this._current.has_compare = compare !== null; 105 | 106 | this._current.clearHighlights = true; 107 | 108 | return sortBy(zipWith(words, scores, norm_pos, compare, 109 | (word, score, pos, compare) => ({word, score, pos, compare})), 110 | (d: { word, score, pos, compare }) => -d.score); 111 | } 112 | 113 | _render(renderData: wordObj[]) { 114 | 115 | // console.log(renderData, "--- renderData"); 116 | const op = this.options; 117 | 118 | const word = this.layers.main.selectAll(".word").data(renderData); 119 | word.exit().remove(); 120 | 121 | const wordEnter = word.enter().append('g').attr('class', 'word'); 122 | wordEnter.append('rect'); 123 | wordEnter.append('text'); 124 | 125 | const xscale = d3.scaleLinear().range([30, op.width - 30]); 126 | const yscale = d3.scaleLinear().range([10, op.height - 10]); 127 | const scoreExtent = d3.extent(renderData.map(d => d.score)); 128 | const wordScale = d3.scaleLinear().domain(scoreExtent).range([10, 12]); 129 | 130 | // use webcola to remove rectangle overlap. 131 | const newPos = this.removeOverlap(renderData, wordScale, xscale, yscale, op.text_measurer); 132 | 133 | const allWords = wordEnter.merge(word); 134 | allWords.attr('transform', 135 | d => `translate(${newPos[d.word].cx}, ${newPos[d.word].cy})`); 136 | allWords.on('click', d => this.clickWord(d)); 137 | allWords.classed('query', (d, i) => i == 0) 138 | 139 | 140 | allWords.select('rect').attrs({ 141 | width: (d, i) => newPos[d.word].w, 142 | height: (d, i) => newPos[d.word].h - 2, 143 | x: (d, i) => -newPos[d.word].w * .5, 144 | y: (d, i) => -newPos[d.word].h * .5 + 1, 145 | }); 146 | allWords.select('text') 147 | .text(d => d.word) 148 | .style('font-size', d => wordScale(d.score) + 'pt'); 149 | 150 | // if (this._current.has_compare) { 151 | // const bd_max = max(renderData.map(d => d.compare.dist)); 152 | // const bd_scale = d3.scaleLinear().domain([0, bd_max]) 153 | // .range(['#ffffff', '#63676e']); //TODO: hard-coded range ?? 154 | // allWords.select('rect').style('fill', d => { 155 | // return bd_scale(d.compare.dist) 156 | // }) 157 | // 158 | // } 159 | 160 | 161 | if (this._current.clearHighlights) { 162 | this.highlightWord(null, false); 163 | this._current.clearHighlights = false; 164 | } 165 | 166 | 167 | } 168 | 169 | private removeOverlap(renderData: wordObj[], wordScale, xscale, yscale, text_measurer: SVGMeasurements) { 170 | const ofree = []; 171 | 172 | for (const rd of renderData) { 173 | const w = rd.word; 174 | const height = wordScale(rd.score); 175 | const x = xscale(rd.pos[0]); 176 | const y = yscale(rd.pos[1]); 177 | 178 | const width = text_measurer.textLength(w, 'font-size:' + height + 'pt;'); 179 | 180 | ofree.push(new Rectangle(x - width / 2 - 4, x + width / 2 + 4, 181 | y - height / 2 - 3, y + height / 2 + 3)) 182 | 183 | } 184 | 185 | 186 | removeOverlaps(ofree); 187 | 188 | const newPos = {}; 189 | ofree.forEach((d, i) => { 190 | newPos[renderData[i].word] = { 191 | cx: (d.X + d.x) * .5, 192 | cy: (d.Y + d.y) * .5, 193 | w: (d.X - d.x), 194 | h: (d.Y - d.y) 195 | } 196 | }); 197 | return newPos; 198 | } 199 | 200 | highlightWord(word: string, highlight: boolean, exclusive = true, label = 'selected'): void { 201 | 202 | const allWords = this.layers.main.selectAll(".word"); 203 | 204 | if (!highlight && exclusive) { 205 | allWords.classed(label, false); 206 | } else { 207 | allWords 208 | .classed(label, function (d: wordObj) { 209 | if ((d.word === word)) { 210 | return highlight; 211 | } else { 212 | if (exclusive) return false; 213 | else return d3.select(this).classed(label) 214 | } 215 | }) 216 | } 217 | 218 | 219 | } 220 | 221 | 222 | private clickWord(d: wordObj) { 223 | this.eventHandler.trigger(WordProjector.events.wordClicked, { 224 | caller: this, 225 | wordObj: d, 226 | sentence: d.compare.orig, 227 | word: d.word 228 | }) 229 | 230 | 231 | } 232 | 233 | 234 | } 235 | -------------------------------------------------------------------------------- /client/fonts/ssp.css: -------------------------------------------------------------------------------- 1 | /* cyrillic-ext */ 2 | @font-face { 3 | font-family: 'Source Sans Pro'; 4 | font-style: normal; 5 | font-weight: 300; 6 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGOpKQFvsMoBH--zPuE-O8ur3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 7 | unicode-range: U+0460-052F, U+1C80-1C88, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; 8 | } 9 | /* cyrillic */ 10 | @font-face { 11 | font-family: 'Source Sans Pro'; 12 | font-style: normal; 13 | font-weight: 300; 14 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGOD5L8S_9S_m4UvyMw5M1CX3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 15 | unicode-range: U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; 16 | } 17 | /* greek-ext */ 18 | @font-face { 19 | font-family: 'Source Sans Pro'; 20 | font-style: normal; 21 | font-weight: 300; 22 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGHtAKCSP8Scq_VkmNL_V6Mr3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 23 | unicode-range: U+1F00-1FFF; 24 | } 25 | /* greek */ 26 | @font-face { 27 | font-family: 'Source Sans Pro'; 28 | font-style: normal; 29 | font-weight: 300; 30 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGCDuzJWrwWDT04xi-7BJnlv3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 31 | unicode-range: U+0370-03FF; 32 | } 33 | /* vietnamese */ 34 | @font-face { 35 | font-family: 'Source Sans Pro'; 36 | font-style: normal; 37 | font-weight: 300; 38 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGCD5K6T8I4oZ1X3Xvlj_UeP3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 39 | unicode-range: U+0102-0103, U+0110-0111, U+1EA0-1EF9, U+20AB; 40 | } 41 | /* latin-ext */ 42 | @font-face { 43 | font-family: 'Source Sans Pro'; 44 | font-style: normal; 45 | font-weight: 300; 46 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGDOFnJNygIkrHciC8BWzbCz3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 47 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+20A0-20AB, U+20AD-20CF, U+2C60-2C7F, U+A720-A7FF; 48 | } 49 | /* latin */ 50 | @font-face { 51 | font-family: 'Source Sans Pro'; 52 | font-style: normal; 53 | font-weight: 300; 54 | src: local('Source Sans Pro Light'), local('SourceSansPro-Light'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGCP2LEk6lMzYsRqr3dHFImA.woff2) format('woff2'); 55 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2212, U+2215; 56 | } 57 | /* cyrillic-ext */ 58 | @font-face { 59 | font-family: 'Source Sans Pro'; 60 | font-style: normal; 61 | font-weight: 400; 62 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlIXYUqYVJeq1_JtQruA3_e8.woff2) format('woff2'); 63 | unicode-range: U+0460-052F, U+1C80-1C88, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; 64 | } 65 | /* cyrillic */ 66 | @font-face { 67 | font-family: 'Source Sans Pro'; 68 | font-style: normal; 69 | font-weight: 400; 70 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlExulUiGX8tUMVYeuJmbj48.woff2) format('woff2'); 71 | unicode-range: U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; 72 | } 73 | /* greek-ext */ 74 | @font-face { 75 | font-family: 'Source Sans Pro'; 76 | font-style: normal; 77 | font-weight: 400; 78 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlBA0E65p__AYvizJB6RduYY.woff2) format('woff2'); 79 | unicode-range: U+1F00-1FFF; 80 | } 81 | /* greek */ 82 | @font-face { 83 | font-family: 'Source Sans Pro'; 84 | font-style: normal; 85 | font-weight: 400; 86 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlC7-kXQoo3swP0nQ6K7J6xc.woff2) format('woff2'); 87 | unicode-range: U+0370-03FF; 88 | } 89 | /* vietnamese */ 90 | @font-face { 91 | font-family: 'Source Sans Pro'; 92 | font-style: normal; 93 | font-weight: 400; 94 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlCxe5Tewm2_XWfbGchcXw4g.woff2) format('woff2'); 95 | unicode-range: U+0102-0103, U+0110-0111, U+1EA0-1EF9, U+20AB; 96 | } 97 | /* latin-ext */ 98 | @font-face { 99 | font-family: 'Source Sans Pro'; 100 | font-style: normal; 101 | font-weight: 400; 102 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlIa1YDtoarzwSXxTHggEXMw.woff2) format('woff2'); 103 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+20A0-20AB, U+20AD-20CF, U+2C60-2C7F, U+A720-A7FF; 104 | } 105 | /* latin */ 106 | @font-face { 107 | font-family: 'Source Sans Pro'; 108 | font-style: normal; 109 | font-weight: 400; 110 | src: local('Source Sans Pro Regular'), local('SourceSansPro-Regular'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/ODelI1aHBYDBqgeIAH2zlJbPFduIYtoLzwST68uhz_Y.woff2) format('woff2'); 111 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2212, U+2215; 112 | } 113 | /* cyrillic-ext */ 114 | @font-face { 115 | font-family: 'Source Sans Pro'; 116 | font-style: normal; 117 | font-weight: 600; 118 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGGWGG8n76xaP_JUl9houU473rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 119 | unicode-range: U+0460-052F, U+1C80-1C88, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; 120 | } 121 | /* cyrillic */ 122 | @font-face { 123 | font-family: 'Source Sans Pro'; 124 | font-style: normal; 125 | font-weight: 600; 126 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGMP5gXq4cN8pjVji5g2q9Mf3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 127 | unicode-range: U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; 128 | } 129 | /* greek-ext */ 130 | @font-face { 131 | font-family: 'Source Sans Pro'; 132 | font-style: normal; 133 | font-weight: 600; 134 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGIwxT-R1rCKQkeTtsDWzfjL3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 135 | unicode-range: U+1F00-1FFF; 136 | } 137 | /* greek */ 138 | @font-face { 139 | font-family: 'Source Sans Pro'; 140 | font-style: normal; 141 | font-weight: 600; 142 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGFCUBMgATkHAQY-Bv-74xcn3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 143 | unicode-range: U+0370-03FF; 144 | } 145 | /* vietnamese */ 146 | @font-face { 147 | font-family: 'Source Sans Pro'; 148 | font-style: normal; 149 | font-weight: 600; 150 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGMZXFz2iDKd7GJNSaxRYiSj3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 151 | unicode-range: U+0102-0103, U+0110-0111, U+1EA0-1EF9, U+20AB; 152 | } 153 | /* latin-ext */ 154 | @font-face { 155 | font-family: 'Source Sans Pro'; 156 | font-style: normal; 157 | font-weight: 600; 158 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGKyGJhAh-RE0BxGcd_izyev3rGVtsTkPsbDajuO5ueQw.woff2) format('woff2'); 159 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+20A0-20AB, U+20AD-20CF, U+2C60-2C7F, U+A720-A7FF; 160 | } 161 | /* latin */ 162 | @font-face { 163 | font-family: 'Source Sans Pro'; 164 | font-style: normal; 165 | font-weight: 600; 166 | src: local('Source Sans Pro SemiBold'), local('SourceSansPro-SemiBold'), url(https://fonts.gstatic.com/s/sourcesanspro/v11/toadOcfmlt9b38dHJxOBGMzFoXZ-Kj537nB_-9jJhlA.woff2) format('woff2'); 167 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2212, U+2215; 168 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq-Vis 2 | [![CircleCI](https://circleci.com/gh/sgratzl/Seq2Seq-Vis.svg?style=shield)](https://circleci.com/gh/sgratzl/Seq2Seq-Vis) 3 | [![Docker Pulls](https://img.shields.io/docker/pulls/sgratzl/seq2seq-vis.svg?maxAge=604800)](https://hub.docker.com/r/sgratzl/seq2seq-vis/) 4 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 5 | [![Latest Release](https://img.shields.io/github/release/HendrikStrobelt/Seq2Seq-Vis/all.svg) ](https://github.com/HendrikStrobelt/Seq2Seq-Vis/releases) 6 | 7 | 8 | ### A visual debugging tool for Sequence-to-Sequence models 9 | *by IBM Research in Cambridge and Harvard SEAS -- more info [seq2seq-vis.io](http://seq2seq-vis.io) 10 | 11 | ![Seq2Seq-Vis](docs/pics/s2s_teaser.png) 12 | 13 | 14 | - [Seq2Seq-Vis](#seq2seq-vis) 15 | * [Install and run with `conda`](#install-and-run-with-conda) 16 | + [1 - Install dependencies (server and client) and create virtual environment](#1---install-dependencies-server-and-client-and-create-virtual-environment) 17 | + [2 - Install custom OpenNMT-py version](#2---install-custom-opennmt-py-version) 18 | + [3 - Download some example data](#3---download-some-example-data) 19 | + [4 - Run the system](#4----run-the-system) 20 | * [Install and run with `docker`](#install-and-run-with-docker) 21 | * [Prepare and run own models](#prepare-and-run-own-models) 22 | + [1 - Prepare your data](#1---prepare-your-data) 23 | + [2 - Create a `s2s.yaml` file to describe project](#2---create-a-s2syaml-file-to-describe-project) 24 | + [3 - Command Line Parameters](#3---command-line-parameters) 25 | * [Change frontend](client/README.md) 26 | - [Cite us](#cite-us) 27 | - [Contributors](#contributors) 28 | - [License](#license) 29 | 30 | ## Install and run with `conda` 31 | 32 | We require using [miniconda](https://conda.io/docs/user-guide/install/index.html) to create a virtual environment and install all dependencies via scripts. 33 | Seq2Seq-Vis currently works with a special version of OpenNMT-py modified version by [Sebastian Gehrmann](https://github.com/sebastianGehrmann/OpenNMT-py/tree/states_in_translation). We provide a script to install this special branch. 34 | 35 | after installation you should have a file structure like this: 36 | ``` 37 | MyS2S/Seq2Seq-Vis ==> the tool 38 | MyS2S/Seq2Seq-Vis/0316-fakedates/ ==> example data 39 | MyS2S/OpenNMT-py ==> modified OpenNMT 40 | ``` 41 | 42 | ### 1 - Install dependencies (server and client) and create virtual environment 43 | 44 | 45 | create root directory (`MyS2S`)and then: 46 | ```bash 47 | git clone https://github.com/HendrikStrobelt/Seq2Seq-Vis.git 48 | cd Seq2Seq-Vis 49 | ``` 50 | 51 | and run in `/Seq2Seq-Vis`: 52 | 53 | ```bash 54 | source setup_cpu.sh 55 | ``` 56 | 57 | ### 2 - Install custom OpenNMT-py version 58 | 59 | return to root directory: 60 | ```bash 61 | cd .. 62 | source Seq2Seq-Vis/setup_onmt_custom.sh 63 | ``` 64 | 65 | ### 3 - Download some example data 66 | Here we provide some example data for a character based dataset which converts date strings (e.g. "March 03, 1999" , "03/03/99") into a base form "mm-dd-yyyy". [Download here ~177MB](https://drive.google.com/file/d/1myjJ-surrO76ImnLd4MMJ0-527Ss2e0V/view?usp=sharing) save it to `/Seq2Seq-Vis` and unzip: 67 | 68 | ```bash 69 | unzip fakedates.zip 70 | ``` 71 | 72 | ### 4 - Run the system 73 | 74 | ```bash 75 | python3 server.py --dir 0316-fakedates/ 76 | ``` 77 | go here: [http://localhost:8080/client/index.html?in=M a r c h _ 0 3 , 1 9 9 9](http://localhost:8080/client/index.html?in=M%20a%20r%20c%20h%20_%200%203%20,%20%201%209%209%209) 78 | 79 | You should see: 80 | 81 | 82 | 83 | Enjoy exploring ! 84 | 85 | 86 | 87 | ## Install and run with `docker` 88 | 89 | Thanks, [Samuel Gratzl](https://github.com/sgratzl/Seq2Seq-Vis) for contributing a docker configuration and [image](https://hub.docker.com/r/sgratzl/seq2seq-vis/). 90 | Here are the steps: 91 | 92 | 1. pull image: `docker pull sgratzl/seq2seq-vis` 93 | 2. download data [Download here ~177MB](https://drive.google.com/file/d/1myjJ-surrO76ImnLd4MMJ0-527Ss2e0V/view?usp=sharing) 94 | and unzip: `unzip fakedates.zip` 95 | 3. run container with bound data:
`docker run --rm -it -v "${PWD}/0316-fakedates:/data" -p "8080:8080" sgratzl/seq2seq-vis` 96 | 97 | 98 | 99 | 100 | ## Prepare and run own models 101 | 102 | ### 1 - Prepare your data 103 | You can use any model trained with OpenNMT-py to extract your own data. To gain access to the extraction scripts, follow the instructions above to install the modified OpenNMT-py version. 104 | 105 | 106 | First, create a folder `s2s` that will be used to save all the extractions by calling `mkdir s2s`. 107 | 108 | Then, call 109 | ``` 110 | python extract_context.py -src $your_input_file \ 111 | -tgt $your_target_file \ 112 | -model $your_model.pt \ 113 | -gpu $your_GPU_id (can be ignored for CPU extraction) \ 114 | -batch_size $your_batch_size 115 | 116 | ``` 117 | You can customize the maximum sequence lengths by setting `max_src_len`, and `max_tgt_len` in the script. If you want to restrict the number of examples in your state file, you can uncomment the following lines and set it to your desrired size: 118 | ``` 119 | # if bcounter > 100: 120 | # break 121 | ``` 122 | 123 | The script creates a file in the location `s2s/states.h5`. This file is what you need to create the indices for searching. 124 | 125 | The file for this is located in this directory in `scripts/h5_to_faiss.py`. 126 | Call it three times (once for each type of state) with the parameters 127 | ``` 128 | -states s2s/states.h5 # Your states file location 129 | -data [decoder_out, encoder_out, cstar] # The three datasets within the states h5 file 130 | -output $your_index_name # We recommend just naming them decoder.faiss, encoder.faiss, and context.faiss 131 | -stepsize 100 # you can increase this, this is the number of batches it will add to the index at once. It is bottlenecked by your memory 132 | ``` 133 | 134 | To generate the dictionary and embedding files, modify [this](https://github.com/sebastianGehrmann/OpenNMT-py/blob/states_in_translation/VisServer.py#L369) line with the location of your model and call 135 | 136 | ``` 137 | python VisServer.py 138 | ``` 139 | This will also test that your model works with our server as it calls the same API. The script will create three files: 140 | 141 | - s2s/embs.h5 142 | - s2s/src.dict 143 | - s2s/tgt.dict 144 | 145 | 146 | ### 2 - Create a `s2s.yaml` file to describe project 147 | 148 | ```yaml 149 | # -- minimal config 150 | model: date_acc_100.00_ppl_1.00_e7.pt # model file 151 | dicts: 152 | src: src.dict # source dictionary file 153 | tgt: tgt.dict # target dictionary file 154 | embeddings: embs.h5 # word embeddings for src and tgt 155 | train: train.h5 # training data 156 | 157 | # -- OPTIONAL: FAISS indices for Neighborhoods 158 | indexType: faiss # index type should be 'faiss' (or 'annoy') 159 | indices: 160 | decoder: decoder.faiss # index for decoder states 161 | encoder: encoder.faiss # index for encoder states 162 | 163 | # -- OPTIONAL: model for linear projection 164 | project_model: linear_projection.pkl # pickl-ed scikit-learn model 165 | ``` 166 | 167 | ### 3 - Command Line Parameters 168 | 169 | ``` 170 | usage: server.py [-h] [--nodebug NODEBUG] [--port PORT] 171 | [-dir DIR] 172 | 173 | optional arguments: 174 | --nodebug TRUE if not in debug mode 175 | --port port to run system (default: 8080) 176 | --dir directory with s2s.yaml file 177 | ``` 178 | 179 | # Cite us 180 | 181 | ``` 182 | @ARTICLE{seq2seqvisv1, 183 | author = {{Strobelt}, H. and {Gehrmann}, S. and {Behrisch}, M. and {Perer}, A. and {Pfister}, H. and {Rush}, A.~M.}, 184 | title = "{Seq2Seq-Vis: A Visual Debugging Tool for Sequence-to-Sequence Models}", 185 | journal = {ArXiv e-prints}, 186 | archivePrefix = "arXiv", 187 | eprint = {1804.09299v1}, 188 | primaryClass = "cs.CL", 189 | keywords = {Computer Science - Computation and Language, Computer Science - Artificial Intelligence, Computer Science - Neural and Evolutionary Computing}, 190 | year = 2018, 191 | month = April 192 | } 193 | ``` 194 | 195 | # Contributors 196 | 197 | - Hendrik Strobelt (IBM Research & MIT-IBM Watson AI Lab) 198 | - Sebastian Gehrmann (Harvard NLP) 199 | - Alexander M. Rush (Harvard NLP) 200 | 201 | - Michael Behrisch (Harvard VCG), Adam Perer (IBM Research), Hanspeter Pfister (Harvard VCG) 202 | - PR #16 signed-off-by: Samuel Gratzl 203 | 204 | # License 205 | 206 | Seq2Seq-Vis is licensed under Apache 2 license. 207 | -------------------------------------------------------------------------------- /client/ts/vis/StatePictograms.ts: -------------------------------------------------------------------------------- 1 | import {VComponent} from "./VisualComponent"; 2 | import {D3Sel} from "../etc/LocalTypes"; 3 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 4 | import { 5 | StateProjector, 6 | StateProjectorClickEvent, 7 | StateProjectorData 8 | } from "./StateProjector"; 9 | import {range} from "lodash"; 10 | 11 | 12 | export type PointSegment = { 13 | x: number, y: number, ox: number, oy: number, 14 | id: number, loc: string, transID: number, wordID: number, 15 | ow: number, oh: number, word: string 16 | }; 17 | 18 | export type StatePictogramsHovered = { 19 | caller: StatePictograms, 20 | segment: PointSegment, 21 | hovered: boolean 22 | } 23 | 24 | export class StatePictograms extends VComponent { 25 | 26 | css_name = 'statepictograms'; 27 | 28 | static events = { 29 | segmentHovered: 'state_picto_hovered' 30 | }; 31 | 32 | protected hiddenCanvas: D3Sel = null; 33 | 34 | protected options = { 35 | pos: {x: 0, y: 0}, 36 | // canvas: {w: 500, h: 500}, 37 | gridElements: 5 38 | }; 39 | 40 | colors = { 41 | bg: '#fffdfa', 42 | pp: {c: 'black', a: 0.3}, 43 | pl: {w: 2, a: .3, c: ['#4e9c57', '#9c6d9b']}, 44 | select: {w: 2, c: '#9c605f'} 45 | } 46 | 47 | constructor(d3Parent: D3Sel, protected projector: StateProjector, eventHandler?: SimpleEventHandler, options: {} = {}) { 48 | super(d3Parent, eventHandler); 49 | 50 | this.superInit(options, false, true, false); 51 | } 52 | 53 | 54 | protected _init() { 55 | const pOp = this.projector.current; 56 | 57 | this.hiddenCanvas = this.parent.append('canvas').attrs({ 58 | width: pOp.project.w, 59 | height: pOp.project.h, 60 | }) 61 | .style('display', 'none') 62 | 63 | 64 | } 65 | 66 | protected _render(renderData): void { 67 | const pCur = this.projector.current; 68 | const states = this.projector.states; 69 | const col = this.colors; 70 | 71 | const vContext = this.hiddenCanvas.node().getContext('2d'); 72 | 73 | vContext.fillStyle = col.bg; 74 | vContext.fillRect(0, 0, pCur.project.w, pCur.project.h); 75 | 76 | vContext.globalAlpha = col.pp.a; 77 | vContext.fillStyle = col.pp.c; 78 | for (const st of states) { 79 | vContext.beginPath(); 80 | vContext.arc(pCur.xScale(st.pos[0]), pCur.yScale(st.pos[1]), 81 | Math.sqrt(st.occ.length), 0, 2 * Math.PI) 82 | vContext.fill(); 83 | } 84 | 85 | 86 | const gridScale = this.options.gridElements / pCur.project.w; 87 | 88 | const panelWidth = pCur.project.w / this.options.gridElements; 89 | const panelWidthCeil = Math.ceil(panelWidth); 90 | 91 | 92 | const lineSeqs: PointSegment[][] = []; 93 | const pivots = pCur.pivots; 94 | for (const transID of range(pivots.length)) { 95 | const line = pivots[transID]; 96 | 97 | const lineSeq: PointSegment[] = []; 98 | 99 | if (line.length > 0) { 100 | vContext.globalAlpha = col.pl.a; 101 | vContext.strokeStyle = col.pl.c[transID]; 102 | vContext.lineWidth = col.pl.w; 103 | vContext.beginPath(); 104 | for (const pID of range(line.length)) { 105 | const pp = line[pID].pos; 106 | if (pID == 0) vContext.moveTo(pCur.xScale(pp[0]), pCur.yScale(pp[1])); 107 | else vContext.lineTo(pCur.xScale(pp[0]), pCur.yScale(pp[1])) 108 | } 109 | vContext.stroke(); 110 | 111 | const wordList = pCur.hasLabels ? this.projector.labels[transID] || [] : [] 112 | 113 | vContext.fillStyle = col.pl.c[transID]; 114 | let wordID = 0; 115 | for (const point of line) { 116 | const ox = pCur.xScale(point.pos[0]); 117 | const oy = pCur.yScale(point.pos[1]); 118 | 119 | const word = wordList[wordID] || '' + wordID; 120 | lineSeq.push({ 121 | ox, oy, 122 | x: Math.floor(ox * gridScale) / gridScale, 123 | y: Math.floor(oy * gridScale) / gridScale, 124 | id: point.id, 125 | loc: this.projector.loc, 126 | transID, 127 | wordID, 128 | ow: panelWidthCeil, 129 | oh: panelWidthCeil, 130 | word 131 | }); 132 | 133 | vContext.beginPath(); 134 | vContext.arc(ox, oy, 5, 0, 2 * Math.PI); 135 | vContext.fill() 136 | 137 | wordID += 1; 138 | } 139 | } 140 | 141 | 142 | lineSeqs.push(lineSeq); 143 | } 144 | 145 | 146 | let panRow = this.parent.selectAll('.row').data(lineSeqs); 147 | panRow.exit().remove(); 148 | 149 | panRow = panRow.enter() 150 | .append('div').attr('class', 'row') 151 | .merge(panRow); 152 | 153 | let pCanvasFrame = panRow.selectAll('.pCanvasFrame').data(d => d); 154 | pCanvasFrame.exit().remove(); 155 | 156 | // noinspection JSSuspiciousNameCombination 157 | const pCanvasFrameEnter = pCanvasFrame.enter().append('g').attr('class', 'pCanvasFrame'); 158 | pCanvasFrameEnter.append('canvas').attrs({ 159 | class: 'pCanvas', 160 | width: panelWidthCeil, 161 | height: panelWidthCeil 162 | }); 163 | pCanvasFrameEnter.append('div') 164 | pCanvasFrame = pCanvasFrameEnter.merge(pCanvasFrame); 165 | 166 | 167 | const that = this; 168 | pCanvasFrame.select('.pCanvas').each(function (d: PointSegment) { 169 | const ctx: CanvasRenderingContext2D = (this) 170 | .getContext('2d'); 171 | 172 | ctx.fillStyle = 'black'; 173 | // ctx.fillRect(2,2,10,10); 174 | 175 | ctx.drawImage(that.hiddenCanvas.node(), 176 | d.x, d.y, panelWidth, panelWidth, 177 | 0, 0, panelWidth, panelWidth) 178 | 179 | ctx.beginPath(); 180 | ctx.strokeStyle = col.select.c; 181 | ctx.lineWidth = col.select.w; 182 | ctx.arc(d.ox - d.x, d.oy - d.y, 5, 0, 2 * Math.PI); 183 | ctx.stroke(); 184 | 185 | }); 186 | 187 | pCanvasFrame.select('div').text(d => d.word) 188 | 189 | 190 | pCanvasFrame.on('mouseenter', (d: PointSegment) => { 191 | const details: StatePictogramsHovered = { 192 | caller: this, 193 | segment: d, 194 | hovered: true 195 | }; 196 | this.eventHandler.trigger(StatePictograms.events.segmentHovered, details) 197 | }); 198 | pCanvasFrame.on('mouseleave', (d: PointSegment) => { 199 | const details: StatePictogramsHovered = { 200 | caller: this, 201 | segment: d, 202 | hovered: false 203 | }; 204 | this.eventHandler.trigger(StatePictograms.events.segmentHovered, details) 205 | }); 206 | 207 | pCanvasFrame.on('click', d => { 208 | const neighborIDs = this.projector.myNeighbors(d.transID, d.wordID).map(nn => nn.id); 209 | 210 | const detail: StateProjectorClickEvent = { 211 | caller: this.projector, 212 | loc: this.projector.loc, 213 | pointIDs: [d.id], 214 | neighborIDs 215 | }; 216 | 217 | this.eventHandler 218 | .trigger(StateProjector.events.clicked, detail); 219 | 220 | 221 | }) 222 | 223 | 224 | // vContext.closePath(); 225 | 226 | 227 | } 228 | 229 | actionHighlightSegment(transID: number, wordID: number, hovered: boolean) { 230 | if (hovered) { 231 | this.parent.selectAll('.pCanvasFrame') 232 | .classed('selected', 233 | (d: PointSegment) => d.wordID === wordID && d.transID === transID); 234 | } else { 235 | this.parent.selectAll('.pCanvasFrame') 236 | .classed('selected', false) 237 | } 238 | 239 | } 240 | 241 | 242 | protected _wrangle(data) { 243 | const pOp = this.projector.current; 244 | 245 | this.hiddenCanvas.attrs({ 246 | width: pOp.project.w, 247 | height: pOp.project.h, 248 | }) 249 | 250 | 251 | } 252 | 253 | } 254 | -------------------------------------------------------------------------------- /client/ts/vis/WordLine.ts: -------------------------------------------------------------------------------- 1 | import * as d3 from 'd3'; 2 | import {VComponent} from "./VisualComponent"; 3 | import {SVGMeasurements} from "../etc/SVGplus"; 4 | import {SimpleEventHandler} from "../etc/SimpleEventHandler"; 5 | import {D3Sel, LooseObject} from "../etc/LocalTypes"; 6 | 7 | 8 | enum BoxType {fixed, flow} 9 | 10 | export type WordLineHoverEvent = { 11 | hovered: boolean, 12 | caller: WordLine, 13 | word: LooseObject, 14 | row: number, 15 | col: number, 16 | css_class_main: string 17 | } 18 | 19 | 20 | export interface WordLineData { 21 | /** 22 | * rows (outer) of words (inner) 23 | */ 24 | wordRows: string[][], 25 | wordFill?: string[][], 26 | wordBorder?: string[][], 27 | boxWidth?: number[][], 28 | } 29 | 30 | type WordToken = { text: string, width: number, realWidth?: number } 31 | 32 | // internal use 33 | type WordLineRender = { 34 | rows: WordToken[][], 35 | positions: number[][], 36 | wordFill?: string[][], 37 | wordBorder?: string[][], 38 | boxWidth?: number[][], 39 | } 40 | 41 | type WordCell = { row: number, col: number, word: WordToken }; 42 | 43 | export class WordLine extends VComponent { 44 | 45 | css_name = 'wordline'; 46 | 47 | static events = { 48 | wordHovered: 'wordline_word_hovered', 49 | wordSelected: 'wordline_word_selected' 50 | }; 51 | 52 | static BoxType = BoxType; 53 | 54 | options = { 55 | pos: {x: 0, y: 0}, 56 | text_measurer: null, 57 | box_height: 23, 58 | box_width: 100, // ignored when flow !! 59 | box_type: WordLine.BoxType.flow, 60 | // data_access: (d) => [d.encoder], // [list of [lists of words]] 61 | css_class_main: 'inWord', 62 | css_class_add: '', 63 | x_offset: 3 64 | }; 65 | 66 | /** 67 | * @inheritDoc 68 | * @override 69 | * @return {Array} 70 | */ 71 | 72 | //-- default constructor -- 73 | constructor(d3Parent: D3Sel, eventHandler?: SimpleEventHandler, options: {} = {}) { 74 | super(d3Parent, eventHandler); 75 | this.superInit(options); 76 | } 77 | 78 | _init() { 79 | this.options.text_measurer = this.options.text_measurer 80 | || new SVGMeasurements(this.parent, 'measureWord'); 81 | } 82 | 83 | 84 | _wrangle(data: WordLineData): WordLineRender { 85 | const op = this.options; 86 | 87 | 88 | let rows = []; 89 | 90 | 91 | const toWordFlow = token => ({ 92 | text: token, 93 | width: Math.max(op.text_measurer.textLength(token), 20) 94 | }); 95 | const toWordFixed = token => ({ 96 | text: token, 97 | width: op.box_width - 10, 98 | realWidth: op.text_measurer.textLength(token) 99 | }); 100 | 101 | if (op.box_type === WordLine.BoxType.fixed) { 102 | rows = data.wordRows.map(row => 103 | row.map(w => toWordFixed(w))) 104 | } else { 105 | rows = data.wordRows.map(row => 106 | row.map(w => toWordFlow(w))) 107 | } 108 | 109 | 110 | const allLengths = []; 111 | const calcPos = words => { 112 | let inc = 0; 113 | const rr = [...words.map(w => { 114 | const res = inc; 115 | inc += +w.width + 10; 116 | return res 117 | })]; 118 | allLengths.push(inc); 119 | return rr; 120 | }; 121 | 122 | // todo: merge with data 123 | const positions = rows.map(row => calcPos(row)); 124 | 125 | this.parent.attrs({ 126 | width: d3.max(allLengths) + 6, 127 | height: rows.length * (op.box_height) - 2 128 | }); 129 | 130 | this._current.selectedWord = null; 131 | this._current.clearSelections = true; 132 | 133 | this._current.customFill = data.wordFill != null; 134 | this._current.customBorder = data.wordBorder != null; 135 | this._current.customBoxWidth = data.boxWidth != null; 136 | 137 | return { 138 | rows, 139 | positions, 140 | wordFill: data.wordFill, 141 | wordBorder: data.wordBorder, 142 | boxWidth: data.boxWidth 143 | }; 144 | 145 | } 146 | 147 | private hoverWord(d: WordCell, i, hovered) { 148 | const detail = { 149 | hovered, 150 | caller: this, 151 | word: d, 152 | row: d.row, 153 | col: d.col, 154 | css_class_main: this.options.css_class_main 155 | } 156 | 157 | this.eventHandler.trigger( 158 | WordLine.events.wordHovered, detail) 159 | } 160 | 161 | private clickWord(d: WordCell, i) { 162 | const hovered = !(this._current.selectedWord === i); 163 | this._current.selectedWord = hovered ? i : null; 164 | 165 | const detail = { 166 | hovered, 167 | caller: this, 168 | row: d.row, 169 | word: d, 170 | col: d.col, 171 | css_class_main: this.options.css_class_main 172 | }; 173 | this.eventHandler.trigger(WordLine.events.wordSelected, detail) 174 | 175 | } 176 | 177 | 178 | actionHighlightWord(row: number, col: number, highlight: boolean, exclusive = false, label = 'highlight'): void { 179 | 180 | // console.log(this.options.css_class_main, this.base.selectAll(`.${this.options.css_class_main}`), "--- this.options.css_class_main, this.base.selectAll(`.${this.options.css_class_main}`)"); 181 | // console.log(row, highlight, exclusive, label, "--- word,highlight,exclusive,label"); 182 | this.base.selectAll(`.${this.options.css_class_main}`) 183 | .classed(label, function (d: WordCell) { 184 | if ((d.row === row) && (d.col === col)) { 185 | return highlight; 186 | } else { 187 | if (exclusive) return false; 188 | else return d3.select(this).classed(label) 189 | } 190 | }) 191 | 192 | } 193 | 194 | _render(render: WordLineRender) { 195 | const op = this.options; 196 | const that = this; 197 | 198 | // [rows of [words of {wordRect, wordText}]] 199 | 200 | let rows = this.base.selectAll('.word_row').data(render.rows); 201 | rows.exit().remove(); 202 | rows = rows.enter() 203 | .append('g').attr('class', 'word_row') 204 | .merge(rows) 205 | .attr('transform', (_, i) => `translate(${op.x_offset},${(i) * (op.box_height)})`); 206 | 207 | let words = rows.selectAll(`.${op.css_class_main}`) 208 | .data((row, rowID) => row.map((word, col) => ({ 209 | row: rowID, 210 | word, 211 | col 212 | }))); 213 | words.exit().remove(); 214 | 215 | const wordsEnter = words.enter() 216 | .append('g').attr('class', `${op.css_class_main} ${op.css_class_add}`); 217 | wordsEnter.append('rect').attrs({ 218 | x: -3, 219 | y: 0, 220 | height: op.box_height - 2, 221 | rx: 3, 222 | ry: 3 223 | }); 224 | wordsEnter.append('text'); 225 | 226 | 227 | /**** UPDATE ***/ 228 | const allWords = wordsEnter.merge(words) 229 | .attrs({'transform': (w: any, i) => `translate(${render.positions[w.row][i]},0)`,}) 230 | .on('mouseenter', (d, i) => this.hoverWord(d, i, true)) 231 | .on('mouseout', (d, i) => this.hoverWord(d, i, false)) 232 | .on('click', (d, i) => this.clickWord(d, i)); 233 | 234 | 235 | const allR = allWords.select('rect'); 236 | if (this._current.customBoxWidth) { 237 | allR.attr('width', (d: any, i) => render.boxWidth[d.row][i] + 6); 238 | } else { 239 | allR.attr('width', (d: any) => d.word.width + 6); 240 | } 241 | 242 | if (this._current.customFill) { 243 | allR.style('fill', (d: any, i) => render.wordFill[d.row][i]) 244 | } 245 | 246 | if (this._current.customBorder) { 247 | allR.style('stroke', (d: any, i) => render.wordBorder[d.row][i]) 248 | } 249 | 250 | allWords.select('text').attr('transform', (d: any) => { 251 | const w = d.word; 252 | if (op.box_type === WordLine.BoxType.fixed 253 | && w.width < w.realWidth && w.realWidth > 0) 254 | return `translate(${d.word.width * .5},${Math.floor(op.box_height / 2)})scale(${w.width / w.realWidth},1)` 255 | else 256 | return `translate(${d.word.width * .5},${Math.floor(op.box_height / 2)})` 257 | }).text((d: any) => d.word.text); 258 | 259 | 260 | if (this._current.clearSelections) { 261 | this.actionHighlightWord(-1, -1, false, 262 | true, 'selected'); 263 | this._current.clearSelections = false; 264 | } 265 | 266 | 267 | } 268 | 269 | get positions() { 270 | return this.renderData.positions; 271 | } 272 | 273 | get rows(): WordToken[][] { 274 | return this.renderData.rows; 275 | } 276 | 277 | get firstRowPlainWords() { 278 | return this.renderData.rows[0].map(word => word.text) 279 | } 280 | 281 | 282 | } 283 | -------------------------------------------------------------------------------- /client/css/vis.scss: -------------------------------------------------------------------------------- 1 | @import "global"; 2 | 3 | $encoder-color-ref: #4363aa; 4 | //$encoder-color: #5ba2ec; 5 | $encoder-color: #98b7d9; 6 | $decoder-color: #f5de93; 7 | $attn-color: #6c7067; 8 | $highlight-color: #f0454a; 9 | 10 | $pivot-color: #4e9c57; 11 | $compare-color: #9c6d9b; 12 | 13 | .rect_highlight { 14 | stroke: $highlight-color; 15 | stroke-width: 2px; 16 | //stroke-dasharray: 1,1; 17 | } 18 | 19 | .small_line { 20 | fill: none; 21 | stroke: $main-color; 22 | pointer-events: none; 23 | stroke-width: 3px; 24 | } 25 | 26 | .measureWord { 27 | alignment-baseline: middle; 28 | text-anchor: middle; 29 | font-weight: bold; 30 | pointer-events: none; 31 | } 32 | 33 | .sideIndicator { 34 | padding: 5px; 35 | text-align: center; 36 | border-top: 3px solid; 37 | font-weight: bold; 38 | //color: $main-color; 39 | } 40 | 41 | .side_pivot { 42 | //background-color: rgba( $pivot-color, .3); 43 | border-top-color: $pivot-color; 44 | color: $pivot-color; 45 | } 46 | 47 | .side_compare { 48 | //background-color: rgba( $compare-color, .3); 49 | border-top-color: $compare-color; 50 | color: $compare-color; 51 | } 52 | 53 | .inWord { 54 | text { 55 | @extend .measureWord 56 | } 57 | rect { 58 | fill: $encoder-color; 59 | //fill-opacity: .9; 60 | } 61 | 62 | &.selected { 63 | text { 64 | fill: #ffffff; 65 | } 66 | 67 | rect { 68 | fill: #f0454a; 69 | fill-opacity: 1; 70 | } 71 | } 72 | 73 | &.highlight rect { 74 | @extend .rect_highlight 75 | } 76 | 77 | } 78 | 79 | .outWord { 80 | text { 81 | @extend .measureWord 82 | } 83 | rect { 84 | fill: $decoder-color; 85 | //fill-opacity: 1; 86 | } 87 | 88 | &.highlight rect { 89 | @extend .rect_highlight 90 | } 91 | 92 | &.selected { 93 | text { 94 | fill: #fff; 95 | } 96 | 97 | rect { 98 | fill: #f0454a; 99 | fill-opacity: 1; 100 | } 101 | } 102 | } 103 | 104 | .topKWord { 105 | text { 106 | @extend .measureWord 107 | } 108 | rect { 109 | fill: #cccccc; 110 | //opacity: .6; 111 | } 112 | &.highlight rect { 113 | @extend .rect_highlight 114 | } 115 | } 116 | 117 | .attn_graph { 118 | path { 119 | fill: none; 120 | stroke: $attn-color; 121 | opacity: .3; 122 | //transition: .5s; 123 | } 124 | &.highlight path { 125 | opacity: .7; 126 | stroke: $highlight-color; 127 | //transition: .5s; 128 | } 129 | } 130 | 131 | .state_line { 132 | 133 | fill: none; 134 | stroke: grey; 135 | opacity: .1; 136 | 137 | &:hover { 138 | opacity: .7; 139 | stroke: blue; 140 | transition: .5s; 141 | } 142 | } 143 | 144 | .state_axis { 145 | line, path { 146 | stroke: lightgrey; 147 | } 148 | text { 149 | fill: grey; 150 | } 151 | } 152 | 153 | .setup { 154 | .bar { 155 | fill: #d4d4d4; 156 | } 157 | } 158 | 159 | .wordprojector { 160 | //text { 161 | //dominant-baseline: middle; 162 | text-anchor: middle; 163 | opacity: .8; 164 | font-size: 8pt; 165 | font-weight: bold; 166 | 167 | rect { 168 | fill: lightgrey; //$main-color; 169 | } 170 | 171 | .query { 172 | rect { 173 | fill: none; 174 | 175 | } 176 | } 177 | 178 | .word { 179 | text { 180 | pointer-events: none; 181 | dominant-baseline: middle; 182 | } 183 | 184 | rect:hover { 185 | @extend .rect_highlight 186 | } 187 | 188 | &.selected { 189 | text { 190 | fill: lightgray; 191 | } 192 | 193 | rect { 194 | fill: $encoder-color-ref !important; 195 | fill-opacity: 1 !important; 196 | 197 | } 198 | } 199 | 200 | } 201 | 202 | //} 203 | } 204 | 205 | .close_word_list { 206 | alignment-baseline: middle; 207 | font-size: 15px; 208 | font-weight: bold; 209 | .word { 210 | opacity: .6; 211 | &:hover { 212 | opacity: 1 213 | } 214 | 215 | } 216 | 217 | .scoreBar { 218 | rect { 219 | fill: #6d859c; 220 | } 221 | 222 | } 223 | 224 | .distBar { 225 | rect { 226 | fill: #ffc466; 227 | } 228 | 229 | } 230 | .barText { 231 | fill: #4a4a4a; 232 | opacity: 1; 233 | font-size: 9px; 234 | font-weight: bold; 235 | alignment-baseline: middle; 236 | } 237 | 238 | .wordComp { 239 | font-weight: 200; 240 | } 241 | } 242 | 243 | .neighborState { 244 | rect { 245 | fill: #d3d3d3 246 | } 247 | .line { 248 | stroke: #1d5e85; 249 | stroke-opacity: .2; 250 | } 251 | 252 | } 253 | 254 | .stateprojector { 255 | .pp { 256 | opacity: .3; 257 | 258 | &:hover { 259 | opacity: .9; 260 | } 261 | &.selected { 262 | stroke: $highlight-color; 263 | stroke-width: 2px; 264 | opacity: 1; 265 | } 266 | 267 | } 268 | .pl { 269 | fill: none; 270 | stroke: $pivot-color; 271 | stroke-width: 2px; 272 | stroke-opacity: .4; 273 | &.pl_1 { 274 | stroke: $compare-color; 275 | } 276 | 277 | } 278 | 279 | .plPoint { 280 | fill: $pivot-color; 281 | fill-opacity: .4; 282 | &.pl_1 { 283 | fill: $compare-color; 284 | } 285 | 286 | &:hover { 287 | stroke: $highlight-color; 288 | } 289 | 290 | &.selected { 291 | stroke: $highlight-color; 292 | fill-opacity: .6; 293 | stroke-width: 2px; 294 | } 295 | } 296 | 297 | .plLabel { 298 | text-anchor: middle; 299 | pointer-events: none; 300 | font-weight: bold; 301 | opacity: .5; 302 | } 303 | 304 | .endPoint { 305 | fill: #9c605f 306 | } 307 | .startPoint { 308 | fill: #7d9c5f 309 | } 310 | 311 | .hoverLine { 312 | fill: none; 313 | stroke: $highlight-color; 314 | opacity: .1; 315 | stroke-width: 2px; 316 | pointer-events: none; 317 | 318 | } 319 | 320 | .highlightRect { 321 | stroke: $highlight-color; 322 | stroke-width: 1px; 323 | stroke-dasharray: 1, 2; 324 | fill: none; 325 | stroke-opacity: .5; 326 | } 327 | 328 | } 329 | 330 | .info_panel { 331 | 332 | .translation { 333 | .fa { 334 | color: lightgrey; 335 | 336 | &.selected { 337 | color: #4a4a4a; 338 | } 339 | &:hover { 340 | color: black; 341 | } 342 | 343 | } 344 | 345 | font-weight: bold; 346 | margin-bottom: 8px; 347 | } 348 | 349 | .menu { 350 | //padding-top: 10px; 351 | padding-bottom: 10px; 352 | //margin-top: 5px; 353 | //margin-bottom: 5px; 354 | //background-color: lightgrey; 355 | 356 | .show_src { 357 | //background-color: #adc2d9; 358 | border: 2px $encoder-color solid; 359 | 360 | &.selected { 361 | background-color: $encoder-color; 362 | 363 | } 364 | 365 | } 366 | .show_tgt { 367 | border: 2px $decoder-color solid; 368 | 369 | &.selected { 370 | background-color: $decoder-color; 371 | 372 | } 373 | } 374 | 375 | .show_btn { 376 | padding: 6px 8px; // to adjust for borders !! 377 | } 378 | 379 | .offset_btn { 380 | //background-color: lightgrey; 381 | &.selected { 382 | background-color: $highlight-color; 383 | } 384 | } 385 | 386 | } 387 | 388 | .src { 389 | padding: 3px 3px 3px 0; 390 | border: $encoder-color solid 1px; 391 | margin-bottom: 1px; 392 | &::before { 393 | content: ""; 394 | padding: 3px 3px 3px 12px; 395 | margin-right: 3px; 396 | background-color: $encoder-color; 397 | } 398 | } 399 | .tgt { 400 | padding: 3px 3px 3px 0; 401 | border: $decoder-color solid 1px; 402 | &::before { 403 | content: ""; 404 | padding: 3px 3px 3px 12px; 405 | margin-right: 3px; 406 | background-color: $decoder-color; 407 | } 408 | } 409 | 410 | .highlight { 411 | padding: 1px 4px 1px 4px; 412 | color: white; 413 | background-color: rgba($highlight-color, .7); 414 | //border: $highlight-color 2px solid; 415 | 416 | } 417 | } 418 | 419 | .pCanvasFrame { 420 | display: inline-block; 421 | padding: 5px; 422 | text-align: center; 423 | 424 | .pCanvas { 425 | border: #d3d3d3 2px solid; 426 | border-radius: 5px; 427 | } 428 | div { 429 | font-weight: bold; 430 | } 431 | 432 | &.selected { 433 | 434 | .pCanvas { 435 | border: $highlight-color 2px solid; 436 | } 437 | div { 438 | color: $highlight-color; 439 | } 440 | 441 | } 442 | 443 | } 444 | 445 | .beamtreevis { 446 | .node { 447 | font-weight: bold; 448 | //&.topBeam { 449 | // fill: #9c605f; 450 | //} 451 | 452 | } 453 | .link { 454 | fill: none; 455 | stroke: #e6e6e6; 456 | stroke-width: 2px; 457 | &.topBeam { 458 | stroke: #cccccc; 459 | stroke-width: 5px; 460 | } 461 | } 462 | 463 | } 464 | 465 | .btn { 466 | //font-weight: normal; 467 | font-size: 9pt; 468 | padding: 8px 8px; 469 | cursor: pointer; 470 | background-color: lightgrey; 471 | display: inline-block; 472 | border-radius: 5px; 473 | &:hover { 474 | background-color: #c6c6c6; 475 | } 476 | 477 | &.selected { 478 | background-color: #767676; 479 | color: white; 480 | } 481 | 482 | } 483 | 484 | .btn_center { 485 | border-radius: 0; 486 | } 487 | 488 | .btn_left { 489 | border-radius: 5px 0 0 5px; 490 | } 491 | 492 | .btn_right { 493 | border-radius: 0 5px 5px 0; 494 | } 495 | 496 | //.btn_round{ 497 | // border-radius: 5px; 498 | //} 499 | 500 | #projectorPanel { 501 | .menu { 502 | padding-bottom: 10px; 503 | } 504 | //.btn_project_edges, .btn_project_nodes { 505 | // //border: 2px solid; 506 | // &.selected { 507 | // background-color: #767676; 508 | // color: white; 509 | // } 510 | //} 511 | } 512 | 513 | 514 | 515 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /scripts/faiss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD+Patents license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | #@nolint 8 | 9 | # not linting this file because it imports * form swigfaiss, which 10 | # causes a ton of useless warnings. 11 | 12 | import numpy as np 13 | import sys 14 | import inspect 15 | import pdb 16 | 17 | 18 | # we import * so that the symbol X can be accessed as faiss.X 19 | 20 | try: 21 | from swigfaiss_gpu import * 22 | except ImportError as e: 23 | 24 | if e.args[0] != 'ImportError: No module named swigfaiss_gpu': 25 | # swigfaiss_gpu is there but failed to load: Warn user about it. 26 | sys.stderr.write("Failed to load GPU Faiss: %s\n" % e.args[0]) 27 | sys.stderr.write("Faiss falling back to CPU-only.\n") 28 | from swigfaiss import * 29 | 30 | 31 | ################################################################## 32 | # The functions below add or replace some methods for classes 33 | # this is to be able to pass in numpy arrays directly 34 | # The C++ version of the classnames will be suffixed with _c 35 | ################################################################## 36 | 37 | 38 | def replace_method(the_class, name, replacement, ignore_missing=False): 39 | try: 40 | orig_method = getattr(the_class, name) 41 | except AttributeError: 42 | if ignore_missing: 43 | return 44 | raise 45 | if orig_method.__name__ == 'replacement_' + name: 46 | # replacement was done in parent class 47 | return 48 | setattr(the_class, name + '_c', orig_method) 49 | setattr(the_class, name, replacement) 50 | 51 | 52 | def handle_Clustering(): 53 | def replacement_train(self, x, index): 54 | assert x.flags.contiguous 55 | n, d = x.shape 56 | assert d == self.d 57 | self.train_c(n, swig_ptr(x), index) 58 | replace_method(Clustering, 'train', replacement_train) 59 | 60 | 61 | handle_Clustering() 62 | 63 | 64 | def handle_Quantizer(the_class): 65 | 66 | def replacement_train(self, x): 67 | n, d = x.shape 68 | assert d == self.d 69 | self.train_c(n, swig_ptr(x)) 70 | 71 | def replacement_compute_codes(self, x): 72 | n, d = x.shape 73 | assert d == self.d 74 | codes = np.empty((n, self.code_size), dtype='uint8') 75 | self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) 76 | return codes 77 | 78 | def replacement_decode(self, codes): 79 | n, cs = codes.shape 80 | assert cs == self.code_size 81 | x = np.empty((n, self.d), dtype='float32') 82 | self.decode_c(swig_ptr(codes), swig_ptr(x), n) 83 | return x 84 | 85 | replace_method(the_class, 'train', replacement_train) 86 | replace_method(the_class, 'compute_codes', replacement_compute_codes) 87 | replace_method(the_class, 'decode', replacement_decode) 88 | 89 | 90 | handle_Quantizer(ProductQuantizer) 91 | handle_Quantizer(ScalarQuantizer) 92 | 93 | 94 | def handle_Index(the_class): 95 | 96 | def replacement_add(self, x): 97 | assert x.flags.contiguous 98 | n, d = x.shape 99 | assert d == self.d 100 | self.add_c(n, swig_ptr(x)) 101 | 102 | def replacement_add_with_ids(self, x, ids): 103 | n, d = x.shape 104 | assert d == self.d 105 | assert ids.shape == (n, ), 'not same nb of vectors as ids' 106 | self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) 107 | 108 | def replacement_train(self, x): 109 | assert x.flags.contiguous 110 | n, d = x.shape 111 | assert d == self.d 112 | self.train_c(n, swig_ptr(x)) 113 | 114 | def replacement_search(self, x, k): 115 | n, d = x.shape 116 | assert d == self.d 117 | distances = np.empty((n, k), dtype=np.float32) 118 | labels = np.empty((n, k), dtype=np.int64) 119 | self.search_c(n, swig_ptr(x), 120 | k, swig_ptr(distances), 121 | swig_ptr(labels)) 122 | return distances, labels 123 | 124 | def replacement_search_and_reconstruct(self, x, k): 125 | n, d = x.shape 126 | assert d == self.d 127 | distances = np.empty((n, k), dtype=np.float32) 128 | labels = np.empty((n, k), dtype=np.int64) 129 | recons = np.empty((n, k, d), dtype=np.float32) 130 | self.search_and_reconstruct_c(n, swig_ptr(x), 131 | k, swig_ptr(distances), 132 | swig_ptr(labels), 133 | swig_ptr(recons)) 134 | return distances, labels, recons 135 | 136 | def replacement_remove_ids(self, x): 137 | if isinstance(x, IDSelector): 138 | sel = x 139 | else: 140 | assert x.ndim == 1 141 | sel = IDSelectorBatch(x.size, swig_ptr(x)) 142 | return self.remove_ids_c(sel) 143 | 144 | def replacement_reconstruct(self, key): 145 | x = np.empty(self.d, dtype=np.float32) 146 | self.reconstruct_c(key, swig_ptr(x)) 147 | return x 148 | 149 | def replacement_reconstruct_n(self, n0, ni): 150 | x = np.empty((ni, self.d), dtype=np.float32) 151 | self.reconstruct_n_c(n0, ni, swig_ptr(x)) 152 | return x 153 | 154 | def replacement_update_vectors(self, keys, x): 155 | n = keys.size 156 | assert keys.shape == (n, ) 157 | assert x.shape == (n, self.d) 158 | self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) 159 | 160 | def replacement_range_search(self, x, thresh): 161 | n, d = x.shape 162 | assert d == self.d 163 | res = RangeSearchResult(n) 164 | self.range_search_c(n, swig_ptr(x), thresh, res) 165 | # get pointers and copy them 166 | lims = rev_swig_ptr(res.lims, n + 1).copy() 167 | nd = int(lims[-1]) 168 | D = rev_swig_ptr(res.distances, nd).copy() 169 | I = rev_swig_ptr(res.labels, nd).copy() 170 | return lims, D, I 171 | 172 | replace_method(the_class, 'add', replacement_add) 173 | replace_method(the_class, 'add_with_ids', replacement_add_with_ids) 174 | replace_method(the_class, 'train', replacement_train) 175 | replace_method(the_class, 'search', replacement_search) 176 | replace_method(the_class, 'remove_ids', replacement_remove_ids) 177 | replace_method(the_class, 'reconstruct', replacement_reconstruct) 178 | replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) 179 | replace_method(the_class, 'range_search', replacement_range_search) 180 | replace_method(the_class, 'update_vectors', replacement_update_vectors, 181 | ignore_missing=True) 182 | replace_method(the_class, 'search_and_reconstruct', 183 | replacement_search_and_reconstruct, ignore_missing=True) 184 | 185 | def handle_VectorTransform(the_class): 186 | 187 | def apply_method(self, x): 188 | assert x.flags.contiguous 189 | n, d = x.shape 190 | assert d == self.d_in 191 | y = np.empty((n, self.d_out), dtype=np.float32) 192 | self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) 193 | return y 194 | 195 | def replacement_reverse_transform(self, x): 196 | n, d = x.shape 197 | assert d == self.d_out 198 | y = np.empty((n, self.d_in), dtype=np.float32) 199 | self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) 200 | return y 201 | 202 | def replacement_vt_train(self, x): 203 | assert x.flags.contiguous 204 | n, d = x.shape 205 | assert d == self.d_in 206 | self.train_c(n, swig_ptr(x)) 207 | 208 | replace_method(the_class, 'train', replacement_vt_train) 209 | # apply is reserved in Pyton... 210 | the_class.apply_py = apply_method 211 | replace_method(the_class, 'reverse_transform', 212 | replacement_reverse_transform) 213 | 214 | 215 | def handle_AutoTuneCriterion(the_class): 216 | def replacement_set_groundtruth(self, D, I): 217 | if D: 218 | assert I.shape == D.shape 219 | self.nq, self.gt_nnn = I.shape 220 | self.set_groundtruth_c( 221 | self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) 222 | 223 | def replacement_evaluate(self, D, I): 224 | assert I.shape == D.shape 225 | assert I.shape == (self.nq, self.nnn) 226 | return self.evaluate_c(swig_ptr(D), swig_ptr(I)) 227 | 228 | replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) 229 | replace_method(the_class, 'evaluate', replacement_evaluate) 230 | 231 | 232 | def handle_ParameterSpace(the_class): 233 | def replacement_explore(self, index, xq, crit): 234 | assert xq.shape == (crit.nq, index.d) 235 | ops = OperatingPoints() 236 | self.explore_c(index, crit.nq, swig_ptr(xq), 237 | crit, ops) 238 | return ops 239 | replace_method(the_class, 'explore', replacement_explore) 240 | 241 | 242 | this_module = sys.modules[__name__] 243 | 244 | 245 | for symbol in dir(this_module): 246 | obj = getattr(this_module, symbol) 247 | # print symbol, isinstance(obj, (type, types.ClassType)) 248 | if inspect.isclass(obj): 249 | the_class = obj 250 | if issubclass(the_class, Index): 251 | handle_Index(the_class) 252 | 253 | if issubclass(the_class, VectorTransform): 254 | handle_VectorTransform(the_class) 255 | 256 | if issubclass(the_class, AutoTuneCriterion): 257 | handle_AutoTuneCriterion(the_class) 258 | 259 | if issubclass(the_class, ParameterSpace): 260 | handle_ParameterSpace(the_class) 261 | 262 | 263 | def index_cpu_to_gpu_multiple_py(resources, index, co=None): 264 | """builds the C++ vectors for the GPU indices and the 265 | resources. Handles the common case where the resources are assigned to 266 | the first len(resources) GPUs""" 267 | vres = GpuResourcesVector() 268 | vdev = IntVector() 269 | for i, res in enumerate(resources): 270 | vdev.push_back(i) 271 | vres.push_back(res) 272 | return index_cpu_to_gpu_multiple(vres, vdev, index, co) 273 | 274 | 275 | def index_cpu_to_all_gpus(index, co=None, ngpu=-1): 276 | if ngpu == -1: 277 | ngpu = get_num_gpus() 278 | res = [StandardGpuResources() for i in range(ngpu)] 279 | index2 = index_cpu_to_gpu_multiple_py(res, index, co) 280 | index2.dont_dealloc = res 281 | return index2 282 | 283 | 284 | # mapping from vector names in swigfaiss.swig and the numpy dtype names 285 | vector_name_map = { 286 | 'Float': 'float32', 287 | 'Byte': 'uint8', 288 | 'Uint64': 'uint64', 289 | 'Long': 'int64', 290 | 'Int': 'int32', 291 | 'Double': 'float64' 292 | } 293 | 294 | def vector_to_array(v): 295 | """ convert a C++ vector to a numpy array """ 296 | classname = v.__class__.__name__ 297 | assert classname.endswith('Vector') 298 | dtype = np.dtype(vector_name_map[classname[:-6]]) 299 | a = np.empty(v.size(), dtype=dtype) 300 | memcpy(swig_ptr(a), v.data(), a.nbytes) 301 | return a 302 | 303 | 304 | def vector_float_to_array(v): 305 | return vector_to_array(v) 306 | 307 | 308 | def copy_array_to_vector(a, v): 309 | """ copy a numpy array to a vector """ 310 | n, = a.shape 311 | classname = v.__class__.__name__ 312 | assert classname.endswith('Vector') 313 | dtype = np.dtype(vector_name_map[classname[:-6]]) 314 | assert dtype == a.dtype, ( 315 | 'cannot copy a %s array to a %s (should be %s)' % ( 316 | a.dtype, classname, dtype)) 317 | v.resize(n) 318 | memcpy(v.data(), swig_ptr(a), a.nbytes) 319 | 320 | 321 | class Kmeans: 322 | 323 | def __init__(self, d, k, niter=25, verbose=False, spherical = False): 324 | self.d = d 325 | self.k = k 326 | self.cp = ClusteringParameters() 327 | self.cp.niter = niter 328 | self.cp.verbose = verbose 329 | self.cp.spherical = spherical 330 | self.centroids = None 331 | 332 | def train(self, x): 333 | assert x.flags.contiguous 334 | n, d = x.shape 335 | assert d == self.d 336 | clus = Clustering(d, self.k, self.cp) 337 | if self.cp.spherical: 338 | self.index = IndexFlatIP(d) 339 | else: 340 | self.index = IndexFlatL2(d) 341 | clus.train(x, self.index) 342 | centroids = vector_float_to_array(clus.centroids) 343 | self.centroids = centroids.reshape(self.k, d) 344 | self.obj = vector_float_to_array(clus.obj) 345 | return self.obj[-1] 346 | 347 | def assign(self, x): 348 | assert self.centroids is not None, "should train before assigning" 349 | index = IndexFlatL2(self.d) 350 | index.add(self.centroids) 351 | D, I = index.search(x, 1) 352 | return D.ravel(), I.ravel() 353 | 354 | 355 | def kmin(array, k): 356 | """return k smallest values (and their indices) of the lines of a 357 | float32 array""" 358 | m, n = array.shape 359 | I = np.zeros((m, k), dtype='int64') 360 | D = np.zeros((m, k), dtype='float32') 361 | ha = float_maxheap_array_t() 362 | ha.ids = swig_ptr(I) 363 | ha.val = swig_ptr(D) 364 | ha.nh = m 365 | ha.k = k 366 | ha.heapify() 367 | ha.addn(n, swig_ptr(array)) 368 | ha.reorder() 369 | return D, I 370 | 371 | 372 | def kmax(array, k): 373 | """return k largest values (and their indices) of the lines of a 374 | float32 array""" 375 | m, n = array.shape 376 | I = np.zeros((m, k), dtype='int64') 377 | D = np.zeros((m, k), dtype='float32') 378 | ha = float_minheap_array_t() 379 | ha.ids = swig_ptr(I) 380 | ha.val = swig_ptr(D) 381 | ha.nh = m 382 | ha.k = k 383 | ha.heapify() 384 | ha.addn(n, swig_ptr(array)) 385 | ha.reorder() 386 | return D, I 387 | 388 | 389 | def rand(n, seed=12345): 390 | res = np.empty(n, dtype='float32') 391 | float_rand(swig_ptr(res), n, seed) 392 | return res 393 | 394 | 395 | def lrand(n, seed=12345): 396 | res = np.empty(n, dtype='int64') 397 | long_rand(swig_ptr(res), n, seed) 398 | return res 399 | 400 | 401 | def randn(n, seed=12345): 402 | res = np.empty(n, dtype='float32') 403 | float_randn(swig_ptr(res), n, seed) 404 | return res 405 | 406 | 407 | def eval_intersection(I1, I2): 408 | """ size of intersection between each line of two result tables""" 409 | n = I1.shape[0] 410 | assert I2.shape[0] == n 411 | k1, k2 = I1.shape[1], I2.shape[1] 412 | ninter = 0 413 | for i in range(n): 414 | ninter += ranklist_intersection_size( 415 | k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) 416 | return ninter 417 | 418 | 419 | def normalize_L2(x): 420 | fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) 421 | 422 | 423 | def replacement_map_add(self, keys, vals): 424 | n, = keys.shape 425 | assert (n,) == keys.shape 426 | self.add_c(n, swig_ptr(keys), swig_ptr(vals)) 427 | 428 | def replacement_map_search_multiple(self, keys): 429 | n, = keys.shape 430 | vals = np.empty(n, dtype='int64') 431 | self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) 432 | return vals 433 | 434 | replace_method(MapLong2Long, 'add', replacement_map_add) 435 | replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) 436 | --------------------------------------------------------------------------------