Skip to content

Commit 64e0ae4

Browse files
committed
TPC splines: add a possibility to merge specific sectors
1 parent cd9df99 commit 64e0ae4

2 files changed

Lines changed: 139 additions & 14 deletions

File tree

Detectors/TPC/calibration/include/TPCCalibration/TPCFastSpaceChargeCorrectionHelper.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ using namespace o2::gpu;
4141

4242
class TPCFastSpaceChargeCorrectionHelper
4343
{
44+
public:
45+
using SectorScales = std::array<double, TPCFastTransformGeo::getNumberOfSectors()>;
46+
4447
public:
4548
/// _____________ Constructors / destructors __________________________
4649

@@ -115,15 +118,32 @@ class TPCFastSpaceChargeCorrectionHelper
115118
/// initialise inverse transformation from linear combination of several input corrections
116119
void initInverse(std::vector<o2::gpu::TPCFastSpaceChargeCorrection*>& corrections, const std::vector<float>& scaling, bool prn);
117120

118-
/// merge several corrections
121+
/// weighted add of several corrections
119122
/// \param mainCorrection main correction
120123
/// \param scale scaling factor for the main correction
121124
/// \param additionalCorrections vector of pairs of additional corrections and their scaling factors
122-
/// \param prn printout flag
123125
/// \return main correction merged with additional corrections
126+
void addCorrections(
127+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, double scale,
128+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, double>>& additionalCorrections);
129+
130+
/// weighted add of several corrections with sector-dependent scaling factors
131+
/// \param mainCorrection main correction
132+
/// \param scale scaling factor for the main correction
133+
/// \param additionalCorrections vector of pairs of additional corrections and their sector-dependent scaling factors
134+
/// \return main correction merged with additional corrections
135+
void addCorrections(
136+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, SectorScales scale,
137+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>>& additionalCorrections);
138+
139+
/// merge of two corrections sector-wise
140+
/// \param destinationCorrection main correction to which the source correction will be added
141+
/// \param sourceCorrection correction to be added to the main correction
142+
/// \param sectors vector of sector indices for which the correction will be added
143+
/// \return main correction merged with the source correction
124144
void mergeCorrections(
125-
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, float scale,
126-
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, float>>& additionalCorrections, bool prn);
145+
o2::gpu::TPCFastSpaceChargeCorrection& destinationCorrection, const o2::gpu::TPCFastSpaceChargeCorrection& sourceCorrection,
146+
const std::vector<int>& sectors);
127147

128148
/// how far the voxel mean is allowed to be outside of the voxel (1.1 means 10%)
129149
void setVoxelMeanValidityRange(double range)

Detectors/TPC/calibration/src/TPCFastSpaceChargeCorrectionHelper.cxx

Lines changed: 115 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,19 +1030,38 @@ void TPCFastSpaceChargeCorrectionHelper::initInverse(std::vector<o2::gpu::TPCFas
10301030
LOGP(info, "Inverse tooks: {}s", duration);
10311031
}
10321032

1033-
void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
1034-
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, float mainScale,
1035-
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, float>>& additionalCorrections, bool /*prn*/)
1033+
void TPCFastSpaceChargeCorrectionHelper::addCorrections(
1034+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, double mainScale,
1035+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, double>>& additionalCorrections)
10361036
{
1037-
/// merge several corrections
1037+
/// weighted add of several corrections
1038+
SectorScales mainSectorScale;
1039+
mainSectorScale.fill(mainScale);
1040+
std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>> additionalSectorScales;
1041+
for (const auto& corr : additionalCorrections) {
1042+
SectorScales sectorScale;
1043+
sectorScale.fill(corr.second);
1044+
additionalSectorScales.emplace_back(corr.first, sectorScale);
1045+
}
1046+
1047+
addCorrections(mainCorrection, mainSectorScale, additionalSectorScales);
1048+
}
1049+
1050+
void TPCFastSpaceChargeCorrectionHelper::addCorrections(
1051+
o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, SectorScales mainScale,
1052+
const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>>& additionalCorrections)
1053+
{
1054+
/// weighted add of several corrections
10381055

10391056
TStopwatch watch;
1040-
LOG(info) << "fast space charge correction helper: Merge corrections";
1057+
LOG(info) << "fast space charge correction helper: Add corrections";
10411058

10421059
const auto& geo = mainCorrection.getGeometry();
10431060

10441061
for (int sector = 0; sector < geo.getNumberOfSectors(); sector++) {
10451062

1063+
float secMainScale = mainScale[sector];
1064+
10461065
auto myThread = [&](int iThread) {
10471066
for (int row = iThread; row < geo.getNumberOfRows(); row += mNthreads) {
10481067
const auto& spline = mainCorrection.getSpline(sector, row);
@@ -1059,10 +1078,10 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10591078

10601079
{ // scale the main correction
10611080
for (int i = 0; i < 3; i++) {
1062-
secRowInfo.maxCorr[i] *= mainScale;
1063-
secRowInfo.minCorr[i] *= mainScale;
1081+
secRowInfo.maxCorr[i] *= secMainScale;
1082+
secRowInfo.minCorr[i] *= secMainScale;
10641083
}
1065-
double parscale[4] = {mainScale, mainScale, mainScale, mainScale * mainScale};
1084+
double parscale[4] = {secMainScale, secMainScale, secMainScale, secMainScale * secMainScale};
10661085
for (int iknot = 0, ind = 0; iknot < spline.getNumberOfKnots(); iknot++) {
10671086
for (int ipar = 0; ipar < nKnotPar1d; ++ipar) {
10681087
for (int idim = 0; idim < 3; idim++, ind++) {
@@ -1093,7 +1112,7 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10931112

10941113
for (int icorr = 0; icorr < additionalCorrections.size(); ++icorr) {
10951114
const auto& corr = *(additionalCorrections[icorr].first);
1096-
double scale = additionalCorrections[icorr].second;
1115+
double scale = additionalCorrections[icorr].second[sector];
10971116
auto& linfo = corr.getSectorRowInfo(sector, row);
10981117
secRowInfo.updateMaxValues(linfo.getMaxValues(), scale);
10991118
secRowInfo.updateMaxValues(linfo.getMinValues(), scale);
@@ -1169,7 +1188,93 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
11691188
}
11701189

11711190
} // sector
1172-
float duration = watch.RealTime();
1191+
double duration = watch.RealTime();
1192+
LOGP(info, "Merge of corrections tooks: {}s", duration);
1193+
}
1194+
1195+
void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(o2::gpu::TPCFastSpaceChargeCorrection& destinationCorrection,
1196+
const o2::gpu::TPCFastSpaceChargeCorrection& sourceCorrection,
1197+
const std::vector<int>& sectors)
1198+
{
1199+
/// merge of two corrections sector-wise
1200+
TStopwatch watch;
1201+
LOG(info) << "fast space charge correction helper: Merge corrections";
1202+
1203+
const auto& geo = destinationCorrection.getGeometry();
1204+
1205+
for (int sector : sectors) {
1206+
if (sector < 0 || sector >= geo.getNumberOfSectors()) {
1207+
LOGP(fatal, "Invalid sector number {}. Valid range is [0, {})", sector, geo.getNumberOfSectors());
1208+
continue;
1209+
}
1210+
auto myThread = [&](int iThread) {
1211+
for (int row = iThread; row < geo.getNumberOfRows(); row += mNthreads) {
1212+
1213+
{ // replace the direct correction
1214+
const auto& destSpline = destinationCorrection.getSpline(sector, row);
1215+
float* destSplineParameters = destinationCorrection.getCorrectionData(sector, row);
1216+
const auto& sourceSpline = sourceCorrection.getSpline(sector, row);
1217+
const float* sourceSplineParameters = sourceCorrection.getCorrectionData(sector, row);
1218+
1219+
// ensure the splines are compatible
1220+
if (destSpline.getGridX1().getNumberOfKnots() != sourceSpline.getGridX1().getNumberOfKnots() ||
1221+
destSpline.getGridX2().getNumberOfKnots() != sourceSpline.getGridX2().getNumberOfKnots()) {
1222+
LOGP(error, "Splines for sector {} row {} are not compatible: number of knots in U or V direction do not match", sector, row);
1223+
continue;
1224+
}
1225+
// replace the destination correction with the source correction for this sector and row
1226+
memcpy(destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters() * sizeof(float));
1227+
}
1228+
1229+
{ // replace the inverse correction X
1230+
const auto& destSpline = destinationCorrection.getSplineInvX(sector, row);
1231+
float* destSplineParameters = destinationCorrection.getCorrectionDataInvX(sector, row);
1232+
const auto& sourceSpline = sourceCorrection.getSplineInvX(sector, row);
1233+
const float* sourceSplineParameters = sourceCorrection.getCorrectionDataInvX(sector, row);
1234+
// ensure the splines are compatible
1235+
if (destSpline.getGridX1().getNumberOfKnots() != sourceSpline.getGridX1().getNumberOfKnots() ||
1236+
destSpline.getGridX2().getNumberOfKnots() != sourceSpline.getGridX2().getNumberOfKnots()) {
1237+
LOGP(error, "Inverse X splines for sector {} row {} are not compatible: number of knots in U or V direction do not match", sector, row);
1238+
continue;
1239+
}
1240+
memcpy(destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters() * sizeof(float));
1241+
}
1242+
1243+
{ // replace the inverse correction YZ
1244+
const auto& destSpline = destinationCorrection.getSplineInvYZ(sector, row);
1245+
float* destSplineParameters = destinationCorrection.getCorrectionDataInvYZ(sector, row);
1246+
const auto& sourceSpline = sourceCorrection.getSplineInvYZ(sector, row);
1247+
const float* sourceSplineParameters = sourceCorrection.getCorrectionDataInvYZ(sector, row);
1248+
// ensure the splines are compatible
1249+
if (destSpline.getGridX1().getNumberOfKnots() != sourceSpline.getGridX1().getNumberOfKnots() ||
1250+
destSpline.getGridX2().getNumberOfKnots() != sourceSpline.getGridX2().getNumberOfKnots()) {
1251+
LOGP(error, "Inverse YZ splines for sector {} row {} are not compatible: number of knots in U or V direction do not match", sector, row);
1252+
continue;
1253+
}
1254+
memcpy(destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters() * sizeof(float));
1255+
}
1256+
1257+
// replace the sector row info
1258+
auto& destSecRowInfo = destinationCorrection.getSectorRowInfo(sector, row);
1259+
const auto& sourceSecRowInfo = sourceCorrection.getSectorRowInfo(sector, row);
1260+
destSecRowInfo = sourceSecRowInfo;
1261+
} // row
1262+
}; // thread
1263+
1264+
std::vector<std::thread> threads(mNthreads);
1265+
1266+
// run n threads
1267+
for (int i = 0; i < mNthreads; i++) {
1268+
threads[i] = std::thread(myThread, i);
1269+
}
1270+
1271+
// wait for the threads to finish
1272+
for (auto& th : threads) {
1273+
th.join();
1274+
}
1275+
1276+
} // sector
1277+
double duration = watch.RealTime();
11731278
LOGP(info, "Merge of corrections tooks: {}s", duration);
11741279
}
11751280

0 commit comments

Comments
 (0)