├── README.md
├── ala2
├── .DS_Store
├── 1_unbiased
│ ├── INPUTS.A
│ └── INPUTS.B
├── 2_training_model
│ ├── model.pt
│ └── training.png
└── 3_enhanced_sampling
│ ├── ala2.tpr
│ └── plumed.dat
├── aldol
├── .DS_Store
├── 1_unbiased
│ ├── INPUTS.P
│ ├── INPUTS.R
│ ├── inp.Aldol_pm6-P
│ └── inp.Aldol_pm6-R
├── 2_training_model
│ ├── full_model.pt
│ └── model.pt
└── 3_enhanced_sampling
│ ├── Aldol_pm6.inp
│ ├── Contacts.cpp
│ └── plumed.dat
└── code
├── PytorchModel.cpp
└── Tutorial - DeepLDA training.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # Data-driven collective variables for enhanced sampling
2 | #### Luigi Bonati, Valerio Rizzi, and Michele Parrinello, _J. Phys. Chem. Lett._ 11, 2998-3004 (2020).
3 |
4 | [](https://doi.org/10.1021/acs.jpclett.0c00535)
5 | [](https://arxiv.org/abs/2002.06562)
6 | [](https://www.plumed-nest.org/eggs/20/004/)
7 | [](https://doi.org/10.24435/materialscloud:2020.0035/v1)
8 |
9 | > [!IMPORTANT]
10 | > This repository is kept as supporting material for the manuscript, but it is no longer updated. Check out the [mlcolvar](https://mlcolvar.readthedocs.io) library for data-driven CVs, where you can find up-to-date tutorials and examples.
11 | >
12 | > [
](https://mlcolvar.readthedocs.io)
13 |
14 |
15 | This repository contains:
16 | 1. the code necessary to train and use a neural-network collective variable optimized with Fisher's discriminant analysis
17 | 2. input file to reproduce the simulations reported in the paper
18 |
19 | #### Requirements
20 | - Pytorch and LibTorch (v == 1.4)
21 | - PLUMED2
22 |
23 | #### Tutorial
24 | Here you can find a [Google Colab notebook](https://colab.research.google.com/drive/1dG0ohT75R-UZAFMf_cbYPNQwBaOsVaAA) with the code and instructions for Deep-LDA CV training and export.
25 |
26 | #### Code and results availability
27 | The code and input files are available also on the [PLUMED-NEST](https://www.plumed-nest.org/eggs/20/004/) archive while the results of the simulations are available in the [Materials Cloud repository](https://archive.materialscloud.org/2020.0035/v1).
28 |
29 | #### Errata
30 | There is a typo in the definition of
below eq. 7. The correct formula is
31 |
.
32 |
33 | #### Contact
34 | If you have comments or questions please send an email to luigi bonati [at] phys chem ethz ch .
35 |
--------------------------------------------------------------------------------
/ala2/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/ala2/.DS_Store
--------------------------------------------------------------------------------
/ala2/2_training_model/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/ala2/2_training_model/model.pt
--------------------------------------------------------------------------------
/ala2/2_training_model/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/ala2/2_training_model/training.png
--------------------------------------------------------------------------------
/ala2/3_enhanced_sampling/ala2.tpr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/ala2/3_enhanced_sampling/ala2.tpr
--------------------------------------------------------------------------------
/ala2/3_enhanced_sampling/plumed.dat:
--------------------------------------------------------------------------------
1 | # vim:ft=plumed
2 |
3 | #DEFINE INPUTS: DISTANCES
4 | d1: DISTANCE ATOMS=2,5
5 | d2: DISTANCE ATOMS=2,6
6 | d3: DISTANCE ATOMS=2,7
7 | d4: DISTANCE ATOMS=2,9
8 | d5: DISTANCE ATOMS=2,11
9 | d6: DISTANCE ATOMS=2,15
10 | d7: DISTANCE ATOMS=2,16
11 | d8: DISTANCE ATOMS=2,17
12 | d9: DISTANCE ATOMS=2,19
13 | d10: DISTANCE ATOMS=5,6
14 | d11: DISTANCE ATOMS=5,7
15 | d12: DISTANCE ATOMS=5,9
16 | d13: DISTANCE ATOMS=5,11
17 | d14: DISTANCE ATOMS=5,15
18 | d15: DISTANCE ATOMS=5,16
19 | d16: DISTANCE ATOMS=5,17
20 | d17: DISTANCE ATOMS=5,19
21 | d18: DISTANCE ATOMS=6,7
22 | d19: DISTANCE ATOMS=6,9
23 | d20: DISTANCE ATOMS=6,11
24 | d21: DISTANCE ATOMS=6,15
25 | d22: DISTANCE ATOMS=6,16
26 | d23: DISTANCE ATOMS=6,17
27 | d24: DISTANCE ATOMS=6,19
28 | d25: DISTANCE ATOMS=7,9
29 | d26: DISTANCE ATOMS=7,11
30 | d27: DISTANCE ATOMS=7,15
31 | d28: DISTANCE ATOMS=7,16
32 | d29: DISTANCE ATOMS=7,17
33 | d30: DISTANCE ATOMS=7,19
34 | d31: DISTANCE ATOMS=9,11
35 | d32: DISTANCE ATOMS=9,15
36 | d33: DISTANCE ATOMS=9,16
37 | d34: DISTANCE ATOMS=9,17
38 | d35: DISTANCE ATOMS=9,19
39 | d36: DISTANCE ATOMS=11,15
40 | d37: DISTANCE ATOMS=11,16
41 | d38: DISTANCE ATOMS=11,17
42 | d39: DISTANCE ATOMS=11,19
43 | d40: DISTANCE ATOMS=15,16
44 | d41: DISTANCE ATOMS=15,17
45 | d42: DISTANCE ATOMS=15,19
46 | d43: DISTANCE ATOMS=16,17
47 | d44: DISTANCE ATOMS=16,19
48 | d45: DISTANCE ATOMS=17,19
49 |
50 | #LOAD PYTORCH MODEL
51 | deep: PYTORCH_MODEL FILE=../2_training_model/model.pt ARG=d1,d2,d3,d4,d5,d6,d7,d8,d9,d10,d11,d12,d13,d14,d15,d16,d17,d18,d19,d20,d21,d22,d23,d24,d25,d26,d27,d28,d29,d30,d31,d32,d33,d34,d35,d36,d37,d38,d39,d40,d41,d42,d43,d44,d45
52 |
53 | #DEFINE OPES CALCULATION
54 | OPES_METAD ...
55 | LABEL=opes
56 | ARG=deep.node-0
57 | PACE=500
58 | SIGMA=0.025
59 | BARRIER=30
60 | ... OPES_METAD
61 |
62 | #MONITOR DIHEDRAL ANGLES
63 | phi: TORSION ATOMS=5,7,9,15
64 | psi: TORSION ATOMS=7,9,15,17
65 |
66 | #PRINT
67 | PRINT STRIDE=500 ARG=deep.node-0,phi,psi,opes.* FILE=COLVAR
68 |
69 | ENDPLUMED
70 |
71 | #alternative: WT-METAD CALCULATION
72 | METAD ...
73 | LABEL=meta
74 | ARG=deep.node-0
75 | PACE=500
76 | HEIGHT=1.25
77 | SIGMA=0.025
78 | BIASFACTOR=6.00
79 | TEMP=300.0
80 | GRID_MIN=-2
81 | GRID_MAX=2
82 | GRID_BIN=750
83 | CALC_RCT
84 | ... METAD
85 |
--------------------------------------------------------------------------------
/aldol/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/aldol/.DS_Store
--------------------------------------------------------------------------------
/aldol/1_unbiased/inp.Aldol_pm6-P:
--------------------------------------------------------------------------------
1 | &GLOBAL
2 | PROJECT Aldol_pm6_P
3 | RUN_TYPE MD
4 | &END GLOBAL
5 |
6 | #&EXT_RESTART
7 | #RESTART_FILE_NAME Aldol_pm6_P-1.restart
8 | #&END
9 |
10 | &MOTION
11 | &MD
12 | ENSEMBLE NVT
13 | STEPS 10000000
14 | TIMESTEP 0.5
15 | TEMPERATURE 300
16 | &THERMOSTAT
17 | &CSVR
18 | TIMECON 100
19 | &END
20 | &END THERMOSTAT
21 | &END MD
22 | &PRINT
23 | &VELOCITIES OFF
24 | &END
25 | &FORCES
26 | &EACH
27 | MD 200
28 | &END
29 | &END
30 | &TRAJECTORY
31 | &EACH
32 | MD 200
33 | &END
34 | &END
35 | &RESTART
36 | &EACH
37 | MD 200
38 | &END
39 | &END
40 | &RESTART_HISTORY
41 | &EACH
42 | MD 20000
43 | &END
44 | &END
45 | &END PRINT
46 |
47 | &FREE_ENERGY
48 | &METADYN
49 | USE_PLUMED .TRUE.
50 | PLUMED_INPUT_FILE ./plumed.dat
51 | &END METADYN
52 | &END FREE_ENERGY
53 |
54 | &END MOTION
55 |
56 | &FORCE_EVAL
57 | METHOD Quickstep
58 | &DFT
59 | &QS
60 | METHOD PM6
61 | &SE
62 | &END
63 | &END QS
64 | &SCF
65 | SCF_GUESS ATOMIC
66 | # SCF_GUESS RESTART
67 | &END SCF
68 | &END DFT
69 | &SUBSYS
70 | &CELL
71 | ABC 15.0 15.0 15.0
72 | PERIODIC NONE
73 | &END CELL
74 | &COORD
75 | C -1.7398459870 1.1338848810 2.2335126395
76 | H -2.5828441792 1.7963437143 2.0707686162
77 | H -2.1967426647 0.1002722800 2.2184518098
78 | C -0.9120671394 1.4445501695 1.0289301080
79 | H -1.3917437438 1.0910381640 0.1579462762
80 | O 0.1430652785 1.9656639332 0.9898175075
81 | H -2.0915615390 3.1244075320 3.6611573450
82 | C -1.0171291761 1.4112560790 3.5721405872
83 | H -0.0051832519 1.7402738743 3.4158576545
84 | H -1.0077374295 0.6475586503 4.3768979920
85 | O -1.7899346604 2.4130111648 4.2532004479
86 | &END COORD
87 | &END SUBSYS
88 | &END FORCE_EVAL
89 |
90 |
--------------------------------------------------------------------------------
/aldol/1_unbiased/inp.Aldol_pm6-R:
--------------------------------------------------------------------------------
1 | &GLOBAL
2 | PROJECT Aldol_pm6_R
3 | RUN_TYPE MD
4 | &END GLOBAL
5 |
6 | #&EXT_RESTART
7 | #RESTART_FILE_NAME Aldol_pm6_R-1.restart
8 | #&END
9 |
10 | &MOTION
11 | &MD
12 | ENSEMBLE NVT
13 | STEPS 10000000
14 | TIMESTEP 0.5
15 | TEMPERATURE 300
16 | &THERMOSTAT
17 | &CSVR
18 | TIMECON 100
19 | &END
20 | &END THERMOSTAT
21 | &END MD
22 | &PRINT
23 | &VELOCITIES OFF
24 | &END
25 | &FORCES
26 | &EACH
27 | MD 200
28 | &END
29 | &END
30 | &TRAJECTORY
31 | &EACH
32 | MD 200
33 | &END
34 | &END
35 | &RESTART
36 | &EACH
37 | MD 200
38 | &END
39 | &END
40 | &RESTART_HISTORY
41 | &EACH
42 | MD 20000
43 | &END
44 | &END
45 | &END PRINT
46 |
47 | &FREE_ENERGY
48 | &METADYN
49 | USE_PLUMED .TRUE.
50 | PLUMED_INPUT_FILE ./plumed.dat
51 | &END METADYN
52 | &END FREE_ENERGY
53 |
54 | &END MOTION
55 |
56 | &FORCE_EVAL
57 | METHOD Quickstep
58 | &DFT
59 | &QS
60 | METHOD PM6
61 | &SE
62 | &END
63 | &END QS
64 | &SCF
65 | SCF_GUESS ATOMIC
66 | # SCF_GUESS RESTART
67 | &END SCF
68 | &END DFT
69 | &SUBSYS
70 | &CELL
71 | ABC 15.0 15.0 15.0
72 | PERIODIC NONE
73 | &END CELL
74 | &COORD
75 | C -1.82559014 1.24990204 1.15842648
76 | H -1.29242639 0.32219712 1.15842648
77 | H -2.89559014 1.24990204 1.15842648
78 | C -1.15031583 2.42487933 1.15842648
79 | H -1.68347958 3.35258425 1.15842648
80 | O 0.27968417 2.42487933 1.15842648
81 | H 0.60013876 3.32981517 1.15840911
82 | C -1.75435327 0.62506239 4.36185501
83 | H -1.22118953 -0.30264253 4.36185501
84 | H -2.82435327 0.62506239 4.36185501
85 | O -1.12731284 1.71611273 4.36185501
86 | &END COORD
87 | &END SUBSYS
88 | &END FORCE_EVAL
89 |
90 |
--------------------------------------------------------------------------------
/aldol/2_training_model/full_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/aldol/2_training_model/full_model.pt
--------------------------------------------------------------------------------
/aldol/2_training_model/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luigibonati/data-driven-CVs/0f2f50c66bd5e3f8bacc5fd07e7350e93ff6742d/aldol/2_training_model/model.pt
--------------------------------------------------------------------------------
/aldol/3_enhanced_sampling/Aldol_pm6.inp:
--------------------------------------------------------------------------------
1 | &GLOBAL
2 | PROJECT Aldol_pm6_R
3 | RUN_TYPE MD
4 | PRINT_LEVEL SILENT
5 | &END GLOBAL
6 |
7 | !&EXT_RESTART
8 | !RESTART_FILE_NAME Aldol_pm6_R-1.restart
9 | !&END
10 |
11 | &MOTION
12 | &MD
13 | ENSEMBLE NVT
14 | STEPS 20000000
15 | TIMESTEP 0.5
16 | TEMPERATURE 300
17 | &THERMOSTAT
18 | &CSVR
19 | TIMECON 100
20 | &END
21 | &END THERMOSTAT
22 | &END MD
23 | &PRINT
24 | &VELOCITIES OFF
25 | &END
26 | &FORCES
27 | &EACH
28 | MD 4000
29 | &END
30 | &END
31 | &TRAJECTORY
32 | &EACH
33 | MD 20
34 | &END
35 | &END
36 | &RESTART
37 | &EACH
38 | MD 1000
39 | &END
40 | &END
41 | &RESTART_HISTORY
42 | &EACH
43 | MD 50000
44 | &END
45 | &END
46 | &END PRINT
47 |
48 | &FREE_ENERGY
49 | &METADYN
50 | USE_PLUMED .TRUE.
51 | PLUMED_INPUT_FILE ./plumed.dat
52 | &END METADYN
53 | &END FREE_ENERGY
54 |
55 | &END MOTION
56 |
57 | &FORCE_EVAL
58 | METHOD Quickstep
59 | &DFT
60 | &QS
61 | METHOD PM6
62 | &END QS
63 | &SCF
64 | SCF_GUESS ATOMIC
65 | # SCF_GUESS RESTART
66 | &END SCF
67 | &END DFT
68 | &SUBSYS
69 | &CELL
70 | ABC 15.0 15.0 15.0
71 | PERIODIC NONE
72 | &END CELL
73 | &COORD
74 | C -0.4380529066 1.0289817184 2.4218640806
75 | H -0.8258218827 -0.0042005318 2.5547137238
76 | H 0.6310002384 0.9815436318 2.6588469857
77 | C -0.7536249062 1.4649971834 0.9661402699
78 | H -0.4400493227 0.6859268302 0.2052006415
79 | O -1.3955297332 2.4535993157 0.7293392173
80 | H -2.5800939393 1.4446323568 4.1921888754
81 | C -0.9189573958 2.0559768551 3.4418190661
82 | H -1.5189990031 2.8768450679 3.0258842458
83 | H -0.1063157930 2.5344538387 3.9092283136
84 | O -1.7006634559 1.5032705498 4.4766189541
85 | &END COORD
86 | &END SUBSYS
87 | &END FORCE_EVAL
88 |
89 |
--------------------------------------------------------------------------------
/aldol/3_enhanced_sampling/Contacts.cpp:
--------------------------------------------------------------------------------
1 | /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | Copyright (c) 2011-2018 The plumed team
3 | (see the PEOPLE file at the root of the distribution for a list of names)
4 |
5 | See http://www.plumed.org for more information.
6 |
7 | This file is part of plumed, version 2.
8 |
9 | plumed is free software: you can redistribute it and/or modify
10 | it under the terms of the GNU Lesser General Public License as published by
11 | the Free Software Foundation, either version 3 of the License, or
12 | (at your option) any later version.
13 |
14 | plumed is distributed in the hope that it will be useful,
15 | but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | GNU Lesser General Public License for more details.
18 |
19 | You should have received a copy of the GNU Lesser General Public License
20 | along with plumed. If not, see .
21 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 | #include "core/Colvar.h"
23 | #include "core/ActionRegister.h"
24 | #include "tools/Pbc.h"
25 |
26 | #include
27 | #include
28 |
29 | #include
30 | #include
31 | #include
32 |
33 | #include "tools/Tools.h"
34 | #include "tools/SwitchingFunction.h"
35 | #include
36 |
37 | using namespace std;
38 |
39 | namespace PLMD {
40 | namespace colvar {
41 |
42 |
43 | class Contacts : public Colvar {
44 | bool components;
45 | bool reorderon;
46 | bool pbc;
47 | SwitchingFunction switchingFunction;
48 | bool twogroups;
49 |
50 | public:
51 | static void registerKeywords( Keywords& keys );
52 | explicit Contacts(const ActionOptions&);
53 | //active methods:
54 | virtual void calculate();
55 | vector atomsa;
56 | vector atomsb;
57 | int num_atomsa;
58 | int num_atomsb;
59 | int num_dist;
60 | };
61 |
62 | PLUMED_REGISTER_ACTION(Contacts,"CONTACTS")
63 |
64 | void Contacts::registerKeywords( Keywords& keys ) {
65 | Colvar::registerKeywords( keys );
66 | keys.addFlag("REORDER",false,"reorder the list of contacts");
67 | keys.addFlag("COMPONENTS",true,"calculate the x, y and z components of the Contacts separately and store them as label.x, label.y and label.z");
68 | keys.add("atoms","GROUPA","First list of atoms");
69 | keys.add("atoms","GROUPB","Second list of atoms (if empty, N*(N-1)/2 pairs in GROUPA are counted)");
70 | keys.add("compulsory","NN","6","The n parameter of the switching function ");
71 | keys.add("compulsory","MM","0","The m parameter of the switching function; 0 implies 2*NN");
72 | keys.add("compulsory","D_0","0.0","The d_0 parameter of the switching function");
73 | keys.add("compulsory","R_0","The r_0 parameter of the switching function");
74 | keys.add("optional","SWITCH","This keyword is used if you want to employ an alternative to the continuous swiching function defined above. "
75 | "The following provides information on the \\ref switchingfunction that are available. "
76 | "When this keyword is present you no longer need the NN, MM, D_0 and R_0 keywords.");
77 | keys.addOutputComponent("c", "default", "contacts");
78 | }
79 |
80 | Contacts::Contacts(const ActionOptions&ao):
81 | PLUMED_COLVAR_INIT(ao),
82 | components(true),
83 | reorderon(false),
84 | pbc(true),
85 | twogroups(true)
86 | {
87 | parseAtomList("GROUPA",atomsa);
88 | parseAtomList("GROUPB",atomsb);
89 | num_atomsa = atomsa.size();
90 | num_atomsb = atomsb.size();
91 |
92 | vector atoms;
93 | for(unsigned int i=0; i0){
116 | num_dist = num_atomsa*num_atomsb;
117 | } else {
118 | num_dist = (num_atomsa*(num_atomsa-1))/2;
119 | twogroups = false;
120 | }
121 |
122 | log.printf(" between two groups of %u and %u atoms\n",static_cast(num_atomsa),static_cast(num_atomsb));
123 | log.printf(" first group:\n");
124 | for(unsigned int i=0;i0){
142 | switchingFunction.set(sw,errors);
143 | if( errors.length()!=0 ) error("problem reading SWITCH keyword : " + errors );
144 | } else {
145 | double r_0=-1.0, d_0; int nn, mm;
146 | parse("NN",nn); parse("MM",mm);
147 | parse("R_0",r_0); parse("D_0",d_0);
148 | if( r_0<0.0 ) error("you must set a value for R_0");
149 | switchingFunction.set(nn,mm,r_0,d_0);
150 | }
151 |
152 | //log.printf(" Tetrahedral order parameter calculated computing the angles with the atoms within %s\n",( switchingFunction.description() ).c_str() );
153 |
154 | if(components) {
155 | for(int i=0; i atom_pos=getPositions();
174 | int N_atoms = atom_pos.size();
175 |
176 | std::vector dist(num_dist);
177 | std::vector invdist(num_dist);
178 | std::vector distmod(num_dist);
179 | double sw, df, df2; //switching functions and derivatives
180 |
181 | int ind = 0;
182 | if(twogroups) {
183 | for(int i=0; i indices(num_dist);
211 | std::iota(begin(indices), end(indices), 0);
212 | std::sort(begin(indices), end(indices),
213 | [&distmod](int lhs, int rhs) { return distmod[lhs] < distmod[rhs]; });
214 | //std::copy(begin(indices), end(indices), std::ostream_iterator(std::cout, " "));
215 | std::vector indicesnew(num_dist);
216 | for(int k=0; k A = {6.5, 2.2, 5.6, 7.1};
223 | std::array indices;
224 | std::iota(begin(indices), end(indices), 0); // initialize the index array
225 | // sort with custom comparator, that compares items not by their own value
226 | // but by the value of the corresponding entries in the original array.
227 | std::sort(begin(indices), end(indices),
228 | [&A](int lhs, int rhs) { return A[lhs] < A[rhs]; });
229 | std::copy(begin(indices), end(indices), std::ostream_iterator(std::cout, " "));
230 | */
231 |
232 | if(components) {
233 | int ind = 0;
234 | if(twogroups) {
235 | for(int i=0; iset(sw);
258 | }
259 | else{
260 | auto label = "c-"+std::to_string(ind); //NON REORDERED
261 |
262 | Value* val=getPntrToComponent(label);
263 | setAtomsDerivatives (val,i,-invdist[ind]*dist[ind]*df2);
264 | setAtomsDerivatives (val,j+num_atomsa,invdist[ind]*dist[ind]*df2);
265 | setBoxDerivativesNoPbc(val);
266 | val->set(sw);
267 | }
268 |
269 | ind = ind+1;
270 | }
271 | }
272 | } else{
273 | for(int i=0; iset(sw);
296 | }
297 | else{
298 | auto label = "c-"+std::to_string(ind); //NON REORDERED
299 |
300 | Value* val=getPntrToComponent(label);
301 | setAtomsDerivatives (val,i,-invdist[ind]*dist[ind]*df2);
302 | setAtomsDerivatives (val,j,invdist[ind]*dist[ind]*df2);
303 | setBoxDerivativesNoPbc(val);
304 | val->set(sw);
305 | }
306 |
307 | ind = ind+1;
308 | }
309 |
310 | }
311 | }
312 | }
313 |
314 | }
315 |
316 | }
317 | }
318 |
319 |
320 |
321 |
--------------------------------------------------------------------------------
/aldol/3_enhanced_sampling/plumed.dat:
--------------------------------------------------------------------------------
1 | # vim:ft=plumed
2 |
3 | #UNITS
4 | UNITS LENGTH=A
5 |
6 | #LOAD FILES
7 | LOAD FILE=Contacts.cpp
8 |
9 | #DEFINE GROUP OF ATOMS
10 | C: GROUP ATOMS=1,4,8
11 | O: GROUP ATOMS=6,11
12 | H: GROUP ATOMS=2,3,5,7,9,10
13 |
14 | #DEFINE CONTACTS
15 | cc2: CONTACTS GROUPA=C SWITCH={RATIONAL D_0=0.0 R_0=1.7 NN=6 MM=8}
16 | oo2: CONTACTS GROUPA=O SWITCH={RATIONAL D_0=0.0 R_0=1.6 NN=6 MM=8}
17 | co2: CONTACTS GROUPA=C GROUPB=O SWITCH={RATIONAL D_0=0.0 R_0=1.6 NN=6 MM=8}
18 | ch2: CONTACTS GROUPA=C GROUPB=H SWITCH={RATIONAL D_0=0.0 R_0=1.2 NN=6 MM=8}
19 | oh2: CONTACTS GROUPA=O GROUPB=H SWITCH={RATIONAL D_0=0.0 R_0=1.2 NN=6 MM=8}
20 |
21 | #LOAD PYTHORCH MODEL
22 | auto: PYTORCH_MODEL FILE=../2_training_model/model.pt ARG=cc2.*,oo2.*,co2.*,ch2.*,oh2.*
23 | NNcube: MATHEVAL ARG=auto.node-0 FUNC=x+x^3 PERIODIC=NO
24 |
25 | #DEFINE WALLS WALLS
26 | com1: COM ATOMS=1-7
27 | com2: COM ATOMS=8-11
28 | dc1: DISTANCE ATOMS=com1,com2 NOPBC
29 | UPPER_WALLS ARG=dc1 AT=+5.0 KAPPA=150.0 EXP=2 LABEL=uwall_1
30 |
31 | nnuw: UPPER_WALLS ARG=NNcube AT=3.2 KAPPA=2000.0 EXP=2
32 | nnlw: LOWER_WALLS ARG=NNcube AT=-3.2 KAPPA=2000.0 EXP=2
33 |
34 | #OPES CALCULATION
35 | OPES_METAD ...
36 | LABEL=opes
37 | ARG=NNcube
38 | PACE=500
39 | BARRIER=160
40 | SIGMA=0.10
41 | TEMP=300
42 | ... OPES_METAD
43 |
44 | #MONITOR DISTANCES
45 | c1c8: DISTANCE ATOMS=1,8 NOPBC
46 | o6h7: DISTANCE ATOMS=6,7 NOPBC
47 | o11h7: DISTANCE ATOMS=11,7 NOPBC
48 |
49 | #PRINT
50 | PRINT STRIDE=20 FILE=COLVAR ARG=*
51 | FLUSH STRIDE=1000
52 |
--------------------------------------------------------------------------------
/code/PytorchModel.cpp:
--------------------------------------------------------------------------------
1 | /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | Copyright (c) 2011-2018 The plumed team
3 | (see the PEOPLE file at the root of the distribution for a list of names)
4 |
5 | See http://www.plumed.org for more information.
6 |
7 | This file is part of plumed, version 2.
8 |
9 | plumed is free software: you can redistribute it and/or modify
10 | it under the terms of the GNU Lesser General Public License as published by
11 | the Free Software Foundation, either version 3 of the License, or
12 | (at your option) any later version.
13 |
14 | plumed is distributed in the hope that it will be useful,
15 | but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | GNU Lesser General Public License for more details.
18 |
19 | You should have received a copy of the GNU Lesser General Public License
20 | along with plumed. If not, see .
21 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 | #include "function/Function.h"
23 | #include "core/ActionRegister.h"
24 |
25 | #include
26 | #include
27 |
28 | #include
29 |
30 | using namespace std;
31 |
32 | std::vector tensor_to_vector(const torch::Tensor& x) {
33 | return std::vector(x.data(), x.data() + x.numel());
34 | }
35 |
36 | namespace PLMD {
37 | namespace function {
38 |
39 | //+PLUMEDOC FUNCTION PYTORCH MODEL
40 | /*
41 | Load a model trained with Pytorch. The derivatives are set using native backpropagation in Pytorch.
42 |
43 | \par Examples
44 | Define a model that takes as inputs two distances d1 and d2
45 |
46 | \plumedfile
47 | model: PYTORCH_MODEL MODEL=model.pt ARG=d1,d2
48 | \endplumedfile
49 |
50 | The N nodes of the neural network are saved as "model.node-0", "model.node-1", ..., "model.node-(N-1)".
51 |
52 | */
53 | //+ENDPLUMEDOC
54 |
55 |
56 | class PytorchModel :
57 | public Function
58 | {
59 | unsigned _n_in;
60 | unsigned _n_out;
61 | torch::jit::script::Module _model;
62 | public:
63 | explicit PytorchModel(const ActionOptions&);
64 | void calculate();
65 | static void registerKeywords(Keywords& keys);
66 | };
67 |
68 |
69 | PLUMED_REGISTER_ACTION(PytorchModel,"PYTORCH_MODEL")
70 |
71 | void PytorchModel::registerKeywords(Keywords& keys) {
72 | Function::registerKeywords(keys);
73 | keys.use("ARG");
74 | keys.add("optional","MODEL","filename of the trained model");
75 | keys.addOutputComponent("node", "default", "NN outputs");
76 | }
77 |
78 | PytorchModel::PytorchModel(const ActionOptions&ao):
79 | Action(ao),
80 | Function(ao)
81 | {
82 | //number of inputs of the model
83 | _n_in=getNumberOfArguments();
84 |
85 | //parse model name
86 | std::string fname="model.pt";
87 | parse("MODEL",fname);
88 |
89 | //deserialize the model from file
90 | try {
91 | _model = torch::jit::load(fname);
92 | }
93 | catch (const c10::Error& e) {
94 | error("Cannot find Pytorch model.");
95 | }
96 |
97 | checkRead();
98 |
99 | //check the dimension of the output
100 | log.printf("Checking output dimension:\n");
101 | std::vector input_test (_n_in);
102 | torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in});
103 | std::vector inputs;
104 | inputs.push_back( single_input );
105 | torch::Tensor output = _model.forward( inputs ).toTensor();
106 | vector cvs = tensor_to_vector (output);
107 | _n_out=cvs.size();
108 |
109 | //create components
110 | for(unsigned j=0; j<_n_out; j++){
111 | string name_comp = "node-"+std::to_string(j);
112 | addComponentWithDerivatives( name_comp );
113 | componentIsNotPeriodic( name_comp );
114 | }
115 |
116 | //print log
117 | //log.printf("Pytorch Model Loaded: %s \n",fname);
118 | log.printf("Number of input: %d \n",_n_in);
119 | log.printf("Number of outputs: %d \n",_n_out);
120 |
121 | }
122 |
123 | void PytorchModel::calculate() {
124 |
125 | //retrieve arguments
126 | vector current_S(_n_in);
127 | for(unsigned i=0; i<_n_in; i++)
128 | current_S[i]=getArgument(i);
129 | //convert to tensor
130 | torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in});
131 | input_S.set_requires_grad(true);
132 | //convert to Ivalue
133 | std::vector inputs;
134 | inputs.push_back( input_S );
135 | //calculate output
136 | torch::Tensor output = _model.forward( inputs ).toTensor();
137 | //set CV values
138 | vector cvs = tensor_to_vector (output);
139 | for(unsigned j=0; j<_n_out; j++){
140 | string name_comp = "node-"+std::to_string(j);
141 | getPntrToComponent(name_comp)->set(cvs[j]);
142 | }
143 | //derivatives
144 | for(unsigned j=0; j<_n_out; j++){
145 | //backpropagation
146 | output[0][j].backward();
147 | //convert to vector
148 | vector der = tensor_to_vector (input_S.grad() );
149 | string name_comp = "node-"+std::to_string(j);
150 | //set derivatives of component j
151 | for(unsigned i=0; i<_n_in; i++)
152 | setDerivative( getPntrToComponent(name_comp) ,i,der[i]);
153 | //reset gradients
154 | input_S.grad().zero_();
155 | //for(unsigned i=0; i<_n_in; i++)
156 | // input_S.grad()[0][i] = 0.;
157 |
158 | }
159 |
160 | }
161 | }
162 | }
163 |
--------------------------------------------------------------------------------