├── LICENCE ├── README.md ├── neurosync_local_api.py └── utils ├── audio ├── extraction │ └── extract_features.py └── processing │ └── audio_processing.py ├── config.py ├── generate_face_shapes.py └── model └── model.py /LICENCE: -------------------------------------------------------------------------------- 1 | # NeuroSync Local_API 2 | 3 | This software is licensed under a **dual-license model**: 4 | 5 | ## 1. Free License (MIT License) 6 | For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License**: 7 | 8 | ``` 9 | MIT License 10 | 11 | Copyright (c) 2024 NeuroSync Local_API 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | ``` 31 | 32 | ## 2. Commercial License (For Businesses Earning $1M+ Per Year) 33 | Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain a **commercial license** to use this software. 34 | 35 | - To acquire a commercial license, please contact us. 36 | - This commercial license allows for **enterprise-level support, priority feature requests, and extended rights**. 37 | 38 | ## Compliance 39 | By using this software, you agree to these licensing terms. If your business exceeds the revenue threshold, you must transition to a commercial license or **cease using the software**. 40 | 41 | © 2025 NeuroSync Local_API 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuroSync Local API 2 | 3 | ## 29/03/2025 Update to model.pth and model.py 4 | 5 | - Increased accuracy (timing and overall face shows more natural movement overall, brows, squint, cheeks + mouth shapes) 6 | - More smoothness during playback (flappy mouth be gone in most cases, even when speaking quickly) 7 | - Works better with more voices and styles of speaking. 8 | - This preview of the new model is a modest increase in capability that requires both model.pth and model.py to be replace with the new versions. 9 | 10 | [Download the model from Hugging Face](https://huggingface.co/AnimaVR/NEUROSYNC_Audio_To_Face_Blendshape) 11 | 12 | These increases in quality come from better data and removal of "global" positional encoding from the model and staying with ropes positional encoding within the MHA block. 13 | 14 | ## Overview 15 | 16 | The **NeuroSync Local API** allows you to host the audio-to-face blendshape transformer model locally. This API processes audio data and outputs facial blendshape coefficients, which can be streamed directly to Unreal Engine using the **NeuroSync Player** and LiveLink. 17 | 18 | ### Features: 19 | - Host the model locally for full control 20 | - Process audio files and generate facial blendshapes 21 | 22 | ## NeuroSync Model 23 | 24 | To generate the blendshapes, you can: 25 | 26 | - [Download the model from Hugging Face](https://huggingface.co/AnimaVR/NEUROSYNC_Audio_To_Face_Blendshape) 27 | 28 | ## Player Requirement 29 | 30 | To stream the generated blendshapes into Unreal Engine, you will need the **NeuroSync Player**. The Player allows for real-time integration with Unreal Engine via LiveLink. 31 | 32 | You can find the NeuroSync Player and instructions on setting it up here: 33 | 34 | - [NeuroSync Player GitHub Repository](https://github.com/AnimaVR/NeuroSync_Player) 35 | 36 | Visit [neurosync.info](https://neurosync.info) 37 | 38 | ## Talk to a NeuroSync prototype live on Twitch : [Visit Mai](https://www.twitch.tv/mai_anima_ai) 39 | -------------------------------------------------------------------------------- /neurosync_local_api.py: -------------------------------------------------------------------------------- 1 | 2 | # This software is licensed under a **dual-license model** 3 | # For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License** 4 | # Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially. 5 | 6 | from flask import request, jsonify 7 | import numpy as np 8 | import torch 9 | import flask 10 | 11 | from utils.generate_face_shapes import generate_facial_data_from_bytes 12 | from utils.model.model import load_model 13 | from utils.config import config 14 | 15 | app = flask.Flask(__name__) 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | print("Activated device:", device) 19 | 20 | model_path = 'utils/model/model.pth' 21 | blendshape_model = load_model(model_path, config, device) 22 | 23 | @app.route('/audio_to_blendshapes', methods=['POST']) 24 | def audio_to_blendshapes_route(): 25 | audio_bytes = request.data 26 | generated_facial_data = generate_facial_data_from_bytes(audio_bytes, blendshape_model, device, config) 27 | generated_facial_data_list = generated_facial_data.tolist() if isinstance(generated_facial_data, np.ndarray) else generated_facial_data 28 | 29 | return jsonify({'blendshapes': generated_facial_data_list}) 30 | 31 | if __name__ == '__main__': 32 | app.run(host='127.0.0.1', port=5000) 33 | -------------------------------------------------------------------------------- /utils/audio/extraction/extract_features.py: -------------------------------------------------------------------------------- 1 | # This software is licensed under a **dual-license model** 2 | # For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License** 3 | # Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially. 4 | 5 | # extract_features.py 6 | import io 7 | import librosa 8 | import numpy as np 9 | import scipy.signal 10 | 11 | 12 | def extract_audio_features(audio_input, sr=88200, from_bytes=False): 13 | try: 14 | if from_bytes: 15 | y, sr = load_audio_from_bytes(audio_input, sr) 16 | else: 17 | y, sr = load_and_preprocess_audio(audio_input, sr) 18 | except Exception as e: 19 | print(f"Loading as WAV failed: {e}\nFalling back to PCM loading.") 20 | y = load_pcm_audio_from_bytes(audio_input) 21 | 22 | frame_length = int(0.01667 * sr) # Frame length set to 0.01667 seconds (~60 fps) 23 | hop_length = frame_length // 2 # 2x overlap for smoother transitions 24 | min_frames = 9 # Minimum number of frames needed for delta calculation 25 | 26 | num_frames = (len(y) - frame_length) // hop_length + 1 27 | 28 | if num_frames < min_frames: 29 | print(f"Audio file is too short: {num_frames} frames, required: {min_frames} frames") 30 | return None, None 31 | 32 | combined_features = extract_and_combine_features(y, sr, frame_length, hop_length) 33 | 34 | return combined_features, y 35 | 36 | def extract_and_combine_features(y, sr, frame_length, hop_length, include_autocorr=True): 37 | 38 | all_features = [] 39 | mfcc_features = extract_mfcc_features(y, sr, frame_length, hop_length) 40 | all_features.append(mfcc_features) 41 | 42 | if include_autocorr: 43 | autocorr_features = extract_autocorrelation_features( 44 | y, sr, frame_length, hop_length 45 | ) 46 | all_features.append(autocorr_features) 47 | 48 | combined_features = np.hstack(all_features) 49 | 50 | return combined_features 51 | 52 | 53 | def extract_mfcc_features(y, sr, frame_length, hop_length, num_mfcc=23): 54 | mfcc_features = extract_overlapping_mfcc(y, sr, num_mfcc, frame_length, hop_length) 55 | reduced_mfcc_features = reduce_features(mfcc_features) 56 | return reduced_mfcc_features.T 57 | 58 | def cepstral_mean_variance_normalization(mfcc): 59 | mean = np.mean(mfcc, axis=1, keepdims=True) 60 | std = np.std(mfcc, axis=1, keepdims=True) 61 | return (mfcc - mean) / (std + 1e-10) 62 | 63 | 64 | def extract_overlapping_mfcc(chunk, sr, num_mfcc, frame_length, hop_length, include_deltas=True, include_cepstral=True, threshold=1e-5): 65 | mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=num_mfcc, n_fft=frame_length, hop_length=hop_length) 66 | if include_cepstral: 67 | mfcc = cepstral_mean_variance_normalization(mfcc) 68 | 69 | if include_deltas: 70 | delta_mfcc = librosa.feature.delta(mfcc) 71 | delta2_mfcc = librosa.feature.delta(mfcc, order=2) 72 | combined_mfcc = np.vstack([mfcc, delta_mfcc, delta2_mfcc]) # Stack original MFCCs with deltas 73 | return combined_mfcc 74 | else: 75 | return mfcc 76 | 77 | 78 | def reduce_features(features): 79 | num_frames = features.shape[1] 80 | paired_frames = features[:, :num_frames // 2 * 2].reshape(features.shape[0], -1, 2) 81 | reduced_frames = paired_frames.mean(axis=2) 82 | 83 | if num_frames % 2 == 1: 84 | last_frame = features[:, -1].reshape(-1, 1) 85 | reduced_final_features = np.hstack((reduced_frames, last_frame)) 86 | else: 87 | reduced_final_features = reduced_frames 88 | 89 | return reduced_final_features 90 | 91 | 92 | 93 | def extract_overlapping_autocorr(y, sr, frame_length, hop_length, num_autocorr_coeff=187, pad_signal=True, padding_mode="reflect", trim_padded=False): 94 | if pad_signal: 95 | pad = frame_length // 2 96 | y_padded = np.pad(y, pad_width=pad, mode=padding_mode) 97 | else: 98 | y_padded = y 99 | 100 | frames = librosa.util.frame(y_padded, frame_length=frame_length, hop_length=hop_length) 101 | if pad_signal and trim_padded: 102 | num_frames = frames.shape[1] 103 | start_indices = np.arange(num_frames) * hop_length 104 | valid_idx = np.where((start_indices >= pad) & (start_indices + frame_length <= len(y) + pad))[0] 105 | frames = frames[:, valid_idx] 106 | 107 | frames = frames - np.mean(frames, axis=0, keepdims=True) 108 | hann_window = np.hanning(frame_length) 109 | windowed_frames = frames * hann_window[:, np.newaxis] 110 | 111 | autocorr_list = [] 112 | for frame in windowed_frames.T: 113 | full_corr = np.correlate(frame, frame, mode='full') 114 | mid = frame_length - 1 # Zero-lag index. 115 | # Extract `num_autocorr_coeff + 1` to include the first column initially 116 | wanted = full_corr[mid: mid + num_autocorr_coeff + 1] 117 | # Normalize by the zero-lag (energy) if nonzero. 118 | if wanted[0] != 0: 119 | wanted = wanted / wanted[0] 120 | autocorr_list.append(wanted) 121 | 122 | # Convert list to array and transpose so that shape is (num_autocorr_coeff + 1, num_valid_frames) 123 | autocorr_features = np.array(autocorr_list).T 124 | # Remove the first coefficient to avoid redundancy 125 | autocorr_features = autocorr_features[1:, :] 126 | 127 | autocorr_features = fix_edge_frames_autocorr(autocorr_features) 128 | 129 | return autocorr_features 130 | 131 | 132 | def fix_edge_frames_autocorr(autocorr_features, zero_threshold=1e-7): 133 | """If the first or last frame is near all-zero, replicate from adjacent frames.""" 134 | # Check first frame energy 135 | if np.all(np.abs(autocorr_features[:, 0]) < zero_threshold): 136 | autocorr_features[:, 0] = autocorr_features[:, 1] 137 | # Check last frame energy 138 | if np.all(np.abs(autocorr_features[:, -1]) < zero_threshold): 139 | autocorr_features[:, -1] = autocorr_features[:, -2] 140 | return autocorr_features 141 | 142 | def extract_autocorrelation_features( 143 | y, sr, frame_length, hop_length, include_deltas=False 144 | ): 145 | """ 146 | Extract autocorrelation features, optionally with deltas/delta-deltas, 147 | then align with the MFCC frame count, reduce, and handle first/last frames. 148 | """ 149 | autocorr_features = extract_overlapping_autocorr( 150 | y, sr, frame_length, hop_length 151 | ) 152 | 153 | if include_deltas: 154 | autocorr_features = compute_autocorr_with_deltas(autocorr_features) 155 | 156 | autocorr_features_reduced = reduce_features(autocorr_features) 157 | 158 | return autocorr_features_reduced.T 159 | 160 | 161 | def compute_autocorr_with_deltas(autocorr_base): 162 | delta_ac = librosa.feature.delta(autocorr_base) 163 | delta2_ac = librosa.feature.delta(autocorr_base, order=2) 164 | combined_autocorr = np.vstack([autocorr_base, delta_ac, delta2_ac]) 165 | return combined_autocorr 166 | 167 | def load_and_preprocess_audio(audio_path, sr=88200): 168 | y, sr = load_audio(audio_path, sr) 169 | if sr != 88200: 170 | y = librosa.resample(y, orig_sr=sr, target_sr=88200) 171 | sr = 88200 172 | 173 | max_val = np.max(np.abs(y)) 174 | if max_val > 0: 175 | y = y / max_val 176 | 177 | return y, sr 178 | 179 | def load_audio(audio_path, sr=88200): 180 | y, sr = librosa.load(audio_path, sr=sr) 181 | print(f"Loaded audio file '{audio_path}' with sample rate {sr}") 182 | return y, sr 183 | 184 | def load_audio_from_bytes(audio_bytes, sr=88200): 185 | audio_file = io.BytesIO(audio_bytes) 186 | y, sr = librosa.load(audio_file, sr=sr) 187 | 188 | max_val = np.max(np.abs(y)) 189 | if max_val > 0: 190 | y = y / max_val 191 | 192 | return y, sr 193 | 194 | def load_audio_file_from_memory(audio_bytes, sr=88200): 195 | """Load audio from memory bytes.""" 196 | y, sr = librosa.load(io.BytesIO(audio_bytes), sr=sr) 197 | print(f"Loaded audio data with sample rate {sr}") 198 | 199 | max_val = np.max(np.abs(y)) 200 | if max_val > 0: 201 | y = y / max_val 202 | 203 | return y, sr 204 | 205 | 206 | 207 | 208 | def load_pcm_audio_from_bytes(audio_bytes, sr=22050, channels=1, sample_width=2): 209 | """ 210 | Load raw PCM bytes into a normalized numpy array and upsample to 88200 Hz. 211 | Assumes little-endian, 16-bit PCM data. 212 | """ 213 | # Determine the appropriate numpy dtype. 214 | if sample_width == 2: 215 | dtype = np.int16 216 | max_val = 32768.0 217 | else: 218 | raise ValueError("Unsupported sample width") 219 | 220 | # Convert bytes to numpy array. 221 | data = np.frombuffer(audio_bytes, dtype=dtype) 222 | 223 | # If stereo or more channels, reshape accordingly. 224 | if channels > 1: 225 | data = data.reshape(-1, channels) 226 | 227 | # Normalize the data to range [-1, 1] 228 | y = data.astype(np.float32) / max_val 229 | 230 | # Upsample the audio from the current sample rate to 88200 Hz. 231 | target_sr = 88200 232 | if sr != target_sr: 233 | # Calculate the number of samples in the resampled signal. 234 | num_samples = int(len(y) * target_sr / sr) 235 | if channels > 1: 236 | # Resample each channel separately. 237 | y_resampled = np.zeros((num_samples, channels), dtype=np.float32) 238 | for ch in range(channels): 239 | y_resampled[:, ch] = scipy.signal.resample(y[:, ch], num_samples) 240 | else: 241 | y_resampled = scipy.signal.resample(y, num_samples) 242 | y = y_resampled 243 | sr = target_sr 244 | 245 | return y 246 | -------------------------------------------------------------------------------- /utils/audio/processing/audio_processing.py: -------------------------------------------------------------------------------- 1 | # This software is licensed under a **dual-license model** 2 | # For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License** 3 | # Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially. 4 | 5 | # audio_processing.py 6 | 7 | import numpy as np 8 | import torch 9 | from torch.cuda.amp import autocast 10 | 11 | def decode_audio_chunk(audio_chunk, model, device, config): 12 | use_half_precision = config.get("use_half_precision", True) 13 | dtype = torch.float16 if use_half_precision else torch.float32 14 | src_tensor = torch.tensor(audio_chunk, dtype=dtype).unsqueeze(0).to(device) 15 | 16 | with torch.no_grad(): 17 | if use_half_precision: 18 | 19 | with autocast(dtype=torch.float16): 20 | encoder_outputs = model.encoder(src_tensor) 21 | output_sequence = model.decoder(encoder_outputs) 22 | else: 23 | encoder_outputs = model.encoder(src_tensor) 24 | output_sequence = model.decoder(encoder_outputs) 25 | 26 | decoded_outputs = output_sequence.squeeze(0).cpu().numpy() 27 | return decoded_outputs 28 | 29 | 30 | def concatenate_outputs(all_decoded_outputs, num_frames): 31 | final_decoded_outputs = np.concatenate(all_decoded_outputs, axis=0) 32 | final_decoded_outputs = final_decoded_outputs[:num_frames] 33 | return final_decoded_outputs 34 | 35 | def ensure_2d(final_decoded_outputs): 36 | if final_decoded_outputs.ndim == 3: 37 | final_decoded_outputs = final_decoded_outputs.reshape(-1, final_decoded_outputs.shape[-1]) 38 | return final_decoded_outputs 39 | 40 | def pad_audio_chunk(audio_chunk, frame_length, num_features, pad_mode='replicate'): 41 | if audio_chunk.shape[0] < frame_length: 42 | pad_length = frame_length - audio_chunk.shape[0] 43 | 44 | if pad_mode == 'reflect': 45 | padding = np.pad( 46 | audio_chunk, 47 | pad_width=((0, pad_length), (0, 0)), 48 | mode='reflect' 49 | ) 50 | audio_chunk = np.vstack((audio_chunk, padding[-pad_length:, :num_features])) 51 | 52 | elif pad_mode == 'replicate': 53 | last_frame = audio_chunk[-1:] 54 | replication = np.tile(last_frame, (pad_length, 1)) 55 | audio_chunk = np.vstack((audio_chunk, replication)) 56 | 57 | else: 58 | raise ValueError(f"Unsupported pad_mode: {pad_mode}. Choose 'reflect' or 'replicate'.") 59 | 60 | return audio_chunk 61 | 62 | 63 | def blend_chunks(chunk1, chunk2, overlap): 64 | actual_overlap = min(overlap, len(chunk1), len(chunk2)) 65 | if actual_overlap == 0: 66 | return np.vstack((chunk1, chunk2)) 67 | 68 | blended_chunk = np.copy(chunk1) 69 | for i in range(actual_overlap): 70 | alpha = i / actual_overlap 71 | blended_chunk[-actual_overlap + i] = (1 - alpha) * chunk1[-actual_overlap + i] + alpha * chunk2[i] 72 | 73 | return np.vstack((blended_chunk, chunk2[actual_overlap:])) 74 | 75 | def process_audio_features(audio_features, model, device, config): 76 | frame_length = config['frame_size'] 77 | overlap = config.get('overlap', 32) 78 | num_features = audio_features.shape[1] 79 | num_frames = audio_features.shape[0] 80 | all_decoded_outputs = [] 81 | model.eval() 82 | 83 | start_idx = 0 84 | while start_idx < num_frames: 85 | end_idx = min(start_idx + frame_length, num_frames) 86 | audio_chunk = audio_features[start_idx:end_idx] 87 | audio_chunk = pad_audio_chunk(audio_chunk, frame_length, num_features) 88 | decoded_outputs = decode_audio_chunk(audio_chunk, model, device, config) 89 | decoded_outputs = decoded_outputs[:end_idx - start_idx] 90 | 91 | if all_decoded_outputs: 92 | last_chunk = all_decoded_outputs.pop() 93 | blended_chunk = blend_chunks(last_chunk, decoded_outputs, overlap) 94 | all_decoded_outputs.append(blended_chunk) 95 | else: 96 | all_decoded_outputs.append(decoded_outputs) 97 | 98 | start_idx += frame_length - overlap 99 | 100 | current_length = sum(len(chunk) for chunk in all_decoded_outputs) 101 | if current_length < num_frames: 102 | remaining_frames = num_frames - current_length 103 | final_chunk_start = num_frames - remaining_frames 104 | audio_chunk = audio_features[final_chunk_start:num_frames] 105 | audio_chunk = pad_audio_chunk(audio_chunk, frame_length, num_features) 106 | decoded_outputs = decode_audio_chunk(audio_chunk, model, device, config) 107 | all_decoded_outputs.append(decoded_outputs[:remaining_frames]) 108 | 109 | final_decoded_outputs = np.concatenate(all_decoded_outputs, axis=0)[:num_frames] 110 | final_decoded_outputs = ensure_2d(final_decoded_outputs) 111 | 112 | final_decoded_outputs[:, :61] /= 100 113 | 114 | ease_duration_frames = min(int(0.1 * 60), final_decoded_outputs.shape[0]) 115 | easing_factors = np.linspace(0, 1, ease_duration_frames)[:, None] 116 | final_decoded_outputs[:ease_duration_frames] *= easing_factors 117 | 118 | final_decoded_outputs = zero_columns(final_decoded_outputs) 119 | 120 | return final_decoded_outputs 121 | 122 | 123 | def zero_columns(data): 124 | columns_to_zero = [0, 1, 2, 3, 4, 7, 8, 9, 10, 11, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60] 125 | modified_data = np.copy(data) 126 | modified_data[:, columns_to_zero] = 0 127 | return modified_data 128 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | config = { 2 | 'sr': 88200, 3 | 'frame_rate': 60, 4 | 'hidden_dim': 1024, 5 | 'n_layers': 8, 6 | 'num_heads': 16, 7 | 'dropout': 0.0, 8 | 'output_dim': 68, # if you trained your own, this should also be 61 9 | 'input_dim': 256, 10 | 'frame_size': 128, 11 | 'use_half_precision': False 12 | } 13 | -------------------------------------------------------------------------------- /utils/generate_face_shapes.py: -------------------------------------------------------------------------------- 1 | # This software is licensed under a **dual-license model** 2 | # For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License** 3 | # Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially. 4 | 5 | # generate_face_shapes.py 6 | 7 | import numpy as np 8 | 9 | from utils.audio.extraction.extract_features import extract_audio_features 10 | from utils.audio.processing.audio_processing import process_audio_features 11 | 12 | def generate_facial_data_from_bytes(audio_bytes, model, device, config): 13 | 14 | audio_features, y = extract_audio_features(audio_bytes, from_bytes=True) 15 | 16 | if audio_features is None or y is None: 17 | return [], np.array([]) 18 | 19 | final_decoded_outputs = process_audio_features(audio_features, model, device, config) 20 | 21 | return final_decoded_outputs 22 | 23 | -------------------------------------------------------------------------------- /utils/model/model.py: -------------------------------------------------------------------------------- 1 | # This software is licensed under a **dual-license model** 2 | # For individuals and businesses earning **under $1M per year**, this software is licensed under the **MIT License** 3 | # Businesses or organizations with **annual revenue of $1,000,000 or more** must obtain permission to use this software commercially. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | def load_model(model_path, config, device): 10 | device = torch.device(device) 11 | 12 | # Retrieve the half precision setting from the config 13 | use_half_precision = config.get('use_half_precision', True) 14 | 15 | # 🔥 NEW: Check for CUDA and cuDNN availability. 16 | # If half precision is requested but CUDA or cuDNN are not available, 17 | # fall back to full precision and update the config. 18 | if use_half_precision: 19 | if not (device.type == 'cuda' and torch.cuda.is_available() and torch.backends.cudnn.enabled): 20 | print("⚠ Half-precision requested but CUDA or cuDNN not available. Falling back to full precision.") 21 | use_half_precision = False 22 | config['use_half_precision'] = False # Update config to reflect the fallback 23 | 24 | hidden_dim = config['hidden_dim'] 25 | n_layers = config['n_layers'] 26 | num_heads = config['num_heads'] 27 | 28 | encoder = Encoder(config['input_dim'], hidden_dim, n_layers, num_heads) 29 | decoder = Decoder(config['output_dim'], hidden_dim, n_layers, num_heads) 30 | model = Seq2Seq(encoder, decoder, device).to(device) 31 | 32 | state_dict = torch.load(model_path, map_location=device) 33 | model.load_state_dict(state_dict, strict=True) 34 | 35 | # Convert the model to half precision if applicable 36 | if use_half_precision and device.type == 'cuda': 37 | model = model.to(torch.float16) 38 | print("⚡ Model converted to float16 (half-precision).") 39 | else: 40 | print("🚫 Half-precision not applied (CPU or unsupported GPU or False set in config).") 41 | 42 | model.eval() 43 | return model 44 | 45 | 46 | 47 | # ------------------------------------------------------------------------------------------- 48 | # Seq2Seq Model 49 | # ------------------------------------------------------------------------------------------- 50 | class Seq2Seq(nn.Module): 51 | def __init__(self, encoder, decoder, device): 52 | super(Seq2Seq, self).__init__() 53 | self.encoder = encoder 54 | self.decoder = decoder 55 | self.device = device 56 | 57 | def forward(self, src): 58 | encoder_outputs = self.encoder(src) 59 | output = self.decoder(encoder_outputs) 60 | return output 61 | 62 | # ------------------------------------------------------------------------------------------- 63 | # Rotary Positional Embedding (RoPE) for Local Attention 64 | # ------------------------------------------------------------------------------------------- 65 | def apply_rope_qk(q, k, use_local_positional_encoding=True): 66 | if not use_local_positional_encoding: 67 | return q, k # Return unmodified q, k if RoPE is disabled 68 | 69 | batch_size, num_heads, seq_len, head_dim = q.size() 70 | assert head_dim % 2 == 0, "head_dim must be even for RoPE" 71 | 72 | position = torch.arange(seq_len, dtype=torch.float, device=q.device).unsqueeze(1) # (seq_len, 1) 73 | dim_indices = torch.arange(0, head_dim, 2, dtype=torch.float, device=q.device) # (head_dim // 2) 74 | div_term = torch.exp(-torch.log(torch.tensor(10000.0)) * dim_indices / head_dim) 75 | 76 | angle = position * div_term # (seq_len, head_dim // 2) 77 | sin = torch.sin(angle).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim // 2) 78 | cos = torch.cos(angle).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim // 2) 79 | 80 | def rope_transform(x): 81 | x1, x2 = x[..., ::2], x[..., 1::2] # Split into even and odd parts 82 | x_rope_even = x1 * cos - x2 * sin 83 | x_rope_odd = x1 * sin + x2 * cos 84 | return torch.stack([x_rope_even, x_rope_odd], dim=-1).flatten(-2) 85 | 86 | q = rope_transform(q) 87 | k = rope_transform(k) 88 | return q, k 89 | 90 | 91 | # ------------------------------------------------------------------------------------------- 92 | # Multi-Head Attention with RoPE 93 | # ------------------------------------------------------------------------------------------- 94 | class MultiHeadAttention(nn.Module): 95 | def __init__(self, hidden_dim, num_heads, dropout=0.0): 96 | super(MultiHeadAttention, self).__init__() 97 | assert hidden_dim % num_heads == 0, "Hidden dimension must be divisible by the number of heads" 98 | self.num_heads = num_heads 99 | self.head_dim = hidden_dim // num_heads 100 | self.scaling = self.head_dim ** -0.5 101 | 102 | self.q_linear = nn.Linear(hidden_dim, hidden_dim) 103 | self.k_linear = nn.Linear(hidden_dim, hidden_dim) 104 | self.v_linear = nn.Linear(hidden_dim, hidden_dim) 105 | self.out_linear = nn.Linear(hidden_dim, hidden_dim) 106 | 107 | self.attn_dropout = nn.Dropout(dropout) 108 | self.resid_dropout = nn.Dropout(dropout) 109 | self.dropout = dropout 110 | 111 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 112 | if not self.flash: 113 | print("WARNING: Flash Attention requires PyTorch >= 2.0") 114 | 115 | def forward(self, query, key, value, mask=None): 116 | batch_size = query.size(0) 117 | 118 | query = self.q_linear(query) 119 | key = self.k_linear(key) 120 | value = self.v_linear(value) 121 | 122 | # Reshape to (B, H, L, D) 123 | query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 124 | key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 125 | value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 126 | 127 | # Apply RoPE to Q and K (if enabled) 128 | query, key = apply_rope_qk(query, key) 129 | 130 | if self.flash: 131 | attn_output = torch.nn.functional.scaled_dot_product_attention( 132 | query, key, value, attn_mask=mask, dropout_p=self.dropout if self.training else 0) 133 | attn_weights = None 134 | else: 135 | scores = torch.matmul(query, key.transpose(-2, -1)) * self.scaling 136 | if mask is not None: 137 | scores = scores.masked_fill(mask == 0, float('-inf')) 138 | attn_weights = F.softmax(scores, dim=-1) 139 | attn_weights = self.attn_dropout(attn_weights) 140 | attn_output = torch.matmul(attn_weights, value) 141 | 142 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim) 143 | output = self.out_linear(attn_output) 144 | output = self.resid_dropout(output) 145 | 146 | return output, attn_weights 147 | 148 | # ------------------------------------------------------------------------------------------- 149 | # Feed-Forward Network 150 | # ------------------------------------------------------------------------------------------- 151 | class FeedForwardNetwork(nn.Module): 152 | def __init__(self, hidden_dim, dim_feedforward=2048, dropout=0.0): 153 | super(FeedForwardNetwork, self).__init__() 154 | self.linear1 = nn.Linear(hidden_dim, dim_feedforward) 155 | self.dropout = nn.Dropout(dropout) 156 | self.linear2 = nn.Linear(dim_feedforward, hidden_dim) 157 | 158 | def forward(self, x): 159 | x = self.linear1(x) 160 | x = F.relu(x) 161 | x = self.dropout(x) 162 | x = self.linear2(x) 163 | return x 164 | 165 | # ------------------------------------------------------------------------------------------- 166 | # Custom Transformer Encoder/Decoder 167 | # ------------------------------------------------------------------------------------------- 168 | class CustomTransformerEncoderLayer(nn.Module): 169 | def __init__(self, hidden_dim, num_heads, dropout=0.0): 170 | super(CustomTransformerEncoderLayer, self).__init__() 171 | self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout) 172 | self.ffn = FeedForwardNetwork(hidden_dim, 4 * hidden_dim, dropout) 173 | self.norm1 = nn.LayerNorm(hidden_dim) 174 | self.norm2 = nn.LayerNorm(hidden_dim) 175 | self.dropout1 = nn.Dropout(dropout) 176 | self.dropout2 = nn.Dropout(dropout) 177 | 178 | def forward(self, src, mask=None): 179 | src2, _ = self.self_attn(src, src, src, mask) 180 | src = src + self.dropout1(src2) 181 | src = self.norm1(src) 182 | 183 | src2 = self.ffn(src) 184 | src = src + self.dropout2(src2) 185 | src = self.norm2(src) 186 | return src 187 | 188 | class CustomTransformerDecoderLayer(nn.Module): 189 | def __init__(self, hidden_dim, num_heads, dropout=0.0): 190 | super(CustomTransformerDecoderLayer, self).__init__() 191 | self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout) 192 | self.multihead_attn = MultiHeadAttention(hidden_dim, num_heads, dropout) 193 | self.ffn = FeedForwardNetwork(hidden_dim, 4 * hidden_dim, dropout) 194 | self.norm1 = nn.LayerNorm(hidden_dim) 195 | self.norm2 = nn.LayerNorm(hidden_dim) 196 | self.norm3 = nn.LayerNorm(hidden_dim) 197 | self.dropout1 = nn.Dropout(dropout) 198 | self.dropout2 = nn.Dropout(dropout) 199 | self.dropout3 = nn.Dropout(dropout) 200 | 201 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): 202 | tgt2, _ = self.self_attn(tgt, tgt, tgt, tgt_mask) 203 | tgt = tgt + self.dropout1(tgt2) 204 | tgt = self.norm1(tgt) 205 | 206 | tgt2, _ = self.multihead_attn(tgt, memory, memory, memory_mask) 207 | tgt = tgt + self.dropout2(tgt2) 208 | tgt = self.norm2(tgt) 209 | 210 | tgt2 = self.ffn(tgt) 211 | tgt = tgt + self.dropout3(tgt2) 212 | tgt = self.norm3(tgt) 213 | return tgt 214 | 215 | # ------------------------------------------------------------------------------------------- 216 | # Encoder 217 | # ------------------------------------------------------------------------------------------- 218 | class Encoder(nn.Module): 219 | def __init__(self, input_dim, hidden_dim, n_layers, num_heads, dropout=0.0, use_norm=True): 220 | super(Encoder, self).__init__() 221 | self.embedding = nn.Linear(input_dim, hidden_dim) 222 | # CHANGED: Removed global positional encoding as RoPE is used in MHA. 223 | self.transformer_encoder = nn.ModuleList([ 224 | CustomTransformerEncoderLayer(hidden_dim, num_heads, dropout) for _ in range(n_layers) 225 | ]) 226 | self.layer_norm = nn.LayerNorm(hidden_dim) if use_norm else None 227 | 228 | def forward(self, x): 229 | x = self.embedding(x) 230 | # CHANGED: Global positional encoding removed. 231 | for layer in self.transformer_encoder: 232 | x = layer(x) 233 | if self.layer_norm: 234 | x = self.layer_norm(x) 235 | return x 236 | 237 | # ------------------------------------------------------------------------------------------- 238 | # Decoder 239 | # ------------------------------------------------------------------------------------------- 240 | class Decoder(nn.Module): 241 | def __init__(self, output_dim, hidden_dim, n_layers, num_heads, dropout=0.0, use_norm=True): 242 | super(Decoder, self).__init__() 243 | self.transformer_decoder = nn.ModuleList([ 244 | CustomTransformerDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(n_layers) 245 | ]) 246 | self.fc_output = nn.Linear(hidden_dim, output_dim) 247 | self.layer_norm = nn.LayerNorm(hidden_dim) if use_norm else None 248 | 249 | def forward(self, encoder_outputs): 250 | x = encoder_outputs 251 | for layer in self.transformer_decoder: 252 | x = layer(x, encoder_outputs) 253 | if self.layer_norm: 254 | x = self.layer_norm(x) 255 | return self.fc_output(x) 256 | 257 | --------------------------------------------------------------------------------