├── .gitignore ├── BIWI ├── ForProcessing │ ├── FaceData │ │ ├── README.md │ │ └── faces │ │ │ ├── Extract_all.sh │ │ │ ├── delete_vls.sh │ │ │ ├── vertex_process.py │ │ │ └── vl2csv_recursive.m │ └── rest │ │ ├── README.md │ │ └── wav_process.py ├── templates │ ├── BIWI_topology.obj │ └── README.md ├── templates_orig.pkl ├── templates_scaled.pkl ├── vertices_npy │ └── README.md └── wav │ └── README.md ├── Evaluation ├── GroundTruth │ └── README.md ├── render_quant_evaluation.py └── renders │ ├── temp │ ├── frames │ │ └── README.md │ └── meshes │ │ └── README.md │ ├── videos_no_audio │ └── README.md │ └── videos_with_audio │ └── README.md ├── FaceXHuBERT.png ├── LICENSE ├── README.md ├── data_loader.py ├── demo ├── render │ ├── frames │ │ └── README.md │ ├── video_with_audio │ │ └── README.md │ └── video_wo_audio │ │ └── README.md ├── result │ └── README.md └── wav │ ├── README.md │ └── test.wav ├── environment.yml ├── faceXhubert.py ├── hubert ├── activations.py ├── configuration_hubert.py ├── configuration_utils.py ├── deepspeed.py ├── dependency_versions_check.py ├── dependency_versions_table.py ├── file_utils.py ├── generation_beam_search.py ├── generation_logits_process.py ├── generation_stopping_criteria.py ├── generation_utils.py ├── modeling_hubert.py ├── modeling_outputs.py ├── modeling_utils.py └── utils │ ├── logging.py │ └── versions.py ├── index.html ├── main.py ├── modeling_hubert.py ├── page_assets ├── bibtex.txt └── paper.png ├── predict.py ├── pretrained_model └── README.md ├── render_result.py ├── renders └── render_folder │ ├── temp │ └── frames │ │ └── README.md │ ├── videos_no_audio │ └── README.md │ └── videos_with_audio │ └── README.md ├── result └── README.md └── save └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the data files 2 | BIWI/wav/*.wav 3 | 4 | # ignore the following extensions 5 | *.npy 6 | *.pth 7 | *.mp4 8 | -------------------------------------------------------------------------------- /BIWI/ForProcessing/FaceData/README.md: -------------------------------------------------------------------------------- 1 | Put "faces0*.tgz" archive files here. -------------------------------------------------------------------------------- /BIWI/ForProcessing/FaceData/faces/Extract_all.sh: -------------------------------------------------------------------------------- 1 | # for d in */; 2 | # do 3 | # echo "put ${d}"; 4 | # cd ${d}; 5 | # echo "starting extract!"; 6 | # for file in *; do 7 | # if [ -f "$file" ]; then 8 | # if [[ $file == *.tar ]]; then 9 | # echo "$file" 10 | # tar -xf $file --one-top-level 11 | # fi 12 | # fi 13 | # done 14 | # cd .. 15 | 16 | # done 17 | 18 | 19 | for d in */; 20 | do 21 | echo "put ${d}"; 22 | cd ${d}; 23 | echo "starting extract!"; 24 | for file in *; do 25 | if [ -f "$file" ]; then 26 | if [[ $file == *.tar ]]; then 27 | echo "$file" 28 | tar -xf $file --one-top-level 29 | fi 30 | fi 31 | done 32 | cd .. 33 | 34 | done -------------------------------------------------------------------------------- /BIWI/ForProcessing/FaceData/faces/delete_vls.sh: -------------------------------------------------------------------------------- 1 | # for d in */; 2 | # do 3 | # echo "put ${d}"; 4 | # cd ${d}; 5 | # for d2 in */; 6 | # cd ${d2}; 7 | # for file in *; do 8 | # if [ -f "$file" ]; then 9 | # if [[ $file == *.vl ]]; then 10 | # echo "$file" 11 | # #rm $file 12 | # fi 13 | # fi 14 | # done 15 | # cd .. 16 | # cd .. 17 | 18 | # done 19 | 20 | 21 | # for file in *; do 22 | # if [ -f "$file" ]; then 23 | # echo "$file" 24 | # fi 25 | # done 26 | 27 | find . -print0 | while IFS= read -r -d '' file 28 | do 29 | if [ -f "$file" ]; then 30 | if [[ $file == *.vl ]]; then 31 | #echo "$file" 32 | rm -v $file 33 | fi 34 | fi 35 | done 36 | 37 | -------------------------------------------------------------------------------- /BIWI/ForProcessing/FaceData/faces/vertex_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | 5 | input_path = "./" 6 | output_path = "../../../vertices_npy/" 7 | output_file = "" 8 | 9 | subjects = [ name for name in os.listdir(input_path) if os.path.isdir(os.path.join(input_path, name)) ] 10 | 11 | #print(subjects) 12 | 13 | for subject in subjects: 14 | seqs = [ name for name in os.listdir(input_path+"/"+subject) if os.path.isdir(os.path.join(input_path+"/"+subject, name)) ] 15 | # print(seqs) 16 | for seq in seqs: 17 | print("now getting subject ", subject, " and seq ", seq) 18 | frame_files = os.listdir(input_path+"/"+subject+"/"+seq) 19 | if seq[0]=='e': 20 | new_seq = int(seq[1:3]) + 40 21 | output_file = subject+"_"+str(new_seq)+".npy" 22 | else: 23 | output_file = subject+"_"+seq+".npy" 24 | full_output_path = output_path+output_file 25 | empty_arr = np.empty((0,70110)) 26 | for file in frame_files: 27 | if file.endswith(".csv"): 28 | # print(file) 29 | df=pd.read_csv(input_path+"/"+subject+"/"+seq+"/"+file, sep=',',header=None) 30 | df = df.iloc[1: , :] 31 | frame = df.to_numpy() 32 | frame = frame.flatten() 33 | frame = np.expand_dims(frame, axis=0) 34 | empty_arr = np.append(empty_arr,frame,axis=0) 35 | np.save(full_output_path, empty_arr) 36 | print("vertice shape: ", empty_arr.shape) 37 | print("saved to ", full_output_path) 38 | 39 | templates_orig = np.load("../../../templates_orig.pkl", allow_pickle=True) 40 | templates_scaled = np.load("../../../templates_scaled.pkl", allow_pickle=True) 41 | 42 | 43 | path = "../../../vertices_npy/" 44 | 45 | npys = [ name for name in os.listdir(path) if name.endswith('.npy') ] 46 | 47 | for npy in npys: 48 | file_path = path + npy 49 | subject = npy.split('_')[0] 50 | subject_template = templates_orig[subject] 51 | 52 | subject_x_mean = subject_template[:,0].mean() 53 | subject_x_max = subject_template[:,0].max() 54 | subject_x_min = subject_template[:,0].min() 55 | 56 | subject_y_mean = subject_template[:,1].mean() 57 | subject_y_max = subject_template[:,1].max() 58 | subject_y_min = subject_template[:,1].min() 59 | 60 | subject_z_mean = subject_template[:,2].mean() 61 | subject_z_max = subject_template[:,2].max() 62 | subject_z_min = subject_template[:,2].min() 63 | 64 | # print(subject) 65 | seq = np.load(file_path, allow_pickle=True).astype(float) 66 | # print(seq.shape) 67 | seq = np.reshape(seq,(-1,70110//3,3)) 68 | for f in range(seq.shape[0]): 69 | frame = seq[f,:,:] 70 | X = (frame[:,0]-subject_x_mean)/(subject_x_max-subject_x_min) 71 | Y = (frame[:,1]-subject_y_mean)/(subject_y_max-subject_y_min) 72 | Z = (frame[:,2]-subject_z_mean)/(subject_z_max-subject_z_min) 73 | frame[:,0] = X 74 | frame[:,1] = Y 75 | frame[:,2] = Z 76 | 77 | seq[f,:,:] = frame 78 | seq = seq.reshape(seq.shape[0],seq.shape[1]*seq.shape[2]) 79 | # print(seq.shape) 80 | np.save(file_path, seq) -------------------------------------------------------------------------------- /BIWI/ForProcessing/FaceData/faces/vl2csv_recursive.m: -------------------------------------------------------------------------------- 1 | rootdir = pwd; 2 | filelist = dir(fullfile(rootdir, '**\*.*')); %get list of files and folders in any subfolder 3 | filelist = filelist(~[filelist.isdir]); %remove folders from list 4 | 5 | % display(filelist(55).folder); 6 | % path = filelist(55).folder; 7 | % fname = filelist(55).name; 8 | % outfname = '\' + string(fname(1:end-2)) + 'csv'; 9 | % outfilepath = path+outfname; 10 | tic 11 | for i=1:size(filelist,1) 12 | currfile = filelist(i); 13 | path = currfile.folder; 14 | fname = currfile.name; 15 | [~, ~, ext] = fileparts(fname); 16 | if isequal(ext, '.vl') 17 | outfname = '\' + string(fname(1:end-2)) + 'csv'; 18 | outfilepath = path+outfname; 19 | disp(outfilepath); 20 | infilepath = path+"\"+fname; 21 | fid = fopen(infilepath); % load the .vl file 22 | n_vertices = fread(fid, 1, 'ulong'); 23 | vertices = fread(fid, [3, n_vertices] , 'float'); 24 | disp("Loaded the .vl file! No. of Vs: " + n_vertices); 25 | 26 | fclose(fid); 27 | header = {'x', 'y', 'z'}; 28 | textHeader = strjoin(header, ','); 29 | %write header to file 30 | fid2 = fopen(outfilepath,'w'); 31 | fprintf(fid2,'%s\n', textHeader); 32 | fclose(fid2); 33 | dlmwrite(outfilepath, vertices.', '-append'); 34 | 35 | end 36 | end 37 | toc -------------------------------------------------------------------------------- /BIWI/ForProcessing/rest/README.md: -------------------------------------------------------------------------------- 1 | Put "rest.tgz" downloaded from the BIWI dataset here. 2 | -------------------------------------------------------------------------------- /BIWI/ForProcessing/rest/wav_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | wav_in = "./audio/" 5 | wav_out = "../../wav/" 6 | 7 | wavs = os.listdir(wav_in) 8 | 9 | for wav in wavs: 10 | input_path = wav_in + wav 11 | # print(input_path) 12 | if wav.endswith('_cut.wav'): 13 | # print(wav) 14 | filename = wav.split('_') 15 | sub = filename[0] 16 | seq = filename[1] 17 | # print(sub) 18 | # print(seq) 19 | if seq[0]=='e': 20 | # print(seq) 21 | new_seq = int(seq[1:3]) + 40 22 | # print(new_seq) 23 | output_path = wav_out + sub + "_" + str(new_seq) + ".wav" 24 | shutil.copy2(input_path, output_path) 25 | # os.rename(input_path, output_path) 26 | else: 27 | output_path = wav_out + sub + "_" + seq + ".wav" 28 | shutil.copy2(input_path, output_path) -------------------------------------------------------------------------------- /BIWI/templates/README.md: -------------------------------------------------------------------------------- 1 | Put the template .obj files here in this folder. 2 | i.e. - F1.obj, F2.obj, F3.obj, F4.obj, F5.obj, F6.obj, F7.obj, F8.obj, M1.obj, M2.obj, M3.obj, M4.obj, M5.obj, M6.obj -------------------------------------------------------------------------------- /BIWI/templates_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/BIWI/templates_orig.pkl -------------------------------------------------------------------------------- /BIWI/templates_scaled.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/BIWI/templates_scaled.pkl -------------------------------------------------------------------------------- /BIWI/vertices_npy/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the vertex data for all the sequences in `.npy` format. -------------------------------------------------------------------------------- /BIWI/wav/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the audio data for all the sequences in `.wav` format. -------------------------------------------------------------------------------- /Evaluation/GroundTruth/README.md: -------------------------------------------------------------------------------- 1 | Put the .npy sequences in this folder that are in your test split during training. You can copy and paste the test split .npy files from `BIWI/vertices_npy/` folder. -------------------------------------------------------------------------------- /Evaluation/render_quant_evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import trimesh 3 | import numpy as np 4 | import cv2 5 | import io 6 | from PIL import Image 7 | import os 8 | import ffmpeg 9 | import gc 10 | import pyrender 11 | import pymeshlab as pmlab 12 | 13 | quantfilename = "quantitative_metric.txt" 14 | render_folder = "renders/" 15 | gt_folder = "GroundTruth/" 16 | pred_folder = "../result/" 17 | audio_folder = "../BIWI/wav/" 18 | video_woA_folder = render_folder + "videos_no_audio/" 19 | video_wA_folder = render_folder + "videos_with_audio/" 20 | meshes_folder = render_folder+ "temp/meshes/" 21 | frames_folder = render_folder+ "temp/frames/" 22 | 23 | mean_face_vertex_error = 0 24 | 25 | gt_seqs = os.listdir(gt_folder) 26 | pred_seqs = os.listdir(pred_folder) 27 | 28 | fps = 25 29 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 30 | 31 | cam = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) 32 | camera_pose = np.array([[1.0, 0, 0.0, 0.00], 33 | [0.0, -1.0, 0.0, 0.00], 34 | [0.0, 0.0, 1.0, -2.0], 35 | [0.0, 0.0, 0.0, 1.0]]) 36 | 37 | 38 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=5.0) 39 | 40 | r = pyrender.OffscreenRenderer(640, 480) 41 | 42 | print("Evaluation started") 43 | for gt_seq in gt_seqs: 44 | if gt_seq.endswith('.npy'): 45 | video_woA_path = video_woA_folder + gt_seq.split('.')[0] + '.mp4' 46 | video_wA_path = video_wA_folder + gt_seq.split('.')[0] + '.mp4' 47 | video = cv2.VideoWriter(video_woA_path, fourcc, fps, (640, 480)) 48 | gt_seq_path = gt_folder + gt_seq 49 | pred_seq_path = pred_folder + gt_seq.split('.')[0] + "_condition_" + gt_seq.split('_')[0] + ".npy" 50 | print("Now evaluating sequence: ", gt_seq) 51 | subject_template_path = "../BIWI/templates/"+ gt_seq.split('_')[0] + ".obj" 52 | audio = gt_seq.split('.')[0].split('_')[0]+'_'+gt_seq.split('.')[0].split('_')[1]+'.wav' 53 | audio_path = audio_folder + audio 54 | # print(gt_seq_path) 55 | # print(pred_seq_path) 56 | # print(audio_path) 57 | # print(subject_template_path) 58 | # ref_mesh = trimesh.load_mesh(subject_template_path, process=False) 59 | 60 | gt_seq = np.load(gt_seq_path) 61 | pred_seq = np.load(pred_seq_path) 62 | 63 | if(gt_seq.shape[0]>pred_seq.shape[0]): 64 | gt_seq = gt_seq[:pred_seq.shape[0]] 65 | 66 | if(pred_seq.shape[0]>gt_seq.shape[0]): 67 | pred_seq = pred_seq[:gt_seq.shape[0]] 68 | 69 | 70 | gt_seq = np.reshape(gt_seq,(-1,70110//3,3)) 71 | pred_seq = np.reshape(pred_seq,(-1,70110//3,3)) 72 | sequence_mean_face_vertex_error = 0 73 | 74 | 75 | # seq = np.reshape(seq,(-1,70110//3,3)) 76 | # ref_mesh.vertices = seq[0,:,:] 77 | # py_mesh = pyrender.Mesh.from_trimesh(ref_mesh) 78 | for f in range(pred_seq.shape[0]): 79 | ms = pmlab.MeshSet() 80 | ms.load_new_mesh(subject_template_path) 81 | template_mesh= ms.current_mesh() 82 | 83 | gt_mesh = pmlab.Mesh(gt_seq[f,:,:], template_mesh.face_matrix(), template_mesh.vertex_normal_matrix()) 84 | ms.add_mesh(gt_mesh) 85 | # print(ms.current_mesh_id()) 86 | 87 | pred_mesh = pmlab.Mesh(pred_seq[f,:,:], template_mesh.face_matrix(), template_mesh.vertex_normal_matrix()) 88 | ms.add_mesh(pred_mesh) 89 | # print(ms.current_mesh_id()) 90 | 91 | ms.apply_filter('distance_from_reference_mesh', measuremesh=2, refmesh=1, signeddist = False) 92 | ms.set_current_mesh(2) 93 | vertex_quality = ms.current_mesh().vertex_quality_array() 94 | #ms.apply_filter('colorize_by_vertex_quality', minval=ms.current_mesh().vertex_quality_array().min(), maxval=ms.current_mesh().vertex_quality_array().max(),zerosym=True) 95 | ms.apply_filter('quality_mapper_applier', minqualityval=ms.current_mesh().vertex_quality_array().min(), maxqualityval=ms.current_mesh().vertex_quality_array().max(),tfslist= 1) 96 | ms.save_current_mesh( meshes_folder + str(f) +".obj", save_vertex_color=True) 97 | sequence_mean_face_vertex_error = sequence_mean_face_vertex_error + vertex_quality.mean(axis=None) 98 | ms.set_current_mesh(2) 99 | ms.delete_current_mesh() 100 | ms.set_current_mesh(1) 101 | ms.delete_current_mesh() 102 | ms.set_current_mesh(0) 103 | 104 | render_mesh = trimesh.load_mesh(meshes_folder + str(f) +".obj", process=False) 105 | py_mesh = pyrender.Mesh.from_trimesh(render_mesh) 106 | 107 | # ref_mesh.vertices = seq[f,:,:] 108 | # py_mesh = pyrender.Mesh.from_trimesh(ref_mesh) 109 | scene = pyrender.Scene() 110 | scene.add(py_mesh) 111 | 112 | scene.add(cam, pose=camera_pose) 113 | scene.add(light, pose=camera_pose) 114 | color, _ = r.render(scene) 115 | 116 | output_frame = frames_folder + "frame" + str(f) + ".jpg" 117 | image_bgr = cv2.cvtColor(color, cv2.COLOR_RGB2BGR) 118 | cv2.imwrite(output_frame, image_bgr) 119 | frame = cv2.imread(output_frame) 120 | video.write(frame) 121 | video.release() 122 | sequence_mean_face_vertex_error = sequence_mean_face_vertex_error/pred_seq.shape[0] 123 | mean_face_vertex_error = mean_face_vertex_error + sequence_mean_face_vertex_error 124 | 125 | input_video = ffmpeg.input(video_woA_path) 126 | input_audio = ffmpeg.input(audio_path) 127 | ffmpeg.concat(input_video, input_audio, v=1, a=1).output(video_wA_path).run() 128 | del video 129 | gc.collect() 130 | 131 | 132 | mean_face_vertex_error = mean_face_vertex_error/len(gt_seqs) 133 | 134 | file = open(quantfilename, "w") 135 | 136 | #convert variable to string 137 | str = repr(mean_face_vertex_error) 138 | file.write("mean_face_vertex_error = " + str + "\n") 139 | 140 | file.close() 141 | print("Done!") -------------------------------------------------------------------------------- /Evaluation/renders/temp/frames/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain temporary frames for the rendered videos of the evaluation script. -------------------------------------------------------------------------------- /Evaluation/renders/temp/meshes/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain temporary face meshes used for rendering the evaluation renders. -------------------------------------------------------------------------------- /Evaluation/renders/videos_no_audio/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the rendered videos without audio. -------------------------------------------------------------------------------- /Evaluation/renders/videos_with_audio/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the rendered videos with audio. -------------------------------------------------------------------------------- /FaceXHuBERT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/FaceXHuBERT.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # Attribution-NonCommercial 4.0 International 2 | 3 | > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.* 4 | > 5 | > ### Using Creative Commons Public Licenses 6 | > 7 | > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 8 | > 9 | > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 10 | > 11 | > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 12 | 13 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License 14 | 15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 16 | 17 | ### Section 1 – Definitions. 18 | 19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 20 | 21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 22 | 23 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 24 | 25 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 26 | 27 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 28 | 29 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 30 | 31 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 32 | 33 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 34 | 35 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 36 | 37 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 38 | 39 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 40 | 41 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 42 | 43 | ### Section 2 – Scope. 44 | 45 | a. ___License grant.___ 46 | 47 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 48 | 49 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 50 | 51 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 52 | 53 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 54 | 55 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 56 | 57 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 58 | 59 | 5. __Downstream recipients.__ 60 | 61 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 62 | 63 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 64 | 65 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 66 | 67 | b. ___Other rights.___ 68 | 69 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 70 | 71 | 2. Patent and trademark rights are not licensed under this Public License. 72 | 73 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 74 | 75 | ### Section 3 – License Conditions. 76 | 77 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 78 | 79 | a. ___Attribution.___ 80 | 81 | 1. If You Share the Licensed Material (including in modified form), You must: 82 | 83 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 84 | 85 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 86 | 87 | ii. a copyright notice; 88 | 89 | iii. a notice that refers to this Public License; 90 | 91 | iv. a notice that refers to the disclaimer of warranties; 92 | 93 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 94 | 95 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 96 | 97 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 98 | 99 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 100 | 101 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 102 | 103 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. 104 | 105 | ### Section 4 – Sui Generis Database Rights. 106 | 107 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 108 | 109 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 110 | 111 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 112 | 113 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 114 | 115 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 116 | 117 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 118 | 119 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 120 | 121 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 122 | 123 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 124 | 125 | ### Section 6 – Term and Termination. 126 | 127 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 128 | 129 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 130 | 131 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 132 | 133 | 2. upon express reinstatement by the Licensor. 134 | 135 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 136 | 137 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 138 | 139 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 140 | 141 | ### Section 7 – Other Terms and Conditions. 142 | 143 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 144 | 145 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 146 | 147 | ### Section 8 – Interpretation. 148 | 149 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 150 | 151 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 152 | 153 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 154 | 155 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 156 | 157 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 158 | > 159 | > Creative Commons may be contacted at creativecommons.org -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaceXHuBERT (ICMI '23) 2 | ### Code repository for the paper: 3 | 4 | >_FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis using Self-Supervised Speech Representation Learning_. 5 | 6 | > Authors: Kazi Injamamul Haque, Zerrin Yumak 7 | 8 | > [[Paper]](https://dl.acm.org/doi/pdf/10.1145/3577190.3614157) [[Project Page]](https://galib360.github.io/FaceXHuBERT/) [[Video]](https://www.youtube.com/watch?v=AkBhnNOxwE4&ab_channel=KaziInjamamulHaque) 9 | 10 | 11 | > This GitHub repository contains PyTorch implementation of the work presented in the paper mentioned above. Given a raw audio, FaceXHuBERT generates and renders expressive 3D facial animation. We recommend visiting the project website and watching the supplementary video. 12 | 13 |

14 | 15 |

16 | 17 | ## Environment 18 | 19 | - Windows (tested on Windows 10 and 11) 20 | - Python 3.8 21 | - PyTorch 1.10.1+cu113 22 | 23 | ## Dependencies 24 | 25 | - Check the required python packages and libraries in `environment.yml`. 26 | - [ffmpeg](https://ffmpeg.org/download.html) for Windows. [WikiHow link on how to install](https://www.wikihow.com/Install-FFmpeg-on-Windows) 27 | 28 | ## Get Started 29 | 30 | It is recommended to create a new anaconda environment with Python 3.8. To do so, please follow the steps sequentially- 31 | - Ensure that [CUDA](https://developer.nvidia.com/cuda-11-5-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exe_local) computing toolkit with appropriate [cudnn](https://developer.nvidia.com/rdp/cudnn-archive) (tested with CUDA 11.5) is properly installed in the system and the environment variable "CUDA_PATH" is set properly. First install CUDA, then download, extract, and copy the cudnn contents in the CUDA installation folder (e.g. C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.5). 32 | - Install [Anaconda](https://www.anaconda.com/products/distribution) for Windows. 33 | - Insatall ffmpeg. [WikiHow link on how to install ffmpeg](https://www.wikihow.com/Install-FFmpeg-on-Windows) 34 | - Clone this repository. 35 | - Open Anaconda Promt CLI. 36 | ``` 37 | cd 38 | ``` 39 | - Then run the following command in the Anaconda promt 40 | 41 | ``` 42 | conda env create --name FaceXHuBERT python=3.8 --file="environment.yml" 43 | conda activate FaceXHuBERT 44 | pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html 45 | 46 | ``` 47 | - Please make sure you run all the python scripts below from the activated virtual environments command line (in other words, your python interpreter is the one from FaceXHuBERT environment you just created). 48 | 49 | 50 | ## Demo 51 | 52 | Download the pretrained model from [FaceXHuBERT model](https://mega.nz/file/L4BzEATa#HZ_BuV56yI4yQLQMhiml5rOLAMcxgjCEwAgcITD_09g). Put the pretrained model under `pretrained_model/` folder. 53 | 54 | - Given a raw audio file (wav_path), animate the mesh and render the video by running the following command: 55 | ``` 56 | python predict.py --subject F1 --condition F3 --wav_path "demo/wav/test.wav" --emotion 1 57 | ``` 58 | 59 | The predict.py will run and generate the rendered videos in the `demo/render/video_with_audio/` folder. The prediction data will be saved in `demo/result/` folder. Try playing with your own audio file (in .wav format), other subjects and conditions (i.e. F1, F2, F3, F4, F5, F6, F7, F8, M1, M2, M3, M4, M5, M6). 60 | 61 | ## Data 62 | ### BIWI 63 | 64 | The [Biwi 3D Audiovisual Corpus of Affective Communication](https://data.vision.ee.ethz.ch/cvl/datasets/b3dac2.en.html) dataset is available upon request for research or academic purposes. You will need the following files from the the dataset: 65 | 66 | - faces01.tgz, faces02.tgz, faces03.tgz, faces04.tgz, faces05.tgz and rest.tgz 67 | - Place all the faces0*.tgz archives in `BIWI/ForProcessing/FaceData/` folder 68 | - Place the rest.tgz archive in `BIWI/ForProcessing/rest/` folder 69 | 70 | 71 | #### Data Preparation and Data Pre-process 72 | Follow the steps below sequentially as they appear - 73 | 74 | - You will need [Matlab](https://mathworks.com/products/matlab.html) installed on you machine to prepapre the data for pre-processing 75 | - Open Anaconda Promt CLI, activate FaceXHuBERT env in the directory- `BIWI/ForPorcessing/rest/` 76 | - Run the following command 77 | ``` 78 | tar -xvzf rest.tgz 79 | ``` 80 | - After extracting, you will see the `audio/` folder that contains the input audios needed for network training in .wav format 81 | - Run the `wav_process.py` script. This will process the `audio/` folder and copy the needed audio sequences with proper names to `FaceXHuBERT/BIWI/wav/` folder for training 82 | ``` 83 | python wav_process.py 84 | ``` 85 | - Open Anaconda Promt CLI, activate FaceXHuBERT env in the directory- `BIWI/ForPorcessing/FaceData/` 86 | - Run the following command for extracting all the archives. Replace `*` with (1-5 for five archives) 87 | ``` 88 | tar -xvzf faces0*.tgz 89 | ``` 90 | - After extracting, you will see a folder named `faces/`. Move all the .obj files from this folder (i.e. F1.obj-M6.obj) to `FaceXHuBERT/BIWI/templates/` folder 91 | - Run the shell script `Extract_all.sh`. This will extract all the archives for all subjects and for all sequences. You will have frame-by-frame vertex data in `frame_*.vl` binary file format 92 | - Run the Matlab script `vl2csv_recusive.m`. This will convert all the `.vl` files into `.csv` files 93 | - Run the `vertex_process.py` script. This will process the data and place the processed data in `FaceXHuBERT/BIWI/vertices_npy/` folder for network training 94 | ``` 95 | python vertex_process.py 96 | ``` 97 | 98 | 99 | ## Model Training 100 | 101 | ### Training and Testing 102 | 103 | - Train the model by running the following command: 104 | 105 | ``` 106 | python main.py 107 | ``` 108 | The test split predicted results will be saved in the `result/`. The trained models (saves the model in 25 epoch interval) will be saved in the `save/` folder. 109 | 110 | ### Visualization 111 | 112 | - Run the following command to render the predicted test sequences stored in `result/`: 113 | 114 | ``` 115 | python render_result.py 116 | ``` 117 | The rendered videos will be saved in the `renders/render_folder/videos_with_audio/` folder. 118 | 119 | ### Evaluation 120 | 121 | - Put the ground truth test split sequences (.npy files) in `Evaluation/GroundTruth/` folder. (If you train the model using our dataset split, then the test-set is sequences 39,40,79,80 for all 14 subjects.) 122 | - Run the following command to run quantitative evaluation and render the evaluation results: 123 | 124 | ``` 125 | cd Evaluation 126 | python render_quant_evaluation.py 127 | ``` 128 | - The rendered videos will be saved in the `Evaluation/renders/videos_with_audio/` folder. 129 | - The computed Mean Face Vertex Error will be saved in `Evaluation/quantitative_metric.txt` 130 | 131 | 132 | ## Citation 133 | 134 | If you find this code useful for your work, please be kind to consider citing our paper: 135 | ``` 136 | @inproceedings{FaceXHuBERT_Haque_ICMI23, 137 | author = {Haque, Kazi Injamamul and Yumak, Zerrin}, 138 | title = {FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis Using Self-Supervised Speech Representation Learning}, 139 | booktitle = {INTERNATIONAL CONFERENCE ON MULTIMODAL INTERACTION (ICMI ’23)}, 140 | year = {2023}, 141 | location = {Paris, France}, 142 | numpages = {10}, 143 | url = {https://doi.org/10.1145/3577190.3614157}, 144 | doi = {10.1145/3577190.3614157}, 145 | publisher = {ACM}, 146 | address = {New York, NY, USA}, 147 | } 148 | ``` 149 | 150 | ## Acknowledgement 151 | We would like to thank the authors of FaceFormer for making their code available. Thanks to ETH Zurich CVL for providing us access to the _Biwi 3D Audiovisual Corpus_. The HuBERT implementation is borrowed from [Hugging Face](https://huggingface.co/). 152 | 153 | ## License 154 | This repository is released under [CC-BY-NC-4.0-International License](https://github.com/Gibberlings3/GitHub-Templates/blob/master/License-Templates/CC-BY-NC-4.0/LICENSE-CC-BY-NC-4.0.md) -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import defaultdict 4 | from torch.utils import data 5 | import numpy as np 6 | import pickle 7 | from tqdm import tqdm 8 | from transformers import Wav2Vec2Processor 9 | import librosa 10 | 11 | 12 | class Dataset(data.Dataset): 13 | """Custom data.Dataset compatible with data.DataLoader.""" 14 | def __init__(self, data, subjects_dict, data_type="train"): 15 | self.data = data 16 | self.len = len(self.data) 17 | self.subjects_dict = subjects_dict 18 | self.data_type = data_type 19 | self.one_hot_labels = np.eye(len(subjects_dict["train"])) 20 | self.emo_one_hot_labels = np.eye(2) 21 | 22 | def __getitem__(self, index): 23 | """Returns one data pair (source and target).""" 24 | file_name = self.data[index]["name"] 25 | audio = self.data[index]["audio"] 26 | vertice = self.data[index]["vertice"] 27 | template = self.data[index]["template"] 28 | sentence_id = int(file_name.split(".")[0].split("_")[1]) 29 | if self.data_type == "train": 30 | subject = "_".join(file_name.split("_")[:-1]) 31 | one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)] 32 | else: 33 | one_hot = self.one_hot_labels 34 | if sentence_id > 40: 35 | emo_one_hot = self.emo_one_hot_labels[1] 36 | else: 37 | emo_one_hot = self.emo_one_hot_labels[0] 38 | 39 | return torch.FloatTensor(audio), vertice, torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name, torch.FloatTensor(emo_one_hot) 40 | 41 | def __len__(self): 42 | return self.len 43 | 44 | 45 | def read_data(args): 46 | print("Loading data...") 47 | data = defaultdict(dict) 48 | train_data = [] 49 | valid_data = [] 50 | test_data = [] 51 | 52 | audio_path = os.path.join(args.dataset, args.wav_path) 53 | vertices_path = os.path.join(args.dataset, args.vertices_path) 54 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft") # HuBERT uses the processor of Wav2Vec 2.0 55 | 56 | template_file = os.path.join(args.dataset, args.template_file) 57 | with open(template_file, 'rb') as fin: 58 | templates = pickle.load(fin,encoding='latin1') 59 | 60 | for r, ds, fs in os.walk(audio_path): 61 | for f in tqdm(fs): 62 | if f.endswith("wav"): 63 | wav_path = os.path.join(r,f) 64 | speech_array, sampling_rate = librosa.load(wav_path, sr=16000) 65 | input_values = processor(speech_array, return_tensors="pt", padding="longest", sampling_rate=sampling_rate).input_values 66 | key = f.replace("wav", "npy") 67 | data[key]["audio"] = input_values 68 | subject_id = "_".join(key.split("_")[:-1]) 69 | temp = templates[subject_id] 70 | data[key]["name"] = f 71 | data[key]["template"] = temp.reshape((-1)) 72 | vertice_path = os.path.join(vertices_path,f.replace("wav", "npy")) 73 | if not os.path.exists(vertice_path): 74 | del data[key] 75 | # print("Vertices Data Not Found! ", vertice_path) 76 | else: 77 | data[key]["vertice"] = vertice_path 78 | 79 | subjects_dict = {} 80 | subjects_dict["train"] = [i for i in args.train_subjects.split(" ")] 81 | subjects_dict["val"] = [i for i in args.val_subjects.split(" ")] 82 | subjects_dict["test"] = [i for i in args.test_subjects.split(" ")] 83 | 84 | splits = {'BIWI': {'train': list(range(1, 37)) + list(range(41, 77)), 'val': list(range(37, 39)) + list(range(77, 79)), 'test': list(range(39, 41)) + list(range(79, 81))}} 85 | 86 | for k, v in data.items(): 87 | subject_id = "_".join(k.split("_")[:-1]) 88 | sentence_id = int(k.split(".")[0][-2:]) 89 | if subject_id in subjects_dict["train"] and sentence_id in splits[args.dataset]['train']: 90 | train_data.append(v) 91 | if subject_id in subjects_dict["val"] and sentence_id in splits[args.dataset]['val']: 92 | valid_data.append(v) 93 | if subject_id in subjects_dict["test"] and sentence_id in splits[args.dataset]['test']: 94 | test_data.append(v) 95 | 96 | print(len(train_data), len(valid_data), len(test_data)) 97 | return train_data, valid_data, test_data, subjects_dict 98 | 99 | 100 | def get_dataloaders(args): 101 | dataset = {} 102 | train_data, valid_data, test_data, subjects_dict = read_data(args) 103 | train_data = Dataset(train_data, subjects_dict,"train") 104 | dataset["train"] = data.DataLoader(dataset=train_data, batch_size=1, shuffle=True) 105 | valid_data = Dataset(valid_data, subjects_dict,"val") 106 | dataset["valid"] = data.DataLoader(dataset=valid_data, batch_size=1, shuffle=False) 107 | test_data = Dataset(test_data, subjects_dict,"test") 108 | dataset["test"] = data.DataLoader(dataset=test_data, batch_size=1, shuffle=False) 109 | return dataset 110 | 111 | 112 | if __name__ == "__main__": 113 | get_dataloaders() 114 | 115 | -------------------------------------------------------------------------------- /demo/render/frames/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the temporary frames of the renders. -------------------------------------------------------------------------------- /demo/render/video_with_audio/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain rendered video with audio. -------------------------------------------------------------------------------- /demo/render/video_wo_audio/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain rendered video without audio. -------------------------------------------------------------------------------- /demo/result/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain generated vertex data result in `.npy` format. -------------------------------------------------------------------------------- /demo/wav/README.md: -------------------------------------------------------------------------------- 1 | Put your own audio files in .wav format in this folder to run the demo script `predict.py`. -------------------------------------------------------------------------------- /demo/wav/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/demo/wav/test.wav -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: FaceXHuBERT 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - argon2-cffi=20.1.0=py38h2bbff1b_1 7 | - async_generator=1.10=pyhd3eb1b0_0 8 | - attrs=21.2.0=pyhd3eb1b0_0 9 | - backcall=0.2.0=pyhd3eb1b0_0 10 | - bleach=4.0.0=pyhd3eb1b0_0 11 | - ca-certificates=2021.10.26=haa95532_2 12 | - certifi=2021.10.8=py38haa95532_2 13 | - cffi=1.15.0=py38h2bbff1b_0 14 | - colorama=0.4.4=pyhd3eb1b0_0 15 | - console_shortcut=0.1.1=4 16 | - debugpy=1.5.1=py38hd77b12b_0 17 | - defusedxml=0.7.1=pyhd3eb1b0_0 18 | - entrypoints=0.3=py38_0 19 | - importlib-metadata=4.8.2=py38haa95532_0 20 | - importlib_metadata=4.8.2=hd3eb1b0_0 21 | - ipykernel=6.4.1=py38haa95532_1 22 | - ipython=7.29.0=py38hd4e2768_0 23 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 24 | - jedi=0.18.0=py38haa95532_1 25 | - jsonschema=3.2.0=pyhd3eb1b0_2 26 | - jupyter_client=7.1.0=pyhd3eb1b0_0 27 | - jupyter_core=4.9.1=py38haa95532_0 28 | - jupyterlab_pygments=0.1.2=py_0 29 | - m2w64-gcc-libgfortran=5.3.0=6 30 | - m2w64-gcc-libs=5.3.0=7 31 | - m2w64-gcc-libs-core=5.3.0=7 32 | - m2w64-gmp=6.1.0=2 33 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 34 | - markupsafe=2.0.1=py38h2bbff1b_0 35 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 36 | - mistune=0.8.4=py38he774522_1000 37 | - msys2-conda-epoch=20160418=1 38 | - nbclient=0.5.3=pyhd3eb1b0_0 39 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 40 | - notebook=6.4.6=py38haa95532_0 41 | - openssl=1.1.1m=h2bbff1b_0 42 | - packaging=21.3=pyhd3eb1b0_0 43 | - pandocfilters=1.4.3=py38haa95532_1 44 | - parso=0.8.3=pyhd3eb1b0_0 45 | - pickleshare=0.7.5=pyhd3eb1b0_1003 46 | - pip=21.2.2=py38haa95532_0 47 | - prometheus_client=0.12.0=pyhd3eb1b0_0 48 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 49 | - pycparser=2.21=pyhd3eb1b0_0 50 | - pygments=2.10.0=pyhd3eb1b0_0 51 | - pyparsing=3.0.4=pyhd3eb1b0_0 52 | - pyrsistent=0.18.0=py38h196d8e1_0 53 | - python=3.8.12=h6244533_0 54 | - python-dateutil=2.8.2=pyhd3eb1b0_0 55 | - pywin32=302=py38h827c3e9_1 56 | - pywinpty=0.5.7=py38_0 57 | - pyzmq=22.3.0=py38hd77b12b_2 58 | - send2trash=1.8.0=pyhd3eb1b0_1 59 | - sqlite=3.37.0=h2bbff1b_0 60 | - terminado=0.9.4=py38haa95532_0 61 | - testpath=0.5.0=pyhd3eb1b0_0 62 | - tornado=6.1=py38h2bbff1b_0 63 | - traitlets=5.1.1=pyhd3eb1b0_0 64 | - wcwidth=0.2.5=pyhd3eb1b0_0 65 | - webencodings=0.5.1=py38_1 66 | - wheel=0.37.0=pyhd3eb1b0_1 67 | - wincertstore=0.2=py38haa95532_2 68 | - winpty=0.4.3=4 69 | - zipp=3.6.0=pyhd3eb1b0_0 70 | - pip: 71 | - absl-py==0.15.0 72 | - ansiwrap==0.8.4 73 | - anyio==3.6.1 74 | - astunparse==1.6.3 75 | - babel==2.10.3 76 | - beautifulsoup4==4.11.1 77 | - biopython==1.79 78 | - cachetools==4.2.4 79 | - click==8.1.1 80 | - configparser==5.2.0 81 | - cycler==0.11.0 82 | - decorator==4.4.2 83 | - deprecation==2.1.0 84 | - et-xmlfile==1.1.0 85 | - fastjsonschema==2.15.3 86 | - ffmpeg-python==0.2.0 87 | - filelock==3.6.0 88 | - fonttools==4.28.5 89 | - freetype-py==2.2.0 90 | - future==0.18.2 91 | - fvcore==0.1.5.post20220305 92 | - google-auth==2.3.3 93 | - google-auth-oauthlib==0.4.6 94 | - google-pasta==0.2.0 95 | - grpcio==1.43.0 96 | - h5py==2.10.0 97 | - huggingface-hub==0.0.8 98 | - hyperpyyaml==1.0.1 99 | - imageio==2.16.1 100 | - imageio-ffmpeg==0.4.7 101 | - iopath==0.1.9 102 | - ipywidgets==7.7.0 103 | - jinja2==3.1.2 104 | - json5==0.9.8 105 | - jsonlines==3.0.0 106 | - jupyter-packaging==0.12.2 107 | - jupyter-server==1.17.1 108 | - jupyterlab==3.4.3 109 | - jupyterlab-server==2.14.0 110 | - jupyterlab-widgets==1.1.0 111 | - keras-preprocessing==1.1.2 112 | - kiwisolver==1.3.2 113 | - librosa==0.8.1 114 | - markdown==3.3.6 115 | - matplotlib==3.5.1 116 | - nbclassic==0.3.7 117 | - nbconvert==6.5.0 118 | - nbformat==5.4.0 119 | - networkx==2.7.1 120 | - notebook-shim==0.1.0 121 | - numpy==1.20.0 122 | - oauthlib==3.1.1 123 | - opencv-python==4.5.5.62 124 | - openpyxl==3.0.9 125 | - opt-einsum==3.3.0 126 | - pandas==1.3.5 127 | - pandas-stubs==1.2.0.57 128 | - pillow==9.0.0 129 | - portalocker==2.4.0 130 | - praatio==5.1.1 131 | - proglog==0.1.10 132 | - protobuf==3.19.3 133 | - pyasn1==0.4.8 134 | - pyasn1-modules==0.2.8 135 | - pyglet==1.5.26 136 | - pymeshlab==0.2.1 137 | - pymysql==1.0.2 138 | - pyopengl==3.1.0 139 | - pyrender==0.1.45 140 | - pytz==2021.3 141 | - pyyaml==6.0 142 | - regex==2022.3.15 143 | - requests==2.28.0 144 | - requests-oauthlib==1.3.0 145 | - rsa==4.8 146 | - ruamel-yaml==0.17.21 147 | - ruamel-yaml-clib==0.2.6 148 | - sacremoses==0.0.49 149 | - sentencepiece==0.1.96 150 | - setuptools==62.6.0 151 | - shutils==0.1.0 152 | - six==1.15.0 153 | - sniffio==1.2.0 154 | - soupsieve==2.3.2.post1 155 | - tabulate==0.8.9 156 | - termcolor==1.1.0 157 | - textwrap3==0.9.2 158 | - tinycss2==1.1.1 159 | - tokenizers==0.10.3 160 | - tomlkit==0.11.0 161 | - tqdm==4.63.0 162 | - transformers==4.7.0 163 | - trimesh==3.12.6 164 | - typing-extensions==4.0.1 165 | - urllib3==1.26.8 166 | - websocket-client==1.3.3 167 | - werkzeug==2.0.2 168 | - widgetsnbextension==3.6.0 169 | - yacs==0.1.8 170 | 171 | -------------------------------------------------------------------------------- /faceXhubert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from hubert.modeling_hubert import HubertModel 4 | import torch.nn.functional as F 5 | 6 | 7 | def inputRepresentationAdjustment(audio_embedding_matrix, vertex_matrix, ifps, ofps): 8 | if ifps % ofps == 0: 9 | factor = -1 * (-ifps // ofps) 10 | if audio_embedding_matrix.shape[1] % 2 != 0: 11 | audio_embedding_matrix = audio_embedding_matrix[:, :audio_embedding_matrix.shape[1] - 1] 12 | 13 | if audio_embedding_matrix.shape[1] > vertex_matrix.shape[1] * 2: 14 | audio_embedding_matrix = audio_embedding_matrix[:, :vertex_matrix.shape[1] * 2] 15 | 16 | elif audio_embedding_matrix.shape[1] < vertex_matrix.shape[1] * 2: 17 | vertex_matrix = vertex_matrix[:, :audio_embedding_matrix.shape[1] // 2] 18 | else: 19 | factor = -1 * (-ifps // ofps) 20 | audio_embedding_seq_len = vertex_matrix.shape[1] * factor 21 | audio_embedding_matrix = audio_embedding_matrix.transpose(1, 2) 22 | audio_embedding_matrix = F.interpolate(audio_embedding_matrix, size=audio_embedding_seq_len, align_corners=True, mode='linear') 23 | audio_embedding_matrix = audio_embedding_matrix.transpose(1, 2) 24 | 25 | frame_num = vertex_matrix.shape[1] 26 | audio_embedding_matrix = torch.reshape(audio_embedding_matrix, (1, audio_embedding_matrix.shape[1] // factor, audio_embedding_matrix.shape[2] * factor)) 27 | 28 | return audio_embedding_matrix, vertex_matrix, frame_num 29 | 30 | 31 | class FaceXHuBERT(nn.Module): 32 | def __init__(self, args): 33 | super(FaceXHuBERT, self).__init__() 34 | """ 35 | audio: (batch_size, raw_wav) 36 | template: (batch_size, V*3) 37 | vertice: (batch_size, seq_len, V*3) 38 | """ 39 | self.dataset = args.dataset 40 | self.i_fps = args.input_fps # audio fps (input to the network) 41 | self.o_fps = args.output_fps # 4D Scan fps (output or target) 42 | self.gru_layer_dim = 2 43 | self.gru_hidden_dim = args.feature_dim 44 | 45 | # Audio Encoder 46 | self.audio_encoder = HubertModel.from_pretrained("facebook/hubert-base-ls960") 47 | self.audio_dim = self.audio_encoder.encoder.config.hidden_size 48 | self.audio_encoder.feature_extractor._freeze_parameters() 49 | 50 | frozen_layers = [0,1] 51 | 52 | for name, param in self.audio_encoder.named_parameters(): 53 | if name.startswith("feature_projection"): 54 | param.requires_grad = False 55 | if name.startswith("encoder.layers"): 56 | layer = int(name.split(".")[2]) 57 | if layer in frozen_layers: 58 | param.requires_grad = False 59 | 60 | #Vertex Decoder 61 | # GRU module 62 | self.gru = nn.GRU(self.audio_dim*2, args.feature_dim, self.gru_layer_dim, batch_first=True, dropout=0.3) 63 | # Fully connected layer 64 | self.fc = nn.Linear(args.feature_dim, args.vertice_dim) 65 | nn.init.constant_(self.fc.weight, 0) 66 | nn.init.constant_(self.fc.bias, 0) 67 | 68 | # Subject embedding, S 69 | self.obj_vector = nn.Linear(len(args.train_subjects.split()), args.feature_dim, bias=False) 70 | 71 | # Emotion embedding, E 72 | self.emo_vector = nn.Linear(2, args.feature_dim, bias=False) 73 | 74 | def forward(self, audio, template, vertice, one_hot, emo_one_hot, criterion): 75 | 76 | template = template.unsqueeze(1) 77 | obj_embedding = self.obj_vector(one_hot) 78 | emo_embedding = self.emo_vector(emo_one_hot) 79 | 80 | hidden_states = audio 81 | 82 | hidden_states = self.audio_encoder(hidden_states).last_hidden_state 83 | 84 | hidden_states, vertice, frame_num = inputRepresentationAdjustment(hidden_states, vertice, self.i_fps, self.o_fps) 85 | 86 | hidden_states = hidden_states[:, :frame_num] 87 | 88 | h0 = torch.zeros(self.gru_layer_dim, hidden_states.shape[0], self.gru_hidden_dim).requires_grad_().cuda() 89 | 90 | vertice_out, _ = self.gru(hidden_states, h0) 91 | vertice_out = vertice_out * obj_embedding 92 | vertice_out = vertice_out * emo_embedding 93 | 94 | vertice_out = self.fc(vertice_out) 95 | 96 | vertice_out = vertice_out + template 97 | 98 | loss = criterion(vertice_out, vertice) 99 | loss = torch.mean(loss) 100 | return loss 101 | 102 | def predict(self, audio, template, one_hot, emo_one_hot): 103 | template = template.unsqueeze(1) 104 | obj_embedding = self.obj_vector(one_hot) 105 | emo_embedding = self.emo_vector(emo_one_hot) 106 | hidden_states = audio 107 | hidden_states = self.audio_encoder(hidden_states).last_hidden_state 108 | 109 | if hidden_states.shape[1] % 2 != 0: 110 | hidden_states = hidden_states[:, :hidden_states.shape[1]-1] 111 | hidden_states = torch.reshape(hidden_states, (1, hidden_states.shape[1] // 2, hidden_states.shape[2] * 2)) 112 | h0 = torch.zeros(self.gru_layer_dim, hidden_states.shape[0], self.gru_hidden_dim).requires_grad_().cuda() 113 | 114 | vertice_out, _ = self.gru(hidden_states, h0) 115 | vertice_out = vertice_out * obj_embedding 116 | vertice_out = vertice_out * emo_embedding 117 | 118 | vertice_out = self.fc(vertice_out) 119 | 120 | vertice_out = vertice_out + template 121 | 122 | return vertice_out 123 | -------------------------------------------------------------------------------- /hubert/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | from packaging import version 19 | from torch import nn 20 | 21 | from .utils import logging 22 | 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | 27 | def _gelu_python(x): 28 | """ 29 | Original Implementation of the GELU activation function in Google BERT repo when initially created. For 30 | information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + 31 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional 32 | Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | def gelu_new(x): 38 | """ 39 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see 40 | the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 41 | """ 42 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 43 | 44 | 45 | if version.parse(torch.__version__) < version.parse("1.4"): 46 | gelu = _gelu_python 47 | else: 48 | gelu = nn.functional.gelu 49 | 50 | 51 | def gelu_fast(x): 52 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 53 | 54 | 55 | def quick_gelu(x): 56 | return x * torch.sigmoid(1.702 * x) 57 | 58 | 59 | def _silu_python(x): 60 | """ 61 | See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear 62 | Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function 63 | Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated 64 | Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with 65 | later. 66 | """ 67 | return x * torch.sigmoid(x) 68 | 69 | 70 | if version.parse(torch.__version__) < version.parse("1.7"): 71 | silu = _silu_python 72 | else: 73 | silu = nn.functional.silu 74 | 75 | 76 | def mish(x): 77 | return x * torch.tanh(nn.functional.softplus(x)) 78 | 79 | 80 | def linear_act(x): 81 | return x 82 | 83 | 84 | ACT2FN = { 85 | "relu": nn.functional.relu, 86 | "silu": silu, 87 | "swish": silu, 88 | "gelu": gelu, 89 | "tanh": torch.tanh, 90 | "gelu_new": gelu_new, 91 | "gelu_fast": gelu_fast, 92 | "quick_gelu": quick_gelu, 93 | "mish": mish, 94 | "linear": linear_act, 95 | "sigmoid": torch.sigmoid, 96 | } 97 | 98 | 99 | def get_activation(activation_string): 100 | if activation_string in ACT2FN: 101 | return ACT2FN[activation_string] 102 | else: 103 | raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") 104 | -------------------------------------------------------------------------------- /hubert/configuration_hubert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Hubert model configuration """ 16 | 17 | from .configuration_utils import PretrainedConfig 18 | from .utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json", 25 | # See all Hubert models at https://huggingface.co/models?filter=hubert 26 | } 27 | 28 | 29 | class HubertConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a :class:`~transformers.HubertModel`. It is used to 32 | instantiate an Hubert model according to the specified arguments, defining the model architecture. Instantiating a 33 | configuration with the defaults will yield a similar configuration to that of the Hubert 34 | `facebook/hubert-base-ls960 `__ architecture. 35 | 36 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 37 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 38 | 39 | 40 | Args: 41 | vocab_size (:obj:`int`, `optional`, defaults to 32): 42 | Vocabulary size of the Hubert model. Defines the number of different tokens that can be represented by the 43 | :obj:`inputs_ids` passed when calling :class:`~transformers.HubertModel`. Vocabulary size of the model. 44 | Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of 45 | :class:`~transformers.HubertModel`. 46 | hidden_size (:obj:`int`, `optional`, defaults to 768): 47 | Dimensionality of the encoder layers and the pooler layer. 48 | num_hidden_layers (:obj:`int`, `optional`, defaults to 12): 49 | Number of hidden layers in the Transformer encoder. 50 | num_attention_heads (:obj:`int`, `optional`, defaults to 12): 51 | Number of attention heads for each attention layer in the Transformer encoder. 52 | intermediate_size (:obj:`int`, `optional`, defaults to 3072): 53 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 54 | hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): 55 | The non-linear activation function (function or string) in the encoder and pooler. If string, 56 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. 57 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 58 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 60 | The dropout ratio for the attention probabilities. 61 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 62 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 63 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): 64 | The epsilon used by the layer normalization layers. 65 | feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`): 66 | The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group 67 | normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D 68 | convolutional layers. 69 | feat_extract_dropout (:obj:`float`, `optional`, defaults to 0.0): 70 | The dropout probabilitiy for all 1D convolutional layers in feature extractor. 71 | feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`): 72 | The non-linear activation function (function or string) in the 1D convolutional layers of the feature 73 | extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. 74 | conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`): 75 | A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the 76 | feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. 77 | conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`): 78 | A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length 79 | of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`. 80 | conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`): 81 | A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The 82 | length of `conv_kernel` defines the number of convolutional layers and has to match the the length of 83 | `conv_dim`. 84 | conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): 85 | Whether the 1D convolutional layers have a bias. 86 | num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128): 87 | Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional 88 | embeddings layer. 89 | num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16): 90 | Number of groups of 1D convolutional positional embeddings layer. 91 | do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`): 92 | Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is 93 | True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is 94 | False`` corresponds to applying layer norm after the attention layer. 95 | apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`): 96 | Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see 97 | `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition 98 | `__. 99 | mask_time_prob (:obj:`float`, `optional`, defaults to 0.05): 100 | Propability of each feature vector along the time axis to be chosen as the start of the vector span to be 101 | masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be 102 | masked along the time axis. This is only relevant if ``apply_spec_augment is True``. 103 | mask_time_length (:obj:`int`, `optional`, defaults to 10): 104 | Length of vector span along the time axis. 105 | mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0): 106 | Propability of each feature vector along the feature axis to be chosen as the start of the vector span to 107 | be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be 108 | masked along the time axis. This is only relevant if ``apply_spec_augment is True``. 109 | mask_feature_length (:obj:`int`, `optional`, defaults to 10): 110 | Length of vector span along the feature axis. 111 | ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`): 112 | Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an 113 | instance of :class:`~transformers.HubertForCTC`. 114 | ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`): 115 | Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses 116 | mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an 117 | instance of :class:`~transformers.HubertForCTC`. 118 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): 119 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 120 | 121 | Example:: 122 | 123 | >>> from transformers import HubertModel, HubertConfig 124 | 125 | >>> # Initializing a Hubert facebook/hubert-base-ls960 style configuration 126 | >>> configuration = HubertConfig() 127 | 128 | >>> # Initializing a model from the facebook/hubert-base-ls960 style configuration 129 | >>> model = HubertModel(configuration) 130 | 131 | >>> # Accessing the model configuration 132 | >>> configuration = model.config 133 | """ 134 | model_type = "hubert" 135 | 136 | def __init__( 137 | self, 138 | vocab_size=32, 139 | hidden_size=768, 140 | num_hidden_layers=12, 141 | num_attention_heads=12, 142 | intermediate_size=3072, 143 | hidden_act="gelu", 144 | hidden_dropout=0.1, 145 | activation_dropout=0.1, 146 | attention_dropout=0.1, 147 | feat_proj_dropout=0.1, 148 | final_dropout=0.1, 149 | layerdrop=0.1, 150 | initializer_range=0.02, 151 | layer_norm_eps=1e-5, 152 | feat_extract_norm="group", 153 | feat_extract_activation="gelu", 154 | conv_dim=(512, 512, 512, 512, 512, 512, 512), 155 | conv_stride=(5, 2, 2, 2, 2, 2, 2), 156 | conv_kernel=(10, 3, 3, 3, 3, 2, 2), 157 | conv_bias=False, 158 | num_conv_pos_embeddings=128, 159 | num_conv_pos_embedding_groups=16, 160 | do_stable_layer_norm=False, 161 | apply_spec_augment=True, 162 | mask_time_prob=0.05, 163 | mask_time_length=10, 164 | mask_feature_prob=0.0, 165 | mask_feature_length=10, 166 | ctc_loss_reduction="sum", 167 | ctc_zero_infinity=False, 168 | gradient_checkpointing=False, 169 | pad_token_id=0, 170 | bos_token_id=1, 171 | eos_token_id=2, 172 | **kwargs 173 | ): 174 | super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) 175 | self.hidden_size = hidden_size 176 | self.feat_extract_norm = feat_extract_norm 177 | self.feat_extract_activation = feat_extract_activation 178 | self.conv_dim = list(conv_dim) 179 | self.conv_stride = list(conv_stride) 180 | self.conv_kernel = list(conv_kernel) 181 | self.conv_bias = conv_bias 182 | self.num_conv_pos_embeddings = num_conv_pos_embeddings 183 | self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups 184 | self.num_feat_extract_layers = len(self.conv_dim) 185 | self.num_hidden_layers = num_hidden_layers 186 | self.intermediate_size = intermediate_size 187 | self.hidden_act = hidden_act 188 | self.num_attention_heads = num_attention_heads 189 | self.hidden_dropout = hidden_dropout 190 | self.attention_dropout = attention_dropout 191 | self.activation_dropout = activation_dropout 192 | self.feat_proj_dropout = feat_proj_dropout 193 | self.final_dropout = final_dropout 194 | self.layerdrop = layerdrop 195 | self.layer_norm_eps = layer_norm_eps 196 | self.initializer_range = initializer_range 197 | self.vocab_size = vocab_size 198 | self.do_stable_layer_norm = do_stable_layer_norm 199 | self.gradient_checkpointing = gradient_checkpointing 200 | 201 | if ( 202 | (len(self.conv_stride) != self.num_feat_extract_layers) 203 | or (len(self.conv_kernel) != self.num_feat_extract_layers) 204 | or (len(self.conv_dim) != self.num_feat_extract_layers) 205 | ): 206 | raise ValueError( 207 | "Configuration for convolutional layers is incorrect." 208 | "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`," 209 | f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)" 210 | f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." 211 | ) 212 | 213 | # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 214 | self.apply_spec_augment = apply_spec_augment 215 | self.mask_time_prob = mask_time_prob 216 | self.mask_time_length = mask_time_length 217 | self.mask_feature_prob = mask_feature_prob 218 | self.mask_feature_length = mask_feature_length 219 | 220 | # ctc loss 221 | self.ctc_loss_reduction = ctc_loss_reduction 222 | self.ctc_zero_infinity = ctc_zero_infinity 223 | -------------------------------------------------------------------------------- /hubert/deepspeed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Integration with Deepspeed 16 | """ 17 | 18 | import importlib.util 19 | import io 20 | import json 21 | import weakref 22 | from copy import deepcopy 23 | from functools import partialmethod 24 | 25 | from .dependency_versions_check import dep_version_check 26 | from .file_utils import is_torch_available 27 | from .utils import logging 28 | 29 | 30 | if is_torch_available(): 31 | import torch 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | 36 | def is_deepspeed_available(): 37 | return importlib.util.find_spec("deepspeed") is not None 38 | 39 | 40 | class HfDeepSpeedConfig: 41 | """ 42 | This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. 43 | 44 | A ``weakref`` of this object is stored in the module's globals to be able to access the config from areas where 45 | things like the Trainer object is not available (e.g. ``from_pretrained`` and ``_get_resized_embeddings``). 46 | Therefore it's important that this object remains alive while the program is still running. 47 | 48 | :class:`~transformers.Trainer` uses the ``HfTrainerDeepSpeedConfig`` subclass instead. That subclass has logic to 49 | sync the configuration with values of :class:`~transformers.TrainingArguments` by replacing special placeholder 50 | values: ``"auto"``. Without this special logic the DeepSpeed configuration is not modified in any way. 51 | 52 | Args: 53 | config_file_or_dict (:obj:`Union[str, Dict]`) - path to DeepSpeed config file or dict. 54 | 55 | """ 56 | 57 | def __init__(self, config_file_or_dict): 58 | # set global weakref object 59 | set_hf_deepspeed_config(self) 60 | 61 | dep_version_check("deepspeed") 62 | 63 | if isinstance(config_file_or_dict, dict): 64 | # Don't modify user's data should they want to reuse it (e.g. in tests), because once we 65 | # modified it, it will not be accepted here again, since `auto` values would have been overriden 66 | config = deepcopy(config_file_or_dict) 67 | elif isinstance(config_file_or_dict, str): 68 | with io.open(config_file_or_dict, "r", encoding="utf-8") as f: 69 | config = json.load(f) 70 | else: 71 | raise ValueError("expecting either a path to a DeepSpeed config file or a pre-populated dict") 72 | self.config = config 73 | 74 | # zero stage - this is done as early as possible, before model is created, to allow 75 | # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object 76 | # during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc. 77 | self._stage = self.get_value("zero_optimization.stage", -1) 78 | 79 | # offload 80 | self._offload = False 81 | if self.is_zero2() or self.is_zero3(): 82 | offload_devices_valid = set(["cpu", "nvme"]) 83 | offload_devices = set( 84 | [ 85 | self.get_value("zero_optimization.offload_optimizer.device"), 86 | self.get_value("zero_optimization.offload_param.device"), 87 | ] 88 | ) 89 | if len(offload_devices & offload_devices_valid) > 0: 90 | self._offload = True 91 | 92 | def find_config_node(self, ds_key_long): 93 | config = self.config 94 | 95 | # find the config node of interest if it exists 96 | nodes = ds_key_long.split(".") 97 | ds_key = nodes.pop() 98 | for node in nodes: 99 | config = config.get(node) 100 | if config is None: 101 | return None, ds_key 102 | 103 | return config, ds_key 104 | 105 | def get_value(self, ds_key_long, default=None): 106 | """ 107 | Returns the set value or ``default`` if no value is set 108 | """ 109 | config, ds_key = self.find_config_node(ds_key_long) 110 | if config is None: 111 | return default 112 | return config.get(ds_key, default) 113 | 114 | def is_true(self, ds_key_long): 115 | """ 116 | Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to 117 | ask the very specific question of whether the value is set to :obj:`True` (and it's not set to :obj:`False` or 118 | isn't set). 119 | 120 | """ 121 | value = self.get_value(ds_key_long) 122 | return False if value is None else bool(value) 123 | 124 | def is_false(self, ds_key_long): 125 | """ 126 | Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to 127 | ask the very specific question of whether the value is set to :obj:`False` (and it's not set to :obj:`True` or 128 | isn't set). 129 | """ 130 | value = self.get_value(ds_key_long) 131 | return False if value is None else not bool(value) 132 | 133 | def is_zero2(self): 134 | return self._stage == 2 135 | 136 | def is_zero3(self): 137 | return self._stage == 3 138 | 139 | def is_offload(self): 140 | return self._offload 141 | 142 | 143 | class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): 144 | """ 145 | The ``HfTrainerDeepSpeedConfig`` object is meant to be created during ``TrainingArguments`` object creation and has 146 | the same lifespan as the latter. 147 | """ 148 | 149 | def __init__(self, config_file_or_dict): 150 | super().__init__(config_file_or_dict) 151 | self._dtype = torch.float16 152 | self.mismatches = [] 153 | 154 | def dtype(self): 155 | return self._dtype 156 | 157 | def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): 158 | """ 159 | A utility method that massages the config file and can optionally verify that the values match. 160 | 161 | 1. Replace "auto" values with ``TrainingArguments`` value. 162 | 163 | 2. If it wasn't "auto" and ``must_match`` is true, then check that DS config matches Trainer 164 | config values and if mismatched add the entry to ``self.mismatched`` - will assert during 165 | ``trainer_config_finalize`` for one or more mismatches. 166 | 167 | """ 168 | config, ds_key = self.find_config_node(ds_key_long) 169 | if config is None: 170 | return 171 | 172 | if config.get(ds_key) == "auto": 173 | config[ds_key] = hf_val 174 | return 175 | 176 | if not must_match: 177 | return 178 | 179 | ds_val = config.get(ds_key) 180 | if ds_val is not None and ds_val != hf_val: 181 | self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") 182 | 183 | fill_only = partialmethod(fill_match, must_match=False) 184 | 185 | def trainer_config_process(self, args): 186 | """ 187 | Adjust the config with ``TrainingArguments`` values. This stage is run during ``TrainingArguments`` object 188 | creation. 189 | """ 190 | # DeepSpeed does: 191 | # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps 192 | train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps 193 | self.fill_match( 194 | "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size" 195 | ) 196 | self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") 197 | self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)") 198 | self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") 199 | 200 | self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") 201 | self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") 202 | self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") 203 | self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") 204 | 205 | self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg 206 | self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") 207 | self.fill_match("scheduler.params.warmup_num_steps", args.warmup_steps, "warmup_steps") 208 | # total_num_steps - will get set in trainer_config_finalize 209 | 210 | # fp16 211 | if args.fp16: 212 | fp16_backend = "apex" if args.fp16_backend == "apex" else "amp" 213 | else: 214 | fp16_backend = None 215 | 216 | # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set 217 | # any here unless the user did the work 218 | self.fill_match("fp16.enabled", fp16_backend == "amp", "fp16+fp16_backend(amp)") 219 | 220 | # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any 221 | # ZeRO features 222 | self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") 223 | self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") 224 | 225 | # only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this 226 | # whole config section is missing then the fallback is fp16 227 | if self.is_false("fp16.enabled"): 228 | self._dtype = torch.float32 229 | # later there will be other dtypes besides just fp16 and fp32 230 | # also not quite sure what dtype should be under apex, defaulting to fp16 for now 231 | 232 | def trainer_config_finalize(self, args, model, num_training_steps): 233 | """ 234 | This stage is run after we have the model and know num_training_steps. 235 | 236 | Now we we can complete the configuration process. 237 | """ 238 | # zero 239 | if self.is_zero3(): 240 | # automatically assign the optimal config values based on model config 241 | hidden_size = model.config.hidden_size 242 | self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) 243 | self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) 244 | self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) 245 | 246 | # scheduler 247 | self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") 248 | 249 | if len(self.mismatches) > 0: 250 | mismatches = "\n".join(self.mismatches) 251 | raise ValueError( 252 | f"Please correct the following DeepSpeed config values that mismatch TrainingArguments values:\n{mismatches}\n" 253 | "The easiest method is to set these DeepSpeed config values to 'auto'." 254 | ) 255 | 256 | 257 | # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle 258 | _hf_deepspeed_config_weak_ref = None 259 | 260 | 261 | def set_hf_deepspeed_config(hf_deepspeed_config_obj): 262 | # this is a special weakref global object to allow us to get to Deepspeed config from APIs 263 | # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain. 264 | global _hf_deepspeed_config_weak_ref 265 | # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed) 266 | _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj) 267 | 268 | 269 | def is_deepspeed_zero3_enabled(): 270 | if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: 271 | return _hf_deepspeed_config_weak_ref().is_zero3() 272 | else: 273 | return False 274 | 275 | 276 | def deepspeed_config(): 277 | if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: 278 | return _hf_deepspeed_config_weak_ref().config 279 | else: 280 | return None 281 | 282 | 283 | def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): 284 | """ 285 | Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. 286 | 287 | If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made. 288 | 289 | Args: 290 | trainer: Trainer object 291 | num_training_steps: per single gpu 292 | resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load 293 | 294 | Returns: model, optimizer, lr_scheduler 295 | 296 | """ 297 | import deepspeed 298 | 299 | model = trainer.model 300 | 301 | hf_deepspeed_config = trainer.args.hf_deepspeed_config 302 | hf_deepspeed_config.trainer_config_finalize(trainer.args, model, num_training_steps) 303 | 304 | # resume config update - some bits like `model` and `num_training_steps` only become available during train 305 | config = hf_deepspeed_config.config 306 | 307 | # Optimizer + Scheduler 308 | # Currently supported combos: 309 | # 1. DS scheduler + DS optimizer: Yes 310 | # 2. HF scheduler + HF optimizer: Yes 311 | # 3. DS scheduler + HF optimizer: Yes 312 | # 4. HF scheduler + DS optimizer: No 313 | # 314 | # Unless Offload is enabled in which case it's: 315 | # 1. DS scheduler + DS optimizer: Yes 316 | # 2. HF scheduler + HF optimizer: No 317 | # 3. DS scheduler + HF optimizer: No 318 | # 4. HF scheduler + DS optimizer: No 319 | 320 | optimizer = None 321 | if "optimizer" not in config: 322 | if hf_deepspeed_config.is_offload(): 323 | raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers") 324 | 325 | # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. 326 | # But trainer uses AdamW by default. 327 | trainer.create_optimizer() 328 | optimizer = trainer.optimizer 329 | # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` 330 | config["zero_allow_untested_optimizer"] = True 331 | 332 | # DS schedulers (deepspeed/runtime/lr_schedules.py): 333 | # 334 | # DS name | --lr_scheduler_type | HF func | Notes 335 | # -------------| ---------------------|-----------------------------------|-------------------- 336 | # LRRangeTest | na | na | LRRT 337 | # OneCycle | na | na | 1CLR 338 | # WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0 339 | # WarmupDecayLR| linear | get_linear_schedule_with_warmup | 340 | lr_scheduler = None 341 | if "scheduler" not in config: 342 | if "optimizer" in config: 343 | # to make this option work, we need to init DS optimizer first, then init HS scheduler, 344 | # then pass the HS scheduler to DS init, which is not possible at the moment 345 | raise ValueError("At the moment HF scheduler + DeepSpeed optimizer combination is not possible") 346 | else: 347 | trainer.create_scheduler(num_training_steps=num_training_steps) 348 | lr_scheduler = trainer.lr_scheduler 349 | 350 | # keep for quick debug: 351 | # from pprint import pprint; pprint(config) 352 | 353 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 354 | 355 | model, optimizer, _, lr_scheduler = deepspeed.initialize( 356 | model=model, 357 | model_parameters=model_parameters, 358 | config_params=config, 359 | optimizer=optimizer, 360 | lr_scheduler=lr_scheduler, 361 | ) 362 | 363 | if resume_from_checkpoint is not None: 364 | 365 | # it's possible that the user is trying to resume from model_path, which doesn't necessarily 366 | # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's 367 | # a resume from a checkpoint and not just a local pretrained weight. So we check here if the 368 | # path contains what looks like a deepspeed checkpoint 369 | import glob 370 | 371 | deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*")) 372 | 373 | if len(deepspeed_checkpoint_dirs) > 0: 374 | logger.info(f"Attempting to resume from {resume_from_checkpoint}") 375 | # this magically updates self.optimizer and self.lr_scheduler 376 | load_path, _ = model.load_checkpoint( 377 | resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True 378 | ) 379 | if load_path is None: 380 | raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}") 381 | else: 382 | logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing") 383 | 384 | return model, optimizer, lr_scheduler 385 | -------------------------------------------------------------------------------- /hubert/dependency_versions_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | from .dependency_versions_table import deps 17 | from .utils.versions import require_version, require_version_core 18 | 19 | 20 | # define which module versions we always want to check at run time 21 | # (usually the ones defined in `install_requires` in setup.py) 22 | # 23 | # order specific notes: 24 | # - tqdm must be checked before tokenizers 25 | 26 | pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split() 27 | if sys.version_info < (3, 7): 28 | pkgs_to_check_at_runtime.append("dataclasses") 29 | if sys.version_info < (3, 8): 30 | pkgs_to_check_at_runtime.append("importlib_metadata") 31 | 32 | for pkg in pkgs_to_check_at_runtime: 33 | if pkg in deps: 34 | if pkg == "tokenizers": 35 | # must be loaded here, or else tqdm check may fail 36 | from .file_utils import is_tokenizers_available 37 | 38 | if not is_tokenizers_available(): 39 | continue # not required, check version only if installed 40 | 41 | require_version_core(deps[pkg]) 42 | else: 43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") 44 | 45 | 46 | def dep_version_check(pkg, hint=None): 47 | require_version(deps[pkg], hint) 48 | -------------------------------------------------------------------------------- /hubert/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "Pillow": "Pillow", 6 | "black": "black==21.4b0", 7 | "cookiecutter": "cookiecutter==1.7.2", 8 | "dataclasses": "dataclasses", 9 | "datasets": "datasets", 10 | "deepspeed": "deepspeed>=0.4.0", 11 | "docutils": "docutils==0.16.0", 12 | "fairscale": "fairscale>0.3", 13 | "faiss-cpu": "faiss-cpu", 14 | "fastapi": "fastapi", 15 | "filelock": "filelock", 16 | "flake8": "flake8>=3.8.3", 17 | "flax": "flax>=0.3.4", 18 | "fugashi": "fugashi>=1.0", 19 | "huggingface-hub": "huggingface-hub==0.0.8", 20 | "importlib_metadata": "importlib_metadata", 21 | "ipadic": "ipadic>=1.0.0,<2.0", 22 | "isort": "isort>=5.5.4", 23 | "jax": "jax>=0.2.8", 24 | "jaxlib": "jaxlib>=0.1.65", 25 | "jieba": "jieba", 26 | "keras2onnx": "keras2onnx", 27 | "nltk": "nltk", 28 | "numpy": "numpy>=1.17", 29 | "onnxconverter-common": "onnxconverter-common", 30 | "onnxruntime-tools": "onnxruntime-tools>=1.4.2", 31 | "onnxruntime": "onnxruntime>=1.4.0", 32 | "optuna": "optuna", 33 | "packaging": "packaging", 34 | "parameterized": "parameterized", 35 | "protobuf": "protobuf", 36 | "psutil": "psutil", 37 | "pyyaml": "pyyaml", 38 | "pydantic": "pydantic", 39 | "pytest": "pytest", 40 | "pytest-sugar": "pytest-sugar", 41 | "pytest-xdist": "pytest-xdist", 42 | "python": "python>=3.6.0", 43 | "ray": "ray", 44 | "recommonmark": "recommonmark", 45 | "regex": "regex!=2019.12.17", 46 | "requests": "requests", 47 | "rouge-score": "rouge-score", 48 | "sacrebleu": "sacrebleu>=1.4.12", 49 | "sacremoses": "sacremoses", 50 | "sagemaker": "sagemaker>=2.31.0", 51 | "scikit-learn": "scikit-learn", 52 | "sentencepiece": "sentencepiece==0.1.91", 53 | "soundfile": "soundfile", 54 | "sphinx-copybutton": "sphinx-copybutton", 55 | "sphinx-markdown-tables": "sphinx-markdown-tables", 56 | "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", 57 | "sphinx": "sphinx==3.2.1", 58 | "sphinxext-opengraph": "sphinxext-opengraph==0.4.1", 59 | "starlette": "starlette", 60 | "tensorflow-cpu": "tensorflow-cpu>=2.3", 61 | "tensorflow": "tensorflow>=2.3", 62 | "timeout-decorator": "timeout-decorator", 63 | "timm": "timm", 64 | "tokenizers": "tokenizers>=0.10.1,<0.11", 65 | "torch": "torch>=1.0", 66 | "torchaudio": "torchaudio", 67 | "tqdm": "tqdm>=4.27", 68 | "unidic": "unidic>=1.0.2", 69 | "unidic_lite": "unidic_lite>=1.0.7", 70 | "uvicorn": "uvicorn", 71 | } 72 | -------------------------------------------------------------------------------- /hubert/generation_beam_search.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import warnings 17 | from abc import ABC, abstractmethod 18 | from collections import UserDict 19 | from typing import Optional, Tuple 20 | 21 | import torch 22 | 23 | from .file_utils import add_start_docstrings 24 | 25 | 26 | PROCESS_INPUTS_DOCSTRING = r""" 27 | Args: 28 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): 29 | Indices of input sequence tokens in the vocabulary. 30 | 31 | Indices can be obtained using any class inheriting from :class:`~transformers.PreTrainedTokenizer`. See 32 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 33 | details. 34 | 35 | `What are input IDs? <../glossary.html#input-ids>`__ 36 | next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): 37 | Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. 38 | next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): 39 | :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. 40 | next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): 41 | Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. 42 | pad_token_id (:obj:`int`, `optional`): 43 | The id of the `padding` token. 44 | eos_token_id (:obj:`int`, `optional`): 45 | The id of the `end-of-sequence` token. 46 | 47 | Return: 48 | :obj:`UserDict`: A dictionary composed of the fields as defined above: 49 | 50 | - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated 51 | scores of all non-finished beams. 52 | - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens 53 | to be added to the non-finished beam_hypotheses. 54 | - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices 55 | indicating to which beam the next tokens shall be added. 56 | 57 | """ 58 | 59 | FINALIZE_INPUTS_DOCSTRING = r""" 60 | Args: 61 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): 62 | Indices of input sequence tokens in the vocabulary. 63 | 64 | Indices can be obtained using any class inheriting from :class:`~transformers.PreTrainedTokenizer`. See 65 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 66 | details. 67 | 68 | `What are input IDs? <../glossary.html#input-ids>`__ 69 | final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): 70 | The final scores of all non-finished beams. 71 | final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): 72 | The last tokens to be added to the non-finished beam_hypotheses. 73 | final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): 74 | The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. 75 | pad_token_id (:obj:`int`, `optional`): 76 | The id of the `padding` token. 77 | eos_token_id (:obj:`int`, `optional`): 78 | The id of the `end-of-sequence` token. 79 | 80 | Return: 81 | :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated 82 | sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all 83 | batches finished early due to the :obj:`eos_token_id`. 84 | 85 | """ 86 | 87 | 88 | class BeamScorer(ABC): 89 | """ 90 | Abstract base class for all beam scorers that are used for :meth:`~transformers.PreTrainedModel.beam_search` and 91 | :meth:`~transformers.PreTrainedModel.beam_sample`. 92 | """ 93 | 94 | @abstractmethod 95 | @add_start_docstrings(PROCESS_INPUTS_DOCSTRING) 96 | def process( 97 | self, 98 | input_ids: torch.LongTensor, 99 | next_scores: torch.FloatTensor, 100 | next_tokens: torch.LongTensor, 101 | next_indices: torch.LongTensor, 102 | **kwargs 103 | ) -> Tuple[torch.Tensor]: 104 | raise NotImplementedError("This is an abstract method.") 105 | 106 | @abstractmethod 107 | @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING) 108 | def finalize( 109 | self, 110 | input_ids: torch.LongTensor, 111 | next_scores: torch.FloatTensor, 112 | next_tokens: torch.LongTensor, 113 | next_indices: torch.LongTensor, 114 | max_length: int, 115 | **kwargs 116 | ) -> torch.LongTensor: 117 | raise NotImplementedError("This is an abstract method.") 118 | 119 | 120 | class BeamSearchScorer(BeamScorer): 121 | r""" 122 | :class:`transformers.BeamScorer` implementing standard beam search decoding. 123 | 124 | Adapted in part from `Facebook's XLM beam search code 125 | `__. 126 | 127 | Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation 128 | `__ 129 | 130 | Args: 131 | batch_size (:obj:`int`): 132 | Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel. 133 | max_length (:obj:`int`): 134 | The maximum length of the sequence to be generated. 135 | num_beams (:obj:`int`): 136 | Number of beams for beam search. 137 | device (:obj:`torch.device`): 138 | Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of 139 | :obj:`BeamSearchScorer` will be allocated. 140 | length_penalty (:obj:`float`, `optional`, defaults to 1.0): 141 | Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the 142 | model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer 143 | sequences. 144 | do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): 145 | Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. 146 | num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): 147 | The number of beam hypotheses that shall be returned upon calling 148 | :meth:`~transformer.BeamSearchScorer.finalize`. 149 | num_beam_groups (:obj:`int`): 150 | Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of 151 | beams. See `this paper `__ for more details. 152 | """ 153 | 154 | def __init__( 155 | self, 156 | batch_size: int, 157 | num_beams: int, 158 | device: torch.device, 159 | length_penalty: Optional[float] = 1.0, 160 | do_early_stopping: Optional[bool] = False, 161 | num_beam_hyps_to_keep: Optional[int] = 1, 162 | num_beam_groups: Optional[int] = 1, 163 | **kwargs, 164 | ): 165 | self.num_beams = num_beams 166 | self.device = device 167 | self.length_penalty = length_penalty 168 | self.do_early_stopping = do_early_stopping 169 | self.num_beam_hyps_to_keep = num_beam_hyps_to_keep 170 | self.num_beam_groups = num_beam_groups 171 | self.group_size = self.num_beams // self.num_beam_groups 172 | 173 | self._is_init = False 174 | self._beam_hyps = [ 175 | BeamHypotheses( 176 | num_beams=self.num_beams, 177 | length_penalty=self.length_penalty, 178 | early_stopping=self.do_early_stopping, 179 | ) 180 | for _ in range(batch_size) 181 | ] 182 | self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) 183 | 184 | if not isinstance(num_beams, int) or num_beams <= 1: 185 | raise ValueError( 186 | f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." 187 | ) 188 | 189 | if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): 190 | raise ValueError( 191 | f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` " 192 | f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." 193 | ) 194 | 195 | if "max_length" in kwargs: 196 | warnings.warn( 197 | "Passing `max_length` to BeamSearchScorer is deprecated and has no effect." 198 | "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" 199 | ",or `group_beam_search(...)`." 200 | ) 201 | 202 | @property 203 | def is_done(self) -> bool: 204 | return self._done.all() 205 | 206 | def process( 207 | self, 208 | input_ids: torch.LongTensor, 209 | next_scores: torch.FloatTensor, 210 | next_tokens: torch.LongTensor, 211 | next_indices: torch.LongTensor, 212 | pad_token_id: Optional[int] = None, 213 | eos_token_id: Optional[int] = None, 214 | ) -> Tuple[torch.Tensor]: 215 | cur_len = input_ids.shape[-1] 216 | batch_size = len(self._beam_hyps) 217 | assert batch_size == (input_ids.shape[0] // self.group_size) 218 | 219 | device = input_ids.device 220 | next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) 221 | next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) 222 | next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) 223 | 224 | for batch_idx, beam_hyp in enumerate(self._beam_hyps): 225 | if self._done[batch_idx]: 226 | assert ( 227 | len(beam_hyp) >= self.num_beams 228 | ), f"Batch can only be done if at least {self.num_beams} beams have been generated" 229 | assert ( 230 | eos_token_id is not None and pad_token_id is not None 231 | ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" 232 | # pad the batch 233 | next_beam_scores[batch_idx, :] = 0 234 | next_beam_tokens[batch_idx, :] = pad_token_id 235 | next_beam_indices[batch_idx, :] = 0 236 | continue 237 | 238 | # next tokens for this sentence 239 | beam_idx = 0 240 | for beam_token_rank, (next_token, next_score, next_index) in enumerate( 241 | zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) 242 | ): 243 | batch_beam_idx = batch_idx * self.group_size + next_index 244 | # add to generated hypotheses if end of sentence 245 | if (eos_token_id is not None) and (next_token.item() == eos_token_id): 246 | # if beam_token does not belong to top num_beams tokens, it should not be added 247 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size 248 | if is_beam_token_worse_than_top_num_beams: 249 | continue 250 | beam_hyp.add( 251 | input_ids[batch_beam_idx].clone(), 252 | next_score.item(), 253 | ) 254 | else: 255 | # add next predicted token since it is not eos_token 256 | next_beam_scores[batch_idx, beam_idx] = next_score 257 | next_beam_tokens[batch_idx, beam_idx] = next_token 258 | next_beam_indices[batch_idx, beam_idx] = batch_beam_idx 259 | beam_idx += 1 260 | 261 | # once the beam for next step is full, don't add more tokens to it. 262 | if beam_idx == self.group_size: 263 | break 264 | 265 | if beam_idx < self.group_size: 266 | raise ValueError( 267 | f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." 268 | ) 269 | 270 | # Check if we are done so that we can save a pad step if all(done) 271 | self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( 272 | next_scores[batch_idx].max().item(), cur_len 273 | ) 274 | 275 | return UserDict( 276 | { 277 | "next_beam_scores": next_beam_scores.view(-1), 278 | "next_beam_tokens": next_beam_tokens.view(-1), 279 | "next_beam_indices": next_beam_indices.view(-1), 280 | } 281 | ) 282 | 283 | def finalize( 284 | self, 285 | input_ids: torch.LongTensor, 286 | final_beam_scores: torch.FloatTensor, 287 | final_beam_tokens: torch.LongTensor, 288 | final_beam_indices: torch.LongTensor, 289 | max_length: int, 290 | pad_token_id: Optional[int] = None, 291 | eos_token_id: Optional[int] = None, 292 | ) -> Tuple[torch.LongTensor]: 293 | batch_size = len(self._beam_hyps) 294 | 295 | # finalize all open beam hypotheses and add to generated hypotheses 296 | for batch_idx, beam_hyp in enumerate(self._beam_hyps): 297 | if self._done[batch_idx]: 298 | continue 299 | 300 | # all open beam hypotheses are added to the beam hypothesis 301 | # beam hypothesis class automatically keeps the best beams 302 | for beam_id in range(self.num_beams): 303 | batch_beam_idx = batch_idx * self.num_beams + beam_id 304 | final_score = final_beam_scores[batch_beam_idx].item() 305 | final_tokens = input_ids[batch_beam_idx] 306 | beam_hyp.add(final_tokens, final_score) 307 | 308 | # select the best hypotheses 309 | sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) 310 | best = [] 311 | best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) 312 | 313 | # retrieve best hypotheses 314 | for i, beam_hyp in enumerate(self._beam_hyps): 315 | sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) 316 | for j in range(self.num_beam_hyps_to_keep): 317 | best_hyp_tuple = sorted_hyps.pop() 318 | best_score = best_hyp_tuple[0] 319 | best_hyp = best_hyp_tuple[1] 320 | sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) 321 | 322 | # append to lists 323 | best.append(best_hyp) 324 | best_scores[i * self.num_beam_hyps_to_keep + j] = best_score 325 | 326 | # prepare for adding eos 327 | sent_max_len = min(sent_lengths.max().item() + 1, max_length) 328 | decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) 329 | # shorter batches are padded if needed 330 | if sent_lengths.min().item() != sent_lengths.max().item(): 331 | assert pad_token_id is not None, "`pad_token_id` has to be defined" 332 | decoded.fill_(pad_token_id) 333 | 334 | # fill with hypotheses and eos_token_id if the latter fits in 335 | for i, hypo in enumerate(best): 336 | decoded[i, : sent_lengths[i]] = hypo 337 | if sent_lengths[i] < max_length: 338 | decoded[i, sent_lengths[i]] = eos_token_id 339 | return UserDict( 340 | { 341 | "sequences": decoded, 342 | "sequence_scores": best_scores, 343 | } 344 | ) 345 | 346 | 347 | class BeamHypotheses: 348 | def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): 349 | """ 350 | Initialize n-best list of hypotheses. 351 | """ 352 | self.length_penalty = length_penalty 353 | self.early_stopping = early_stopping 354 | self.num_beams = num_beams 355 | self.beams = [] 356 | self.worst_score = 1e9 357 | 358 | def __len__(self): 359 | """ 360 | Number of hypotheses in the list. 361 | """ 362 | return len(self.beams) 363 | 364 | def add(self, hyp: torch.LongTensor, sum_logprobs: float): 365 | """ 366 | Add a new hypothesis to the list. 367 | """ 368 | score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) 369 | if len(self) < self.num_beams or score > self.worst_score: 370 | self.beams.append((score, hyp)) 371 | if len(self) > self.num_beams: 372 | sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) 373 | del self.beams[sorted_next_scores[0][1]] 374 | self.worst_score = sorted_next_scores[1][0] 375 | else: 376 | self.worst_score = min(score, self.worst_score) 377 | 378 | def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: 379 | """ 380 | If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst 381 | one in the heap, then we are done with this sentence. 382 | """ 383 | 384 | if len(self) < self.num_beams: 385 | return False 386 | elif self.early_stopping: 387 | return True 388 | else: 389 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 390 | ret = self.worst_score >= cur_score 391 | return ret 392 | -------------------------------------------------------------------------------- /hubert/generation_logits_process.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Inc. team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import inspect 17 | import math 18 | from abc import ABC 19 | from typing import Callable, Iterable, List 20 | 21 | import numpy as np 22 | import torch 23 | 24 | from .file_utils import add_start_docstrings 25 | from .utils.logging import get_logger 26 | 27 | 28 | logger = get_logger(__name__) 29 | 30 | 31 | LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" 32 | Args: 33 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 34 | Indices of input sequence tokens in the vocabulary. 35 | 36 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See 37 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 38 | details. 39 | 40 | `What are input IDs? <../glossary.html#input-ids>`__ 41 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): 42 | Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam 43 | search or log softmax for each vocabulary token when using beam search 44 | kwargs: 45 | Additional logits processor specific kwargs. 46 | 47 | Return: 48 | :obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. 49 | 50 | """ 51 | 52 | 53 | class LogitsProcessor(ABC): 54 | """Abstract base class for all logit processors that can be applied during generation.""" 55 | 56 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 57 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 58 | """Torch method for processing logits.""" 59 | raise NotImplementedError( 60 | f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." 61 | ) 62 | 63 | 64 | class LogitsWarper(ABC): 65 | """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" 66 | 67 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 68 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 69 | """Torch method for warping logits.""" 70 | raise NotImplementedError( 71 | f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." 72 | ) 73 | 74 | 75 | class LogitsProcessorList(list): 76 | """ 77 | This class can be used to create a list of :class:`~transformers.LogitsProcessor` or 78 | :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from 79 | list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or 80 | :class:`~transformers.LogitsWarper` to the inputs. 81 | """ 82 | 83 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 84 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: 85 | for processor in self: 86 | function_args = inspect.signature(processor.__call__).parameters 87 | if len(function_args) > 2: 88 | assert all( 89 | arg in kwargs for arg in list(function_args.keys())[2:] 90 | ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." 91 | scores = processor(input_ids, scores, **kwargs) 92 | else: 93 | scores = processor(input_ids, scores) 94 | return scores 95 | 96 | 97 | class MinLengthLogitsProcessor(LogitsProcessor): 98 | r""" 99 | :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. 100 | 101 | Args: 102 | min_length (:obj:`int`): 103 | The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. 104 | eos_token_id (:obj:`int`): 105 | The id of the `end-of-sequence` token. 106 | """ 107 | 108 | def __init__(self, min_length: int, eos_token_id: int): 109 | if not isinstance(min_length, int) or min_length < 0: 110 | raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") 111 | 112 | if not isinstance(eos_token_id, int) or eos_token_id < 0: 113 | raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") 114 | 115 | self.min_length = min_length 116 | self.eos_token_id = eos_token_id 117 | 118 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 119 | cur_len = input_ids.shape[-1] 120 | if cur_len < self.min_length: 121 | scores[:, self.eos_token_id] = -float("inf") 122 | return scores 123 | 124 | 125 | class TemperatureLogitsWarper(LogitsWarper): 126 | r""" 127 | :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution). 128 | 129 | Args: 130 | temperature (:obj:`float`): 131 | The value used to module the logits distribution. 132 | """ 133 | 134 | def __init__(self, temperature: float): 135 | if not isinstance(temperature, float) or not (temperature > 0): 136 | raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") 137 | 138 | self.temperature = temperature 139 | 140 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: 141 | scores = scores / self.temperature 142 | return scores 143 | 144 | 145 | class RepetitionPenaltyLogitsProcessor(LogitsProcessor): 146 | r""" 147 | :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences. 148 | 149 | Args: 150 | repetition_penalty (:obj:`float`): 151 | The parameter for repetition penalty. 1.0 means no penalty. See `this paper 152 | `__ for more details. 153 | """ 154 | 155 | def __init__(self, penalty: float): 156 | if not isinstance(penalty, float) or not (penalty > 0): 157 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") 158 | 159 | self.penalty = penalty 160 | 161 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 162 | score = torch.gather(scores, 1, input_ids) 163 | 164 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability 165 | score = torch.where(score < 0, score * self.penalty, score / self.penalty) 166 | 167 | scores.scatter_(1, input_ids, score) 168 | return scores 169 | 170 | 171 | class TopPLogitsWarper(LogitsWarper): 172 | """ 173 | :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= 174 | prob_cut_off. 175 | 176 | Args: 177 | top_p (:obj:`float`): 178 | If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are 179 | kept for generation. 180 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): 181 | All filtered values will be set to this float value. 182 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): 183 | Minimum number of tokens that cannot be filtered. 184 | """ 185 | 186 | def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): 187 | if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): 188 | raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") 189 | 190 | self.top_p = top_p 191 | self.filter_value = filter_value 192 | self.min_tokens_to_keep = min_tokens_to_keep 193 | 194 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 195 | sorted_logits, sorted_indices = torch.sort(scores, descending=True) 196 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 197 | 198 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 199 | sorted_indices_to_remove = cumulative_probs > self.top_p 200 | if self.min_tokens_to_keep > 1: 201 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 202 | sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 203 | # Shift the indices to the right to keep also the first token above the threshold 204 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 205 | sorted_indices_to_remove[..., 0] = 0 206 | 207 | # scatter sorted tensors to original indexing 208 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 209 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 210 | return scores 211 | 212 | 213 | class TopKLogitsWarper(LogitsWarper): 214 | r""" 215 | :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. 216 | 217 | Args: 218 | top_k (:obj:`int`): 219 | The number of highest probability vocabulary tokens to keep for top-k-filtering. 220 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): 221 | All filtered values will be set to this float value. 222 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): 223 | Minimum number of tokens that cannot be filtered. 224 | """ 225 | 226 | def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): 227 | if not isinstance(top_k, int) or top_k <= 0: 228 | raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") 229 | 230 | self.top_k = top_k 231 | self.filter_value = filter_value 232 | self.min_tokens_to_keep = min_tokens_to_keep 233 | 234 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 235 | top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check 236 | # Remove all tokens with a probability less than the last token of the top-k 237 | indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] 238 | scores = scores.masked_fill(indices_to_remove, self.filter_value) 239 | return scores 240 | 241 | 242 | def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): 243 | generated_ngrams = [{} for _ in range(num_hypos)] 244 | for idx in range(num_hypos): 245 | gen_tokens = prev_input_ids[idx].tolist() 246 | generated_ngram = generated_ngrams[idx] 247 | for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): 248 | prev_ngram_tuple = tuple(ngram[:-1]) 249 | generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] 250 | return generated_ngrams 251 | 252 | 253 | def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): 254 | # Before decoding the next token, prevent decoding of ngrams that have already appeared 255 | start_idx = cur_len + 1 - ngram_size 256 | ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) 257 | return banned_ngrams.get(ngram_idx, []) 258 | 259 | 260 | def _calc_banned_ngram_tokens( 261 | ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int 262 | ) -> List[Iterable[int]]: 263 | """Copied from fairseq for no_repeat_ngram in beam_search""" 264 | if cur_len + 1 < ngram_size: 265 | # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet 266 | return [[] for _ in range(num_hypos)] 267 | 268 | generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) 269 | 270 | banned_tokens = [ 271 | _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) 272 | for hypo_idx in range(num_hypos) 273 | ] 274 | return banned_tokens 275 | 276 | 277 | class NoRepeatNGramLogitsProcessor(LogitsProcessor): 278 | r""" 279 | :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq 280 | `__. 281 | 282 | Args: 283 | ngram_size (:obj:`int`): 284 | All ngrams of size :obj:`ngram_size` can only occur once. 285 | """ 286 | 287 | def __init__(self, ngram_size: int): 288 | if not isinstance(ngram_size, int) or ngram_size <= 0: 289 | raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") 290 | self.ngram_size = ngram_size 291 | 292 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 293 | num_batch_hypotheses = scores.shape[0] 294 | cur_len = input_ids.shape[-1] 295 | banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) 296 | 297 | for i, banned_tokens in enumerate(banned_batch_tokens): 298 | scores[i, banned_tokens] = -float("inf") 299 | 300 | return scores 301 | 302 | 303 | class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): 304 | r""" 305 | :class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids. 306 | See `ParlAI `__. 307 | 308 | Args: 309 | encoder_ngram_size (:obj:`int`): 310 | All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids. 311 | encoder_input_ids (:obj:`int`): 312 | The encoder_input_ids that should not be repeated within the decoder ids. 313 | """ 314 | 315 | def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor): 316 | if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0: 317 | raise ValueError( 318 | f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" 319 | ) 320 | self.ngram_size = encoder_ngram_size 321 | if len(encoder_input_ids.shape) == 1: 322 | encoder_input_ids = encoder_input_ids.unsqueeze(0) 323 | self.batch_size = encoder_input_ids.shape[0] 324 | self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) 325 | 326 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 327 | # B x num_beams 328 | num_hypos = scores.shape[0] 329 | num_beams = num_hypos // self.batch_size 330 | cur_len = input_ids.shape[-1] 331 | banned_batch_tokens = [ 332 | _get_generated_ngrams( 333 | self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len 334 | ) 335 | for hypo_idx in range(num_hypos) 336 | ] 337 | 338 | for i, banned_tokens in enumerate(banned_batch_tokens): 339 | scores[i, banned_tokens] = -float("inf") 340 | 341 | return scores 342 | 343 | 344 | class NoBadWordsLogitsProcessor(LogitsProcessor): 345 | """ 346 | :class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled. 347 | 348 | Args: 349 | bad_words_ids (:obj:`List[List[int]]`): 350 | List of list of token ids that are not allowed to be generated. In order to get the tokens of the words 351 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 352 | add_prefix_space=True).input_ids`. 353 | eos_token_id (:obj:`int`): 354 | The id of the `end-of-sequence` token. 355 | """ 356 | 357 | def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int): 358 | 359 | if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: 360 | raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.") 361 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): 362 | raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") 363 | if any( 364 | any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) 365 | for bad_word_ids in bad_words_ids 366 | ): 367 | raise ValueError( 368 | f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." 369 | ) 370 | 371 | self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) 372 | 373 | for banned_token_seq in self.bad_words_ids: 374 | assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list" 375 | 376 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 377 | banned_tokens = self._calc_banned_bad_words_ids(input_ids) 378 | scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens) 379 | 380 | return scores 381 | 382 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 383 | if len(tokens) == 0: 384 | # if bad word tokens is just one token always ban it 385 | return True 386 | elif len(tokens) > len(prev_tokens): 387 | # if bad word tokens are longer then prev input_ids they can't be equal 388 | return False 389 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 390 | # if tokens match 391 | return True 392 | else: 393 | return False 394 | 395 | def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 396 | banned_tokens = [] 397 | for prev_input_ids_slice in prev_input_ids: 398 | banned_tokens_slice = [] 399 | for banned_token_seq in self.bad_words_ids: 400 | if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: 401 | # if tokens do not match continue 402 | continue 403 | 404 | banned_tokens_slice.append(banned_token_seq[-1]) 405 | 406 | banned_tokens.append(banned_tokens_slice) 407 | 408 | return banned_tokens 409 | 410 | def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: 411 | """ 412 | Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a 413 | list of list of banned tokens to ban in the format [[batch index, vocabulary position],... 414 | 415 | Args: 416 | scores: logits distribution of shape (batch size, vocabulary size) 417 | banned_tokens: list of list of tokens to ban of length (batch_size) 418 | """ 419 | banned_mask_list = [] 420 | for idx, batch_banned_tokens in enumerate(banned_tokens): 421 | for token in batch_banned_tokens: 422 | # Eliminates invalid bad word IDs that are over the vocabulary size. 423 | if token <= scores.shape[1]: 424 | banned_mask_list.append([idx, token]) 425 | else: 426 | logger.error( 427 | f"An invalid bad word ID is defined: {token}. This ID is not contained in the" 428 | f"vocabulary, and is therefore ignored." 429 | ) 430 | if not banned_mask_list: 431 | return scores 432 | 433 | banned_mask = torch.LongTensor(banned_mask_list) 434 | indices = torch.ones(len(banned_mask)) 435 | # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: 436 | # [ 0 1 1 ] 437 | # [ 0 0 0 ] 438 | # [ 1 0 0 ] 439 | 440 | banned_mask = ( 441 | torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() 442 | ) 443 | scores = scores.masked_fill(banned_mask, -float("inf")) 444 | return scores 445 | 446 | 447 | class PrefixConstrainedLogitsProcessor(LogitsProcessor): 448 | r""" 449 | :class:`transformers.LogitsProcessor` that enforces constrained generation and is useful for prefix-conditioned 450 | constrained generation. See `Autoregressive Entity Retrieval `__ for more 451 | information. 452 | 453 | Args: 454 | prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`): 455 | This function constraints the beam search to allowed tokens only at each step. This function takes 2 456 | arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed 457 | tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and 458 | the batch ID :obj:`batch_id`. 459 | """ 460 | 461 | def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): 462 | self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn 463 | self._num_beams = num_beams 464 | 465 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 466 | mask = torch.full_like(scores, -math.inf) 467 | for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): 468 | for beam_id, sent in enumerate(beam_sent): 469 | mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 470 | 471 | return scores + mask 472 | 473 | 474 | class HammingDiversityLogitsProcessor(LogitsProcessor): 475 | r""" 476 | :class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only 477 | effective for :meth:`transformers.PreTrainedModel.group_beam_search`. See `Diverse Beam Search: Decoding Diverse 478 | Solutions from Neural Sequence Models `__ for more details. 479 | 480 | Args: 481 | diversity_penalty (:obj:`float`): 482 | This value is subtracted from a beam's score if it generates a token same as any beam from other group at a 483 | particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. 484 | num_beams (:obj:`int`): 485 | Number of beams used for group beam search. See `this paper `__ for 486 | more details. 487 | num_beam_groups (:obj:`int`): 488 | Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of 489 | beams. See `this paper `__ for more details. 490 | """ 491 | 492 | def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): 493 | if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): 494 | raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") 495 | self._diversity_penalty = diversity_penalty 496 | if not isinstance(num_beams, int) or num_beams < 2: 497 | raise ValueError("`num_beams` should be an integer strictly larger than 1.") 498 | self._num_beams = num_beams 499 | if not isinstance(num_beam_groups, int) or num_beam_groups < 2: 500 | raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") 501 | if num_beam_groups > num_beams: 502 | raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") 503 | self._num_sub_beams = num_beams // num_beam_groups 504 | 505 | def __call__( 506 | self, 507 | input_ids: torch.LongTensor, 508 | scores: torch.FloatTensor, 509 | current_tokens: torch.LongTensor, 510 | beam_group_idx: int, 511 | ) -> torch.FloatTensor: 512 | # hamming diversity: penalise using same token in current group which was used in previous groups at 513 | # the same time step 514 | batch_size = current_tokens.shape[0] // self._num_beams 515 | group_start_idx = beam_group_idx * self._num_sub_beams 516 | group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) 517 | group_size = group_end_idx - group_start_idx 518 | vocab_size = scores.shape[-1] 519 | 520 | if group_start_idx == 0: 521 | return scores 522 | 523 | for batch_idx in range(batch_size): 524 | # predicted tokens of last time step of previous groups 525 | previous_group_tokens = current_tokens[ 526 | batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx 527 | ] 528 | token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device) 529 | scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency 530 | 531 | return scores 532 | 533 | 534 | class ForcedBOSTokenLogitsProcessor(LogitsProcessor): 535 | r""" 536 | :class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token. 537 | 538 | Args: 539 | bos_token_id (:obj:`int`): 540 | The id of the token to force as the first generated token. 541 | """ 542 | 543 | def __init__(self, bos_token_id: int): 544 | self.bos_token_id = bos_token_id 545 | 546 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 547 | cur_len = input_ids.shape[-1] 548 | if cur_len == 1: 549 | num_tokens = scores.shape[1] 550 | scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") 551 | scores[:, self.bos_token_id] = 0 552 | return scores 553 | 554 | 555 | class ForcedEOSTokenLogitsProcessor(LogitsProcessor): 556 | r""" 557 | :class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when 558 | :obj:`max_length` is reached. 559 | 560 | Args: 561 | max_length (:obj:`int`): 562 | The maximum length of the sequence to be generated. 563 | eos_token_id (:obj:`int`): 564 | The id of the token to force as the last generated token when :obj:`max_length` is reached. 565 | """ 566 | 567 | def __init__(self, max_length: int, eos_token_id: int): 568 | self.max_length = max_length 569 | self.eos_token_id = eos_token_id 570 | 571 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 572 | cur_len = input_ids.shape[-1] 573 | if cur_len == self.max_length - 1: 574 | num_tokens = scores.shape[1] 575 | scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") 576 | scores[:, self.eos_token_id] = 0 577 | return scores 578 | 579 | 580 | class InfNanRemoveLogitsProcessor(LogitsProcessor): 581 | r""" 582 | :class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation 583 | method to fail. Note that using the logits processor should only be used if necessary since it can slow down the 584 | generation method. :obj:`max_length` is reached. 585 | """ 586 | 587 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 588 | # set all nan values to 0.0 589 | scores[scores != scores] = 0.0 590 | 591 | # set all inf values to max possible value 592 | scores[scores == float("inf")] = torch.finfo(scores.dtype).max 593 | 594 | return scores 595 | -------------------------------------------------------------------------------- /hubert/generation_stopping_criteria.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from abc import ABC 4 | from copy import deepcopy 5 | from typing import Optional 6 | 7 | import torch 8 | 9 | from .file_utils import add_start_docstrings 10 | 11 | 12 | STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" 13 | Args: 14 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 15 | Indices of input sequence tokens in the vocabulary. 16 | 17 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See 18 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 19 | details. 20 | 21 | `What are input IDs? <../glossary.html#input-ids>`__ 22 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): 23 | Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax 24 | or scores for each vocabulary token after SoftMax. 25 | kwargs: 26 | Additional stopping criteria specific kwargs. 27 | 28 | Return: 29 | :obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop. 30 | 31 | """ 32 | 33 | 34 | class StoppingCriteria(ABC): 35 | """Abstract base class for all stopping criteria that can be applied during generation.""" 36 | 37 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 38 | def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: 39 | raise NotImplementedError("StoppingCriteria needs to be subclassed") 40 | 41 | 42 | class MaxLengthCriteria(StoppingCriteria): 43 | """ 44 | This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`. 45 | Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens. 46 | 47 | Args: 48 | max_length (:obj:`int`): 49 | The maximum length that the output sequence can have in number of tokens. 50 | """ 51 | 52 | def __init__(self, max_length: int): 53 | self.max_length = max_length 54 | 55 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 56 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 57 | return input_ids.shape[-1] >= self.max_length 58 | 59 | 60 | class MaxNewTokensCriteria(StoppingCriteria): 61 | """ 62 | This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`. 63 | Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is 64 | very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens. 65 | 66 | Args: 67 | start_length (:obj:`int`): 68 | The number of initial tokens. 69 | max_new_tokens (:obj:`int`): 70 | The maximum number of tokens to generate. 71 | """ 72 | 73 | def __init__(self, start_length: int, max_new_tokens: int): 74 | self.start_length = start_length 75 | self.max_new_tokens = max_new_tokens 76 | self.max_length = start_length + max_new_tokens 77 | 78 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 79 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 80 | return input_ids.shape[-1] >= self.max_length 81 | 82 | 83 | class MaxTimeCriteria(StoppingCriteria): 84 | """ 85 | This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the 86 | time will start being counted when you initialize this function. You can override this by passing an 87 | :obj:`initial_time`. 88 | 89 | Args: 90 | max_time (:obj:`float`): 91 | The maximum allowed time in seconds for the generation. 92 | initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`): 93 | The start of the generation allowed time. 94 | """ 95 | 96 | def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): 97 | self.max_time = max_time 98 | self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp 99 | 100 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 102 | return time.time() - self.initial_timestamp > self.max_time 103 | 104 | 105 | class StoppingCriteriaList(list): 106 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 107 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 108 | return any(criteria(input_ids, scores) for criteria in self) 109 | 110 | @property 111 | def max_length(self) -> Optional[int]: 112 | for stopping_criterium in self: 113 | if isinstance(stopping_criterium, MaxLengthCriteria): 114 | return stopping_criterium.max_length 115 | elif isinstance(stopping_criterium, MaxNewTokensCriteria): 116 | return stopping_criterium.max_length 117 | return None 118 | 119 | 120 | def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: 121 | stopping_max_length = stopping_criteria.max_length 122 | new_stopping_criteria = deepcopy(stopping_criteria) 123 | if stopping_max_length is not None and stopping_max_length != max_length: 124 | warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) 125 | elif stopping_max_length is None: 126 | new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) 127 | return new_stopping_criteria 128 | -------------------------------------------------------------------------------- /hubert/utils/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities. """ 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from logging import CRITICAL # NOQA 22 | from logging import DEBUG # NOQA 23 | from logging import ERROR # NOQA 24 | from logging import FATAL # NOQA 25 | from logging import INFO # NOQA 26 | from logging import NOTSET # NOQA 27 | from logging import WARN # NOQA 28 | from logging import WARNING # NOQA 29 | from typing import Optional 30 | 31 | 32 | _lock = threading.Lock() 33 | _default_handler: Optional[logging.Handler] = None 34 | 35 | log_levels = { 36 | "debug": logging.DEBUG, 37 | "info": logging.INFO, 38 | "warning": logging.WARNING, 39 | "error": logging.ERROR, 40 | "critical": logging.CRITICAL, 41 | } 42 | 43 | _default_log_level = logging.WARNING 44 | 45 | 46 | def _get_default_logging_level(): 47 | """ 48 | If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 49 | not - fall back to ``_default_log_level`` 50 | """ 51 | env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None) 52 | if env_level_str: 53 | if env_level_str in log_levels: 54 | return log_levels[env_level_str] 55 | else: 56 | logging.getLogger().warning( 57 | f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, " 58 | f"has to be one of: { ', '.join(log_levels.keys()) }" 59 | ) 60 | return _default_log_level 61 | 62 | 63 | def _get_library_name() -> str: 64 | 65 | return __name__.split(".")[0] 66 | 67 | 68 | def _get_library_root_logger() -> logging.Logger: 69 | 70 | return logging.getLogger(_get_library_name()) 71 | 72 | 73 | def _configure_library_root_logger() -> None: 74 | 75 | global _default_handler 76 | 77 | with _lock: 78 | if _default_handler: 79 | # This library has already configured the library root logger. 80 | return 81 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 82 | _default_handler.flush = sys.stderr.flush 83 | 84 | # Apply our default configuration to the library root logger. 85 | library_root_logger = _get_library_root_logger() 86 | library_root_logger.addHandler(_default_handler) 87 | library_root_logger.setLevel(_get_default_logging_level()) 88 | library_root_logger.propagate = False 89 | 90 | 91 | def _reset_library_root_logger() -> None: 92 | 93 | global _default_handler 94 | 95 | with _lock: 96 | if not _default_handler: 97 | return 98 | 99 | library_root_logger = _get_library_root_logger() 100 | library_root_logger.removeHandler(_default_handler) 101 | library_root_logger.setLevel(logging.NOTSET) 102 | _default_handler = None 103 | 104 | 105 | def get_logger(name: Optional[str] = None) -> logging.Logger: 106 | """ 107 | Return a logger with the specified name. 108 | 109 | This function is not supposed to be directly accessed unless you are writing a custom transformers module. 110 | """ 111 | 112 | if name is None: 113 | name = _get_library_name() 114 | 115 | _configure_library_root_logger() 116 | return logging.getLogger(name) 117 | 118 | 119 | def get_verbosity() -> int: 120 | """ 121 | Return the current level for the 🤗 Transformers's root logger as an int. 122 | 123 | Returns: 124 | :obj:`int`: The logging level. 125 | 126 | .. note:: 127 | 128 | 🤗 Transformers has following logging levels: 129 | 130 | - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` 131 | - 40: ``transformers.logging.ERROR`` 132 | - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` 133 | - 20: ``transformers.logging.INFO`` 134 | - 10: ``transformers.logging.DEBUG`` 135 | """ 136 | 137 | _configure_library_root_logger() 138 | return _get_library_root_logger().getEffectiveLevel() 139 | 140 | 141 | def set_verbosity(verbosity: int) -> None: 142 | """ 143 | Set the verbosity level for the 🤗 Transformers's root logger. 144 | 145 | Args: 146 | verbosity (:obj:`int`): 147 | Logging level, e.g., one of: 148 | 149 | - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` 150 | - ``transformers.logging.ERROR`` 151 | - ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` 152 | - ``transformers.logging.INFO`` 153 | - ``transformers.logging.DEBUG`` 154 | """ 155 | 156 | _configure_library_root_logger() 157 | _get_library_root_logger().setLevel(verbosity) 158 | 159 | 160 | def set_verbosity_info(): 161 | """Set the verbosity to the :obj:`INFO` level.""" 162 | return set_verbosity(INFO) 163 | 164 | 165 | def set_verbosity_warning(): 166 | """Set the verbosity to the :obj:`WARNING` level.""" 167 | return set_verbosity(WARNING) 168 | 169 | 170 | def set_verbosity_debug(): 171 | """Set the verbosity to the :obj:`DEBUG` level.""" 172 | return set_verbosity(DEBUG) 173 | 174 | 175 | def set_verbosity_error(): 176 | """Set the verbosity to the :obj:`ERROR` level.""" 177 | return set_verbosity(ERROR) 178 | 179 | 180 | def disable_default_handler() -> None: 181 | """Disable the default handler of the HuggingFace Transformers's root logger.""" 182 | 183 | _configure_library_root_logger() 184 | 185 | assert _default_handler is not None 186 | _get_library_root_logger().removeHandler(_default_handler) 187 | 188 | 189 | def enable_default_handler() -> None: 190 | """Enable the default handler of the HuggingFace Transformers's root logger.""" 191 | 192 | _configure_library_root_logger() 193 | 194 | assert _default_handler is not None 195 | _get_library_root_logger().addHandler(_default_handler) 196 | 197 | 198 | def add_handler(handler: logging.Handler) -> None: 199 | """adds a handler to the HuggingFace Transformers's root logger.""" 200 | 201 | _configure_library_root_logger() 202 | 203 | assert handler is not None 204 | _get_library_root_logger().addHandler(handler) 205 | 206 | 207 | def remove_handler(handler: logging.Handler) -> None: 208 | """removes given handler from the HuggingFace Transformers's root logger.""" 209 | 210 | _configure_library_root_logger() 211 | 212 | assert handler is not None and handler not in _get_library_root_logger().handlers 213 | _get_library_root_logger().removeHandler(handler) 214 | 215 | 216 | def disable_propagation() -> None: 217 | """ 218 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 219 | """ 220 | 221 | _configure_library_root_logger() 222 | _get_library_root_logger().propagate = False 223 | 224 | 225 | def enable_propagation() -> None: 226 | """ 227 | Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to 228 | prevent double logging if the root logger has been configured. 229 | """ 230 | 231 | _configure_library_root_logger() 232 | _get_library_root_logger().propagate = True 233 | 234 | 235 | def enable_explicit_format() -> None: 236 | """ 237 | Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows: 238 | 239 | :: 240 | 241 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 242 | 243 | All handlers currently bound to the root logger are affected by this method. 244 | """ 245 | handlers = _get_library_root_logger().handlers 246 | 247 | for handler in handlers: 248 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") 249 | handler.setFormatter(formatter) 250 | 251 | 252 | def reset_format() -> None: 253 | """ 254 | Resets the formatting for HuggingFace Transformers's loggers. 255 | 256 | All handlers currently bound to the root logger are affected by this method. 257 | """ 258 | handlers = _get_library_root_logger().handlers 259 | 260 | for handler in handlers: 261 | handler.setFormatter(None) 262 | -------------------------------------------------------------------------------- /hubert/utils/versions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utilities for working with package versions 16 | """ 17 | 18 | import operator 19 | import re 20 | import sys 21 | from typing import Optional 22 | 23 | from packaging import version 24 | 25 | 26 | # The package importlib_metadata is in a different place, depending on the python version. 27 | if sys.version_info < (3, 8): 28 | import importlib_metadata 29 | else: 30 | import importlib.metadata as importlib_metadata 31 | 32 | 33 | ops = { 34 | "<": operator.lt, 35 | "<=": operator.le, 36 | "==": operator.eq, 37 | "!=": operator.ne, 38 | ">=": operator.ge, 39 | ">": operator.gt, 40 | } 41 | 42 | 43 | def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): 44 | if got_ver is None: 45 | raise ValueError("got_ver is None") 46 | if want_ver is None: 47 | raise ValueError("want_ver is None") 48 | if not ops[op](version.parse(got_ver), version.parse(want_ver)): 49 | raise ImportError( 50 | f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" 51 | ) 52 | 53 | 54 | def require_version(requirement: str, hint: Optional[str] = None) -> None: 55 | """ 56 | Perform a runtime check of the dependency versions, using the exact same syntax used by pip. 57 | 58 | The installed module version comes from the `site-packages` dir via `importlib_metadata`. 59 | 60 | Args: 61 | requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" 62 | hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met 63 | 64 | Example:: 65 | 66 | require_version("pandas>1.1.2") 67 | require_version("numpy>1.18.5", "this is important to have for whatever reason") 68 | 69 | """ 70 | 71 | hint = f"\n{hint}" if hint is not None else "" 72 | 73 | # non-versioned check 74 | if re.match(r"^[\w_\-\d]+$", requirement): 75 | pkg, op, want_ver = requirement, None, None 76 | else: 77 | match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) 78 | if not match: 79 | raise ValueError( 80 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" 81 | ) 82 | pkg, want_full = match[0] 83 | want_range = want_full.split(",") # there could be multiple requirements 84 | wanted = {} 85 | for w in want_range: 86 | match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) 87 | if not match: 88 | raise ValueError( 89 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" 90 | ) 91 | op, want_ver = match[0] 92 | wanted[op] = want_ver 93 | if op not in ops: 94 | raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") 95 | 96 | # special case 97 | if pkg == "python": 98 | got_ver = ".".join([str(x) for x in sys.version_info[:3]]) 99 | for op, want_ver in wanted.items(): 100 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) 101 | return 102 | 103 | # check if any version is installed 104 | try: 105 | got_ver = importlib_metadata.version(pkg) 106 | except importlib_metadata.PackageNotFoundError: 107 | raise importlib_metadata.PackageNotFoundError( 108 | f"The '{requirement}' distribution was not found and is required by this application. {hint}" 109 | ) 110 | 111 | # check that the right version is installed if version number or a range was provided 112 | if want_ver is not None: 113 | for op, want_ver in wanted.items(): 114 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) 115 | 116 | 117 | def require_version_core(requirement): 118 | """require_version wrapper which emits a core-specific hint on failure""" 119 | hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master" 120 | return require_version(requirement, hint) 121 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 118 | 119 | 120 | 121 | FaceXHuBERT 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 136 | 137 | 138 | 139 |
140 |
141 | FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis Using Self-Supervised Speech Representation Learning (ICMI '23) 142 | 143 |
144 | 145 | 150 | 155 | 162 | 163 |
146 |
147 | Kazi Injamamul Haque 148 |
149 |
151 |
152 | Zerrin Yumak 153 |
154 |
164 | 165 | 166 | 171 | 176 | 177 |
167 |
168 | [Paper] 169 |
170 |
172 |
173 | [GitHub]
174 |
175 |
178 | 179 |
180 | 181 | 201 | 202 |
203 | 204 | 205 |

Abstract

206 | 207 | 210 | 211 |
208 | This paper presents FaceXHuBERT, a text-less speech-driven 3D facial animation generation method that allows to capture personalized and subtle cues in speech (e.g. identity, emotion and hesitation). It is also very robust to background noise and can handle audio recorded in a variety of situations (e.g. multiple people speaking). Recent approaches employ end-to-end deep learning taking into account both audio and text as input to generate facial animation for the whole face. However, scarcity of publicly available expressive audio-3D facial animation datasets poses a major bottleneck. The resulting animations still have issues regarding accurate lip-synching, expressivity, person-specific information and generalizability. We effectively employ self-supervised pretrained HuBERT model in the training process that allows us to incorporate both lexical and non-lexical information in the audio without using a large lexicon. Additionally, guiding the training with a binary emotion condition and speaker identity distinguishes the tiniest subtle facial motion. We carried out extensive objective and subjective evaluation in comparison to ground truth and state-of-the-art work. A perceptual user study demonstrates that our approach produces superior results with respect to the realism of the animation 78% of the time in comparison to the state-of-the-art. In addition, our method is 4 times faster eliminating the use of complex sequential models such as transformers. We strongly recommend watching the supplementary video before reading the paper. 209 |
212 |
213 |
214 |

Video

215 |

216 | 217 |

218 | 219 | 230 |
231 | 232 |

Methodology

233 | 234 | 235 |
236 |
237 | 239 | 240 | 241 |
238 |
242 | 243 | 244 | 247 | 248 | 249 | 250 |
245 |
246 |
251 | 252 |
253 |
254 | 257 | 258 | 259 |
255 | 256 |
260 |
261 |

Code

262 | 263 |
264 |
265 |  [GitHub] 266 |
267 | 268 |
269 |
270 |
271 | 272 |

Paper and Supplementary Material

273 | 274 | 275 | 283 | 284 |
K.I. Haque and Z. Yumak
276 | FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis using Self-Supervised Speech Representation Learning
277 | 278 | (Pre-print on ArXiv)
279 | 280 |
281 |
282 |
285 |
286 | 287 | 288 | 289 | 292 | 293 |
290 | [Bibtex] 291 |
294 | 295 |
296 |
297 | 298 | 299 | 300 | 306 | 307 |
301 | 302 |

Acknowledgements

303 | This template was originally made by Phillip Isola and Richard Zhang for a colorful ECCV project; the code can be found here. 304 |
305 |
308 | 309 |
310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from tqdm import tqdm 4 | import os, shutil 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import torch.nn as nn 8 | from data_loader import get_dataloaders 9 | from faceXhubert import FaceXHuBERT 10 | 11 | 12 | def plot_losses(train_losses, val_losses): 13 | plt.plot(train_losses, label="Training loss") 14 | plt.plot(val_losses, label="Validation loss") 15 | plt.legend() 16 | plt.title("Losses") 17 | plt.savefig("losses.png") 18 | plt.close() 19 | 20 | 21 | def trainer(args, train_loader, dev_loader, model, optimizer, criterion, epoch=100): 22 | train_losses = [] 23 | val_losses = [] 24 | 25 | save_path = os.path.join(args.save_path) 26 | if os.path.exists(save_path): 27 | shutil.rmtree(save_path) 28 | os.makedirs(save_path) 29 | 30 | train_subjects_list = [i for i in args.train_subjects.split(" ")] 31 | iteration = 0 32 | for e in range(epoch+1): 33 | loss_log = [] 34 | model.train() 35 | pbar = tqdm(enumerate(train_loader),total=len(train_loader)) 36 | optimizer.zero_grad() 37 | 38 | for i, (audio, vertice, template, one_hot, file_name, emo_one_hot) in pbar: 39 | iteration += 1 40 | vertice = str(vertice[0]) 41 | vertice = np.load(vertice,allow_pickle=True) 42 | vertice = vertice.astype(np.float32) 43 | vertice = torch.from_numpy(vertice) 44 | vertice = torch.unsqueeze(vertice,0) 45 | audio, vertice, template, one_hot, emo_one_hot = audio.to(device="cuda"), vertice.to(device="cuda"), template.to(device="cuda"), one_hot.to(device="cuda"), emo_one_hot.to(device="cuda") 46 | loss = model(audio, template, vertice, one_hot, emo_one_hot, criterion) 47 | 48 | loss.backward() 49 | loss_log.append(loss.item()) 50 | if i % args.gradient_accumulation_steps==0: 51 | optimizer.step() 52 | optimizer.zero_grad() 53 | del audio, vertice, template, one_hot, emo_one_hot 54 | torch.cuda.empty_cache() 55 | 56 | pbar.set_description("(Epoch {}, iteration {}) TRAIN LOSS:{:.8f}".format((e+1), iteration ,np.mean(loss_log))) 57 | 58 | train_losses.append(np.mean(loss_log)) 59 | 60 | valid_loss_log = [] 61 | model.eval() 62 | for audio, vertice, template, one_hot_all,file_name, emo_one_hot in dev_loader: 63 | # to gpu 64 | vertice = str(vertice[0]) 65 | vertice = np.load(vertice,allow_pickle=True) 66 | vertice = vertice.astype(np.float32) 67 | vertice = torch.from_numpy(vertice) 68 | vertice = torch.unsqueeze(vertice,0) 69 | audio, vertice, template, one_hot_all, emo_one_hot= audio.to(device="cuda"), vertice.to(device="cuda"), template.to(device="cuda"), one_hot_all.to(device="cuda"), emo_one_hot.to(device="cuda") 70 | train_subject = "_".join(file_name[0].split("_")[:-1]) 71 | if train_subject in train_subjects_list: 72 | condition_subject = train_subject 73 | iter = train_subjects_list.index(condition_subject) 74 | one_hot = one_hot_all[:,iter,:] 75 | loss = model(audio, template, vertice, one_hot, emo_one_hot, criterion) 76 | valid_loss_log.append(loss.item()) 77 | else: 78 | for iter in range(one_hot_all.shape[-1]): 79 | condition_subject = train_subjects_list[iter] 80 | one_hot = one_hot_all[:,iter,:] 81 | loss = model(audio, template, vertice, one_hot, emo_one_hot, criterion) 82 | valid_loss_log.append(loss.item()) 83 | 84 | current_loss = np.mean(valid_loss_log) 85 | 86 | val_losses.append(current_loss) 87 | if (e > 0 and e % 25 == 0) or e == args.max_epoch: 88 | torch.save(model.state_dict(), os.path.join(save_path,'{}_model.pth'.format(e))) 89 | 90 | print("epcoh: {}, current loss:{:.8f}".format(e+1,current_loss)) 91 | 92 | plot_losses(train_losses, val_losses) 93 | 94 | return model 95 | 96 | 97 | @torch.no_grad() 98 | def test(args, model, test_loader,epoch): 99 | result_path = os.path.join(args.result_path) 100 | if os.path.exists(result_path): 101 | shutil.rmtree(result_path) 102 | os.makedirs(result_path) 103 | 104 | save_path = os.path.join(args.save_path) 105 | train_subjects_list = [i for i in args.train_subjects.split(" ")] 106 | 107 | model.load_state_dict(torch.load(os.path.join(save_path, '{}_model.pth'.format(epoch)))) 108 | model = model.to(torch.device("cuda")) 109 | model.eval() 110 | 111 | for audio, vertice, template, one_hot_all, file_name, emo_one_hot in test_loader: 112 | vertice = str(vertice[0]) 113 | vertice = np.load(vertice,allow_pickle=True) 114 | vertice = vertice.astype(np.float32) 115 | vertice = torch.from_numpy(vertice) 116 | vertice = torch.unsqueeze(vertice,0) 117 | audio, vertice, template, one_hot_all, emo_one_hot= audio.to(device="cuda"), vertice.to(device="cuda"), template.to(device="cuda"), one_hot_all.to(device="cuda"), emo_one_hot.to(device="cuda") 118 | train_subject = "_".join(file_name[0].split("_")[:-1]) 119 | if train_subject in train_subjects_list: 120 | condition_subject = train_subject 121 | iter = train_subjects_list.index(condition_subject) 122 | one_hot = one_hot_all[:,iter,:] 123 | prediction = model.predict(audio, template, one_hot, emo_one_hot) 124 | prediction = prediction.squeeze() 125 | np.save(os.path.join(result_path, file_name[0].split(".")[0]+"_condition_"+condition_subject+".npy"), prediction.detach().cpu().numpy()) 126 | else: 127 | for iter in range(one_hot_all.shape[-1]): 128 | condition_subject = train_subjects_list[iter] 129 | one_hot = one_hot_all[:,iter,:] 130 | prediction = model.predict(audio, template, one_hot) 131 | prediction = prediction.squeeze() 132 | np.save(os.path.join(result_path, file_name[0].split(".")[0]+"_condition_"+condition_subject+".npy"), prediction.detach().cpu().numpy()) 133 | 134 | 135 | def count_parameters(model): 136 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 137 | 138 | 139 | def main(): 140 | parser = argparse.ArgumentParser(description='FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis using Self-Supervised Speech Representation Learning') 141 | parser.add_argument("--lr", type=float, default=0.0001, help='learning rate') 142 | parser.add_argument("--dataset", type=str, default="BIWI", help='Name of the dataset folder. eg: BIWI') 143 | parser.add_argument("--vertice_dim", type=int, default=70110, help='number of vertices - 23370*3 for BIWI dataset') 144 | parser.add_argument("--feature_dim", type=int, default=256, help='GRU Vertex decoder hidden size') 145 | parser.add_argument("--wav_path", type=str, default= "wav", help='path of the audio signals') 146 | parser.add_argument("--vertices_path", type=str, default="vertices_npy", help='path of the ground truth') 147 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help='gradient accumulation') 148 | parser.add_argument("--max_epoch", type=int, default=100, help='number of epochs') 149 | parser.add_argument("--device", type=str, default="cuda") 150 | parser.add_argument("--template_file", type=str, default="templates_scaled.pkl", help='path of the train subject templates') 151 | parser.add_argument("--save_path", type=str, default="save", help='path of the trained models') 152 | parser.add_argument("--result_path", type=str, default="result", help='path to the predictions') 153 | parser.add_argument("--train_subjects", type=str, default="F1 F2 F3 F4 F5 F6 F7 F8 M1 M2 M3 M4 M5 M6") 154 | parser.add_argument("--val_subjects", type=str, default="F1 F2 F3 F4 F5 F6 F7 F8 M1 M2 M3 M4 M5 M6") 155 | parser.add_argument("--test_subjects", type=str, default="F1 F2 F3 F4 F5 F6 F7 F8 M1 M2 M3 M4 M5 M6") 156 | parser.add_argument("--input_fps", type=int, default=50, help='HuBERT last hidden state produces 50 fps audio representation') 157 | parser.add_argument("--output_fps", type=int, default=25, help='fps of the visual data, BIWI was captured in 25 fps') 158 | args = parser.parse_args() 159 | 160 | model = FaceXHuBERT(args) 161 | print("model parameters: ", count_parameters(model)) 162 | 163 | assert torch.cuda.is_available() 164 | 165 | model = model.to(torch.device("cuda")) 166 | dataset = get_dataloaders(args) 167 | criterion = nn.HuberLoss() 168 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=args.lr) 169 | 170 | model = trainer(args, dataset["train"], dataset["valid"],model, optimizer, criterion, epoch=args.max_epoch) 171 | 172 | test(args, model, dataset["test"], epoch=args.max_epoch) 173 | 174 | print(model) 175 | 176 | 177 | if __name__ == "__main__": 178 | main() 179 | -------------------------------------------------------------------------------- /page_assets/bibtex.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{facexhubert, 2 | author = {Haque, Kazi Injamamul and Yumak, Zerrin}, 3 | title = {FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis Using Self-Supervised Speech Representation Learning}, 4 | booktitle = {INTERNATIONAL CONFERENCE ON MULTIMODAL INTERACTION (ICMI ’23)}, 5 | year = {2023}, 6 | location = {Paris, France}, 7 | numpages = {10}, 8 | url = {https://doi.org/10.1145/3577190.3614157}, 9 | doi = {10.1145/3577190.3614157}, 10 | publisher = {ACM}, 11 | address = {New York, NY, USA}, 12 | } -------------------------------------------------------------------------------- /page_assets/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/page_assets/paper.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import argparse 3 | import pickle 4 | import torch 5 | import trimesh 6 | import numpy as np 7 | import cv2 8 | import os 9 | import ffmpeg 10 | import gc 11 | import pyrender 12 | from faceXhubert import FaceXHuBERT 13 | from transformers import Wav2Vec2Processor 14 | import time 15 | 16 | 17 | def test_model(args): 18 | if not os.path.exists(args.result_path): 19 | os.makedirs(args.result_path) 20 | 21 | model = FaceXHuBERT(args) 22 | model.load_state_dict(torch.load('pretrained_model/{}.pth'.format(args.model_name))) 23 | model = model.to(torch.device(args.device)) 24 | model.eval() 25 | 26 | template_file = os.path.join(args.dataset, args.template_path) 27 | with open(template_file, 'rb') as fin: 28 | templates = pickle.load(fin,encoding='latin1') 29 | 30 | train_subjects_list = [i for i in args.train_subjects.split(" ")] 31 | 32 | one_hot_labels = np.eye(len(train_subjects_list)) 33 | emo_one_hot_labels = np.eye(2) 34 | if args.emotion == 1: 35 | emo_one_hot = torch.FloatTensor(emo_one_hot_labels[1]).to(device=args.device) 36 | emo_label = "emotional" 37 | else: 38 | emo_one_hot = torch.FloatTensor(emo_one_hot_labels[0]).to(device=args.device) 39 | emo_label = "neutral" 40 | 41 | iter = train_subjects_list.index(args.condition) 42 | one_hot = one_hot_labels[iter] 43 | one_hot = np.reshape(one_hot,(-1,one_hot.shape[0])) 44 | one_hot = torch.FloatTensor(one_hot).to(device=args.device) 45 | 46 | temp = templates[args.subject] 47 | 48 | template = temp.reshape((-1)) 49 | template = np.reshape(template,(-1,template.shape[0])) 50 | template = torch.FloatTensor(template).to(device=args.device) 51 | 52 | wav_path = args.wav_path 53 | test_name = os.path.basename(wav_path).split(".")[0] 54 | start_time = time.time() 55 | speech_array, sampling_rate = librosa.load(os.path.join(wav_path), sr=16000) 56 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft") 57 | audio_feature = processor(speech_array, return_tensors="pt", padding="longest", sampling_rate=sampling_rate).input_values 58 | audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0])) 59 | audio_feature = torch.FloatTensor(audio_feature).to(device=args.device) 60 | 61 | prediction = model.predict(audio_feature, template, one_hot, emo_one_hot) 62 | prediction = prediction.squeeze() 63 | elapsed = time.time() - start_time 64 | print("Inference time for ", prediction.shape[0], " frames is: ", elapsed, " seconds.") 65 | print("Inference time for 1 frame is: ", elapsed/prediction.shape[0], " seconds.") 66 | print("Inference time for 1 second of audio is: ", ((elapsed * 25) / prediction.shape[0]), " seconds.") 67 | out_file_name = test_name + "_" + emo_label + "_" + args.subject + "_Condition_" + args.condition 68 | np.save(os.path.join(args.result_path, out_file_name), prediction.detach().cpu().numpy()) 69 | 70 | 71 | def render(args): 72 | fps = args.fps 73 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 74 | render_path = "demo/render/" 75 | frames_folder = render_path + "frames/" 76 | video_woA_folder = render_path + "video_wo_audio/" 77 | video_wA_folder = render_path + "video_with_audio/" 78 | emo_label = "emotional" 79 | if args.emotion == 0: 80 | emo_label = "neutral" 81 | 82 | wav_path = args.wav_path 83 | test_name = os.path.basename(wav_path).split(".")[0] 84 | out_file_name = test_name + "_" + emo_label + "_" + args.subject + "_Condition_" + args.condition 85 | predicted_vertices_path = os.path.join(args.result_path,out_file_name+".npy") 86 | if args.dataset == "BIWI": 87 | template_file = os.path.join(args.dataset, args.render_template_path + "/BIWI_topology.obj") 88 | 89 | cam = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) 90 | camera_pose = np.array([[1.0, 0, 0.0, 0.00], 91 | [0.0, -1.0, 0.0, 0.00], 92 | [0.0, 0.0, 1.0, -1.6], 93 | [0.0, 0.0, 0.0, 1.0]]) 94 | 95 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0) 96 | 97 | # r = pyrender.OffscreenRenderer(640, 480) 98 | r = pyrender.OffscreenRenderer(1920, 1440) 99 | 100 | print("rendering the predicted sequence: ", test_name) 101 | 102 | video_woA_path = video_woA_folder + out_file_name + '.mp4' 103 | video_wA_path = video_wA_folder + out_file_name + '.mp4' 104 | # video = cv2.VideoWriter(video_woA_path, fourcc, fps, (640, 480)) 105 | video = cv2.VideoWriter(video_woA_path, fourcc, fps, (1920, 1440)) 106 | 107 | ref_mesh = trimesh.load_mesh(template_file, process=False) 108 | seq = np.load(predicted_vertices_path) 109 | seq = np.reshape(seq, (-1, 70110 // 3, 3)) 110 | ref_mesh.vertices = seq[0, :, :] 111 | py_mesh = pyrender.Mesh.from_trimesh(ref_mesh) 112 | 113 | for f in range(seq.shape[0]): 114 | ref_mesh.vertices = seq[f, :, :] 115 | py_mesh = pyrender.Mesh.from_trimesh(ref_mesh) 116 | scene = pyrender.Scene() 117 | scene.add(py_mesh) 118 | 119 | scene.add(cam, pose=camera_pose) 120 | scene.add(light, pose=camera_pose) 121 | color, _ = r.render(scene) 122 | 123 | output_frame = frames_folder + "frame" + str(f) + ".jpg" 124 | cv2.imwrite(output_frame, color) 125 | frame = cv2.imread(output_frame) 126 | video.write(frame) 127 | video.release() 128 | 129 | input_video = ffmpeg.input(video_woA_path) 130 | input_audio = ffmpeg.input(wav_path) 131 | 132 | ffmpeg.concat(input_video, input_audio, v=1, a=1).output(video_wA_path).run() 133 | del video, seq, ref_mesh 134 | gc.collect() 135 | 136 | 137 | def main(): 138 | parser = argparse.ArgumentParser(description='FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis using Self-Supervised Speech Representation Learning') 139 | parser.add_argument("--model_name", type=str, default="FaceXHuBERT") 140 | parser.add_argument("--dataset", type=str, default="BIWI", help='name of the dataset folder. eg: BIWI') 141 | parser.add_argument("--fps", type=float, default=25, help='frame rate - 25 for BIWI') 142 | parser.add_argument("--feature_dim", type=int, default=256, help='GRU Vertex Decoder hidden size') 143 | parser.add_argument("--vertice_dim", type=int, default=70110, help='number of vertices - 23370*3 for BIWI') 144 | parser.add_argument("--device", type=str, default="cuda") 145 | parser.add_argument("--train_subjects", type=str, default="F1 F2 F3 F4 F5 F6 F7 F8 M1 M2 M3 M4 M5 M6") 146 | parser.add_argument("--test_subjects", type=str, default="F1 F2 F3 F4 F5 F6 F7 F8 M1 M2 M3 M4 M5 M6") 147 | parser.add_argument("--wav_path", type=str, default="demo/wav/test.wav", help='path of the input audio signal in .wav format') 148 | parser.add_argument("--result_path", type=str, default="demo/result", help='path of the predictions in .npy format') 149 | parser.add_argument("--condition", type=str, default="M3", help='select a conditioning subject from train_subjects') 150 | parser.add_argument("--subject", type=str, default="M1", help='select a subject from test_subjects or train_subjects') 151 | parser.add_argument("--template_path", type=str, default="templates_scaled.pkl", help='path of the personalized templates') 152 | parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI topology') 153 | parser.add_argument("--input_fps", type=int, default=50, help='HuBERT last hidden state produces 50 fps audio representation') 154 | parser.add_argument("--output_fps", type=int, default=25, help='fps of the visual data, BIWI was captured in 25 fps') 155 | parser.add_argument("--emotion", type=int, default="1", help='style control for emotion, 1 for expressive animation, 0 for neutral animation') 156 | args = parser.parse_args() 157 | 158 | test_model(args) 159 | render(args) 160 | 161 | 162 | if __name__=="__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /pretrained_model/README.md: -------------------------------------------------------------------------------- 1 | Put the downloaded FaceXHuBERT pretrained model in this folder. -------------------------------------------------------------------------------- /render_result.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | import cv2 4 | import os 5 | import ffmpeg 6 | import gc 7 | import pyrender 8 | 9 | train_folder = "render_folder/" 10 | results_folder = "result/" 11 | audio_folder = "BIWI/wav/" 12 | video_woA_folder = "renders/"+ train_folder+ "videos_no_audio/" 13 | video_wA_folder = "renders/"+ train_folder+ "videos_with_audio/" 14 | frames_folder = "renders/"+ train_folder+ "temp/frames/" 15 | 16 | seqs = os.listdir(results_folder) 17 | 18 | fps = 25 19 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 20 | 21 | cam = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) 22 | camera_pose = np.array([[1.0, 0, 0.0, 0.00], 23 | [0.0, -1.0, 0.0, 0.00], 24 | [0.0, 0.0, 1.0, -1.6], 25 | [0.0, 0.0, 0.0, 1.0]]) 26 | 27 | 28 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0) 29 | 30 | r = pyrender.OffscreenRenderer(640, 480) 31 | 32 | for seq in seqs: 33 | if seq.endswith('.npy'): 34 | video_woA_path = video_woA_folder + seq.split('.')[0] + '.mp4' 35 | video_wA_path = video_wA_folder + seq.split('.')[0] + '.mp4' 36 | video = cv2.VideoWriter(video_woA_path, fourcc, fps, (640, 480)) 37 | seq_path = results_folder + seq 38 | subject_template_path = "BIWI/templates/"+ seq.split('_')[0] + ".obj" 39 | audio = seq.split('_')[0]+'_'+seq.split('_')[1]+'.wav' 40 | audio_path = audio_folder + audio 41 | ref_mesh = trimesh.load_mesh(subject_template_path, process=False) 42 | 43 | seq = np.load(seq_path) 44 | seq = np.reshape(seq,(-1,70110//3,3)) 45 | ref_mesh.vertices = seq[0,:,:] 46 | py_mesh = pyrender.Mesh.from_trimesh(ref_mesh) 47 | for f in range(seq.shape[0]): 48 | ref_mesh.vertices = seq[f,:,:] 49 | py_mesh = pyrender.Mesh.from_trimesh(ref_mesh) 50 | scene = pyrender.Scene() 51 | scene.add(py_mesh) 52 | 53 | scene.add(cam, pose=camera_pose) 54 | scene.add(light, pose=camera_pose) 55 | color, _ = r.render(scene) 56 | 57 | output_frame = frames_folder + "frame" + str(f) + ".jpg" 58 | cv2.imwrite(output_frame, color) 59 | frame = cv2.imread(output_frame) 60 | video.write(frame) 61 | video.release() 62 | 63 | input_video = ffmpeg.input(video_woA_path) 64 | input_audio = ffmpeg.input(audio_path) 65 | ffmpeg.concat(input_video, input_audio, v=1, a=1).output(video_wA_path).run() 66 | del video, seq, ref_mesh 67 | gc.collect() 68 | -------------------------------------------------------------------------------- /renders/render_folder/temp/frames/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the temporary frames of the renders. -------------------------------------------------------------------------------- /renders/render_folder/videos_no_audio/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain rendered video without audio. -------------------------------------------------------------------------------- /renders/render_folder/videos_with_audio/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain rendered video with audio. -------------------------------------------------------------------------------- /result/README.md: -------------------------------------------------------------------------------- 1 | This folder will contain the predicted results of the test split audio sequences in `.npy` format after the training script ends. Run `render_result.py` to render the results in this folder. -------------------------------------------------------------------------------- /save/README.md: -------------------------------------------------------------------------------- 1 | Trained models will be saved in this folder. --------------------------------------------------------------------------------