├── requirements.txt
├── __init__.py
├── janus
├── utils
│ ├── __init__.py
│ ├── io.py
│ └── conversation.py
├── models
│ ├── __init__.py
│ ├── projector.py
│ ├── clip_encoder.py
│ ├── image_processing_vlm.py
│ ├── modeling_vlm.py
│ ├── processing_vlm.py
│ ├── vq_model.py
│ └── siglip_vit.py
└── __init__.py
├── README.md
├── .gitignore
├── JanusPro.py
└── LICENSE
/requirements.txt:
--------------------------------------------------------------------------------
1 | attrdict
2 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .JanusPro import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 |
3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
4 |
--------------------------------------------------------------------------------
/janus/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
--------------------------------------------------------------------------------
/janus/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from .image_processing_vlm import VLMImageProcessor
21 | from .modeling_vlm import MultiModalityCausalLM
22 | from .processing_vlm import VLChatProcessor
23 |
24 | __all__ = [
25 | "VLMImageProcessor",
26 | "VLChatProcessor",
27 | "MultiModalityCausalLM",
28 | ]
29 |
--------------------------------------------------------------------------------
/janus/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | # check if python version is above 3.10
22 | import sys
23 |
24 | if sys.version_info >= (3, 10):
25 | print("Python version is above 3.10, patching the collections module.")
26 | # Monkey patch collections
27 | import collections
28 | import collections.abc
29 |
30 | for type_name in collections.abc.__all__:
31 | setattr(collections, type_name, getattr(collections.abc, type_name))
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
3 |
4 |
5 | # ComfyUI-DeepSeek-JanusPro(项目说明+细节还在完善中,代码已经可以使用)
6 |
7 |
8 |
9 | ## 由 DeepSeek R1 成功独立完成代码(指:我未写、我未了解原项目代码、我未检查代码)
10 |
11 | DeepSeek R1 自己给自己的 JanusPro 成功写好 ComfyUI 插件(我没写一行!
12 |
13 | 关键点:之前是 LLM 辅助我写插件,我还得了解代码本身,现在几乎无脑给 R1 就能直接交付了
14 |
15 | 无需微调直接就成,无需人看代码/写代码,细节准确度高,预计交互次数理想状态下可以控制在 3-5 次以内(标准是直接就能在 ComfyUI 成功运行),体感比 O1 的细节/准确度更好(还需进一步验证
16 |
17 |
18 | ## 具体过程如下
19 |
20 | 1)我的角色:信息传递员+判断者,我没看 JanusPro 代码,直接都丢给 R1 处理
21 |
22 | 2)给 R1 的样本学习:我自己写的 Emu3 插件的完整代码(两者架构不同
23 |
24 | 3)把 JanusPro 的官方 demo 代码丢给 R1
25 |
26 | 4)R1 先将其分为3个核心节点,然后写出了完整代码,并对其做了优化和兼容性考虑(增强,还给出了使用方式和建议参数范围
27 |
28 | 5)运行之后遇到第一次报错(1个,我提出要求之后 R1 完成修改
29 |
30 | 6)运行之后遇到第二次报错(2个,成功解决,但是由于报错之后未运行第二项功能的节点,所以我提出同样也需要修改,R1 完成修改,但是漏掉了部分关键格式
31 |
32 | 7)补充完整遗漏,第一部分功能已经实现可以正常运行
33 |
34 | 8)第二部分功能 R1 做了过度思考和复杂化,导致偏离原代码,我在发现此现象后,向其提出是否已经偏离原代码,请检查,R1 回顾之前报错并纠正偏离,第二部分也成功实现并运行,运行结果如下图
35 |
36 |
37 | ## 部分思考过程截图
38 |
39 |
40 |
41 |
42 |
43 |
44 | ## 使用示例:
45 |
46 |
47 |
48 |
49 | ## 更新日志
50 |
51 | - 20250221 新增封面图,并将会合并到新的大项目中:[DeepSeek|All-In-One|ComfyUI](https://github.com/ZHO-ZHO-ZHO/ComfyUI-DeepSeek-All-In-One)
52 |
53 | - 20250130(大年初二)
54 |
55 | V1.0 由 DeepSeek R1 成功独立完成代码(指:我未写、我未了解原项目代码、我未检查代码)
56 |
57 | 创建项目
58 |
59 |
60 | ## Stars
61 |
62 | [](https://star-history.com/#ZHO-ZHO-ZHO/ComfyUI-DeepSeek-JanusPro&Date)
63 |
64 |
65 | ## 关于我 | About me
66 |
67 | 📬 **联系我**:
68 | - 邮箱:zhozho3965@gmail.com
69 | - QQ 群:839821928
70 |
71 | 🔗 **社交媒体**:
72 | - 个人页:[-Zho-](https://jike.city/zho)
73 | - Bilibili:[我的B站主页](https://space.bilibili.com/484366804)
74 | - X(Twitter):[我的Twitter](https://twitter.com/ZHO_ZHO_ZHO)
75 | - 小红书:[我的小红书主页](https://www.xiaohongshu.com/user/profile/63f11530000000001001e0c8?xhsshare=CopyLink&appuid=63f11530000000001001e0c8&apptime=1690528872)
76 |
77 | 💡 **支持我**:
78 | - B站:[B站充电](https://space.bilibili.com/484366804)
79 | - 爱发电:[为我充电](https://afdian.com/a/ZHOZHO)
80 |
81 |
82 | ## Credits
83 |
84 | [Janus](https://github.com/deepseek-ai/Janus/tree/main)
85 |
--------------------------------------------------------------------------------
/janus/utils/io.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import json
21 | from typing import Dict, List
22 |
23 | import PIL.Image
24 | import torch
25 | import base64
26 | import io
27 | from transformers import AutoModelForCausalLM
28 |
29 | from janus.models import MultiModalityCausalLM, VLChatProcessor
30 |
31 |
32 | def load_pretrained_model(model_path: str):
33 | vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
34 | tokenizer = vl_chat_processor.tokenizer
35 |
36 | vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
37 | model_path, trust_remote_code=True
38 | )
39 | vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
40 |
41 | return tokenizer, vl_chat_processor, vl_gpt
42 |
43 |
44 | def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
45 | """
46 |
47 | Support file path or base64 images.
48 |
49 | Args:
50 | conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
51 | [
52 | {
53 | "role": "User",
54 | "content": "\nExtract all information from this image and convert them into markdown format.",
55 | "images": ["./examples/table_datasets.png"]
56 | },
57 | {"role": "Assistant", "content": ""},
58 | ]
59 |
60 | Returns:
61 | pil_images (List[PIL.Image.Image]): the list of PIL images.
62 |
63 | """
64 |
65 | pil_images = []
66 |
67 | for message in conversations:
68 | if "images" not in message:
69 | continue
70 |
71 | for image_data in message["images"]:
72 | if image_data.startswith("data:image"):
73 | # Image data is in base64 format
74 | _, image_data = image_data.split(",", 1)
75 | image_bytes = base64.b64decode(image_data)
76 | pil_img = PIL.Image.open(io.BytesIO(image_bytes))
77 | else:
78 | # Image data is a file path
79 | pil_img = PIL.Image.open(image_data)
80 | pil_img = pil_img.convert("RGB")
81 | pil_images.append(pil_img)
82 |
83 | return pil_images
84 |
85 |
86 | def load_json(filepath):
87 | with open(filepath, "r") as f:
88 | data = json.load(f)
89 | return data
90 |
--------------------------------------------------------------------------------
/janus/models/projector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Tuple, Union
21 |
22 | import torch
23 | import torch.nn as nn
24 | from attrdict import AttrDict
25 |
26 |
27 | class MlpProjector(nn.Module):
28 | def __init__(self, cfg):
29 | super().__init__()
30 |
31 | self.cfg = cfg
32 |
33 | if cfg.projector_type == "identity":
34 | modules = nn.Identity()
35 |
36 | elif cfg.projector_type == "linear":
37 | modules = nn.Linear(cfg.input_dim, cfg.n_embed)
38 |
39 | elif cfg.projector_type == "mlp_gelu":
40 | mlp_depth = cfg.get("depth", 1)
41 | modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
42 | for _ in range(1, mlp_depth):
43 | modules.append(nn.GELU())
44 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
45 | modules = nn.Sequential(*modules)
46 |
47 | elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
48 | mlp_depth = cfg.get("depth", 1)
49 | self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
50 | self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
51 |
52 | modules = []
53 | for _ in range(1, mlp_depth):
54 | modules.append(nn.GELU())
55 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
56 | modules = nn.Sequential(*modules)
57 |
58 | else:
59 | raise ValueError(f"Unknown projector type: {cfg.projector_type}")
60 |
61 | self.layers = modules
62 |
63 | def forward(
64 | self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
65 | ):
66 | """
67 |
68 | Args:
69 | x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
70 | then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
71 | otherwise it is the feature from the single vision encoder.
72 |
73 | Returns:
74 | x (torch.Tensor): [b, s, c]
75 | """
76 |
77 | if isinstance(x_or_tuple, tuple):
78 | # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
79 | high_x, low_x = x_or_tuple
80 | high_x = self.high_up_proj(high_x)
81 | low_x = self.low_up_proj(low_x)
82 | x = torch.concat([high_x, low_x], dim=-1)
83 | else:
84 | x = x_or_tuple
85 |
86 | return self.layers(x)
87 |
88 |
89 | if __name__ == "__main__":
90 | cfg = AttrDict(
91 | input_dim=1024,
92 | n_embed=2048,
93 | depth=2,
94 | projector_type="low_high_hybrid_split_mlp_gelu",
95 | )
96 | inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
97 |
98 | m = MlpProjector(cfg)
99 | out = m(inputs)
100 | print(out.shape)
101 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/janus/models/clip_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Dict, List, Literal, Optional, Tuple, Union
21 |
22 | import torch
23 | import torch.nn as nn
24 | import torchvision.transforms
25 | from einops import rearrange
26 |
27 | from janus.models.siglip_vit import create_siglip_vit
28 |
29 |
30 | class CLIPVisionTower(nn.Module):
31 | def __init__(
32 | self,
33 | model_name: str = "siglip_large_patch16_384",
34 | image_size: Union[Tuple[int, int], int] = 336,
35 | select_feature: str = "patch",
36 | select_layer: int = -2,
37 | select_layers: list = None,
38 | ckpt_path: str = "",
39 | pixel_mean: Optional[List[float]] = None,
40 | pixel_std: Optional[List[float]] = None,
41 | **kwargs,
42 | ):
43 | super().__init__()
44 |
45 | self.model_name = model_name
46 | self.select_feature = select_feature
47 | self.select_layer = select_layer
48 | self.select_layers = select_layers
49 |
50 | vision_tower_params = {
51 | "model_name": model_name,
52 | "image_size": image_size,
53 | "ckpt_path": ckpt_path,
54 | "select_layer": select_layer,
55 | }
56 | vision_tower_params.update(kwargs)
57 | self.vision_tower, self.forward_kwargs = self.build_vision_tower(
58 | vision_tower_params
59 | )
60 |
61 | if pixel_mean is not None and pixel_std is not None:
62 | image_norm = torchvision.transforms.Normalize(
63 | mean=pixel_mean, std=pixel_std
64 | )
65 | else:
66 | image_norm = None
67 |
68 | self.image_norm = image_norm
69 |
70 | def build_vision_tower(self, vision_tower_params):
71 | if self.model_name.startswith("siglip"):
72 | self.select_feature = "same"
73 | vision_tower = create_siglip_vit(**vision_tower_params)
74 | forward_kwargs = dict()
75 |
76 | elif self.model_name.startswith("sam"):
77 | vision_tower = create_sam_vit(**vision_tower_params)
78 | forward_kwargs = dict()
79 |
80 | else: # huggingface
81 | from transformers import CLIPVisionModel
82 |
83 | vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
84 | forward_kwargs = dict(output_hidden_states=True)
85 |
86 | return vision_tower, forward_kwargs
87 |
88 | def feature_select(self, image_forward_outs):
89 | if isinstance(image_forward_outs, torch.Tensor):
90 | # the output has been the self.select_layer"s features
91 | image_features = image_forward_outs
92 | else:
93 | image_features = image_forward_outs.hidden_states[self.select_layer]
94 |
95 | if self.select_feature == "patch":
96 | # if the output has cls_token
97 | image_features = image_features[:, 1:]
98 | elif self.select_feature == "cls_patch":
99 | image_features = image_features
100 | elif self.select_feature == "same":
101 | image_features = image_features
102 |
103 | else:
104 | raise ValueError(f"Unexpected select feature: {self.select_feature}")
105 | return image_features
106 |
107 | def forward(self, images):
108 | """
109 |
110 | Args:
111 | images (torch.Tensor): [b, 3, H, W]
112 |
113 | Returns:
114 | image_features (torch.Tensor): [b, n_patch, d]
115 | """
116 |
117 | if self.image_norm is not None:
118 | images = self.image_norm(images)
119 |
120 | image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
121 | image_features = self.feature_select(image_forward_outs)
122 | return image_features
123 |
--------------------------------------------------------------------------------
/janus/models/image_processing_vlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import List, Tuple, Union
21 |
22 | import numpy as np
23 | import torch
24 | import torchvision
25 | import torchvision.transforms.functional
26 | from PIL import Image
27 | from transformers import AutoImageProcessor, PretrainedConfig
28 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
29 | from transformers.image_utils import to_numpy_array
30 | from transformers.utils import logging
31 |
32 | logger = logging.get_logger(__name__)
33 |
34 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
35 | IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36 | IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
37 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
38 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
39 |
40 |
41 | def expand2square(pil_img, background_color):
42 | width, height = pil_img.size
43 | if width == height:
44 | return pil_img
45 | elif width > height:
46 | result = Image.new(pil_img.mode, (width, width), background_color)
47 | result.paste(pil_img, (0, (width - height) // 2))
48 | return result
49 | else:
50 | result = Image.new(pil_img.mode, (height, height), background_color)
51 | result.paste(pil_img, ((height - width) // 2, 0))
52 | return result
53 |
54 |
55 | class VLMImageProcessorConfig(PretrainedConfig):
56 | model_type = "deepseek_vlm"
57 | image_size: int
58 | min_size: int
59 | image_mean: Union[Tuple[float, float, float], List[float]]
60 | image_std: Union[Tuple[float, float, float], List[float]]
61 | rescale_factor: float
62 | do_normalize: bool
63 |
64 | def __init__(
65 | self,
66 | image_size: int,
67 | min_size: int = 14,
68 | image_mean: Union[Tuple[float, float, float], List[float]] = (
69 | 0.48145466,
70 | 0.4578275,
71 | 0.40821073,
72 | ),
73 | image_std: Union[Tuple[float, float, float], List[float]] = (
74 | 0.26862954,
75 | 0.26130258,
76 | 0.27577711,
77 | ),
78 | rescale_factor: float = 1.0 / 255.0,
79 | do_normalize: bool = True,
80 | **kwargs,
81 | ):
82 | self.image_size = image_size
83 | self.min_size = min_size
84 | self.image_mean = image_mean
85 | self.image_std = image_std
86 | self.rescale_factor = rescale_factor
87 | self.do_normalize = do_normalize
88 |
89 | super().__init__(**kwargs)
90 |
91 |
92 | class VLMImageProcessor(BaseImageProcessor):
93 | model_input_names = ["pixel_values"]
94 |
95 | def __init__(
96 | self,
97 | image_size: int,
98 | min_size: int = 14,
99 | image_mean: Union[Tuple[float, float, float], List[float]] = (
100 | 0.48145466,
101 | 0.4578275,
102 | 0.40821073,
103 | ),
104 | image_std: Union[Tuple[float, float, float], List[float]] = (
105 | 0.26862954,
106 | 0.26130258,
107 | 0.27577711,
108 | ),
109 | rescale_factor: float = 1.0 / 255.0,
110 | do_normalize: bool = True,
111 | **kwargs,
112 | ):
113 | super().__init__(**kwargs)
114 |
115 | self.image_size = image_size
116 | self.rescale_factor = rescale_factor
117 | self.image_mean = image_mean
118 | self.image_std = image_std
119 | self.min_size = min_size
120 | self.do_normalize = do_normalize
121 |
122 | if image_mean is None:
123 | self.background_color = (127, 127, 127)
124 | else:
125 | self.background_color = tuple([int(x * 255) for x in image_mean])
126 |
127 | def resize(self, pil_img: Image) -> np.ndarray:
128 | """
129 |
130 | Args:
131 | pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
132 |
133 | Returns:
134 | x (np.ndarray): [3, self.image_size, self.image_size]
135 | """
136 |
137 | width, height = pil_img.size
138 | max_size = max(width, height)
139 |
140 | size = [
141 | max(int(height / max_size * self.image_size), self.min_size),
142 | max(int(width / max_size * self.image_size), self.min_size),
143 | ]
144 |
145 | if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
146 | print(f"orig size = {pil_img.size}, new size = {size}")
147 | raise ValueError("Invalid size!")
148 |
149 | pil_img = torchvision.transforms.functional.resize(
150 | pil_img,
151 | size,
152 | interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
153 | antialias=True,
154 | )
155 |
156 | pil_img = expand2square(pil_img, self.background_color)
157 | x = to_numpy_array(pil_img)
158 |
159 | # [H, W, 3] -> [3, H, W]
160 | x = np.transpose(x, (2, 0, 1))
161 |
162 | return x
163 |
164 | def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
165 | # resize and pad to [self.image_size, self.image_size]
166 | # then convert from [H, W, 3] to [3, H, W]
167 | images: List[np.ndarray] = [self.resize(image) for image in images]
168 |
169 | # resacle from [0, 255] -> [0, 1]
170 | images = [
171 | self.rescale(
172 | image=image,
173 | scale=self.rescale_factor,
174 | input_data_format="channels_first",
175 | )
176 | for image in images
177 | ]
178 |
179 | # normalize
180 | if self.do_normalize:
181 | images = [
182 | self.normalize(
183 | image=image,
184 | mean=self.image_mean,
185 | std=self.image_std,
186 | input_data_format="channels_first",
187 | )
188 | for image in images
189 | ]
190 |
191 | data = {"pixel_values": images}
192 | return BatchFeature(data=data, tensor_type=return_tensors)
193 |
194 | @property
195 | def default_shape(self):
196 | return [3, self.image_size, self.image_size]
197 |
198 |
199 | AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
200 |
201 |
202 | if __name__ == "__main__":
203 | image_processor = VLMImageProcessor(
204 | image_size=1024,
205 | image_mean=IMAGENET_INCEPTION_MEAN,
206 | image_std=IMAGENET_INCEPTION_STD,
207 | do_normalize=True,
208 | )
209 |
--------------------------------------------------------------------------------
/janus/models/modeling_vlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | from attrdict import AttrDict
22 | from einops import rearrange
23 | from transformers import (
24 | AutoConfig,
25 | AutoModelForCausalLM,
26 | LlamaConfig,
27 | LlamaForCausalLM,
28 | PreTrainedModel,
29 | )
30 | from transformers.configuration_utils import PretrainedConfig
31 |
32 | from janus.models.clip_encoder import CLIPVisionTower
33 | from janus.models.projector import MlpProjector
34 |
35 |
36 | class vision_head(torch.nn.Module):
37 | def __init__(self, params):
38 | super().__init__()
39 | self.output_mlp_projector = torch.nn.Linear(
40 | params.n_embed, params.image_token_embed
41 | )
42 | self.vision_activation = torch.nn.GELU()
43 | self.vision_head = torch.nn.Linear(
44 | params.image_token_embed, params.image_token_size
45 | )
46 |
47 | def forward(self, x):
48 | x = self.output_mlp_projector(x)
49 | x = self.vision_activation(x)
50 | x = self.vision_head(x)
51 | return x
52 |
53 |
54 | def model_name_to_cls(cls_name):
55 | if "MlpProjector" in cls_name:
56 | cls = MlpProjector
57 |
58 | elif "CLIPVisionTower" in cls_name:
59 | cls = CLIPVisionTower
60 |
61 | elif "VQ" in cls_name:
62 | from janus.models.vq_model import VQ_models
63 |
64 | cls = VQ_models[cls_name]
65 | elif "vision_head" in cls_name:
66 | cls = vision_head
67 | else:
68 | raise ValueError(f"class_name {cls_name} is invalid.")
69 |
70 | return cls
71 |
72 |
73 | class VisionConfig(PretrainedConfig):
74 | model_type = "vision"
75 | cls: str = ""
76 | params: AttrDict = {}
77 |
78 | def __init__(self, **kwargs):
79 | super().__init__(**kwargs)
80 |
81 | self.cls = kwargs.get("cls", "")
82 | if not isinstance(self.cls, str):
83 | self.cls = self.cls.__name__
84 |
85 | self.params = AttrDict(kwargs.get("params", {}))
86 |
87 |
88 | class AlignerConfig(PretrainedConfig):
89 | model_type = "aligner"
90 | cls: str = ""
91 | params: AttrDict = {}
92 |
93 | def __init__(self, **kwargs):
94 | super().__init__(**kwargs)
95 |
96 | self.cls = kwargs.get("cls", "")
97 | if not isinstance(self.cls, str):
98 | self.cls = self.cls.__name__
99 |
100 | self.params = AttrDict(kwargs.get("params", {}))
101 |
102 |
103 | class GenVisionConfig(PretrainedConfig):
104 | model_type = "gen_vision"
105 | cls: str = ""
106 | params: AttrDict = {}
107 |
108 | def __init__(self, **kwargs):
109 | super().__init__(**kwargs)
110 |
111 | self.cls = kwargs.get("cls", "")
112 | if not isinstance(self.cls, str):
113 | self.cls = self.cls.__name__
114 |
115 | self.params = AttrDict(kwargs.get("params", {}))
116 |
117 |
118 | class GenAlignerConfig(PretrainedConfig):
119 | model_type = "gen_aligner"
120 | cls: str = ""
121 | params: AttrDict = {}
122 |
123 | def __init__(self, **kwargs):
124 | super().__init__(**kwargs)
125 |
126 | self.cls = kwargs.get("cls", "")
127 | if not isinstance(self.cls, str):
128 | self.cls = self.cls.__name__
129 |
130 | self.params = AttrDict(kwargs.get("params", {}))
131 |
132 |
133 | class GenHeadConfig(PretrainedConfig):
134 | model_type = "gen_head"
135 | cls: str = ""
136 | params: AttrDict = {}
137 |
138 | def __init__(self, **kwargs):
139 | super().__init__(**kwargs)
140 |
141 | self.cls = kwargs.get("cls", "")
142 | if not isinstance(self.cls, str):
143 | self.cls = self.cls.__name__
144 |
145 | self.params = AttrDict(kwargs.get("params", {}))
146 |
147 |
148 | class MultiModalityConfig(PretrainedConfig):
149 | model_type = "multi_modality"
150 | vision_config: VisionConfig
151 | aligner_config: AlignerConfig
152 |
153 | gen_vision_config: GenVisionConfig
154 | gen_aligner_config: GenAlignerConfig
155 | gen_head_config: GenHeadConfig
156 |
157 | language_config: LlamaConfig
158 |
159 | def __init__(self, **kwargs):
160 | super().__init__(**kwargs)
161 | vision_config = kwargs.get("vision_config", {})
162 | self.vision_config = VisionConfig(**vision_config)
163 |
164 | aligner_config = kwargs.get("aligner_config", {})
165 | self.aligner_config = AlignerConfig(**aligner_config)
166 |
167 | gen_vision_config = kwargs.get("gen_vision_config", {})
168 | self.gen_vision_config = GenVisionConfig(**gen_vision_config)
169 |
170 | gen_aligner_config = kwargs.get("gen_aligner_config", {})
171 | self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
172 |
173 | gen_head_config = kwargs.get("gen_head_config", {})
174 | self.gen_head_config = GenHeadConfig(**gen_head_config)
175 |
176 | language_config = kwargs.get("language_config", {})
177 | if isinstance(language_config, LlamaConfig):
178 | self.language_config = language_config
179 | else:
180 | self.language_config = LlamaConfig(**language_config)
181 |
182 |
183 | class MultiModalityPreTrainedModel(PreTrainedModel):
184 | config_class = MultiModalityConfig
185 | base_model_prefix = "multi_modality"
186 | _no_split_modules = []
187 | _skip_keys_device_placement = "past_key_values"
188 |
189 |
190 | class MultiModalityCausalLM(MultiModalityPreTrainedModel):
191 | def __init__(self, config: MultiModalityConfig):
192 | super().__init__(config)
193 |
194 | vision_config = config.vision_config
195 | vision_cls = model_name_to_cls(vision_config.cls)
196 | self.vision_model = vision_cls(**vision_config.params)
197 |
198 | aligner_config = config.aligner_config
199 | aligner_cls = model_name_to_cls(aligner_config.cls)
200 | self.aligner = aligner_cls(aligner_config.params)
201 |
202 | gen_vision_config = config.gen_vision_config
203 | gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
204 | self.gen_vision_model = gen_vision_cls()
205 |
206 | gen_aligner_config = config.gen_aligner_config
207 | gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
208 | self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
209 |
210 | gen_head_config = config.gen_head_config
211 | gen_head_cls = model_name_to_cls(gen_head_config.cls)
212 | self.gen_head = gen_head_cls(gen_head_config.params)
213 |
214 | self.gen_embed = torch.nn.Embedding(
215 | gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
216 | )
217 |
218 | language_config = config.language_config
219 | self.language_model = LlamaForCausalLM(language_config)
220 |
221 | def prepare_inputs_embeds(
222 | self,
223 | input_ids: torch.LongTensor,
224 | pixel_values: torch.FloatTensor,
225 | images_seq_mask: torch.LongTensor,
226 | images_emb_mask: torch.LongTensor,
227 | **kwargs,
228 | ):
229 | """
230 |
231 | Args:
232 | input_ids (torch.LongTensor): [b, T]
233 | pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
234 | images_seq_mask (torch.BoolTensor): [b, T]
235 | images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
236 |
237 | assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
238 |
239 | Returns:
240 | input_embeds (torch.Tensor): [b, T, D]
241 | """
242 |
243 | bs, n = pixel_values.shape[0:2]
244 | images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
245 | # [b x n, T2, D]
246 | images_embeds = self.aligner(self.vision_model(images))
247 |
248 | # [b x n, T2, D] -> [b, n x T2, D]
249 | images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
250 | # [b, n, T2] -> [b, n x T2]
251 | images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
252 |
253 | # [b, T, D]
254 | input_ids[input_ids < 0] = 0 # ignore the image embeddings
255 | inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
256 |
257 | # replace with the image embeddings
258 | inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
259 |
260 | return inputs_embeds
261 |
262 | def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
263 | return self.gen_aligner(self.gen_embed(image_ids))
264 |
265 |
266 | AutoConfig.register("vision", VisionConfig)
267 | AutoConfig.register("aligner", AlignerConfig)
268 | AutoConfig.register("gen_vision", GenVisionConfig)
269 | AutoConfig.register("gen_aligner", GenAlignerConfig)
270 | AutoConfig.register("gen_head", GenHeadConfig)
271 | AutoConfig.register("multi_modality", MultiModalityConfig)
272 | AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
273 |
--------------------------------------------------------------------------------
/JanusPro.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import torch
4 | import numpy as np
5 | import folder_paths
6 | import time
7 | import re
8 | from PIL import Image
9 | from transformers import AutoConfig, AutoModelForCausalLM
10 |
11 | # 关键路径处理:将当前目录添加到系统路径
12 | current_dir = os.path.dirname(os.path.abspath(__file__))
13 | sys.path.insert(0, current_dir) # 添加当前目录到Python路径
14 |
15 | try:
16 | from janus.models import MultiModalityCausalLM, VLChatProcessor
17 | from janus.utils.io import load_pil_images
18 | except ImportError as e:
19 | print(f"路径调试信息:")
20 | print(f"当前目录: {current_dir}")
21 | print(f"目录内容: {os.listdir(current_dir)}")
22 | print(f"sys.path: {sys.path}")
23 | raise
24 |
25 | # 添加模型路径配置
26 | current_directory = os.path.dirname(os.path.abspath(__file__))
27 | folder_paths.folder_names_and_paths["Janus"] = ([os.path.join(folder_paths.models_dir, "Janus")], folder_paths.supported_pt_extensions)
28 |
29 | # 辅助函数
30 | def tensor2pil(image):
31 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
32 |
33 | def pil2tensor(image):
34 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
35 |
36 | class Janus_ModelLoader:
37 | def __init__(self):
38 | pass
39 |
40 | @classmethod
41 | def INPUT_TYPES(cls):
42 | return {
43 | "required": {
44 | "model_path": ("STRING", {"default": "deepseek-ai/Janus-Pro-7B"}),
45 | }
46 | }
47 |
48 | RETURN_TYPES = ("JANUS_MODEL", "PROCESSOR", "TOKENIZER")
49 | RETURN_NAMES = ("model", "processor", "tokenizer")
50 | FUNCTION = "load_model"
51 | CATEGORY = "🧩Janus"
52 |
53 | def load_model(self, model_path):
54 | # 加载配置
55 | config = AutoConfig.from_pretrained(model_path)
56 | language_config = config.language_config
57 | language_config._attn_implementation = 'eager'
58 |
59 | # 加载模型
60 | vl_gpt = AutoModelForCausalLM.from_pretrained(
61 | model_path,
62 | language_config=language_config,
63 | trust_remote_code=True
64 | ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16)
65 |
66 | if torch.cuda.is_available():
67 | vl_gpt = vl_gpt.cuda()
68 |
69 | # 加载处理器
70 | processor = VLChatProcessor.from_pretrained(model_path)
71 | tokenizer = processor.tokenizer
72 |
73 | return (vl_gpt, processor, tokenizer)
74 |
75 | class Janus_MultimodalUnderstanding:
76 | @classmethod
77 | def INPUT_TYPES(cls):
78 | return {
79 | "required": {
80 | "model": ("JANUS_MODEL",),
81 | "processor": ("PROCESSOR",),
82 | "tokenizer": ("TOKENIZER",),
83 | "image": ("IMAGE",),
84 | "question": ("STRING", {"default": "describe the image", "multiline": True}),
85 | "seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
86 | "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.05}),
87 | "temperature": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05}),
88 | },
89 | "optional": {
90 | "max_new_tokens": ("INT", {"default": 512, "min": 16, "max": 2048}),
91 | }
92 | }
93 |
94 | RETURN_TYPES = ("STRING",)
95 | RETURN_NAMES = ("response",)
96 | FUNCTION = "understand"
97 | CATEGORY = "🧩Janus"
98 |
99 | def understand(self, model, processor, tokenizer, image, question, seed, top_p, temperature, max_new_tokens=512):
100 | # 修复种子范围问题
101 | seed = seed % (2**32)
102 |
103 | # 设置随机种子(添加CUDA同步)
104 | torch.manual_seed(seed)
105 | np.random.seed(seed % (2**32 - 1)) # 适配numpy种子范围
106 | if torch.cuda.is_available():
107 | torch.cuda.manual_seed_all(seed)
108 | torch.cuda.synchronize()
109 |
110 | try:
111 | # 图像预处理(添加维度验证)
112 | if isinstance(image, list):
113 | image_tensor = image[0]
114 | else:
115 | image_tensor = image
116 |
117 | pil_image = tensor2pil(image_tensor)
118 | if pil_image.mode != "RGB":
119 | pil_image = pil_image.convert("RGB")
120 |
121 | # 构建对话(添加异常处理)
122 | try:
123 | conversation = [{
124 | "role": "<|User|>",
125 | "content": f"\n{question}",
126 | "images": [pil_image],
127 | }, {
128 | "role": "<|Assistant|>",
129 | "content": ""
130 | }]
131 | except Exception as e:
132 | print(f"对话构建失败: {e}")
133 | return ("Error: Invalid conversation format",)
134 |
135 | # 处理输入(添加维度调试)
136 | try:
137 | prepare_inputs = processor(
138 | conversations=conversation,
139 | images=[pil_image],
140 | force_batchify=True
141 | ).to(model.device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
142 |
143 | print(f"输入张量形状 - input_ids: {prepare_inputs.input_ids.shape}")
144 | print(f"注意力掩码形状: {prepare_inputs.attention_mask.shape}")
145 | except Exception as e:
146 | print(f"输入处理失败: {e}")
147 | return ("Error: Input processing failed",)
148 |
149 | # 生成过程(添加参数验证)
150 | try:
151 | inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs)
152 | print(f"输入嵌入形状: {inputs_embeds.shape}")
153 |
154 | generation_config = {
155 | "inputs_embeds": inputs_embeds,
156 | "attention_mask": prepare_inputs.attention_mask,
157 | "pad_token_id": tokenizer.eos_token_id,
158 | "bos_token_id": tokenizer.bos_token_id,
159 | "eos_token_id": tokenizer.eos_token_id,
160 | "max_new_tokens": max_new_tokens,
161 | "do_sample": temperature > 0,
162 | "temperature": temperature if temperature > 0 else 1.0,
163 | "top_p": top_p,
164 | }
165 |
166 | # 执行生成(添加时间监控)
167 | start_time = time.time()
168 | outputs = model.language_model.generate(**generation_config)
169 | print(f"生成耗时: {time.time() - start_time:.2f}秒")
170 |
171 | except Exception as e:
172 | print(f"生成失败: {e}")
173 | return ("Error: Generation failed",)
174 |
175 | # 解码输出(添加异常处理)
176 | try:
177 | full_output = outputs[0].cpu().tolist()
178 | answer = tokenizer.decode(full_output, skip_special_tokens=True)
179 |
180 | # 清理特殊标记
181 | clean_pattern = r'<\|.*?\|>'
182 | clean_answer = re.sub(clean_pattern, '', answer).strip()
183 |
184 | return (clean_answer,)
185 |
186 | except Exception as e:
187 | print(f"解码失败: {e}")
188 | return ("Error: Output decoding failed",)
189 |
190 | except Exception as e:
191 | print(f"处理过程中出现未捕获的异常: {e}")
192 | return ("Error: Unexpected processing error",)
193 |
194 |
195 | class Janus_ImageGeneration:
196 | @classmethod
197 | def INPUT_TYPES(cls):
198 | return {
199 | "required": {
200 | "model": ("JANUS_MODEL",),
201 | "processor": ("PROCESSOR",),
202 | "tokenizer": ("TOKENIZER",),
203 | "prompt": ("STRING", {"multiline": True, "default": "Master shifu racoon wearing drip attire"}),
204 | "seed": ("INT", {"default": 12345, "min": 0, "max": 0xffffffffffffffff}),
205 | "cfg_weight": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 10.0, "step": 0.5}),
206 | "temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}),
207 | }
208 | }
209 |
210 | RETURN_TYPES = ("IMAGE",)
211 | RETURN_NAMES = ("images",)
212 | FUNCTION = "generate"
213 | CATEGORY = "🧩Janus"
214 |
215 | def generate(self, model, processor, tokenizer, prompt, seed, cfg_weight, temperature):
216 | # 清理缓存并设置种子
217 | torch.cuda.empty_cache()
218 | seed = seed % (2**32)
219 | torch.manual_seed(seed)
220 | np.random.seed(seed)
221 | if torch.cuda.is_available():
222 | torch.cuda.manual_seed_all(seed)
223 |
224 | # 固定参数(与原始代码一致)
225 | width = 384
226 | height = 384
227 | parallel_size = 5
228 | patch_size = 16
229 | image_token_num = 576
230 |
231 | # 构建输入文本
232 | messages = [{'role': '<|User|>', 'content': prompt},
233 | {'role': '<|Assistant|>', 'content': ''}]
234 | text = processor.apply_sft_template_for_multi_turn_prompts(
235 | conversations=messages,
236 | sft_format=processor.sft_format,
237 | system_prompt=''
238 | ) + processor.image_start_tag
239 |
240 | # 生成输入ID
241 | input_ids = torch.LongTensor(tokenizer.encode(text)).to(model.device)
242 |
243 | # 初始化Tokens(严格保持原始结构)
244 | tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=model.device)
245 | for i in range(parallel_size * 2):
246 | tokens[i, :] = input_ids
247 | if i % 2 != 0:
248 | tokens[i, 1:-1] = processor.pad_id
249 |
250 | # 生成过程(保持原始循环结构)
251 | inputs_embeds = model.language_model.get_input_embeddings()(tokens)
252 | generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int, device=model.device)
253 |
254 | pkv = None
255 | for i in range(image_token_num):
256 | with torch.no_grad():
257 | outputs = model.language_model.model(
258 | inputs_embeds=inputs_embeds,
259 | use_cache=True,
260 | past_key_values=pkv
261 | )
262 | pkv = outputs.past_key_values
263 |
264 | # 原始分类器自由引导实现
265 | logits = model.gen_head(outputs.last_hidden_state[:, -1, :])
266 | logit_cond = logits[0::2, :]
267 | logit_uncond = logits[1::2, :]
268 | logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
269 |
270 | # 采样逻辑
271 | probs = torch.softmax(logits / temperature, dim=-1)
272 | next_token = torch.multinomial(probs, num_samples=1)
273 | generated_tokens[:, i] = next_token.squeeze(dim=-1)
274 |
275 | # 准备下一轮输入(保持原始视图操作)
276 | next_token = torch.cat([next_token.unsqueeze(1), next_token.unsqueeze(1)], dim=1).view(-1)
277 | img_embeds = model.prepare_gen_img_embeds(next_token)
278 | inputs_embeds = img_embeds.unsqueeze(dim=1)
279 |
280 | # 图像解码(严格保持原始实现)
281 | patches = model.gen_vision_model.decode_code(
282 | generated_tokens.to(dtype=torch.int),
283 | shape=[parallel_size, 8, width//patch_size, height//patch_size]
284 | )
285 |
286 | # 后处理(原始unpack逻辑)
287 | dec = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
288 | dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
289 | visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
290 | visual_img[:, :, :] = dec
291 |
292 | # 转换为ComfyUI图像格式
293 | output_images = []
294 | for i in range(parallel_size):
295 | pil_img = Image.fromarray(visual_img[i]).resize((768, 768), Image.LANCZOS)
296 | output_images.append(pil2tensor(pil_img))
297 |
298 | return (torch.cat(output_images, dim=0),)
299 |
300 |
301 | NODE_CLASS_MAPPINGS = {
302 | "Janus_ModelLoader": Janus_ModelLoader,
303 | "Janus_MultimodalUnderstanding": Janus_MultimodalUnderstanding,
304 | "Janus_ImageGeneration": Janus_ImageGeneration
305 | }
306 |
307 | NODE_DISPLAY_NAME_MAPPINGS = {
308 | "Janus_ModelLoader": "🧩Janus Model Loader",
309 | "Janus_MultimodalUnderstanding": "🧩Janus Multimodal Understanding",
310 | "Janus_ImageGeneration": "🧩Janus Image Generation"
311 | }
312 |
--------------------------------------------------------------------------------
/janus/utils/conversation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | """
21 | From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22 | """
23 |
24 | import dataclasses
25 | from enum import IntEnum, auto
26 | from typing import Dict, List
27 |
28 |
29 | class SeparatorStyle(IntEnum):
30 | """Separator styles."""
31 |
32 | ADD_COLON_SINGLE = auto()
33 | ADD_COLON_TWO = auto()
34 | ADD_COLON_SPACE_SINGLE = auto()
35 | NO_COLON_SINGLE = auto()
36 | NO_COLON_TWO = auto()
37 | ADD_NEW_LINE_SINGLE = auto()
38 | LLAMA2 = auto()
39 | CHATGLM = auto()
40 | CHATML = auto()
41 | CHATINTERN = auto()
42 | DOLLY = auto()
43 | RWKV = auto()
44 | PHOENIX = auto()
45 | ROBIN = auto()
46 | DeepSeek = auto()
47 | PLAIN = auto()
48 | ALIGNMENT = auto()
49 |
50 |
51 | @dataclasses.dataclass
52 | class Conversation:
53 | """A class that manages prompt templates and keeps all conversation history."""
54 |
55 | # The name of this template
56 | name: str
57 | # The template of the system prompt
58 | system_template: str = "{system_message}"
59 | # The system message
60 | system_message: str = ""
61 | # The names of two roles
62 | roles: List[str] = (("USER", "ASSISTANT"),)
63 | # All messages. Each item is (role, message).
64 | messages: List[List[str]] = ()
65 | # The number of few shot examples
66 | offset: int = 0
67 | # The separator style and configurations
68 | sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69 | sep: str = "\n"
70 | sep2: str = None
71 | # Stop criteria (the default one is EOS token)
72 | stop_str: str = None
73 | # Stops generation if meeting any token in this list
74 | stop_token_ids: List[int] = None
75 |
76 | def get_prompt(self) -> str:
77 | """Get the prompt for generation."""
78 | system_prompt = self.system_template.format(system_message=self.system_message)
79 |
80 | if self.sep_style == SeparatorStyle.DeepSeek:
81 | seps = [self.sep, self.sep2]
82 | if system_prompt == "" or system_prompt is None:
83 | ret = ""
84 | else:
85 | ret = system_prompt + seps[0]
86 | for i, (role, message) in enumerate(self.messages):
87 | if message:
88 | ret += role + ": " + message + seps[i % 2]
89 | else:
90 | ret += role + ":"
91 | return ret
92 | elif self.sep_style == SeparatorStyle.LLAMA2:
93 | seps = [self.sep, self.sep2]
94 | if self.system_message:
95 | ret = system_prompt
96 | else:
97 | ret = "[INST] "
98 | for i, (role, message) in enumerate(self.messages):
99 | tag = self.roles[i % 2]
100 | if message:
101 | if type(message) is tuple: # multimodal message
102 | message, _ = message
103 | if i == 0:
104 | ret += message + " "
105 | else:
106 | ret += tag + " " + message + seps[i % 2]
107 | else:
108 | ret += tag
109 | return ret
110 | elif self.sep_style == SeparatorStyle.PLAIN:
111 | seps = [self.sep, self.sep2]
112 | ret = ""
113 | for i, (role, message) in enumerate(self.messages):
114 | if message:
115 | if type(message) is tuple:
116 | message, _, _ = message
117 | if i % 2 == 0:
118 | ret += message + seps[i % 2]
119 | else:
120 | ret += message + seps[i % 2]
121 | else:
122 | ret += ""
123 | return ret
124 | elif self.sep_style == SeparatorStyle.ALIGNMENT:
125 | seps = [self.sep, self.sep2]
126 | ret = ""
127 | for i, (role, message) in enumerate(self.messages):
128 | if message:
129 | if type(message) is tuple:
130 | message, _, _ = message
131 | if i % 2 == 0:
132 | ret += "\n" + seps[i % 2]
133 | else:
134 | ret += message + seps[i % 2]
135 | else:
136 | ret += ""
137 | return ret
138 | else:
139 | raise ValueError(f"Invalid style: {self.sep_style}")
140 |
141 | def get_prompt_for_current_round(self, content=None):
142 | """Get current round formatted question prompt during sft training"""
143 | if self.sep_style == SeparatorStyle.PLAIN:
144 | formatted_question = "\n"
145 | elif self.sep_style == SeparatorStyle.DeepSeek:
146 | formatted_question = (
147 | f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148 | )
149 | else:
150 | raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151 | return formatted_question
152 |
153 | def set_system_message(self, system_message: str):
154 | """Set the system message."""
155 | self.system_message = system_message
156 |
157 | def append_message(self, role: str, message: str):
158 | """Append a new message."""
159 | self.messages.append([role, message])
160 |
161 | def reset_message(self):
162 | """Reset a new message."""
163 | self.messages = []
164 |
165 | def update_last_message(self, message: str):
166 | """Update the last output.
167 |
168 | The last message is typically set to be None when constructing the prompt,
169 | so we need to update it in-place after getting the response from a model.
170 | """
171 | self.messages[-1][1] = message
172 |
173 | def to_gradio_chatbot(self):
174 | """Convert the conversation to gradio chatbot format."""
175 | ret = []
176 | for i, (role, msg) in enumerate(self.messages[self.offset :]):
177 | if i % 2 == 0:
178 | ret.append([msg, None])
179 | else:
180 | ret[-1][-1] = msg
181 | return ret
182 |
183 | def to_openai_api_messages(self):
184 | """Convert the conversation to OpenAI chat completion format."""
185 | system_prompt = self.system_template.format(system_message=self.system_message)
186 | ret = [{"role": "system", "content": system_prompt}]
187 |
188 | for i, (_, msg) in enumerate(self.messages[self.offset :]):
189 | if i % 2 == 0:
190 | ret.append({"role": "user", "content": msg})
191 | else:
192 | if msg is not None:
193 | ret.append({"role": "assistant", "content": msg})
194 | return ret
195 |
196 | def copy(self):
197 | return Conversation(
198 | name=self.name,
199 | system_template=self.system_template,
200 | system_message=self.system_message,
201 | roles=self.roles,
202 | messages=[[x, y] for x, y in self.messages],
203 | offset=self.offset,
204 | sep_style=self.sep_style,
205 | sep=self.sep,
206 | sep2=self.sep2,
207 | stop_str=self.stop_str,
208 | stop_token_ids=self.stop_token_ids,
209 | )
210 |
211 | def dict(self):
212 | return {
213 | "template_name": self.name,
214 | "system_message": self.system_message,
215 | "roles": self.roles,
216 | "messages": self.messages,
217 | "offset": self.offset,
218 | }
219 |
220 |
221 | # A global registry for all conversation templates
222 | conv_templates: Dict[str, Conversation] = {}
223 |
224 |
225 | def register_conv_template(template: Conversation, override: bool = False):
226 | """Register a new conversation template."""
227 | if not override:
228 | assert (
229 | template.name not in conv_templates
230 | ), f"{template.name} has been registered."
231 |
232 | conv_templates[template.name] = template
233 |
234 |
235 | def get_conv_template(name: str) -> Conversation:
236 | """Get a conversation template."""
237 | return conv_templates[name].copy()
238 |
239 |
240 | # llava_llama2 template
241 | register_conv_template(
242 | Conversation(
243 | name="llava_llama2",
244 | system_message="You are a helpful language and vision assistant. "
245 | "You are able to understand the visual content that the user provides, "
246 | "and assist the user with a variety of tasks using natural language.",
247 | system_template="[INST] <>\n{system_message}\n<>\n\n",
248 | roles=("[INST]", "[/INST]"),
249 | messages=(),
250 | offset=0,
251 | sep_style=SeparatorStyle.LLAMA2,
252 | sep=" ",
253 | sep2=" ",
254 | stop_token_ids=[2],
255 | )
256 | )
257 |
258 | # llama2 template
259 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260 | register_conv_template(
261 | Conversation(
262 | name="llama-2",
263 | system_template="[INST] <>\n{system_message}\n<>\n\n",
264 | roles=("[INST]", "[/INST]"),
265 | messages=(),
266 | offset=0,
267 | sep_style=SeparatorStyle.LLAMA2,
268 | sep=" ",
269 | sep2=" ",
270 | stop_token_ids=[2],
271 | )
272 | )
273 |
274 |
275 | # deepseek template
276 | register_conv_template(
277 | Conversation(
278 | name="deepseek_old",
279 | system_template="{system_message}",
280 | # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281 | # "thinking step by step to be sure you get the right answer.",
282 | system_message="",
283 | roles=("User", "Assistant"),
284 | messages=(),
285 | offset=0,
286 | sep_style=SeparatorStyle.DeepSeek,
287 | sep="\n\n",
288 | sep2="<|end▁of▁sentence|>",
289 | stop_token_ids=[100001],
290 | stop_str=["User:", "<|end▁of▁sentence|>"],
291 | )
292 | )
293 | register_conv_template(
294 | Conversation(
295 | name="deepseek",
296 | system_template="{system_message}",
297 | # system_message="You are a helpful assistant. Please answer truthfully and write out your "
298 | # "thinking step by step to be sure you get the right answer.",
299 | system_message="",
300 | roles=("<|User|>", "<|Assistant|>"),
301 | messages=(),
302 | offset=0,
303 | sep_style=SeparatorStyle.DeepSeek,
304 | sep="\n\n",
305 | sep2="<|end▁of▁sentence|>",
306 | stop_token_ids=[100001],
307 | stop_str=["<|User|>", "<|end▁of▁sentence|>"]
308 | )
309 | )
310 |
311 | register_conv_template(
312 | Conversation(
313 | name="plain",
314 | system_template="",
315 | system_message="",
316 | roles=("", ""),
317 | messages=(),
318 | offset=0,
319 | sep_style=SeparatorStyle.PLAIN,
320 | sep="",
321 | sep2="",
322 | stop_token_ids=[2],
323 | stop_str=[""],
324 | )
325 | )
326 |
327 |
328 | register_conv_template(
329 | Conversation(
330 | name="alignment",
331 | system_template="",
332 | system_message="",
333 | roles=("", ""),
334 | messages=(),
335 | offset=0,
336 | sep_style=SeparatorStyle.ALIGNMENT,
337 | sep="",
338 | sep2="",
339 | stop_token_ids=[2],
340 | stop_str=[""],
341 | )
342 | )
343 |
344 |
345 | if __name__ == "__main__":
346 | # print("Llama-2 template:")
347 | # conv = get_conv_template("llama-2")
348 | # conv.set_system_message("You are a helpful, respectful and honest assistant.")
349 | # conv.append_message(conv.roles[0], "Hello!")
350 | # conv.append_message(conv.roles[1], "Hi!")
351 | # conv.append_message(conv.roles[0], "How are you?")
352 | # conv.append_message(conv.roles[1], None)
353 | # print(conv.get_prompt())
354 |
355 | # print("\n")
356 |
357 | print("deepseek template:")
358 | conv = get_conv_template("deepseek")
359 | conv.append_message(conv.roles[0], "Hello!")
360 | conv.append_message(conv.roles[1], "Hi! This is Tony.")
361 | conv.append_message(conv.roles[0], "Who are you?")
362 | conv.append_message(conv.roles[1], "I am a helpful assistant.")
363 | conv.append_message(conv.roles[0], "How are you?")
364 | conv.append_message(conv.roles[1], None)
365 | print(conv.get_prompt())
366 |
--------------------------------------------------------------------------------
/janus/models/processing_vlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | from dataclasses import dataclass
21 | from typing import Dict, List
22 |
23 | import torch
24 | from PIL.Image import Image
25 | from transformers import LlamaTokenizerFast
26 | from transformers.processing_utils import ProcessorMixin
27 |
28 | from janus.models.image_processing_vlm import VLMImageProcessor
29 | from janus.utils.conversation import get_conv_template
30 |
31 |
32 | class DictOutput(object):
33 | def keys(self):
34 | return self.__dict__.keys()
35 |
36 | def __getitem__(self, item):
37 | return self.__dict__[item]
38 |
39 | def __setitem__(self, key, value):
40 | self.__dict__[key] = value
41 |
42 |
43 | @dataclass
44 | class VLChatProcessorOutput(DictOutput):
45 | sft_format: str
46 | input_ids: torch.Tensor
47 | pixel_values: torch.Tensor
48 | num_image_tokens: torch.IntTensor
49 |
50 | def __len__(self):
51 | return len(self.input_ids)
52 |
53 |
54 | @dataclass
55 | class BatchedVLChatProcessorOutput(DictOutput):
56 | sft_format: List[str]
57 | input_ids: torch.Tensor
58 | pixel_values: torch.Tensor
59 | attention_mask: torch.Tensor
60 | images_seq_mask: torch.BoolTensor
61 | images_emb_mask: torch.BoolTensor
62 |
63 | def to(self, device, dtype=torch.bfloat16):
64 | self.input_ids = self.input_ids.to(device)
65 | self.attention_mask = self.attention_mask.to(device)
66 | self.images_seq_mask = self.images_seq_mask.to(device)
67 | self.images_emb_mask = self.images_emb_mask.to(device)
68 | self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
69 | return self
70 |
71 |
72 | class VLChatProcessor(ProcessorMixin):
73 | image_processor_class = "AutoImageProcessor"
74 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
75 |
76 | attributes = ["image_processor", "tokenizer"]
77 |
78 | system_prompt = (
79 | "You are a helpful language and vision assistant. "
80 | "You are able to understand the visual content that the user provides, "
81 | "and assist the user with a variety of tasks using natural language."
82 | )
83 |
84 | def __init__(
85 | self,
86 | image_processor: VLMImageProcessor,
87 | tokenizer: LlamaTokenizerFast,
88 | image_tag: str = "",
89 | image_start_tag: str = "",
90 | image_end_tag: str = "",
91 | pad_tag: str = "<|▁pad▁|>",
92 | num_image_tokens: int = 576,
93 | add_special_token: bool = False,
94 | sft_format: str = "deepseek",
95 | mask_prompt: bool = True,
96 | ignore_id: int = -100,
97 | **kwargs,
98 | ):
99 | self.image_processor = image_processor
100 | self.tokenizer = tokenizer
101 |
102 | image_id = self.tokenizer.vocab.get(image_tag)
103 | if image_id is None:
104 | special_tokens = [image_tag]
105 | special_tokens_dict = {"additional_special_tokens": special_tokens}
106 | self.tokenizer.add_special_tokens(special_tokens_dict)
107 | print(f"Add image tag = {image_tag} to the tokenizer")
108 |
109 | self.image_tag = image_tag
110 | self.image_start_tag = image_start_tag
111 | self.image_end_tag = image_end_tag
112 | self.pad_tag = pad_tag
113 |
114 | self.num_image_tokens = num_image_tokens
115 | self.add_special_token = add_special_token
116 | self.sft_format = sft_format
117 | self.mask_prompt = mask_prompt
118 | self.ignore_id = ignore_id
119 |
120 | super().__init__(
121 | image_processor,
122 | tokenizer,
123 | image_tag,
124 | num_image_tokens,
125 | add_special_token,
126 | sft_format,
127 | mask_prompt,
128 | ignore_id,
129 | **kwargs,
130 | )
131 |
132 | def new_chat_template(self):
133 | conv = get_conv_template(self.sft_format)
134 | conv.set_system_message(self.system_prompt)
135 | return conv
136 |
137 | def apply_sft_template_for_multi_turn_prompts(
138 | self,
139 | conversations: List[Dict[str, str]],
140 | sft_format: str = "deepseek",
141 | system_prompt: str = "",
142 | ):
143 | """
144 | Applies the SFT template to conversation.
145 |
146 | An example of conversation:
147 | conversation = [
148 | {
149 | "role": "User",
150 | "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?",
151 | "images": [
152 | "./multi-images/attribute_comparison_1.png",
153 | "./multi-images/attribute_comparison_2.png"
154 | ]
155 | },
156 | {
157 | "role": "Assistant",
158 | "content": ""
159 | }
160 | ]
161 |
162 | Args:
163 | conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
164 | sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
165 | system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
166 |
167 | Returns:
168 | sft_prompt (str): The formatted text.
169 | """
170 |
171 | conv = get_conv_template(sft_format)
172 | conv.set_system_message(system_prompt)
173 | for message in conversations:
174 | conv.append_message(message["role"], message["content"].strip())
175 | sft_prompt = conv.get_prompt().strip()
176 |
177 | return sft_prompt
178 |
179 | @property
180 | def image_token(self):
181 | return self.image_tag
182 |
183 | @property
184 | def image_id(self):
185 | image_id = self.tokenizer.vocab.get(self.image_tag)
186 | return image_id
187 |
188 | @property
189 | def image_start_id(self):
190 | image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
191 | return image_start_id
192 |
193 | @property
194 | def image_end_id(self):
195 | image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
196 | return image_end_id
197 |
198 | @property
199 | def image_start_token(self):
200 | return self.image_start_tag
201 |
202 | @property
203 | def image_end_token(self):
204 | return self.image_end_tag
205 |
206 | @property
207 | def pad_id(self):
208 | pad_id = self.tokenizer.vocab.get(self.pad_tag)
209 | # pad_id = self.tokenizer.pad_token_id
210 | # if pad_id is None:
211 | # pad_id = self.tokenizer.eos_token_id
212 |
213 | return pad_id
214 |
215 | def add_image_token(
216 | self,
217 | image_indices: List[int],
218 | input_ids: torch.LongTensor,
219 | ):
220 | """
221 |
222 | Args:
223 | image_indices (List[int]): [index_0, index_1, ..., index_j]
224 | input_ids (torch.LongTensor): [N]
225 |
226 | Returns:
227 | input_ids (torch.LongTensor): [N + image tokens]
228 | num_image_tokens (torch.IntTensor): [n_images]
229 | """
230 |
231 | input_slices = []
232 |
233 | start = 0
234 | for index in image_indices:
235 | if self.add_special_token:
236 | end = index + 1
237 | else:
238 | end = index
239 |
240 | # original text tokens
241 | input_slices.append(input_ids[start:end])
242 |
243 | # add boi, image tokens, eoi and set the mask as False
244 | input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
245 | input_slices.append(
246 | self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
247 | )
248 | input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
249 | start = index + 1
250 |
251 | # the left part
252 | input_slices.append(input_ids[start:])
253 |
254 | # concat all slices
255 | input_ids = torch.cat(input_slices, dim=0)
256 | num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
257 |
258 | return input_ids, num_image_tokens
259 |
260 | def process_one(
261 | self,
262 | prompt: str = None,
263 | conversations: List[Dict[str, str]] = None,
264 | images: List[Image] = None,
265 | **kwargs,
266 | ):
267 | """
268 |
269 | Args:
270 | prompt (str): the formatted prompt;
271 | conversations (List[Dict]): conversations with a list of messages;
272 | images (List[ImageType]): the list of images;
273 | **kwargs:
274 |
275 | Returns:
276 | outputs (BaseProcessorOutput): the output of the processor,
277 | - input_ids (torch.LongTensor): [N + image tokens]
278 | - target_ids (torch.LongTensor): [N + image tokens]
279 | - images (torch.FloatTensor): [n_images, 3, H, W]
280 | - image_id (int): the id of the image token
281 | - num_image_tokens (List[int]): the number of image tokens
282 | """
283 |
284 | assert (
285 | prompt is None or conversations is None
286 | ), "prompt and conversations cannot be used at the same time."
287 |
288 | if prompt is None:
289 | # apply sft format
290 | sft_format = self.apply_sft_template_for_multi_turn_prompts(
291 | conversations=conversations,
292 | sft_format=self.sft_format,
293 | system_prompt=self.system_prompt,
294 | )
295 | else:
296 | sft_format = prompt
297 |
298 | # tokenize
299 | input_ids = self.tokenizer.encode(sft_format)
300 | input_ids = torch.LongTensor(input_ids)
301 |
302 | # add image tokens to the input_ids
303 | image_token_mask: torch.BoolTensor = input_ids == self.image_id
304 | image_indices = image_token_mask.nonzero()
305 | input_ids, num_image_tokens = self.add_image_token(
306 | image_indices=image_indices,
307 | input_ids=input_ids,
308 | )
309 |
310 | # load images
311 | images_outputs = self.image_processor(images, return_tensors="pt")
312 |
313 | prepare = VLChatProcessorOutput(
314 | sft_format=sft_format,
315 | input_ids=input_ids,
316 | pixel_values=images_outputs.pixel_values,
317 | num_image_tokens=num_image_tokens,
318 | )
319 |
320 | return prepare
321 |
322 | def __call__(
323 | self,
324 | *,
325 | prompt: str = None,
326 | conversations: List[Dict[str, str]] = None,
327 | images: List[Image] = None,
328 | force_batchify: bool = True,
329 | **kwargs,
330 | ):
331 | """
332 |
333 | Args:
334 | prompt (str): the formatted prompt;
335 | conversations (List[Dict]): conversations with a list of messages;
336 | images (List[ImageType]): the list of images;
337 | force_batchify (bool): force batchify the inputs;
338 | **kwargs:
339 |
340 | Returns:
341 | outputs (BaseProcessorOutput): the output of the processor,
342 | - input_ids (torch.LongTensor): [N + image tokens]
343 | - images (torch.FloatTensor): [n_images, 3, H, W]
344 | - image_id (int): the id of the image token
345 | - num_image_tokens (List[int]): the number of image tokens
346 | """
347 |
348 | prepare = self.process_one(
349 | prompt=prompt, conversations=conversations, images=images
350 | )
351 |
352 | if force_batchify:
353 | prepare = self.batchify([prepare])
354 |
355 | return prepare
356 |
357 | def batchify(
358 | self, prepare_list: List[VLChatProcessorOutput]
359 | ) -> BatchedVLChatProcessorOutput:
360 | """
361 | Preprocesses the inputs for multimodal inference.
362 |
363 | Args:
364 | prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
365 |
366 | Returns:
367 | BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
368 | """
369 |
370 | batch_size = len(prepare_list)
371 | sft_format = []
372 | n_images = []
373 | seq_lens = []
374 | for prepare in prepare_list:
375 | n_images.append(len(prepare.num_image_tokens))
376 | seq_lens.append(len(prepare))
377 |
378 | input_token_max_len = max(seq_lens)
379 | max_n_images = max(1, max(n_images))
380 |
381 | batched_input_ids = torch.full(
382 | (batch_size, input_token_max_len), self.pad_id
383 | ).long() # FIXME
384 | batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
385 | batched_pixel_values = torch.zeros(
386 | (batch_size, max_n_images, *self.image_processor.default_shape)
387 | ).float()
388 | batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
389 | batched_images_emb_mask = torch.zeros(
390 | (batch_size, max_n_images, self.num_image_tokens)
391 | ).bool()
392 |
393 | for i, prepare in enumerate(prepare_list):
394 | input_ids = prepare.input_ids
395 | seq_len = len(prepare)
396 | n_image = len(prepare.num_image_tokens)
397 | # left-padding
398 | batched_attention_mask[i, -seq_len:] = 1
399 | batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
400 | batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
401 |
402 | if n_image > 0:
403 | batched_pixel_values[i, :n_image] = prepare.pixel_values
404 | for j, n_image_tokens in enumerate(prepare.num_image_tokens):
405 | batched_images_emb_mask[i, j, :n_image_tokens] = True
406 |
407 | sft_format.append(prepare.sft_format)
408 |
409 | batched_prepares = BatchedVLChatProcessorOutput(
410 | input_ids=batched_input_ids,
411 | attention_mask=batched_attention_mask,
412 | pixel_values=batched_pixel_values,
413 | images_seq_mask=batched_images_seq_mask,
414 | images_emb_mask=batched_images_emb_mask,
415 | sft_format=sft_format,
416 | )
417 |
418 | return batched_prepares
419 |
--------------------------------------------------------------------------------
/janus/models/vq_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | from dataclasses import dataclass, field
22 | from typing import List
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | from functools import partial
29 |
30 |
31 | @dataclass
32 | class ModelArgs:
33 | codebook_size: int = 16384
34 | codebook_embed_dim: int = 8
35 | codebook_l2_norm: bool = True
36 | codebook_show_usage: bool = True
37 | commit_loss_beta: float = 0.25
38 | entropy_loss_ratio: float = 0.0
39 |
40 | encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
41 | decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
42 | z_channels: int = 256
43 | dropout_p: float = 0.0
44 |
45 |
46 | class Encoder(nn.Module):
47 | def __init__(
48 | self,
49 | in_channels=3,
50 | ch=128,
51 | ch_mult=(1, 1, 2, 2, 4),
52 | num_res_blocks=2,
53 | norm_type="group",
54 | dropout=0.0,
55 | resamp_with_conv=True,
56 | z_channels=256,
57 | ):
58 | super().__init__()
59 | self.num_resolutions = len(ch_mult)
60 | self.num_res_blocks = num_res_blocks
61 | self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
62 |
63 | # downsampling
64 | in_ch_mult = (1,) + tuple(ch_mult)
65 | self.conv_blocks = nn.ModuleList()
66 | for i_level in range(self.num_resolutions):
67 | conv_block = nn.Module()
68 | # res & attn
69 | res_block = nn.ModuleList()
70 | attn_block = nn.ModuleList()
71 | block_in = ch * in_ch_mult[i_level]
72 | block_out = ch * ch_mult[i_level]
73 | for _ in range(self.num_res_blocks):
74 | res_block.append(
75 | ResnetBlock(
76 | block_in, block_out, dropout=dropout, norm_type=norm_type
77 | )
78 | )
79 | block_in = block_out
80 | if i_level == self.num_resolutions - 1:
81 | attn_block.append(AttnBlock(block_in, norm_type))
82 | conv_block.res = res_block
83 | conv_block.attn = attn_block
84 | # downsample
85 | if i_level != self.num_resolutions - 1:
86 | conv_block.downsample = Downsample(block_in, resamp_with_conv)
87 | self.conv_blocks.append(conv_block)
88 |
89 | # middle
90 | self.mid = nn.ModuleList()
91 | self.mid.append(
92 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
93 | )
94 | self.mid.append(AttnBlock(block_in, norm_type=norm_type))
95 | self.mid.append(
96 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
97 | )
98 |
99 | # end
100 | self.norm_out = Normalize(block_in, norm_type)
101 | self.conv_out = nn.Conv2d(
102 | block_in, z_channels, kernel_size=3, stride=1, padding=1
103 | )
104 |
105 | def forward(self, x):
106 | h = self.conv_in(x)
107 | # downsampling
108 | for i_level, block in enumerate(self.conv_blocks):
109 | for i_block in range(self.num_res_blocks):
110 | h = block.res[i_block](h)
111 | if len(block.attn) > 0:
112 | h = block.attn[i_block](h)
113 | if i_level != self.num_resolutions - 1:
114 | h = block.downsample(h)
115 |
116 | # middle
117 | for mid_block in self.mid:
118 | h = mid_block(h)
119 |
120 | # end
121 | h = self.norm_out(h)
122 | h = nonlinearity(h)
123 | h = self.conv_out(h)
124 | return h
125 |
126 |
127 | class Decoder(nn.Module):
128 | def __init__(
129 | self,
130 | z_channels=256,
131 | ch=128,
132 | ch_mult=(1, 1, 2, 2, 4),
133 | num_res_blocks=2,
134 | norm_type="group",
135 | dropout=0.0,
136 | resamp_with_conv=True,
137 | out_channels=3,
138 | ):
139 | super().__init__()
140 | self.num_resolutions = len(ch_mult)
141 | self.num_res_blocks = num_res_blocks
142 |
143 | block_in = ch * ch_mult[self.num_resolutions - 1]
144 | # z to block_in
145 | self.conv_in = nn.Conv2d(
146 | z_channels, block_in, kernel_size=3, stride=1, padding=1
147 | )
148 |
149 | # middle
150 | self.mid = nn.ModuleList()
151 | self.mid.append(
152 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
153 | )
154 | self.mid.append(AttnBlock(block_in, norm_type=norm_type))
155 | self.mid.append(
156 | ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
157 | )
158 |
159 | # upsampling
160 | self.conv_blocks = nn.ModuleList()
161 | for i_level in reversed(range(self.num_resolutions)):
162 | conv_block = nn.Module()
163 | # res & attn
164 | res_block = nn.ModuleList()
165 | attn_block = nn.ModuleList()
166 | block_out = ch * ch_mult[i_level]
167 | for _ in range(self.num_res_blocks + 1):
168 | res_block.append(
169 | ResnetBlock(
170 | block_in, block_out, dropout=dropout, norm_type=norm_type
171 | )
172 | )
173 | block_in = block_out
174 | if i_level == self.num_resolutions - 1:
175 | attn_block.append(AttnBlock(block_in, norm_type))
176 | conv_block.res = res_block
177 | conv_block.attn = attn_block
178 | # downsample
179 | if i_level != 0:
180 | conv_block.upsample = Upsample(block_in, resamp_with_conv)
181 | self.conv_blocks.append(conv_block)
182 |
183 | # end
184 | self.norm_out = Normalize(block_in, norm_type)
185 | self.conv_out = nn.Conv2d(
186 | block_in, out_channels, kernel_size=3, stride=1, padding=1
187 | )
188 |
189 | @property
190 | def last_layer(self):
191 | return self.conv_out.weight
192 |
193 | def forward(self, z):
194 | # z to block_in
195 | h = self.conv_in(z)
196 |
197 | # middle
198 | for mid_block in self.mid:
199 | h = mid_block(h)
200 |
201 | # upsampling
202 | for i_level, block in enumerate(self.conv_blocks):
203 | for i_block in range(self.num_res_blocks + 1):
204 | h = block.res[i_block](h)
205 | if len(block.attn) > 0:
206 | h = block.attn[i_block](h)
207 | if i_level != self.num_resolutions - 1:
208 | h = block.upsample(h)
209 |
210 | # end
211 | h = self.norm_out(h)
212 | h = nonlinearity(h)
213 | h = self.conv_out(h)
214 | return h
215 |
216 |
217 | class VectorQuantizer(nn.Module):
218 | def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
219 | super().__init__()
220 | self.n_e = n_e
221 | self.e_dim = e_dim
222 | self.beta = beta
223 | self.entropy_loss_ratio = entropy_loss_ratio
224 | self.l2_norm = l2_norm
225 | self.show_usage = show_usage
226 |
227 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
228 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
229 | if self.l2_norm:
230 | self.embedding.weight.data = F.normalize(
231 | self.embedding.weight.data, p=2, dim=-1
232 | )
233 | if self.show_usage:
234 | self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
235 |
236 | def forward(self, z):
237 | # reshape z -> (batch, height, width, channel) and flatten
238 | z = torch.einsum("b c h w -> b h w c", z).contiguous()
239 | z_flattened = z.view(-1, self.e_dim)
240 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
241 |
242 | if self.l2_norm:
243 | z = F.normalize(z, p=2, dim=-1)
244 | z_flattened = F.normalize(z_flattened, p=2, dim=-1)
245 | embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
246 | else:
247 | embedding = self.embedding.weight
248 |
249 | d = (
250 | torch.sum(z_flattened**2, dim=1, keepdim=True)
251 | + torch.sum(embedding**2, dim=1)
252 | - 2
253 | * torch.einsum(
254 | "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
255 | )
256 | )
257 |
258 | min_encoding_indices = torch.argmin(d, dim=1)
259 | z_q = embedding[min_encoding_indices].view(z.shape)
260 | perplexity = None
261 | min_encodings = None
262 | vq_loss = None
263 | commit_loss = None
264 | entropy_loss = None
265 |
266 | # compute loss for embedding
267 | if self.training:
268 | vq_loss = torch.mean((z_q - z.detach()) ** 2)
269 | commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
270 | entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
271 |
272 | # preserve gradients
273 | z_q = z + (z_q - z).detach()
274 |
275 | # reshape back to match original input shape
276 | z_q = torch.einsum("b h w c -> b c h w", z_q)
277 |
278 | return (
279 | z_q,
280 | (vq_loss, commit_loss, entropy_loss),
281 | (perplexity, min_encodings, min_encoding_indices),
282 | )
283 |
284 | def get_codebook_entry(self, indices, shape=None, channel_first=True):
285 | # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
286 | if self.l2_norm:
287 | embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
288 | else:
289 | embedding = self.embedding.weight
290 | z_q = embedding[indices] # (b*h*w, c)
291 |
292 | if shape is not None:
293 | if channel_first:
294 | z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
295 | # reshape back to match original input shape
296 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
297 | else:
298 | z_q = z_q.view(shape)
299 | return z_q
300 |
301 |
302 | class ResnetBlock(nn.Module):
303 | def __init__(
304 | self,
305 | in_channels,
306 | out_channels=None,
307 | conv_shortcut=False,
308 | dropout=0.0,
309 | norm_type="group",
310 | ):
311 | super().__init__()
312 | self.in_channels = in_channels
313 | out_channels = in_channels if out_channels is None else out_channels
314 | self.out_channels = out_channels
315 | self.use_conv_shortcut = conv_shortcut
316 |
317 | self.norm1 = Normalize(in_channels, norm_type)
318 | self.conv1 = nn.Conv2d(
319 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
320 | )
321 | self.norm2 = Normalize(out_channels, norm_type)
322 | self.dropout = nn.Dropout(dropout)
323 | self.conv2 = nn.Conv2d(
324 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
325 | )
326 |
327 | if self.in_channels != self.out_channels:
328 | if self.use_conv_shortcut:
329 | self.conv_shortcut = nn.Conv2d(
330 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
331 | )
332 | else:
333 | self.nin_shortcut = nn.Conv2d(
334 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
335 | )
336 |
337 | def forward(self, x):
338 | h = x
339 | h = self.norm1(h)
340 | h = nonlinearity(h)
341 | h = self.conv1(h)
342 | h = self.norm2(h)
343 | h = nonlinearity(h)
344 | h = self.dropout(h)
345 | h = self.conv2(h)
346 |
347 | if self.in_channels != self.out_channels:
348 | if self.use_conv_shortcut:
349 | x = self.conv_shortcut(x)
350 | else:
351 | x = self.nin_shortcut(x)
352 | return x + h
353 |
354 |
355 | class AttnBlock(nn.Module):
356 | def __init__(self, in_channels, norm_type="group"):
357 | super().__init__()
358 | self.norm = Normalize(in_channels, norm_type)
359 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
360 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
361 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
362 | self.proj_out = nn.Conv2d(
363 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
364 | )
365 |
366 | def forward(self, x):
367 | h_ = x
368 | h_ = self.norm(h_)
369 | q = self.q(h_)
370 | k = self.k(h_)
371 | v = self.v(h_)
372 |
373 | # compute attention
374 | b, c, h, w = q.shape
375 | q = q.reshape(b, c, h * w)
376 | q = q.permute(0, 2, 1) # b,hw,c
377 | k = k.reshape(b, c, h * w) # b,c,hw
378 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
379 | w_ = w_ * (int(c) ** (-0.5))
380 | w_ = F.softmax(w_, dim=2)
381 |
382 | # attend to values
383 | v = v.reshape(b, c, h * w)
384 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
385 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
386 | h_ = h_.reshape(b, c, h, w)
387 |
388 | h_ = self.proj_out(h_)
389 |
390 | return x + h_
391 |
392 |
393 | def nonlinearity(x):
394 | # swish
395 | return x * torch.sigmoid(x)
396 |
397 |
398 | def Normalize(in_channels, norm_type="group"):
399 | assert norm_type in ["group", "batch"]
400 | if norm_type == "group":
401 | return nn.GroupNorm(
402 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
403 | )
404 | elif norm_type == "batch":
405 | return nn.SyncBatchNorm(in_channels)
406 |
407 |
408 | class Upsample(nn.Module):
409 | def __init__(self, in_channels, with_conv):
410 | super().__init__()
411 | self.with_conv = with_conv
412 | if self.with_conv:
413 | self.conv = nn.Conv2d(
414 | in_channels, in_channels, kernel_size=3, stride=1, padding=1
415 | )
416 |
417 | def forward(self, x):
418 | if x.dtype != torch.float32:
419 | x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
420 | torch.bfloat16
421 | )
422 | else:
423 | x = F.interpolate(x, scale_factor=2.0, mode="nearest")
424 |
425 | if self.with_conv:
426 | x = self.conv(x)
427 | return x
428 |
429 |
430 | class Downsample(nn.Module):
431 | def __init__(self, in_channels, with_conv):
432 | super().__init__()
433 | self.with_conv = with_conv
434 | if self.with_conv:
435 | # no asymmetric padding in torch conv, must do it ourselves
436 | self.conv = nn.Conv2d(
437 | in_channels, in_channels, kernel_size=3, stride=2, padding=0
438 | )
439 |
440 | def forward(self, x):
441 | if self.with_conv:
442 | pad = (0, 1, 0, 1)
443 | x = F.pad(x, pad, mode="constant", value=0)
444 | x = self.conv(x)
445 | else:
446 | x = F.avg_pool2d(x, kernel_size=2, stride=2)
447 | return x
448 |
449 |
450 | def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
451 | flat_affinity = affinity.reshape(-1, affinity.shape[-1])
452 | flat_affinity /= temperature
453 | probs = F.softmax(flat_affinity, dim=-1)
454 | log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
455 | if loss_type == "softmax":
456 | target_probs = probs
457 | else:
458 | raise ValueError("Entropy loss {} not supported".format(loss_type))
459 | avg_probs = torch.mean(target_probs, dim=0)
460 | avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
461 | sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
462 | loss = sample_entropy - avg_entropy
463 | return loss
464 |
465 |
466 | class VQModel(nn.Module):
467 | def __init__(self, config: ModelArgs):
468 | super().__init__()
469 | self.config = config
470 | self.encoder = Encoder(
471 | ch_mult=config.encoder_ch_mult,
472 | z_channels=config.z_channels,
473 | dropout=config.dropout_p,
474 | )
475 | self.decoder = Decoder(
476 | ch_mult=config.decoder_ch_mult,
477 | z_channels=config.z_channels,
478 | dropout=config.dropout_p,
479 | )
480 |
481 | self.quantize = VectorQuantizer(
482 | config.codebook_size,
483 | config.codebook_embed_dim,
484 | config.commit_loss_beta,
485 | config.entropy_loss_ratio,
486 | config.codebook_l2_norm,
487 | config.codebook_show_usage,
488 | )
489 | self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
490 | self.post_quant_conv = nn.Conv2d(
491 | config.codebook_embed_dim, config.z_channels, 1
492 | )
493 |
494 | def encode(self, x):
495 | h = self.encoder(x)
496 | h = self.quant_conv(h)
497 | quant, emb_loss, info = self.quantize(h)
498 | return quant, emb_loss, info
499 |
500 | def decode(self, quant):
501 | quant = self.post_quant_conv(quant)
502 | dec = self.decoder(quant)
503 | return dec
504 |
505 | def decode_code(self, code_b, shape=None, channel_first=True):
506 | quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
507 | dec = self.decode(quant_b)
508 | return dec
509 |
510 | def forward(self, input):
511 | quant, diff, _ = self.encode(input)
512 | dec = self.decode(quant)
513 | return dec, diff
514 |
515 |
516 | #################################################################################
517 | # VQ Model Configs #
518 | #################################################################################
519 | def VQ_16(**kwargs):
520 | return VQModel(
521 | ModelArgs(
522 | encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
523 | )
524 | )
525 |
526 |
527 | VQ_models = {"VQ-16": VQ_16}
528 |
--------------------------------------------------------------------------------
/janus/models/siglip_vit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024 DeepSeek.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
21 | import math
22 | import warnings
23 | from dataclasses import dataclass
24 | from functools import partial
25 | from typing import (
26 | Callable,
27 | Dict,
28 | Final,
29 | List,
30 | Literal,
31 | Optional,
32 | Sequence,
33 | Set,
34 | Tuple,
35 | Type,
36 | Union,
37 | )
38 |
39 | import torch
40 | import torch.nn as nn
41 | import torch.nn.functional as F
42 | from timm.layers import (
43 | AttentionPoolLatent,
44 | DropPath,
45 | LayerType,
46 | Mlp,
47 | PatchDropout,
48 | PatchEmbed,
49 | resample_abs_pos_embed,
50 | )
51 | from timm.models._manipulate import checkpoint_seq, named_apply
52 |
53 |
54 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
55 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
56 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
57 | def norm_cdf(x):
58 | # Computes standard normal cumulative distribution function
59 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
60 |
61 | if (mean < a - 2 * std) or (mean > b + 2 * std):
62 | warnings.warn(
63 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
64 | "The distribution of values may be incorrect.",
65 | stacklevel=2,
66 | )
67 |
68 | with torch.no_grad():
69 | # Values are generated by using a truncated uniform distribution and
70 | # then using the inverse CDF for the normal distribution.
71 | # Get upper and lower cdf values
72 | l = norm_cdf((a - mean) / std) # noqa: E741
73 | u = norm_cdf((b - mean) / std)
74 |
75 | # Uniformly fill tensor with values from [l, u], then translate to
76 | # [2l-1, 2u-1].
77 | tensor.uniform_(2 * l - 1, 2 * u - 1)
78 |
79 | # Use inverse cdf transform for normal distribution to get truncated
80 | # standard normal
81 | tensor.erfinv_()
82 |
83 | # Transform to proper mean, std
84 | tensor.mul_(std * math.sqrt(2.0))
85 | tensor.add_(mean)
86 |
87 | # Clamp to ensure it's in the proper range
88 | tensor.clamp_(min=a, max=b)
89 | return tensor
90 |
91 |
92 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93 | # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94 | r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95 | convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
96 | Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97 | from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98 | with values outside :math:`[a, b]` redrawn until they are within
99 | the bounds. The method used for generating the random values works
100 | best when :math:`a \leq \text{mean} \leq b`.
101 | Args:
102 | tensor: an n-dimensional `torch.Tensor`
103 | mean: the mean of the normal distribution
104 | std: the standard deviation of the normal distribution
105 | a: the minimum cutoff value
106 | b: the maximum cutoff value
107 | Examples:
108 | >>> w = torch.empty(3, 5)
109 | >>> nn.init.trunc_normal_(w)
110 | """
111 |
112 | with torch.no_grad():
113 | dtype = tensor.dtype
114 | tensor_fp32 = tensor.float()
115 | tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
116 | tensor_dtype = tensor_fp32.to(dtype=dtype)
117 | tensor.copy_(tensor_dtype)
118 |
119 |
120 | def init_weights(self):
121 | if self.pos_embed is not None:
122 | trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
123 | trunc_normal_(self.latent, std=self.latent_dim**-0.5)
124 |
125 |
126 | def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
127 | """ViT weight initialization, original timm impl (for reproducibility)"""
128 | if isinstance(module, nn.Linear):
129 | trunc_normal_(module.weight, std=0.02)
130 | if module.bias is not None:
131 | nn.init.zeros_(module.bias)
132 | elif hasattr(module, "init_weights"):
133 | module.init_weights()
134 |
135 |
136 | class Attention(nn.Module):
137 | fused_attn: Final[bool]
138 |
139 | def __init__(
140 | self,
141 | dim: int,
142 | num_heads: int = 8,
143 | qkv_bias: bool = False,
144 | qk_norm: bool = False,
145 | attn_drop: float = 0.0,
146 | proj_drop: float = 0.0,
147 | norm_layer: nn.Module = nn.LayerNorm,
148 | ) -> None:
149 | super().__init__()
150 | assert dim % num_heads == 0, "dim should be divisible by num_heads"
151 | self.num_heads = num_heads
152 | self.head_dim = dim // num_heads
153 | self.scale = self.head_dim**-0.5
154 | # self.fused_attn = use_fused_attn()
155 | self.fused_attn = True
156 |
157 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
158 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
159 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
160 | self.attn_drop = nn.Dropout(attn_drop)
161 | self.proj = nn.Linear(dim, dim)
162 | self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
163 |
164 | def forward(self, x: torch.Tensor) -> torch.Tensor:
165 | B, N, C = x.shape
166 | qkv = (
167 | self.qkv(x)
168 | .reshape(B, N, 3, self.num_heads, self.head_dim)
169 | .permute(2, 0, 3, 1, 4)
170 | )
171 | q, k, v = qkv.unbind(0)
172 | q, k = self.q_norm(q), self.k_norm(k)
173 |
174 | if self.fused_attn:
175 | x = F.scaled_dot_product_attention(
176 | q,
177 | k,
178 | v,
179 | dropout_p=self.attn_drop.p if self.training else 0.0,
180 | )
181 | else:
182 | q = q * self.scale
183 | attn = q @ k.transpose(-2, -1)
184 | attn = attn.softmax(dim=-1)
185 | attn = self.attn_drop(attn)
186 | x = attn @ v
187 |
188 | x = x.transpose(1, 2).reshape(B, N, C)
189 | x = self.proj(x)
190 | x = self.proj_drop(x)
191 | return x
192 |
193 |
194 | class LayerScale(nn.Module):
195 | def __init__(
196 | self,
197 | dim: int,
198 | init_values: float = 1e-5,
199 | inplace: bool = False,
200 | ) -> None:
201 | super().__init__()
202 | self.inplace = inplace
203 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
204 |
205 | def forward(self, x: torch.Tensor) -> torch.Tensor:
206 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
207 |
208 |
209 | class Block(nn.Module):
210 | def __init__(
211 | self,
212 | dim: int,
213 | num_heads: int,
214 | mlp_ratio: float = 4.0,
215 | qkv_bias: bool = False,
216 | qk_norm: bool = False,
217 | proj_drop: float = 0.0,
218 | attn_drop: float = 0.0,
219 | init_values: Optional[float] = None,
220 | drop_path: float = 0.0,
221 | act_layer: nn.Module = nn.GELU,
222 | norm_layer: nn.Module = nn.LayerNorm,
223 | mlp_layer: nn.Module = Mlp,
224 | ) -> None:
225 | super().__init__()
226 | self.norm1 = norm_layer(dim)
227 | self.attn = Attention(
228 | dim,
229 | num_heads=num_heads,
230 | qkv_bias=qkv_bias,
231 | qk_norm=qk_norm,
232 | attn_drop=attn_drop,
233 | proj_drop=proj_drop,
234 | norm_layer=norm_layer,
235 | )
236 | self.ls1 = (
237 | LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
238 | )
239 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
240 |
241 | self.norm2 = norm_layer(dim)
242 | self.mlp = mlp_layer(
243 | in_features=dim,
244 | hidden_features=int(dim * mlp_ratio),
245 | act_layer=act_layer,
246 | drop=proj_drop,
247 | )
248 | self.ls2 = (
249 | LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
250 | )
251 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
252 |
253 | def forward(self, x: torch.Tensor) -> torch.Tensor:
254 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
255 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
256 | return x
257 |
258 |
259 | class VisionTransformer(nn.Module):
260 | """Vision Transformer
261 |
262 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
263 | - https://arxiv.org/abs/2010.11929
264 | """
265 |
266 | dynamic_img_size: Final[bool]
267 |
268 | def __init__(
269 | self,
270 | img_size: Union[int, Tuple[int, int]] = 224,
271 | patch_size: Union[int, Tuple[int, int]] = 16,
272 | in_chans: int = 3,
273 | num_classes: int = 1000,
274 | global_pool: Literal["", "avg", "token", "map"] = "token",
275 | embed_dim: int = 768,
276 | depth: int = 12,
277 | num_heads: int = 12,
278 | mlp_ratio: float = 4.0,
279 | qkv_bias: bool = True,
280 | qk_norm: bool = False,
281 | init_values: Optional[float] = None,
282 | class_token: bool = True,
283 | no_embed_class: bool = False,
284 | reg_tokens: int = 0,
285 | pre_norm: bool = False,
286 | fc_norm: Optional[bool] = None,
287 | dynamic_img_size: bool = False,
288 | dynamic_img_pad: bool = False,
289 | drop_rate: float = 0.0,
290 | pos_drop_rate: float = 0.0,
291 | patch_drop_rate: float = 0.0,
292 | proj_drop_rate: float = 0.0,
293 | attn_drop_rate: float = 0.0,
294 | drop_path_rate: float = 0.0,
295 | weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
296 | embed_layer: Callable = PatchEmbed,
297 | norm_layer: Optional[LayerType] = None,
298 | act_layer: Optional[LayerType] = None,
299 | block_fn: Type[nn.Module] = Block,
300 | mlp_layer: Type[nn.Module] = Mlp,
301 | ignore_head: bool = False,
302 | ) -> None:
303 | """
304 | Args:
305 | img_size: Input image size.
306 | patch_size: Patch size.
307 | in_chans: Number of image input channels.
308 | num_classes: Mumber of classes for classification head.
309 | global_pool: Type of global pooling for final sequence (default: 'token').
310 | embed_dim: Transformer embedding dimension.
311 | depth: Depth of transformer.
312 | num_heads: Number of attention heads.
313 | mlp_ratio: Ratio of mlp hidden dim to embedding dim.
314 | qkv_bias: Enable bias for qkv projections if True.
315 | init_values: Layer-scale init values (layer-scale enabled if not None).
316 | class_token: Use class token.
317 | no_embed_class: Don't include position embeddings for class (or reg) tokens.
318 | reg_tokens: Number of register tokens.
319 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
320 | drop_rate: Head dropout rate.
321 | pos_drop_rate: Position embedding dropout rate.
322 | attn_drop_rate: Attention dropout rate.
323 | drop_path_rate: Stochastic depth rate.
324 | weight_init: Weight initialization scheme.
325 | embed_layer: Patch embedding layer.
326 | norm_layer: Normalization layer.
327 | act_layer: MLP activation layer.
328 | block_fn: Transformer block layer.
329 | """
330 | super().__init__()
331 | assert global_pool in ("", "avg", "token", "map")
332 | assert class_token or global_pool != "token"
333 | use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
334 | # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
335 | # act_layer = get_act_layer(act_layer) or nn.GELU
336 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
337 | act_layer = nn.GELU
338 |
339 | self.num_classes = num_classes
340 | self.global_pool = global_pool
341 | self.num_features = self.embed_dim = (
342 | embed_dim # num_features for consistency with other models
343 | )
344 | self.num_prefix_tokens = 1 if class_token else 0
345 | self.num_prefix_tokens += reg_tokens
346 | self.num_reg_tokens = reg_tokens
347 | self.has_class_token = class_token
348 | self.no_embed_class = (
349 | no_embed_class # don't embed prefix positions (includes reg)
350 | )
351 | self.dynamic_img_size = dynamic_img_size
352 | self.grad_checkpointing = False
353 | self.ignore_head = ignore_head
354 |
355 | embed_args = {}
356 | if dynamic_img_size:
357 | # flatten deferred until after pos embed
358 | embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
359 | self.patch_embed = embed_layer(
360 | img_size=img_size,
361 | patch_size=patch_size,
362 | in_chans=in_chans,
363 | embed_dim=embed_dim,
364 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
365 | dynamic_img_pad=dynamic_img_pad,
366 | **embed_args,
367 | )
368 | num_patches = self.patch_embed.num_patches
369 |
370 | self.cls_token = (
371 | nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
372 | )
373 | self.reg_token = (
374 | nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
375 | )
376 | embed_len = (
377 | num_patches if no_embed_class else num_patches + self.num_prefix_tokens
378 | )
379 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
380 | self.pos_drop = nn.Dropout(p=pos_drop_rate)
381 | if patch_drop_rate > 0:
382 | self.patch_drop = PatchDropout(
383 | patch_drop_rate,
384 | num_prefix_tokens=self.num_prefix_tokens,
385 | )
386 | else:
387 | self.patch_drop = nn.Identity()
388 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
389 |
390 | dpr = [
391 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
392 | ] # stochastic depth decay rule
393 | self.blocks = nn.Sequential(
394 | *[
395 | block_fn(
396 | dim=embed_dim,
397 | num_heads=num_heads,
398 | mlp_ratio=mlp_ratio,
399 | qkv_bias=qkv_bias,
400 | qk_norm=qk_norm,
401 | init_values=init_values,
402 | proj_drop=proj_drop_rate,
403 | attn_drop=attn_drop_rate,
404 | drop_path=dpr[i],
405 | norm_layer=norm_layer,
406 | act_layer=act_layer,
407 | mlp_layer=mlp_layer,
408 | )
409 | for i in range(depth)
410 | ]
411 | )
412 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
413 |
414 | # Classifier Head
415 | if global_pool == "map":
416 | AttentionPoolLatent.init_weights = init_weights
417 | self.attn_pool = AttentionPoolLatent(
418 | self.embed_dim,
419 | num_heads=num_heads,
420 | mlp_ratio=mlp_ratio,
421 | norm_layer=norm_layer,
422 | )
423 | else:
424 | self.attn_pool = None
425 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
426 | self.head_drop = nn.Dropout(drop_rate)
427 | self.head = (
428 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
429 | )
430 |
431 | if weight_init != "skip":
432 | self.init_weights(weight_init)
433 |
434 | def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
435 | assert mode in ("jax", "jax_nlhb", "moco", "")
436 | # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
437 | trunc_normal_(self.pos_embed, std=0.02)
438 | if self.cls_token is not None:
439 | nn.init.normal_(self.cls_token, std=1e-6)
440 | named_apply(init_weights_vit_timm, self)
441 |
442 | @torch.jit.ignore
443 | def no_weight_decay(self) -> Set:
444 | return {"pos_embed", "cls_token", "dist_token"}
445 |
446 | @torch.jit.ignore
447 | def group_matcher(self, coarse: bool = False) -> Dict:
448 | return dict(
449 | stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
450 | blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
451 | )
452 |
453 | @torch.jit.ignore
454 | def set_grad_checkpointing(self, enable: bool = True) -> None:
455 | self.grad_checkpointing = enable
456 |
457 | @torch.jit.ignore
458 | def get_classifier(self) -> nn.Module:
459 | return self.head
460 |
461 | def reset_classifier(self, num_classes: int, global_pool=None) -> None:
462 | self.num_classes = num_classes
463 | if global_pool is not None:
464 | assert global_pool in ("", "avg", "token", "map")
465 | if global_pool == "map" and self.attn_pool is None:
466 | assert (
467 | False
468 | ), "Cannot currently add attention pooling in reset_classifier()."
469 | elif global_pool != "map " and self.attn_pool is not None:
470 | self.attn_pool = None # remove attention pooling
471 | self.global_pool = global_pool
472 | self.head = (
473 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
474 | )
475 |
476 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477 | if self.dynamic_img_size:
478 | B, H, W, C = x.shape
479 | pos_embed = resample_abs_pos_embed(
480 | self.pos_embed,
481 | (H, W),
482 | num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483 | )
484 | x = x.view(B, -1, C)
485 | else:
486 | pos_embed = self.pos_embed
487 |
488 | to_cat = []
489 | if self.cls_token is not None:
490 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491 | if self.reg_token is not None:
492 | to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493 |
494 | if self.no_embed_class:
495 | # deit-3, updated JAX (big vision)
496 | # position embedding does not overlap with class token, add then concat
497 | x = x + pos_embed
498 | if to_cat:
499 | x = torch.cat(to_cat + [x], dim=1)
500 | else:
501 | # original timm, JAX, and deit vit impl
502 | # pos_embed has entry for class token, concat then add
503 | if to_cat:
504 | x = torch.cat(to_cat + [x], dim=1)
505 | x = x + pos_embed
506 |
507 | return self.pos_drop(x)
508 |
509 | def _intermediate_layers(
510 | self,
511 | x: torch.Tensor,
512 | n: Union[int, Sequence] = 1,
513 | ) -> List[torch.Tensor]:
514 | outputs, num_blocks = [], len(self.blocks)
515 | take_indices = set(
516 | range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517 | )
518 |
519 | # forward pass
520 | x = self.patch_embed(x)
521 | x = self._pos_embed(x)
522 | x = self.patch_drop(x)
523 | x = self.norm_pre(x)
524 | for i, blk in enumerate(self.blocks):
525 | x = blk(x)
526 | if i in take_indices:
527 | outputs.append(x)
528 |
529 | return outputs
530 |
531 | def get_intermediate_layers(
532 | self,
533 | x: torch.Tensor,
534 | n: Union[int, Sequence] = 1,
535 | reshape: bool = False,
536 | return_prefix_tokens: bool = False,
537 | norm: bool = False,
538 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539 | """Intermediate layer accessor (NOTE: This is a WIP experiment).
540 | Inspired by DINO / DINOv2 interface
541 | """
542 | # take last n blocks if n is an int, if in is a sequence, select by matching indices
543 | outputs = self._intermediate_layers(x, n)
544 | if norm:
545 | outputs = [self.norm(out) for out in outputs]
546 | prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547 | outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548 |
549 | if reshape:
550 | grid_size = self.patch_embed.grid_size
551 | outputs = [
552 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553 | .permute(0, 3, 1, 2)
554 | .contiguous()
555 | for out in outputs
556 | ]
557 |
558 | if return_prefix_tokens:
559 | return tuple(zip(outputs, prefix_tokens))
560 | return tuple(outputs)
561 |
562 | def forward_features(self, x: torch.Tensor) -> torch.Tensor:
563 | x = self.patch_embed(x)
564 | x = self._pos_embed(x)
565 | x = self.patch_drop(x)
566 | x = self.norm_pre(x)
567 | if self.grad_checkpointing and not torch.jit.is_scripting():
568 | x = checkpoint_seq(self.blocks, x)
569 | else:
570 | x = self.blocks(x)
571 | x = self.norm(x)
572 | return x
573 |
574 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
575 | if self.attn_pool is not None:
576 | x = self.attn_pool(x)
577 | elif self.global_pool == "avg":
578 | x = x[:, self.num_prefix_tokens :].mean(dim=1)
579 | elif self.global_pool:
580 | x = x[:, 0] # class token
581 | x = self.fc_norm(x)
582 | x = self.head_drop(x)
583 | return x if pre_logits else self.head(x)
584 |
585 | def forward(self, x: torch.Tensor) -> torch.Tensor:
586 | x = self.forward_features(x)
587 | if not self.ignore_head:
588 | x = self.forward_head(x)
589 | return x
590 |
591 |
592 | @dataclass
593 | class SigLIPVisionCfg:
594 | width: int = 1152
595 | layers: Union[Tuple[int, int, int, int], int] = 27
596 | heads: int = 16
597 | patch_size: int = 14
598 | image_size: Union[Tuple[int, int], int] = 336
599 | global_pool: str = "map"
600 | mlp_ratio: float = 3.7362
601 | class_token: bool = False
602 | num_classes: int = 0
603 | use_checkpoint: bool = False
604 |
605 |
606 | SigLIP_MODEL_CONFIG = {
607 | "siglip_so400m_patch14_384": {
608 | "image_size": 336,
609 | "patch_size": 14,
610 | "width": 1152,
611 | "layers": 27,
612 | "heads": 16,
613 | "mlp_ratio": 3.7362,
614 | "global_pool": "map",
615 | "use_checkpoint": False,
616 | },
617 | "siglip_so400m_patch14_224": {
618 | "image_size": 224,
619 | "patch_size": 14,
620 | "width": 1152,
621 | "layers": 27,
622 | "heads": 16,
623 | "mlp_ratio": 3.7362,
624 | "global_pool": "map",
625 | "use_checkpoint": False,
626 | },
627 | "siglip_large_patch16_384": {
628 | "image_size": 384,
629 | "patch_size": 16,
630 | "width": 1024,
631 | "layers": 24,
632 | "heads": 16,
633 | "mlp_ratio": 4,
634 | "global_pool": "map",
635 | "use_checkpoint": False,
636 | },
637 | }
638 |
639 |
640 | def create_siglip_vit(
641 | model_name: str = "siglip_so400m_patch14_384",
642 | image_size: int = 384,
643 | select_layer: int = -1,
644 | ckpt_path: str = "",
645 | **kwargs,
646 | ):
647 | assert (
648 | model_name in SigLIP_MODEL_CONFIG.keys()
649 | ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
650 |
651 | vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
652 |
653 | if select_layer <= 0:
654 | layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
655 | else:
656 | layers = min(vision_cfg.layers, select_layer)
657 |
658 | model = VisionTransformer(
659 | img_size=image_size,
660 | patch_size=vision_cfg.patch_size,
661 | embed_dim=vision_cfg.width,
662 | depth=layers,
663 | num_heads=vision_cfg.heads,
664 | mlp_ratio=vision_cfg.mlp_ratio,
665 | class_token=vision_cfg.class_token,
666 | global_pool=vision_cfg.global_pool,
667 | ignore_head=kwargs.get("ignore_head", True),
668 | weight_init=kwargs.get("weight_init", "skip"),
669 | num_classes=0,
670 | )
671 |
672 | if ckpt_path:
673 | state_dict = torch.load(ckpt_path, map_location="cpu")
674 |
675 | incompatible_keys = model.load_state_dict(state_dict, strict=False)
676 | print(
677 | f"SigLIP-ViT restores from {ckpt_path},\n"
678 | f"\tincompatible_keys:', {incompatible_keys}."
679 | )
680 |
681 | return model
682 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------