Skip to content

Commit 39d6b12

Browse files
committed
Add python interface
1 parent 65b239a commit 39d6b12

13 files changed

Lines changed: 227 additions & 38 deletions

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "extern/pybind11"]
2+
path = extern/pybind11
3+
url = https://github.com/pybind/pybind11

Makefile

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
NVCC=nvcc
44
CXX=g++
5+
#To compile the python wrapper
6+
PYTHON3=python3
7+
#Pybind is cloned as a submodule to this location
8+
PYBIND_INCLUDE=extern/pybind11/include
59

610
#Uncomment for a GPU enabled library
7-
CUDA_ENABLED=-DCUDA_ENABLED
11+
#CUDA_ENABLED=-DCUDA_ENABLED
812

913
#Uncomment to compile in double precision mode
1014
#DOUBLE_PRECISION=-DDOUBLE_PRECISION
@@ -15,31 +19,44 @@ NVCCLDFLAGS= -lcublas
1519
LDFLAGS= -llapacke -lcblas
1620

1721
LIBNAME=liblanczos.so
22+
PYTHON_MODULE_NAME=Lanczos
1823

1924

20-
CXXFLAGS=-fPIC -w -O3 -g -std=c++14 $(INCLUDEFLAGS) $(DOUBLE_PRECISION) $(CUDA_ENABLED)
25+
CXXFLAGS=-fPIC -w -O3 -g -std=c++14 $(INCLUDEFLAGS) $(DOUBLE_PRECISION)
2126
NVCCFLAGS=-ccbin=$(CXX) -Xcompiler "$(CXXFLAGS)" -std=c++14 -O3 $(INCLUDEFLAGS) $(DOUBLE_PRECISION) $(CUDA_ENABLED)
2227

28+
PYTHON_LIBRARY_NAME=python/$(PYTHON_MODULE_NAME)$(shell $(PYTHON3)-config --extension-suffix)
29+
2330
ifndef CUDA_ENABLED
2431
COMPILER=$(CXX)
25-
CXXFLAGS:=$(CXXFLAGS) -xc++
32+
CXXFLAGS_BOTH:=$(CXXFLAGS) -xc++
33+
LDFLAGS_BOTH:=$(LDFLAGS)
2634
else
2735
COMPILER=$(NVCC)
28-
LDFLAGS:=$(LDFLAGS) $(NVCCLDFLAGS)
29-
CXXFLAGS:=$(NVCCFLAGS) -I$(CUDA_ROOT)/include
36+
LDFLAGS_BOTH:=$(LDFLAGS) $(NVCCLDFLAGS)
37+
CXXFLAGS_BOTH:=$(NVCCFLAGS) -I$(CUDA_ROOT)/include $(CUDA_ENABLED)
3038
endif
3139

32-
all: shared $(patsubst %.cu, %, $(wildcard *.cu)) $(patsubst %.cpp, %, $(wildcard *.cpp))
40+
all: shared python $(patsubst %.cu, %, $(wildcard *.cu)) $(patsubst %.cpp, %, $(wildcard *.cpp)) Makefile
3341

3442
$(LIBNAME): $(wildcard include/*.cu)
35-
$(COMPILER) -DSHARED_LIBRARY_COMPILATION -shared $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
43+
$(COMPILER) -DSHARED_LIBRARY_COMPILATION -shared $(CXXFLAGS_BOTH) $^ -o $@ $(LDFLAGS_BOTH)
44+
45+
46+
shared: $(LIBNAME) Makefile
47+
3648

49+
python: $(PYTHON_LIBRARY_NAME) Makefile
50+
# -DLANCZOS_PYTHON_NAME=$(PYTHON_MODULE_NAME)
3751

38-
shared: $(LIBNAME)
52+
$(PYTHON_LIBRARY_NAME): python/python_wrapper.cpp python/lanczos_trampoline.o
53+
$(CXX) $(CXXFLAGS) `$(PYTHON3)-config --includes` -I $(PYBIND_INCLUDE) -shared $^ -o $@ $(LDFLAGS)
3954

55+
python/lanczos_trampoline.o: python/lanczos_trampoline.cpp Makefile
56+
$(CXX) $(CXXFLAGS) -c $< -o $@
4057

4158
%: %.cu Makefile
42-
$(COMPILER) $(CXXFLAGS) $< -o $@ $(LDFLAGS)
59+
$(COMPILER) $(CXXFLAGS_BOTH) $< -o $@ $(LDFLAGS_BOTH)
4360

4461

4562

@@ -48,4 +65,4 @@ shared: $(LIBNAME)
4865
#%.clean:
4966
#rm -f $(@:.clean=.so)
5067
clean:
51-
rm -rf include/*.o $(LIBNAME) example
68+
rm -rf include/*.o python/*.o $(LIBNAME) example $(PYTHON_LIBRARY_NAME)

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ Note, however, that the heavy-weight of this solver comes from the Matrix-vector
7878
7979
See the Makefile for further instructions.
8080
81+
## Python interface
82+
83+
The python/ folder contains a python wrapper to the solver. A class defining the matrix vector product can be written directly in python an provided to the solver.
84+
See python/example.py for more information.
85+
86+
The root folder's Makefile will try to compile the python library as well. It expects pybind11 to be placed under the extern/ folder. Pybind11 is included as a submodule, so make sure to clone this repository with --recursive.
87+
Note that the python wrapper can only be compiled in CPU mode.
88+
8189
## References:
8290
8391
[1] Krylov subspace methods for computing hydrodynamic interactions in Brownian dynamics simulations J. Chem. Phys. 137, 064106 (2012); doi: 10.1063/1.4742347

example.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ struct DiagonalMatrix: public lanczos::MatrixDot{
2323
int size;
2424
DiagonalMatrix(int size): size(size){}
2525

26-
void operator()(real* v, real* Mv){
26+
void dot(real* v, real* Mv) override{
2727
//an example diagonal matrix
2828
for(int i=0; i<size; i++){
29-
Mv[i] = (2+i/10.0)*v[i];
29+
Mv[i] = (2+i/10.0)*v[i]*2;
3030
}
3131
}
3232

example.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct DiagonalMatrix: public lanczos::MatrixDot{
2323
int size;
2424
DiagonalMatrix(int size): size(size){}
2525

26-
void operator()(real* v, real* Mv){
26+
void dot(real* v, real* Mv) override{
2727
//An example diagonal matrix
2828
for(int i=0; i<size; i++){
2929
Mv[i] = (2+i/10.0)*v[i];

extern/pybind11

Submodule pybind11 added at 9ec1128

include/LanczosAlgorithm.cu

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ References:
99
#include<string.h>
1010
#include"utils/lapack_and_blas_defines.h"
1111
#include<stdexcept>
12+
#ifdef CUDA_ENABLED
13+
#include"utils/debugTools.h"
14+
#endif
15+
1216
namespace lanczos{
1317

1418
Solver::Solver(real tolerance):
@@ -23,6 +27,18 @@ namespace lanczos{
2327
#endif
2428
}
2529

30+
Solver::~Solver(){
31+
#ifdef CUDA_ENABLED
32+
CublasSafeCall(cublasDestroy(cublas_handle));
33+
#endif
34+
}
35+
36+
real* Solver::getV(int N){
37+
if(N != this->N) numElementsChanged(N);
38+
return detail::getRawPointer(V);
39+
}
40+
41+
2642
void Solver::numElementsChanged(int newN){
2743
this-> N = newN;
2844
try{
@@ -45,7 +61,7 @@ namespace lanczos{
4561
this->max_iter += inc;
4662
}
4763

48-
int Solver::solve(MatrixDot *dot, real *Bz, real*z, int N){
64+
int Solver::solve(MatrixDot *dot, real *Bz, const real*z, int N){
4965
//Handles the case of the number of elements changing since last call
5066
if(N != this->N){
5167
real * d_V = detail::getRawPointer(V);
@@ -103,7 +119,8 @@ namespace lanczos{
103119
real* d_V = detail::getRawPointer(V);
104120
real * d_w = detail::getRawPointer(w);
105121
/*w = D·vi*/
106-
dot->operator()(d_V+N*i, d_w);
122+
dot->setSize(N);
123+
dot->dot(d_V+N*i, d_w);
107124
if(i>0){
108125
/*w = w-h[i-1][i]·vi*/
109126
real alpha = -hsup[i-1];

include/LanczosAlgorithm.h

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,49 +24,33 @@ Some notes:
2424
#include<vector>
2525
#include<memory>
2626
#include"utils/device_container.h"
27-
28-
#ifdef CUDA_ENABLED
29-
#include"utils/debugTools.h"
30-
#endif
27+
#include"utils/MatrixDot.h"
3128
namespace lanczos{
3229

33-
struct MatrixDot{
34-
35-
virtual void operator()(real* Mv, real*v) = 0;
36-
37-
};
38-
3930
struct Solver{
4031
Solver(real tolerance = 1e-3);
4132

42-
~Solver(){
43-
#ifdef CUDA_ENABLED
44-
CublasSafeCall(cublasDestroy(cublas_handle));
45-
#endif
46-
}
47-
33+
~Solver();
34+
4835
//Given a Dotctor that computes a product M·v (where M is handled by Dotctor ), computes Bv = sqrt(M)·v
4936
//Returns the number of iterations performed
5037
//B = sqrt(M)
51-
int solve(MatrixDot *dot, real *Bv, real* v, int N);
38+
int solve(MatrixDot *dot, real *Bv, const real* v, int N);
5239

5340
//Overload for a shared_ptr
54-
int solve(std::shared_ptr<MatrixDot> dot, real *Bv, real* v, int N){
41+
int solve(std::shared_ptr<MatrixDot> dot, real *Bv, const real* v, int N){
5542
return this->solve(dot.get(), Bv, v, N);
5643
}
5744

5845
//Overload for an instance
5946
template<class SomeDot>
60-
int solve(SomeDot &dot, real *Bv, real* v, int N){
47+
int solve(SomeDot &dot, real *Bv, const real* v, int N){
6148
MatrixDot* ptr = static_cast<MatrixDot*>(&dot);
6249
return this->solve(ptr, Bv, v, N);
6350
}
6451

6552
//You can use this array as input to the solve operation, which will save some memory
66-
real * getV(int N){
67-
if(N != this->N) numElementsChanged(N);
68-
return detail::getRawPointer(V);
69-
}
53+
real * getV(int N);
7054

7155
#ifdef CUDA_ENABLED
7256
//The solver will use this cuda stream when possible

include/utils/MatrixDot.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef LANCZOS_MATRIX_DOT_H
2+
#define LANCZOS_MATRIX_DOT_H
3+
#include"defines.h"
4+
namespace lanczos{
5+
6+
struct MatrixDot{
7+
void setSize(int newsize){this->m_size = newsize;}
8+
virtual void dot(real* v, real*Mv) = 0;
9+
protected:
10+
int m_size;
11+
};
12+
}
13+
#endif

python/example.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#Raul P. Pelaez 2022. Usage example for the Lanczos solver's Python interface
2+
#A class that computes the dot product of a matrix, M, and an arbitrary vector, v, must be written to use the solver (see DiagonalMatrix below).
3+
#The class must inherit from Lanczos.MatrixDot and provide a function called "dot" that given an arbitrary vector, v, returns the product Mv.
4+
#When provided with an instance of this class, the function "solve" in Lanczos.Solver will return the product sqrt(M)v
5+
#IMPORTANT: Remeber to use the same numerical precision here and when compiling the library (see the Makefile for more info)
6+
#Try help(Lanczos)
7+
8+
import Lanczos
9+
import numpy as np
10+
11+
#Lanczos provides the precision it was compiled in via this function.
12+
precision = np.float32 if Lanczos.getPrecision() else np.float64;
13+
14+
15+
# A simple class that computes the product of a diagonal matrix (2*I) by the input vector
16+
class DiagonalMatrix(Lanczos.MatrixDot):
17+
18+
def dot(self, v):
19+
# size=v.size()
20+
Mv = v*2.0
21+
return Mv
22+
23+
#Create the solver and provide a tolerance
24+
solver = Lanczos.Solver(tolerance=1e-3)
25+
26+
#Let us compute the result of sqrt(2*I)*v, where v=[1,1,1....1] and I the identity matrix
27+
#The result vector will be filled with sqrt(2)
28+
size = 1000000
29+
result = np.zeros(size, precision);
30+
v = np.ones(size, precision);
31+
32+
dotProduct = DiagonalMatrix()
33+
#The solve function fills the result vector with sqrt(M)*v and returns the number of iterations required to do so.
34+
numiter = solver.solve(dotProduct, result,v, size)
35+
36+
print("Done after "+ str(numiter) + " iterations.")
37+
print("Result vector (should be filled with ~sqrt(2)="+str(np.sqrt(2))+"):")
38+
print(result)

0 commit comments

Comments
 (0)