├── MozartsTouch ├── utils │ ├── __init__.py │ ├── MusicGenerator │ │ ├── __init__.py │ │ ├── music_gen.py │ │ └── suno_ai.py │ ├── normalize_volume.py │ ├── image_processing.py │ ├── music_generation.py │ ├── txt_converter.py │ └── preprocess_single.py ├── static │ ├── BONK.mp3 │ ├── Loss.jpg │ ├── test.jpg │ └── stone.mp4 ├── __init__.py ├── config.yaml ├── download_model.py └── main.py ├── requirements_for_server.txt ├── logo.png ├── architecture.png ├── requirements.txt ├── .gitignore ├── start_server.py ├── LICENSE ├── README.md └── backend_app.py /MozartsTouch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MozartsTouch/utils/MusicGenerator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements_for_server.txt: -------------------------------------------------------------------------------- 1 | fastapi>=0.100.0 2 | uvicorn>=0.22.0 -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiffanyBlews/MozartsTouch/HEAD/logo.png -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiffanyBlews/MozartsTouch/HEAD/architecture.png -------------------------------------------------------------------------------- /MozartsTouch/static/BONK.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiffanyBlews/MozartsTouch/HEAD/MozartsTouch/static/BONK.mp3 -------------------------------------------------------------------------------- /MozartsTouch/static/Loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiffanyBlews/MozartsTouch/HEAD/MozartsTouch/static/Loss.jpg -------------------------------------------------------------------------------- /MozartsTouch/static/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiffanyBlews/MozartsTouch/HEAD/MozartsTouch/static/test.jpg -------------------------------------------------------------------------------- /MozartsTouch/static/stone.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TiffanyBlews/MozartsTouch/HEAD/MozartsTouch/static/stone.mp4 -------------------------------------------------------------------------------- /MozartsTouch/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import img_to_music_generate, import_ir, import_music_generator, video_to_music_generate 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=10.0.0 2 | transformers>=4.31.0 3 | torch>=2.0.1 4 | openai 5 | einops 6 | decord 7 | moviepy 8 | loguru 9 | pyyaml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Diancai-Backend.code-workspace 2 | __pycache__ 3 | cache 4 | outputs 5 | *.json 6 | .vs 7 | model 8 | musicgen.wav 9 | videos 10 | config.yaml 11 | MozartsTouch/config.yaml 12 | backend_v2.py 13 | -------------------------------------------------------------------------------- /start_server.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 请在运行前进行如下配置: 3 | 1.在MozartsTouch文件夹内设定config.yaml 4 | 2.确保电脑自身配置足够运行 5 | ''' 6 | 7 | import uvicorn 8 | 9 | if __name__ == '__main__': 10 | uvicorn.run("backend_app:app", host="0.0.0.0", port=3001, reload=True) -------------------------------------------------------------------------------- /MozartsTouch/config.yaml: -------------------------------------------------------------------------------- 1 | # When in TEST_MODE, no models will be loaded and only dummy data will be returned 2 | TEST_MODE: False 3 | 4 | # Change this to disable LLM for ablation study purpose 5 | USE_LLM: True 6 | 7 | DEFAULT_CAPTION_MODEL: Florence-2-large 8 | 9 | DEFAULT_LLM_MODEL: qwen2.5-instruct-32b-int4 10 | 11 | # Options:['test', 'musicgen-small', 'musicgen-medium', 'musicgen-large', 'suno'] 12 | DEFAULT_MUSIC_MODEL: musicgen-small 13 | 14 | 15 | CAPTION_MODEL_CONFIG: 16 | VIDEO_SAMPLE_AMOUNT: 20 17 | 18 | LLM_MODEL_CONFIG: 19 | # Platform type 20 | # Options:['xinference', 'ollama', 'oneapi', 'fastchat', 'openai', 'custom openai'] 21 | # This option is not implemented yet ¯\_(シ)_/¯ 22 | PLATFORM_TYPE: openai 23 | 24 | # OpenAI API URL 25 | API_BASE_URL: http://10.29.118.247:9997/v1 26 | 27 | # api key if available 28 | API_KEY: EMPTY 29 | 30 | API_PROXY: '' 31 | 32 | API_CONCURRENCIES: 1 33 | 34 | MUSIC_MODEL_CONFIG: 35 | API_BASE_URL: https://suno-api-psi-one.vercel.app/ 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 WangTN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MozartsTouch/utils/normalize_volume.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import numpy as np 3 | from pydub import AudioSegment 4 | from pathlib import Path 5 | 6 | module_path = Path(__file__).resolve().parent.parent # module_path为模块根目录(`/MozartsTouch`) 7 | 8 | def normalize_volume(audio_path): 9 | # 读取BytesIO对象中的音频数据 10 | audio_segment = AudioSegment.from_file(audio_path, format="mp3") # 根据实际情况选择正确的格式 11 | 12 | # 将音频数据转换为numpy数组 13 | samples = np.array(audio_segment.get_array_of_samples()) 14 | 15 | # 计算音频数据的最大振幅 16 | max_amplitude = np.max(np.abs(samples)) 17 | print(max_amplitude) 18 | 19 | # 设置目标最大振幅 20 | target_amplitude = 32000.0 21 | 22 | # 计算缩放因子 23 | scale_factor = target_amplitude / max_amplitude 24 | print(scale_factor) 25 | # 缩放音频数据 26 | normalized_samples = (samples * scale_factor).astype(np.int16) 27 | 28 | # 创建新的音频段并返回 29 | normalized_audio = AudioSegment( 30 | normalized_samples.tobytes(), 31 | frame_rate=audio_segment.frame_rate, 32 | sample_width=audio_segment.sample_width, 33 | channels=audio_segment.channels 34 | ) 35 | 36 | normalized_audio.export(audio_path.parent / ("normalized_"+audio_path.name), format="mp3") 37 | return normalized_audio 38 | 39 | if __name__ == "__main__": 40 | normalized_audio = normalize_volume(module_path/"static"/ "BONK.mp3") 41 | # AudioSegment.from_file(module_path/"static"/ "BONK.mp3") 42 | 43 | -------------------------------------------------------------------------------- /MozartsTouch/utils/MusicGenerator/music_gen.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import time 3 | from transformers import AutoProcessor, MusicgenForConditionalGeneration 4 | import scipy.io.wavfile 5 | import torch 6 | from io import BytesIO 7 | from loguru import logger 8 | 9 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | module_path = Path(__file__).resolve().parent.parent.parent 11 | 12 | class MusicGen: 13 | def __init__(self, model_name="musicgen_small") -> None: 14 | self.processor = AutoProcessor.from_pretrained(module_path / "model" / f"{model_name}_processor") 15 | self.model = MusicgenForConditionalGeneration.from_pretrained(module_path / "model" / f"{model_name}_model").to(device) 16 | self.sampling_rate = self.model.config.audio_encoder.sampling_rate 17 | 18 | def generate(self, text: str, music_duration: int) -> BytesIO: 19 | logger.info(f"Start generating music") 20 | start_time = time.time() 21 | 22 | inputs = self.processor( 23 | text=[text], 24 | padding=True, 25 | return_tensors="pt", 26 | ).to(device) 27 | audio_values = self.model.generate(**inputs, max_new_tokens=int(256 * music_duration // 5)) # music_duration为秒数,256token = 5s 28 | 29 | wav_file_data = BytesIO() 30 | scipy.io.wavfile.write(wav_file_data, rate=self.sampling_rate, data=audio_values[0, 0].cpu().numpy()) 31 | logger.info(f"[TIME] taken for txt2music: {time.time() - start_time :.2f}s") 32 | return wav_file_data 33 | 34 | if __name__ == "__main__": 35 | music_gen_small = MusicGen() 36 | output = music_gen_small.generate("cyberpunk electronic dancing music", 1) 37 | with open(module_path / 'music_gen_test.wav', 'wb') as f: 38 | f.write(output.getvalue()) -------------------------------------------------------------------------------- /MozartsTouch/download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 4 | 5 | from pathlib import Path 6 | from transformers import AutoProcessor, MusicgenForConditionalGeneration, BlipForConditionalGeneration, AutoModelForCausalLM 7 | from loguru import logger 8 | 9 | # 设置环境变量和路径 10 | cwd = Path(__file__).resolve().parent 11 | model_path = cwd / "model" 12 | model_path.mkdir(parents=True, exist_ok=True) 13 | 14 | # 通用加载和保存函数 15 | def download_and_save_model(model_class, processor_class, model_name, save_dir): 16 | """ 17 | 下载模型和处理器并保存到指定路径 18 | :param model_class: 模型类 19 | :param processor_class: 处理器类 20 | :param model_name: 预训练模型名称 21 | :param save_dir: 保存目录 22 | """ 23 | 24 | model_save_path = save_dir / f"{model_name.split('/')[-1]}_model" 25 | processor_save_path = save_dir / f"{model_name.split('/')[-1]}_processor" 26 | 27 | try: 28 | logger.info(f"正在尝试加载模型和处理器: {model_name}...") 29 | model = model_class.from_pretrained(model_save_path,trust_remote_code=True) 30 | processor = processor_class.from_pretrained(processor_save_path,trust_remote_code=True) 31 | 32 | logger.info(f"{model_name} 加载成功!") 33 | except Exception as e: 34 | logger.info(f"加载 {model_name} 时出错: {e}") 35 | model = model_class.from_pretrained(model_name, trust_remote_code=True) 36 | processor = processor_class.from_pretrained(model_name, trust_remote_code=True) 37 | 38 | model.save_pretrained(model_save_path) 39 | processor.save_pretrained(processor_save_path) 40 | 41 | # 下载模型 42 | # download_and_save_model(MusicgenForConditionalGeneration, AutoProcessor, "facebook/musicgen-small", model_path) 43 | # download_and_save_model(BlipForConditionalGeneration, AutoProcessor, "Salesforce/blip-image-captioning-base", model_path) 44 | # download_and_save_model(MusicgenForConditionalGeneration, AutoProcessor, "facebook/musicgen-medium", model_path) 45 | download_and_save_model(AutoModelForCausalLM, AutoProcessor, "microsoft/Florence-2-large", model_path) 46 | -------------------------------------------------------------------------------- /MozartsTouch/utils/MusicGenerator/suno_ai.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import time 3 | import requests 4 | from loguru import logger 5 | import yaml 6 | 7 | module_path = Path(__file__).resolve().parent.parent.parent # module_path为模块根目录(`/MozartsTouch`) 8 | with open(module_path / 'config.yaml', 'r', encoding='utf8') as file: 9 | config = yaml.safe_load(file) 10 | 11 | class Suno: 12 | def __init__(self) -> None: 13 | self.base_url = config['MUSIC_MODEL_CONFIG']['API_BASE_URL'] 14 | self.task_list = { 15 | 'custom_generate_audio': f"{self.base_url}/api/custom_generate", 16 | 'extend_audio': f"{self.base_url}/api/extend_audio", 17 | 'generate_audio_by_prompt': f"{self.base_url}/api/generate", 18 | 'get_audio_information': f"{self.base_url}/api/get?ids=", 19 | 'get_quota_information': f"{self.base_url}/api/get_limit" 20 | } 21 | 22 | def post_suno_api(self, task, payload): 23 | url = self.task_list[task] 24 | response = requests.post(url, json=payload, headers={'Content-Type': 'application/json'}) 25 | response.raise_for_status() 26 | return response.json() 27 | 28 | def get_suno_api(self, task, audio_ids=None): 29 | url = self.task_list[task] 30 | if audio_ids: 31 | url += audio_ids 32 | 33 | response = requests.get(url) 34 | response.raise_for_status() 35 | return response.json() 36 | 37 | def generate(self, text: str): 38 | data = self.post_suno_api('generate_audio_by_prompt', { 39 | "prompt": text, 40 | "make_instrumental": True, 41 | "wait_audio": False 42 | }) 43 | ids = ",".join([item['id'] for item in data]) 44 | logger.info(f"Generated IDs: {ids}") 45 | 46 | for _ in range(60): 47 | try: 48 | data = self.get_suno_api('get_audio_information', ids) 49 | if data[1]["status"] == 'streaming': 50 | logger.info(f"Audio URL: {data[1]['audio_url']}") 51 | return data[1]['audio_url'] 52 | except Exception as e: 53 | logger.warning(f"Error fetching audio information: {e}") 54 | time.sleep(5) 55 | 56 | logger.error("Failed to generate audio within the expected time.") 57 | return None 58 | 59 | if __name__ == '__main__': 60 | suno = Suno() 61 | music = suno.generate("A popular heavy metal song about war, sung by a deep-voiced male singer, slowly and melodiously.") 62 | print(music) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mozart's Touch: Multi-Modal Music Generation with Pre-Trained Models 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2405.02801) [![GitHub Stars](https://img.shields.io/github/stars/TiffanyBlews/MozartsTouch?style=social)](https://github.com/TiffanyBlews/MozartsTouch) [![githubio](https://img.shields.io/badge/GitHub.io-Demo_Page-blue?logo=Github&style=flat-square)](https://tiffanyblews.github.io/MozartsTouch-demo/) 3 | 4 | This is the official implementation of [Mozart's Touch: A Lightweight Multi-modal Music Generation Framework Based on Pre-Trained Large Models](https://arxiv.org/abs/2405.02801) 5 | 6 | ![](logo.png) 7 | 8 | ![](architecture.png) 9 | 10 | ## Package Description 11 | This repository is structured as follows: 12 | ``` 13 | Diancai-Backend 14 | ├─MozartsTouch/: source code for the implementation of Mozart's Touch 15 | │ ├─model/: pre-trained models 16 | │ ├─static/: static source for test purpose 17 | │ ├─utils/: source code for the modules 18 | │ ├─download_model.py: download pre-trained model to ./model/ 19 | │ ├─config.yaml: configurations such as LLM model URLs, API keys 20 | │ └─main.py: Main program of Mozart's Touch 21 | │ outputs/: directory to store generation result music 22 | ├─backend_app.py: program for backend web application of Mozart's Touch 23 | └─start_server.py: start the backend server of Mozart's Touch 24 | ``` 25 | ## Setup 26 | 1. Before running, please configure [config.yaml](MozartsTouch/config.yaml). 27 | 2. Install dependencies using `pip install -r requirements.txt`. 28 | 3. Run [download_model.py](MozartsTouch/download_model.py) to download model parameters needed. 29 | 4. Use [MozartsTouch.img_to_music_generate()](MozartsTouch/main.py) to generate music. 30 | 31 | To test codes without importing large models, set `TEST_MODE` to `True` in [config.yaml](MozartsTouch/config.yaml). 32 | 33 | ## Usage 34 | 35 | 36 | ## Running as a Command Line Tool 37 | With the setup complete, you can now run the following command to generate music: 38 | ```bash 39 | python main.py 40 | ``` 41 | or debug with no model imported: 42 | ```bash 43 | python main.py --test_mode 44 | ``` 45 | 46 | ## Running as a Web Backend Server 47 | 48 | 1. Install dependencies using `pip install -r requirements_for_server.txt`. 49 | 2. Configure port number and other parameters in[start_server.py](/app/start_server.py). 50 | 3. Run `python start_server.py`. 51 | 4. Access http://localhost:3001/docs#/ to view the backend documentation and test the APIs. 52 | 53 | The related frontend project is at https://github.com/ScientificW/MozartFrontEndConnect 54 | 55 | 56 | ## TO-DO List 57 | - ~~增加用户输入提示词功能~~ 58 | - 删除API中的mode 59 | - ~~更新到最新的代码,将 `Video-BLIP2` 整合到我们的项目中。~~ 60 | - 将评估代码整合进来 61 | - ~~Use `argparse` to set and pass config~~ 62 | - ~~MusicGen部分重构策略模式~~ 63 | - ~~Use API instead of loading models manually~~ 64 | - ~~Add support for other models as an alternative e.g. LLaMa.~~ 65 | ### 远期任务 66 | - 尝试Florence-2等最新模型 67 | - 优化音乐生成部分MusicGen模型的代码(主要需求:优化生成效率) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /MozartsTouch/utils/image_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | from PIL import Image 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoProcessor 7 | from loguru import logger 8 | 9 | module_path = Path(__file__).resolve().parent.parent 10 | 11 | class ImageRecognization: 12 | def __init__(self, beam_amount=7, min_prompt_length=15, max_prompt_length=30, test_mode=False): 13 | self.processor = None 14 | self.model = None 15 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 16 | self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 17 | self.beam_amount = beam_amount 18 | self.min_length = min_prompt_length 19 | self.max_length = max_prompt_length 20 | self.test_mode = test_mode 21 | logger.info(f"self.device: {self.device}") 22 | if not self.test_mode: 23 | self._load_model("Florence-2-large") # Call the _load_model method 24 | 25 | def _load_model(self, model_name="Florence-2-large"): 26 | if not self.model or not self.processor: 27 | logger.info(f"Loading captioning model {model_name}") 28 | start_time = time.time() 29 | self.processor = AutoProcessor.from_pretrained(module_path / "model" / f"{model_name}_processor", trust_remote_code=True) 30 | self.model = AutoModelForCausalLM.from_pretrained( 31 | module_path / "model" / f"{model_name}_model", torch_dtype=self.torch_dtype, trust_remote_code=True 32 | ).to(self.device) 33 | logger.info(f"[TIME] taken for loading {model_name}: {time.time() - start_time :.2f}s") 34 | 35 | def img2txt(self, image: Image, task='') -> str: 36 | '''将图片转换为文本''' 37 | if self.test_mode: 38 | return self._test_img2txt() 39 | return self._img2txt(image, task) 40 | 41 | def _img2txt(self, image: Image, task='') -> str: 42 | start_time = time.time() 43 | image = image.convert('RGB') 44 | inputs = self.processor(text=task, images=image, return_tensors="pt").to(self.device, self.torch_dtype) 45 | generated_ids = self.model.generate( 46 | input_ids=inputs["input_ids"], 47 | pixel_values=inputs["pixel_values"], 48 | max_new_tokens=1024, 49 | num_beams=self.beam_amount, 50 | do_sample=False 51 | ) 52 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 53 | parsed_answer = self.processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height)) 54 | logger.info(f"[TIME] taken for img2txt: {time.time() - start_time :.2f}s") 55 | logger.info(parsed_answer) 56 | return parsed_answer[task] 57 | 58 | def _test_img2txt(self) -> str: 59 | '''测试用函数,只会直接返回一种结果''' 60 | return "The image is a comic strip with four panels. \n\nThe first panel on the top left shows a young man with brown hair and a blue shirt, who appears to be a doctor or nurse. He is standing in front of a door with the word \"GENO\" written on it. The man is gesturing with his hand as if he is explaining something to the doctor.\n\nIn the second panel, there is a young woman sitting at a desk with a concerned expression on her face. She is looking at the doctor with a worried expression. The doctor is wearing a stethoscope around his neck and is holding a clipboard in his hand. The woman is lying on a hospital bed with her eyes closed and her head resting on the bed. The background shows a hospital room with a window and a door." 61 | 62 | if __name__ == "__main__": 63 | test_image = module_path / "static" / "test.jpg" 64 | test_image = r"C:\Users\ljj\Downloads\VE7Z10AAOTyPsEDH2nspEXpURazGWgbjXxgbT4_UrR9fEcNQM672DkJqVDZ-p68zYRN832Wd18XpG0sySsNHOg.webp" 65 | 66 | image = Image.open(test_image) 67 | img_recog = ImageRecognization() 68 | for task in ['', '', '']: 69 | result = img_recog.img2txt(image, task) -------------------------------------------------------------------------------- /MozartsTouch/utils/music_generation.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | from loguru import logger 5 | from abc import ABC, abstractmethod 6 | 7 | if __name__ == "__main__": 8 | from MusicGenerator.music_gen import MusicGen 9 | from MusicGenerator.suno_ai import Suno 10 | else: 11 | from .MusicGenerator.music_gen import MusicGen 12 | from .MusicGenerator.suno_ai import Suno 13 | 14 | module_path = Path(__file__).resolve().parent.parent 15 | 16 | class AbstractSingletonMeta(type, ABC): 17 | ''' 18 | 单例模式抽象基类元类,保证音乐生成类只有一个实例 19 | ''' 20 | _instances = {} 21 | 22 | def __call__(cls, *args, **kwargs): 23 | if cls not in cls._instances: 24 | cls._instances[cls] = super().__call__(*args, **kwargs) 25 | return cls._instances[cls] 26 | 27 | 28 | class MusicGenerator(metaclass=AbstractSingletonMeta): 29 | ''' 30 | 音乐生成模型抽象基类,增加模型类时需要继承此类并实现generate方法 31 | ''' 32 | @property 33 | @abstractmethod 34 | def model_name(self): 35 | pass 36 | 37 | @abstractmethod 38 | def generate(self, text: str, music_duration: int) -> Union[io.BytesIO, str]: 39 | """ 40 | 根据传入文本和音乐时长生成音乐 41 | 42 | Parameters: 43 | text (str): 音乐生成的提示文本 44 | music_duration (int): 音乐的时长,单位为秒 45 | 46 | Returns: 47 | Union[io.BytesIO, str]: Generated music as a byte stream or string. 48 | """ 49 | pass 50 | 51 | class TestGenerator(MusicGenerator): 52 | def __init__(self) -> None: 53 | self._model_name = "test" 54 | 55 | @property 56 | def model_name(self): 57 | return self._model_name 58 | 59 | def generate(self, text: str, music_duration: int) -> io.BytesIO: 60 | """Return dummy audio *BONK*.mp3""" 61 | logger.info("Music Generation Prompt: " + text.encode('gbk', errors='replace').decode('gbk')) 62 | test_path = module_path / "static" / "BONK.mp3" 63 | with open(test_path, "rb") as f: 64 | test_mp3 = f.read() 65 | return io.BytesIO(test_mp3) 66 | 67 | class MusicGenSmallGenerator(MusicGenerator): 68 | def __init__(self) -> None: 69 | self.model = MusicGen("musicgen-small") 70 | self._model_name = "musicgen-small" 71 | 72 | @property 73 | def model_name(self): 74 | return self._model_name 75 | 76 | def generate(self, text: str, music_duration: int) -> io.BytesIO: 77 | return self.model.generate(text, music_duration) 78 | 79 | class MusicGenMediumGenerator(MusicGenerator): 80 | def __init__(self) -> None: 81 | self.model = MusicGen("musicgen-medium") 82 | self._model_name = "musicgen-medium" 83 | 84 | @property 85 | def model_name(self): 86 | return self._model_name 87 | 88 | def generate(self, text: str, music_duration: int) -> io.BytesIO: 89 | return self.model.generate(text, music_duration) 90 | 91 | class MusicGenLargeGenerator(MusicGenerator): 92 | def __init__(self) -> None: 93 | self.model = MusicGen("musicgen-large") 94 | self._model_name = "musicgen-large" 95 | 96 | @property 97 | def model_name(self): 98 | return self._model_name 99 | 100 | def generate(self, text: str, music_duration: int) -> io.BytesIO: 101 | return self.model.generate(text, music_duration) 102 | 103 | class SunoGenerator(MusicGenerator): 104 | def __init__(self) -> None: 105 | self._model_name = "Suno" 106 | 107 | @property 108 | def model_name(self): 109 | return self._model_name 110 | 111 | def generate(self, text: str, *_) -> str: 112 | return Suno.generate(text) 113 | 114 | class MusicGeneratorFactory: 115 | """ 116 | 工厂类获取音乐生成模型实例 117 | """ 118 | generator_classes = { 119 | "test": TestGenerator, 120 | "musicgen-small": MusicGenSmallGenerator, 121 | "musicgen-medium": MusicGenMediumGenerator, 122 | "musicgen-large": MusicGenLargeGenerator, 123 | "suno": SunoGenerator, 124 | } 125 | 126 | @classmethod 127 | def create_music_generator(cls, music_gen_model_name: str) -> MusicGenerator: 128 | generator_class = cls.generator_classes.get(music_gen_model_name) 129 | if generator_class: 130 | logger.info(f'Creating music generator: {generator_class.__name__}') 131 | return generator_class() 132 | else: 133 | raise ValueError(f"Unsupported music_gen_model_name: {music_gen_model_name}") 134 | 135 | if __name__ == "__main__": 136 | music_gen_small = MusicGeneratorFactory.create_music_generator("musicgen-small") 137 | output = music_gen_small.generate("a daft punky thrash music", 1) 138 | logger.info(f"Model name: {music_gen_small.model_name}") 139 | with open(module_path / 'musicgen.wav', 'wb') as f: 140 | f.write(output.getvalue()) -------------------------------------------------------------------------------- /backend_app.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from fastapi import APIRouter, FastAPI, File, UploadFile, Form 3 | from fastapi.middleware.cors import CORSMiddleware 4 | from fastapi.responses import FileResponse 5 | # import uvicorn 6 | from io import BytesIO 7 | from PIL import Image 8 | from pydantic import BaseModel 9 | from pathlib import Path 10 | 11 | import uvicorn 12 | import MozartsTouch 13 | from loguru import logger 14 | from typing import Optional 15 | 16 | router = APIRouter(prefix="/v1", tags=["v1"]) 17 | models = {} 18 | app_path = Path(__file__).parent# app_path为项目根目录(`/`) 19 | 20 | class MusicResponse(BaseModel): 21 | prompt: str 22 | converted_prompt: str 23 | result_file_url: str 24 | 25 | @asynccontextmanager 26 | async def lifespan(app: FastAPI): 27 | # Load the ML model 28 | models["music_generator"] = MozartsTouch.import_music_generator() 29 | models["image_recog"] = MozartsTouch.import_ir() 30 | 31 | yield 32 | 33 | # Clean up the ML models and release the resources 34 | del models["music_generator"] 35 | del models["image_recog"] 36 | 37 | # Create FastAPI app 38 | app = FastAPI(title='点彩成乐', description='“点彩成乐”项目后端', lifespan=lifespan) 39 | 40 | origins = ["http://localhost:5173"] 41 | # Add CORS middleware 42 | app.add_middleware( 43 | CORSMiddleware, 44 | allow_origins=origins, 45 | allow_credentials=True, 46 | allow_methods=["*"], 47 | allow_headers=["*"], 48 | ) 49 | 50 | 51 | @router.post("/image", response_model=MusicResponse) 52 | async def upload_image(file: UploadFile = File(...), 53 | music_duration: Optional[int] = Form(10), 54 | instruction: Optional[str] = Form("")): 55 | ''' 56 | 上传图片以进行音乐生成 57 | 58 | Parameters: 59 | - file: 图片文件,Content-Type: image/* 60 | - music_duration: 指定生成时间,请输入整数,以秒为单位。默认值为10秒。若使用Suno AI生成音乐则此参数会被忽略 61 | - instruction: 用户输入的限定文本,可选 62 | 63 | Return: 64 | - prompt: 图片转文本结果 65 | - converted_prompt: 用于生成音乐的提示词文本 66 | - result_file_url: 生成的音频URL,使用GET方法访问"result_file_url"获取音频文件 67 | ''' 68 | logger.info("Request Received Successfully, Processing...") 69 | output_folder = app_path / "outputs" 70 | 71 | img = Image.open(file.file) 72 | result = MozartsTouch.img_to_music_generate(img, music_duration, models["image_recog"], models["music_generator"], output_folder) 73 | 74 | if not models["music_generator"].model_name.startswith("suno"): 75 | prefix = 'http://localhost:3001/music/' 76 | result = (*result[:2], prefix + result[2]) 77 | 78 | result_dict = {key: value for key, value in zip(("prompt", "converted_prompt", "result_file_url"), result)} 79 | logger.info('**********FINAL RESULT**********') 80 | logger.info(result_dict) 81 | 82 | return result_dict 83 | 84 | @router.post("/video", response_model=MusicResponse) 85 | async def upload_video(file: UploadFile = File(...), instruction: Optional[str] = Form('')): 86 | ''' 87 | 上传视频以进行音乐生成 88 | 89 | Parameters: 90 | - file: 图片文件,Content-Type: image/* 91 | - instruction: 用户输入的限定文本,可选 92 | 93 | Return: 94 | - prompt: 图片转文本结果 95 | - converted_prompt: 用于生成音乐的提示词文本 96 | - result_file_url: 视频配合生成的音频的最终视频,使用GET方法访问"result_file_url"获取视频文件 97 | ''' 98 | logger.info("Request Received Successfully, Processing...") 99 | output_folder = app_path / "outputs" 100 | video_path = app_path / "videos" / file.filename 101 | 102 | # 将视频保存至本地,然后读取视频帧 103 | contents = await file.read() 104 | with open(video_path, "wb") as f: 105 | f.write(contents) 106 | 107 | result = MozartsTouch.video_to_music_generate(str(video_path), models["image_recog"], models["music_generator"], output_folder, instruction) 108 | 109 | prefix = 'http://localhost:3001/music/' # 将musicgen生成的音乐文件名包装成URL 110 | filename_with_prefix = prefix + result[2] 111 | result = (*result[:2], filename_with_prefix) 112 | 113 | result_dict = {key: value for key, value in zip(("prompt", "converted_prompt", "result_file_url"), result)} 114 | logger.info('**********FINAL RESULT**********') 115 | logger.info(result_dict) 116 | 117 | return result_dict 118 | 119 | @router.get("/music/{result_file_name}") 120 | async def get_music(result_file_name: str): 121 | ''' 122 | 获取/outputs目录下对应名称的文件 123 | 124 | Return: 125 | - 音频文件 126 | ''' 127 | file_full_path = app_path / "outputs" / result_file_name 128 | logger.info(f'Return file {file_full_path}') 129 | return FileResponse(file_full_path) 130 | 131 | @app.get("/") 132 | async def root(): 133 | return {"message": "Good morning, and in case I don't see you, good afternoon, good evening, and good night! 这是“点彩成乐”后端根域名,在域名后面加上`/docs#/`访问后端API文档页面!"} 134 | 135 | app.include_router(router) 136 | 137 | if __name__ == "__main__": 138 | uvicorn.run(app, host="0.0.0.0", port=3001) 139 | -------------------------------------------------------------------------------- /MozartsTouch/utils/txt_converter.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from pathlib import Path 3 | import httpx 4 | import yaml 5 | from loguru import logger 6 | 7 | module_path = Path(__file__).resolve().parent.parent 8 | with open(module_path / 'config.yaml', 'r', encoding='utf8') as file: 9 | config = yaml.safe_load(file) 10 | 11 | class TxtConverter: 12 | def __init__(self): 13 | self.use_llm = config.get('USE_LLM', False) 14 | if self.use_llm: 15 | self.model = config['DEFAULT_LLM_MODEL'] 16 | self.api_url = config['LLM_MODEL_CONFIG']['API_BASE_URL'] 17 | self.api_key = config['LLM_MODEL_CONFIG']['API_KEY'] or self._prompt_for_api_key() 18 | self.client = OpenAI( 19 | base_url=self.api_url, 20 | api_key=self.api_key, 21 | http_client=httpx.Client(base_url=self.api_url, follow_redirects=True), 22 | ) 23 | 24 | def _prompt_for_api_key(self): 25 | api_key = input("Enter your OpenAI API key: ") 26 | config['LLM_MODEL_CONFIG']['API_KEY'] = api_key 27 | with open(module_path / 'config.yaml', 'w') as file: 28 | yaml.dump(config, file) 29 | return api_key 30 | 31 | def process_video_description(self, texts: list): 32 | if not self.use_llm: 33 | return str(texts[0]) 34 | 35 | completion = self.client.chat.completions.create( 36 | model=self.model, 37 | messages=[ 38 | {"role": "system", "content": "You are about to process a sequence of captions, each corresponding to a distinct frame sampled from a video. Your task is to convert these captions into a cohesive, well-structured paragraph. This paragraph should describe the video in a fluid, engaging manner and follows these guidelines: avoiding semantic repetition to the greatest extent, and giving a description in less than 200 characters."}, 39 | {"role": "user", "content": str(texts)} 40 | ] 41 | ) 42 | result = completion.choices[0].message.content 43 | return result 44 | 45 | def txt_converter(self, content, addtxt=None): 46 | if addtxt: 47 | content += addtxt #在这里加入附加文本然后一起丢进llm跑 48 | # logger.info("filtered_prompt result:"+content.encode('utf8', errors='replace').decode('utf8')) 49 | 50 | if not self.use_llm: 51 | return content 52 | completion = self.client.chat.completions.create( 53 | model=self.model, 54 | messages=[ 55 | {"role": "system", "content": "Convert in less than 200 characters this image caption to a very concise musical description with musical terms, so that it can be used as a prompt to generate music through AI model, strictly in English. You need to speculate the mood of the given image caption and add it to the music description. You also need to specify a music genre in the description such as pop, hip hop, funk, electronic, jazz, rock, metal, soul, R&B etc."}, 56 | {"role": "user", "content": "a city with a tower and a castle in the background, a detailed matte painting, art nouveau, epic cinematic painting, kingslanding"}, 57 | {"role": "assistant", "content": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle."}, 58 | {"role": "user", "content": "a group of people sitting on a beach next to a body of water, tourist destination, hawaii"}, 59 | {"role": "assistant", "content": "Pop dance track with catchy melodies, tropical percussion, and upbeat rhythms, perfect for the beach."}, 60 | {"role": "user", "content": content} 61 | ] 62 | ) 63 | converted_result = completion.choices[0].message.content 64 | logger.info("converted result: " + converted_result.encode('utf8', errors='replace').decode('utf8')) 65 | return converted_result 66 | 67 | def video_txt_converter(self, content, addtxt=None): 68 | if addtxt: 69 | content += addtxt 70 | if not self.use_llm: 71 | logger.info(type(content)) 72 | return content 73 | completion = self.client.chat.completions.create( 74 | model=self.model, 75 | messages=[ 76 | {"role": "system", "content": "Convert in less than 200 characters this video caption to a very concise musical description with musical terms, so that it can be used as a prompt to generate music through AI model, strictly in English. You need to speculate the mood of the given video caption and add it to the music description. You also need to specify a music genre in the description such as pop, hip hop, funk, electronic, jazz, rock, metal, soul, R&B etc."}, 77 | {"role": "user", "content": "Two men playing cellos in a room with a piano and a grand glass window backdrop."}, 78 | {"role": "assistant", "content": "Classical chamber music piece featuring cello duet, intricate piano accompaniment, emotive melodies set in an elegant setting, showcasing intricate melodies and emotional depth, the rich harmonies blend seamlessly in an elegant and refined setting, creating a symphonic masterpiece."}, 79 | {"role": "user", "content": "A man with guitar in hand, captivates a large audience on stage at a concert. The crowd watches in awe as the performer delivers a stellar musical performance."}, 80 | {"role": "assistant", "content": "Rock concert with dynamic guitar riffs, precise drumming, and powerful vocals, creating a captivating and electrifying atmosphere, uniting the audience in excitement and musical euphoria."}, 81 | {"role": "user", "content": content} 82 | ] 83 | ) 84 | converted_result = completion.choices[0].message.content 85 | logger.info("converted result: " + converted_result.encode('utf8', errors='replace').decode('utf8')) 86 | return converted_result 87 | 88 | if __name__ == "__main__": 89 | # content = "a wreath hanging from a rope, an album cover inspired, land art, japanese shibari with flowers, hanging from a tree,the empress’ hanging" 90 | content = input() 91 | txt_con = TxtConverter() 92 | converted_result = txt_con.txt_converter(content) -------------------------------------------------------------------------------- /MozartsTouch/main.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | import datetime 4 | from PIL import Image 5 | import time 6 | import argparse 7 | import yaml 8 | from loguru import logger 9 | from moviepy import VideoFileClip, AudioFileClip 10 | 11 | 12 | ''' 13 | Because of Python's feature of chain importing (https://stackoverflow.com/questions/5226893/understanding-a-chain-of-imports-in-python) 14 | you need to use these lines below instead of those above to be able to run the test code after `if __name__=="__main__"` 15 | ''' 16 | if __name__=="__main__": 17 | from utils.image_processing import ImageRecognization 18 | from utils.music_generation import MusicGenerator, MusicGeneratorFactory 19 | from utils.txt_converter import TxtConverter 20 | from utils.preprocess_single import PreProcessVideos 21 | else: 22 | from .utils.image_processing import ImageRecognization 23 | from .utils.music_generation import MusicGenerator, MusicGeneratorFactory 24 | from .utils.txt_converter import TxtConverter 25 | from .utils.preprocess_single import PreProcessVideos 26 | 27 | 28 | module_path = Path(__file__).resolve().parent 29 | with open(module_path / 'config.yaml', 'r', encoding='utf8') as file: 30 | config = yaml.safe_load(file) 31 | 32 | test_mode = config.get('TEST_MODE', False) 33 | logger.info(f"Test mode: {test_mode}") 34 | 35 | def import_ir(): 36 | '''导入图像识别模型''' 37 | ir = ImageRecognization(test_mode=test_mode) 38 | return ir 39 | 40 | def import_music_generator(): 41 | start_time = time.time() 42 | music_model = config['DEFAULT_MUSIC_MODEL'] 43 | if test_mode: 44 | mg = MusicGeneratorFactory.create_music_generator("test") 45 | else: 46 | mg = MusicGeneratorFactory.create_music_generator(music_model) 47 | logger.info(f"[TIME] taken to load Music Generation module {music_model}: {time.time() - start_time :.2f}s") 48 | return mg 49 | 50 | 51 | class Entry: 52 | '''每个Entry代表一次用户输入,然后调用自己的方法对输入进行处理以得到生成结果''' 53 | def __init__(self, image_recog:ImageRecognization, music_gen: MusicGenerator,\ 54 | music_duration: int, addtxt:str, output_folder:Path, img:Image=None, video_path:Path=None) -> None: 55 | self.img = img 56 | self.video_path = video_path 57 | self.txt = None 58 | self.txt_con = TxtConverter() 59 | self.converted_txt = None 60 | self.addtxt = addtxt # 追加文本输入 61 | self.image_recog = image_recog # 使用传入的图像识别模型对象 62 | self.music_gen = music_gen # 使用传入的音乐生成对象 63 | self.music_duration = music_duration 64 | self.output_folder = output_folder 65 | self.timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # 记录用户上传时间作为标识符 66 | self.result_urls = None 67 | self.music_bytes_io = None 68 | 69 | def init_video(self): 70 | assert self.img is None 71 | # 将视频帧进行采样并分别识别,同时获取视频长度作为self.music_duration,之后调用self.video_txt_descriper合成描述视频的一段话 72 | video_processor = PreProcessVideos( 73 | str(self.video_path), 74 | self.image_recog, 75 | prompt_amount=config['CAPTION_MODEL_CONFIG']['VIDEO_SAMPLE_AMOUNT'] 76 | ) 77 | video_frame_texts = video_processor.process_video() 78 | self.music_duration = video_processor.video_seconds + 1 79 | self.video_txt_descriper(video_frame_texts) 80 | 81 | 82 | def img2txt(self): 83 | assert self.video_path is None 84 | '''进行图像识别''' 85 | self.txt = self.image_recog.img2txt(self.img) 86 | 87 | def txt_converter(self): 88 | '''利用LLM优化已有的图片描述文本''' 89 | self.converted_txt = self.txt_con.txt_converter(self.txt, self.addtxt) # 追加一个附加输入,具体改动参见txt_converter 90 | 91 | def video_txt_descriper(self, texts): 92 | '''将每帧的描述文本转换为视频整体的描述文本''' 93 | self.txt = self.txt_con.process_video_description(texts) 94 | logger.info(f"Video description: {self.txt}") 95 | 96 | def video_txt_converter(self): 97 | '''利用LLM优化已有的视频描述文本''' 98 | self.converted_txt = self.txt_con.video_txt_converter(self.txt, self.addtxt) # 追加一个附加输入,具体改动参见txt_converter 99 | 100 | def txt2music(self): 101 | '''根据文本进行音乐生成,获取生成的音乐的BytesIO或URL''' 102 | assert self.music_duration 103 | if self.music_gen.model_name.startswith("Suno"): 104 | self.result_urls = self.music_gen.generate(self.converted_txt, self.music_duration) 105 | else: 106 | self.music_bytes_io = self.music_gen.generate(self.converted_txt, self.music_duration) 107 | 108 | def save_to_file(self): 109 | '''将音乐保存到`/outputs`中,文件名为用户上传时间的时间戳''' 110 | self.output_folder.mkdir(parents=True, exist_ok=True) 111 | 112 | self.result_file_name = f"{self.timestamp}.wav" 113 | file_path = self.output_folder / self.result_file_name 114 | 115 | with open(file_path, "wb") as music_file: 116 | music_file.write(self.music_bytes_io.getvalue()) 117 | 118 | logger.info(f"音乐已保存至 {file_path}") 119 | 120 | return self.result_file_name 121 | def merge_audio_video(self): 122 | '''合成原视频与生成的音乐''' 123 | # 读取视频文件 124 | video_clip = VideoFileClip(self.video_path) 125 | 126 | # 读取音频数据流 127 | audio_clip = AudioFileClip(self.output_folder / self.result_file_name) 128 | 129 | # 将音频和视频合成 130 | final_video = video_clip.with_audio(audio_clip) 131 | self.result_video_name = f"{self.timestamp}.mp4" 132 | # 输出到文件 133 | final_video.write_videofile(video_file_path := self.output_folder / self.result_video_name) 134 | logger.info(f"视频已保存至 {video_file_path}") 135 | return self.result_video_name 136 | 137 | def img_to_music_generate(img: Image, music_duration: int, image_recog: ImageRecognization,\ 138 | music_gen: MusicGenerator, output_folder=Path("./outputs"), addtxt: str=None): 139 | '''模型核心过程''' 140 | # 根据输入mode信息获得对应的音乐生成模型类的实例 141 | # mg = mgs[mode] 142 | 143 | # 根据用户输入创建一个类,并传入图像识别和音乐生成模型的实例 144 | entry = Entry(image_recog, music_gen, music_duration, addtxt, output_folder, img=img) 145 | 146 | # 图片转文字 147 | entry.img2txt() 148 | 149 | # 文本优化 150 | entry.txt_converter() 151 | 152 | #文本生成音乐 153 | entry.txt2music() 154 | 155 | if not music_gen.model_name.startswith("Suno"): 156 | # print("Here.") 157 | entry.save_to_file() 158 | 159 | return (entry.txt, entry.converted_txt, entry.result_file_name) 160 | 161 | def video_to_music_generate(video_path: Path, image_recog: ImageRecognization, music_gen: MusicGenerator,\ 162 | output_folder=Path("./outputs"), addtxt: str=None): 163 | '''模型核心过程''' 164 | # 根据用户输入创建一个类,并传入图像识别和音乐生成模型的实例 165 | entry = Entry(image_recog, music_gen, None, addtxt, output_folder, video_path=video_path) 166 | # 视频采样、识别 167 | entry.init_video() 168 | 169 | # 文本优化 170 | entry.video_txt_converter() 171 | 172 | # 文本生成音乐 173 | entry.txt2music() 174 | entry.save_to_file() 175 | 176 | # 合成视频 177 | entry.merge_audio_video() 178 | 179 | 180 | return (entry.txt, entry.converted_txt, entry.result_video_name) 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser(description='Mozart\'s Touch: Multi-modal Music Generation Framework') 184 | 185 | parser.add_argument('-d', '--test', help='Test mode', default=False, action='store_true') 186 | 187 | args = parser.parse_args() 188 | test_mode = args.test # True时关闭img2txt功能,节省运行资源,用于调试程序 189 | 190 | image_recog = import_ir() 191 | music_gen = import_music_generator() 192 | output_folder = module_path / "outputs" 193 | addtxt = None 194 | 195 | # img = Image.open(module_path / "static" / "test.jpg") 196 | # music_duration = 10 197 | # result = img_to_music_generate(img, music_duration, image_recog, music_gen, output_folder, addtxt) 198 | 199 | video_path = module_path / "static" / "stone.mp4" 200 | result = video_to_music_generate(video_path, image_recog, music_gen, output_folder, addtxt) 201 | 202 | key_names = ("prompt", "converted_prompt", "result_file_name") 203 | 204 | result_dict = {key: value for key, value in zip(key_names, result)} 205 | 206 | logger.info(result_dict) -------------------------------------------------------------------------------- /MozartsTouch/utils/preprocess_single.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | import torchvision 4 | import json 5 | import os 6 | import random 7 | import argparse 8 | import decord 9 | 10 | # from einops import rearrange 11 | from torchvision import transforms 12 | from tqdm import tqdm 13 | from PIL import Image 14 | from decord import VideoReader, cpu 15 | # from transformers import BlipProcessor, BlipForConditionalGeneration 16 | from transformers import AutoModelForCausalLM, AutoProcessor 17 | from loguru import logger 18 | if __name__=="__main__": 19 | from image_processing import ImageRecognization 20 | else: 21 | from .image_processing import ImageRecognization 22 | 23 | module_path = Path(__file__).resolve().parent.parent # module_path为模块根目录(`/MozartsTouch`) 24 | 25 | decord.bridge.set_bridge('torch') 26 | 27 | class PreProcessVideos: 28 | def __init__( 29 | self, 30 | video_path, 31 | image_recognization :ImageRecognization, 32 | random_start_frame = False, 33 | clip_frame_data = False, 34 | prompt_amount = 25, 35 | ): 36 | 37 | # Paramaters for parsing videos 38 | self.prompt_amount = prompt_amount 39 | self.video_path = video_path 40 | self.random_start_frame = random_start_frame 41 | self.clip_frame_data = clip_frame_data 42 | self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") 43 | self.video_frames = None 44 | self.video_seconds = None 45 | self.image_recognization = image_recognization 46 | 47 | # Helper parameters 48 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 49 | # self.save_dir = save_dir 50 | 51 | # Config parameters 52 | # self.config_name = config_name 53 | # self.config_save_name = config_save_name 54 | 55 | # Base dict to hold all the data. 56 | # {base_config} 57 | # def build_base_config(self): 58 | # return { 59 | # "name": self.config_name, 60 | # "data": [] 61 | # } 62 | 63 | # Video dict for individual videos. 64 | # {base_config: data -> [{video_path, num_frames, data}]} 65 | # def build_video_config(self, video_path: str, num_frames: int): 66 | # return { 67 | # "video_path": video_path, 68 | # "num_frames": num_frames, 69 | # "data": [] 70 | # } 71 | 72 | # Dict for video frames and prompts / captions. 73 | # Gets the frame index, then gets a caption for the that frame and stores it. 74 | # {base_config: data -> [{name, num_frames, data: {frame_index, prompt}}]} 75 | def build_video_data(self, frame_index: int, prompt: str): 76 | return { 77 | "frame_index": frame_index, 78 | "prompt": prompt 79 | } 80 | 81 | # Process the frames to get the length and image. 82 | # The limit parameter ensures we don't get near the max frame length. 83 | def video_processor( 84 | self, 85 | video_reader: VideoReader, 86 | num_frames: int, 87 | random_start_frame=True, 88 | frame_num=0 89 | ): 90 | 91 | frame_number = ( 92 | random.randrange(0, int(num_frames)) if random_start_frame else frame_num 93 | ) 94 | # logger.info("getting frame: ", frame_number) 95 | frame = video_reader[frame_number].permute(2,0,1) 96 | # logger.info("getting image...") 97 | image = transforms.ToPILImage()(frame).convert("RGB") 98 | return frame_number, image 99 | 100 | def get_frame_range(self, derterministic): 101 | return range(self.prompt_amount) if self.random_start_frame else derterministic 102 | 103 | def image_caption(self, image: Image, task=''): 104 | return self.image_recognization.img2txt(image, task) 105 | 106 | # def get_out_paths(self, prompt, frame_number): 107 | # out_name= f"{prompt}_{str(frame_number)}" 108 | # save_path = f"{self.save_dir}/{self.config_save_name}" 109 | # save_filepath = f"{save_path}/{out_name}.mp4" 110 | 111 | # return out_name, save_path, save_filepath 112 | 113 | # def save_train_config(self, config: dict): 114 | # os.makedirs(self.save_dir, exist_ok=True) 115 | 116 | # save_json = json.dumps(config, indent=4) 117 | # save_dir = f"{self.save_dir}/{self.config_save_name}" 118 | 119 | # with open(f"{save_dir}.json", 'w') as f: 120 | # f.write(save_json) 121 | 122 | def save_video(self, save_path, save_filepath, frames): 123 | os.makedirs(save_path, exist_ok=True) 124 | torchvision.io.write_video(save_filepath, frames, fps=30) 125 | 126 | # Main loop for processing all videos. 127 | def process_video(self) -> list: 128 | video_path = self.video_path 129 | video_frame_list = [] 130 | 131 | if not os.path.exists(video_path): 132 | raise ValueError(f"{video_path} does not exist.") 133 | 134 | # try: 135 | video_reader = VideoReader(video_path, ctx=cpu(0)) 136 | self.video_frames = int(len(video_reader)) 137 | self.video_seconds = self.video_frames // video_reader.get_avg_fps() 138 | frame_step = abs(self.video_frames // self.prompt_amount) 139 | derterministic_range = range(1, abs(self.video_frames - 1), frame_step) if frame_step else range(self.video_frames) 140 | # except: 141 | # logger.error(f"Error loading {video_path}. Video may be unsupported or corrupt.") 142 | # return 143 | 144 | # try: 145 | # video_config = self.build_video_config(video_path, num_frames) 146 | 147 | for i in tqdm( 148 | self.get_frame_range(derterministic_range), 149 | desc=f"Processing {os.path.basename(video_path)}" 150 | ): 151 | 152 | frame_number, image = self.video_processor( 153 | video_reader, 154 | self.video_frames, 155 | self.random_start_frame, 156 | frame_num=i 157 | ) 158 | 159 | prompt = self.image_caption(image) 160 | video_data = self.build_video_data(frame_number, prompt) 161 | video_frame_list.append(video_data) 162 | 163 | # except Exception as e: 164 | # logger.error(e) 165 | 166 | 167 | # logger.info(f"Done. Saving train config to {self.save_dir}.") 168 | # self.save_train_config(config) 169 | # logger.info(video_frame_list) 170 | return video_frame_list 171 | 172 | 173 | if __name__ == "__main__": 174 | parser = argparse.ArgumentParser() 175 | 176 | # parser.add_argument('--config_name', help="The name of the configuration.", type=str, default='My Config') 177 | # parser.add_argument('--config_save_name', help="The name of the config file that's saved.", type=str, default='my_config') 178 | parser.add_argument('--video_path', help="The directory where your videos are located.", type=str, default='/root/Mozart-Diancai/Video-BLIP2-Preprocessor/videos') 179 | # parser.add_argument( 180 | # '--random_start_frame', 181 | # help="Use random start frame when processing videos. Good for long videos where frames have different scenes and meanings.", 182 | # action='store_true', 183 | # default=False 184 | # ) 185 | parser.add_argument( 186 | '--clip_frame_data', 187 | help="Save the frames as video clips to HDD/SDD. Videos clips are saved in the same folder as your json directory.", 188 | action='store_true', 189 | default=False 190 | ) 191 | parser.add_argument('--max_frames', help="Maximum frames for clips when --clip_frame_data is enabled.", type=int, default=60) 192 | parser.add_argument('--beam_amount', help="Amount for BLIP beam search.", type=int, default=7) 193 | parser.add_argument('--prompt_amount', help="The amount of prompts per video that is processed.", type=int, default=25) 194 | parser.add_argument('--min_prompt_length', help="Minimum words required in prompt.", type=int, default=15) 195 | parser.add_argument('--max_prompt_length', help="Maximum words required in prompt.", type=int, default=30) 196 | # parser.add_argument('--save_dir', help="The directory to save the config to.", type=str, default=f"{os.getcwd()}/train_data") 197 | 198 | args = parser.parse_args() 199 | 200 | 201 | #processor = PreProcessVideos(**vars(args)) 202 | #processor.process_videos() 203 | 204 | #json_file_path = 'Video-BLIP2-Preprocessor/train_data/my_videos.json' 205 | #content = process_video_description(json_file_path) 206 | #txt_con = TxtConverter() 207 | #converted_result = txt_con.txt_converter(content) --------------------------------------------------------------------------------