-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmlforce.i
More file actions
executable file
·156 lines (142 loc) · 6.06 KB
/
mlforce.i
File metadata and controls
executable file
·156 lines (142 loc) · 6.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
%module mlforce
%import(module="openmm") "swig/OpenMMSwigHeaders.i"
%include "swig/typemaps.i"
%include <std_string.i>
%include <std_vector.i>
%{
#include "PyTorchForce.h"
#include "OpenMM.h"
#include "OpenMMAmoeba.h"
#include "OpenMMDrude.h"
#include "openmm/RPMDIntegrator.h"
#include "openmm/RPMDMonteCarloBarostat.h"
%}
%template(vectori) std::vector<int>;
%template(vectorii) std::vector<std::vector<int> >;
%template(vectord) std::vector<double>;
%template(vectordd) std::vector<std::vector<double> >;
%template(vectorf) std::vector<float> ;
%template(vectorff) std::vector<std::vector<float> >;
namespace PyTorchPlugin {
class PyTorchForce : public OpenMM::Force {
public:
PyTorchForce(const std::string& file,
std::vector<std::vector<double> > targetFeatures,
const std::vector<int> particleIndices,
const std::vector<double> signalForceWeights,
double scale,
int assignFreq,
std::vector<std::vector<int> > restraintIndices,
const std::vector<double> restraintDistances,
double rmaxDelta,
double restraintK,
const std::vector<int> initialAssignment
);
const std::string& getFile() const;
const double getScale() const;
const int getAssignFreq() const;
const std::vector<std::vector<double> > getTargetFeatures() const;
const std::vector<int> getParticleIndices() const;
const std::vector<double> getSignalForceWeights() const;
void setUsesPeriodicBoundaryConditions(bool periodic);
bool usesPeriodicBoundaryConditions() const;
int getNumGlobalParameters() const;
int addGlobalParameter(const std::string& name, double defaultValue);
const std::string& getGlobalParameterName(int index) const;
void setGlobalParameterName(int index, const std::string& name);
double getGlobalParameterDefaultValue(int index) const;
void setGlobalParameterDefaultValue(int index, double defaultValue);
const std::vector<std::vector<int> > getRestraintIndices() const;
const std::vector<double> getRestraintDistances() const;
const std::vector<double> getRestraintParams() const;
const std::vector<int> getInitialAssignment() const;
};
class PyTorchForceE2E : public OpenMM::Force {
public:
PyTorchForceE2E(const std::string& file,
const std::vector<int> particleIndices,
const std::vector<double> signalForceWeights,
double scale,
double offset,
bool useLambda);
const std::string& getFile() const;
const double getScale() const;
const double getOffset() const;
const std::vector<int> getParticleIndices() const;
const std::vector<double> getSignalForceWeights() const;
void setUsesPeriodicBoundaryConditions(bool periodic);
bool usesPeriodicBoundaryConditions() const;
bool usesLambda() const;
int getNumGlobalParameters() const;
int addGlobalParameter(const std::string& name, double defaultValue);
const std::string& getGlobalParameterName(int index) const;
void setGlobalParameterName(int index, const std::string& name);
double getGlobalParameterDefaultValue(int index) const;
void setGlobalParameterDefaultValue(int index, double defaultValue);
};
class PyTorchForceE2EDirect : public OpenMM::Force {
public:
PyTorchForceE2EDirect(const std::string& file,
const std::vector<int> particleIndices,
const std::vector<double> signalForceWeights,
double scale,
const std::vector<int> atomTypes,
const std::vector<std::vector<int>> edgeIndices,
const std::vector<int> edgeTypes,
bool useAttr);
const std::string& getFile() const;
const double getScale() const;
const std::vector<int> getAtomTypes() const;
const std::vector<std::vector<int>> getEdgeIndices() const;
const std::vector<int> getEdgeTypes() const;
const bool getUseAttr() const;
const std::vector<int> getParticleIndices() const;
const std::vector<double> getSignalForceWeights() const;
void setUsesPeriodicBoundaryConditions(bool periodic);
bool usesPeriodicBoundaryConditions() const;
int getNumGlobalParameters() const;
int addGlobalParameter(const std::string& name, double defaultValue);
const std::string& getGlobalParameterName(int index) const;
void setGlobalParameterName(int index, const std::string& name);
double getGlobalParameterDefaultValue(int index) const;
void setGlobalParameterDefaultValue(int index, double defaultValue);
};
class PyTorchForceE2EDiffConf : public OpenMM::Force {
public:
PyTorchForceE2EDiffConf(const std::string& file,
const std::vector<int> particleIndices,
const std::vector<double> signalForceWeights,
double scale,
const std::vector<int> atoms,
const std::vector<std::vector<int>> bonds,
const std::vector<std::vector<int>> angles,
const std::vector<std::vector<int>> propers,
const std::vector<std::vector<int>> impropers,
const std::vector<std::vector<int>> pairs,
const std::vector<std::vector<int>> tetras,
const std::vector<std::vector<int>> cistrans,
const std::vector<std::vector<float>> encoding
);
const std::string& getFile() const;
const double getScale() const;
const std::vector<int> getAtomTypes() const;
const std::vector<std::vector<int>> getEdgeIndices() const;
const std::vector<std::vector<int>> getAngles() const;
const std::vector<std::vector<int>> getPropers() const;
const std::vector<std::vector<int>> getImpropers() const;
const std::vector<std::vector<int>> getPairs() const;
const std::vector<std::vector<int>> getTetras() const;
const std::vector<std::vector<int>> getCisTrans() const;
const std::vector<std::vector<float>> getEncoding() const;
const std::vector<int> getParticleIndices() const;
const std::vector<double> getSignalForceWeights() const;
void setUsesPeriodicBoundaryConditions(bool periodic);
bool usesPeriodicBoundaryConditions() const;
int getNumGlobalParameters() const;
int addGlobalParameter(const std::string& name, double defaultValue);
const std::string& getGlobalParameterName(int index) const;
void setGlobalParameterName(int index, const std::string& name);
double getGlobalParameterDefaultValue(int index) const;
void setGlobalParameterDefaultValue(int index, double defaultValue);
};
}