├── .gitignore ├── README.md ├── README_zh.md ├── dataprocess ├── config.yaml └── process_image.py ├── dataset └── image_caption_dataset.py ├── minicpm ├── Mminicpm.py ├── configuration_minicpm.py └── modeling_minicpm.py ├── model └── model.py ├── qwen ├── Mqwen.py ├── cache_autogptq_cuda_256.cpp ├── cache_autogptq_cuda_kernel_256.cu ├── configuration_qwen.py ├── cpp_kernels.py ├── modeling_qwen.py ├── qwen_generation_utils.py └── tokenization_qwen.py ├── requirements.txt ├── test.py ├── test.sh ├── test_img └── 1.jpg ├── train.py ├── train.sh ├── trainer.py ├── visual ├── CLIP_VIT.py └── SIGLIP_VIT.py └── webUI.py /.gitignore: -------------------------------------------------------------------------------- 1 | /weights/* 2 | /data/* 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Building Your Own Multimodal Large Model from Scratch 2 | 3 | For the Chinese version of the README, please refer to [中文文档](README_zh.md). 4 | 5 | ## Code Explanation 💻 6 | 7 | - **Data Preprocessing**: The relevant code is located in the `dataprocess` folder, and dataset-related code is in the `dataset` folder. Data preprocessing mainly includes path merging, QA data concatenation, feature insertion token processing, etc. 8 | - **LLM Model**: Uses Qwen-7B as the main model, with relevant code in the `qwen` folder. By overriding the `forward` method of `QWenModel`, multimodal feature injection is achieved. 9 | - **Visual Model**: Uses `CLIP_VIT` and `SIGLIP_VIT`, with relevant code in the `visual` folder, which also includes other backbone networks. 10 | - **VLM Model**: Relevant code is in the `model.py` file under the `model` folder. 11 | 12 | ## Dataset 🌏 13 | 14 | We use a multilingual dataset, mainly including the COCO2017 dataset and the AI Challenger image Chinese description dataset: 15 | - The COCO dataset annotations use LLAVA's `detail_23k` and `complex_reasoning_77k`, which can effectively enhance the richness of the model's descriptions. 16 | - The AI Challenger dataset uses the original annotations and a fixed prompt. 17 | 18 | ## Model Architecture 🤖 19 | 20 | In VLM, the visual part uses the `CLIP` or `SIGLIP` model, which has already achieved preliminary semantic alignment, and uses a two-layer MLP for feature mapping. By overriding the `forward` method of `QWenModel`, the corresponding `image` tokens are replaced with visual features. 21 | 22 | If you wish to replace the model architecture, please modify [this part](https://github.com/xinyanghuang7/Basic-Vision-Language-Model/blob/main/train.py#L41). 23 | 24 | ## How to Start Deployment 🔧 25 | 26 | ### Download Relevant Data 27 | 28 | | AI Challenger | COCO | complex_reasoning_77k.json | detail_23k.json | 29 | | --- | --- | --- | --- | 30 | | [AI Challenger](https://tianchi.aliyun.com/dataset/145781) | [COCO 2017](http://images.cocodataset.org/zips/train2017.zip) | [complex_reasoning_77k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/complex_reasoning_77k.json) | [detail_23k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/detail_23k.json) | 31 | 32 | Please store the datasets according to the paths in the [configuration file](https://github.com/xinyanghuang7/Basic-Vision-Language-Model/blob/main/dataprocess/config.yaml). Of course, the paths can be customized. 33 | 34 | Please note that this path needs to be consistent with [data/](https://github.com/xinyanghuang7/Basic-Vision-Language-Model/blob/main/train.py#L29) for the model to read. 35 | 36 | After downloading the data, use `process_image.py` for preprocessing. 37 | 38 | ### Install the Runtime Environment 39 | 40 | Use `pip install` to install `requirements.txt`: 41 | 42 | ```shell 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ### Start Training 47 | 48 | Model training adopts the method of freezing the image model, and LLM uses the LoRA method to reduce training pressure. The parameters to be trained include the visual feature mapping layer and the LoRA parameters in the LLM. Since the mapping layer is initialized with untrained parameters, to balance the optimization speed of the model parameters, a larger learning rate is set for the mapping layer than for the LoRA part. 49 | 50 | Run the `train.sh` in the root directory, and you can configure the relevant parameters for experiments. 51 | 52 | ```shell 53 | sh train.sh 54 | ``` 55 | 56 | Through the above steps, you can start the training process and train the multimodal model. 57 | 58 | The model weights will be saved in the `--output_dir`, and this path can also be customized. 59 | 60 | #### `train.sh` Script Analysis 61 | 62 | ```sh 63 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=25642 train.py \ 64 | --lora_rank 128 \ 65 | --lora_dropout 0.10 \ 66 | --per_device_train_batch_size 4 \ 67 | --gradient_accumulation_steps 1 \ 68 | --num_train_epochs 2 \ 69 | --save_steps 1000 \ 70 | --save_total_limit 5 \ 71 | --learning_rate 3e-5 \ 72 | --seed 42 \ 73 | --ddp_find_unused_parameters False \ 74 | --feature_proj_lr 1e-4 \ 75 | --remove_unused_columns false \ 76 | --logging_steps 100 \ 77 | --output_dir ./weights/train_V1_5 \ 78 | --target_modules "c_attn|w1|w2" \ 79 | --image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json \ 80 | --captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json 81 | ``` 82 | 83 | #### Explanation 84 | 85 | 1. **CUDA_VISIBLE_DEVICES=0**: Use GPU with ID 0. 86 | 2. **torchrun**: PyTorch's distributed training tool. 87 | 3. **--nproc_per_node=1**: Run 1 process per node. 88 | 4. **--master_port=25642**: Set the inter-process communication port. 89 | 5. **train.py**: Main training script. 90 | 91 | #### Parameters Passed to `train.py` 92 | 93 | 1. **--lora_rank 128**: The rank of the LoRA layer is 128. 94 | 2. **--lora_dropout 0.10**: The dropout rate of the LoRA layer is 10%. 95 | 3. **--per_device_train_batch_size 4**: The training batch size per device is 4. 96 | 4. **--gradient_accumulation_steps 1**: Gradient accumulation steps are 1. 97 | 5. **--num_train_epochs 2**: Train for 2 epochs. 98 | 6. **--save_steps 1000**: Save the model every 1000 steps. 99 | 7. **--save_total_limit 5**: Save up to 5 checkpoints. 100 | 8. **--learning_rate 3e-5**: Learning rate is 3e-5. 101 | 9. **--seed 42**: Random seed is 42. 102 | 10. **--ddp_find_unused_parameters False**: Disable DDP finding unused parameters. 103 | 11. **--feature_proj_lr 1e-4**: Learning rate for the feature projection layer is 1e-4. 104 | 12. **--remove_unused_columns false**: Retain unused columns. 105 | 13. **--logging_steps 100**: Log every 100 steps. 106 | 14. **--output_dir ./weights/train_V1_5**: Output directory. 107 | 15. **--target_modules "c_attn|w1|w2"**: Target modules for LoRA adaptation. 108 | 16. **--image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json**: Path to the image mapping file. 109 | 17. **--captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json**: Path to the captions file. 110 | 111 | ### Test the Model 112 | 113 | Run the `test.sh` in the root directory, and you can configure the relevant parameters for experiments. 114 | 115 | ```shell 116 | sh test.sh 117 | ``` 118 | 119 | The code will read images from the folder for Q&A. 120 | 121 | #### `test.sh` Script Analysis 122 | 123 | ```sh 124 | python test.py --base_language_model Qwen/Qwen-7B-Chat --base_value_model openai/clip-vit-large-patch14 --model_weights ./weights/train_V1_5/checkpoint-10000/ --image_path ./test_img/1.jpg --prompt "Describe the colors appearing in the image<|extra_0|>" 125 | ``` 126 | 127 | #### Load pre-trained model (optional) 128 | If you want to test the model directly, the pre-trained weights provided are as follows: 129 | 130 | | SIGLIP_Qwen_epoch19000 | SIGLIP_Qwen_epoch36000 | 131 | | :---: | :---: | 132 | |[Model1](https://huggingface.co/xinyanghuang/Basic-Visual-Language-Model/tree/main/checkpoint-19000)|[Model2](https://huggingface.co/xinyanghuang/Basic-Visual-Language-Model/tree/main/checkpoint-36000)| 133 | 134 | You can directly download the relevant files and test them. 135 | 136 | #### Parameters Passed to `test.py` 137 | 138 | 1. **--base_language_model Qwen/Qwen-7B-Chat**: Specify the path to the base language model, here using `Qwen/Qwen-7B-Chat`. 139 | 2. **--base_value_model openai/clip-vit-large-patch14**: Specify the path to the base visual model, here using `openai/clip-vit-large-patch14`. 140 | 3. **--model_weights ./weights/train_V1_5/checkpoint-10000/**: Specify the path to the model weights, here using the checkpoint `checkpoint-10000` saved during training. 141 | 4. **--image_path ./test_img/1.jpg**: Specify the path to the input image, here using `./test_img/1.jpg`. 142 | 5. **--prompt "Describe the colors appearing in the image<|extra_0|>"**: Specify the prompt for the model, here asking the model to describe the colors appearing in the image. 143 | 144 | ## References 📚 145 | 146 | Thanks to the great work of the following projects 🙌: 147 | 148 | - https://github.com/WatchTower-Liu/VLM-learning/tree/main 149 | - https://github.com/QwenLM/Qwen 150 | - https://github.com/haotian-liu/LLaVA 151 | 152 | ## Contact ✉ 153 | 154 | If you have any questions or ideas, feel free to contact me 😊: 155 | 156 | hsinyanghuang7@gmail.com 157 | 158 | I will reply as soon as I see the email! 159 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # 从零搭建自己的多模态大模型 2 | 3 | For the English version of the README, please refer to [README.md](README.md). 4 | 5 | ## 代码说明 💻 6 | 7 | - **数据预处理**:相关代码位于 `dataprocess` 文件夹下,数据集相关代码在 `dataset` 文件夹中。数据预处理主要包括路径合并、QA 数据拼接、特征插入 token 处理等。 8 | - **LLM模型**:使用 Qwen-7B 作为主体,相关代码在 `qwen` 文件夹中。通过重写 `QWenModel` 的 `forward` 方法,实现多模态特征的注入。 9 | - **视觉模型**:使用 `CLIP_VIT` 和 `SIGLIP_VIT`,相关代码在 `visual` 文件夹中,其中还包含其他主干网络。 10 | - **VLM模型**:相关代码在 `model` 文件夹下的 `model.py` 文件中。 11 | 12 | ## 数据集 🌏 13 | 14 | 我们使用了多语言数据集,主要包括 COCO2017 数据集和 AI Challenger 图像中文描述数据集: 15 | - COCO 数据集的标注使用了 LLAVA 的 `detail_23k` 和 `complex_reasoning_77k`,这些标注可以有效提升模型的描述丰富度。 16 | - AI Challenger 数据集使用原始标注,并使用固定的 prompt。 17 | 18 | ## 模型架构 🤖 19 | 20 | 在 VLM 中,视觉部分采用已经实现初步语义对齐的 `CLIP` 或 `SIGLIP` 模型,并使用两层 MLP 进行特征映射。通过重写 `QWenModel` 的 `forward` 方法,将对应的 `image` 标记替换为视觉特征。 21 | 22 | 如果你希望替换模型架构,请修改[这部分](https://github.com/xinyanghuang7/Basic-Vision-Language-Model/blob/main/train.py#L41)。 23 | 24 | ## 如何开始部署 🔧 25 | 26 | ### 下载相关数据 27 | 28 | | AI Challenger | COCO | complex_reasoning_77k.json | detail_23k.json | 29 | | --- | --- | --- | --- | 30 | | [AI Challenger](https://tianchi.aliyun.com/dataset/145781) | [COCO 2017](http://images.cocodataset.org/zips/train2017.zip) | [complex_reasoning_77k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/complex_reasoning_77k.json) | [detail_23k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/detail_23k.json) | 31 | 32 | 请按照[配置文件](https://github.com/xinyanghuang7/Basic-Vision-Language-Model/blob/main/dataprocess/config.yaml)中的路径存放数据集。当然,路径可以自定义。 33 | 34 | 请注意,此路径需要与[data/](https://github.com/xinyanghuang7/Basic-Vision-Language-Model/blob/main/train.py#L29)保持一致,以便模型进行读取。 35 | 36 | 数据下载完毕后,使用 `process_image.py` 进行预处理。 37 | 38 | ### 安装运行环境 39 | 40 | 使用 `pip install` 安装 `requirements.txt`: 41 | 42 | ```shell 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ### 开始训练 47 | 48 | 模型训练采用 image model 冻结的方式进行,LLM 使用 Lora 方式训练以减少训练压力。需要训练的参数包括视觉特征映射层以及 LLM 中 Lora 的参数。由于映射层是未训练的初始化参数,为了平衡模型参数优化速度,这里为映射层设定了比 Lora 部分更大的学习率。 49 | 50 | 运行根目录的 `train.sh`,可自行配置相关参数进行试验。 51 | 52 | ```shell 53 | sh train.sh 54 | ``` 55 | 56 | 通过上述步骤,您可以启动训练过程并进行多模态模型的训练。 57 | 58 | 模型权重将会保存在`--output_dir`中,同样,这个路径可以进行自定义。 59 | 60 | #### `train.sh` 脚本解析 61 | 62 | ```sh 63 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=25642 train.py \ 64 | --lora_rank 128 \ 65 | --lora_dropout 0.10 \ 66 | --per_device_train_batch_size 4 \ 67 | --gradient_accumulation_steps 1 \ 68 | --num_train_epochs 2 \ 69 | --save_steps 1000 \ 70 | --save_total_limit 5 \ 71 | --learning_rate 3e-5 \ 72 | --seed 42 \ 73 | --ddp_find_unused_parameters False \ 74 | --feature_proj_lr 1e-4 \ 75 | --remove_unused_columns false \ 76 | --logging_steps 100 \ 77 | --output_dir ./weights/train_V1_5 \ 78 | --target_modules "c_attn|w1|w2" \ 79 | --image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json \ 80 | --captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json 81 | ``` 82 | 83 | #### 解释 84 | 85 | 1. **CUDA_VISIBLE_DEVICES=0**: 使用ID为0的GPU。 86 | 2. **torchrun**: PyTorch的分布式训练工具。 87 | 3. **--nproc_per_node=1**: 每个节点运行1个进程。 88 | 4. **--master_port=25642**: 设置进程间通信端口。 89 | 5. **train.py**: 主训练脚本。 90 | 91 | #### 传递给 `train.py` 的参数 92 | 93 | 1. **--lora_rank 128**: LoRA层的秩为128。 94 | 2. **--lora_dropout 0.10**: LoRA层的dropout率为10%。 95 | 3. **--per_device_train_batch_size 4**: 每个设备的训练批次大小为4。 96 | 4. **--gradient_accumulation_steps 1**: 梯度累积步数为1。 97 | 5. **--num_train_epochs 2**: 训练2个epoch。 98 | 6. **--save_steps 1000**: 每1000步保存一次模型。 99 | 7. **--save_total_limit 5**: 最多保存5个检查点。 100 | 8. **--learning_rate 3e-5**: 学习率为3e-5。 101 | 9. **--seed 42**: 随机种子为42。 102 | 10. **--ddp_find_unused_parameters False**: 禁用DDP查找未使用的参数。 103 | 11. **--feature_proj_lr 1e-4**: 特征投影层的学习率为1e-4。 104 | 12. **--remove_unused_columns false**: 保留未使用的列。 105 | 13. **--logging_steps 100**: 每100步记录一次日志。 106 | 14. **--output_dir ./weights/train_V1_5**: 输出目录。 107 | 15. **--target_modules "c_attn|w1|w2"**: LoRA适配的目标模块。 108 | 16. **--image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json**: 图像映射文件路径。 109 | 17. **--captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json**: 标注文件路径。 110 | 111 | ### 测试模型 112 | 113 | 运行根目录的 `test.sh`,可自行配置相关参数进行试验。 114 | 115 | ```shell 116 | sh test.sh 117 | ``` 118 | 119 | 代码会读取文件夹下的图片进行问答。 120 | 121 | #### 加载预训练模型(可选) 122 | 如果想直接测试模型效果,所提供的预训练权重如下: 123 | 124 | | SIGLIP_Qwen_epoch19000 | SIGLIP_Qwen_epoch36000 | 125 | | :---: | :---: | 126 | |[Model1](https://huggingface.co/xinyanghuang/Basic-Visual-Language-Model/tree/main/checkpoint-19000)|[Model2](https://huggingface.co/xinyanghuang/Basic-Visual-Language-Model/tree/main/checkpoint-36000)| 127 | 128 | 可以直接下载相关文件后进行测试。 129 | 130 | #### `test.sh` 脚本解析 131 | 132 | ```sh 133 | python test.py --base_language_model Qwen/Qwen-7B-Chat --base_value_model openai/clip-vit-large-patch14 --model_weights ./weights/train_V1_5/checkpoint-10000/ --image_path ./test_img/1.jpg --prompt "使用语言描述一下图中出现了那些颜色<|extra_0|>" 134 | ``` 135 | 136 | #### 传递给 `test.py` 的参数 137 | 138 | 1. **--base_language_model Qwen/Qwen-7B-Chat**: 指定基础语言模型的路径,这里使用的是 `Qwen/Qwen-7B-Chat`。 139 | 2. **--base_value_model openai/clip-vit-large-patch14**: 指定基础视觉模型的路径,这里使用的是 `openai/clip-vit-large-patch14`。 140 | 3. **--model_weights ./weights/train_V1_5/checkpoint-10000/**: 指定模型权重的路径,这里使用的是训练过程中保存的检查点 `checkpoint-10000`。 141 | 4. **--image_path ./test_img/1.jpg**: 指定输入图像的路径,这里使用的是 `./test_img/1.jpg`。 142 | 5. **--prompt "使用语言描述一下图中出现了那些颜色<|extra_0|>"**: 指定模型的提示语,这里要求模型用语言描述图中出现的颜色。 143 | 144 | ## 参考 📚 145 | 146 | 感谢以下项目的伟大工作🙌: 147 | 148 | - https://github.com/WatchTower-Liu/VLM-learning/tree/main 149 | - https://github.com/QwenLM/Qwen 150 | - https://github.com/haotian-liu/LLaVA 151 | 152 | ## 联系 ✉ 153 | 154 | 如果你有任何疑问或者想法,十分欢迎随时联系我😊: 155 | 156 | hsinyanghuang7@gmail.com 157 | 158 | 我会在看到邮件的第一时间回复! 159 | -------------------------------------------------------------------------------- /dataprocess/config.yaml: -------------------------------------------------------------------------------- 1 | ali_image_path: "/home/u2023111315/ai_challenger_caption_train_20170902/caption_train_images_20170902/" 2 | ali_image_path_label_file: "/home/u2023111315/ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json" 3 | 4 | coco_image_path: "/home/u2023111315/train2017/" 5 | coco_image_path_label_file1: "/home/u2023111315/Basic-Vision-Language-Model/data/complex_reasoning_77k.json" 6 | coco_image_path_label_file2: "/home/u2023111315/Basic-Vision-Language-Model/data/detail_23k.json" 7 | 8 | output_image_map: "/home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json" 9 | output_captions: "/home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json" 10 | -------------------------------------------------------------------------------- /dataprocess/process_image.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import yaml 4 | 5 | def read_json(json_path): 6 | with open(json_path, 'r', encoding="utf-8") as f: 7 | data = json.load(f) 8 | return data 9 | 10 | def process_ali(caption_data, image_path, start_ID=0): 11 | image_map = {} 12 | captions = {} 13 | for idx, D in enumerate(caption_data): 14 | if idx < 5: 15 | print(f"Processing element {idx}: {D}") 16 | 17 | image_name = D.get("image_id", D.get("image")) 18 | caption = D.get("caption", D.get("conversations")) 19 | 20 | if isinstance(caption, list): 21 | selected_caption = np.random.choice(caption) 22 | if np.random.random() > 0.5: 23 | captions[idx + start_ID] = {"q": "这幅图像描述了什么?<|extra_0|>", "a": selected_caption} 24 | else: 25 | captions[idx + start_ID] = {"q": "<|extra_0|>这幅图像中有什么?", "a": selected_caption} 26 | else: 27 | if np.random.random() > 0.5: 28 | captions[idx + start_ID] = {"q": "这幅图像描述了什么?<|extra_0|>", "a": caption} 29 | else: 30 | captions[idx + start_ID] = {"q": "<|extra_0|>这幅图像中有什么?", "a": caption} 31 | 32 | image_map[image_name] = {"image_file": image_name, "ID": idx + start_ID, "path": image_path} 33 | 34 | return image_map, captions 35 | 36 | def process_coco(caption_data, image_path, start_ID=0): 37 | image_map = {} 38 | captions = {} 39 | for idx, D in enumerate(caption_data): 40 | image_name = D["image"] 41 | caption = D["conversations"] 42 | 43 | answer = caption[1]["value"] 44 | question = caption[0]["value"].replace("", "<|extra_0|>").replace("\n", "") 45 | image_map[image_name] = {"image_file": image_name, "ID": idx + start_ID, "path": image_path} 46 | captions[idx + start_ID] = {"q": question, "a": [answer]} 47 | 48 | return image_map, captions 49 | 50 | def main(): 51 | config = yaml.load(open("config.yaml", 'r', encoding="utf-8"), Loader=yaml.FullLoader) 52 | 53 | caption_data_ali = read_json(config["ali_image_path_label_file"]) 54 | caption_data_chat1 = read_json(config["coco_image_path_label_file1"]) 55 | caption_data_chat2 = read_json(config["coco_image_path_label_file2"]) 56 | 57 | print("First few elements of caption_data_ali:") 58 | for i in range(min(5, len(caption_data_ali))): 59 | print(caption_data_ali[i]) 60 | 61 | image_map_ali, captions_ali = process_ali(caption_data_ali, config["ali_image_path"]) 62 | print(f"Processed {len(image_map_ali)} images from ali dataset.") 63 | 64 | image_map_chat1, captions_chat1 = process_coco(caption_data_chat1, config["coco_image_path"], len(image_map_ali)) 65 | print(f"Processed {len(image_map_chat1)} images from coco dataset file 1.") 66 | 67 | image_map_chat2, captions_chat2 = process_coco(caption_data_chat2, config["coco_image_path"], len(image_map_ali) + len(image_map_chat1)) 68 | print(f"Processed {len(image_map_chat2)} images from coco dataset file 2.") 69 | 70 | image_map = {**image_map_ali, **image_map_chat1, **image_map_chat2} 71 | captions = {**captions_ali, **captions_chat1, **captions_chat2} 72 | 73 | print(f"Total images processed: {len(image_map)}") 74 | 75 | with open(config["output_image_map"], 'w', encoding="utf-8") as f: 76 | json.dump(image_map, f, ensure_ascii=False, indent=4) 77 | 78 | with open(config["output_captions"], 'w', encoding="utf-8") as f: 79 | json.dump(captions, f, ensure_ascii=False, indent=4) 80 | 81 | if __name__ == "__main__": 82 | main() -------------------------------------------------------------------------------- /dataset/image_caption_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms 5 | from transformers import CLIPProcessor, SiglipProcessor 6 | from PIL import Image 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from qwen.qwen_generation_utils import make_context 11 | 12 | def readJson(filePath): 13 | with open(filePath, 'r', encoding="utf-8") as f: 14 | data = json.load(f) 15 | return data 16 | 17 | def data_collate(example, tokenizer, black_token_length): 18 | images = [] 19 | captions = [] 20 | labels = [] 21 | max_length = np.max([len(e[1]) for e in example]) + 1 22 | for e in example: 23 | img, caption, L = e 24 | L = L + 1 25 | caption = caption + [tokenizer.eod_id] 26 | images.append(img) 27 | caption_labels = [-100]*(black_token_length + (len(caption)-L) - 1) + caption[-L:] + [-100]*(max_length - len(caption)) 28 | captions.append(torch.tensor(caption + [tokenizer.eod_id]*(max_length - len(caption)))) 29 | labels.append(torch.tensor(caption_labels)) 30 | 31 | labels = torch.stack(labels, dim=0).long() 32 | captions = torch.stack(captions, dim=0).long() 33 | images = torch.stack(images, dim=0).to(torch.float16) 34 | 35 | return {"images": images, "input_ids": captions, "labels": labels} 36 | 37 | class ImageCaptionDataset(Dataset): 38 | def __init__(self, tokenizer, image_map_file, captions_file, Vconfig, return_caption_num=1, max_train_data_item=None): 39 | super().__init__() 40 | self.tokenizer = tokenizer 41 | self.return_caption_num = return_caption_num 42 | self.max_train_data_item = max_train_data_item 43 | 44 | mean = [0.485, 0.456, 0.406] # RGB 45 | std = [0.229, 0.224, 0.225] # RGB 46 | 47 | self.tran = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean, std), 50 | transforms.Resize([224, 224]) 51 | ]) 52 | 53 | self.image_map = readJson(image_map_file) 54 | self.captions = readJson(captions_file) 55 | 56 | # self.image_processor = CLIPProcessor.from_pretrained(Vconfig.model_path) 57 | 58 | self.image_processor = SiglipProcessor.from_pretrained(Vconfig.model_path) 59 | 60 | self.readImage() # 一次性读入内存 61 | 62 | def readImage(self): 63 | self.data_list = [] 64 | number = 0 65 | image_map_keys = list(self.image_map.keys()) 66 | np.random.shuffle(image_map_keys) 67 | for IM in tqdm(image_map_keys): 68 | number += 1 69 | if self.max_train_data_item is not None and number > self.max_train_data_item: 70 | return 71 | try: 72 | image_file_path = self.image_map[IM]["path"] + self.image_map[IM]["image_file"] 73 | self.data_list.append([image_file_path, self.image_map[IM]["ID"]]) 74 | except Exception as e: 75 | print(f"Error loading image {IM}: {e}") 76 | continue 77 | 78 | # Debug information 79 | print(f"Total images loaded: {len(self.data_list)}") 80 | 81 | def __getitem__(self, index): 82 | image_path, ID = self.data_list[index] 83 | try: 84 | image = Image.open(image_path).convert("RGB") 85 | image = self.image_processor(images=image, return_tensors="pt")["pixel_values"][0] 86 | except Exception as e: 87 | print(f"Error processing image {image_path}: {e}") 88 | raise 89 | 90 | captions_data = self.captions.get(str(ID), {}) 91 | captions = captions_data.get("a", []) 92 | 93 | # Ensure captions is a list 94 | if isinstance(captions, str): 95 | captions = [captions] 96 | elif isinstance(captions, dict): 97 | # Handle the case where captions is a dictionary 98 | captions = [captions.get("value", "")] 99 | 100 | if not isinstance(captions, list): 101 | raise ValueError(f"Captions for ID {ID} are not in the expected format: {captions}") 102 | 103 | if not captions: 104 | raise ValueError(f"No captions found for ID {ID}") 105 | 106 | prompt = captions_data.get("q", "") 107 | 108 | # Debug information 109 | # print(f"Captions for ID {ID}: {captions}") 110 | 111 | select_idx = np.random.choice(len(captions)) 112 | 113 | # More debug information 114 | # print(f"Selected index: {select_idx}, Selected caption: {captions[select_idx]}") 115 | 116 | messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt}] 117 | 118 | prompt_raw, context_tokens = make_context( 119 | self.tokenizer, 120 | prompt, 121 | history=[], 122 | system="你是一位图像理解助手。" 123 | ) 124 | 125 | choice_captions = self.tokenizer(prompt_raw)["input_ids"] 126 | answer = self.tokenizer(captions[select_idx])["input_ids"] 127 | choice_captions = choice_captions + answer 128 | 129 | return image, choice_captions, len(answer) 130 | 131 | def __len__(self): 132 | return len(self.data_list) 133 | 134 | -------------------------------------------------------------------------------- /minicpm/Mminicpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple, Union, List, Dict 4 | from torch.nn import CrossEntropyLoss 5 | from transformers.modeling_outputs import CausalLMOutputWithPast 6 | 7 | from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMModel, BaseModelOutputWithPast, logger 8 | 9 | class MMiniCPMModel(MiniCPMModel): 10 | def __init__(self, config, otherConfig): 11 | super().__init__(config) 12 | self.otherConfig = otherConfig 13 | 14 | def forward( 15 | self, 16 | input_ids: Optional[torch.LongTensor] = None, 17 | images: Optional[torch.Tensor] = None, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | position_ids: Optional[torch.LongTensor] = None, 20 | past_key_values: Optional[List[torch.FloatTensor]] = None, 21 | inputs_embeds: Optional[torch.FloatTensor] = None, 22 | use_cache: Optional[bool] = None, 23 | output_attentions: Optional[bool] = None, 24 | output_hidden_states: Optional[bool] = None, 25 | return_dict: Optional[bool] = None, 26 | ) -> Union[Tuple, BaseModelOutputWithPast]: 27 | device = input_ids.device if input_ids is not None else inputs_embeds.device 28 | first_step = False 29 | if images is not None and past_key_values is None: 30 | image_index = torch.where(input_ids == self.otherConfig["replace_token_id"])[1] 31 | new_input_ids = [] 32 | for b_idx, img_idx in enumerate(image_index): 33 | new_input_ids.append(torch.cat([input_ids[b_idx][:img_idx], input_ids[b_idx][img_idx+1:]], dim=0)) 34 | input_ids = torch.stack(new_input_ids, dim=0).to(input_ids) 35 | first_step = True 36 | 37 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 38 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 39 | use_cache = use_cache if use_cache is not None else self.config.use_cache 40 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 41 | 42 | if input_ids is not None and inputs_embeds is not None: 43 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 44 | elif input_ids is not None: 45 | batch_size, seq_length = input_ids.shape[:2] 46 | elif inputs_embeds is not None: 47 | batch_size, seq_length = inputs_embeds.shape[:2] 48 | else: 49 | raise ValueError("You have to specify either input_ids or inputs_embeds") 50 | 51 | if self.gradient_checkpointing and self.training: 52 | if use_cache: 53 | logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") 54 | use_cache = False 55 | 56 | past_key_values_length = 0 57 | if use_cache: 58 | use_legacy_cache = not isinstance(past_key_values, Cache) 59 | if use_legacy_cache: 60 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 61 | past_key_values_length = past_key_values.get_usable_length(seq_length) 62 | 63 | if position_ids is None: 64 | device = input_ids.device if input_ids is not None else inputs_embeds.device 65 | position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) 66 | position_ids = position_ids.unsqueeze(0) 67 | 68 | if inputs_embeds is None: 69 | inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb 70 | 71 | if self._use_flash_attention_2: 72 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 73 | elif self._use_sdpa and not output_attentions: 74 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length) 75 | else: 76 | attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length) 77 | 78 | hidden_states = inputs_embeds 79 | 80 | if images is not None and first_step: 81 | new_hidden_states = [] 82 | for b_idx, img_idx in enumerate(image_index): 83 | new_hidden_states.append(torch.cat([hidden_states[b_idx][:img_idx], images[b_idx], hidden_states[b_idx][img_idx:]], dim=0)) 84 | hidden_states = torch.stack(new_hidden_states, dim=0).to(hidden_states) 85 | 86 | all_hidden_states = () if output_hidden_states else None 87 | all_self_attns = () if output_attentions else None 88 | next_decoder_cache = None 89 | 90 | for decoder_layer in self.layers: 91 | if output_hidden_states: 92 | all_hidden_states += (hidden_states,) 93 | 94 | if self.gradient_checkpointing and self.training: 95 | layer_outputs = self._gradient_checkpointing_func( 96 | decoder_layer.__call__, 97 | hidden_states, 98 | attention_mask, 99 | position_ids, 100 | past_key_values, 101 | output_attentions, 102 | use_cache, 103 | ) 104 | else: 105 | layer_outputs = decoder_layer( 106 | hidden_states, 107 | attention_mask=attention_mask, 108 | position_ids=position_ids, 109 | past_key_value=past_key_values, 110 | output_attentions=output_attentions, 111 | use_cache=use_cache, 112 | ) 113 | 114 | hidden_states = layer_outputs[0] 115 | 116 | if use_cache: 117 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 118 | 119 | if output_attentions: 120 | all_self_attns += (layer_outputs[1],) 121 | 122 | hidden_states = self.norm(hidden_states) 123 | 124 | if output_hidden_states: 125 | all_hidden_states += (hidden_states,) 126 | 127 | next_cache = None 128 | if use_cache: 129 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 130 | if not return_dict: 131 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 132 | return BaseModelOutputWithPast( 133 | last_hidden_state=hidden_states, 134 | past_key_values=next_cache, 135 | hidden_states=all_hidden_states, 136 | attentions=all_self_attns, 137 | ) 138 | 139 | class MMiniCPMLMHeadModel(MiniCPMForCausalLM): 140 | def __init__(self, config, otherConfig): 141 | super().__init__(config) 142 | self.model = MMiniCPMModel(config, otherConfig) 143 | 144 | def prepare_inputs_for_generation( 145 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 146 | ): 147 | if past_key_values: 148 | input_ids = input_ids[:, -1].unsqueeze(-1) 149 | 150 | if input_ids.size(0) == 1: 151 | attention_mask = None 152 | else: 153 | attention_mask = kwargs.get("attention_mask", None) 154 | 155 | if inputs_embeds is not None and past_key_values is None: 156 | model_inputs = {"inputs_embeds": inputs_embeds} 157 | else: 158 | model_inputs = {"input_ids": input_ids} 159 | 160 | model_inputs.update( 161 | { 162 | "past_key_values": past_key_values, 163 | "use_cache": kwargs.get("use_cache"), 164 | "attention_mask": attention_mask, 165 | "images": kwargs.get("images") 166 | } 167 | ) 168 | return model_inputs 169 | 170 | def forward( 171 | self, 172 | input_ids: Optional[torch.LongTensor] = None, 173 | images: Optional[torch.Tensor] = None, 174 | attention_mask: Optional[torch.Tensor] = None, 175 | position_ids: Optional[torch.LongTensor] = None, 176 | past_key_values: Optional[List[torch.FloatTensor]] = None, 177 | inputs_embeds: Optional[torch.FloatTensor] = None, 178 | labels: Optional[torch.LongTensor] = None, 179 | use_cache: Optional[bool] = None, 180 | output_attentions: Optional[bool] = None, 181 | output_hidden_states: Optional[bool] = None, 182 | return_dict: Optional[bool] = None, 183 | ) -> Union[Tuple, CausalLMOutputWithPast]: 184 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 185 | transformer_outputs = self.model( 186 | input_ids=input_ids, 187 | images=images, 188 | attention_mask=attention_mask, 189 | position_ids=position_ids, 190 | past_key_values=past_key_values, 191 | inputs_embeds=inputs_embeds, 192 | use_cache=use_cache, 193 | output_attentions=output_attentions, 194 | output_hidden_states=output_hidden_states, 195 | return_dict=return_dict, 196 | ) 197 | hidden_states = transformer_outputs[0] 198 | lm_logits = self.lm_head(hidden_states) 199 | 200 | loss = None 201 | if labels is not None: 202 | labels = labels.to(lm_logits.device) 203 | shift_logits = lm_logits[..., :-1, :].contiguous() 204 | shift_labels = labels[..., 1:].contiguous() 205 | loss_fct = CrossEntropyLoss() 206 | loss = loss_fct( 207 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 208 | ) 209 | 210 | if not return_dict: 211 | output = (lm_logits,) + transformer_outputs[1:] 212 | return ((loss,) + output) if loss is not None else output 213 | 214 | return CausalLMOutputWithPast( 215 | loss=loss, 216 | logits=lm_logits, 217 | past_key_values=transformer_outputs.past_key_values, 218 | hidden_states=transformer_outputs.hidden_states, 219 | attentions=transformer_outputs.attentions, 220 | ) -------------------------------------------------------------------------------- /minicpm/configuration_minicpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ MiniCPM model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class MiniCPMConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the MiniCPM-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`MiniCPMModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer decoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer decoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to 2048): 64 | The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens, 65 | MiniCPM 2 up to 4096, CodeMiniCPM up to 16384. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | Padding token id. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | Beginning of stream token id. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | End of stream token id. 79 | pretraining_tp (`int`, *optional*, defaults to 1): 80 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 81 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 82 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 83 | issue](https://github.com/pytorch/pytorch/issues/76232). 84 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 85 | Whether to tie weight embeddings 86 | rope_theta (`float`, *optional*, defaults to 10000.0): 87 | The base period of the RoPE embeddings. 88 | rope_scaling (`Dict`, *optional*): 89 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 90 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 91 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 92 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 93 | these scaling strategies behave: 94 | https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 95 | experimental feature, subject to breaking API changes in future versions. 96 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 97 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 98 | attention_dropout (`float`, *optional*, defaults to 0.0): 99 | The dropout ratio for the attention probabilities. 100 | 101 | ```python 102 | >>> from transformers import MiniCPMModel, MiniCPMConfig 103 | 104 | >>> # Initializing a MiniCPM minicpm-7b style configuration 105 | >>> configuration = MiniCPMConfig() 106 | 107 | >>> # Initializing a model from the minicpm-7b style configuration 108 | >>> model = MiniCPMModel(configuration) 109 | 110 | >>> # Accessing the model configuration 111 | >>> configuration = model.config 112 | ```""" 113 | 114 | model_type = "minicpm" 115 | keys_to_ignore_at_inference = ["past_key_values"] 116 | 117 | def __init__( 118 | self, 119 | vocab_size=32000, 120 | hidden_size=4096, 121 | intermediate_size=11008, 122 | num_hidden_layers=32, 123 | num_attention_heads=32, 124 | num_key_value_heads=None, 125 | hidden_act="silu", 126 | max_position_embeddings=2048, 127 | initializer_range=0.02, 128 | rms_norm_eps=1e-6, 129 | use_cache=True, 130 | pad_token_id=None, 131 | bos_token_id=1, 132 | eos_token_id=2, 133 | pretraining_tp=1, 134 | tie_word_embeddings=True, 135 | rope_theta=10000.0, 136 | rope_scaling=None, 137 | attention_bias=False, 138 | attention_dropout=0.0, 139 | scale_emb=1, 140 | dim_model_base=1, 141 | scale_depth=1, 142 | **kwargs, 143 | ): 144 | self.vocab_size = vocab_size 145 | self.max_position_embeddings = max_position_embeddings 146 | self.hidden_size = hidden_size 147 | self.intermediate_size = intermediate_size 148 | self.num_hidden_layers = num_hidden_layers 149 | self.num_attention_heads = num_attention_heads 150 | 151 | # for backward compatibility 152 | if num_key_value_heads is None: 153 | num_key_value_heads = num_attention_heads 154 | 155 | self.num_key_value_heads = num_key_value_heads 156 | self.hidden_act = hidden_act 157 | self.initializer_range = initializer_range 158 | self.rms_norm_eps = rms_norm_eps 159 | self.pretraining_tp = pretraining_tp 160 | self.use_cache = use_cache 161 | self.rope_theta = rope_theta 162 | self.rope_scaling = rope_scaling 163 | self._rope_scaling_validation() 164 | self.attention_bias = attention_bias 165 | self.attention_dropout = attention_dropout 166 | self.scale_emb = scale_emb 167 | self.dim_model_base = dim_model_base 168 | self.scale_depth = scale_depth 169 | 170 | super().__init__( 171 | pad_token_id=pad_token_id, 172 | bos_token_id=bos_token_id, 173 | eos_token_id=eos_token_id, 174 | tie_word_embeddings=tie_word_embeddings, 175 | **kwargs, 176 | ) 177 | try: 178 | import flash_attn 179 | self._attn_implementation = "flash_attention_2" 180 | except: 181 | pass 182 | 183 | def _rope_scaling_validation(self): 184 | """ 185 | Validate the `rope_scaling` configuration. 186 | """ 187 | if self.rope_scaling is None: 188 | return 189 | 190 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 191 | raise ValueError( 192 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 193 | f"got {self.rope_scaling}" 194 | ) 195 | rope_scaling_type = self.rope_scaling.get("type", None) 196 | rope_scaling_factor = self.rope_scaling.get("factor", None) 197 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 198 | raise ValueError( 199 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 200 | ) 201 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 202 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional 4 | import os 5 | import sys 6 | sys.path.append("../") 7 | from transformers.modeling_outputs import CausalLMOutputWithPast 8 | from transformers import CLIPProcessor 9 | from dataclasses import dataclass, asdict 10 | from peft import get_peft_model, LoraConfig, TaskType, PeftModel 11 | 12 | # from visual.CLIP_VIT import visualModel 13 | from visual.SIGLIP_VIT import visualModel 14 | from qwen.Mqwen import MQWenLMHeadModel 15 | from minicpm.Mminicpm import MMiniCPMLMHeadModel 16 | 17 | @dataclass 18 | class LanguageConfig(): 19 | model_path: str 20 | torch_dtype: torch.dtype = torch.bfloat16 21 | trust_remote_code: bool = True 22 | 23 | @dataclass 24 | class VisualConfig(): 25 | model_path: str 26 | pretrained: bool = True 27 | 28 | 29 | @dataclass 30 | class MultiModalConfig(): 31 | replace_token_id: int 32 | # image_context_length: int = 256 33 | image_context_length: int = 728 34 | image_feature_hidden_size: int = 4096 35 | 36 | 37 | def make_lora(model, finetune_args): 38 | peft_config = LoraConfig( 39 | task_type=TaskType.CAUSAL_LM, 40 | inference_mode=False, 41 | r=finetune_args.lora_rank, 42 | lora_alpha=32, 43 | lora_dropout=finetune_args.lora_dropout, 44 | target_modules = finetune_args.target_modules.split('|') # 把model打印出来,找跟attention相关的模块 45 | ) 46 | 47 | model = get_peft_model(model, peft_config) 48 | 49 | return model 50 | 51 | class MMultiModal(nn.Module): 52 | def __init__(self, Lconfig: LanguageConfig, Vconfig: VisualConfig, MMconfig: MultiModalConfig, finetune_args = None, train = False, *args, **kwargs) -> None: 53 | super().__init__(*args, **kwargs) 54 | image_feature_length = MMconfig.image_context_length * MMconfig.image_feature_hidden_size 55 | 56 | self.LLM = MQWenLMHeadModel.from_pretrained(Lconfig.model_path, asdict(MMconfig), torch_dtype = Lconfig.torch_dtype, trust_remote_code = Lconfig.trust_remote_code) 57 | # self.LLM = MMiniCPMLMHeadModel.from_pretrained(Lconfig.model_path, asdict(MMconfig), torch_dtype = Lconfig.torch_dtype, trust_remote_code = Lconfig.trust_remote_code) 58 | 59 | if train: 60 | self.LLM.gradient_checkpointing_enable() 61 | self.LLM.enable_input_require_grads() 62 | 63 | self.LLM.config.image_feature_length = image_feature_length 64 | 65 | if train and finetune_args is not None: 66 | self.LLM = make_lora(self.LLM, finetune_args) 67 | 68 | assert MMconfig.image_feature_hidden_size == self.LLM.config.hidden_size 69 | 70 | self.visualModel = visualModel.from_pretrained(Vconfig.model_path).to(Lconfig.torch_dtype) 71 | 72 | Vhidden_dim = self.visualModel.vision_embed_dim 73 | Lhidden_dim = self.LLM.config.hidden_size 74 | 75 | self.make_feature_proj(Vhidden_dim, Lhidden_dim, Lconfig) 76 | 77 | self.MMconfig = MMconfig 78 | 79 | print(f"LLM dtype: {self.LLM.dtype}") 80 | print(f"Visual model dtype: {self.visualModel.dtype}") 81 | print(f"Feature projection dtype: {self.feature_proj[0].weight.dtype}") 82 | 83 | def make_feature_proj(self, Vhidden_dim, Lhidden_dim, Lconfig): 84 | self.feature_proj = nn.Sequential( 85 | nn.Linear(Vhidden_dim, Lhidden_dim, dtype=Lconfig.torch_dtype), 86 | nn.GELU(), 87 | nn.Linear(Lhidden_dim, Lhidden_dim, dtype=Lconfig.torch_dtype) 88 | ) 89 | 90 | for name, module in self.feature_proj.named_children(): 91 | if "Linear" in module._get_name(): 92 | module.weight.data.normal_(mean=0.0, std = 0.01) 93 | module.bias.data.zero_() 94 | 95 | def forward(self, image: torch.Tensor, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None): 96 | with torch.no_grad(): 97 | # 确保 image 的数据类型为 bfloat16 98 | image = image.to(dtype=torch.bfloat16) 99 | image_feature = self.visualModel.get_image_features(pixel_values=image)[:,1:, :] 100 | image_feature = image_feature.detach() 101 | 102 | image_feature = self.feature_proj(image_feature) 103 | 104 | out = self.LLM(input_ids, labels=labels, images=image_feature) 105 | 106 | loss1 = out.loss 107 | 108 | return CausalLMOutputWithPast( 109 | loss=loss1, 110 | logits=out.logits, 111 | past_key_values=out.past_key_values, 112 | hidden_states=out.hidden_states, 113 | attentions=out.attentions, 114 | ) 115 | 116 | def to(self, *args, **kwargs): 117 | return super().to(*args, **kwargs) 118 | 119 | def load(self, modelPath): 120 | self.LLM = PeftModel.from_pretrained(self.LLM, modelPath, inference_mode=True) 121 | other_params = torch.load(os.path.join(modelPath, "other_params.bin")) 122 | self.feature_proj.load_state_dict(other_params) 123 | 124 | @torch.no_grad() 125 | def generate(self, image: torch.Tensor, input_ids: torch.LongTensor): 126 | if image is None: 127 | image_feature = None 128 | else: 129 | image_feature=self.visualModel.get_image_features(pixel_values=image)[:,1:, :] 130 | image_feature = self.feature_proj(image_feature) 131 | 132 | input_ids = torch.tensor([input_ids]).long().to(self.LLM.device) 133 | 134 | out = self.LLM.generate(inputs = input_ids, images=image_feature)[:, len(input_ids[0]):-1] 135 | 136 | return out.long().cpu() -------------------------------------------------------------------------------- /qwen/Mqwen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple, Union 4 | from torch.nn import CrossEntropyLoss 5 | from transformers.modeling_outputs import CausalLMOutputWithPast 6 | 7 | from .modeling_qwen import QWenLMHeadModel, QWenModel, BaseModelOutputWithPast, logger 8 | 9 | class MQWenModel(QWenModel): 10 | def __init__(self, config, otherConfig): 11 | super().__init__(config) 12 | 13 | self.otherConfig = otherConfig 14 | 15 | def forward( 16 | self, 17 | input_ids: Optional[torch.LongTensor] = None, 18 | images: Optional[torch.Tensor] = None, 19 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 20 | attention_mask: Optional[torch.FloatTensor] = None, 21 | token_type_ids: Optional[torch.LongTensor] = None, 22 | position_ids: Optional[torch.LongTensor] = None, 23 | head_mask: Optional[torch.FloatTensor] = None, 24 | inputs_embeds: Optional[torch.FloatTensor] = None, 25 | encoder_hidden_states: Optional[torch.Tensor] = None, 26 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 27 | use_cache: Optional[bool] = None, 28 | output_attentions: Optional[bool] = None, 29 | output_hidden_states: Optional[bool] = None, 30 | return_dict: Optional[bool] = None, 31 | ): 32 | device = input_ids.device if input_ids is not None else inputs_embeds.device 33 | first_step = False 34 | if images is not None and past_key_values is None: 35 | image_index = torch.where(input_ids == self.otherConfig["replace_token_id"])[1] 36 | new_input_ids = [] 37 | for b_idx, img_idx in enumerate(image_index): 38 | new_input_ids.append(torch.cat([input_ids[b_idx][:img_idx], input_ids[b_idx][img_idx+1:]], dim = 0)) ############# concat image and text 39 | 40 | input_ids = torch.stack(new_input_ids, dim = 0).to(input_ids) 41 | first_step = True 42 | 43 | output_attentions = ( 44 | output_attentions 45 | if output_attentions is not None 46 | else self.config.output_attentions 47 | ) 48 | output_hidden_states = ( 49 | output_hidden_states 50 | if output_hidden_states is not None 51 | else self.config.output_hidden_states 52 | ) 53 | use_cache = use_cache if use_cache is not None else self.config.use_cache 54 | return_dict = ( 55 | return_dict if return_dict is not None else self.config.use_return_dict 56 | ) 57 | 58 | if input_ids is not None and inputs_embeds is not None: 59 | raise ValueError( 60 | "You cannot specify both input_ids and inputs_embeds at the same time" 61 | ) 62 | elif input_ids is not None: 63 | input_shape = input_ids.size() 64 | input_ids = input_ids.view(-1, input_shape[-1]).contiguous() 65 | batch_size = input_ids.shape[0] 66 | elif inputs_embeds is not None: 67 | input_shape = inputs_embeds.size()[:-1] 68 | batch_size = inputs_embeds.shape[0] 69 | else: 70 | raise ValueError("You have to specify either input_ids or inputs_embeds") 71 | 72 | if images is not None and first_step: 73 | input_shape = input_shape[0], input_shape[-1] + self.otherConfig["image_context_length"] ############## 74 | 75 | 76 | if token_type_ids is not None: 77 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 78 | if position_ids is not None: 79 | position_ids = position_ids.view(-1, input_shape[-1]) 80 | 81 | if past_key_values is None: 82 | past_length = 0 83 | past_key_values = tuple([None] * len(self.h)) 84 | else: 85 | if self.use_cache_quantization: 86 | past_length = past_key_values[0][0][0].size(2) 87 | else: 88 | past_length = past_key_values[0][0].size(-2) 89 | if position_ids is None: 90 | position_ids = torch.arange( 91 | past_length, 92 | input_shape[-1] + past_length, 93 | dtype=torch.long, 94 | device=device, 95 | ) 96 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 97 | 98 | if attention_mask is not None: 99 | # image_feaute_length = self.otherConfig["image_context_length"]*self.otherConfig["image_feature_hidden_size"] 100 | # attention_mask_length = attention_mask.shape[-1] - image_feaute_length + self.otherConfig["image_context_length"] 101 | # attention_mask = torch.ones((batch_size, attention_mask_length), dtype=torch.long, device=device) 102 | if batch_size <= 0: 103 | raise ValueError("batch_size has to be defined and > 0") 104 | attention_mask = attention_mask.view(batch_size, -1) 105 | attention_mask = attention_mask[:, None, None, :] 106 | attention_mask = attention_mask.to(dtype=self.dtype) 107 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 108 | 109 | encoder_attention_mask = None 110 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 111 | 112 | if inputs_embeds is None: 113 | inputs_embeds = self.wte(input_ids) 114 | hidden_states = inputs_embeds 115 | 116 | if images is not None and first_step: 117 | 118 | new_hidden_states = [] 119 | for b_idx, img_idx in enumerate(image_index): 120 | new_hidden_states.append(torch.cat([hidden_states[b_idx][:img_idx], images[b_idx], hidden_states[b_idx][img_idx:]], dim = 0)) ############# concat image and text 121 | 122 | hidden_states = torch.stack(new_hidden_states, dim = 0).to(hidden_states) 123 | 124 | 125 | kv_seq_len = hidden_states.size()[1] 126 | if past_key_values[0] is not None: 127 | # past key values[0][0] shape: bs * seq_len * head_num * dim 128 | if self.use_cache_quantization: 129 | kv_seq_len += past_key_values[0][0][0].shape[2] 130 | else: 131 | kv_seq_len += past_key_values[0][0].shape[1] 132 | 133 | if self.training or not self.use_dynamic_ntk: 134 | ntk_alpha_list = [1.0] 135 | elif kv_seq_len != hidden_states.size()[1]: 136 | ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list 137 | else: 138 | ntk_alpha_list = [] 139 | if attention_mask is not None and kv_seq_len > self.seq_length: 140 | true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) 141 | for i in range(hidden_states.size()[0]): 142 | true_seq_len = true_seq_lens[i].item() 143 | ntk_alpha = self.get_ntk_alpha(true_seq_len) 144 | ntk_alpha_list.append(ntk_alpha) 145 | else: 146 | ntk_alpha = self.get_ntk_alpha(kv_seq_len) 147 | ntk_alpha_list.append(ntk_alpha) 148 | self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list 149 | rotary_pos_emb_list = [ 150 | self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list 151 | ] 152 | 153 | hidden_states = self.drop(hidden_states) 154 | 155 | # exit() 156 | output_shape = input_shape + (hidden_states.size(-1),) 157 | 158 | if self.gradient_checkpointing and self.training: 159 | if use_cache: 160 | logger.warning_once( 161 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 162 | ) 163 | use_cache = False 164 | 165 | presents = () if use_cache else None 166 | all_self_attentions = () if output_attentions else None 167 | all_hidden_states = () if output_hidden_states else None 168 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 169 | 170 | if output_hidden_states: 171 | all_hidden_states = all_hidden_states + (hidden_states,) 172 | 173 | if self.gradient_checkpointing and self.training: 174 | 175 | def create_custom_forward(module): 176 | def custom_forward(*inputs): 177 | # None for past_key_value 178 | return module(*inputs, use_cache, output_attentions) 179 | 180 | return custom_forward 181 | 182 | outputs = torch.utils.checkpoint.checkpoint( 183 | create_custom_forward(block), 184 | hidden_states, 185 | rotary_pos_emb_list, 186 | None, 187 | attention_mask, 188 | head_mask[i], 189 | encoder_hidden_states, 190 | encoder_attention_mask, 191 | ) 192 | else: 193 | outputs = block( 194 | hidden_states, 195 | layer_past=layer_past, 196 | rotary_pos_emb_list=rotary_pos_emb_list, 197 | attention_mask=attention_mask, 198 | head_mask=head_mask[i], 199 | encoder_hidden_states=encoder_hidden_states, 200 | encoder_attention_mask=encoder_attention_mask, 201 | use_cache=use_cache, 202 | output_attentions=output_attentions, 203 | ) 204 | 205 | hidden_states = outputs[0] 206 | if use_cache is True: 207 | presents = presents + (outputs[1],) 208 | 209 | if output_attentions: 210 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 211 | 212 | hidden_states = self.ln_f(hidden_states) 213 | hidden_states = hidden_states.view(output_shape) 214 | # Add last hidden state 215 | if output_hidden_states: 216 | all_hidden_states = all_hidden_states + (hidden_states,) 217 | 218 | if not return_dict: 219 | return tuple( 220 | v for v in [hidden_states, presents, all_hidden_states] if v is not None 221 | ) 222 | 223 | return BaseModelOutputWithPast( 224 | last_hidden_state=hidden_states, 225 | past_key_values=presents, 226 | hidden_states=all_hidden_states, 227 | attentions=all_self_attentions, 228 | ) 229 | 230 | class MQWenLMHeadModel(QWenLMHeadModel): 231 | def __init__(self, config, otherConfig): 232 | super().__init__(config) 233 | 234 | self.transformer = MQWenModel(config, otherConfig) 235 | 236 | if config.bf16: 237 | self.transformer.bfloat16() 238 | 239 | if config.fp16: 240 | self.transformer.half() 241 | 242 | def prepare_inputs_for_generation( 243 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 244 | ): 245 | if past_key_values: 246 | input_ids = input_ids[:, -1].unsqueeze(-1) 247 | 248 | if input_ids.size(0) == 1: 249 | attention_mask = None 250 | else: 251 | attention_mask = kwargs.get("attention_mask", None) 252 | 253 | if inputs_embeds is not None and past_key_values is None: 254 | model_inputs = {"inputs_embeds": inputs_embeds} 255 | else: 256 | model_inputs = {"input_ids": input_ids} 257 | 258 | model_inputs.update( 259 | { 260 | "past_key_values": past_key_values, 261 | "use_cache": kwargs.get("use_cache"), 262 | "attention_mask": attention_mask, 263 | "images": kwargs.get("images") 264 | } 265 | ) 266 | return model_inputs 267 | 268 | def forward( 269 | self, 270 | input_ids: Optional[torch.LongTensor] = None, 271 | images: Optional[torch.Tensor] = None, 272 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 273 | attention_mask: Optional[torch.FloatTensor] = None, 274 | token_type_ids: Optional[torch.LongTensor] = None, 275 | position_ids: Optional[torch.LongTensor] = None, 276 | head_mask: Optional[torch.FloatTensor] = None, 277 | inputs_embeds: Optional[torch.FloatTensor] = None, 278 | encoder_hidden_states: Optional[torch.Tensor] = None, 279 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 280 | labels: Optional[torch.LongTensor] = None, 281 | use_cache: Optional[bool] = None, 282 | output_attentions: Optional[bool] = None, 283 | output_hidden_states: Optional[bool] = None, 284 | return_dict: Optional[bool] = None, 285 | ) -> Union[Tuple, CausalLMOutputWithPast]: 286 | 287 | return_dict = ( 288 | return_dict if return_dict is not None else self.config.use_return_dict 289 | ) 290 | transformer_outputs = self.transformer( 291 | input_ids, 292 | images=images, 293 | past_key_values=past_key_values, 294 | attention_mask=attention_mask, 295 | token_type_ids=token_type_ids, 296 | position_ids=position_ids, 297 | head_mask=head_mask, 298 | inputs_embeds=inputs_embeds, 299 | encoder_hidden_states=encoder_hidden_states, 300 | encoder_attention_mask=encoder_attention_mask, 301 | use_cache=use_cache, 302 | output_attentions=output_attentions, 303 | output_hidden_states=output_hidden_states, 304 | return_dict=return_dict, 305 | ) 306 | hidden_states = transformer_outputs[0] 307 | 308 | lm_logits = self.lm_head(hidden_states) 309 | 310 | loss = None 311 | if labels is not None: 312 | labels = labels.to(lm_logits.device) 313 | shift_logits = lm_logits[..., :-1, :].contiguous() 314 | shift_labels = labels[..., 1:].contiguous() 315 | loss_fct = CrossEntropyLoss() 316 | loss = loss_fct( 317 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 318 | ) 319 | 320 | if not return_dict: 321 | output = (lm_logits,) + transformer_outputs[1:] 322 | return ((loss,) + output) if loss is not None else output 323 | 324 | return CausalLMOutputWithPast( 325 | loss=loss, 326 | logits=lm_logits, 327 | past_key_values=transformer_outputs.past_key_values, 328 | hidden_states=transformer_outputs.hidden_states, 329 | attentions=transformer_outputs.attentions, 330 | ) 331 | 332 | 333 | def main(): 334 | MQ = MQWenLMHeadModel.from_pretrained("F:/huggingface_model/qwen/Qwen-1_8B/", torch_dtype = torch.bfloat16, trust_remote_code = True) 335 | 336 | if __name__ == "__main__": 337 | main() 338 | -------------------------------------------------------------------------------- /qwen/cache_autogptq_cuda_256.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_256.cpp 6 | void vecquant8matmul_cuda( 7 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 8 | torch::Tensor scales, torch::Tensor zeros, 9 | torch::Tensor g_idx 10 | ); 11 | 12 | void vecquant8matmul( 13 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 14 | torch::Tensor scales, torch::Tensor zeros, 15 | torch::Tensor g_idx 16 | ) { 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 18 | vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); 19 | } 20 | 21 | void vecquant8matmul_batched_cuda( 22 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 23 | torch::Tensor scales, torch::Tensor zeros 24 | ); 25 | 26 | void vecquant8matmul_batched( 27 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 28 | torch::Tensor scales, torch::Tensor zeros 29 | ) { 30 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 31 | vecquant8matmul_batched_cuda(vec, mat, mul, scales, zeros); 32 | } 33 | 34 | void vecquant8matmul_batched_column_compression_cuda( 35 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 36 | torch::Tensor scales, torch::Tensor zeros 37 | ); 38 | 39 | void vecquant8matmul_batched_column_compression( 40 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 41 | torch::Tensor scales, torch::Tensor zeros 42 | ) { 43 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 44 | vecquant8matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros); 45 | } 46 | 47 | void vecquant4matmul_batched_cuda( 48 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 49 | torch::Tensor scales, torch::Tensor zeros 50 | ); 51 | 52 | void vecquant4matmul_batched( 53 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 54 | torch::Tensor scales, torch::Tensor zeros 55 | ) { 56 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 57 | vecquant4matmul_batched_cuda(vec, mat, mul, scales, zeros); 58 | } 59 | 60 | void vecquant4matmul_batched_column_compression_cuda( 61 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 62 | torch::Tensor scales, torch::Tensor zeros 63 | ); 64 | 65 | void vecquant4matmul_batched_column_compression( 66 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 67 | torch::Tensor scales, torch::Tensor zeros 68 | ) { 69 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 70 | vecquant4matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros); 71 | } 72 | 73 | void vecquant8matmul_batched_old_cuda( 74 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 75 | torch::Tensor scales, torch::Tensor zeros 76 | ); 77 | 78 | void vecquant8matmul_batched_old( 79 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 80 | torch::Tensor scales, torch::Tensor zeros 81 | ) { 82 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 83 | vecquant8matmul_batched_old_cuda(vec, mat, mul, scales, zeros); 84 | } 85 | 86 | 87 | void vecquant4matmul_batched_old_cuda( 88 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 89 | torch::Tensor scales, torch::Tensor zeros 90 | ); 91 | 92 | void vecquant4matmul_batched_old( 93 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 94 | torch::Tensor scales, torch::Tensor zeros 95 | ) { 96 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 97 | vecquant4matmul_batched_old_cuda(vec, mat, mul, scales, zeros); 98 | } 99 | 100 | void vecquant8matmul_batched_column_compression_old_cuda( 101 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 102 | torch::Tensor scales, torch::Tensor zeros 103 | ); 104 | 105 | void vecquant8matmul_batched_column_compression_old( 106 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 107 | torch::Tensor scales, torch::Tensor zeros 108 | ) { 109 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 110 | vecquant8matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros); 111 | } 112 | 113 | void vecquant4matmul_batched_column_compression_old_cuda( 114 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 115 | torch::Tensor scales, torch::Tensor zeros 116 | ); 117 | 118 | void vecquant4matmul_batched_column_compression_old( 119 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 120 | torch::Tensor scales, torch::Tensor zeros 121 | ) { 122 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 123 | vecquant4matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros); 124 | } 125 | 126 | 127 | 128 | void vecquant8matmul_batched_faster_cuda( 129 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 130 | torch::Tensor scales, torch::Tensor zeros 131 | ); 132 | 133 | void vecquant8matmul_batched_faster( 134 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 135 | torch::Tensor scales, torch::Tensor zeros 136 | ) { 137 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 138 | vecquant8matmul_batched_faster_cuda(vec, mat, mul, scales, zeros); 139 | } 140 | 141 | 142 | void vecquant8matmul_batched_faster_old_cuda( 143 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 144 | torch::Tensor scales, torch::Tensor zeros 145 | ); 146 | 147 | void vecquant8matmul_batched_faster_old( 148 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 149 | torch::Tensor scales, torch::Tensor zeros 150 | ) { 151 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 152 | vecquant8matmul_batched_faster_old_cuda(vec, mat, mul, scales, zeros); 153 | } 154 | 155 | void vecquant8matmul_batched_column_compression_faster_cuda( 156 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 157 | torch::Tensor scales, torch::Tensor zeros 158 | ); 159 | 160 | void vecquant8matmul_batched_column_compression_faster( 161 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 162 | torch::Tensor scales, torch::Tensor zeros 163 | ) { 164 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 165 | vecquant8matmul_batched_column_compression_faster_cuda(vec, mat, mul, scales, zeros); 166 | } 167 | 168 | 169 | void vecquant8matmul_batched_column_compression_faster_old_cuda( 170 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 171 | torch::Tensor scales, torch::Tensor zeros 172 | ); 173 | 174 | void vecquant8matmul_batched_column_compression_faster_old( 175 | torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, 176 | torch::Tensor scales, torch::Tensor zeros 177 | ) { 178 | const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); 179 | vecquant8matmul_batched_column_compression_faster_old_cuda(vec, mat, mul, scales, zeros); 180 | } 181 | 182 | 183 | 184 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 185 | m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); 186 | m.def("vecquant8matmul_batched", &vecquant8matmul_batched, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); 187 | m.def("vecquant8matmul_batched_old", &vecquant8matmul_batched_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); 188 | m.def("vecquant8matmul_batched_faster", &vecquant8matmul_batched_faster, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); 189 | m.def("vecquant8matmul_batched_faster_old", &vecquant8matmul_batched_faster_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); 190 | m.def("vecquant4matmul_batched_old", &vecquant4matmul_batched_old, "Vector 4-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); 191 | m.def("vecquant8matmul_batched_column_compression", &vecquant8matmul_batched_column_compression, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); 192 | m.def("vecquant8matmul_batched_column_compression_old", &vecquant8matmul_batched_column_compression_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); 193 | m.def("vecquant8matmul_batched_column_compression_faster", &vecquant8matmul_batched_column_compression_faster, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); 194 | m.def("vecquant8matmul_batched_column_compression_faster_old", &vecquant8matmul_batched_column_compression_faster_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); 195 | m.def("vecquant4matmul_batched_column_compression_old", &vecquant4matmul_batched_column_compression_old, "Vector old 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); 196 | m.def("vecquant4matmul_batched", &vecquant4matmul_batched, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); 197 | m.def("vecquant4matmul_batched_column_compression", &vecquant4matmul_batched_column_compression, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); 198 | } 199 | -------------------------------------------------------------------------------- /qwen/cache_autogptq_cuda_kernel_256.cu: -------------------------------------------------------------------------------- 1 | #define _CRT_SECURE_NO_WARNINGS 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) 10 | // adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu 11 | __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { 12 | unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); 13 | unsigned int old = *address_as_ui; 14 | unsigned int assumed; 15 | 16 | do { 17 | assumed = old; 18 | unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); 19 | hsum += val; 20 | old = reinterpret_cast(address) & 2 21 | ? (old & 0xffff) | (hsum << 16) 22 | : (old & 0xffff0000) | hsum; 23 | old = atomicCAS(address_as_ui, assumed, old); 24 | 25 | // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) 26 | } while (assumed != old); 27 | } 28 | __device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) { 29 | unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); 30 | unsigned int old = *address_as_ui; 31 | unsigned int assumed; 32 | 33 | do { 34 | assumed = old; 35 | __half_raw hsum; 36 | hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); 37 | half tmpres = __hadd(hsum, val); 38 | hsum = __half_raw(tmpres); 39 | old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; 40 | old = atomicCAS(address_as_ui, assumed, old); 41 | } while (assumed != old); 42 | } 43 | #endif 44 | 45 | template 46 | __global__ void VecQuant8MatMulKernel( 47 | const scalar_t* __restrict__ vec, 48 | const int* __restrict__ mat, 49 | scalar_t* __restrict__ mul, 50 | const scalar_t* __restrict__ scales, 51 | const int* __restrict__ zeros, 52 | const int* __restrict__ g_idx, 53 | int batch, 54 | int vec_height, 55 | int height, 56 | int width, 57 | int zero_width 58 | ); 59 | 60 | template 61 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel( 62 | const scalar_t* __restrict__ vec, 63 | const int* __restrict__ mat, 64 | scalar_t* __restrict__ mul, 65 | const scalar_t* __restrict__ scales, 66 | const int* __restrict__ zeros, 67 | int batch, 68 | int heads, 69 | int vec_row, 70 | int height, 71 | int width 72 | ); 73 | 74 | template 75 | __global__ void VecQuant4BatchMatMulColumnCompressionKernel( 76 | const scalar_t* __restrict__ vec, 77 | const int* __restrict__ mat, 78 | scalar_t* __restrict__ mul, 79 | const scalar_t* __restrict__ scales, 80 | const int* __restrict__ zeros, 81 | int batch, 82 | int heads, 83 | int vec_row, 84 | int height, 85 | int width 86 | ); 87 | 88 | template 89 | __global__ void VecQuant8BatchMatMulKernel( 90 | const scalar_t* __restrict__ vec, 91 | const int* __restrict__ mat, 92 | scalar_t* __restrict__ mul, 93 | const scalar_t* __restrict__ scales, 94 | const int* __restrict__ zeros, 95 | int batch, 96 | int heads, 97 | int vec_row, 98 | int vec_height, 99 | int height, 100 | int width, 101 | int zero_width 102 | ); 103 | 104 | template 105 | __global__ void VecQuant4BatchMatMulKernel( 106 | const scalar_t* __restrict__ vec, 107 | const int* __restrict__ mat, 108 | scalar_t* __restrict__ mul, 109 | const scalar_t* __restrict__ scales, 110 | const int* __restrict__ zeros, 111 | int batch, 112 | int heads, 113 | int vec_row, 114 | int vec_height, 115 | int height, 116 | int width, 117 | int zero_width 118 | ); 119 | 120 | 121 | 122 | template 123 | __global__ void VecQuant8BatchMatMulKernel_old( 124 | const scalar_t* __restrict__ vec, 125 | const uint8_t* __restrict__ mat, 126 | scalar_t* __restrict__ mul, 127 | const scalar_t* __restrict__ scales, 128 | const scalar_t* __restrict__ zeros, 129 | int batch, 130 | int heads, 131 | int vec_row, 132 | int vec_height, 133 | int height, 134 | int width, 135 | int zero_width 136 | ); 137 | 138 | __global__ void VecQuant8BatchMatMulKernel_faster( 139 | const half* __restrict__ vec, 140 | const uint8_t* __restrict__ mat, 141 | half* __restrict__ mul, 142 | const half* __restrict__ scales, 143 | const half* __restrict__ zeros, 144 | int batch, 145 | int heads, 146 | int vec_row, 147 | int vec_height, 148 | int height, 149 | int width, 150 | int zero_width 151 | ); 152 | 153 | 154 | 155 | __global__ void VecQuant8BatchMatMulKernel_faster_old( 156 | const half* __restrict__ vec, 157 | const uint8_t* __restrict__ mat, 158 | half* __restrict__ mul, 159 | const half* __restrict__ scales, 160 | const half* __restrict__ zeros, 161 | int batch, 162 | int heads, 163 | int vec_row, 164 | int vec_height, 165 | int height, 166 | int width 167 | ); 168 | 169 | 170 | template 171 | __global__ void VecQuant4BatchMatMulKernel_old( 172 | const scalar_t* __restrict__ vec, 173 | const uint8_t* __restrict__ mat, 174 | scalar_t* __restrict__ mul, 175 | const scalar_t* __restrict__ scales, 176 | const scalar_t* __restrict__ zeros, 177 | int batch, 178 | int heads, 179 | int vec_row, 180 | int vec_height, 181 | int height, 182 | int width, 183 | int zero_width 184 | ); 185 | 186 | 187 | template 188 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_old( 189 | const scalar_t* __restrict__ vec, 190 | const uint8_t* __restrict__ mat, 191 | scalar_t* __restrict__ mul, 192 | const scalar_t* __restrict__ scales, 193 | const scalar_t* __restrict__ zeros, 194 | int batch, 195 | int heads, 196 | int vec_row, 197 | int height, 198 | int width 199 | ); 200 | 201 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( 202 | const half* __restrict__ vec, 203 | const uint8_t* __restrict__ mat, 204 | half* __restrict__ mul, 205 | const half* __restrict__ scales, 206 | const half* __restrict__ zeros, 207 | int batch, 208 | int heads, 209 | int vec_row, 210 | int height, 211 | int width 212 | ); 213 | 214 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old( 215 | const half* __restrict__ vec, 216 | const uint8_t* __restrict__ mat, 217 | half* __restrict__ mul, 218 | const half* __restrict__ scales, 219 | const half* __restrict__ zeros, 220 | int batch, 221 | int heads, 222 | int vec_row, 223 | int height, 224 | int width 225 | ); 226 | 227 | 228 | template 229 | __global__ void VecQuant4BatchMatMulColumnCompressionKernel_old( 230 | const scalar_t* __restrict__ vec, 231 | const uint8_t* __restrict__ mat, 232 | scalar_t* __restrict__ mul, 233 | const scalar_t* __restrict__ scales, 234 | const scalar_t* __restrict__ zeros, 235 | int batch, 236 | int heads, 237 | int vec_row, 238 | int height, 239 | int width 240 | ); 241 | 242 | 243 | __global__ void VecQuant8BatchMatMulKernel_faster( 244 | const half* __restrict__ vec, 245 | const uint8_t* __restrict__ mat, 246 | half* __restrict__ mul, 247 | const half* __restrict__ scales, 248 | const half* __restrict__ zeros, 249 | int batch, 250 | int heads, 251 | int vec_row, 252 | int vec_height, 253 | int height, 254 | int width 255 | ); 256 | 257 | 258 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( 259 | const half* __restrict__ vec, 260 | const uint8_t* __restrict__ mat, 261 | half* __restrict__ mul, 262 | const half* __restrict__ scales, 263 | const half* __restrict__ zeros, 264 | int batch, 265 | int heads, 266 | int vec_row, 267 | int height, 268 | int width 269 | ); 270 | 271 | const int BLOCKWIDTH = 128; 272 | const int BLOCKHEIGHT8 = 32; 273 | const int BLOCKHEIGHT4 = 16; 274 | const int BLOCKHEIGHT_OLD4 = 128; 275 | //const int BLOCKHEIGHT_OLD8 = 128; 276 | 277 | __device__ inline unsigned int as_unsigned(int i) { 278 | return *reinterpret_cast(&i); 279 | } 280 | 281 | __device__ inline int as_int(int i) { 282 | return *reinterpret_cast(&i); 283 | } 284 | 285 | void vecquant8matmul_batched_column_compression_cuda( 286 | torch::Tensor vec, 287 | torch::Tensor mat, 288 | torch::Tensor mul, 289 | torch::Tensor scales, 290 | torch::Tensor zeros 291 | ) { 292 | int batch = vec.size(0); 293 | int heads = vec.size(1); 294 | int vec_row = vec.size(2); 295 | int height = vec.size(3); 296 | int width = mat.size(3) * 4; 297 | 298 | dim3 blocks( 299 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 300 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 301 | ); 302 | dim3 threads(BLOCKWIDTH); 303 | 304 | AT_DISPATCH_FLOATING_TYPES( 305 | vec.type(), "vecquant8matmul_batched_cuda", ([&] { 306 | VecQuant8BatchMatMulColumnCompressionKernel<<>>( 307 | vec.data(), mat.data(), mul.data(), 308 | scales.data(), zeros.data(), 309 | batch, heads, vec_row, height, width 310 | ); 311 | }) 312 | ); 313 | 314 | } 315 | 316 | template 317 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel( 318 | const scalar_t* __restrict__ vec, 319 | const int* __restrict__ mat, 320 | scalar_t* __restrict__ mul, 321 | const scalar_t* __restrict__ scales, 322 | const int* __restrict__ zeros, 323 | int batch, 324 | int heads, 325 | int vec_row, 326 | int height, 327 | int width 328 | ) { 329 | int weight_total = batch * heads * height * width / 4; 330 | int input_total = batch * heads * vec_row * height; 331 | int out_total = batch * heads * vec_row * width; 332 | int tid = threadIdx.x; 333 | // h is index of height with step being BLOCKWIDTH 334 | int h = BLOCKWIDTH * blockIdx.x; 335 | // w is index of width with step being 1 336 | int w = BLOCKWIDTH * blockIdx.y + tid; 337 | if (w >= width && tid >= height) { 338 | return; 339 | } 340 | 341 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 342 | int k; 343 | scalar_t w_tmp; 344 | 345 | float weight[BLOCKWIDTH]; 346 | 347 | for (int b = 0; b < batch; ++b){ 348 | for (int head = 0; head < heads; ++head){ 349 | int batch_shift = b * heads + head; 350 | for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ 351 | int i_w = (w / 4); 352 | int w_bit = (w % 4) * 8; 353 | 354 | int w_index = (batch_shift * height + h + k) * width / 4 + i_w; 355 | if (w_index >= weight_total || w >= width) { 356 | weight[k] = 0; 357 | } else { 358 | scalar_t scale = scales[batch_shift * height + h + k]; 359 | scalar_t zero = zeros[batch_shift * height + h + k]; 360 | w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xFF); 361 | weight[k] = scale * (w_tmp - zero); 362 | } 363 | } 364 | 365 | scalar_t res; 366 | for (int vr = 0; vr < vec_row; ++vr){ 367 | res = 0; 368 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 369 | if (vec_index < input_total) { 370 | blockvec[tid] = vec[vec_index]; 371 | } else { 372 | blockvec[tid] = 0; 373 | } 374 | 375 | __syncthreads(); 376 | for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ 377 | // res is the dot product of BLOCKWIDTH elements (part of width) 378 | res += weight[k] * blockvec[k]; 379 | } 380 | // add res to the final result, final matrix shape: (batch, vec_row, width) 381 | int out_index = (batch_shift * vec_row + vr) * width + w; 382 | if (out_index < out_total) { 383 | atomicAdd(&mul[out_index], res); 384 | } 385 | __syncthreads(); 386 | } 387 | } 388 | } 389 | } 390 | 391 | void vecquant8matmul_batched_cuda( 392 | torch::Tensor vec, 393 | torch::Tensor mat, 394 | torch::Tensor mul, 395 | torch::Tensor scales, 396 | torch::Tensor zeros 397 | ) { 398 | int batch = vec.size(0); 399 | int heads = vec.size(1); 400 | int vec_row = vec.size(2); 401 | int vec_height = vec.size(3); 402 | int height = mat.size(2); 403 | int width = mat.size(3); 404 | int zero_width = zeros.size(2); 405 | 406 | dim3 blocks( 407 | (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, 408 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 409 | ); 410 | dim3 threads(BLOCKWIDTH); 411 | 412 | AT_DISPATCH_FLOATING_TYPES( 413 | vec.type(), "vecquant8matmul_batched_cuda", ([&] { 414 | VecQuant8BatchMatMulKernel<<>>( 415 | vec.data(), mat.data(), mul.data(), 416 | scales.data(), zeros.data(), 417 | batch, heads, vec_row, vec_height, height, width, zero_width 418 | ); 419 | }) 420 | ); 421 | 422 | } 423 | 424 | template 425 | __global__ void VecQuant8BatchMatMulKernel( 426 | const scalar_t* __restrict__ vec, 427 | const int* __restrict__ mat, 428 | scalar_t* __restrict__ mul, 429 | const scalar_t* __restrict__ scales, 430 | const int* __restrict__ zeros, 431 | int batch, 432 | int heads, 433 | int vec_row, 434 | int vec_height, 435 | int height, 436 | int width, 437 | int zero_width 438 | ) { 439 | int weight_total = batch * heads * height * width; 440 | int input_total = batch * heads * vec_row * vec_height; 441 | int out_total = batch * heads * vec_row * width; 442 | int tid = threadIdx.x; 443 | // h is index of height with step being BLOCKHEIGHT8 444 | int h = BLOCKHEIGHT8 * blockIdx.x; 445 | // w is index of width with step being 1 446 | int w = BLOCKWIDTH * blockIdx.y + tid; 447 | if (w >= width && tid >= vec_height) { 448 | return; 449 | } 450 | 451 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 452 | // i is index of mat of block first row 453 | int i = width * h + w; 454 | // if (i >= width * height) { 455 | // return; 456 | // } 457 | int k; 458 | scalar_t w_tmp; 459 | 460 | int z_w = w / 4; 461 | int z_mod = (w % 4) * 8; 462 | 463 | float weight[BLOCKWIDTH]; 464 | 465 | for (int b = 0; b < batch; ++b){ 466 | for (int head = 0; head < heads; ++head){ 467 | int batch_shift = b * heads + head; 468 | for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){ 469 | int k_w = (k / 4); 470 | int k_bit = (k % 4) * 8; 471 | 472 | int w_index = batch_shift * height * width + i + (k_w * width); 473 | if (w_index >= weight_total || w >= width) { 474 | weight[k] = 0; 475 | } else { 476 | scalar_t scale = scales[batch_shift * width + w]; 477 | scalar_t zero; 478 | if (zero_width == width) { 479 | zero = zeros[batch_shift * width + w]; 480 | } else { 481 | zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xFF) + 1); 482 | } 483 | w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xFF); 484 | weight[k] = scale * (w_tmp - zero); 485 | } 486 | } 487 | 488 | scalar_t res; 489 | for (int vr = 0; vr < vec_row; ++vr){ 490 | res = 0; 491 | int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; 492 | if (vec_index < input_total) { 493 | blockvec[tid] = vec[vec_index]; 494 | } else { 495 | blockvec[tid] = 0; 496 | } 497 | 498 | __syncthreads(); 499 | for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){ 500 | // res is the dot product of BLOCKWIDTH elements (part of width) 501 | res += weight[k] * blockvec[k]; 502 | } 503 | // add res to the final result, final matrix shape: (batch, vec_row, width) 504 | int out_index = (batch_shift * vec_row + vr) * width + w; 505 | if (out_index < out_total) { 506 | atomicAdd(&mul[out_index], res); 507 | } 508 | __syncthreads(); 509 | } 510 | } 511 | } 512 | } 513 | 514 | 515 | void vecquant8matmul_cuda( 516 | torch::Tensor vec, 517 | torch::Tensor mat, 518 | torch::Tensor mul, 519 | torch::Tensor scales, 520 | torch::Tensor zeros, 521 | torch::Tensor g_idx 522 | ) { 523 | int batch = vec.size(0); 524 | int vec_height = vec.size(1); 525 | int height = mat.size(0); 526 | int width = mat.size(1); 527 | int zero_width = zeros.size(1); 528 | 529 | dim3 blocks( 530 | (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, 531 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 532 | ); 533 | dim3 threads(BLOCKWIDTH); 534 | 535 | AT_DISPATCH_FLOATING_TYPES( 536 | vec.type(), "vecquant8matmul_cuda", ([&] { 537 | VecQuant8MatMulKernel<<>>( 538 | vec.data(), mat.data(), mul.data(), 539 | scales.data(), zeros.data(), g_idx.data(), 540 | batch, vec_height, height, width, zero_width 541 | ); 542 | }) 543 | ); 544 | } 545 | 546 | template 547 | __global__ void VecQuant8MatMulKernel( 548 | const scalar_t* __restrict__ vec, 549 | const int* __restrict__ mat, 550 | scalar_t* __restrict__ mul, 551 | const scalar_t* __restrict__ scales, 552 | const int* __restrict__ zeros, 553 | const int* __restrict__ g_idx, 554 | int batch, 555 | int vec_height, 556 | int height, 557 | int width, 558 | int zero_width 559 | ) { 560 | int h = BLOCKHEIGHT8 * blockIdx.x; 561 | int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; 562 | 563 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 564 | int i = width * h + w; 565 | int g_h = h * 4; 566 | int k; 567 | unsigned int g; 568 | scalar_t w_tmp; 569 | 570 | int z_w = w / 4; 571 | int z_mod = (w % 4) * 8; 572 | 573 | float weight[BLOCKWIDTH]; 574 | 575 | for (k = 0; k < BLOCKWIDTH; ++k){ 576 | int k_w = (k / 4); 577 | int k_bit = (k % 4) * 8; 578 | 579 | g = as_int(g_idx[g_h + k]); 580 | scalar_t scale = scales[g * width + w]; 581 | scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); 582 | 583 | w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); 584 | 585 | weight[k] = scale * (w_tmp - zero); 586 | } 587 | 588 | 589 | scalar_t res; 590 | for (int b = 0; b < batch; ++b){ 591 | res = 0; 592 | blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; 593 | __syncthreads(); 594 | for (k = 0; k < BLOCKWIDTH; ++k){ 595 | res += weight[k] * blockvec[k]; 596 | } 597 | atomicAdd(&mul[b * width + w], res); 598 | __syncthreads(); 599 | } 600 | } 601 | 602 | 603 | 604 | void vecquant4matmul_batched_cuda( 605 | torch::Tensor vec, 606 | torch::Tensor mat, 607 | torch::Tensor mul, 608 | torch::Tensor scales, 609 | torch::Tensor zeros 610 | ) { 611 | int batch = vec.size(0); 612 | int heads = vec.size(1); 613 | int vec_row = vec.size(2); 614 | int vec_height = vec.size(3); 615 | int height = mat.size(2); 616 | int width = mat.size(3); 617 | int zero_width = zeros.size(2); 618 | 619 | dim3 blocks( 620 | (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, 621 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 622 | ); 623 | dim3 threads(BLOCKWIDTH); 624 | 625 | AT_DISPATCH_FLOATING_TYPES( 626 | vec.type(), "vecquant4matmul_batched_cuda", ([&] { 627 | VecQuant4BatchMatMulKernel<<>>( 628 | vec.data(), mat.data(), mul.data(), 629 | scales.data(), zeros.data(), 630 | batch, heads, vec_row, vec_height, height, width, zero_width 631 | ); 632 | }) 633 | ); 634 | 635 | } 636 | 637 | template 638 | __global__ void VecQuant4BatchMatMulKernel( 639 | const scalar_t* __restrict__ vec, 640 | const int* __restrict__ mat, 641 | scalar_t* __restrict__ mul, 642 | const scalar_t* __restrict__ scales, 643 | const int* __restrict__ zeros, 644 | int batch, 645 | int heads, 646 | int vec_row, 647 | int vec_height, 648 | int height, 649 | int width, 650 | int zero_width 651 | ) { 652 | int weight_total = batch * heads * height * width; 653 | int input_total = batch * heads * vec_row * vec_height; 654 | int out_total = batch * heads * vec_row * width; 655 | int tid = threadIdx.x; 656 | // h is index of height with step being BLOCKHEIGHT4 657 | int h = BLOCKHEIGHT4 * blockIdx.x; 658 | // w is index of width with step being 1 659 | int w = BLOCKWIDTH * blockIdx.y + tid; 660 | if (w >= width && tid >= vec_height) { 661 | return; 662 | } 663 | 664 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 665 | // i is index of mat of block first row 666 | int i = width * h + w; 667 | int k; 668 | scalar_t w_tmp; 669 | 670 | int z_w = w / 8; 671 | int z_mod = (w % 8) * 4; 672 | 673 | float weight[BLOCKWIDTH]; 674 | 675 | for (int b = 0; b < batch; ++b){ 676 | for (int head = 0; head < heads; ++head){ 677 | int batch_shift = b * heads + head; 678 | for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){ 679 | int k_w = (k / 8); 680 | int k_bit = (k % 8) * 4; 681 | 682 | int w_index = batch_shift * height * width + i + (k_w * width); 683 | if (w_index >= weight_total || w >= width) { 684 | weight[k] = 0; 685 | } else { 686 | scalar_t scale = scales[batch_shift * width + w]; 687 | scalar_t zero; 688 | if (zero_width == width) { 689 | zero = zeros[batch_shift * width + w]; 690 | } else { 691 | zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xF)); 692 | } 693 | w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); 694 | weight[k] = scale * (w_tmp - zero); 695 | } 696 | } 697 | 698 | scalar_t res; 699 | for (int vr = 0; vr < vec_row; ++vr){ 700 | res = 0; 701 | int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; 702 | if (vec_index < input_total) { 703 | blockvec[tid] = vec[vec_index]; 704 | } else { 705 | blockvec[tid] = 0; 706 | } 707 | 708 | __syncthreads(); 709 | for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){ 710 | // res is the dot product of BLOCKWIDTH elements (part of width) 711 | res += weight[k] * blockvec[k]; 712 | } 713 | // add res to the final result, final matrix shape: (batch, vec_row, width) 714 | int out_index = (batch_shift * vec_row + vr) * width + w; 715 | if (out_index < out_total) { 716 | atomicAdd(&mul[out_index], res); 717 | } 718 | __syncthreads(); 719 | } 720 | } 721 | } 722 | } 723 | 724 | 725 | 726 | void vecquant4matmul_batched_column_compression_cuda( 727 | torch::Tensor vec, 728 | torch::Tensor mat, 729 | torch::Tensor mul, 730 | torch::Tensor scales, 731 | torch::Tensor zeros 732 | ) { 733 | int batch = vec.size(0); 734 | int heads = vec.size(1); 735 | int vec_row = vec.size(2); 736 | int height = vec.size(3); 737 | int width = mat.size(3) * 8; 738 | 739 | dim3 blocks( 740 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 741 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 742 | ); 743 | dim3 threads(BLOCKWIDTH); 744 | 745 | AT_DISPATCH_FLOATING_TYPES( 746 | vec.type(), "vecquant4matmul_batched_cuda", ([&] { 747 | VecQuant4BatchMatMulColumnCompressionKernel<<>>( 748 | vec.data(), mat.data(), mul.data(), 749 | scales.data(), zeros.data(), 750 | batch, heads, vec_row, height, width 751 | ); 752 | }) 753 | ); 754 | 755 | } 756 | 757 | template 758 | __global__ void VecQuant4BatchMatMulColumnCompressionKernel( 759 | const scalar_t* __restrict__ vec, 760 | const int* __restrict__ mat, 761 | scalar_t* __restrict__ mul, 762 | const scalar_t* __restrict__ scales, 763 | const int* __restrict__ zeros, 764 | int batch, 765 | int heads, 766 | int vec_row, 767 | int height, 768 | int width 769 | ) { 770 | int weight_total = batch * heads * height * width / 8; 771 | int input_total = batch * heads * vec_row * height; 772 | int out_total = batch * heads * vec_row * width; 773 | int tid = threadIdx.x; 774 | // h is index of height with step being BLOCKWIDTH 775 | int h = BLOCKWIDTH * blockIdx.x; 776 | // w is index of width with step being 1 777 | int w = BLOCKWIDTH * blockIdx.y + tid; 778 | if (w >= width && tid >= height) { 779 | return; 780 | } 781 | 782 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 783 | int k; 784 | scalar_t w_tmp; 785 | 786 | float weight[BLOCKWIDTH]; 787 | 788 | for (int b = 0; b < batch; ++b){ 789 | for (int head = 0; head < heads; ++head){ 790 | int batch_shift = b * heads + head; 791 | for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ 792 | int i_w = (w / 8); 793 | int w_bit = (w % 8) * 4; 794 | 795 | int w_index = (batch_shift * height + h + k) * width / 8 + i_w; 796 | if (w_index >= weight_total || w >= width) { 797 | weight[k] = 0; 798 | } else { 799 | scalar_t scale = scales[batch_shift * height + h + k]; 800 | scalar_t zero = zeros[batch_shift * height + h + k]; 801 | w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xF); 802 | weight[k] = scale * (w_tmp - zero); 803 | } 804 | } 805 | 806 | scalar_t res; 807 | for (int vr = 0; vr < vec_row; ++vr){ 808 | res = 0; 809 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 810 | if (vec_index < input_total) { 811 | blockvec[tid] = vec[vec_index]; 812 | } else { 813 | blockvec[tid] = 0; 814 | } 815 | 816 | __syncthreads(); 817 | for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ 818 | // res is the dot product of BLOCKWIDTH elements (part of width) 819 | res += weight[k] * blockvec[k]; 820 | } 821 | // add res to the final result, final matrix shape: (batch, vec_row, width) 822 | int out_index = (batch_shift * vec_row + vr) * width + w; 823 | if (out_index < out_total) { 824 | atomicAdd(&mul[out_index], res); 825 | } 826 | __syncthreads(); 827 | } 828 | } 829 | } 830 | } 831 | 832 | 833 | void vecquant8matmul_batched_old_cuda( 834 | torch::Tensor vec, 835 | torch::Tensor mat, 836 | torch::Tensor mul, 837 | torch::Tensor scales, 838 | torch::Tensor zeros 839 | ) { 840 | int batch = vec.size(0); 841 | int heads = vec.size(1); 842 | int vec_row = vec.size(2); 843 | int vec_height = vec.size(3); 844 | int height = mat.size(2); 845 | int width = mat.size(3); 846 | int zero_width = zeros.size(2); 847 | 848 | dim3 blocks( 849 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 850 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 851 | ); 852 | dim3 threads(BLOCKWIDTH); 853 | 854 | AT_DISPATCH_FLOATING_TYPES( 855 | vec.type(), "vecquant8matmul_batched_old_cuda", ([&] { 856 | VecQuant8BatchMatMulKernel_old<<>>( 857 | vec.data(), mat.data(), mul.data(), 858 | scales.data(), zeros.data(), 859 | batch, heads, vec_row, vec_height, height, width, zero_width 860 | ); 861 | }) 862 | ); 863 | } 864 | 865 | 866 | template 867 | __global__ void VecQuant8BatchMatMulKernel_old( 868 | const scalar_t* __restrict__ vec, 869 | const uint8_t* __restrict__ mat, 870 | scalar_t* __restrict__ mul, 871 | const scalar_t* __restrict__ scales, 872 | const scalar_t* __restrict__ zeros, 873 | int batch, 874 | int heads, 875 | int vec_row, 876 | int vec_height, 877 | int height, 878 | int width, 879 | int zero_width 880 | ) { 881 | int weight_total = batch * heads * height * width; 882 | int input_total = batch * heads * vec_row * vec_height; 883 | int out_total = batch * heads * vec_row * width; 884 | int tid = threadIdx.x; 885 | // h is index of height with step being BLOCKHEIGHT8 886 | int h = BLOCKWIDTH * blockIdx.x; 887 | // w is index of width with step being 1 888 | int w = BLOCKWIDTH * blockIdx.y + tid; 889 | if (w >= width && tid >= vec_height) { 890 | return; 891 | } 892 | 893 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 894 | // i is index of mat of block first row 895 | int i = width * h + w; 896 | int k; 897 | scalar_t w_tmp; 898 | 899 | float weight[BLOCKWIDTH]; 900 | for (int b = 0; b < batch; ++b){ 901 | for (int head = 0; head < heads; ++head){ 902 | int batch_shift = b * heads + head; 903 | for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ 904 | int k_w = k; 905 | int w_index = batch_shift * height * width + i + (k_w * width); 906 | if (w_index >= weight_total || w >= width) { 907 | weight[k] = 0; 908 | } else { 909 | scalar_t scale = scales[batch_shift * width + w]; 910 | scalar_t zero = zeros[batch_shift * width + w]; 911 | w_tmp = as_unsigned(mat[w_index]); 912 | weight[k] = scale * (w_tmp - zero); 913 | } 914 | } 915 | 916 | scalar_t res; 917 | for (int vr = 0; vr < vec_row; ++vr){ 918 | res = 0; 919 | int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; 920 | if (vec_index < input_total) { 921 | blockvec[tid] = vec[vec_index]; 922 | } else { 923 | blockvec[tid] = 0; 924 | } 925 | 926 | __syncthreads(); 927 | for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ 928 | // res is the dot product of BLOCKWIDTH elements (part of width) 929 | res += weight[k] * blockvec[k]; 930 | } 931 | // add res to the final result, final matrix shape: (batch, vec_row, width) 932 | int out_index = (batch_shift * vec_row + vr) * width + w; 933 | if (out_index < out_total) { 934 | atomicAdd(&mul[out_index], res); 935 | } 936 | __syncthreads(); 937 | } 938 | } 939 | } 940 | } 941 | 942 | 943 | 944 | void vecquant8matmul_batched_faster_cuda( 945 | torch::Tensor vec, 946 | torch::Tensor mat, 947 | torch::Tensor mul, 948 | torch::Tensor scales, 949 | torch::Tensor zeros 950 | ) { 951 | int batch = vec.size(0); 952 | int heads = vec.size(1); 953 | int vec_row = vec.size(2); 954 | int vec_height = vec.size(3); 955 | int height = mat.size(2); 956 | int width = mat.size(3); 957 | int zero_width = zeros.size(2); 958 | 959 | dim3 blocks( 960 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 961 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 962 | ); 963 | dim3 threads(BLOCKWIDTH); 964 | 965 | VecQuant8BatchMatMulKernel_faster<<>>( 966 | (half*) vec.data_ptr(), 967 | (uint8_t*) mat.data_ptr(), 968 | (half*) mul.data_ptr(), 969 | (half*) scales.data_ptr(), 970 | (half*) zeros.data_ptr(), 971 | batch, heads, vec_row, vec_height, height, width, zero_width 972 | ); 973 | } 974 | 975 | 976 | 977 | __global__ void VecQuant8BatchMatMulKernel_faster( 978 | const half* __restrict__ vec, 979 | const uint8_t* __restrict__ mat, 980 | half* __restrict__ mul, 981 | const half* __restrict__ scales, 982 | const half* __restrict__ zeros, 983 | int batch, 984 | int heads, 985 | int vec_row, 986 | int vec_height, 987 | int height, 988 | int width, 989 | int zero_width 990 | ) { 991 | //int weight_total = batch * heads * height * width; 992 | int input_total = batch * heads * vec_row * vec_height; 993 | int out_total = batch * heads * vec_row * width; 994 | int tid = threadIdx.x; 995 | int h = BLOCKWIDTH * blockIdx.x; 996 | int w = BLOCKWIDTH * blockIdx.y + tid; 997 | if (w >= width && tid >= height) { 998 | return; 999 | } 1000 | 1001 | __shared__ float blockvec[BLOCKWIDTH]; 1002 | int i = width * h + w; 1003 | int k; 1004 | float w_tmp; 1005 | 1006 | float weight[BLOCKWIDTH]; 1007 | for (int b = 0; b < batch; ++b){ 1008 | for (int head = 0; head < heads; ++head){ 1009 | int batch_shift = b * heads + head; 1010 | for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ 1011 | int k_w = k; 1012 | int w_index = batch_shift * height * width + i + (k_w * width); 1013 | float scale = __half2float(scales[batch_shift * width + w]); 1014 | float zero = __half2float(zeros[batch_shift * width + w]); 1015 | w_tmp = as_unsigned(mat[w_index]); 1016 | weight[k] = scale *(w_tmp-zero); 1017 | } 1018 | 1019 | float res; 1020 | for (int vr = 0; vr < vec_row; ++vr){ 1021 | res = 0; 1022 | int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; 1023 | if (vec_index < input_total) { 1024 | blockvec[tid] = __half2float(vec[vec_index]); 1025 | } else { 1026 | blockvec[tid] = 0; 1027 | } 1028 | __syncthreads(); 1029 | for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ 1030 | float temp_res = weight[k]*blockvec[k]; 1031 | res += temp_res; 1032 | } 1033 | int out_index = (batch_shift * vec_row + vr) * width + w; 1034 | if (out_index < out_total) { 1035 | atomicAdd(&mul[out_index], __float2half(res)); 1036 | } 1037 | __syncthreads(); 1038 | } 1039 | } 1040 | } 1041 | } 1042 | 1043 | 1044 | 1045 | 1046 | void vecquant8matmul_batched_column_compression_faster_cuda( 1047 | torch::Tensor vec, 1048 | torch::Tensor mat, 1049 | torch::Tensor mul, 1050 | torch::Tensor scales, 1051 | torch::Tensor zeros 1052 | ) { 1053 | int batch = vec.size(0); 1054 | int heads = vec.size(1); 1055 | int vec_row = vec.size(2); 1056 | int height = vec.size(3); 1057 | int width = mat.size(3); 1058 | 1059 | dim3 blocks( 1060 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 1061 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 1062 | ); 1063 | dim3 threads(BLOCKWIDTH); 1064 | 1065 | VecQuant8BatchMatMulColumnCompressionKernel_faster<<>>( 1066 | (half*) vec.data_ptr(), 1067 | (uint8_t*) mat.data_ptr(), 1068 | (half*) mul.data_ptr(), 1069 | (half*) scales.data_ptr(), 1070 | (half*) zeros.data_ptr(), 1071 | batch, heads, vec_row, height, width 1072 | ); 1073 | 1074 | } 1075 | 1076 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( 1077 | const half* __restrict__ vec, 1078 | const uint8_t* __restrict__ mat, 1079 | half* __restrict__ mul, 1080 | const half* __restrict__ scales, 1081 | const half* __restrict__ zeros, 1082 | int batch, 1083 | int heads, 1084 | int vec_row, 1085 | int height, 1086 | int width 1087 | ) { 1088 | //int weight_total = batch * heads * height * width; 1089 | int input_total = batch * heads * vec_row * height; 1090 | int out_total = batch * heads * vec_row * width; 1091 | int tid = threadIdx.x; 1092 | int h = BLOCKWIDTH * blockIdx.x; 1093 | int w = BLOCKWIDTH * blockIdx.y + tid; 1094 | if (w >= width && tid >= height) { 1095 | return; 1096 | } 1097 | 1098 | __shared__ float blockvec[BLOCKWIDTH]; 1099 | int k; 1100 | float w_tmp; 1101 | float weight[BLOCKWIDTH]; 1102 | 1103 | for (int b = 0; b < batch; ++b){ 1104 | for (int head = 0; head < heads; ++head){ 1105 | int batch_shift = b * heads + head; 1106 | for (k = 0; k < BLOCKWIDTH; ++k){ 1107 | int w_index = (batch_shift * height + h + k) * width + w; 1108 | float scale = __half2float(scales[batch_shift * height + h + k]); 1109 | float zero = __half2float(zeros[batch_shift * height + h + k]); 1110 | w_tmp = mat[w_index]; 1111 | weight[k] = scale * (w_tmp-zero); 1112 | } 1113 | 1114 | float res; 1115 | for (int vr = 0; vr < vec_row; ++vr){ 1116 | res = 0; 1117 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 1118 | if (vec_index < input_total) { 1119 | blockvec[tid] = __half2float(vec[vec_index]); 1120 | } else { 1121 | blockvec[tid] = 0; 1122 | } 1123 | __syncthreads(); 1124 | for (k = 0; k < BLOCKWIDTH; ++k){ 1125 | res += weight[k]*blockvec[k]; 1126 | } 1127 | int out_index = (batch_shift * vec_row + vr) * width + w; 1128 | if (out_index < out_total) { 1129 | atomicAdd(&mul[out_index], __float2half(res)); 1130 | } 1131 | __syncthreads(); 1132 | } 1133 | } 1134 | } 1135 | } 1136 | 1137 | 1138 | 1139 | void vecquant8matmul_batched_column_compression_old_cuda( 1140 | torch::Tensor vec, 1141 | torch::Tensor mat, 1142 | torch::Tensor mul, 1143 | torch::Tensor scales, 1144 | torch::Tensor zeros 1145 | ) { 1146 | int batch = vec.size(0); 1147 | int heads = vec.size(1); 1148 | int vec_row = vec.size(2); 1149 | int height = vec.size(3); 1150 | int width = mat.size(3); 1151 | 1152 | dim3 blocks( 1153 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 1154 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 1155 | ); 1156 | dim3 threads(BLOCKWIDTH); 1157 | 1158 | AT_DISPATCH_FLOATING_TYPES( 1159 | vec.type(), "vecquant8matmul_batched_column_compression_old_cuda", ([&] { 1160 | VecQuant8BatchMatMulColumnCompressionKernel_old<<>>( 1161 | vec.data(), mat.data(), mul.data(), 1162 | scales.data(), zeros.data(), 1163 | batch, heads, vec_row, height, width 1164 | ); 1165 | }) 1166 | ); 1167 | 1168 | } 1169 | 1170 | template 1171 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_old( 1172 | const scalar_t* __restrict__ vec, 1173 | const uint8_t* __restrict__ mat, 1174 | scalar_t* __restrict__ mul, 1175 | const scalar_t* __restrict__ scales, 1176 | const scalar_t* __restrict__ zeros, 1177 | int batch, 1178 | int heads, 1179 | int vec_row, 1180 | int height, 1181 | int width 1182 | ) { 1183 | int weight_total = batch * heads * height * width; 1184 | int input_total = batch * heads * vec_row * height; 1185 | int out_total = batch * heads * vec_row * width; 1186 | int tid = threadIdx.x; 1187 | // h is index of height with step being BLOCKWIDTH 1188 | int h = BLOCKWIDTH * blockIdx.x; 1189 | // w is index of width with step being 1 1190 | int w = BLOCKWIDTH * blockIdx.y + tid; 1191 | if (w >= width && tid >= height) { 1192 | return; 1193 | } 1194 | 1195 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 1196 | int k; 1197 | scalar_t w_tmp; 1198 | 1199 | float weight[BLOCKWIDTH]; 1200 | 1201 | for (int b = 0; b < batch; ++b){ 1202 | for (int head = 0; head < heads; ++head){ 1203 | int batch_shift = b * heads + head; 1204 | for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ 1205 | int w_index = (batch_shift * height + h + k) * width + w; 1206 | if (w_index >= weight_total || w >= width) { 1207 | weight[k] = 0; 1208 | } else { 1209 | scalar_t scale = scales[batch_shift * height + h + k]; 1210 | scalar_t zero = zeros[batch_shift * height + h + k]; 1211 | w_tmp = mat[w_index]; 1212 | weight[k] = scale * (w_tmp - zero); 1213 | } 1214 | } 1215 | 1216 | scalar_t res; 1217 | for (int vr = 0; vr < vec_row; ++vr){ 1218 | res = 0; 1219 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 1220 | if (vec_index < input_total) { 1221 | blockvec[tid] = vec[vec_index]; 1222 | } else { 1223 | blockvec[tid] = 0; 1224 | } 1225 | 1226 | __syncthreads(); 1227 | for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ 1228 | // res is the dot product of BLOCKWIDTH elements (part of width) 1229 | res += weight[k] * blockvec[k]; 1230 | } 1231 | // add res to the final result, final matrix shape: (batch, vec_row, width) 1232 | int out_index = (batch_shift * vec_row + vr) * width + w; 1233 | if (out_index < out_total) { 1234 | atomicAdd(&mul[out_index], res); 1235 | } 1236 | __syncthreads(); 1237 | } 1238 | } 1239 | } 1240 | } 1241 | 1242 | 1243 | void vecquant4matmul_batched_old_cuda( 1244 | torch::Tensor vec, 1245 | torch::Tensor mat, 1246 | torch::Tensor mul, 1247 | torch::Tensor scales, 1248 | torch::Tensor zeros 1249 | ) { 1250 | int batch = vec.size(0); 1251 | int heads = vec.size(1); 1252 | int vec_row = vec.size(2); 1253 | int vec_height = vec.size(3); 1254 | int height = mat.size(2); 1255 | int width = mat.size(3); 1256 | int zero_width = zeros.size(2); 1257 | 1258 | dim3 blocks( 1259 | (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4, 1260 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 1261 | ); 1262 | dim3 threads(BLOCKWIDTH); 1263 | 1264 | AT_DISPATCH_FLOATING_TYPES( 1265 | vec.type(), "vecquant4matmul_batched_old_cuda", ([&] { 1266 | VecQuant4BatchMatMulKernel_old<<>>( 1267 | vec.data(), mat.data(), mul.data(), 1268 | scales.data(), zeros.data(), 1269 | batch, heads, vec_row, vec_height, height, width, zero_width 1270 | ); 1271 | }) 1272 | ); 1273 | 1274 | } 1275 | 1276 | template 1277 | __global__ void VecQuant4BatchMatMulKernel_old( 1278 | const scalar_t* __restrict__ vec, 1279 | const uint8_t* __restrict__ mat, 1280 | scalar_t* __restrict__ mul, 1281 | const scalar_t* __restrict__ scales, 1282 | const scalar_t* __restrict__ zeros, 1283 | int batch, 1284 | int heads, 1285 | int vec_row, 1286 | int vec_height, 1287 | int height, 1288 | int width, 1289 | int zero_width 1290 | ) { 1291 | int weight_total = batch * heads * height * width; 1292 | int input_total = batch * heads * vec_row * vec_height; 1293 | int out_total = batch * heads * vec_row * width; 1294 | int tid = threadIdx.x; 1295 | // h is index of height with step being BLOCKHEIGHT_OLD4 1296 | int h = BLOCKHEIGHT_OLD4 * blockIdx.x; 1297 | // w is index of width with step being 1 1298 | int w = BLOCKWIDTH * blockIdx.y + tid; 1299 | if (w >= width && tid >= vec_height) { 1300 | return; 1301 | } 1302 | 1303 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 1304 | // i is index of mat of block first row 1305 | int i = width * h + w; 1306 | int k; 1307 | scalar_t w_tmp; 1308 | 1309 | float weight[BLOCKWIDTH]; 1310 | for (int b = 0; b < batch; ++b){ 1311 | for (int head = 0; head < heads; ++head){ 1312 | int batch_shift = b * heads + head; 1313 | for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){ 1314 | int k_w = (k / 2); 1315 | int k_bit = (k % 2) * 4; 1316 | int w_index = batch_shift * height * width + i + (k_w * width); 1317 | if (w_index >= weight_total || w >= width) { 1318 | weight[k] = 0; 1319 | } else { 1320 | scalar_t scale = scales[batch_shift * width + w]; 1321 | scalar_t zero = zeros[batch_shift * width + w]; 1322 | w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); 1323 | weight[k] = scale * (w_tmp - zero); 1324 | } 1325 | } 1326 | 1327 | scalar_t res; 1328 | for (int vr = 0; vr < vec_row; ++vr){ 1329 | res = 0; 1330 | int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; 1331 | if (vec_index < input_total) { 1332 | blockvec[tid] = vec[vec_index]; 1333 | } else { 1334 | blockvec[tid] = 0; 1335 | } 1336 | 1337 | __syncthreads(); 1338 | for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){ 1339 | // res is the dot product of BLOCKWIDTH elements (part of width) 1340 | res += weight[k] * blockvec[k]; 1341 | } 1342 | // add res to the final result, final matrix shape: (batch, vec_row, width) 1343 | int out_index = (batch_shift * vec_row + vr) * width + w; 1344 | if (out_index < out_total) { 1345 | atomicAdd(&mul[out_index], res); 1346 | } 1347 | __syncthreads(); 1348 | } 1349 | } 1350 | } 1351 | } 1352 | 1353 | 1354 | 1355 | 1356 | 1357 | void vecquant4matmul_batched_column_compression_old_cuda( 1358 | torch::Tensor vec, 1359 | torch::Tensor mat, 1360 | torch::Tensor mul, 1361 | torch::Tensor scales, 1362 | torch::Tensor zeros 1363 | ) { 1364 | int batch = vec.size(0); 1365 | int heads = vec.size(1); 1366 | int vec_row = vec.size(2); 1367 | int height = vec.size(3); 1368 | int width = mat.size(3); 1369 | 1370 | dim3 blocks( 1371 | (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4, 1372 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 1373 | ); 1374 | dim3 threads(BLOCKWIDTH); 1375 | 1376 | AT_DISPATCH_FLOATING_TYPES( 1377 | vec.type(), "vecquant4matmul_batched_column_compression_old_cuda", ([&] { 1378 | VecQuant4BatchMatMulColumnCompressionKernel_old<<>>( 1379 | vec.data(), mat.data(), mul.data(), 1380 | scales.data(), zeros.data(), 1381 | batch, heads, vec_row, height, width 1382 | ); 1383 | }) 1384 | ); 1385 | 1386 | } 1387 | 1388 | template 1389 | __global__ void VecQuant4BatchMatMulColumnCompressionKernel_old( 1390 | const scalar_t* __restrict__ vec, 1391 | const uint8_t* __restrict__ mat, 1392 | scalar_t* __restrict__ mul, 1393 | const scalar_t* __restrict__ scales, 1394 | const scalar_t* __restrict__ zeros, 1395 | int batch, 1396 | int heads, 1397 | int vec_row, 1398 | int height, 1399 | int width 1400 | ) { 1401 | int weight_total = batch * heads * height * width; 1402 | int input_total = batch * heads * vec_row * height; 1403 | int out_total = batch * heads * vec_row * width; 1404 | int tid = threadIdx.x; 1405 | // h is index of height with step being BLOCKWIDTH 1406 | int h = BLOCKHEIGHT_OLD4 * blockIdx.x; 1407 | // w is index of width with step being 1 1408 | int w = BLOCKWIDTH * blockIdx.y + tid; 1409 | if (w >= width && tid >= height) { 1410 | return; 1411 | } 1412 | 1413 | __shared__ scalar_t blockvec[BLOCKWIDTH]; 1414 | int k; 1415 | scalar_t w_tmp; 1416 | 1417 | float weight[BLOCKWIDTH]; 1418 | 1419 | for (int b = 0; b < batch; ++b){ 1420 | for (int head = 0; head < heads; ++head){ 1421 | int batch_shift = b * heads + head; 1422 | for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){ 1423 | int k_w = (k / 2); 1424 | int k_bit = (k % 2) * 4; 1425 | int w_index = (batch_shift * height + h + k) * width + k_w; 1426 | if (w_index >= weight_total || w >= width) { 1427 | weight[k] = 0; 1428 | } else { 1429 | scalar_t scale = scales[batch_shift * height + h + k]; 1430 | scalar_t zero = zeros[batch_shift * height + h + k]; 1431 | w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); 1432 | weight[k] = scale * (w_tmp - zero); 1433 | } 1434 | } 1435 | 1436 | scalar_t res; 1437 | for (int vr = 0; vr < vec_row; ++vr){ 1438 | res = 0; 1439 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 1440 | if (vec_index < input_total) { 1441 | blockvec[tid] = vec[vec_index]; 1442 | } else { 1443 | blockvec[tid] = 0; 1444 | } 1445 | 1446 | __syncthreads(); 1447 | for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){ 1448 | // res is the dot product of BLOCKWIDTH elements (part of width) 1449 | res += weight[k] * blockvec[k]; 1450 | } 1451 | // add res to the final result, final matrix shape: (batch, vec_row, width) 1452 | int out_index = (batch_shift * vec_row + vr) * width + w; 1453 | if (out_index < out_total) { 1454 | atomicAdd(&mul[out_index], res); 1455 | } 1456 | __syncthreads(); 1457 | } 1458 | } 1459 | } 1460 | } 1461 | 1462 | 1463 | 1464 | 1465 | 1466 | void vecquant8matmul_batched_faster_old_cuda( 1467 | torch::Tensor vec, 1468 | torch::Tensor mat, 1469 | torch::Tensor mul, 1470 | torch::Tensor scales, 1471 | torch::Tensor zeros 1472 | ) { 1473 | int batch = vec.size(0); 1474 | int heads = vec.size(1); 1475 | int vec_row = vec.size(2); 1476 | int vec_height = vec.size(3); 1477 | int height = mat.size(2); 1478 | int width = mat.size(3); 1479 | 1480 | dim3 blocks( 1481 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 1482 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 1483 | ); 1484 | dim3 threads(BLOCKWIDTH); 1485 | 1486 | VecQuant8BatchMatMulKernel_faster_old<<>>( 1487 | (half*) vec.data_ptr(), 1488 | (uint8_t*) mat.data_ptr(), 1489 | (half*) mul.data_ptr(), 1490 | (half*) scales.data_ptr(), 1491 | (half*) zeros.data_ptr(), 1492 | batch, heads, vec_row, vec_height, height, width 1493 | ); 1494 | } 1495 | 1496 | 1497 | __global__ void VecQuant8BatchMatMulKernel_faster_old( 1498 | const half* __restrict__ vec, 1499 | const uint8_t* __restrict__ mat, 1500 | half* __restrict__ mul, 1501 | const half* __restrict__ scales, 1502 | const half* __restrict__ zeros, 1503 | int batch, 1504 | int heads, 1505 | int vec_row, 1506 | int vec_height, 1507 | int height, 1508 | int width 1509 | ) { 1510 | int weight_total = batch * heads * height * width; 1511 | int input_total = batch * heads * vec_row * vec_height; 1512 | int out_total = batch * heads * vec_row * width; 1513 | int tid = threadIdx.x; 1514 | const int BLOCKWIDTH_half = BLOCKWIDTH/2; 1515 | 1516 | int h = BLOCKWIDTH * blockIdx.x; //head_dim, dim=-1 1517 | int w = BLOCKWIDTH * blockIdx.y + tid; //seq-len, +0-256 ,dim=-2 1518 | /* 1519 | if (w >= width && tid >= vec_height) { 1520 | return; 1521 | } 1522 | */ 1523 | __shared__ half blockvec[BLOCKWIDTH]; //256 1524 | int i = width * h + w; 1525 | int k; 1526 | 1527 | half w_tmp1 = __float2half(0); 1528 | half w_tmp2 = __float2half(0); 1529 | 1530 | half2 weight[BLOCKWIDTH_half]; 1531 | for (int b = 0; b < batch; ++b){ 1532 | for (int head = 0; head < heads; ++head){ 1533 | int batch_shift = b * heads + head; 1534 | //int zero_index = batch_shift; 1535 | for (k = 0; k < BLOCKWIDTH_half; ++k){ 1536 | int w_index1 = batch_shift * height * width + i + (2 * k * width); // [batch,head,h+k, w] 1537 | int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width); 1538 | int zero_index = batch_shift * width + w; // [batch,head, w] 1539 | if (w_index1 >= weight_total || w >= width || (2 * k + h) >= height) { 1540 | weight[k] = __float2half2_rn(0); 1541 | } else { 1542 | float zero_f=__half2float(zeros[zero_index]); 1543 | float scale_f= __half2float(scales[zero_index]); 1544 | if (w_index2 >= weight_total){ 1545 | w_tmp1 = __float2half((as_unsigned(mat[w_index1]) -zero_f)*scale_f); 1546 | w_tmp2 = __float2half(0); 1547 | weight[k] = __halves2half2(w_tmp1,w_tmp2); 1548 | //printf("zero_index is %d w is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,w,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k])); 1549 | }else{ 1550 | w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1])); 1551 | w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2])); 1552 | 1553 | //weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero,zero)),__halves2half2(scale,scale)); 1554 | weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f))); 1555 | //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k])); 1556 | } 1557 | } 1558 | } 1559 | 1560 | 1561 | for (int vr = 0; vr < vec_row; ++vr){ 1562 | float res=0; 1563 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 1564 | int out_index = (batch_shift * vec_row + vr) * width + w; 1565 | if (vec_index < input_total) { 1566 | //blockvec[tid] = __half2float(vec[vec_index]);// [batch, head, vr, tid(seq_len dim+)] 1567 | blockvec[tid] = vec[vec_index]; 1568 | //printf("width is %d height is %d h is %d w is %d vec_index is %d out_index is %d vec_row is %d vec_height is %d,vr is %d tid is %d blockvec is %f\n",width,height, h,w,vec_index,out_index,vec_row,vec_height,vr,tid,blockvec[tid]); 1569 | } else { 1570 | blockvec[tid] = __float2half(0); 1571 | } 1572 | __syncthreads(); 1573 | if (out_index < out_total) { 1574 | for (k = 0; k < BLOCKWIDTH_half; ++k){ 1575 | half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1])); 1576 | res += __low2float(res2) + __high2float(res2); 1577 | } 1578 | atomicAdd(&mul[out_index], __float2half(res)); 1579 | } 1580 | __syncthreads(); 1581 | } 1582 | } 1583 | } 1584 | } 1585 | 1586 | 1587 | void vecquant8matmul_batched_column_compression_faster_old_cuda( 1588 | torch::Tensor vec, // [batch,heads, seq_q, seq_v] 1589 | torch::Tensor mat, // [batch,heads, seq_v, head_dim] 1590 | torch::Tensor mul, // [batch,heads, seq_q,head_dim] 1591 | torch::Tensor scales, // [batch,heads, head_dim] 1592 | torch::Tensor zeros 1593 | ) { 1594 | int batch = vec.size(0); 1595 | int heads = vec.size(1); 1596 | int vec_row = vec.size(2); //ql 1597 | int height = mat.size(2); //vl 1598 | int width = mat.size(3); //head_dim 1599 | 1600 | dim3 blocks( 1601 | (height + BLOCKWIDTH - 1) / BLOCKWIDTH, 1602 | (width + BLOCKWIDTH - 1) / BLOCKWIDTH 1603 | ); 1604 | dim3 threads(BLOCKWIDTH); 1605 | 1606 | VecQuant8BatchMatMulColumnCompressionKernel_faster_old<<>>( 1607 | (half*) vec.data_ptr(), 1608 | (uint8_t*) mat.data_ptr(), 1609 | (half*) mul.data_ptr(), 1610 | (half*) scales.data_ptr(), 1611 | (half*) zeros.data_ptr(), 1612 | batch, heads, vec_row, height, width 1613 | ); 1614 | 1615 | } 1616 | 1617 | 1618 | __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old( 1619 | const half* __restrict__ vec, // [batch,heads, seq_q, seq_v] 1620 | const uint8_t* __restrict__ mat, // [batch,heads, seq_v, head_dim] 1621 | half* __restrict__ mul, // [batch,heads, seq_q,head_dim] 1622 | const half* __restrict__ scales, // [batch,heads, seq_v] 1623 | const half* __restrict__ zeros, 1624 | int batch, 1625 | int heads, 1626 | int vec_row, //seq_q 1627 | int height, //seq_v 1628 | int width //head_dim 1629 | ) { 1630 | int weight_total = batch * heads * height * width; 1631 | int input_total = batch * heads * vec_row * height; 1632 | int out_total = batch * heads * vec_row * width; 1633 | int tid = threadIdx.x; 1634 | int h = BLOCKWIDTH * blockIdx.x; // vl 1635 | int w = BLOCKWIDTH * blockIdx.y + tid; //head_dim + block 1636 | if (w >= width && tid >= height) { 1637 | return; 1638 | } 1639 | __shared__ half blockvec[BLOCKWIDTH]; 1640 | int k; 1641 | half w_tmp1 = __float2half(0); 1642 | half w_tmp2 = __float2half(0); 1643 | int i = width * h + w; 1644 | const int BLOCKWIDTH_half = BLOCKWIDTH/2; 1645 | half2 weight[BLOCKWIDTH_half]; 1646 | 1647 | for (int b = 0; b < batch; ++b){ 1648 | for (int head = 0; head < heads; ++head){ 1649 | int batch_shift = b * heads + head; 1650 | //int zero_index = batch_shift; 1651 | for (k = 0; k < BLOCKWIDTH_half; ++k){ 1652 | int w_index1 = batch_shift * height * width + i + (2 * k) * width; // [batch,head, h+k, w] 1653 | int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width); 1654 | int zero_index1 = batch_shift * height + h + 2*k; // [batch,head, w] 1655 | int zero_index2 = batch_shift * height + h + 2*k+1; // [batch,head, w] 1656 | 1657 | if (w_index1 >= weight_total || (2 * k + h)>=height) { 1658 | weight[k]=__float2half2_rn(0); 1659 | } else{ 1660 | //int zero_index = batch_shift + h; // [batch,head, w] 1661 | //float scale_f1 = __half2float(scales[zero_index1]); 1662 | //float zero_f1 = __half2float(zeros[zero_index1]); 1663 | if (w_index2>=weight_total){ 1664 | w_tmp1 = __float2half((as_unsigned(mat[w_index1]) - __half2float(zeros[zero_index1]))* __half2float(scales[zero_index1])); 1665 | w_tmp2 = __float2half(0); 1666 | weight[k] = __halves2half2(w_tmp1,w_tmp2); 1667 | //printf("zero_index is %d k is %d w is %d head is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,k,w,head,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k])); 1668 | }else{ 1669 | w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1])); 1670 | w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2])); 1671 | half zero1=zeros[zero_index1]; 1672 | half zero2=zeros[zero_index2]; 1673 | half scale1=scales[zero_index1]; 1674 | half scale2=scales[zero_index2]; 1675 | weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero1,zero2)),__halves2half2(scale1,scale2)); 1676 | //weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f))); 1677 | //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k])); 1678 | } 1679 | } 1680 | } 1681 | 1682 | 1683 | for (int vr = 0; vr < vec_row; ++vr){ 1684 | float res=0; 1685 | int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; 1686 | int out_index = (batch_shift * vec_row + vr) * width + w; 1687 | 1688 | if (vec_index < input_total) { 1689 | //blockvec[tid] = __half2float(vec[vec_index]); 1690 | blockvec[tid] = vec[vec_index]; 1691 | //printf("vec_index is %d out_index is %d vec_row is %d ,vr is %d tid is %d blockvec is %f\n",vec_index,out_index,vec_row,vr,tid,blockvec[tid]); 1692 | } else { 1693 | blockvec[tid] = __float2half(0); 1694 | //blockvec[tid] = 0; 1695 | } 1696 | __syncthreads(); 1697 | if (out_index < out_total) { 1698 | for (k = 0; k < BLOCKWIDTH_half; ++k){ 1699 | half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1])); 1700 | res += __low2float(res2) + __high2float(res2); 1701 | } 1702 | atomicAdd(&mul[out_index], __float2half(res)); 1703 | } 1704 | __syncthreads(); 1705 | } 1706 | } 1707 | } 1708 | } 1709 | -------------------------------------------------------------------------------- /qwen/configuration_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import PretrainedConfig 7 | 8 | 9 | class QWenConfig(PretrainedConfig): 10 | model_type = "qwen" 11 | keys_to_ignore_at_inference = ["past_key_values"] 12 | 13 | def __init__( 14 | self, 15 | vocab_size=151936, 16 | hidden_size=4096, 17 | num_hidden_layers=32, 18 | num_attention_heads=32, 19 | emb_dropout_prob=0.0, 20 | attn_dropout_prob=0.0, 21 | layer_norm_epsilon=1e-6, 22 | initializer_range=0.02, 23 | max_position_embeddings=8192, 24 | scale_attn_weights=True, 25 | use_cache=True, 26 | bf16=False, 27 | fp16=False, 28 | fp32=False, 29 | kv_channels=128, 30 | rotary_pct=1.0, 31 | rotary_emb_base=10000, 32 | use_dynamic_ntk=True, 33 | use_logn_attn=True, 34 | use_flash_attn="auto", 35 | intermediate_size=22016, 36 | no_bias=True, 37 | tie_word_embeddings=False, 38 | use_cache_quantization=False, 39 | use_cache_kernel=False, 40 | softmax_in_fp32=False, 41 | **kwargs, 42 | ): 43 | self.vocab_size = vocab_size 44 | self.hidden_size = hidden_size 45 | self.intermediate_size = intermediate_size 46 | self.num_hidden_layers = num_hidden_layers 47 | self.num_attention_heads = num_attention_heads 48 | self.emb_dropout_prob = emb_dropout_prob 49 | self.attn_dropout_prob = attn_dropout_prob 50 | self.layer_norm_epsilon = layer_norm_epsilon 51 | self.initializer_range = initializer_range 52 | self.scale_attn_weights = scale_attn_weights 53 | self.use_cache = use_cache 54 | self.max_position_embeddings = max_position_embeddings 55 | self.bf16 = bf16 56 | self.fp16 = fp16 57 | self.fp32 = fp32 58 | self.kv_channels = kv_channels 59 | self.rotary_pct = rotary_pct 60 | self.rotary_emb_base = rotary_emb_base 61 | self.use_dynamic_ntk = use_dynamic_ntk 62 | self.use_logn_attn = use_logn_attn 63 | self.use_flash_attn = use_flash_attn 64 | self.no_bias = no_bias 65 | self.use_cache_quantization = use_cache_quantization 66 | self.use_cache_kernel = use_cache_kernel 67 | self.softmax_in_fp32 = softmax_in_fp32 68 | super().__init__( 69 | tie_word_embeddings=tie_word_embeddings, 70 | **kwargs 71 | ) 72 | -------------------------------------------------------------------------------- /qwen/cpp_kernels.py: -------------------------------------------------------------------------------- 1 | from torch.utils import cpp_extension 2 | import pathlib 3 | import os 4 | import subprocess 5 | 6 | def _get_cuda_bare_metal_version(cuda_dir): 7 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 8 | universal_newlines=True) 9 | output = raw_output.split() 10 | release_idx = output.index("release") + 1 11 | release = output[release_idx].split(".") 12 | bare_metal_major = release[0] 13 | bare_metal_minor = release[1][0] 14 | 15 | return raw_output, bare_metal_major, bare_metal_minor 16 | 17 | def _create_build_dir(buildpath): 18 | try: 19 | os.mkdir(buildpath) 20 | except OSError: 21 | if not os.path.isdir(buildpath): 22 | print(f"Creation of the build directory {buildpath} failed") 23 | 24 | # Check if cuda 11 is installed for compute capability 8.0 25 | cc_flag = [] 26 | _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) 27 | if int(bare_metal_major) >= 11: 28 | cc_flag.append('-gencode') 29 | cc_flag.append('arch=compute_80,code=sm_80') 30 | if int(bare_metal_minor) >= 7: 31 | cc_flag.append('-gencode') 32 | cc_flag.append('arch=compute_90,code=sm_90') 33 | 34 | # Build path 35 | srcpath = pathlib.Path(__file__).parent.absolute() 36 | buildpath = srcpath / 'build' 37 | _create_build_dir(buildpath) 38 | 39 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags): 40 | return cpp_extension.load( 41 | name=name, 42 | sources=sources, 43 | build_directory=buildpath, 44 | extra_cflags=['-O3', ], 45 | extra_cuda_cflags=['-O3', 46 | '-gencode', 'arch=compute_70,code=sm_70', 47 | '--use_fast_math'] + extra_cuda_flags + cc_flag, 48 | verbose=1 49 | ) 50 | 51 | extra_flags = [] 52 | 53 | cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp", 54 | "./cache_autogptq_cuda_kernel_256.cu"] 55 | cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags) 56 | -------------------------------------------------------------------------------- /qwen/qwen_generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generation support.""" 7 | 8 | from typing import Tuple, List, Union, Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from transformers import PreTrainedTokenizer 14 | from transformers import logging 15 | from transformers.generation import LogitsProcessor 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | # Types. 20 | HistoryType = List[Tuple[str, str]] 21 | TokensType = List[int] 22 | BatchTokensType = List[List[int]] 23 | 24 | 25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: 26 | for tokens in batch: 27 | context_length = len(tokens) 28 | if context_length < seq_length: 29 | tokens.extend([pad_id] * (seq_length - context_length)) 30 | return batch 31 | 32 | 33 | def get_ltor_masks_and_position_ids( 34 | data, 35 | eod_token, 36 | reset_position_ids, 37 | reset_attention_mask, 38 | eod_mask_loss, 39 | ): 40 | """Build masks and position id for left to right model.""" 41 | 42 | # Extract batch size and sequence length. 43 | micro_batch_size, seq_length = data.size() 44 | 45 | # Attention mask (lower triangular). 46 | if reset_attention_mask: 47 | att_mask_batch = micro_batch_size 48 | else: 49 | att_mask_batch = 1 50 | attention_mask = torch.tril( 51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) 52 | ).view(att_mask_batch, 1, seq_length, seq_length) 53 | 54 | # Loss mask. 55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 56 | if eod_mask_loss: 57 | loss_mask[data == eod_token] = 0.0 58 | 59 | # Position ids. 60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) 61 | position_ids = position_ids.unsqueeze(0).expand_as(data) 62 | # We need to clone as the ids will be modifed based on batch index. 63 | if reset_position_ids: 64 | position_ids = position_ids.clone() 65 | 66 | if reset_position_ids or reset_attention_mask: 67 | # Loop through the batches: 68 | for b in range(micro_batch_size): 69 | 70 | # Find indecies where EOD token is. 71 | eod_index = position_ids[b, data[b] == eod_token] 72 | # Detach indecies from positions if going to modify positions. 73 | if reset_position_ids: 74 | eod_index = eod_index.clone() 75 | 76 | # Loop through EOD indecies: 77 | prev_index = 0 78 | for j in range(eod_index.size()[0]): 79 | i = eod_index[j] 80 | # Mask attention loss. 81 | if reset_attention_mask: 82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 83 | # Reset positions. 84 | if reset_position_ids: 85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index 86 | prev_index = i + 1 87 | 88 | # Convert attention mask to binary: 89 | attention_mask = attention_mask < 0.5 90 | 91 | return attention_mask, loss_mask, position_ids 92 | 93 | 94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int): 95 | """Generate batch from context tokens.""" 96 | # Move to GPU. 97 | tokens = context_tokens.contiguous().to(context_tokens.device) 98 | # Get the attention mask and postition ids. 99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 100 | tokens, 101 | eod_id, 102 | reset_position_ids=False, 103 | reset_attention_mask=False, 104 | eod_mask_loss=False, 105 | ) 106 | return tokens, attention_mask, position_ids 107 | 108 | 109 | def get_stop_words_ids(chat_format, tokenizer): 110 | if chat_format == "raw": 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 112 | elif chat_format == "chatml": 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 114 | else: 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 116 | return stop_words_ids 117 | 118 | 119 | def make_context( 120 | tokenizer: PreTrainedTokenizer, 121 | query: str, 122 | history: List[Tuple[str, str]] = None, 123 | system: str = "", 124 | max_window_size: int = 6144, 125 | chat_format: str = "chatml", 126 | ): 127 | if history is None: 128 | history = [] 129 | 130 | if chat_format == "chatml": 131 | im_start, im_end = "<|im_start|>", "<|im_end|>" 132 | im_start_tokens = [tokenizer.im_start_id] 133 | im_end_tokens = [tokenizer.im_end_id] 134 | nl_tokens = tokenizer.encode("\n") 135 | 136 | def _tokenize_str(role, content): 137 | return f"{role}\n{content}", tokenizer.encode( 138 | role, allowed_special=set() 139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) 140 | 141 | system_text, system_tokens_part = _tokenize_str("system", system) 142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 143 | 144 | raw_text = "" 145 | context_tokens = [] 146 | 147 | for turn_query, turn_response in reversed(history): 148 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 150 | response_text, response_tokens_part = _tokenize_str( 151 | "assistant", turn_response 152 | ) 153 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 154 | 155 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 156 | prev_chat = ( 157 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 158 | ) 159 | 160 | current_context_size = ( 161 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 162 | ) 163 | if current_context_size < max_window_size: 164 | context_tokens = next_context_tokens + context_tokens 165 | raw_text = prev_chat + raw_text 166 | else: 167 | break 168 | 169 | context_tokens = system_tokens + context_tokens 170 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 171 | context_tokens += ( 172 | nl_tokens 173 | + im_start_tokens 174 | + _tokenize_str("user", query)[1] 175 | + im_end_tokens 176 | + nl_tokens 177 | + im_start_tokens 178 | + tokenizer.encode("assistant") 179 | + nl_tokens 180 | ) 181 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 182 | 183 | elif chat_format == "raw": 184 | raw_text = query 185 | context_tokens = tokenizer.encode(raw_text) 186 | else: 187 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 188 | 189 | return raw_text, context_tokens 190 | 191 | 192 | def _decode_default( 193 | tokens: List[int], 194 | *, 195 | stop_words: List[str], 196 | eod_words: List[str], 197 | tokenizer: PreTrainedTokenizer, 198 | raw_text_len: int, 199 | verbose: bool = False, 200 | return_end_reason: bool = False, 201 | errors: str='replace', 202 | ): 203 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] 204 | if verbose: 205 | print("\nRaw Generate: ", trim_decode_tokens) 206 | 207 | end_reason = f"Gen length {len(tokens)}" 208 | for stop_word in stop_words: 209 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 210 | for eod_word in eod_words: 211 | if eod_word in trim_decode_tokens: 212 | end_reason = f"Gen {eod_word!r}" 213 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] 214 | trim_decode_tokens = trim_decode_tokens.strip() 215 | if verbose: 216 | print("\nEnd Reason:", end_reason) 217 | print("\nGenerate: ", trim_decode_tokens) 218 | 219 | if return_end_reason: 220 | return trim_decode_tokens, end_reason 221 | else: 222 | return trim_decode_tokens 223 | 224 | 225 | def _decode_chatml( 226 | tokens: List[int], 227 | *, 228 | stop_words: List[str], 229 | eod_token_ids: List[int], 230 | tokenizer: PreTrainedTokenizer, 231 | raw_text_len: int, 232 | context_length: int, 233 | verbose: bool = False, 234 | return_end_reason: bool = False, 235 | errors: str='replace' 236 | ): 237 | end_reason = f"Gen length {len(tokens)}" 238 | eod_token_idx = context_length 239 | for eod_token_idx in range(context_length, len(tokens)): 240 | if tokens[eod_token_idx] in eod_token_ids: 241 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" 242 | break 243 | 244 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] 245 | if verbose: 246 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) 247 | print("\nRaw Generate:", trim_decode_tokens) 248 | print("\nEnd Reason:", end_reason) 249 | for stop_word in stop_words: 250 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 251 | trim_decode_tokens = trim_decode_tokens.strip() 252 | if verbose: 253 | print("\nGenerate:", trim_decode_tokens) 254 | 255 | if return_end_reason: 256 | return trim_decode_tokens, end_reason 257 | else: 258 | return trim_decode_tokens 259 | 260 | 261 | def decode_tokens( 262 | tokens: Union[torch.LongTensor, TokensType], 263 | tokenizer: PreTrainedTokenizer, 264 | raw_text_len: int, 265 | context_length: int, 266 | chat_format: str, 267 | verbose: bool = False, 268 | return_end_reason: bool = False, 269 | errors: str="replace", 270 | ) -> str: 271 | if torch.is_tensor(tokens): 272 | tokens = tokens.cpu().numpy().tolist() 273 | 274 | if chat_format == "chatml": 275 | return _decode_chatml( 276 | tokens, 277 | stop_words=[], 278 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], 279 | tokenizer=tokenizer, 280 | raw_text_len=raw_text_len, 281 | context_length=context_length, 282 | verbose=verbose, 283 | return_end_reason=return_end_reason, 284 | errors=errors, 285 | ) 286 | elif chat_format == "raw": 287 | return _decode_default( 288 | tokens, 289 | stop_words=["<|endoftext|>"], 290 | eod_words=["<|endoftext|>"], 291 | tokenizer=tokenizer, 292 | raw_text_len=raw_text_len, 293 | verbose=verbose, 294 | return_end_reason=return_end_reason, 295 | errors=errors, 296 | ) 297 | else: 298 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 299 | 300 | 301 | class StopWordsLogitsProcessor(LogitsProcessor): 302 | """ 303 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. 304 | 305 | Args: 306 | stop_words_ids (:obj:`List[List[int]]`): 307 | List of list of token ids of stop ids. In order to get the tokens of the words 308 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 309 | add_prefix_space=True).input_ids`. 310 | eos_token_id (:obj:`int`): 311 | The id of the `end-of-sequence` token. 312 | """ 313 | 314 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): 315 | 316 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: 317 | raise ValueError( 318 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." 319 | ) 320 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): 321 | raise ValueError( 322 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." 323 | ) 324 | if any( 325 | any( 326 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) 327 | for token_id in stop_word_ids 328 | ) 329 | for stop_word_ids in stop_words_ids 330 | ): 331 | raise ValueError( 332 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." 333 | ) 334 | 335 | self.stop_words_ids = list( 336 | filter( 337 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids 338 | ) 339 | ) 340 | self.eos_token_id = eos_token_id 341 | for stop_token_seq in self.stop_words_ids: 342 | assert ( 343 | len(stop_token_seq) > 0 344 | ), "Stop words token sequences {} cannot have an empty list".format( 345 | stop_words_ids 346 | ) 347 | 348 | def __call__( 349 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 350 | ) -> torch.FloatTensor: 351 | stopped_samples = self._calc_stopped_samples(input_ids) 352 | for i, should_stop in enumerate(stopped_samples): 353 | if should_stop: 354 | scores[i, self.eos_token_id] = float(2**15) 355 | return scores 356 | 357 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 358 | if len(tokens) == 0: 359 | # if bad word tokens is just one token always ban it 360 | return True 361 | elif len(tokens) > len(prev_tokens): 362 | # if bad word tokens are longer then prev input_ids they can't be equal 363 | return False 364 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 365 | # if tokens match 366 | return True 367 | else: 368 | return False 369 | 370 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 371 | stopped_samples = [] 372 | for prev_input_ids_slice in prev_input_ids: 373 | match = False 374 | for stop_token_seq in self.stop_words_ids: 375 | if self._tokens_match(prev_input_ids_slice, stop_token_seq): 376 | # if tokens do not match continue 377 | match = True 378 | break 379 | stopped_samples.append(match) 380 | 381 | return stopped_samples 382 | 383 | 384 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 385 | """This function has been mostly taken from huggingface conversational 386 | ai code at 387 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- 388 | conversational-ai-with-transfer-learning-2d818ac26313""" 389 | 390 | if top_k > 0: 391 | # Remove all tokens with a probability less than the 392 | # last token of the top-k 393 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 394 | logits[indices_to_remove] = filter_value 395 | 396 | if top_p > 0.0: 397 | # Cconvert to 1D 398 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 399 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 400 | 401 | # Remove tokens with cumulative probability above the threshold 402 | sorted_indices_to_remove = cumulative_probs > top_p 403 | # Shift the indices to the right to keep also the first token 404 | # above the threshold 405 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 406 | sorted_indices_to_remove[..., 0] = 0 407 | for i in range(sorted_indices.size(0)): 408 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] 409 | logits[i][indices_to_remove] = filter_value 410 | 411 | return logits 412 | 413 | 414 | def switch(val1, val2, boolean): 415 | boolean = boolean.type_as(val1) 416 | return (1 - boolean) * val1 + boolean * val2 417 | -------------------------------------------------------------------------------- /qwen/tokenization_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Tokenization classes for QWen.""" 7 | 8 | import base64 9 | import logging 10 | import os 11 | import unicodedata 12 | from typing import Collection, Dict, List, Set, Tuple, Union 13 | 14 | import tiktoken 15 | from transformers import PreTrainedTokenizer, AddedToken 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} 21 | 22 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 23 | ENDOFTEXT = "<|endoftext|>" 24 | IMSTART = "<|im_start|>" 25 | IMEND = "<|im_end|>" 26 | # as the default behavior is changed to allow special tokens in 27 | # regular texts, the surface forms of special tokens need to be 28 | # as different as possible to minimize the impact 29 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) 30 | # changed to use actual index to avoid misconfiguration with vocabulary expansion 31 | SPECIAL_START_ID = 151643 32 | SPECIAL_TOKENS = tuple( 33 | enumerate( 34 | ( 35 | ( 36 | ENDOFTEXT, 37 | IMSTART, 38 | IMEND, 39 | ) 40 | + EXTRAS 41 | ), 42 | start=SPECIAL_START_ID, 43 | ) 44 | ) 45 | SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) 46 | 47 | 48 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: 49 | with open(tiktoken_bpe_file, "rb") as f: 50 | contents = f.read() 51 | return { 52 | base64.b64decode(token): int(rank) 53 | for token, rank in (line.split() for line in contents.splitlines() if line) 54 | } 55 | 56 | 57 | class QWenTokenizer(PreTrainedTokenizer): 58 | """QWen tokenizer.""" 59 | 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | 62 | def __init__( 63 | self, 64 | vocab_file, 65 | errors="replace", 66 | extra_vocab_file=None, 67 | **kwargs, 68 | ): 69 | super().__init__(**kwargs) 70 | 71 | # how to handle errors in decoding UTF-8 byte sequences 72 | # use ignore if you are in streaming inference 73 | self.errors = errors 74 | 75 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] 76 | self.special_tokens = { 77 | token: index 78 | for index, token in SPECIAL_TOKENS 79 | } 80 | 81 | # try load extra vocab from file 82 | if extra_vocab_file is not None: 83 | used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) 84 | extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) 85 | for token, index in extra_mergeable_ranks.items(): 86 | if token in self.mergeable_ranks: 87 | logger.info(f"extra token {token} exists, skipping") 88 | continue 89 | if index in used_ids: 90 | logger.info(f'the index {index} for extra token {token} exists, skipping') 91 | continue 92 | self.mergeable_ranks[token] = index 93 | # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this 94 | 95 | enc = tiktoken.Encoding( 96 | "Qwen", 97 | pat_str=PAT_STR, 98 | mergeable_ranks=self.mergeable_ranks, 99 | special_tokens=self.special_tokens, 100 | ) 101 | assert ( 102 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab 103 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" 104 | 105 | self.decoder = { 106 | v: k for k, v in self.mergeable_ranks.items() 107 | } # type: dict[int, bytes|str] 108 | self.decoder.update({v: k for k, v in self.special_tokens.items()}) 109 | 110 | self.tokenizer = enc # type: tiktoken.Encoding 111 | 112 | self.eod_id = self.tokenizer.eot_token 113 | self.im_start_id = self.special_tokens[IMSTART] 114 | self.im_end_id = self.special_tokens[IMEND] 115 | 116 | def __getstate__(self): 117 | # for pickle lovers 118 | state = self.__dict__.copy() 119 | del state["tokenizer"] 120 | return state 121 | 122 | def __setstate__(self, state): 123 | # tokenizer is not python native; don't pass it; rebuild it 124 | self.__dict__.update(state) 125 | enc = tiktoken.Encoding( 126 | "Qwen", 127 | pat_str=PAT_STR, 128 | mergeable_ranks=self.mergeable_ranks, 129 | special_tokens=self.special_tokens, 130 | ) 131 | self.tokenizer = enc 132 | 133 | def __len__(self) -> int: 134 | return self.tokenizer.n_vocab 135 | 136 | def get_vocab(self) -> Dict[bytes, int]: 137 | return self.mergeable_ranks 138 | 139 | def convert_tokens_to_ids( 140 | self, tokens: Union[bytes, str, List[Union[bytes, str]]] 141 | ) -> List[int]: 142 | ids = [] 143 | if isinstance(tokens, (str, bytes)): 144 | if tokens in self.special_tokens: 145 | return self.special_tokens[tokens] 146 | else: 147 | return self.mergeable_ranks.get(tokens) 148 | for token in tokens: 149 | if token in self.special_tokens: 150 | ids.append(self.special_tokens[token]) 151 | else: 152 | ids.append(self.mergeable_ranks.get(token)) 153 | return ids 154 | 155 | def _add_tokens( 156 | self, 157 | new_tokens: Union[List[str], List[AddedToken]], 158 | special_tokens: bool = False, 159 | ) -> int: 160 | if not special_tokens and new_tokens: 161 | raise ValueError("Adding regular tokens is not supported") 162 | for token in new_tokens: 163 | surface_form = token.content if isinstance(token, AddedToken) else token 164 | if surface_form not in SPECIAL_TOKENS_SET: 165 | raise ValueError("Adding unknown special tokens is not supported") 166 | return 0 167 | 168 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: 169 | """ 170 | Save only the vocabulary of the tokenizer (vocabulary). 171 | 172 | Returns: 173 | `Tuple(str)`: Paths to the files saved. 174 | """ 175 | file_path = os.path.join(save_directory, "qwen.tiktoken") 176 | with open(file_path, "w", encoding="utf8") as w: 177 | for k, v in self.mergeable_ranks.items(): 178 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" 179 | w.write(line) 180 | return (file_path,) 181 | 182 | def tokenize( 183 | self, 184 | text: str, 185 | allowed_special: Union[Set, str] = "all", 186 | disallowed_special: Union[Collection, str] = (), 187 | **kwargs, 188 | ) -> List[Union[bytes, str]]: 189 | """ 190 | Converts a string in a sequence of tokens. 191 | 192 | Args: 193 | text (`str`): 194 | The sequence to be encoded. 195 | allowed_special (`Literal["all"]` or `set`): 196 | The surface forms of the tokens to be encoded as special tokens in regular texts. 197 | Default to "all". 198 | disallowed_special (`Literal["all"]` or `Collection`): 199 | The surface forms of the tokens that should not be in regular texts and trigger errors. 200 | Default to an empty tuple. 201 | 202 | kwargs (additional keyword arguments, *optional*): 203 | Will be passed to the underlying model specific encode method. 204 | 205 | Returns: 206 | `List[bytes|str]`: The list of tokens. 207 | """ 208 | tokens = [] 209 | text = unicodedata.normalize("NFC", text) 210 | 211 | # this implementation takes a detour: text -> token id -> token surface forms 212 | for t in self.tokenizer.encode( 213 | text, allowed_special=allowed_special, disallowed_special=disallowed_special 214 | ): 215 | tokens.append(self.decoder[t]) 216 | return tokens 217 | 218 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: 219 | """ 220 | Converts a sequence of tokens in a single string. 221 | """ 222 | text = "" 223 | temp = b"" 224 | for t in tokens: 225 | if isinstance(t, str): 226 | if temp: 227 | text += temp.decode("utf-8", errors=self.errors) 228 | temp = b"" 229 | text += t 230 | elif isinstance(t, bytes): 231 | temp += t 232 | else: 233 | raise TypeError("token should only be of type types or str") 234 | if temp: 235 | text += temp.decode("utf-8", errors=self.errors) 236 | return text 237 | 238 | @property 239 | def vocab_size(self): 240 | return self.tokenizer.n_vocab 241 | 242 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]: 243 | """Converts an id to a token, special tokens included""" 244 | if index in self.decoder: 245 | return self.decoder[index] 246 | raise ValueError("unknown ids") 247 | 248 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int: 249 | """Converts a token to an id using the vocab, special tokens included""" 250 | if token in self.special_tokens: 251 | return self.special_tokens[token] 252 | if token in self.mergeable_ranks: 253 | return self.mergeable_ranks[token] 254 | raise ValueError("unknown token") 255 | 256 | def _tokenize(self, text: str, **kwargs): 257 | """ 258 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based 259 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). 260 | 261 | Do NOT take care of added tokens. 262 | """ 263 | raise NotImplementedError 264 | 265 | def _decode( 266 | self, 267 | token_ids: Union[int, List[int]], 268 | skip_special_tokens: bool = False, 269 | errors: str = None, 270 | **kwargs, 271 | ) -> str: 272 | if isinstance(token_ids, int): 273 | token_ids = [token_ids] 274 | if skip_special_tokens: 275 | token_ids = [i for i in token_ids if i < self.eod_id] 276 | return self.tokenizer.decode(token_ids, errors=errors or self.errors) 277 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | annotated-types==0.7.0 3 | certifi==2022.12.7 4 | charset-normalizer==2.1.1 5 | cmake==3.25.0 6 | deepspeed==0.14.2 7 | einops==0.8.0 8 | filelock==3.13.1 9 | fsspec==2024.2.0 10 | hjson==3.1.0 11 | huggingface-hub==0.23.2 12 | idna==3.4 13 | Jinja2==3.1.3 14 | lit==15.0.7 15 | MarkupSafe==2.1.5 16 | mpmath==1.3.0 17 | networkx==3.2.1 18 | ninja==1.11.1.1 19 | numpy==1.26.3 20 | packaging==24.0 21 | peft==0.11.1 22 | pillow==10.2.0 23 | psutil==5.9.8 24 | py-cpuinfo==9.0.0 25 | pydantic==2.7.3 26 | pydantic_core==2.18.4 27 | pynvml==11.5.0 28 | PyYAML==6.0.1 29 | regex==2024.5.15 30 | requests==2.28.1 31 | safetensors==0.4.3 32 | sympy==1.12 33 | tiktoken==0.7.0 34 | tokenizers==0.19.1 35 | torch==2.0.1+cu118 36 | torchaudio==2.0.2+cu118 37 | torchvision==0.15.2+cu118 38 | tqdm==4.66.4 39 | transformers==4.41.2 40 | triton==2.0.0 41 | typing_extensions==4.9.0 42 | urllib3==1.26.13 43 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import AutoTokenizer, SiglipProcessor 4 | from torchvision import transforms 5 | from PIL import Image 6 | 7 | from model.model import MMultiModal, LanguageConfig, VisualConfig, MultiModalConfig 8 | from qwen.qwen_generation_utils import make_context 9 | 10 | 11 | def image_process(image): 12 | mean=[0.485, 0.456, 0.406] # RGB 13 | std=[0.229, 0.224, 0.225] # RGB 14 | 15 | tran = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean, std), 18 | transforms.Resize([384, 384]) 19 | ]) 20 | 21 | return tran(image) 22 | 23 | def main(args): 24 | tokenizer = AutoTokenizer.from_pretrained(args.base_language_model, trust_remote_code=True) 25 | replace_token_id = tokenizer.convert_tokens_to_ids("<|extra_0|>") 26 | 27 | model = MMultiModal(LanguageConfig(model_path=args.base_language_model), 28 | VisualConfig(model_path=args.base_value_model), 29 | MultiModalConfig(replace_token_id=replace_token_id), 30 | train=False).cuda() 31 | model.load(args.model_weights) 32 | 33 | prompt = args.prompt 34 | 35 | image_processor = SiglipProcessor.from_pretrained(args.base_value_model) 36 | image = Image.open(args.image_path).convert("RGB") 37 | image_pt = image_processor(images=image, return_tensors="pt")["pixel_values"].cuda().to(torch.bfloat16) 38 | # image_pt = image_process(image).unsqueeze(0).cuda().to(torch.bfloat16) 39 | # print(image_pt1.shape, image_pt.shape) 40 | messages = [{"role": "system", "content": "你是一位图像理解助手。"}, {"role": "user", "content": "用中文回答:"+prompt}] 41 | raw_text, context_tokens = make_context( 42 | tokenizer, 43 | "用中文回答:"+prompt, 44 | history=[], 45 | system="你是一位图像理解助手。" 46 | ) 47 | question_ids = tokenizer.encode(raw_text) 48 | 49 | result = model.generate(image_pt, question_ids) 50 | result = tokenizer.decode(result[0]) 51 | print(result) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser(description="Image and Text Processing with MultiModal Model") 56 | parser.add_argument("--base_language_model", type=str, required=True, help="Path to the base language model") 57 | parser.add_argument("--base_value_model", type=str, required=True, help="Path to the base value model") 58 | parser.add_argument("--model_weights", type=str, required=True, help="Path to the model weights") 59 | parser.add_argument("--image_path", type=str, required=True, help="Path to the input image") 60 | parser.add_argument("--prompt", type=str, required=True, help="Prompt for the model") 61 | 62 | args = parser.parse_args() 63 | main(args) 64 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py --base_language_model Qwen/Qwen-7B-Chat --base_value_model google/siglip-so400m-patch14-384 --model_weights ./weights/train_V1_5/checkpoint-36000/ --image_path ./test_img/1.jpg --prompt "使用语言描述一下这幅图<|extra_0|>" 2 | -------------------------------------------------------------------------------- /test_img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyanghuang7/Basic-Visual-Language-Model/559d07ca75a3751c427d7990fe0b598de4b68f14/test_img/1.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from typing import Optional 5 | from functools import partial 6 | from peft import LoraConfig, TaskType, get_peft_model, PeftModel 7 | 8 | from trainer import MultiModalTrainer 9 | from model.model import MMultiModal, LanguageConfig, VisualConfig, MultiModalConfig 10 | from dataset.image_caption_dataset import ImageCaptionDataset, data_collate 11 | 12 | import transformers 13 | from transformers import HfArgumentParser, AutoTokenizer 14 | from dataclasses import dataclass, field 15 | 16 | from qwen.modeling_qwen import QWenLMHeadModel 17 | 18 | from einops import rearrange 19 | 20 | from accelerate import Accelerator 21 | 22 | @dataclass 23 | class FinetuneArguments: 24 | lora_rank: int = field(default=8) 25 | lora_dropout: float = field(default=0.1) 26 | previous_lora_weights: Optional[str] = field(default=None) 27 | target_modules: str = field(default="W_pack") 28 | image_map: str = field(default="data/image_map_b.json", metadata={"help": "图像文件与索引ID"}) 29 | captions_file: str = field(default="data/captions_b.json", metadata={"help": "ID与caption的对应"}) 30 | 31 | @dataclass 32 | class TrainingArguments(transformers.TrainingArguments): 33 | feature_proj_lr: Optional[float] = None 34 | 35 | def train(): 36 | finetune_args, training_args = HfArgumentParser( 37 | (FinetuneArguments, TrainingArguments) 38 | ).parse_args_into_dataclasses() 39 | 40 | base_language_model = "Qwen/Qwen-7B-Chat" 41 | # base_language_model = "openbmb/MiniCPM-2B-history" 42 | 43 | # base_value_model = "openai/clip-vit-large-patch14" 44 | base_value_model = "google/siglip-so400m-patch14-384" 45 | 46 | tokenizer = AutoTokenizer.from_pretrained(base_language_model, trust_remote_code=True) 47 | replace_token_id = tokenizer.convert_tokens_to_ids("<|extra_0|>") 48 | 49 | # Check file paths 50 | if not os.path.exists(finetune_args.image_map): 51 | raise FileNotFoundError(f"Image map file not found: {finetune_args.image_map}") 52 | 53 | if not os.path.exists(finetune_args.captions_file): 54 | raise FileNotFoundError(f"Captions file not found: {finetune_args.captions_file}") 55 | 56 | # Load and check file contents 57 | with open(finetune_args.image_map, 'r') as f: 58 | image_map = json.load(f) 59 | print(f"Image map contains {len(image_map)} entries") 60 | 61 | with open(finetune_args.captions_file, 'r') as f: 62 | captions = json.load(f) 63 | print(f"Captions file contains {len(captions)} entries") 64 | 65 | model = MMultiModal( 66 | LanguageConfig(model_path=base_language_model), 67 | VisualConfig(model_path=base_value_model), 68 | MultiModalConfig(replace_token_id=replace_token_id), 69 | finetune_args, 70 | train=True 71 | ).cuda() 72 | model.train() 73 | model.LLM.config.use_cache = False 74 | 75 | dataset = ImageCaptionDataset( 76 | tokenizer, 77 | finetune_args.image_map, 78 | finetune_args.captions_file, 79 | VisualConfig(model_path=base_value_model), 80 | max_train_data_item=300000 81 | ) 82 | 83 | # Add debug information 84 | print(f"Dataset length: {len(dataset)}") 85 | if len(dataset) == 0: 86 | raise ValueError("The dataset is empty. Please check the dataset files and paths.") 87 | 88 | print(training_args) 89 | 90 | # Initialize Accelerator 91 | accelerator = Accelerator() 92 | 93 | # Create DataLoader 94 | train_dataloader = torch.utils.data.DataLoader( 95 | dataset, 96 | batch_size=training_args.per_device_train_batch_size, 97 | shuffle=True, 98 | collate_fn=partial(data_collate, tokenizer=tokenizer, black_token_length=MultiModalConfig.image_context_length) 99 | ) 100 | 101 | trainer = MultiModalTrainer( 102 | model=model, 103 | data_collator=partial(data_collate, tokenizer=tokenizer, black_token_length=MultiModalConfig.image_context_length), 104 | train_dataset=dataset, 105 | args=training_args 106 | ) 107 | 108 | # Prepare the trainer and dataloader with the accelerator 109 | trainer, train_dataloader = accelerator.prepare(trainer, train_dataloader) 110 | 111 | trainer.train() 112 | 113 | def main(): 114 | torch.distributed.init_process_group(backend='nccl') 115 | train() 116 | torch.distributed.destroy_process_group() 117 | 118 | if __name__ == "__main__": 119 | main() -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=25642 train.py \ 2 | --lora_rank 128 \ 3 | --lora_dropout 0.10 \ 4 | --per_device_train_batch_size 4 \ 5 | --gradient_accumulation_steps 1 \ 6 | --num_train_epochs 2 \ 7 | --save_steps 1000 \ 8 | --save_total_limit 5 \ 9 | --learning_rate 3e-5 \ 10 | --seed 42 \ 11 | --ddp_find_unused_parameters False \ 12 | --feature_proj_lr 1e-4 \ 13 | --remove_unused_columns false \ 14 | --logging_steps 100 \ 15 | --output_dir ./weights/train_V1_5 \ 16 | --target_modules "c_attn|w1|w2" \ 17 | --image_map /home/u2023111315/Basic-Vision-Language-Model/data/image_map_b.json \ 18 | --captions_file /home/u2023111315/Basic-Vision-Language-Model/data/captions_b.json 19 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Trainer 3 | from transformers.trainer import ( 4 | is_sagemaker_mp_enabled, 5 | get_parameter_names, 6 | has_length, 7 | ALL_LAYERNORM_LAYERS, 8 | logger, 9 | ) 10 | import os 11 | from peft import get_peft_model_state_dict 12 | 13 | class MultiModalTrainer(Trainer): 14 | def compute_loss(self, model, inputs, return_outputs=False): 15 | return model( 16 | image=inputs["images"], 17 | input_ids=inputs["input_ids"], 18 | labels=inputs["labels"], 19 | ).loss 20 | 21 | def save_model(self, output_dir=None, _internal_call=False): 22 | from transformers.trainer import TRAINING_ARGS_NAME 23 | 24 | # Ensure output_dir is not None 25 | if output_dir is None: 26 | output_dir = self.args.output_dir 27 | 28 | # Create the output directory if it doesn't exist 29 | os.makedirs(output_dir, exist_ok=True) 30 | 31 | # Save training arguments 32 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 33 | 34 | # Access the original model 35 | model = self.model.module if hasattr(self.model, 'module') else self.model 36 | 37 | # Save LLM parameters 38 | saved_params_LLM = get_peft_model_state_dict(model.LLM) 39 | torch.save(saved_params_LLM, os.path.join(output_dir, "adapter_model.bin")) 40 | 41 | # Save other parameters 42 | saved_params_other = model.feature_proj.state_dict() 43 | torch.save(saved_params_other, os.path.join(output_dir, "other_params.bin")) 44 | 45 | # Save configuration 46 | config = model.LLM.peft_config 47 | selected_adapters = list(config.keys()) 48 | config[selected_adapters[0]].save_pretrained(output_dir, auto_mapping_dict=None) 49 | 50 | def create_optimizer(self): 51 | if is_sagemaker_mp_enabled(): 52 | return super().create_optimizer() 53 | 54 | opt_model = self.model 55 | 56 | if self.optimizer is None: 57 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 58 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 59 | if self.args.feature_proj_lr is not None: 60 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "feature_proj" in name] 61 | optimizer_grouped_parameters = [ 62 | { 63 | "params": [ 64 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 65 | ], 66 | "weight_decay": self.args.weight_decay, 67 | }, 68 | { 69 | "params": [ 70 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 71 | ], 72 | "weight_decay": 0.0, 73 | }, 74 | { 75 | "params": [ 76 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 77 | ], 78 | "weight_decay": self.args.weight_decay, 79 | "lr": self.args.feature_proj_lr, 80 | }, 81 | { 82 | "params": [ 83 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 84 | ], 85 | "weight_decay": 0.0, 86 | "lr": self.args.feature_proj_lr, 87 | }, 88 | ] 89 | else: 90 | optimizer_grouped_parameters = [ 91 | { 92 | "params": [ 93 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 94 | ], 95 | "weight_decay": self.args.weight_decay, 96 | }, 97 | { 98 | "params": [ 99 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 100 | ], 101 | "weight_decay": 0.0, 102 | }, 103 | ] 104 | 105 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 106 | 107 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 108 | 109 | return self.optimizer 110 | 111 | def create_optimizer_and_scheduler(self, num_training_steps: int): 112 | super().create_optimizer_and_scheduler(num_training_steps) 113 | if self.args.local_rank != -1: 114 | self.model = torch.nn.parallel.DistributedDataParallel( 115 | self.model, 116 | device_ids=[self.args.local_rank], 117 | output_device=self.args.local_rank, 118 | find_unused_parameters=True 119 | ) -------------------------------------------------------------------------------- /visual/CLIP_VIT.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from typing import Optional 4 | from torch import nn 5 | from PIL import Image 6 | from transformers import CLIPModel, CLIPConfig, CLIPProcessor 7 | from transformers.utils import add_start_docstrings_to_model_forward 8 | from transformers.models.clip.modeling_clip import CLIP_VISION_INPUTS_DOCSTRING 9 | 10 | class visualModel(CLIPModel): 11 | def __init__(self, config: CLIPConfig): 12 | super().__init__(config) 13 | 14 | @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) 15 | def get_image_features( 16 | self, 17 | pixel_values: Optional[torch.FloatTensor] = None, 18 | output_attentions: Optional[bool] = None, 19 | output_hidden_states: Optional[bool] = None, 20 | return_dict: Optional[bool] = None, 21 | ) -> torch.FloatTensor: 22 | 23 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 24 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 25 | output_hidden_states = ( 26 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 27 | ) 28 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 29 | 30 | vision_outputs = self.vision_model( 31 | pixel_values=pixel_values, 32 | output_attentions=output_attentions, 33 | output_hidden_states=output_hidden_states, 34 | return_dict=return_dict, 35 | ) 36 | 37 | pooled_output = vision_outputs.last_hidden_state # pooled_output 38 | # print(pooled_output.shape) 39 | return pooled_output 40 | 41 | 42 | def main(): 43 | modle_path = "F:/huggingface_model/clip-vit-large-patch14" 44 | model = visualModel.from_pretrained(modle_path) 45 | processor = CLIPProcessor.from_pretrained(modle_path) 46 | test_img = Image.open("D:/code/multimodal/data/000000391895.jpg") 47 | P_input = processor(images=test_img, return_tensors="pt") 48 | print(model.get_image_features(**P_input).shape) 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /visual/SIGLIP_VIT.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | from typing import Optional 4 | from torch import nn 5 | from PIL import Image 6 | from transformers import SiglipModel, SiglipConfig, SiglipProcessor 7 | from transformers.utils import add_start_docstrings_to_model_forward 8 | from transformers.models.siglip.modeling_siglip import SIGLIP_VISION_INPUTS_DOCSTRING 9 | 10 | class visualModel(SiglipModel): 11 | def __init__(self, config: SiglipConfig): 12 | super().__init__(config) 13 | vision_config = config.vision_config 14 | self.vision_embed_dim = vision_config.hidden_size 15 | 16 | @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) 17 | def get_image_features( 18 | self, 19 | pixel_values: Optional[torch.FloatTensor] = None, 20 | output_attentions: Optional[bool] = None, 21 | output_hidden_states: Optional[bool] = None, 22 | return_dict: Optional[bool] = None, 23 | ) -> torch.FloatTensor: 24 | 25 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 26 | output_hidden_states = ( 27 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 28 | ) 29 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 30 | 31 | # 确保 pixel_values 的数据类型为 bfloat16 32 | pixel_values = pixel_values.to(dtype=torch.bfloat16) 33 | 34 | vision_outputs = self.vision_model( 35 | pixel_values=pixel_values, 36 | output_attentions=output_attentions, 37 | output_hidden_states=output_hidden_states, 38 | return_dict=return_dict, 39 | ) 40 | 41 | pooled_output = vision_outputs.last_hidden_state # pooled_output 42 | return pooled_output -------------------------------------------------------------------------------- /webUI.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gradio as gr 3 | import torch 4 | from transformers import AutoTokenizer, ChineseCLIPProcessor 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | from model.model import MMultiModal, LanguageConfig, VisualConfig, MultiModalConfig 9 | 10 | base_language_model = "F:/huggingface_model/qwen/Qwen-7B-chat/" 11 | base_value_model = "F:/huggingface_model/clip-vit-large-patch14" 12 | 13 | tokenizer = AutoTokenizer.from_pretrained(base_language_model, trust_remote_code=True) 14 | replace_token_id = tokenizer.convert_tokens_to_ids("<|extra_0|>") 15 | 16 | model = MMultiModal(LanguageConfig(model_path=base_language_model), VisualConfig(model_path=base_value_model), 17 | MultiModalConfig(replace_token_id=replace_token_id),train=False).cuda() 18 | model.load("./weights/train_V1_5/checkpoint-18000/") 19 | 20 | def image_process(image): 21 | mean=[0.485, 0.456, 0.406] # RGB 22 | std=[0.229, 0.224, 0.225] # RGB 23 | 24 | tran = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean, std), 27 | transforms.Resize([224, 224]) 28 | ]) 29 | 30 | return tran(image) 31 | 32 | def chat(image, messages): 33 | if image is None: 34 | image_pt = None 35 | else: 36 | image_pt = image_process(image).unsqueeze(0).cuda().to(torch.bfloat16) 37 | raw_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 38 | question_ids = tokenizer.encode(raw_text) 39 | result = model.generate(image_pt, question_ids)[0] 40 | result = tokenizer.decode(result) 41 | return result 42 | 43 | def chatbot_(input_text, chat_history, image): 44 | SP_token = "<|extra_0|>" 45 | send_history = [{"role": "system", "content": "你是一位图像理解助手。"}] 46 | for CH in chat_history: 47 | send_history.append({"role":"user", "content":CH[0]}) 48 | send_history.append({"role":"assistant", "content":CH[1]}) 49 | if image is not None: 50 | send_history.append({"role":"user", "content":"用中文回答:"+input_text+SP_token}) 51 | else: 52 | send_history.append({"role":"user", "content":"用中文回答:"+input_text}) 53 | bot_message = chat(image, send_history) 54 | chat_history.append((input_text, bot_message)) 55 | return "", chat_history 56 | 57 | def clear_history(): 58 | return "", [] 59 | 60 | 61 | with gr.Blocks() as demo: 62 | with gr.Row(): 63 | with gr.Column(): 64 | image = gr.Image(label="image") 65 | clear = gr.ClearButton() 66 | with gr.Column(): 67 | chatbot = gr.Chatbot() 68 | input_text = gr.Textbox() 69 | 70 | input_text.submit(chatbot_, [input_text, chatbot, image], [input_text, chatbot]) 71 | clear.click(clear_history, [], [input_text, chatbot]) 72 | 73 | def main(): 74 | demo.launch(server_port=23200) 75 | 76 | if __name__ == "__main__": 77 | main() 78 | --------------------------------------------------------------------------------