89 | `);
90 | }
91 | }
92 |
93 | declare global {
94 | interface HTMLElementTagNameMap {
95 | 'message-list': MessageList;
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/components/model-controls.scss:
--------------------------------------------------------------------------------
1 | .controls {
2 | margin: 1em 0;
3 | display: flex;
4 |
5 | select {
6 | margin-left: 0.5em;
7 | }
8 |
9 | progress {
10 | margin: 0 1em;
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/components/model-controls.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Controls to choose model.
20 | */
21 |
22 | import {html, LitElement} from 'lit';
23 |
24 | import {getModels} from '../lit_demo/constants';
25 | import {app} from '../lit_demo/app';
26 |
27 | import {customElement, property} from 'lit/decorators.js';
28 | import styles from './model-controls.scss';
29 |
30 | /**
31 | * Shows controls for model selection, progress bar, and status text.
32 | */
33 | @customElement('model-controls')
34 | export class ModelControls extends LitElement {
35 |
36 | static override styles = [styles];
37 |
38 | @property({attribute: false})
39 | progress: number = 0;
40 |
41 | @property({attribute: false})
42 | status: string = 'Initializing...';
43 |
44 | constructor() {
45 | super();
46 | app.models.addListener(this.onModelUpdate.bind(this));
47 | app.models.load(getModels()[0]);
48 | }
49 |
50 | onModelUpdate(progress: number, message?: string) {
51 | this.progress = progress;
52 | if (message) this.status = message;
53 | }
54 |
55 | onModelChange(event: Event) {
56 | const target = event.target as HTMLSelectElement;
57 | const name = target.value;
58 | app.models.load(name).catch((error) => {
59 | this.status = `ERROR loading model "${name}": ${error}`;
60 | });
61 | }
62 |
63 | async setModel(model: string) {
64 | if (getModels().indexOf(model) === -1) {
65 | throw new Error(`Model "${model}" not found!`);
66 | }
67 | await this.updateComplete;
68 | const dropdown = this.shadowRoot!.querySelector('#model_dropdown') as HTMLSelectElement;
69 | dropdown.value = model;
70 | dropdown.dispatchEvent(new Event('change'));
71 | }
72 |
73 | override render() {
74 | const options = getModels().map((model: string) =>
75 | html``);
76 | return html`
77 |
78 |
79 |
82 |
83 |
${this.status}
84 |
85 | `;
86 | }
87 | }
88 |
89 | declare global {
90 | interface HTMLElementTagNameMap {
91 | 'model-controls': ModelControls;
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/exports.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview some useful exports to play around with the models &
20 | * tokenizers.
21 | *
22 | * Simple usage (see ./playground.html for more complete usage example):
23 | *
24 | * model = lit.Model('tiny');
25 | * model.load(progress => console.log('loading...', progress));
26 | * console.log(model.computeProbabilities(['a dog', 'a cat'], '0'));
27 | */
28 |
29 | import {Model} from './lit_demo/compute';
30 | import {getImageUrl, setBaseUrl} from './lit_demo/constants';
31 | import {ImageData} from './lit_demo/data';
32 | import * as tf from '@tensorflow/tfjs-core';
33 |
34 | // tslint:disable-next-line:no-any Export symbols into global namespace.
35 | (window as any).lit = { Model, getImageUrl, ImageData, setBaseUrl };
36 | // tslint:disable-next-line:no-any Export symbols into global namespace.
37 | // tslint:disable-next-line:ban-module-namespace-object-escape Export all of TF.
38 | (window as any).tf = tf;
39 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/index.html:
--------------------------------------------------------------------------------
1 |
17 |
18 |
19 |
20 |
21 |
22 | Lit Demo App
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
LiT: Zero-Shot Transfer with Locked-image Tuning
31 |
32 |
33 | This page is an interactive demo of the Google AI blog post
34 | LiT: adding language understanding to image models
36 | – please refer to that page for a detailed explanation of how a LiT model works.
37 | If you're interested in how this demo makes a JAX model run on device in your
38 | browser, check out our other blog post
39 | JAX on the Web with TensorFlow.js.
41 |
42 |
43 |
44 | Below you can choose an image from a selection and then write free-form
45 | text prompts that are matched to the image. Once you hit return on your
46 | keyboard or press the "compute" button, a text encoder implemented in
47 | TensorFlow.js
48 | will compute embeddings for the provided text on your local device, and the
49 | similarity of these text embeddings to the image embedding will be displayed.
50 |
51 |
52 |
53 | The prompts can be used to classify an image into multiple categories, listing
54 | each category individually with a prompt "an image of a X". But you can also
55 | probe the model interactively with more detailed prompts, comparing the
56 | different results when small details change in the text.
57 |
58 |
59 |
60 | Please use this demo responsibly. The models will always compare the image to
61 | the prompts you provide, and it is therefore trivial to construct situations
62 | where the model picks from a bunch of bad options.
63 |
64 |
65 |
66 | Note:
67 | The models available in this interactive demo are not those from the
68 | paper.
70 | We had to train much smaller text towers and tokenizers to avoid
71 | overloading your browser. Please see
72 | our GitHub repository
74 | for the models from the paper pre-trained on public datasets.
75 | Multilingual models coming soon.
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/lit_demo/app.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Global app state.
20 | */
21 |
22 | import {ImageData} from './data';
23 | import {Models} from './compute';
24 |
25 | /**
26 | * Container class holding image data and models.
27 | *
28 | * The main application component would typically call `load()` and then show
29 | * the components depending on this class asynchronously.
30 | */
31 | export class App {
32 |
33 | imageData = new ImageData();
34 | models = new Models();
35 |
36 | ready: boolean = false;
37 |
38 | async load() {
39 | await this.imageData.load();
40 | this.ready = true;
41 | }
42 | }
43 |
44 | /**
45 | * Global app state.
46 | */
47 | export const app = new App();
48 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/lit_demo/constants.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Project-wide constants.
20 | */
21 |
22 | // Can be overwritten with setBaseUrl() below.
23 | // let baseUrl = 'https://google-research.github.io/vision_transformer/lit';
24 | let baseUrl = 'https://figur.li/jax2tfjs';
25 | // Can be overwritten with setModels() below.
26 | let models = ['tiny', 'small'];
27 |
28 | /** Allows to set abnew base URL. ase URL on which all other. */
29 | export const setBaseUrl = (newBaseUrl: string) => {
30 | baseUrl = newBaseUrl;
31 | };
32 |
33 | /** Retrieves URL for a model-specific file (vocabulary, embeddings, ...). */
34 | export const getModelFileUrl = (name: string, relativePath: string) => (
35 | `${baseUrl}/data/models/${name}/${relativePath}`
36 | );
37 |
38 | /** Retrieves the URL for images information JSON file. */
39 | export const getImagesInfoUrl = () => `${baseUrl}/data/images/info.json`;
40 |
41 | /** Retrieves the URL for an image. */
42 | export const getImageUrl = (id: string) => `${baseUrl}/data/images/${id}.jpg`;
43 |
44 | /** Returns names of available models. */
45 | export const getModels = () => models;
46 |
47 | /** Sets names of available models. */
48 | export const setModels = (newModels: string[]) => {
49 | models = newModels;
50 | };
51 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/lit_demo/data.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Accessing additional data.
20 | */
21 |
22 | import {getImagesInfoUrl} from './constants';
23 |
24 | /**
25 | * Information about a single image.
26 | */
27 | export interface ImageRow {
28 | /** Stable ID of the image. */
29 | id: string;
30 | /** Set of example prompts for this image. */
31 | prompts: string;
32 | /** License of the image. */
33 | license: string;
34 | /** Where the image was originally downloaded from. */
35 | source: string;
36 | /** Short description of image. */
37 | description: string;
38 | }
39 | /**
40 | * Contains information about all images.
41 | */
42 | export class ImageData {
43 |
44 | rows: ImageRow[] = [];
45 | /** Will be set to `true` when `load()` finishes. */
46 | ready = false;
47 |
48 | /**
49 | * Gets an image by ID. Throws an error if image is not found, data is not
50 | * loaded, or ID is not unique.
51 | */
52 | get(id: string): ImageRow {
53 | if (!this.ready) {
54 | throw new Error('ImageData not loaded!');
55 | }
56 | const matching = this.rows.filter(row => row.id === id);
57 | if (matching.length !== 1) {
58 | throw new Error(`Got unexpected ${matching.length} matches for id="${id}"`);
59 | }
60 | return matching[0];
61 | }
62 |
63 | /**
64 | * Loads image data asynchronously.
65 | */
66 | async load() {
67 | this.rows = (
68 | await fetch(getImagesInfoUrl())
69 | .then(response => {
70 | console.log('response', response);
71 | return response.json();
72 | })
73 | );
74 | this.ready = true;
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/lit_demo/url_utils.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview (De)serialize state from/to URL.
20 | */
21 |
22 | // Should be updated whenever URLs are not compatible anymore
23 | // (e.g. adding new images)
24 | export const VERSION = 'v2';
25 | // version history:
26 | // v1 used row number instead of image id
27 |
28 | const V1_IMAGE_IDS = [
29 | '1', '48', '43', '22', '2', '3', '4', '5', '6', '7', '8', '9',
30 | '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21',
31 | '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34',
32 | '35', '36', '37', '38', '39', '40', '41', '42', '44', '45', '46', '47',
33 | '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60'
34 | ];
35 |
36 | /**
37 | * State that can be stored in the URL.
38 | */
39 | export interface State {
40 | /** Name of the model. */
41 | modelName: string;
42 | /** ID Of the image. */
43 | imageId: string;
44 | /** List of text prompts. */
45 | prompts: string[];
46 | }
47 |
48 | /**
49 | * Returns a URL for provided model/image/prompts.
50 | */
51 | export const getUrl =
52 | (modelName: string, imageId: string, prompts: string[]): string => {
53 | let href = window.location.href;
54 | if (href.indexOf('#') !== -1) {
55 | href = href.substring(0, href.indexOf('#'));
56 | }
57 | const parts = [
58 | VERSION,
59 | modelName,
60 | imageId,
61 | ...prompts,
62 | ];
63 | return href + '#' + parts.map(encodeURIComponent).join('|');
64 | };
65 |
66 | /**
67 | * Parses an URL and returns a `State`, or undefined if no state is spefified.
68 | *
69 | * Raises an exception if there was a problem with the parsing of the URL.
70 | */
71 | export const parseUrl = (): State|undefined => {
72 | const hash = window.location.hash.substring(1);
73 | if (!hash) return;
74 | const parts = hash.split(/\|/g);
75 | if (parts.length < 4) {
76 | throw new Error(`Invalid URL: "${hash}"`);
77 | }
78 | let [version, modelName, imageId, ...texts] = parts;
79 | if (version === VERSION) {
80 | } else if (version === 'v1') {
81 | const idx = Number(imageId);
82 | if (isNaN(idx)) throw new Error(`Expected idx="${idx}" to be numerical!`);
83 | imageId = V1_IMAGE_IDS[idx];
84 | } else {
85 | throw new Error(`Incompatible version: ${version} (supported: ${VERSION})`);
86 | }
87 | return {
88 | modelName,
89 | imageId,
90 | prompts: texts.map(decodeURIComponent),
91 | };
92 | };
93 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/playground.html:
--------------------------------------------------------------------------------
1 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | A simple demonstration how to use LiT models in a JS application using global exports.
24 | See source code of this file for API usage.
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
82 |
83 |
93 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/style.scss:
--------------------------------------------------------------------------------
1 | // General styles for the page.
2 |
3 | @import './style/colors';
4 | @import './style/mixins';
5 |
6 | html {
7 | font-size: 14px;
8 | line-height: 1.6em;
9 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen,
10 | Ubuntu, Cantarell, 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial,
11 | sans-serif;
12 | text-size-adjust: 100%;
13 | -ms-text-size-adjust: 100%;
14 | -webkit-text-size-adjust: 100%;
15 |
16 | @media (min-width: 1200px) {
17 | width: 1024px;
18 | margin: 0 auto;
19 | }
20 | @media (min-width: 768px) {
21 | font-size: 16px;
22 | }
23 |
24 | color: var(--text-fg);
25 | background: var(--text-bg);
26 |
27 | body {
28 | margin: 0;
29 | padding: 0rem 1rem 10rem;
30 | }
31 | }
32 |
33 | a,
34 | a:visited {
35 | color: var(--link-col);
36 | }
37 |
38 | h1 {
39 | font-weight: 700;
40 | font-size: 2rem;
41 | line-height: 1.3em;
42 | }
43 |
44 | p {
45 | font-size: 1.06rem;
46 | line-height: 1.3em;
47 | }
48 |
49 | input {
50 | font-size: 1rem;
51 |
52 | &::placeholder {
53 | color: var(--placeholder-col);
54 | }
55 | }
56 |
57 | .note {
58 | font-style: normal;
59 | border: none;
60 | border-radius: 2px;
61 | margin-left: auto;
62 | margin-right: auto;
63 |
64 | padding: 0.5rem 0.5rem 0.5rem 2rem;
65 | width: 90%;
66 |
67 | @include phone-portrait {
68 | width: 100%;
69 | padding: 0.5rem;
70 | box-sizing: border-box;
71 | }
72 |
73 | background-color: var(--note-bg);
74 | color: var(--note-fg);
75 |
76 | &.warning {
77 | background-color: var(--warn-bg);
78 | color: var(--warn-fg);
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/style/colors.scss:
--------------------------------------------------------------------------------
1 | // Dark and light mode colors.
2 |
3 | :root {
4 | --text-bg: hsl(0, 0%, 97%);
5 | --gray-border: hsla(0, 0%, 0%, 0.1);
6 | --gray: rgba(0, 0, 0, 0.6);
7 | --border-radius: 5px;
8 | --orange: hsl(24, 100%, 50%);
9 | --distill-blue: hsl(200, 50%, 25%);
10 | --blue: #337699;
11 | --green: #3db867;
12 | --text-fg: rgb(15, 15, 15);
13 | --text-red: rgb(220, 0, 0);
14 | --bar-col: rgb(171, 199, 227);
15 | --link-col: rgb(0, 0, 238);
16 | --placeholder-col: rgb(166, 166, 166);
17 | --note-bg: #e1f5fe;
18 | --note-fg: #1a6ebb;
19 | --warn-bg: #ffe1aa;
20 | --warn-fg: #a16800;
21 | --error-bg: #850000;
22 | --error-fg: white;
23 |
24 | @media (prefers-color-scheme: dark) {
25 | --text-bg: rgb(56, 56, 56);
26 | --text-fg: rgb(213, 213, 213);
27 | --bar-col: rgb(20, 109, 163);
28 | --link-col: rgb(66, 165, 245);
29 |
30 | --note-fg: rgb(121 157 190);
31 | --note-bg: rgb(2 59 85);
32 | --warn-bg: #784e00;
33 | --warn-fg: #edbe68;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/style/mixins.scss:
--------------------------------------------------------------------------------
1 | // Useful mixins.
2 |
3 | // To wrap styles that should only trigger for phones in portrait mode.
4 | @mixin phone-portrait {
5 | @media only screen and (max-device-width: 800px) and (orientation: portrait) {
6 | @content;
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tokenizers/common.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Utility code shared between tokenizers.
20 | */
21 |
22 | /**
23 | * A vocabulary consists of a list of tokens, and optional numerical value.
24 | * The numerical value is used by the unigram algorithnm to find the best
25 | * tokenizaion, and is ignored by the BPE algorithm.
26 | */
27 | export type Vocabulary = Array<[string, number]>;
28 |
29 | /**
30 | * Converts a string to a sequence of tokens.
31 | */
32 | export interface Tokenizer {
33 | encode(input: string): number[];
34 | }
35 |
36 | /**
37 | * Factory for new `Tokenizer`.
38 | */
39 | export interface TokenizerConstructor {
40 | new (vocabulary: Vocabulary): Tokenizer;
41 | }
42 |
43 | /**
44 | * Unicode-aware character iteration of strings.
45 | */
46 | export const stringToChars = (input: string): string[] => {
47 | const symbols = [];
48 | for (const symbol of input) {
49 | symbols.push(symbol);
50 | }
51 | return symbols;
52 | };
53 |
54 | /**
55 | * Special separator character used to delimit sub-word tokens.
56 | */
57 | export const TOKEN_SEPARATOR =
58 | '\u2581'; // This is the unicode character 'lower one eighth block'.
59 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tokenizers/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | /**
19 | * @fileoverview Tokenizers and tokenizer mappings.
20 | */
21 |
22 | import {Tokenizer, TokenizerConstructor, Vocabulary} from './common';
23 | import * as sentencepieceBpe from './sentencepiece_bpe';
24 | import * as sentencepieceUnigram from './sentencepiece_unigram';
25 |
26 | export {Tokenizer, Vocabulary} from './common';
27 |
28 | const TOKENIZERS = new Map([
29 | ['BPE', sentencepieceBpe.Tokenizer],
30 | ['UNIGRAM', sentencepieceUnigram.Tokenizer],
31 | ]);
32 |
33 | /**
34 | * Returns a tokenizer of type `name` using `vocabulary`.
35 | */
36 | export const getTokenizer = (name: string, vocabulary: Vocabulary): Tokenizer => {
37 | const ctor = TOKENIZERS.get(name);
38 | if (!ctor) throw new Error(`Unknown tokenizer: ${name}`);
39 | return new ctor(vocabulary);
40 | };
41 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import {stringToChars, TOKEN_SEPARATOR, Vocabulary, Tokenizer as TokenizerInterface} from './common';
19 |
20 | interface Candidate {
21 | piece: string;
22 | pos: number;
23 | score: number;
24 | }
25 |
26 | const scoreDesc = (a: Candidate, b: Candidate) => b.score - a.score;
27 |
28 | function processInput(str: string): string {
29 | const normalized = str.normalize('NFKC');
30 | return normalized.length > 0 ?
31 | TOKEN_SEPARATOR + normalized.replace(/ /g, TOKEN_SEPARATOR) :
32 | normalized;
33 | }
34 |
35 | /**
36 | * Sentencepiece tokenizer implementing the BPE algorithm.
37 | */
38 | export class Tokenizer implements TokenizerInterface {
39 |
40 | // piece -> [score, index]
41 | private readonly map: Map;
42 |
43 | constructor(vocabulary: Vocabulary) {
44 | this.map = new Map();
45 | vocabulary.forEach(([piece, score], idx) => {
46 | if (this.map.has(piece)) {
47 | throw new Error(`Piece "${piece}" occurs multiple times in vocabulary`);
48 | }
49 | this.map.set(piece, [score, idx]);
50 | });
51 | }
52 |
53 | encode(input: string): number[] {
54 | const processed: string = processInput(input);
55 | let pieces: string[] = stringToChars(processed);
56 |
57 | while (true) {
58 | const candidates: Candidate[] = [];
59 | for (let i = 0; i < pieces.length - 1; i++) {
60 | const fused = pieces[i] + pieces[i + 1];
61 | const el = this.map.get(fused);
62 | if (el) {
63 | candidates.push({ piece: fused, pos: i, score: el[0] });
64 | }
65 | }
66 | if (candidates.length === 0) {
67 | break;
68 | }
69 | candidates.sort(scoreDesc);
70 | const best = candidates[0];
71 | pieces = [
72 | ...pieces.slice(0, best.pos),
73 | best.piece,
74 | ...pieces.slice(best.pos + 2)
75 | ];
76 | }
77 |
78 | return pieces.map(piece => this.map.get(piece)![1]);
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_bpe_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import 'jasmine';
19 |
20 | describe('sentencepiece bpe test', () => {
21 | it('computes a thing when asked', () => {});
22 | });
23 |
24 | import * as bpe from './sentencepiece_bpe';
25 | import {TOKEN_SEPARATOR, Vocabulary} from './common';
26 |
27 | const vocab: Vocabulary = [
28 | [TOKEN_SEPARATOR, 0], // 0
29 | ['a', 0], // 1
30 | ['e', 0], // 2
31 | ['s', 0], // 3
32 | ['t', 0], // 4
33 | ['te', -1], // 5
34 | ['st', -2], // 6
35 | ['test', -3], // 7
36 | ['tes', -4], // 8
37 | ];
38 |
39 | describe('BPE Tokenizer', () => {
40 | let tokenizer: bpe.Tokenizer;
41 | beforeAll(() => {
42 | tokenizer = new bpe.Tokenizer(vocab);
43 | });
44 |
45 | it('should tokenize correctly', () => {
46 | expect(tokenizer.encode('a test')).toEqual([0, 1, 0, 7]);
47 | });
48 | });
49 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tokenizers/sentencepiece_unigram_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | import {Tokenizer} from './sentencepiece_unigram';
19 |
20 | const stubbedTokenizerVocab = [
21 | ['�', 0],
22 | ['', 0],
23 | ['', 0],
24 | ['extra_token_id_1', 0],
25 | ['extra_token_id_2', 0],
26 | ['extra_token_id_3', 0],
27 | ['▁', -2],
28 | ['▁a', -1],
29 | ['▁ç', -2],
30 | ['a', -3],
31 | ['.', -1],
32 | ['▁I', -1],
33 | ['▁like', -1],
34 | ['▁it', -1],
35 | ['I', -2],
36 | ['like', -2],
37 | ['it', -2],
38 | ['l', -3],
39 | ['i', -3],
40 | ['k', -3],
41 | ['e', -3],
42 | ['i', -3],
43 | ['t', -3]
44 | ];
45 |
46 | describe('Universal Sentence Encoder tokenizer', () => {
47 | let tokenizer: Tokenizer;
48 | beforeAll(() => {
49 | tokenizer = new Tokenizer(stubbedTokenizerVocab as Array<[string, number]>);
50 | });
51 |
52 | it('basic usage', () => {
53 | expect(tokenizer.encode('Ilikeit.')).toEqual([11, 15, 16, 10]);
54 | });
55 |
56 | it('handles whitespace', () => {
57 | expect(tokenizer.encode('I like it.')).toEqual([11, 12, 13, 10]);
58 | });
59 |
60 | it('should normalize inputs', () => {
61 | expect(tokenizer.encode('ça')).toEqual(tokenizer.encode('c\u0327a'));
62 | });
63 |
64 | it('should handle unknown inputs', () => {
65 | expect(() => tokenizer.encode('😹')).not.toThrow();
66 | });
67 |
68 | it('should treat consecutive unknown inputs as a single word', () => {
69 | expect(tokenizer.encode('a😹😹')).toEqual([7, 0]);
70 | });
71 | });
72 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tokenizers/trie.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright Big Vision Authors
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | // Copied from
19 | // https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts
20 |
21 | import {stringToChars} from './common';
22 |
23 | // [token, score, index]
24 | type OutputNode = [string[], number, number];
25 |
26 | class TrieNode {
27 | parent: TrieNode|null;
28 | end: boolean;
29 | children: {[firstSymbol: string]: TrieNode};
30 | word: OutputNode;
31 |
32 | constructor() {
33 | this.parent = null;
34 | this.children = {};
35 | this.end = false;
36 | this.word = [[], 0, 0];
37 | }
38 | }
39 |
40 | /**
41 | * Simple Trie datastructure.
42 | */
43 | export class Trie {
44 | root: TrieNode;
45 |
46 | constructor() {
47 | this.root = new TrieNode();
48 | }
49 |
50 | /**
51 | * Inserts a token into the trie.
52 | */
53 | insert(word: string, score: number, index: number) {
54 | let node = this.root;
55 |
56 | const symbols = stringToChars(word);
57 |
58 | for (let i = 0; i < symbols.length; i++) {
59 | if (!node.children[symbols[i]]) {
60 | node.children[symbols[i]] = new TrieNode();
61 | node.children[symbols[i]].parent = node;
62 | node.children[symbols[i]].word[0] = node.word[0].concat(symbols[i]);
63 | }
64 |
65 | node = node.children[symbols[i]];
66 | if (i === symbols.length - 1) {
67 | node.end = true;
68 | node.word[1] = score;
69 | node.word[2] = index;
70 | }
71 | }
72 | }
73 |
74 | /**
75 | * Returns an array of all tokens starting with ss.
76 | *
77 | * @param ss The prefix to match on.
78 | */
79 | commonPrefixSearch(ss: string[]): OutputNode[] {
80 | const output: OutputNode[] = [];
81 | let node = this.root.children[ss[0]];
82 |
83 | for (let i = 0; i < ss.length && node; i++) {
84 | if (node.end) {
85 | output.push(node.word);
86 | }
87 | node = node.children[ss[i + 1]];
88 | }
89 |
90 | if (!output.length) {
91 | output.push([[ss[0]], 0, 0]);
92 | }
93 |
94 | return output;
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/big_vision/tools/lit_demo/src/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "outDir": "dist",
4 | "target": "es6",
5 | "module": "commonjs",
6 | "lib": ["dom", "DOM.Iterable", "es2019", "es2020.string"],
7 | "types": ["node", "jasmine", "resize-observer-browser"],
8 | "moduleResolution": "node",
9 | "allowJs": false,
10 | "pretty": true,
11 | "resolveJsonModule": true,
12 | "sourceMap": false,
13 | "skipLibCheck": true,
14 | "removeComments": true,
15 | "esModuleInterop": true,
16 | "importsNotUsedAsValues": "preserve",
17 | "downlevelIteration": true,
18 | "skipDefaultLibCheck": true,
19 | "preserveConstEnums": false,
20 | "experimentalDecorators": true,
21 | "emitDecoratorMetadata": true,
22 | "noErrorTruncation": false,
23 | "noEmitOnError": false,
24 | "declaration": false,
25 | "stripInternal": true,
26 | "inlineSourceMap": true,
27 | "inlineSources": true,
28 | "importHelpers": true,
29 | "allowUnreachableCode": false,
30 | "noFallthroughCasesInSwitch": true,
31 | "noImplicitAny": true,
32 | "noImplicitReturns": false,
33 | "noImplicitThis": true,
34 | "strictBindCallApply": true,
35 | "strictFunctionTypes": true,
36 | "strictNullChecks": false,
37 | "strictPropertyInitialization": false
38 | },
39 | "include": ["./client", "./examples"],
40 | "compileOnSave": false
41 | }
42 |
--------------------------------------------------------------------------------
/big_vision/trainers/proj/flexi/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Few common utils used in both/all flexi-trainers."""
16 | import functools
17 | import itertools
18 | import numpy as np
19 |
20 |
21 | def mkrng(xid, wid, step):
22 | # Need to cap at 0, for example localruns use -1.
23 | rng_key = (max(xid, 0), max(wid, 0), max(step, 0))
24 | return np.random.default_rng(rng_key)
25 |
26 |
27 | def mkprob(x):
28 | if x is None:
29 | return x
30 | return np.array(x) / np.sum(x)
31 |
32 |
33 | def choice(values, ratios, rng=None):
34 | rng = rng or np.random.default_rng()
35 | return rng.choice(values, p=mkprob(ratios))
36 |
37 |
38 | def mkpredictfns(predict_fn, config, template="predict_{x}"):
39 | # If we have two flexi args a=[1,2], b=[10,20], then we create a
40 | # predict_fn for all possible combinations, named "predict_a=1_b=10" etc.
41 | all_combinations = [dict(comb) for comb in itertools.product(
42 | *[[(arg, val) for val in config[arg].v] for arg in config]
43 | )]
44 | return {
45 | template.format(x="_".join(f"{k}={v}" for k, v in kw.items())):
46 | functools.partial(predict_fn, **kw)
47 | for kw in all_combinations}
48 |
--------------------------------------------------------------------------------
/big_vision/trainers/proj/givt/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils for GIVT stage I and II trainers."""
16 |
17 | from typing import Any
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | def unbin_depth(
24 | depth: jax.Array,
25 | *,
26 | min_depth: float,
27 | max_depth: float,
28 | num_bins: int,
29 | ) -> jax.Array:
30 | """Transform a depth map with binned values into a float-valued depth map.
31 |
32 | Args:
33 | depth: Depth map whose binned values are encoded in one-hot fashion along
34 | the last dimension.
35 | min_depth: Minimum binned depth value.
36 | max_depth: Maximum value of binned depth.
37 | num_bins: Number of depth bins.
38 |
39 | Returns:
40 | Float-valued depth map.
41 | """
42 | depth = jnp.argmax(depth, axis=-1)
43 | depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation.
44 | depth /= num_bins
45 | return depth * (max_depth - min_depth) + min_depth
46 |
47 |
48 | def get_local_rng(
49 | seed: int | jax.Array,
50 | batch: Any,
51 | ) -> jax.Array:
52 | """Generate a per-image seed based on the image id or the image values.
53 |
54 | Args:
55 | seed: Random seed from which per-image seeds should be derived.
56 | batch: Pytree containing a batch of images (key "image") and optionally
57 | image ids (key "image/id").
58 |
59 | Returns:
60 | Array containing per-image ids.
61 | """
62 | fake_id = None
63 | if "image" in batch:
64 | fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32)
65 | return jax.lax.scan(
66 | lambda k, x: (jax.random.fold_in(k, x), None),
67 | jax.random.PRNGKey(seed),
68 | batch.get("image/id", fake_id),
69 | )[0]
70 |
71 |
--------------------------------------------------------------------------------
/big_vision/trainers/proj/uvim/coco_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities to inspect coco data and predictions in notebooks."""
16 | # pylint: disable=consider-using-from-import
17 | import functools
18 | import json
19 |
20 | import numpy as np
21 | from panopticapi import utils as pycoco_utils
22 | from skimage import segmentation
23 |
24 | import tensorflow.io.gfile as gfile
25 |
26 |
27 | import os
28 | ROOT = os.environ.get('COCO_DATA_DIR', '.')
29 |
30 |
31 | PANOPTIC_COCO_CATS_FILE = f'{ROOT}/panoptic_coco_categories.json'
32 |
33 |
34 | @functools.lru_cache(maxsize=None)
35 | def _coco_panoptic_categories():
36 | with gfile.GFile(PANOPTIC_COCO_CATS_FILE, 'r') as f:
37 | categories_list = json.load(f)
38 | return tuple(categories_list)
39 |
40 |
41 | def rgb_panoptic_from_twochannels(twochannels, boundaries: bool = False):
42 | """Makes a RGB panoptic output and segments_info from a twochannels view."""
43 | semantics = twochannels[..., 0]
44 | instances = twochannels[..., 1]
45 | max_instances = np.max(instances) + 1
46 | merged = semantics * max_instances + instances
47 | merged = np.where(semantics < 0, semantics, merged)
48 |
49 | categories_list = _coco_panoptic_categories()
50 | categories = {category['id']: category for category in categories_list}
51 | id_generator = pycoco_utils.IdGenerator(categories)
52 | segments_info = {}
53 | rgb = np.zeros((*instances.shape[:2], 3), dtype=np.uint8)
54 |
55 | for merged_id in np.unique(merged):
56 | if merged_id // max_instances > 0:
57 | category = categories_list[int(merged_id // max_instances) - 1]
58 | segment_id, color = id_generator.get_id_and_color(category['id'])
59 | else:
60 | category = {'id': -1, 'name': 'void', 'isthing': False}
61 | segment_id, color = -1, np.array([0, 0, 0])
62 | segments_info[segment_id] = {
63 | 'id': segment_id,
64 | 'color': color,
65 | 'category_id': category['id'],
66 | 'name': category['name'],
67 | 'isthing': category['isthing'],
68 | }
69 | rgb[merged == merged_id] = color
70 |
71 | if boundaries:
72 | boundaries = segmentation.find_boundaries(
73 | pycoco_utils.rgb2id(rgb), mode='thick')
74 | rgb[boundaries] = 0
75 | return rgb, segments_info
76 |
--------------------------------------------------------------------------------
/big_vision/trainers/proj/uvim/colorization_task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Inputs, outputs and losses for colorization task."""
16 | import einops
17 | import jax.numpy as jnp
18 | import numpy as np
19 |
20 | ONE_HOT_AXIS = -2
21 |
22 |
23 | def input_pp(batch, config):
24 | """Make inputs for colorization task."""
25 | if "labels" not in batch:
26 | # During predict of phase2 there is no 'labels' field.
27 | x = None
28 | else:
29 | hp, wp = config.model.patch_size
30 | x = {
31 | "color": batch["labels"],
32 | }
33 | # Convert labels from (B, H, W) to (B, num_patches, C, patch_size)
34 | x["color"] = einops.rearrange(
35 | x["color"], "b (hn hp) (wn wp) c -> b (hn wn) c (hp wp)", hp=hp, wp=wp)
36 | ctx = batch.get("image_ctx", batch.get("image", None))
37 | return {"ctx": ctx, "x": x}
38 |
39 |
40 | def loss_fn(logits, batch, config):
41 | """Compute loss for colorization task."""
42 | labels = input_pp(batch, config)["x"]
43 | error = logits["color"] - labels["color"]
44 | loss = jnp.square(error)
45 | return loss, {"loss_color": loss}
46 |
47 |
48 | def predict_outputs(logits, config):
49 | """Make outputs for colorization task."""
50 | # Map logits to (height, width, channels).
51 | hp, wp = config.model.patch_size
52 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp))
53 | assert ONE_HOT_AXIS == -2, "Rearrange below depends on this."
54 | output = einops.rearrange(
55 | logits["color"],
56 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c",
57 | hn=hn,
58 | wn=wn,
59 | hp=hp,
60 | wp=wp)
61 | output = jnp.clip(output, -1., 1.)
62 | return {"color": output}
63 |
--------------------------------------------------------------------------------
/big_vision/trainers/proj/uvim/depth_task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Inputs, outputs and losses for depth prediction task."""
16 | import big_vision.utils as u
17 | import einops
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 |
22 |
23 | ONE_HOT_AXIS = -2
24 |
25 |
26 | def input_pp(batch, config):
27 | """Makes inputs for depth prediction task."""
28 | if "labels" not in batch:
29 | x = None
30 | else:
31 | hp, wp = config.model.patch_size
32 | depth = batch["labels"][..., 0]
33 |
34 | # Discretize to [0, ..., bins - 1].
35 | nbins = config.model.inputs.depth[ONE_HOT_AXIS]
36 | mind = config.min_depth
37 | maxd = config.max_depth
38 | depth = (depth - mind) / (maxd - mind)
39 | depth *= nbins
40 | depth = jnp.floor(depth).astype(jnp.int32)
41 | depth = jnp.minimum(depth, nbins - 1)
42 | depth = jnp.maximum(depth, 0)
43 |
44 | # Converts labels from (B, H, W, c) to (B, num_patches, c, patch_size).
45 | depth = jax.nn.one_hot(
46 | einops.rearrange(
47 | depth, "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp),
48 | num_classes=config.model.inputs.depth[ONE_HOT_AXIS],
49 | axis=ONE_HOT_AXIS)
50 | x = {"depth": depth}
51 | ctx = batch.get("image_ctx", batch.get("image", None))
52 | return {"ctx": ctx, "x": x}
53 |
54 |
55 | def loss_fn(predictions, batch, config):
56 | """Computes loss for depth prediction task."""
57 | labels = input_pp(batch, config)["x"]
58 | losses = {}
59 | loss = u.softmax_xent(
60 | logits=predictions["depth"], labels=labels["depth"], reduction=False,
61 | axis=ONE_HOT_AXIS)
62 | # Do not train on the closest class; usually regions of the image with
63 | # depth==0, which is the default for regions with no depth signal.
64 | # TODO: Encode depth==0 as class==-1.
65 | mask = jnp.argmax(labels["depth"], ONE_HOT_AXIS) != 0
66 | loss = loss * mask
67 | losses["loss_depth"] = loss
68 | return sum(losses.values()), losses
69 |
70 |
71 | def predict_outputs(predictions, config):
72 | """Makes outputs for depth predictin tasks."""
73 | # Maps predictions to (height, width, channels).
74 | hp, wp = config.model.patch_size
75 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp))
76 | depth = einops.rearrange(
77 | predictions["depth"],
78 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c",
79 | hn=hn, wn=wn, hp=hp, wp=wp)
80 |
81 | depth = jnp.argmax(depth, axis=-1) # [B, H, W]
82 |
83 | # Revert discretization.
84 | nbins = config.model.inputs.depth[ONE_HOT_AXIS]
85 | mind = config.min_depth
86 | maxd = config.max_depth
87 | depth = depth.astype(jnp.float32) + 0.5 # Undoes floor in expectation.
88 | depth /= nbins
89 | depth = depth * (maxd - mind) + mind
90 |
91 | return {"depth": depth}
92 |
--------------------------------------------------------------------------------
/big_vision/trainers/proj/uvim/panoptic_task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Big Vision Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Inputs, outputs and losses for panoptic task."""
16 | import big_vision.utils as u
17 | import einops
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 |
22 | ONE_HOT_AXIS = -2
23 |
24 |
25 | def input_pp(batch, config):
26 | """Make inputs for panoptic segmentation task."""
27 | if "labels" not in batch:
28 | # During predict of phase2 there is no 'labels' field.
29 | x = None
30 | else:
31 | hp, wp = config.model.patch_size
32 | x = {
33 | "semantics": batch["labels"][..., 0],
34 | "instances": batch["labels"][..., 1],
35 | }
36 | # Convert labels from (B, H, W) to (B, num_patches, num_classes, patch_size)
37 | for key in ["semantics", "instances"]:
38 | x[key] = jax.nn.one_hot(
39 | einops.rearrange(
40 | x[key], "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp),
41 | num_classes=config.model.inputs[key][ONE_HOT_AXIS], axis=ONE_HOT_AXIS)
42 | ctx = batch.get("image_ctx", batch.get("image", None))
43 | return {"ctx": ctx, "x": x}
44 |
45 |
46 | def loss_fn(logits, batch, config):
47 | """Compute loss for panoptic task."""
48 | labels = input_pp(batch, config)["x"]
49 | losses = {}
50 | for key in ["semantics", "instances"]:
51 | losses[f"loss_{key}"] = u.softmax_xent(
52 | logits=logits[key], labels=labels[key], reduction=False,
53 | axis=ONE_HOT_AXIS)
54 | return sum(losses.values()), losses
55 |
56 |
57 | def predict_outputs(logits, config, min_fraction=0.0):
58 | """Make outputs for panoptic segmentation task."""
59 | # Map logits to (height, width, channels).
60 | hp, wp = config.model.patch_size
61 | hn, wn = np.array(config.model.input_size) // np.array((hp, wp))
62 | outputs = {}
63 | for key in ["semantics", "instances"]:
64 | assert ONE_HOT_AXIS == -2, "Rearrange below depends on this."
65 | outputs[key] = einops.rearrange(
66 | logits[key],
67 | "b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c",
68 | hn=hn, wn=wn, hp=hp, wp=wp)
69 | return panoptic_predictions_from_logits(
70 | **outputs, min_fraction=min_fraction)
71 |
72 |
73 | def panoptic_predictions_from_logits(semantics, instances, min_fraction=0.0):
74 | """Make panoptic prediction from logits."""
75 | ins = jnp.argmax(instances, axis=-1)
76 | # Note: Make sure each instance has all pixels annotated with same label.
77 | # Otherwise they are further split into more instances and greatly affect
78 | # the number of unmatched predicted segments (FP) and RQ.
79 | masks = jax.nn.one_hot(ins, instances.shape[-1], dtype=jnp.int32)
80 | label = jnp.argmax(jnp.einsum("bhwk,bhwn->bnk", semantics, masks), axis=-1)
81 | sem = jnp.einsum("bhwn,bn->bhw", masks, label)
82 | out = jnp.stack([sem, ins], axis=-1)
83 | # Filter out small objects
84 | fraction = jnp.sum(masks, axis=(1, 2), keepdims=True)/np.prod(ins.shape[1:3])
85 | mask_big = (fraction > min_fraction).astype("int32")
86 | mask_big_spatial = jnp.sum(masks * mask_big, axis=-1, keepdims=True) > 0
87 | return out * mask_big_spatial.astype("int32")
88 |
--------------------------------------------------------------------------------