├── .gitignore
├── BIWI
├── ForProcessing
│ ├── FaceData
│ │ ├── README.md
│ │ └── faces
│ │ │ ├── Extract_all.sh
│ │ │ ├── delete_vls.sh
│ │ │ ├── vertex_process.py
│ │ │ └── vl2csv_recursive.m
│ └── rest
│ │ ├── README.md
│ │ └── wav_process.py
├── templates
│ ├── BIWI_topology.obj
│ └── README.md
├── templates_orig.pkl
├── templates_scaled.pkl
├── vertices_npy
│ └── README.md
└── wav
│ └── README.md
├── Evaluation
├── GroundTruth
│ └── README.md
├── render_quant_evaluation.py
└── renders
│ ├── temp
│ ├── frames
│ │ └── README.md
│ └── meshes
│ │ └── README.md
│ ├── videos_no_audio
│ └── README.md
│ └── videos_with_audio
│ └── README.md
├── FaceXHuBERT.png
├── LICENSE
├── README.md
├── data_loader.py
├── demo
├── render
│ ├── frames
│ │ └── README.md
│ ├── video_with_audio
│ │ └── README.md
│ └── video_wo_audio
│ │ └── README.md
├── result
│ └── README.md
└── wav
│ ├── README.md
│ └── test.wav
├── environment.yml
├── faceXhubert.py
├── hubert
├── activations.py
├── configuration_hubert.py
├── configuration_utils.py
├── deepspeed.py
├── dependency_versions_check.py
├── dependency_versions_table.py
├── file_utils.py
├── generation_beam_search.py
├── generation_logits_process.py
├── generation_stopping_criteria.py
├── generation_utils.py
├── modeling_hubert.py
├── modeling_outputs.py
├── modeling_utils.py
└── utils
│ ├── logging.py
│ └── versions.py
├── index.html
├── main.py
├── modeling_hubert.py
├── page_assets
├── bibtex.txt
└── paper.png
├── predict.py
├── pretrained_model
└── README.md
├── render_result.py
├── renders
└── render_folder
│ ├── temp
│ └── frames
│ │ └── README.md
│ ├── videos_no_audio
│ └── README.md
│ └── videos_with_audio
│ └── README.md
├── result
└── README.md
└── save
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore the data files
2 | BIWI/wav/*.wav
3 |
4 | # ignore the following extensions
5 | *.npy
6 | *.pth
7 | *.mp4
8 |
--------------------------------------------------------------------------------
/BIWI/ForProcessing/FaceData/README.md:
--------------------------------------------------------------------------------
1 | Put "faces0*.tgz" archive files here.
--------------------------------------------------------------------------------
/BIWI/ForProcessing/FaceData/faces/Extract_all.sh:
--------------------------------------------------------------------------------
1 | # for d in */;
2 | # do
3 | # echo "put ${d}";
4 | # cd ${d};
5 | # echo "starting extract!";
6 | # for file in *; do
7 | # if [ -f "$file" ]; then
8 | # if [[ $file == *.tar ]]; then
9 | # echo "$file"
10 | # tar -xf $file --one-top-level
11 | # fi
12 | # fi
13 | # done
14 | # cd ..
15 |
16 | # done
17 |
18 |
19 | for d in */;
20 | do
21 | echo "put ${d}";
22 | cd ${d};
23 | echo "starting extract!";
24 | for file in *; do
25 | if [ -f "$file" ]; then
26 | if [[ $file == *.tar ]]; then
27 | echo "$file"
28 | tar -xf $file --one-top-level
29 | fi
30 | fi
31 | done
32 | cd ..
33 |
34 | done
--------------------------------------------------------------------------------
/BIWI/ForProcessing/FaceData/faces/delete_vls.sh:
--------------------------------------------------------------------------------
1 | # for d in */;
2 | # do
3 | # echo "put ${d}";
4 | # cd ${d};
5 | # for d2 in */;
6 | # cd ${d2};
7 | # for file in *; do
8 | # if [ -f "$file" ]; then
9 | # if [[ $file == *.vl ]]; then
10 | # echo "$file"
11 | # #rm $file
12 | # fi
13 | # fi
14 | # done
15 | # cd ..
16 | # cd ..
17 |
18 | # done
19 |
20 |
21 | # for file in *; do
22 | # if [ -f "$file" ]; then
23 | # echo "$file"
24 | # fi
25 | # done
26 |
27 | find . -print0 | while IFS= read -r -d '' file
28 | do
29 | if [ -f "$file" ]; then
30 | if [[ $file == *.vl ]]; then
31 | #echo "$file"
32 | rm -v $file
33 | fi
34 | fi
35 | done
36 |
37 |
--------------------------------------------------------------------------------
/BIWI/ForProcessing/FaceData/faces/vertex_process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pandas as pd
4 |
5 | input_path = "./"
6 | output_path = "../../../vertices_npy/"
7 | output_file = ""
8 |
9 | subjects = [ name for name in os.listdir(input_path) if os.path.isdir(os.path.join(input_path, name)) ]
10 |
11 | #print(subjects)
12 |
13 | for subject in subjects:
14 | seqs = [ name for name in os.listdir(input_path+"/"+subject) if os.path.isdir(os.path.join(input_path+"/"+subject, name)) ]
15 | # print(seqs)
16 | for seq in seqs:
17 | print("now getting subject ", subject, " and seq ", seq)
18 | frame_files = os.listdir(input_path+"/"+subject+"/"+seq)
19 | if seq[0]=='e':
20 | new_seq = int(seq[1:3]) + 40
21 | output_file = subject+"_"+str(new_seq)+".npy"
22 | else:
23 | output_file = subject+"_"+seq+".npy"
24 | full_output_path = output_path+output_file
25 | empty_arr = np.empty((0,70110))
26 | for file in frame_files:
27 | if file.endswith(".csv"):
28 | # print(file)
29 | df=pd.read_csv(input_path+"/"+subject+"/"+seq+"/"+file, sep=',',header=None)
30 | df = df.iloc[1: , :]
31 | frame = df.to_numpy()
32 | frame = frame.flatten()
33 | frame = np.expand_dims(frame, axis=0)
34 | empty_arr = np.append(empty_arr,frame,axis=0)
35 | np.save(full_output_path, empty_arr)
36 | print("vertice shape: ", empty_arr.shape)
37 | print("saved to ", full_output_path)
38 |
39 | templates_orig = np.load("../../../templates_orig.pkl", allow_pickle=True)
40 | templates_scaled = np.load("../../../templates_scaled.pkl", allow_pickle=True)
41 |
42 |
43 | path = "../../../vertices_npy/"
44 |
45 | npys = [ name for name in os.listdir(path) if name.endswith('.npy') ]
46 |
47 | for npy in npys:
48 | file_path = path + npy
49 | subject = npy.split('_')[0]
50 | subject_template = templates_orig[subject]
51 |
52 | subject_x_mean = subject_template[:,0].mean()
53 | subject_x_max = subject_template[:,0].max()
54 | subject_x_min = subject_template[:,0].min()
55 |
56 | subject_y_mean = subject_template[:,1].mean()
57 | subject_y_max = subject_template[:,1].max()
58 | subject_y_min = subject_template[:,1].min()
59 |
60 | subject_z_mean = subject_template[:,2].mean()
61 | subject_z_max = subject_template[:,2].max()
62 | subject_z_min = subject_template[:,2].min()
63 |
64 | # print(subject)
65 | seq = np.load(file_path, allow_pickle=True).astype(float)
66 | # print(seq.shape)
67 | seq = np.reshape(seq,(-1,70110//3,3))
68 | for f in range(seq.shape[0]):
69 | frame = seq[f,:,:]
70 | X = (frame[:,0]-subject_x_mean)/(subject_x_max-subject_x_min)
71 | Y = (frame[:,1]-subject_y_mean)/(subject_y_max-subject_y_min)
72 | Z = (frame[:,2]-subject_z_mean)/(subject_z_max-subject_z_min)
73 | frame[:,0] = X
74 | frame[:,1] = Y
75 | frame[:,2] = Z
76 |
77 | seq[f,:,:] = frame
78 | seq = seq.reshape(seq.shape[0],seq.shape[1]*seq.shape[2])
79 | # print(seq.shape)
80 | np.save(file_path, seq)
--------------------------------------------------------------------------------
/BIWI/ForProcessing/FaceData/faces/vl2csv_recursive.m:
--------------------------------------------------------------------------------
1 | rootdir = pwd;
2 | filelist = dir(fullfile(rootdir, '**\*.*')); %get list of files and folders in any subfolder
3 | filelist = filelist(~[filelist.isdir]); %remove folders from list
4 |
5 | % display(filelist(55).folder);
6 | % path = filelist(55).folder;
7 | % fname = filelist(55).name;
8 | % outfname = '\' + string(fname(1:end-2)) + 'csv';
9 | % outfilepath = path+outfname;
10 | tic
11 | for i=1:size(filelist,1)
12 | currfile = filelist(i);
13 | path = currfile.folder;
14 | fname = currfile.name;
15 | [~, ~, ext] = fileparts(fname);
16 | if isequal(ext, '.vl')
17 | outfname = '\' + string(fname(1:end-2)) + 'csv';
18 | outfilepath = path+outfname;
19 | disp(outfilepath);
20 | infilepath = path+"\"+fname;
21 | fid = fopen(infilepath); % load the .vl file
22 | n_vertices = fread(fid, 1, 'ulong');
23 | vertices = fread(fid, [3, n_vertices] , 'float');
24 | disp("Loaded the .vl file! No. of Vs: " + n_vertices);
25 |
26 | fclose(fid);
27 | header = {'x', 'y', 'z'};
28 | textHeader = strjoin(header, ',');
29 | %write header to file
30 | fid2 = fopen(outfilepath,'w');
31 | fprintf(fid2,'%s\n', textHeader);
32 | fclose(fid2);
33 | dlmwrite(outfilepath, vertices.', '-append');
34 |
35 | end
36 | end
37 | toc
--------------------------------------------------------------------------------
/BIWI/ForProcessing/rest/README.md:
--------------------------------------------------------------------------------
1 | Put "rest.tgz" downloaded from the BIWI dataset here.
2 |
--------------------------------------------------------------------------------
/BIWI/ForProcessing/rest/wav_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | wav_in = "./audio/"
5 | wav_out = "../../wav/"
6 |
7 | wavs = os.listdir(wav_in)
8 |
9 | for wav in wavs:
10 | input_path = wav_in + wav
11 | # print(input_path)
12 | if wav.endswith('_cut.wav'):
13 | # print(wav)
14 | filename = wav.split('_')
15 | sub = filename[0]
16 | seq = filename[1]
17 | # print(sub)
18 | # print(seq)
19 | if seq[0]=='e':
20 | # print(seq)
21 | new_seq = int(seq[1:3]) + 40
22 | # print(new_seq)
23 | output_path = wav_out + sub + "_" + str(new_seq) + ".wav"
24 | shutil.copy2(input_path, output_path)
25 | # os.rename(input_path, output_path)
26 | else:
27 | output_path = wav_out + sub + "_" + seq + ".wav"
28 | shutil.copy2(input_path, output_path)
--------------------------------------------------------------------------------
/BIWI/templates/README.md:
--------------------------------------------------------------------------------
1 | Put the template .obj files here in this folder.
2 | i.e. - F1.obj, F2.obj, F3.obj, F4.obj, F5.obj, F6.obj, F7.obj, F8.obj, M1.obj, M2.obj, M3.obj, M4.obj, M5.obj, M6.obj
--------------------------------------------------------------------------------
/BIWI/templates_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/BIWI/templates_orig.pkl
--------------------------------------------------------------------------------
/BIWI/templates_scaled.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/BIWI/templates_scaled.pkl
--------------------------------------------------------------------------------
/BIWI/vertices_npy/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain the vertex data for all the sequences in `.npy` format.
--------------------------------------------------------------------------------
/BIWI/wav/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain the audio data for all the sequences in `.wav` format.
--------------------------------------------------------------------------------
/Evaluation/GroundTruth/README.md:
--------------------------------------------------------------------------------
1 | Put the .npy sequences in this folder that are in your test split during training. You can copy and paste the test split .npy files from `BIWI/vertices_npy/` folder.
--------------------------------------------------------------------------------
/Evaluation/render_quant_evaluation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import trimesh
3 | import numpy as np
4 | import cv2
5 | import io
6 | from PIL import Image
7 | import os
8 | import ffmpeg
9 | import gc
10 | import pyrender
11 | import pymeshlab as pmlab
12 |
13 | quantfilename = "quantitative_metric.txt"
14 | render_folder = "renders/"
15 | gt_folder = "GroundTruth/"
16 | pred_folder = "../result/"
17 | audio_folder = "../BIWI/wav/"
18 | video_woA_folder = render_folder + "videos_no_audio/"
19 | video_wA_folder = render_folder + "videos_with_audio/"
20 | meshes_folder = render_folder+ "temp/meshes/"
21 | frames_folder = render_folder+ "temp/frames/"
22 |
23 | mean_face_vertex_error = 0
24 |
25 | gt_seqs = os.listdir(gt_folder)
26 | pred_seqs = os.listdir(pred_folder)
27 |
28 | fps = 25
29 | fourcc = cv2.VideoWriter_fourcc(*'MP4V')
30 |
31 | cam = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414)
32 | camera_pose = np.array([[1.0, 0, 0.0, 0.00],
33 | [0.0, -1.0, 0.0, 0.00],
34 | [0.0, 0.0, 1.0, -2.0],
35 | [0.0, 0.0, 0.0, 1.0]])
36 |
37 |
38 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=5.0)
39 |
40 | r = pyrender.OffscreenRenderer(640, 480)
41 |
42 | print("Evaluation started")
43 | for gt_seq in gt_seqs:
44 | if gt_seq.endswith('.npy'):
45 | video_woA_path = video_woA_folder + gt_seq.split('.')[0] + '.mp4'
46 | video_wA_path = video_wA_folder + gt_seq.split('.')[0] + '.mp4'
47 | video = cv2.VideoWriter(video_woA_path, fourcc, fps, (640, 480))
48 | gt_seq_path = gt_folder + gt_seq
49 | pred_seq_path = pred_folder + gt_seq.split('.')[0] + "_condition_" + gt_seq.split('_')[0] + ".npy"
50 | print("Now evaluating sequence: ", gt_seq)
51 | subject_template_path = "../BIWI/templates/"+ gt_seq.split('_')[0] + ".obj"
52 | audio = gt_seq.split('.')[0].split('_')[0]+'_'+gt_seq.split('.')[0].split('_')[1]+'.wav'
53 | audio_path = audio_folder + audio
54 | # print(gt_seq_path)
55 | # print(pred_seq_path)
56 | # print(audio_path)
57 | # print(subject_template_path)
58 | # ref_mesh = trimesh.load_mesh(subject_template_path, process=False)
59 |
60 | gt_seq = np.load(gt_seq_path)
61 | pred_seq = np.load(pred_seq_path)
62 |
63 | if(gt_seq.shape[0]>pred_seq.shape[0]):
64 | gt_seq = gt_seq[:pred_seq.shape[0]]
65 |
66 | if(pred_seq.shape[0]>gt_seq.shape[0]):
67 | pred_seq = pred_seq[:gt_seq.shape[0]]
68 |
69 |
70 | gt_seq = np.reshape(gt_seq,(-1,70110//3,3))
71 | pred_seq = np.reshape(pred_seq,(-1,70110//3,3))
72 | sequence_mean_face_vertex_error = 0
73 |
74 |
75 | # seq = np.reshape(seq,(-1,70110//3,3))
76 | # ref_mesh.vertices = seq[0,:,:]
77 | # py_mesh = pyrender.Mesh.from_trimesh(ref_mesh)
78 | for f in range(pred_seq.shape[0]):
79 | ms = pmlab.MeshSet()
80 | ms.load_new_mesh(subject_template_path)
81 | template_mesh= ms.current_mesh()
82 |
83 | gt_mesh = pmlab.Mesh(gt_seq[f,:,:], template_mesh.face_matrix(), template_mesh.vertex_normal_matrix())
84 | ms.add_mesh(gt_mesh)
85 | # print(ms.current_mesh_id())
86 |
87 | pred_mesh = pmlab.Mesh(pred_seq[f,:,:], template_mesh.face_matrix(), template_mesh.vertex_normal_matrix())
88 | ms.add_mesh(pred_mesh)
89 | # print(ms.current_mesh_id())
90 |
91 | ms.apply_filter('distance_from_reference_mesh', measuremesh=2, refmesh=1, signeddist = False)
92 | ms.set_current_mesh(2)
93 | vertex_quality = ms.current_mesh().vertex_quality_array()
94 | #ms.apply_filter('colorize_by_vertex_quality', minval=ms.current_mesh().vertex_quality_array().min(), maxval=ms.current_mesh().vertex_quality_array().max(),zerosym=True)
95 | ms.apply_filter('quality_mapper_applier', minqualityval=ms.current_mesh().vertex_quality_array().min(), maxqualityval=ms.current_mesh().vertex_quality_array().max(),tfslist= 1)
96 | ms.save_current_mesh( meshes_folder + str(f) +".obj", save_vertex_color=True)
97 | sequence_mean_face_vertex_error = sequence_mean_face_vertex_error + vertex_quality.mean(axis=None)
98 | ms.set_current_mesh(2)
99 | ms.delete_current_mesh()
100 | ms.set_current_mesh(1)
101 | ms.delete_current_mesh()
102 | ms.set_current_mesh(0)
103 |
104 | render_mesh = trimesh.load_mesh(meshes_folder + str(f) +".obj", process=False)
105 | py_mesh = pyrender.Mesh.from_trimesh(render_mesh)
106 |
107 | # ref_mesh.vertices = seq[f,:,:]
108 | # py_mesh = pyrender.Mesh.from_trimesh(ref_mesh)
109 | scene = pyrender.Scene()
110 | scene.add(py_mesh)
111 |
112 | scene.add(cam, pose=camera_pose)
113 | scene.add(light, pose=camera_pose)
114 | color, _ = r.render(scene)
115 |
116 | output_frame = frames_folder + "frame" + str(f) + ".jpg"
117 | image_bgr = cv2.cvtColor(color, cv2.COLOR_RGB2BGR)
118 | cv2.imwrite(output_frame, image_bgr)
119 | frame = cv2.imread(output_frame)
120 | video.write(frame)
121 | video.release()
122 | sequence_mean_face_vertex_error = sequence_mean_face_vertex_error/pred_seq.shape[0]
123 | mean_face_vertex_error = mean_face_vertex_error + sequence_mean_face_vertex_error
124 |
125 | input_video = ffmpeg.input(video_woA_path)
126 | input_audio = ffmpeg.input(audio_path)
127 | ffmpeg.concat(input_video, input_audio, v=1, a=1).output(video_wA_path).run()
128 | del video
129 | gc.collect()
130 |
131 |
132 | mean_face_vertex_error = mean_face_vertex_error/len(gt_seqs)
133 |
134 | file = open(quantfilename, "w")
135 |
136 | #convert variable to string
137 | str = repr(mean_face_vertex_error)
138 | file.write("mean_face_vertex_error = " + str + "\n")
139 |
140 | file.close()
141 | print("Done!")
--------------------------------------------------------------------------------
/Evaluation/renders/temp/frames/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain temporary frames for the rendered videos of the evaluation script.
--------------------------------------------------------------------------------
/Evaluation/renders/temp/meshes/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain temporary face meshes used for rendering the evaluation renders.
--------------------------------------------------------------------------------
/Evaluation/renders/videos_no_audio/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain the rendered videos without audio.
--------------------------------------------------------------------------------
/Evaluation/renders/videos_with_audio/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain the rendered videos with audio.
--------------------------------------------------------------------------------
/FaceXHuBERT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/FaceXHuBERT.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # Attribution-NonCommercial 4.0 International
2 |
3 | > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.*
4 | >
5 | > ### Using Creative Commons Public Licenses
6 | >
7 | > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8 | >
9 | > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10 | >
11 | > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12 |
13 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License
14 |
15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16 |
17 | ### Section 1 – Definitions.
18 |
19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20 |
21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
22 |
23 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
24 |
25 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
26 |
27 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
28 |
29 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
30 |
31 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
32 |
33 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
34 |
35 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
36 |
37 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
38 |
39 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
40 |
41 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
42 |
43 | ### Section 2 – Scope.
44 |
45 | a. ___License grant.___
46 |
47 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
48 |
49 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
50 |
51 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
52 |
53 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
54 |
55 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
56 |
57 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
58 |
59 | 5. __Downstream recipients.__
60 |
61 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
62 |
63 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
64 |
65 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
66 |
67 | b. ___Other rights.___
68 |
69 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
70 |
71 | 2. Patent and trademark rights are not licensed under this Public License.
72 |
73 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
74 |
75 | ### Section 3 – License Conditions.
76 |
77 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
78 |
79 | a. ___Attribution.___
80 |
81 | 1. If You Share the Licensed Material (including in modified form), You must:
82 |
83 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
84 |
85 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
86 |
87 | ii. a copyright notice;
88 |
89 | iii. a notice that refers to this Public License;
90 |
91 | iv. a notice that refers to the disclaimer of warranties;
92 |
93 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
94 |
95 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
96 |
97 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
98 |
99 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
100 |
101 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
102 |
103 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
104 |
105 | ### Section 4 – Sui Generis Database Rights.
106 |
107 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
108 |
109 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
110 |
111 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
112 |
113 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
114 |
115 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
116 |
117 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
118 |
119 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
120 |
121 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
122 |
123 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
124 |
125 | ### Section 6 – Term and Termination.
126 |
127 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
128 |
129 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
130 |
131 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
132 |
133 | 2. upon express reinstatement by the Licensor.
134 |
135 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
136 |
137 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
138 |
139 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
140 |
141 | ### Section 7 – Other Terms and Conditions.
142 |
143 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
144 |
145 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
146 |
147 | ### Section 8 – Interpretation.
148 |
149 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
150 |
151 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
152 |
153 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
154 |
155 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
156 |
157 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
158 | >
159 | > Creative Commons may be contacted at creativecommons.org
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FaceXHuBERT (ICMI '23)
2 | ### Code repository for the paper:
3 |
4 | >_FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis using Self-Supervised Speech Representation Learning_.
5 |
6 | > Authors: Kazi Injamamul Haque, Zerrin Yumak
7 |
8 | > [[Paper]](https://dl.acm.org/doi/pdf/10.1145/3577190.3614157) [[Project Page]](https://galib360.github.io/FaceXHuBERT/) [[Video]](https://www.youtube.com/watch?v=AkBhnNOxwE4&ab_channel=KaziInjamamulHaque)
9 |
10 |
11 | > This GitHub repository contains PyTorch implementation of the work presented in the paper mentioned above. Given a raw audio, FaceXHuBERT generates and renders expressive 3D facial animation. We recommend visiting the project website and watching the supplementary video.
12 |
13 |
14 |
15 |
16 |
17 | ## Environment
18 |
19 | - Windows (tested on Windows 10 and 11)
20 | - Python 3.8
21 | - PyTorch 1.10.1+cu113
22 |
23 | ## Dependencies
24 |
25 | - Check the required python packages and libraries in `environment.yml`.
26 | - [ffmpeg](https://ffmpeg.org/download.html) for Windows. [WikiHow link on how to install](https://www.wikihow.com/Install-FFmpeg-on-Windows)
27 |
28 | ## Get Started
29 |
30 | It is recommended to create a new anaconda environment with Python 3.8. To do so, please follow the steps sequentially-
31 | - Ensure that [CUDA](https://developer.nvidia.com/cuda-11-5-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exe_local) computing toolkit with appropriate [cudnn](https://developer.nvidia.com/rdp/cudnn-archive) (tested with CUDA 11.5) is properly installed in the system and the environment variable "CUDA_PATH" is set properly. First install CUDA, then download, extract, and copy the cudnn contents in the CUDA installation folder (e.g. C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.5).
32 | - Install [Anaconda](https://www.anaconda.com/products/distribution) for Windows.
33 | - Insatall ffmpeg. [WikiHow link on how to install ffmpeg](https://www.wikihow.com/Install-FFmpeg-on-Windows)
34 | - Clone this repository.
35 | - Open Anaconda Promt CLI.
36 | ```
37 | cd
38 | ```
39 | - Then run the following command in the Anaconda promt
40 |
41 | ```
42 | conda env create --name FaceXHuBERT python=3.8 --file="environment.yml"
43 | conda activate FaceXHuBERT
44 | pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
45 |
46 | ```
47 | - Please make sure you run all the python scripts below from the activated virtual environments command line (in other words, your python interpreter is the one from FaceXHuBERT environment you just created).
48 |
49 |
50 | ## Demo
51 |
52 | Download the pretrained model from [FaceXHuBERT model](https://mega.nz/file/L4BzEATa#HZ_BuV56yI4yQLQMhiml5rOLAMcxgjCEwAgcITD_09g). Put the pretrained model under `pretrained_model/` folder.
53 |
54 | - Given a raw audio file (wav_path), animate the mesh and render the video by running the following command:
55 | ```
56 | python predict.py --subject F1 --condition F3 --wav_path "demo/wav/test.wav" --emotion 1
57 | ```
58 |
59 | The predict.py will run and generate the rendered videos in the `demo/render/video_with_audio/` folder. The prediction data will be saved in `demo/result/` folder. Try playing with your own audio file (in .wav format), other subjects and conditions (i.e. F1, F2, F3, F4, F5, F6, F7, F8, M1, M2, M3, M4, M5, M6).
60 |
61 | ## Data
62 | ### BIWI
63 |
64 | The [Biwi 3D Audiovisual Corpus of Affective Communication](https://data.vision.ee.ethz.ch/cvl/datasets/b3dac2.en.html) dataset is available upon request for research or academic purposes. You will need the following files from the the dataset:
65 |
66 | - faces01.tgz, faces02.tgz, faces03.tgz, faces04.tgz, faces05.tgz and rest.tgz
67 | - Place all the faces0*.tgz archives in `BIWI/ForProcessing/FaceData/` folder
68 | - Place the rest.tgz archive in `BIWI/ForProcessing/rest/` folder
69 |
70 |
71 | #### Data Preparation and Data Pre-process
72 | Follow the steps below sequentially as they appear -
73 |
74 | - You will need [Matlab](https://mathworks.com/products/matlab.html) installed on you machine to prepapre the data for pre-processing
75 | - Open Anaconda Promt CLI, activate FaceXHuBERT env in the directory- `BIWI/ForPorcessing/rest/`
76 | - Run the following command
77 | ```
78 | tar -xvzf rest.tgz
79 | ```
80 | - After extracting, you will see the `audio/` folder that contains the input audios needed for network training in .wav format
81 | - Run the `wav_process.py` script. This will process the `audio/` folder and copy the needed audio sequences with proper names to `FaceXHuBERT/BIWI/wav/` folder for training
82 | ```
83 | python wav_process.py
84 | ```
85 | - Open Anaconda Promt CLI, activate FaceXHuBERT env in the directory- `BIWI/ForPorcessing/FaceData/`
86 | - Run the following command for extracting all the archives. Replace `*` with (1-5 for five archives)
87 | ```
88 | tar -xvzf faces0*.tgz
89 | ```
90 | - After extracting, you will see a folder named `faces/`. Move all the .obj files from this folder (i.e. F1.obj-M6.obj) to `FaceXHuBERT/BIWI/templates/` folder
91 | - Run the shell script `Extract_all.sh`. This will extract all the archives for all subjects and for all sequences. You will have frame-by-frame vertex data in `frame_*.vl` binary file format
92 | - Run the Matlab script `vl2csv_recusive.m`. This will convert all the `.vl` files into `.csv` files
93 | - Run the `vertex_process.py` script. This will process the data and place the processed data in `FaceXHuBERT/BIWI/vertices_npy/` folder for network training
94 | ```
95 | python vertex_process.py
96 | ```
97 |
98 |
99 | ## Model Training
100 |
101 | ### Training and Testing
102 |
103 | - Train the model by running the following command:
104 |
105 | ```
106 | python main.py
107 | ```
108 | The test split predicted results will be saved in the `result/`. The trained models (saves the model in 25 epoch interval) will be saved in the `save/` folder.
109 |
110 | ### Visualization
111 |
112 | - Run the following command to render the predicted test sequences stored in `result/`:
113 |
114 | ```
115 | python render_result.py
116 | ```
117 | The rendered videos will be saved in the `renders/render_folder/videos_with_audio/` folder.
118 |
119 | ### Evaluation
120 |
121 | - Put the ground truth test split sequences (.npy files) in `Evaluation/GroundTruth/` folder. (If you train the model using our dataset split, then the test-set is sequences 39,40,79,80 for all 14 subjects.)
122 | - Run the following command to run quantitative evaluation and render the evaluation results:
123 |
124 | ```
125 | cd Evaluation
126 | python render_quant_evaluation.py
127 | ```
128 | - The rendered videos will be saved in the `Evaluation/renders/videos_with_audio/` folder.
129 | - The computed Mean Face Vertex Error will be saved in `Evaluation/quantitative_metric.txt`
130 |
131 |
132 | ## Citation
133 |
134 | If you find this code useful for your work, please be kind to consider citing our paper:
135 | ```
136 | @inproceedings{FaceXHuBERT_Haque_ICMI23,
137 | author = {Haque, Kazi Injamamul and Yumak, Zerrin},
138 | title = {FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis Using Self-Supervised Speech Representation Learning},
139 | booktitle = {INTERNATIONAL CONFERENCE ON MULTIMODAL INTERACTION (ICMI ’23)},
140 | year = {2023},
141 | location = {Paris, France},
142 | numpages = {10},
143 | url = {https://doi.org/10.1145/3577190.3614157},
144 | doi = {10.1145/3577190.3614157},
145 | publisher = {ACM},
146 | address = {New York, NY, USA},
147 | }
148 | ```
149 |
150 | ## Acknowledgement
151 | We would like to thank the authors of FaceFormer for making their code available. Thanks to ETH Zurich CVL for providing us access to the _Biwi 3D Audiovisual Corpus_. The HuBERT implementation is borrowed from [Hugging Face](https://huggingface.co/).
152 |
153 | ## License
154 | This repository is released under [CC-BY-NC-4.0-International License](https://github.com/Gibberlings3/GitHub-Templates/blob/master/License-Templates/CC-BY-NC-4.0/LICENSE-CC-BY-NC-4.0.md)
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import defaultdict
4 | from torch.utils import data
5 | import numpy as np
6 | import pickle
7 | from tqdm import tqdm
8 | from transformers import Wav2Vec2Processor
9 | import librosa
10 |
11 |
12 | class Dataset(data.Dataset):
13 | """Custom data.Dataset compatible with data.DataLoader."""
14 | def __init__(self, data, subjects_dict, data_type="train"):
15 | self.data = data
16 | self.len = len(self.data)
17 | self.subjects_dict = subjects_dict
18 | self.data_type = data_type
19 | self.one_hot_labels = np.eye(len(subjects_dict["train"]))
20 | self.emo_one_hot_labels = np.eye(2)
21 |
22 | def __getitem__(self, index):
23 | """Returns one data pair (source and target)."""
24 | file_name = self.data[index]["name"]
25 | audio = self.data[index]["audio"]
26 | vertice = self.data[index]["vertice"]
27 | template = self.data[index]["template"]
28 | sentence_id = int(file_name.split(".")[0].split("_")[1])
29 | if self.data_type == "train":
30 | subject = "_".join(file_name.split("_")[:-1])
31 | one_hot = self.one_hot_labels[self.subjects_dict["train"].index(subject)]
32 | else:
33 | one_hot = self.one_hot_labels
34 | if sentence_id > 40:
35 | emo_one_hot = self.emo_one_hot_labels[1]
36 | else:
37 | emo_one_hot = self.emo_one_hot_labels[0]
38 |
39 | return torch.FloatTensor(audio), vertice, torch.FloatTensor(template), torch.FloatTensor(one_hot), file_name, torch.FloatTensor(emo_one_hot)
40 |
41 | def __len__(self):
42 | return self.len
43 |
44 |
45 | def read_data(args):
46 | print("Loading data...")
47 | data = defaultdict(dict)
48 | train_data = []
49 | valid_data = []
50 | test_data = []
51 |
52 | audio_path = os.path.join(args.dataset, args.wav_path)
53 | vertices_path = os.path.join(args.dataset, args.vertices_path)
54 | processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-xlarge-ls960-ft") # HuBERT uses the processor of Wav2Vec 2.0
55 |
56 | template_file = os.path.join(args.dataset, args.template_file)
57 | with open(template_file, 'rb') as fin:
58 | templates = pickle.load(fin,encoding='latin1')
59 |
60 | for r, ds, fs in os.walk(audio_path):
61 | for f in tqdm(fs):
62 | if f.endswith("wav"):
63 | wav_path = os.path.join(r,f)
64 | speech_array, sampling_rate = librosa.load(wav_path, sr=16000)
65 | input_values = processor(speech_array, return_tensors="pt", padding="longest", sampling_rate=sampling_rate).input_values
66 | key = f.replace("wav", "npy")
67 | data[key]["audio"] = input_values
68 | subject_id = "_".join(key.split("_")[:-1])
69 | temp = templates[subject_id]
70 | data[key]["name"] = f
71 | data[key]["template"] = temp.reshape((-1))
72 | vertice_path = os.path.join(vertices_path,f.replace("wav", "npy"))
73 | if not os.path.exists(vertice_path):
74 | del data[key]
75 | # print("Vertices Data Not Found! ", vertice_path)
76 | else:
77 | data[key]["vertice"] = vertice_path
78 |
79 | subjects_dict = {}
80 | subjects_dict["train"] = [i for i in args.train_subjects.split(" ")]
81 | subjects_dict["val"] = [i for i in args.val_subjects.split(" ")]
82 | subjects_dict["test"] = [i for i in args.test_subjects.split(" ")]
83 |
84 | splits = {'BIWI': {'train': list(range(1, 37)) + list(range(41, 77)), 'val': list(range(37, 39)) + list(range(77, 79)), 'test': list(range(39, 41)) + list(range(79, 81))}}
85 |
86 | for k, v in data.items():
87 | subject_id = "_".join(k.split("_")[:-1])
88 | sentence_id = int(k.split(".")[0][-2:])
89 | if subject_id in subjects_dict["train"] and sentence_id in splits[args.dataset]['train']:
90 | train_data.append(v)
91 | if subject_id in subjects_dict["val"] and sentence_id in splits[args.dataset]['val']:
92 | valid_data.append(v)
93 | if subject_id in subjects_dict["test"] and sentence_id in splits[args.dataset]['test']:
94 | test_data.append(v)
95 |
96 | print(len(train_data), len(valid_data), len(test_data))
97 | return train_data, valid_data, test_data, subjects_dict
98 |
99 |
100 | def get_dataloaders(args):
101 | dataset = {}
102 | train_data, valid_data, test_data, subjects_dict = read_data(args)
103 | train_data = Dataset(train_data, subjects_dict,"train")
104 | dataset["train"] = data.DataLoader(dataset=train_data, batch_size=1, shuffle=True)
105 | valid_data = Dataset(valid_data, subjects_dict,"val")
106 | dataset["valid"] = data.DataLoader(dataset=valid_data, batch_size=1, shuffle=False)
107 | test_data = Dataset(test_data, subjects_dict,"test")
108 | dataset["test"] = data.DataLoader(dataset=test_data, batch_size=1, shuffle=False)
109 | return dataset
110 |
111 |
112 | if __name__ == "__main__":
113 | get_dataloaders()
114 |
115 |
--------------------------------------------------------------------------------
/demo/render/frames/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain the temporary frames of the renders.
--------------------------------------------------------------------------------
/demo/render/video_with_audio/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain rendered video with audio.
--------------------------------------------------------------------------------
/demo/render/video_wo_audio/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain rendered video without audio.
--------------------------------------------------------------------------------
/demo/result/README.md:
--------------------------------------------------------------------------------
1 | This folder will contain generated vertex data result in `.npy` format.
--------------------------------------------------------------------------------
/demo/wav/README.md:
--------------------------------------------------------------------------------
1 | Put your own audio files in .wav format in this folder to run the demo script `predict.py`.
--------------------------------------------------------------------------------
/demo/wav/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galib360/FaceXHuBERT/f54f9a99282b6a3b0b99770cbc50cb7ea7f3746b/demo/wav/test.wav
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: FaceXHuBERT
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - argon2-cffi=20.1.0=py38h2bbff1b_1
7 | - async_generator=1.10=pyhd3eb1b0_0
8 | - attrs=21.2.0=pyhd3eb1b0_0
9 | - backcall=0.2.0=pyhd3eb1b0_0
10 | - bleach=4.0.0=pyhd3eb1b0_0
11 | - ca-certificates=2021.10.26=haa95532_2
12 | - certifi=2021.10.8=py38haa95532_2
13 | - cffi=1.15.0=py38h2bbff1b_0
14 | - colorama=0.4.4=pyhd3eb1b0_0
15 | - console_shortcut=0.1.1=4
16 | - debugpy=1.5.1=py38hd77b12b_0
17 | - defusedxml=0.7.1=pyhd3eb1b0_0
18 | - entrypoints=0.3=py38_0
19 | - importlib-metadata=4.8.2=py38haa95532_0
20 | - importlib_metadata=4.8.2=hd3eb1b0_0
21 | - ipykernel=6.4.1=py38haa95532_1
22 | - ipython=7.29.0=py38hd4e2768_0
23 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
24 | - jedi=0.18.0=py38haa95532_1
25 | - jsonschema=3.2.0=pyhd3eb1b0_2
26 | - jupyter_client=7.1.0=pyhd3eb1b0_0
27 | - jupyter_core=4.9.1=py38haa95532_0
28 | - jupyterlab_pygments=0.1.2=py_0
29 | - m2w64-gcc-libgfortran=5.3.0=6
30 | - m2w64-gcc-libs=5.3.0=7
31 | - m2w64-gcc-libs-core=5.3.0=7
32 | - m2w64-gmp=6.1.0=2
33 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
34 | - markupsafe=2.0.1=py38h2bbff1b_0
35 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
36 | - mistune=0.8.4=py38he774522_1000
37 | - msys2-conda-epoch=20160418=1
38 | - nbclient=0.5.3=pyhd3eb1b0_0
39 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
40 | - notebook=6.4.6=py38haa95532_0
41 | - openssl=1.1.1m=h2bbff1b_0
42 | - packaging=21.3=pyhd3eb1b0_0
43 | - pandocfilters=1.4.3=py38haa95532_1
44 | - parso=0.8.3=pyhd3eb1b0_0
45 | - pickleshare=0.7.5=pyhd3eb1b0_1003
46 | - pip=21.2.2=py38haa95532_0
47 | - prometheus_client=0.12.0=pyhd3eb1b0_0
48 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
49 | - pycparser=2.21=pyhd3eb1b0_0
50 | - pygments=2.10.0=pyhd3eb1b0_0
51 | - pyparsing=3.0.4=pyhd3eb1b0_0
52 | - pyrsistent=0.18.0=py38h196d8e1_0
53 | - python=3.8.12=h6244533_0
54 | - python-dateutil=2.8.2=pyhd3eb1b0_0
55 | - pywin32=302=py38h827c3e9_1
56 | - pywinpty=0.5.7=py38_0
57 | - pyzmq=22.3.0=py38hd77b12b_2
58 | - send2trash=1.8.0=pyhd3eb1b0_1
59 | - sqlite=3.37.0=h2bbff1b_0
60 | - terminado=0.9.4=py38haa95532_0
61 | - testpath=0.5.0=pyhd3eb1b0_0
62 | - tornado=6.1=py38h2bbff1b_0
63 | - traitlets=5.1.1=pyhd3eb1b0_0
64 | - wcwidth=0.2.5=pyhd3eb1b0_0
65 | - webencodings=0.5.1=py38_1
66 | - wheel=0.37.0=pyhd3eb1b0_1
67 | - wincertstore=0.2=py38haa95532_2
68 | - winpty=0.4.3=4
69 | - zipp=3.6.0=pyhd3eb1b0_0
70 | - pip:
71 | - absl-py==0.15.0
72 | - ansiwrap==0.8.4
73 | - anyio==3.6.1
74 | - astunparse==1.6.3
75 | - babel==2.10.3
76 | - beautifulsoup4==4.11.1
77 | - biopython==1.79
78 | - cachetools==4.2.4
79 | - click==8.1.1
80 | - configparser==5.2.0
81 | - cycler==0.11.0
82 | - decorator==4.4.2
83 | - deprecation==2.1.0
84 | - et-xmlfile==1.1.0
85 | - fastjsonschema==2.15.3
86 | - ffmpeg-python==0.2.0
87 | - filelock==3.6.0
88 | - fonttools==4.28.5
89 | - freetype-py==2.2.0
90 | - future==0.18.2
91 | - fvcore==0.1.5.post20220305
92 | - google-auth==2.3.3
93 | - google-auth-oauthlib==0.4.6
94 | - google-pasta==0.2.0
95 | - grpcio==1.43.0
96 | - h5py==2.10.0
97 | - huggingface-hub==0.0.8
98 | - hyperpyyaml==1.0.1
99 | - imageio==2.16.1
100 | - imageio-ffmpeg==0.4.7
101 | - iopath==0.1.9
102 | - ipywidgets==7.7.0
103 | - jinja2==3.1.2
104 | - json5==0.9.8
105 | - jsonlines==3.0.0
106 | - jupyter-packaging==0.12.2
107 | - jupyter-server==1.17.1
108 | - jupyterlab==3.4.3
109 | - jupyterlab-server==2.14.0
110 | - jupyterlab-widgets==1.1.0
111 | - keras-preprocessing==1.1.2
112 | - kiwisolver==1.3.2
113 | - librosa==0.8.1
114 | - markdown==3.3.6
115 | - matplotlib==3.5.1
116 | - nbclassic==0.3.7
117 | - nbconvert==6.5.0
118 | - nbformat==5.4.0
119 | - networkx==2.7.1
120 | - notebook-shim==0.1.0
121 | - numpy==1.20.0
122 | - oauthlib==3.1.1
123 | - opencv-python==4.5.5.62
124 | - openpyxl==3.0.9
125 | - opt-einsum==3.3.0
126 | - pandas==1.3.5
127 | - pandas-stubs==1.2.0.57
128 | - pillow==9.0.0
129 | - portalocker==2.4.0
130 | - praatio==5.1.1
131 | - proglog==0.1.10
132 | - protobuf==3.19.3
133 | - pyasn1==0.4.8
134 | - pyasn1-modules==0.2.8
135 | - pyglet==1.5.26
136 | - pymeshlab==0.2.1
137 | - pymysql==1.0.2
138 | - pyopengl==3.1.0
139 | - pyrender==0.1.45
140 | - pytz==2021.3
141 | - pyyaml==6.0
142 | - regex==2022.3.15
143 | - requests==2.28.0
144 | - requests-oauthlib==1.3.0
145 | - rsa==4.8
146 | - ruamel-yaml==0.17.21
147 | - ruamel-yaml-clib==0.2.6
148 | - sacremoses==0.0.49
149 | - sentencepiece==0.1.96
150 | - setuptools==62.6.0
151 | - shutils==0.1.0
152 | - six==1.15.0
153 | - sniffio==1.2.0
154 | - soupsieve==2.3.2.post1
155 | - tabulate==0.8.9
156 | - termcolor==1.1.0
157 | - textwrap3==0.9.2
158 | - tinycss2==1.1.1
159 | - tokenizers==0.10.3
160 | - tomlkit==0.11.0
161 | - tqdm==4.63.0
162 | - transformers==4.7.0
163 | - trimesh==3.12.6
164 | - typing-extensions==4.0.1
165 | - urllib3==1.26.8
166 | - websocket-client==1.3.3
167 | - werkzeug==2.0.2
168 | - widgetsnbextension==3.6.0
169 | - yacs==0.1.8
170 |
171 |
--------------------------------------------------------------------------------
/faceXhubert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from hubert.modeling_hubert import HubertModel
4 | import torch.nn.functional as F
5 |
6 |
7 | def inputRepresentationAdjustment(audio_embedding_matrix, vertex_matrix, ifps, ofps):
8 | if ifps % ofps == 0:
9 | factor = -1 * (-ifps // ofps)
10 | if audio_embedding_matrix.shape[1] % 2 != 0:
11 | audio_embedding_matrix = audio_embedding_matrix[:, :audio_embedding_matrix.shape[1] - 1]
12 |
13 | if audio_embedding_matrix.shape[1] > vertex_matrix.shape[1] * 2:
14 | audio_embedding_matrix = audio_embedding_matrix[:, :vertex_matrix.shape[1] * 2]
15 |
16 | elif audio_embedding_matrix.shape[1] < vertex_matrix.shape[1] * 2:
17 | vertex_matrix = vertex_matrix[:, :audio_embedding_matrix.shape[1] // 2]
18 | else:
19 | factor = -1 * (-ifps // ofps)
20 | audio_embedding_seq_len = vertex_matrix.shape[1] * factor
21 | audio_embedding_matrix = audio_embedding_matrix.transpose(1, 2)
22 | audio_embedding_matrix = F.interpolate(audio_embedding_matrix, size=audio_embedding_seq_len, align_corners=True, mode='linear')
23 | audio_embedding_matrix = audio_embedding_matrix.transpose(1, 2)
24 |
25 | frame_num = vertex_matrix.shape[1]
26 | audio_embedding_matrix = torch.reshape(audio_embedding_matrix, (1, audio_embedding_matrix.shape[1] // factor, audio_embedding_matrix.shape[2] * factor))
27 |
28 | return audio_embedding_matrix, vertex_matrix, frame_num
29 |
30 |
31 | class FaceXHuBERT(nn.Module):
32 | def __init__(self, args):
33 | super(FaceXHuBERT, self).__init__()
34 | """
35 | audio: (batch_size, raw_wav)
36 | template: (batch_size, V*3)
37 | vertice: (batch_size, seq_len, V*3)
38 | """
39 | self.dataset = args.dataset
40 | self.i_fps = args.input_fps # audio fps (input to the network)
41 | self.o_fps = args.output_fps # 4D Scan fps (output or target)
42 | self.gru_layer_dim = 2
43 | self.gru_hidden_dim = args.feature_dim
44 |
45 | # Audio Encoder
46 | self.audio_encoder = HubertModel.from_pretrained("facebook/hubert-base-ls960")
47 | self.audio_dim = self.audio_encoder.encoder.config.hidden_size
48 | self.audio_encoder.feature_extractor._freeze_parameters()
49 |
50 | frozen_layers = [0,1]
51 |
52 | for name, param in self.audio_encoder.named_parameters():
53 | if name.startswith("feature_projection"):
54 | param.requires_grad = False
55 | if name.startswith("encoder.layers"):
56 | layer = int(name.split(".")[2])
57 | if layer in frozen_layers:
58 | param.requires_grad = False
59 |
60 | #Vertex Decoder
61 | # GRU module
62 | self.gru = nn.GRU(self.audio_dim*2, args.feature_dim, self.gru_layer_dim, batch_first=True, dropout=0.3)
63 | # Fully connected layer
64 | self.fc = nn.Linear(args.feature_dim, args.vertice_dim)
65 | nn.init.constant_(self.fc.weight, 0)
66 | nn.init.constant_(self.fc.bias, 0)
67 |
68 | # Subject embedding, S
69 | self.obj_vector = nn.Linear(len(args.train_subjects.split()), args.feature_dim, bias=False)
70 |
71 | # Emotion embedding, E
72 | self.emo_vector = nn.Linear(2, args.feature_dim, bias=False)
73 |
74 | def forward(self, audio, template, vertice, one_hot, emo_one_hot, criterion):
75 |
76 | template = template.unsqueeze(1)
77 | obj_embedding = self.obj_vector(one_hot)
78 | emo_embedding = self.emo_vector(emo_one_hot)
79 |
80 | hidden_states = audio
81 |
82 | hidden_states = self.audio_encoder(hidden_states).last_hidden_state
83 |
84 | hidden_states, vertice, frame_num = inputRepresentationAdjustment(hidden_states, vertice, self.i_fps, self.o_fps)
85 |
86 | hidden_states = hidden_states[:, :frame_num]
87 |
88 | h0 = torch.zeros(self.gru_layer_dim, hidden_states.shape[0], self.gru_hidden_dim).requires_grad_().cuda()
89 |
90 | vertice_out, _ = self.gru(hidden_states, h0)
91 | vertice_out = vertice_out * obj_embedding
92 | vertice_out = vertice_out * emo_embedding
93 |
94 | vertice_out = self.fc(vertice_out)
95 |
96 | vertice_out = vertice_out + template
97 |
98 | loss = criterion(vertice_out, vertice)
99 | loss = torch.mean(loss)
100 | return loss
101 |
102 | def predict(self, audio, template, one_hot, emo_one_hot):
103 | template = template.unsqueeze(1)
104 | obj_embedding = self.obj_vector(one_hot)
105 | emo_embedding = self.emo_vector(emo_one_hot)
106 | hidden_states = audio
107 | hidden_states = self.audio_encoder(hidden_states).last_hidden_state
108 |
109 | if hidden_states.shape[1] % 2 != 0:
110 | hidden_states = hidden_states[:, :hidden_states.shape[1]-1]
111 | hidden_states = torch.reshape(hidden_states, (1, hidden_states.shape[1] // 2, hidden_states.shape[2] * 2))
112 | h0 = torch.zeros(self.gru_layer_dim, hidden_states.shape[0], self.gru_hidden_dim).requires_grad_().cuda()
113 |
114 | vertice_out, _ = self.gru(hidden_states, h0)
115 | vertice_out = vertice_out * obj_embedding
116 | vertice_out = vertice_out * emo_embedding
117 |
118 | vertice_out = self.fc(vertice_out)
119 |
120 | vertice_out = vertice_out + template
121 |
122 | return vertice_out
123 |
--------------------------------------------------------------------------------
/hubert/activations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 | from packaging import version
19 | from torch import nn
20 |
21 | from .utils import logging
22 |
23 |
24 | logger = logging.get_logger(__name__)
25 |
26 |
27 | def _gelu_python(x):
28 | """
29 | Original Implementation of the GELU activation function in Google BERT repo when initially created. For
30 | information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
31 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
32 | Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
33 | """
34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
35 |
36 |
37 | def gelu_new(x):
38 | """
39 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
40 | the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
41 | """
42 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
43 |
44 |
45 | if version.parse(torch.__version__) < version.parse("1.4"):
46 | gelu = _gelu_python
47 | else:
48 | gelu = nn.functional.gelu
49 |
50 |
51 | def gelu_fast(x):
52 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
53 |
54 |
55 | def quick_gelu(x):
56 | return x * torch.sigmoid(1.702 * x)
57 |
58 |
59 | def _silu_python(x):
60 | """
61 | See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
62 | Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
63 | Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
64 | Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
65 | later.
66 | """
67 | return x * torch.sigmoid(x)
68 |
69 |
70 | if version.parse(torch.__version__) < version.parse("1.7"):
71 | silu = _silu_python
72 | else:
73 | silu = nn.functional.silu
74 |
75 |
76 | def mish(x):
77 | return x * torch.tanh(nn.functional.softplus(x))
78 |
79 |
80 | def linear_act(x):
81 | return x
82 |
83 |
84 | ACT2FN = {
85 | "relu": nn.functional.relu,
86 | "silu": silu,
87 | "swish": silu,
88 | "gelu": gelu,
89 | "tanh": torch.tanh,
90 | "gelu_new": gelu_new,
91 | "gelu_fast": gelu_fast,
92 | "quick_gelu": quick_gelu,
93 | "mish": mish,
94 | "linear": linear_act,
95 | "sigmoid": torch.sigmoid,
96 | }
97 |
98 |
99 | def get_activation(activation_string):
100 | if activation_string in ACT2FN:
101 | return ACT2FN[activation_string]
102 | else:
103 | raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
104 |
--------------------------------------------------------------------------------
/hubert/configuration_hubert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Hubert model configuration """
16 |
17 | from .configuration_utils import PretrainedConfig
18 | from .utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | "facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json",
25 | # See all Hubert models at https://huggingface.co/models?filter=hubert
26 | }
27 |
28 |
29 | class HubertConfig(PretrainedConfig):
30 | r"""
31 | This is the configuration class to store the configuration of a :class:`~transformers.HubertModel`. It is used to
32 | instantiate an Hubert model according to the specified arguments, defining the model architecture. Instantiating a
33 | configuration with the defaults will yield a similar configuration to that of the Hubert
34 | `facebook/hubert-base-ls960 `__ architecture.
35 |
36 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
37 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
38 |
39 |
40 | Args:
41 | vocab_size (:obj:`int`, `optional`, defaults to 32):
42 | Vocabulary size of the Hubert model. Defines the number of different tokens that can be represented by the
43 | :obj:`inputs_ids` passed when calling :class:`~transformers.HubertModel`. Vocabulary size of the model.
44 | Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
45 | :class:`~transformers.HubertModel`.
46 | hidden_size (:obj:`int`, `optional`, defaults to 768):
47 | Dimensionality of the encoder layers and the pooler layer.
48 | num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
49 | Number of hidden layers in the Transformer encoder.
50 | num_attention_heads (:obj:`int`, `optional`, defaults to 12):
51 | Number of attention heads for each attention layer in the Transformer encoder.
52 | intermediate_size (:obj:`int`, `optional`, defaults to 3072):
53 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
54 | hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
55 | The non-linear activation function (function or string) in the encoder and pooler. If string,
56 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
57 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
58 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
59 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
60 | The dropout ratio for the attention probabilities.
61 | initializer_range (:obj:`float`, `optional`, defaults to 0.02):
62 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
64 | The epsilon used by the layer normalization layers.
65 | feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
66 | The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
67 | normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
68 | convolutional layers.
69 | feat_extract_dropout (:obj:`float`, `optional`, defaults to 0.0):
70 | The dropout probabilitiy for all 1D convolutional layers in feature extractor.
71 | feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
72 | The non-linear activation function (function or string) in the 1D convolutional layers of the feature
73 | extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
74 | conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
75 | A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
76 | feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
77 | conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
78 | A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
79 | of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`.
80 | conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
81 | A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
82 | length of `conv_kernel` defines the number of convolutional layers and has to match the the length of
83 | `conv_dim`.
84 | conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
85 | Whether the 1D convolutional layers have a bias.
86 | num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128):
87 | Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
88 | embeddings layer.
89 | num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16):
90 | Number of groups of 1D convolutional positional embeddings layer.
91 | do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
92 | Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
93 | True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
94 | False`` corresponds to applying layer norm after the attention layer.
95 | apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`):
96 | Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see
97 | `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
98 | `__.
99 | mask_time_prob (:obj:`float`, `optional`, defaults to 0.05):
100 | Propability of each feature vector along the time axis to be chosen as the start of the vector span to be
101 | masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be
102 | masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
103 | mask_time_length (:obj:`int`, `optional`, defaults to 10):
104 | Length of vector span along the time axis.
105 | mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0):
106 | Propability of each feature vector along the feature axis to be chosen as the start of the vector span to
107 | be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be
108 | masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
109 | mask_feature_length (:obj:`int`, `optional`, defaults to 10):
110 | Length of vector span along the feature axis.
111 | ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
112 | Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
113 | instance of :class:`~transformers.HubertForCTC`.
114 | ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`):
115 | Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
116 | mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
117 | instance of :class:`~transformers.HubertForCTC`.
118 | gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
119 | If True, use gradient checkpointing to save memory at the expense of slower backward pass.
120 |
121 | Example::
122 |
123 | >>> from transformers import HubertModel, HubertConfig
124 |
125 | >>> # Initializing a Hubert facebook/hubert-base-ls960 style configuration
126 | >>> configuration = HubertConfig()
127 |
128 | >>> # Initializing a model from the facebook/hubert-base-ls960 style configuration
129 | >>> model = HubertModel(configuration)
130 |
131 | >>> # Accessing the model configuration
132 | >>> configuration = model.config
133 | """
134 | model_type = "hubert"
135 |
136 | def __init__(
137 | self,
138 | vocab_size=32,
139 | hidden_size=768,
140 | num_hidden_layers=12,
141 | num_attention_heads=12,
142 | intermediate_size=3072,
143 | hidden_act="gelu",
144 | hidden_dropout=0.1,
145 | activation_dropout=0.1,
146 | attention_dropout=0.1,
147 | feat_proj_dropout=0.1,
148 | final_dropout=0.1,
149 | layerdrop=0.1,
150 | initializer_range=0.02,
151 | layer_norm_eps=1e-5,
152 | feat_extract_norm="group",
153 | feat_extract_activation="gelu",
154 | conv_dim=(512, 512, 512, 512, 512, 512, 512),
155 | conv_stride=(5, 2, 2, 2, 2, 2, 2),
156 | conv_kernel=(10, 3, 3, 3, 3, 2, 2),
157 | conv_bias=False,
158 | num_conv_pos_embeddings=128,
159 | num_conv_pos_embedding_groups=16,
160 | do_stable_layer_norm=False,
161 | apply_spec_augment=True,
162 | mask_time_prob=0.05,
163 | mask_time_length=10,
164 | mask_feature_prob=0.0,
165 | mask_feature_length=10,
166 | ctc_loss_reduction="sum",
167 | ctc_zero_infinity=False,
168 | gradient_checkpointing=False,
169 | pad_token_id=0,
170 | bos_token_id=1,
171 | eos_token_id=2,
172 | **kwargs
173 | ):
174 | super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
175 | self.hidden_size = hidden_size
176 | self.feat_extract_norm = feat_extract_norm
177 | self.feat_extract_activation = feat_extract_activation
178 | self.conv_dim = list(conv_dim)
179 | self.conv_stride = list(conv_stride)
180 | self.conv_kernel = list(conv_kernel)
181 | self.conv_bias = conv_bias
182 | self.num_conv_pos_embeddings = num_conv_pos_embeddings
183 | self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
184 | self.num_feat_extract_layers = len(self.conv_dim)
185 | self.num_hidden_layers = num_hidden_layers
186 | self.intermediate_size = intermediate_size
187 | self.hidden_act = hidden_act
188 | self.num_attention_heads = num_attention_heads
189 | self.hidden_dropout = hidden_dropout
190 | self.attention_dropout = attention_dropout
191 | self.activation_dropout = activation_dropout
192 | self.feat_proj_dropout = feat_proj_dropout
193 | self.final_dropout = final_dropout
194 | self.layerdrop = layerdrop
195 | self.layer_norm_eps = layer_norm_eps
196 | self.initializer_range = initializer_range
197 | self.vocab_size = vocab_size
198 | self.do_stable_layer_norm = do_stable_layer_norm
199 | self.gradient_checkpointing = gradient_checkpointing
200 |
201 | if (
202 | (len(self.conv_stride) != self.num_feat_extract_layers)
203 | or (len(self.conv_kernel) != self.num_feat_extract_layers)
204 | or (len(self.conv_dim) != self.num_feat_extract_layers)
205 | ):
206 | raise ValueError(
207 | "Configuration for convolutional layers is incorrect."
208 | "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`,"
209 | f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
210 | f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
211 | )
212 |
213 | # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
214 | self.apply_spec_augment = apply_spec_augment
215 | self.mask_time_prob = mask_time_prob
216 | self.mask_time_length = mask_time_length
217 | self.mask_feature_prob = mask_feature_prob
218 | self.mask_feature_length = mask_feature_length
219 |
220 | # ctc loss
221 | self.ctc_loss_reduction = ctc_loss_reduction
222 | self.ctc_zero_infinity = ctc_zero_infinity
223 |
--------------------------------------------------------------------------------
/hubert/deepspeed.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Integration with Deepspeed
16 | """
17 |
18 | import importlib.util
19 | import io
20 | import json
21 | import weakref
22 | from copy import deepcopy
23 | from functools import partialmethod
24 |
25 | from .dependency_versions_check import dep_version_check
26 | from .file_utils import is_torch_available
27 | from .utils import logging
28 |
29 |
30 | if is_torch_available():
31 | import torch
32 |
33 | logger = logging.get_logger(__name__)
34 |
35 |
36 | def is_deepspeed_available():
37 | return importlib.util.find_spec("deepspeed") is not None
38 |
39 |
40 | class HfDeepSpeedConfig:
41 | """
42 | This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
43 |
44 | A ``weakref`` of this object is stored in the module's globals to be able to access the config from areas where
45 | things like the Trainer object is not available (e.g. ``from_pretrained`` and ``_get_resized_embeddings``).
46 | Therefore it's important that this object remains alive while the program is still running.
47 |
48 | :class:`~transformers.Trainer` uses the ``HfTrainerDeepSpeedConfig`` subclass instead. That subclass has logic to
49 | sync the configuration with values of :class:`~transformers.TrainingArguments` by replacing special placeholder
50 | values: ``"auto"``. Without this special logic the DeepSpeed configuration is not modified in any way.
51 |
52 | Args:
53 | config_file_or_dict (:obj:`Union[str, Dict]`) - path to DeepSpeed config file or dict.
54 |
55 | """
56 |
57 | def __init__(self, config_file_or_dict):
58 | # set global weakref object
59 | set_hf_deepspeed_config(self)
60 |
61 | dep_version_check("deepspeed")
62 |
63 | if isinstance(config_file_or_dict, dict):
64 | # Don't modify user's data should they want to reuse it (e.g. in tests), because once we
65 | # modified it, it will not be accepted here again, since `auto` values would have been overriden
66 | config = deepcopy(config_file_or_dict)
67 | elif isinstance(config_file_or_dict, str):
68 | with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
69 | config = json.load(f)
70 | else:
71 | raise ValueError("expecting either a path to a DeepSpeed config file or a pre-populated dict")
72 | self.config = config
73 |
74 | # zero stage - this is done as early as possible, before model is created, to allow
75 | # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
76 | # during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc.
77 | self._stage = self.get_value("zero_optimization.stage", -1)
78 |
79 | # offload
80 | self._offload = False
81 | if self.is_zero2() or self.is_zero3():
82 | offload_devices_valid = set(["cpu", "nvme"])
83 | offload_devices = set(
84 | [
85 | self.get_value("zero_optimization.offload_optimizer.device"),
86 | self.get_value("zero_optimization.offload_param.device"),
87 | ]
88 | )
89 | if len(offload_devices & offload_devices_valid) > 0:
90 | self._offload = True
91 |
92 | def find_config_node(self, ds_key_long):
93 | config = self.config
94 |
95 | # find the config node of interest if it exists
96 | nodes = ds_key_long.split(".")
97 | ds_key = nodes.pop()
98 | for node in nodes:
99 | config = config.get(node)
100 | if config is None:
101 | return None, ds_key
102 |
103 | return config, ds_key
104 |
105 | def get_value(self, ds_key_long, default=None):
106 | """
107 | Returns the set value or ``default`` if no value is set
108 | """
109 | config, ds_key = self.find_config_node(ds_key_long)
110 | if config is None:
111 | return default
112 | return config.get(ds_key, default)
113 |
114 | def is_true(self, ds_key_long):
115 | """
116 | Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
117 | ask the very specific question of whether the value is set to :obj:`True` (and it's not set to :obj:`False` or
118 | isn't set).
119 |
120 | """
121 | value = self.get_value(ds_key_long)
122 | return False if value is None else bool(value)
123 |
124 | def is_false(self, ds_key_long):
125 | """
126 | Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
127 | ask the very specific question of whether the value is set to :obj:`False` (and it's not set to :obj:`True` or
128 | isn't set).
129 | """
130 | value = self.get_value(ds_key_long)
131 | return False if value is None else not bool(value)
132 |
133 | def is_zero2(self):
134 | return self._stage == 2
135 |
136 | def is_zero3(self):
137 | return self._stage == 3
138 |
139 | def is_offload(self):
140 | return self._offload
141 |
142 |
143 | class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
144 | """
145 | The ``HfTrainerDeepSpeedConfig`` object is meant to be created during ``TrainingArguments`` object creation and has
146 | the same lifespan as the latter.
147 | """
148 |
149 | def __init__(self, config_file_or_dict):
150 | super().__init__(config_file_or_dict)
151 | self._dtype = torch.float16
152 | self.mismatches = []
153 |
154 | def dtype(self):
155 | return self._dtype
156 |
157 | def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
158 | """
159 | A utility method that massages the config file and can optionally verify that the values match.
160 |
161 | 1. Replace "auto" values with ``TrainingArguments`` value.
162 |
163 | 2. If it wasn't "auto" and ``must_match`` is true, then check that DS config matches Trainer
164 | config values and if mismatched add the entry to ``self.mismatched`` - will assert during
165 | ``trainer_config_finalize`` for one or more mismatches.
166 |
167 | """
168 | config, ds_key = self.find_config_node(ds_key_long)
169 | if config is None:
170 | return
171 |
172 | if config.get(ds_key) == "auto":
173 | config[ds_key] = hf_val
174 | return
175 |
176 | if not must_match:
177 | return
178 |
179 | ds_val = config.get(ds_key)
180 | if ds_val is not None and ds_val != hf_val:
181 | self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")
182 |
183 | fill_only = partialmethod(fill_match, must_match=False)
184 |
185 | def trainer_config_process(self, args):
186 | """
187 | Adjust the config with ``TrainingArguments`` values. This stage is run during ``TrainingArguments`` object
188 | creation.
189 | """
190 | # DeepSpeed does:
191 | # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
192 | train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
193 | self.fill_match(
194 | "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size"
195 | )
196 | self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps")
197 | self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)")
198 | self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm")
199 |
200 | self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate")
201 | self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2")
202 | self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon")
203 | self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay")
204 |
205 | self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
206 | self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
207 | self.fill_match("scheduler.params.warmup_num_steps", args.warmup_steps, "warmup_steps")
208 | # total_num_steps - will get set in trainer_config_finalize
209 |
210 | # fp16
211 | if args.fp16:
212 | fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
213 | else:
214 | fp16_backend = None
215 |
216 | # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
217 | # any here unless the user did the work
218 | self.fill_match("fp16.enabled", fp16_backend == "amp", "fp16+fp16_backend(amp)")
219 |
220 | # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
221 | # ZeRO features
222 | self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
223 | self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
224 |
225 | # only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this
226 | # whole config section is missing then the fallback is fp16
227 | if self.is_false("fp16.enabled"):
228 | self._dtype = torch.float32
229 | # later there will be other dtypes besides just fp16 and fp32
230 | # also not quite sure what dtype should be under apex, defaulting to fp16 for now
231 |
232 | def trainer_config_finalize(self, args, model, num_training_steps):
233 | """
234 | This stage is run after we have the model and know num_training_steps.
235 |
236 | Now we we can complete the configuration process.
237 | """
238 | # zero
239 | if self.is_zero3():
240 | # automatically assign the optimal config values based on model config
241 | hidden_size = model.config.hidden_size
242 | self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
243 | self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size)
244 | self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size)
245 |
246 | # scheduler
247 | self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)")
248 |
249 | if len(self.mismatches) > 0:
250 | mismatches = "\n".join(self.mismatches)
251 | raise ValueError(
252 | f"Please correct the following DeepSpeed config values that mismatch TrainingArguments values:\n{mismatches}\n"
253 | "The easiest method is to set these DeepSpeed config values to 'auto'."
254 | )
255 |
256 |
257 | # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
258 | _hf_deepspeed_config_weak_ref = None
259 |
260 |
261 | def set_hf_deepspeed_config(hf_deepspeed_config_obj):
262 | # this is a special weakref global object to allow us to get to Deepspeed config from APIs
263 | # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
264 | global _hf_deepspeed_config_weak_ref
265 | # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed)
266 | _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)
267 |
268 |
269 | def is_deepspeed_zero3_enabled():
270 | if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
271 | return _hf_deepspeed_config_weak_ref().is_zero3()
272 | else:
273 | return False
274 |
275 |
276 | def deepspeed_config():
277 | if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
278 | return _hf_deepspeed_config_weak_ref().config
279 | else:
280 | return None
281 |
282 |
283 | def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
284 | """
285 | Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
286 |
287 | If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made.
288 |
289 | Args:
290 | trainer: Trainer object
291 | num_training_steps: per single gpu
292 | resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
293 |
294 | Returns: model, optimizer, lr_scheduler
295 |
296 | """
297 | import deepspeed
298 |
299 | model = trainer.model
300 |
301 | hf_deepspeed_config = trainer.args.hf_deepspeed_config
302 | hf_deepspeed_config.trainer_config_finalize(trainer.args, model, num_training_steps)
303 |
304 | # resume config update - some bits like `model` and `num_training_steps` only become available during train
305 | config = hf_deepspeed_config.config
306 |
307 | # Optimizer + Scheduler
308 | # Currently supported combos:
309 | # 1. DS scheduler + DS optimizer: Yes
310 | # 2. HF scheduler + HF optimizer: Yes
311 | # 3. DS scheduler + HF optimizer: Yes
312 | # 4. HF scheduler + DS optimizer: No
313 | #
314 | # Unless Offload is enabled in which case it's:
315 | # 1. DS scheduler + DS optimizer: Yes
316 | # 2. HF scheduler + HF optimizer: No
317 | # 3. DS scheduler + HF optimizer: No
318 | # 4. HF scheduler + DS optimizer: No
319 |
320 | optimizer = None
321 | if "optimizer" not in config:
322 | if hf_deepspeed_config.is_offload():
323 | raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers")
324 |
325 | # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
326 | # But trainer uses AdamW by default.
327 | trainer.create_optimizer()
328 | optimizer = trainer.optimizer
329 | # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
330 | config["zero_allow_untested_optimizer"] = True
331 |
332 | # DS schedulers (deepspeed/runtime/lr_schedules.py):
333 | #
334 | # DS name | --lr_scheduler_type | HF func | Notes
335 | # -------------| ---------------------|-----------------------------------|--------------------
336 | # LRRangeTest | na | na | LRRT
337 | # OneCycle | na | na | 1CLR
338 | # WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
339 | # WarmupDecayLR| linear | get_linear_schedule_with_warmup |
340 | lr_scheduler = None
341 | if "scheduler" not in config:
342 | if "optimizer" in config:
343 | # to make this option work, we need to init DS optimizer first, then init HS scheduler,
344 | # then pass the HS scheduler to DS init, which is not possible at the moment
345 | raise ValueError("At the moment HF scheduler + DeepSpeed optimizer combination is not possible")
346 | else:
347 | trainer.create_scheduler(num_training_steps=num_training_steps)
348 | lr_scheduler = trainer.lr_scheduler
349 |
350 | # keep for quick debug:
351 | # from pprint import pprint; pprint(config)
352 |
353 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
354 |
355 | model, optimizer, _, lr_scheduler = deepspeed.initialize(
356 | model=model,
357 | model_parameters=model_parameters,
358 | config_params=config,
359 | optimizer=optimizer,
360 | lr_scheduler=lr_scheduler,
361 | )
362 |
363 | if resume_from_checkpoint is not None:
364 |
365 | # it's possible that the user is trying to resume from model_path, which doesn't necessarily
366 | # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
367 | # a resume from a checkpoint and not just a local pretrained weight. So we check here if the
368 | # path contains what looks like a deepspeed checkpoint
369 | import glob
370 |
371 | deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))
372 |
373 | if len(deepspeed_checkpoint_dirs) > 0:
374 | logger.info(f"Attempting to resume from {resume_from_checkpoint}")
375 | # this magically updates self.optimizer and self.lr_scheduler
376 | load_path, _ = model.load_checkpoint(
377 | resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
378 | )
379 | if load_path is None:
380 | raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
381 | else:
382 | logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
383 |
384 | return model, optimizer, lr_scheduler
385 |
--------------------------------------------------------------------------------
/hubert/dependency_versions_check.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import sys
15 |
16 | from .dependency_versions_table import deps
17 | from .utils.versions import require_version, require_version_core
18 |
19 |
20 | # define which module versions we always want to check at run time
21 | # (usually the ones defined in `install_requires` in setup.py)
22 | #
23 | # order specific notes:
24 | # - tqdm must be checked before tokenizers
25 |
26 | pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
27 | if sys.version_info < (3, 7):
28 | pkgs_to_check_at_runtime.append("dataclasses")
29 | if sys.version_info < (3, 8):
30 | pkgs_to_check_at_runtime.append("importlib_metadata")
31 |
32 | for pkg in pkgs_to_check_at_runtime:
33 | if pkg in deps:
34 | if pkg == "tokenizers":
35 | # must be loaded here, or else tqdm check may fail
36 | from .file_utils import is_tokenizers_available
37 |
38 | if not is_tokenizers_available():
39 | continue # not required, check version only if installed
40 |
41 | require_version_core(deps[pkg])
42 | else:
43 | raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44 |
45 |
46 | def dep_version_check(pkg, hint=None):
47 | require_version(deps[pkg], hint)
48 |
--------------------------------------------------------------------------------
/hubert/dependency_versions_table.py:
--------------------------------------------------------------------------------
1 | # THIS FILE HAS BEEN AUTOGENERATED. To update:
2 | # 1. modify the `_deps` dict in setup.py
3 | # 2. run `make deps_table_update``
4 | deps = {
5 | "Pillow": "Pillow",
6 | "black": "black==21.4b0",
7 | "cookiecutter": "cookiecutter==1.7.2",
8 | "dataclasses": "dataclasses",
9 | "datasets": "datasets",
10 | "deepspeed": "deepspeed>=0.4.0",
11 | "docutils": "docutils==0.16.0",
12 | "fairscale": "fairscale>0.3",
13 | "faiss-cpu": "faiss-cpu",
14 | "fastapi": "fastapi",
15 | "filelock": "filelock",
16 | "flake8": "flake8>=3.8.3",
17 | "flax": "flax>=0.3.4",
18 | "fugashi": "fugashi>=1.0",
19 | "huggingface-hub": "huggingface-hub==0.0.8",
20 | "importlib_metadata": "importlib_metadata",
21 | "ipadic": "ipadic>=1.0.0,<2.0",
22 | "isort": "isort>=5.5.4",
23 | "jax": "jax>=0.2.8",
24 | "jaxlib": "jaxlib>=0.1.65",
25 | "jieba": "jieba",
26 | "keras2onnx": "keras2onnx",
27 | "nltk": "nltk",
28 | "numpy": "numpy>=1.17",
29 | "onnxconverter-common": "onnxconverter-common",
30 | "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
31 | "onnxruntime": "onnxruntime>=1.4.0",
32 | "optuna": "optuna",
33 | "packaging": "packaging",
34 | "parameterized": "parameterized",
35 | "protobuf": "protobuf",
36 | "psutil": "psutil",
37 | "pyyaml": "pyyaml",
38 | "pydantic": "pydantic",
39 | "pytest": "pytest",
40 | "pytest-sugar": "pytest-sugar",
41 | "pytest-xdist": "pytest-xdist",
42 | "python": "python>=3.6.0",
43 | "ray": "ray",
44 | "recommonmark": "recommonmark",
45 | "regex": "regex!=2019.12.17",
46 | "requests": "requests",
47 | "rouge-score": "rouge-score",
48 | "sacrebleu": "sacrebleu>=1.4.12",
49 | "sacremoses": "sacremoses",
50 | "sagemaker": "sagemaker>=2.31.0",
51 | "scikit-learn": "scikit-learn",
52 | "sentencepiece": "sentencepiece==0.1.91",
53 | "soundfile": "soundfile",
54 | "sphinx-copybutton": "sphinx-copybutton",
55 | "sphinx-markdown-tables": "sphinx-markdown-tables",
56 | "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
57 | "sphinx": "sphinx==3.2.1",
58 | "sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
59 | "starlette": "starlette",
60 | "tensorflow-cpu": "tensorflow-cpu>=2.3",
61 | "tensorflow": "tensorflow>=2.3",
62 | "timeout-decorator": "timeout-decorator",
63 | "timm": "timm",
64 | "tokenizers": "tokenizers>=0.10.1,<0.11",
65 | "torch": "torch>=1.0",
66 | "torchaudio": "torchaudio",
67 | "tqdm": "tqdm>=4.27",
68 | "unidic": "unidic>=1.0.2",
69 | "unidic_lite": "unidic_lite>=1.0.7",
70 | "uvicorn": "uvicorn",
71 | }
72 |
--------------------------------------------------------------------------------
/hubert/generation_beam_search.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import warnings
17 | from abc import ABC, abstractmethod
18 | from collections import UserDict
19 | from typing import Optional, Tuple
20 |
21 | import torch
22 |
23 | from .file_utils import add_start_docstrings
24 |
25 |
26 | PROCESS_INPUTS_DOCSTRING = r"""
27 | Args:
28 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
29 | Indices of input sequence tokens in the vocabulary.
30 |
31 | Indices can be obtained using any class inheriting from :class:`~transformers.PreTrainedTokenizer`. See
32 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
33 | details.
34 |
35 | `What are input IDs? <../glossary.html#input-ids>`__
36 | next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
37 | Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses.
38 | next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
39 | :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses.
40 | next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
41 | Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond.
42 | pad_token_id (:obj:`int`, `optional`):
43 | The id of the `padding` token.
44 | eos_token_id (:obj:`int`, `optional`):
45 | The id of the `end-of-sequence` token.
46 |
47 | Return:
48 | :obj:`UserDict`: A dictionary composed of the fields as defined above:
49 |
50 | - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated
51 | scores of all non-finished beams.
52 | - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens
53 | to be added to the non-finished beam_hypotheses.
54 | - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices
55 | indicating to which beam the next tokens shall be added.
56 |
57 | """
58 |
59 | FINALIZE_INPUTS_DOCSTRING = r"""
60 | Args:
61 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
62 | Indices of input sequence tokens in the vocabulary.
63 |
64 | Indices can be obtained using any class inheriting from :class:`~transformers.PreTrainedTokenizer`. See
65 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
66 | details.
67 |
68 | `What are input IDs? <../glossary.html#input-ids>`__
69 | final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
70 | The final scores of all non-finished beams.
71 | final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
72 | The last tokens to be added to the non-finished beam_hypotheses.
73 | final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
74 | The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added.
75 | pad_token_id (:obj:`int`, `optional`):
76 | The id of the `padding` token.
77 | eos_token_id (:obj:`int`, `optional`):
78 | The id of the `end-of-sequence` token.
79 |
80 | Return:
81 | :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
82 | sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
83 | batches finished early due to the :obj:`eos_token_id`.
84 |
85 | """
86 |
87 |
88 | class BeamScorer(ABC):
89 | """
90 | Abstract base class for all beam scorers that are used for :meth:`~transformers.PreTrainedModel.beam_search` and
91 | :meth:`~transformers.PreTrainedModel.beam_sample`.
92 | """
93 |
94 | @abstractmethod
95 | @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
96 | def process(
97 | self,
98 | input_ids: torch.LongTensor,
99 | next_scores: torch.FloatTensor,
100 | next_tokens: torch.LongTensor,
101 | next_indices: torch.LongTensor,
102 | **kwargs
103 | ) -> Tuple[torch.Tensor]:
104 | raise NotImplementedError("This is an abstract method.")
105 |
106 | @abstractmethod
107 | @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
108 | def finalize(
109 | self,
110 | input_ids: torch.LongTensor,
111 | next_scores: torch.FloatTensor,
112 | next_tokens: torch.LongTensor,
113 | next_indices: torch.LongTensor,
114 | max_length: int,
115 | **kwargs
116 | ) -> torch.LongTensor:
117 | raise NotImplementedError("This is an abstract method.")
118 |
119 |
120 | class BeamSearchScorer(BeamScorer):
121 | r"""
122 | :class:`transformers.BeamScorer` implementing standard beam search decoding.
123 |
124 | Adapted in part from `Facebook's XLM beam search code
125 | `__.
126 |
127 | Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation
128 | `__
129 |
130 | Args:
131 | batch_size (:obj:`int`):
132 | Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel.
133 | max_length (:obj:`int`):
134 | The maximum length of the sequence to be generated.
135 | num_beams (:obj:`int`):
136 | Number of beams for beam search.
137 | device (:obj:`torch.device`):
138 | Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
139 | :obj:`BeamSearchScorer` will be allocated.
140 | length_penalty (:obj:`float`, `optional`, defaults to 1.0):
141 | Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
142 | model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
143 | sequences.
144 | do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
145 | Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
146 | num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
147 | The number of beam hypotheses that shall be returned upon calling
148 | :meth:`~transformer.BeamSearchScorer.finalize`.
149 | num_beam_groups (:obj:`int`):
150 | Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
151 | beams. See `this paper `__ for more details.
152 | """
153 |
154 | def __init__(
155 | self,
156 | batch_size: int,
157 | num_beams: int,
158 | device: torch.device,
159 | length_penalty: Optional[float] = 1.0,
160 | do_early_stopping: Optional[bool] = False,
161 | num_beam_hyps_to_keep: Optional[int] = 1,
162 | num_beam_groups: Optional[int] = 1,
163 | **kwargs,
164 | ):
165 | self.num_beams = num_beams
166 | self.device = device
167 | self.length_penalty = length_penalty
168 | self.do_early_stopping = do_early_stopping
169 | self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
170 | self.num_beam_groups = num_beam_groups
171 | self.group_size = self.num_beams // self.num_beam_groups
172 |
173 | self._is_init = False
174 | self._beam_hyps = [
175 | BeamHypotheses(
176 | num_beams=self.num_beams,
177 | length_penalty=self.length_penalty,
178 | early_stopping=self.do_early_stopping,
179 | )
180 | for _ in range(batch_size)
181 | ]
182 | self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
183 |
184 | if not isinstance(num_beams, int) or num_beams <= 1:
185 | raise ValueError(
186 | f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
187 | )
188 |
189 | if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
190 | raise ValueError(
191 | f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
192 | f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
193 | )
194 |
195 | if "max_length" in kwargs:
196 | warnings.warn(
197 | "Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
198 | "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
199 | ",or `group_beam_search(...)`."
200 | )
201 |
202 | @property
203 | def is_done(self) -> bool:
204 | return self._done.all()
205 |
206 | def process(
207 | self,
208 | input_ids: torch.LongTensor,
209 | next_scores: torch.FloatTensor,
210 | next_tokens: torch.LongTensor,
211 | next_indices: torch.LongTensor,
212 | pad_token_id: Optional[int] = None,
213 | eos_token_id: Optional[int] = None,
214 | ) -> Tuple[torch.Tensor]:
215 | cur_len = input_ids.shape[-1]
216 | batch_size = len(self._beam_hyps)
217 | assert batch_size == (input_ids.shape[0] // self.group_size)
218 |
219 | device = input_ids.device
220 | next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
221 | next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
222 | next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
223 |
224 | for batch_idx, beam_hyp in enumerate(self._beam_hyps):
225 | if self._done[batch_idx]:
226 | assert (
227 | len(beam_hyp) >= self.num_beams
228 | ), f"Batch can only be done if at least {self.num_beams} beams have been generated"
229 | assert (
230 | eos_token_id is not None and pad_token_id is not None
231 | ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
232 | # pad the batch
233 | next_beam_scores[batch_idx, :] = 0
234 | next_beam_tokens[batch_idx, :] = pad_token_id
235 | next_beam_indices[batch_idx, :] = 0
236 | continue
237 |
238 | # next tokens for this sentence
239 | beam_idx = 0
240 | for beam_token_rank, (next_token, next_score, next_index) in enumerate(
241 | zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
242 | ):
243 | batch_beam_idx = batch_idx * self.group_size + next_index
244 | # add to generated hypotheses if end of sentence
245 | if (eos_token_id is not None) and (next_token.item() == eos_token_id):
246 | # if beam_token does not belong to top num_beams tokens, it should not be added
247 | is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
248 | if is_beam_token_worse_than_top_num_beams:
249 | continue
250 | beam_hyp.add(
251 | input_ids[batch_beam_idx].clone(),
252 | next_score.item(),
253 | )
254 | else:
255 | # add next predicted token since it is not eos_token
256 | next_beam_scores[batch_idx, beam_idx] = next_score
257 | next_beam_tokens[batch_idx, beam_idx] = next_token
258 | next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
259 | beam_idx += 1
260 |
261 | # once the beam for next step is full, don't add more tokens to it.
262 | if beam_idx == self.group_size:
263 | break
264 |
265 | if beam_idx < self.group_size:
266 | raise ValueError(
267 | f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
268 | )
269 |
270 | # Check if we are done so that we can save a pad step if all(done)
271 | self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
272 | next_scores[batch_idx].max().item(), cur_len
273 | )
274 |
275 | return UserDict(
276 | {
277 | "next_beam_scores": next_beam_scores.view(-1),
278 | "next_beam_tokens": next_beam_tokens.view(-1),
279 | "next_beam_indices": next_beam_indices.view(-1),
280 | }
281 | )
282 |
283 | def finalize(
284 | self,
285 | input_ids: torch.LongTensor,
286 | final_beam_scores: torch.FloatTensor,
287 | final_beam_tokens: torch.LongTensor,
288 | final_beam_indices: torch.LongTensor,
289 | max_length: int,
290 | pad_token_id: Optional[int] = None,
291 | eos_token_id: Optional[int] = None,
292 | ) -> Tuple[torch.LongTensor]:
293 | batch_size = len(self._beam_hyps)
294 |
295 | # finalize all open beam hypotheses and add to generated hypotheses
296 | for batch_idx, beam_hyp in enumerate(self._beam_hyps):
297 | if self._done[batch_idx]:
298 | continue
299 |
300 | # all open beam hypotheses are added to the beam hypothesis
301 | # beam hypothesis class automatically keeps the best beams
302 | for beam_id in range(self.num_beams):
303 | batch_beam_idx = batch_idx * self.num_beams + beam_id
304 | final_score = final_beam_scores[batch_beam_idx].item()
305 | final_tokens = input_ids[batch_beam_idx]
306 | beam_hyp.add(final_tokens, final_score)
307 |
308 | # select the best hypotheses
309 | sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
310 | best = []
311 | best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
312 |
313 | # retrieve best hypotheses
314 | for i, beam_hyp in enumerate(self._beam_hyps):
315 | sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
316 | for j in range(self.num_beam_hyps_to_keep):
317 | best_hyp_tuple = sorted_hyps.pop()
318 | best_score = best_hyp_tuple[0]
319 | best_hyp = best_hyp_tuple[1]
320 | sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
321 |
322 | # append to lists
323 | best.append(best_hyp)
324 | best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
325 |
326 | # prepare for adding eos
327 | sent_max_len = min(sent_lengths.max().item() + 1, max_length)
328 | decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
329 | # shorter batches are padded if needed
330 | if sent_lengths.min().item() != sent_lengths.max().item():
331 | assert pad_token_id is not None, "`pad_token_id` has to be defined"
332 | decoded.fill_(pad_token_id)
333 |
334 | # fill with hypotheses and eos_token_id if the latter fits in
335 | for i, hypo in enumerate(best):
336 | decoded[i, : sent_lengths[i]] = hypo
337 | if sent_lengths[i] < max_length:
338 | decoded[i, sent_lengths[i]] = eos_token_id
339 | return UserDict(
340 | {
341 | "sequences": decoded,
342 | "sequence_scores": best_scores,
343 | }
344 | )
345 |
346 |
347 | class BeamHypotheses:
348 | def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
349 | """
350 | Initialize n-best list of hypotheses.
351 | """
352 | self.length_penalty = length_penalty
353 | self.early_stopping = early_stopping
354 | self.num_beams = num_beams
355 | self.beams = []
356 | self.worst_score = 1e9
357 |
358 | def __len__(self):
359 | """
360 | Number of hypotheses in the list.
361 | """
362 | return len(self.beams)
363 |
364 | def add(self, hyp: torch.LongTensor, sum_logprobs: float):
365 | """
366 | Add a new hypothesis to the list.
367 | """
368 | score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
369 | if len(self) < self.num_beams or score > self.worst_score:
370 | self.beams.append((score, hyp))
371 | if len(self) > self.num_beams:
372 | sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
373 | del self.beams[sorted_next_scores[0][1]]
374 | self.worst_score = sorted_next_scores[1][0]
375 | else:
376 | self.worst_score = min(score, self.worst_score)
377 |
378 | def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
379 | """
380 | If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
381 | one in the heap, then we are done with this sentence.
382 | """
383 |
384 | if len(self) < self.num_beams:
385 | return False
386 | elif self.early_stopping:
387 | return True
388 | else:
389 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty
390 | ret = self.worst_score >= cur_score
391 | return ret
392 |
--------------------------------------------------------------------------------
/hubert/generation_logits_process.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import inspect
17 | import math
18 | from abc import ABC
19 | from typing import Callable, Iterable, List
20 |
21 | import numpy as np
22 | import torch
23 |
24 | from .file_utils import add_start_docstrings
25 | from .utils.logging import get_logger
26 |
27 |
28 | logger = get_logger(__name__)
29 |
30 |
31 | LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
32 | Args:
33 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
34 | Indices of input sequence tokens in the vocabulary.
35 |
36 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See
37 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
38 | details.
39 |
40 | `What are input IDs? <../glossary.html#input-ids>`__
41 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
42 | Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
43 | search or log softmax for each vocabulary token when using beam search
44 | kwargs:
45 | Additional logits processor specific kwargs.
46 |
47 | Return:
48 | :obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
49 |
50 | """
51 |
52 |
53 | class LogitsProcessor(ABC):
54 | """Abstract base class for all logit processors that can be applied during generation."""
55 |
56 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
57 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
58 | """Torch method for processing logits."""
59 | raise NotImplementedError(
60 | f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
61 | )
62 |
63 |
64 | class LogitsWarper(ABC):
65 | """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
66 |
67 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
68 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
69 | """Torch method for warping logits."""
70 | raise NotImplementedError(
71 | f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
72 | )
73 |
74 |
75 | class LogitsProcessorList(list):
76 | """
77 | This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
78 | :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
79 | list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
80 | :class:`~transformers.LogitsWarper` to the inputs.
81 | """
82 |
83 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
84 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
85 | for processor in self:
86 | function_args = inspect.signature(processor.__call__).parameters
87 | if len(function_args) > 2:
88 | assert all(
89 | arg in kwargs for arg in list(function_args.keys())[2:]
90 | ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
91 | scores = processor(input_ids, scores, **kwargs)
92 | else:
93 | scores = processor(input_ids, scores)
94 | return scores
95 |
96 |
97 | class MinLengthLogitsProcessor(LogitsProcessor):
98 | r"""
99 | :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
100 |
101 | Args:
102 | min_length (:obj:`int`):
103 | The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
104 | eos_token_id (:obj:`int`):
105 | The id of the `end-of-sequence` token.
106 | """
107 |
108 | def __init__(self, min_length: int, eos_token_id: int):
109 | if not isinstance(min_length, int) or min_length < 0:
110 | raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
111 |
112 | if not isinstance(eos_token_id, int) or eos_token_id < 0:
113 | raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
114 |
115 | self.min_length = min_length
116 | self.eos_token_id = eos_token_id
117 |
118 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
119 | cur_len = input_ids.shape[-1]
120 | if cur_len < self.min_length:
121 | scores[:, self.eos_token_id] = -float("inf")
122 | return scores
123 |
124 |
125 | class TemperatureLogitsWarper(LogitsWarper):
126 | r"""
127 | :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
128 |
129 | Args:
130 | temperature (:obj:`float`):
131 | The value used to module the logits distribution.
132 | """
133 |
134 | def __init__(self, temperature: float):
135 | if not isinstance(temperature, float) or not (temperature > 0):
136 | raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
137 |
138 | self.temperature = temperature
139 |
140 | def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
141 | scores = scores / self.temperature
142 | return scores
143 |
144 |
145 | class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
146 | r"""
147 | :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
148 |
149 | Args:
150 | repetition_penalty (:obj:`float`):
151 | The parameter for repetition penalty. 1.0 means no penalty. See `this paper
152 | `__ for more details.
153 | """
154 |
155 | def __init__(self, penalty: float):
156 | if not isinstance(penalty, float) or not (penalty > 0):
157 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
158 |
159 | self.penalty = penalty
160 |
161 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
162 | score = torch.gather(scores, 1, input_ids)
163 |
164 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
165 | score = torch.where(score < 0, score * self.penalty, score / self.penalty)
166 |
167 | scores.scatter_(1, input_ids, score)
168 | return scores
169 |
170 |
171 | class TopPLogitsWarper(LogitsWarper):
172 | """
173 | :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
174 | prob_cut_off.
175 |
176 | Args:
177 | top_p (:obj:`float`):
178 | If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
179 | kept for generation.
180 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
181 | All filtered values will be set to this float value.
182 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
183 | Minimum number of tokens that cannot be filtered.
184 | """
185 |
186 | def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
187 | if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
188 | raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
189 |
190 | self.top_p = top_p
191 | self.filter_value = filter_value
192 | self.min_tokens_to_keep = min_tokens_to_keep
193 |
194 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
195 | sorted_logits, sorted_indices = torch.sort(scores, descending=True)
196 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
197 |
198 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
199 | sorted_indices_to_remove = cumulative_probs > self.top_p
200 | if self.min_tokens_to_keep > 1:
201 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
202 | sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
203 | # Shift the indices to the right to keep also the first token above the threshold
204 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
205 | sorted_indices_to_remove[..., 0] = 0
206 |
207 | # scatter sorted tensors to original indexing
208 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
209 | scores = scores.masked_fill(indices_to_remove, self.filter_value)
210 | return scores
211 |
212 |
213 | class TopKLogitsWarper(LogitsWarper):
214 | r"""
215 | :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
216 |
217 | Args:
218 | top_k (:obj:`int`):
219 | The number of highest probability vocabulary tokens to keep for top-k-filtering.
220 | filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
221 | All filtered values will be set to this float value.
222 | min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
223 | Minimum number of tokens that cannot be filtered.
224 | """
225 |
226 | def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
227 | if not isinstance(top_k, int) or top_k <= 0:
228 | raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
229 |
230 | self.top_k = top_k
231 | self.filter_value = filter_value
232 | self.min_tokens_to_keep = min_tokens_to_keep
233 |
234 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
235 | top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
236 | # Remove all tokens with a probability less than the last token of the top-k
237 | indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
238 | scores = scores.masked_fill(indices_to_remove, self.filter_value)
239 | return scores
240 |
241 |
242 | def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
243 | generated_ngrams = [{} for _ in range(num_hypos)]
244 | for idx in range(num_hypos):
245 | gen_tokens = prev_input_ids[idx].tolist()
246 | generated_ngram = generated_ngrams[idx]
247 | for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
248 | prev_ngram_tuple = tuple(ngram[:-1])
249 | generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
250 | return generated_ngrams
251 |
252 |
253 | def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
254 | # Before decoding the next token, prevent decoding of ngrams that have already appeared
255 | start_idx = cur_len + 1 - ngram_size
256 | ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
257 | return banned_ngrams.get(ngram_idx, [])
258 |
259 |
260 | def _calc_banned_ngram_tokens(
261 | ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
262 | ) -> List[Iterable[int]]:
263 | """Copied from fairseq for no_repeat_ngram in beam_search"""
264 | if cur_len + 1 < ngram_size:
265 | # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
266 | return [[] for _ in range(num_hypos)]
267 |
268 | generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
269 |
270 | banned_tokens = [
271 | _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
272 | for hypo_idx in range(num_hypos)
273 | ]
274 | return banned_tokens
275 |
276 |
277 | class NoRepeatNGramLogitsProcessor(LogitsProcessor):
278 | r"""
279 | :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
280 | `__.
281 |
282 | Args:
283 | ngram_size (:obj:`int`):
284 | All ngrams of size :obj:`ngram_size` can only occur once.
285 | """
286 |
287 | def __init__(self, ngram_size: int):
288 | if not isinstance(ngram_size, int) or ngram_size <= 0:
289 | raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
290 | self.ngram_size = ngram_size
291 |
292 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
293 | num_batch_hypotheses = scores.shape[0]
294 | cur_len = input_ids.shape[-1]
295 | banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
296 |
297 | for i, banned_tokens in enumerate(banned_batch_tokens):
298 | scores[i, banned_tokens] = -float("inf")
299 |
300 | return scores
301 |
302 |
303 | class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
304 | r"""
305 | :class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids.
306 | See `ParlAI `__.
307 |
308 | Args:
309 | encoder_ngram_size (:obj:`int`):
310 | All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids.
311 | encoder_input_ids (:obj:`int`):
312 | The encoder_input_ids that should not be repeated within the decoder ids.
313 | """
314 |
315 | def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
316 | if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
317 | raise ValueError(
318 | f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
319 | )
320 | self.ngram_size = encoder_ngram_size
321 | if len(encoder_input_ids.shape) == 1:
322 | encoder_input_ids = encoder_input_ids.unsqueeze(0)
323 | self.batch_size = encoder_input_ids.shape[0]
324 | self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
325 |
326 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
327 | # B x num_beams
328 | num_hypos = scores.shape[0]
329 | num_beams = num_hypos // self.batch_size
330 | cur_len = input_ids.shape[-1]
331 | banned_batch_tokens = [
332 | _get_generated_ngrams(
333 | self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
334 | )
335 | for hypo_idx in range(num_hypos)
336 | ]
337 |
338 | for i, banned_tokens in enumerate(banned_batch_tokens):
339 | scores[i, banned_tokens] = -float("inf")
340 |
341 | return scores
342 |
343 |
344 | class NoBadWordsLogitsProcessor(LogitsProcessor):
345 | """
346 | :class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
347 |
348 | Args:
349 | bad_words_ids (:obj:`List[List[int]]`):
350 | List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
351 | that should not appear in the generated text, use :obj:`tokenizer(bad_word,
352 | add_prefix_space=True).input_ids`.
353 | eos_token_id (:obj:`int`):
354 | The id of the `end-of-sequence` token.
355 | """
356 |
357 | def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
358 |
359 | if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
360 | raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
361 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
362 | raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
363 | if any(
364 | any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
365 | for bad_word_ids in bad_words_ids
366 | ):
367 | raise ValueError(
368 | f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
369 | )
370 |
371 | self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
372 |
373 | for banned_token_seq in self.bad_words_ids:
374 | assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list"
375 |
376 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
377 | banned_tokens = self._calc_banned_bad_words_ids(input_ids)
378 | scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
379 |
380 | return scores
381 |
382 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
383 | if len(tokens) == 0:
384 | # if bad word tokens is just one token always ban it
385 | return True
386 | elif len(tokens) > len(prev_tokens):
387 | # if bad word tokens are longer then prev input_ids they can't be equal
388 | return False
389 | elif prev_tokens[-len(tokens) :].tolist() == tokens:
390 | # if tokens match
391 | return True
392 | else:
393 | return False
394 |
395 | def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
396 | banned_tokens = []
397 | for prev_input_ids_slice in prev_input_ids:
398 | banned_tokens_slice = []
399 | for banned_token_seq in self.bad_words_ids:
400 | if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
401 | # if tokens do not match continue
402 | continue
403 |
404 | banned_tokens_slice.append(banned_token_seq[-1])
405 |
406 | banned_tokens.append(banned_tokens_slice)
407 |
408 | return banned_tokens
409 |
410 | def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
411 | """
412 | Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
413 | list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
414 |
415 | Args:
416 | scores: logits distribution of shape (batch size, vocabulary size)
417 | banned_tokens: list of list of tokens to ban of length (batch_size)
418 | """
419 | banned_mask_list = []
420 | for idx, batch_banned_tokens in enumerate(banned_tokens):
421 | for token in batch_banned_tokens:
422 | # Eliminates invalid bad word IDs that are over the vocabulary size.
423 | if token <= scores.shape[1]:
424 | banned_mask_list.append([idx, token])
425 | else:
426 | logger.error(
427 | f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
428 | f"vocabulary, and is therefore ignored."
429 | )
430 | if not banned_mask_list:
431 | return scores
432 |
433 | banned_mask = torch.LongTensor(banned_mask_list)
434 | indices = torch.ones(len(banned_mask))
435 | # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
436 | # [ 0 1 1 ]
437 | # [ 0 0 0 ]
438 | # [ 1 0 0 ]
439 |
440 | banned_mask = (
441 | torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
442 | )
443 | scores = scores.masked_fill(banned_mask, -float("inf"))
444 | return scores
445 |
446 |
447 | class PrefixConstrainedLogitsProcessor(LogitsProcessor):
448 | r"""
449 | :class:`transformers.LogitsProcessor` that enforces constrained generation and is useful for prefix-conditioned
450 | constrained generation. See `Autoregressive Entity Retrieval `__ for more
451 | information.
452 |
453 | Args:
454 | prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
455 | This function constraints the beam search to allowed tokens only at each step. This function takes 2
456 | arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
457 | tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
458 | the batch ID :obj:`batch_id`.
459 | """
460 |
461 | def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
462 | self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
463 | self._num_beams = num_beams
464 |
465 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
466 | mask = torch.full_like(scores, -math.inf)
467 | for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
468 | for beam_id, sent in enumerate(beam_sent):
469 | mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
470 |
471 | return scores + mask
472 |
473 |
474 | class HammingDiversityLogitsProcessor(LogitsProcessor):
475 | r"""
476 | :class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only
477 | effective for :meth:`transformers.PreTrainedModel.group_beam_search`. See `Diverse Beam Search: Decoding Diverse
478 | Solutions from Neural Sequence Models `__ for more details.
479 |
480 | Args:
481 | diversity_penalty (:obj:`float`):
482 | This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
483 | particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled.
484 | num_beams (:obj:`int`):
485 | Number of beams used for group beam search. See `this paper `__ for
486 | more details.
487 | num_beam_groups (:obj:`int`):
488 | Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
489 | beams. See `this paper `__ for more details.
490 | """
491 |
492 | def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
493 | if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
494 | raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
495 | self._diversity_penalty = diversity_penalty
496 | if not isinstance(num_beams, int) or num_beams < 2:
497 | raise ValueError("`num_beams` should be an integer strictly larger than 1.")
498 | self._num_beams = num_beams
499 | if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
500 | raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
501 | if num_beam_groups > num_beams:
502 | raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
503 | self._num_sub_beams = num_beams // num_beam_groups
504 |
505 | def __call__(
506 | self,
507 | input_ids: torch.LongTensor,
508 | scores: torch.FloatTensor,
509 | current_tokens: torch.LongTensor,
510 | beam_group_idx: int,
511 | ) -> torch.FloatTensor:
512 | # hamming diversity: penalise using same token in current group which was used in previous groups at
513 | # the same time step
514 | batch_size = current_tokens.shape[0] // self._num_beams
515 | group_start_idx = beam_group_idx * self._num_sub_beams
516 | group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
517 | group_size = group_end_idx - group_start_idx
518 | vocab_size = scores.shape[-1]
519 |
520 | if group_start_idx == 0:
521 | return scores
522 |
523 | for batch_idx in range(batch_size):
524 | # predicted tokens of last time step of previous groups
525 | previous_group_tokens = current_tokens[
526 | batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
527 | ]
528 | token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
529 | scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
530 |
531 | return scores
532 |
533 |
534 | class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
535 | r"""
536 | :class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token.
537 |
538 | Args:
539 | bos_token_id (:obj:`int`):
540 | The id of the token to force as the first generated token.
541 | """
542 |
543 | def __init__(self, bos_token_id: int):
544 | self.bos_token_id = bos_token_id
545 |
546 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
547 | cur_len = input_ids.shape[-1]
548 | if cur_len == 1:
549 | num_tokens = scores.shape[1]
550 | scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
551 | scores[:, self.bos_token_id] = 0
552 | return scores
553 |
554 |
555 | class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
556 | r"""
557 | :class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when
558 | :obj:`max_length` is reached.
559 |
560 | Args:
561 | max_length (:obj:`int`):
562 | The maximum length of the sequence to be generated.
563 | eos_token_id (:obj:`int`):
564 | The id of the token to force as the last generated token when :obj:`max_length` is reached.
565 | """
566 |
567 | def __init__(self, max_length: int, eos_token_id: int):
568 | self.max_length = max_length
569 | self.eos_token_id = eos_token_id
570 |
571 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
572 | cur_len = input_ids.shape[-1]
573 | if cur_len == self.max_length - 1:
574 | num_tokens = scores.shape[1]
575 | scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
576 | scores[:, self.eos_token_id] = 0
577 | return scores
578 |
579 |
580 | class InfNanRemoveLogitsProcessor(LogitsProcessor):
581 | r"""
582 | :class:`~transformers.LogitsProcessor` that removes all :obj:`nan` and :obj:`inf` values to avoid the generation
583 | method to fail. Note that using the logits processor should only be used if necessary since it can slow down the
584 | generation method. :obj:`max_length` is reached.
585 | """
586 |
587 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
588 | # set all nan values to 0.0
589 | scores[scores != scores] = 0.0
590 |
591 | # set all inf values to max possible value
592 | scores[scores == float("inf")] = torch.finfo(scores.dtype).max
593 |
594 | return scores
595 |
--------------------------------------------------------------------------------
/hubert/generation_stopping_criteria.py:
--------------------------------------------------------------------------------
1 | import time
2 | import warnings
3 | from abc import ABC
4 | from copy import deepcopy
5 | from typing import Optional
6 |
7 | import torch
8 |
9 | from .file_utils import add_start_docstrings
10 |
11 |
12 | STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
13 | Args:
14 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
15 | Indices of input sequence tokens in the vocabulary.
16 |
17 | Indices can be obtained using :class:`~transformers.BertTokenizer`. See
18 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
19 | details.
20 |
21 | `What are input IDs? <../glossary.html#input-ids>`__
22 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
23 | Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
24 | or scores for each vocabulary token after SoftMax.
25 | kwargs:
26 | Additional stopping criteria specific kwargs.
27 |
28 | Return:
29 | :obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop.
30 |
31 | """
32 |
33 |
34 | class StoppingCriteria(ABC):
35 | """Abstract base class for all stopping criteria that can be applied during generation."""
36 |
37 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
38 | def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
39 | raise NotImplementedError("StoppingCriteria needs to be subclassed")
40 |
41 |
42 | class MaxLengthCriteria(StoppingCriteria):
43 | """
44 | This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`.
45 | Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
46 |
47 | Args:
48 | max_length (:obj:`int`):
49 | The maximum length that the output sequence can have in number of tokens.
50 | """
51 |
52 | def __init__(self, max_length: int):
53 | self.max_length = max_length
54 |
55 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
56 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
57 | return input_ids.shape[-1] >= self.max_length
58 |
59 |
60 | class MaxNewTokensCriteria(StoppingCriteria):
61 | """
62 | This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`.
63 | Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is
64 | very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens.
65 |
66 | Args:
67 | start_length (:obj:`int`):
68 | The number of initial tokens.
69 | max_new_tokens (:obj:`int`):
70 | The maximum number of tokens to generate.
71 | """
72 |
73 | def __init__(self, start_length: int, max_new_tokens: int):
74 | self.start_length = start_length
75 | self.max_new_tokens = max_new_tokens
76 | self.max_length = start_length + max_new_tokens
77 |
78 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
79 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
80 | return input_ids.shape[-1] >= self.max_length
81 |
82 |
83 | class MaxTimeCriteria(StoppingCriteria):
84 | """
85 | This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
86 | time will start being counted when you initialize this function. You can override this by passing an
87 | :obj:`initial_time`.
88 |
89 | Args:
90 | max_time (:obj:`float`):
91 | The maximum allowed time in seconds for the generation.
92 | initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`):
93 | The start of the generation allowed time.
94 | """
95 |
96 | def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
97 | self.max_time = max_time
98 | self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
99 |
100 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
101 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102 | return time.time() - self.initial_timestamp > self.max_time
103 |
104 |
105 | class StoppingCriteriaList(list):
106 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
107 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
108 | return any(criteria(input_ids, scores) for criteria in self)
109 |
110 | @property
111 | def max_length(self) -> Optional[int]:
112 | for stopping_criterium in self:
113 | if isinstance(stopping_criterium, MaxLengthCriteria):
114 | return stopping_criterium.max_length
115 | elif isinstance(stopping_criterium, MaxNewTokensCriteria):
116 | return stopping_criterium.max_length
117 | return None
118 |
119 |
120 | def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
121 | stopping_max_length = stopping_criteria.max_length
122 | new_stopping_criteria = deepcopy(stopping_criteria)
123 | if stopping_max_length is not None and stopping_max_length != max_length:
124 | warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
125 | elif stopping_max_length is None:
126 | new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
127 | return new_stopping_criteria
128 |
--------------------------------------------------------------------------------
/hubert/utils/logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Optuna, Hugging Face
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Logging utilities. """
16 |
17 | import logging
18 | import os
19 | import sys
20 | import threading
21 | from logging import CRITICAL # NOQA
22 | from logging import DEBUG # NOQA
23 | from logging import ERROR # NOQA
24 | from logging import FATAL # NOQA
25 | from logging import INFO # NOQA
26 | from logging import NOTSET # NOQA
27 | from logging import WARN # NOQA
28 | from logging import WARNING # NOQA
29 | from typing import Optional
30 |
31 |
32 | _lock = threading.Lock()
33 | _default_handler: Optional[logging.Handler] = None
34 |
35 | log_levels = {
36 | "debug": logging.DEBUG,
37 | "info": logging.INFO,
38 | "warning": logging.WARNING,
39 | "error": logging.ERROR,
40 | "critical": logging.CRITICAL,
41 | }
42 |
43 | _default_log_level = logging.WARNING
44 |
45 |
46 | def _get_default_logging_level():
47 | """
48 | If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
49 | not - fall back to ``_default_log_level``
50 | """
51 | env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
52 | if env_level_str:
53 | if env_level_str in log_levels:
54 | return log_levels[env_level_str]
55 | else:
56 | logging.getLogger().warning(
57 | f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
58 | f"has to be one of: { ', '.join(log_levels.keys()) }"
59 | )
60 | return _default_log_level
61 |
62 |
63 | def _get_library_name() -> str:
64 |
65 | return __name__.split(".")[0]
66 |
67 |
68 | def _get_library_root_logger() -> logging.Logger:
69 |
70 | return logging.getLogger(_get_library_name())
71 |
72 |
73 | def _configure_library_root_logger() -> None:
74 |
75 | global _default_handler
76 |
77 | with _lock:
78 | if _default_handler:
79 | # This library has already configured the library root logger.
80 | return
81 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
82 | _default_handler.flush = sys.stderr.flush
83 |
84 | # Apply our default configuration to the library root logger.
85 | library_root_logger = _get_library_root_logger()
86 | library_root_logger.addHandler(_default_handler)
87 | library_root_logger.setLevel(_get_default_logging_level())
88 | library_root_logger.propagate = False
89 |
90 |
91 | def _reset_library_root_logger() -> None:
92 |
93 | global _default_handler
94 |
95 | with _lock:
96 | if not _default_handler:
97 | return
98 |
99 | library_root_logger = _get_library_root_logger()
100 | library_root_logger.removeHandler(_default_handler)
101 | library_root_logger.setLevel(logging.NOTSET)
102 | _default_handler = None
103 |
104 |
105 | def get_logger(name: Optional[str] = None) -> logging.Logger:
106 | """
107 | Return a logger with the specified name.
108 |
109 | This function is not supposed to be directly accessed unless you are writing a custom transformers module.
110 | """
111 |
112 | if name is None:
113 | name = _get_library_name()
114 |
115 | _configure_library_root_logger()
116 | return logging.getLogger(name)
117 |
118 |
119 | def get_verbosity() -> int:
120 | """
121 | Return the current level for the 🤗 Transformers's root logger as an int.
122 |
123 | Returns:
124 | :obj:`int`: The logging level.
125 |
126 | .. note::
127 |
128 | 🤗 Transformers has following logging levels:
129 |
130 | - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
131 | - 40: ``transformers.logging.ERROR``
132 | - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
133 | - 20: ``transformers.logging.INFO``
134 | - 10: ``transformers.logging.DEBUG``
135 | """
136 |
137 | _configure_library_root_logger()
138 | return _get_library_root_logger().getEffectiveLevel()
139 |
140 |
141 | def set_verbosity(verbosity: int) -> None:
142 | """
143 | Set the verbosity level for the 🤗 Transformers's root logger.
144 |
145 | Args:
146 | verbosity (:obj:`int`):
147 | Logging level, e.g., one of:
148 |
149 | - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
150 | - ``transformers.logging.ERROR``
151 | - ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
152 | - ``transformers.logging.INFO``
153 | - ``transformers.logging.DEBUG``
154 | """
155 |
156 | _configure_library_root_logger()
157 | _get_library_root_logger().setLevel(verbosity)
158 |
159 |
160 | def set_verbosity_info():
161 | """Set the verbosity to the :obj:`INFO` level."""
162 | return set_verbosity(INFO)
163 |
164 |
165 | def set_verbosity_warning():
166 | """Set the verbosity to the :obj:`WARNING` level."""
167 | return set_verbosity(WARNING)
168 |
169 |
170 | def set_verbosity_debug():
171 | """Set the verbosity to the :obj:`DEBUG` level."""
172 | return set_verbosity(DEBUG)
173 |
174 |
175 | def set_verbosity_error():
176 | """Set the verbosity to the :obj:`ERROR` level."""
177 | return set_verbosity(ERROR)
178 |
179 |
180 | def disable_default_handler() -> None:
181 | """Disable the default handler of the HuggingFace Transformers's root logger."""
182 |
183 | _configure_library_root_logger()
184 |
185 | assert _default_handler is not None
186 | _get_library_root_logger().removeHandler(_default_handler)
187 |
188 |
189 | def enable_default_handler() -> None:
190 | """Enable the default handler of the HuggingFace Transformers's root logger."""
191 |
192 | _configure_library_root_logger()
193 |
194 | assert _default_handler is not None
195 | _get_library_root_logger().addHandler(_default_handler)
196 |
197 |
198 | def add_handler(handler: logging.Handler) -> None:
199 | """adds a handler to the HuggingFace Transformers's root logger."""
200 |
201 | _configure_library_root_logger()
202 |
203 | assert handler is not None
204 | _get_library_root_logger().addHandler(handler)
205 |
206 |
207 | def remove_handler(handler: logging.Handler) -> None:
208 | """removes given handler from the HuggingFace Transformers's root logger."""
209 |
210 | _configure_library_root_logger()
211 |
212 | assert handler is not None and handler not in _get_library_root_logger().handlers
213 | _get_library_root_logger().removeHandler(handler)
214 |
215 |
216 | def disable_propagation() -> None:
217 | """
218 | Disable propagation of the library log outputs. Note that log propagation is disabled by default.
219 | """
220 |
221 | _configure_library_root_logger()
222 | _get_library_root_logger().propagate = False
223 |
224 |
225 | def enable_propagation() -> None:
226 | """
227 | Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
228 | prevent double logging if the root logger has been configured.
229 | """
230 |
231 | _configure_library_root_logger()
232 | _get_library_root_logger().propagate = True
233 |
234 |
235 | def enable_explicit_format() -> None:
236 | """
237 | Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
238 |
239 | ::
240 |
241 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
242 |
243 | All handlers currently bound to the root logger are affected by this method.
244 | """
245 | handlers = _get_library_root_logger().handlers
246 |
247 | for handler in handlers:
248 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
249 | handler.setFormatter(formatter)
250 |
251 |
252 | def reset_format() -> None:
253 | """
254 | Resets the formatting for HuggingFace Transformers's loggers.
255 |
256 | All handlers currently bound to the root logger are affected by this method.
257 | """
258 | handlers = _get_library_root_logger().handlers
259 |
260 | for handler in handlers:
261 | handler.setFormatter(None)
262 |
--------------------------------------------------------------------------------
/hubert/utils/versions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Utilities for working with package versions
16 | """
17 |
18 | import operator
19 | import re
20 | import sys
21 | from typing import Optional
22 |
23 | from packaging import version
24 |
25 |
26 | # The package importlib_metadata is in a different place, depending on the python version.
27 | if sys.version_info < (3, 8):
28 | import importlib_metadata
29 | else:
30 | import importlib.metadata as importlib_metadata
31 |
32 |
33 | ops = {
34 | "<": operator.lt,
35 | "<=": operator.le,
36 | "==": operator.eq,
37 | "!=": operator.ne,
38 | ">=": operator.ge,
39 | ">": operator.gt,
40 | }
41 |
42 |
43 | def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
44 | if got_ver is None:
45 | raise ValueError("got_ver is None")
46 | if want_ver is None:
47 | raise ValueError("want_ver is None")
48 | if not ops[op](version.parse(got_ver), version.parse(want_ver)):
49 | raise ImportError(
50 | f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
51 | )
52 |
53 |
54 | def require_version(requirement: str, hint: Optional[str] = None) -> None:
55 | """
56 | Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
57 |
58 | The installed module version comes from the `site-packages` dir via `importlib_metadata`.
59 |
60 | Args:
61 | requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
62 | hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
63 |
64 | Example::
65 |
66 | require_version("pandas>1.1.2")
67 | require_version("numpy>1.18.5", "this is important to have for whatever reason")
68 |
69 | """
70 |
71 | hint = f"\n{hint}" if hint is not None else ""
72 |
73 | # non-versioned check
74 | if re.match(r"^[\w_\-\d]+$", requirement):
75 | pkg, op, want_ver = requirement, None, None
76 | else:
77 | match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
78 | if not match:
79 | raise ValueError(
80 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
81 | )
82 | pkg, want_full = match[0]
83 | want_range = want_full.split(",") # there could be multiple requirements
84 | wanted = {}
85 | for w in want_range:
86 | match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
87 | if not match:
88 | raise ValueError(
89 | f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
90 | )
91 | op, want_ver = match[0]
92 | wanted[op] = want_ver
93 | if op not in ops:
94 | raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
95 |
96 | # special case
97 | if pkg == "python":
98 | got_ver = ".".join([str(x) for x in sys.version_info[:3]])
99 | for op, want_ver in wanted.items():
100 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
101 | return
102 |
103 | # check if any version is installed
104 | try:
105 | got_ver = importlib_metadata.version(pkg)
106 | except importlib_metadata.PackageNotFoundError:
107 | raise importlib_metadata.PackageNotFoundError(
108 | f"The '{requirement}' distribution was not found and is required by this application. {hint}"
109 | )
110 |
111 | # check that the right version is installed if version number or a range was provided
112 | if want_ver is not None:
113 | for op, want_ver in wanted.items():
114 | _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
115 |
116 |
117 | def require_version_core(requirement):
118 | """require_version wrapper which emits a core-specific hint on failure"""
119 | hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master"
120 | return require_version(requirement, hint)
121 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
118 |
119 |
120 |
121 | FaceXHuBERT
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
136 |
137 |
138 |
139 |
140 |
141 | FaceXHuBERT: Text-less Speech-driven E(X)pressive 3D Facial Animation Synthesis Using Self-Supervised Speech Representation Learning (ICMI '23)
142 |
208 | This paper presents FaceXHuBERT, a text-less speech-driven 3D facial animation generation method that allows to capture personalized and subtle cues in speech (e.g. identity, emotion and hesitation). It is also very robust to background noise and can handle audio recorded in a variety of situations (e.g. multiple people speaking). Recent approaches employ end-to-end deep learning taking into account both audio and text as input to generate facial animation for the whole face. However, scarcity of publicly available expressive audio-3D facial animation datasets poses a major bottleneck. The resulting animations still have issues regarding accurate lip-synching, expressivity, person-specific information and generalizability. We effectively employ self-supervised pretrained HuBERT model in the training process that allows us to incorporate both lexical and non-lexical information in the audio without using a large lexicon. Additionally, guiding the training with a binary emotion condition and speaker identity distinguishes the tiniest subtle facial motion. We carried out extensive objective and subjective evaluation in comparison to ground truth and state-of-the-art work. A perceptual user study demonstrates that our approach produces superior results with respect to the realism of the animation 78% of the time in comparison to the state-of-the-art. In addition, our method is 4 times faster eliminating the use of complex sequential models such as transformers. We strongly recommend watching the supplementary video before reading the paper.
209 |