@@ -1030,19 +1030,38 @@ void TPCFastSpaceChargeCorrectionHelper::initInverse(std::vector<o2::gpu::TPCFas
10301030 LOGP (info, " Inverse tooks: {}s" , duration);
10311031}
10321032
1033- void TPCFastSpaceChargeCorrectionHelper::mergeCorrections (
1034- o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, float mainScale,
1035- const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, float >>& additionalCorrections, bool /* prn */ )
1033+ void TPCFastSpaceChargeCorrectionHelper::addCorrections (
1034+ o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, double mainScale,
1035+ const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, double >>& additionalCorrections)
10361036{
1037- // / merge several corrections
1037+ // / weighted add of several corrections
1038+ SectorScales mainSectorScale;
1039+ mainSectorScale.fill (mainScale);
1040+ std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>> additionalSectorScales;
1041+ for (const auto & corr : additionalCorrections) {
1042+ SectorScales sectorScale;
1043+ sectorScale.fill (corr.second );
1044+ additionalSectorScales.emplace_back (corr.first , sectorScale);
1045+ }
1046+
1047+ addCorrections (mainCorrection, mainSectorScale, additionalSectorScales);
1048+ }
1049+
1050+ void TPCFastSpaceChargeCorrectionHelper::addCorrections (
1051+ o2::gpu::TPCFastSpaceChargeCorrection& mainCorrection, SectorScales mainScale,
1052+ const std::vector<std::pair<const o2::gpu::TPCFastSpaceChargeCorrection*, SectorScales>>& additionalCorrections)
1053+ {
1054+ // / weighted add of several corrections
10381055
10391056 TStopwatch watch;
1040- LOG (info) << " fast space charge correction helper: Merge corrections" ;
1057+ LOG (info) << " fast space charge correction helper: Add corrections" ;
10411058
10421059 const auto & geo = mainCorrection.getGeometry ();
10431060
10441061 for (int sector = 0 ; sector < geo.getNumberOfSectors (); sector++) {
10451062
1063+ float secMainScale = mainScale[sector];
1064+
10461065 auto myThread = [&](int iThread) {
10471066 for (int row = iThread; row < geo.getNumberOfRows (); row += mNthreads ) {
10481067 const auto & spline = mainCorrection.getSpline (sector, row);
@@ -1059,10 +1078,10 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10591078
10601079 { // scale the main correction
10611080 for (int i = 0 ; i < 3 ; i++) {
1062- secRowInfo.maxCorr [i] *= mainScale ;
1063- secRowInfo.minCorr [i] *= mainScale ;
1081+ secRowInfo.maxCorr [i] *= secMainScale ;
1082+ secRowInfo.minCorr [i] *= secMainScale ;
10641083 }
1065- double parscale[4 ] = {mainScale, mainScale, mainScale, mainScale * mainScale };
1084+ double parscale[4 ] = {secMainScale, secMainScale, secMainScale, secMainScale * secMainScale };
10661085 for (int iknot = 0 , ind = 0 ; iknot < spline.getNumberOfKnots (); iknot++) {
10671086 for (int ipar = 0 ; ipar < nKnotPar1d; ++ipar) {
10681087 for (int idim = 0 ; idim < 3 ; idim++, ind++) {
@@ -1093,7 +1112,7 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
10931112
10941113 for (int icorr = 0 ; icorr < additionalCorrections.size (); ++icorr) {
10951114 const auto & corr = *(additionalCorrections[icorr].first );
1096- double scale = additionalCorrections[icorr].second ;
1115+ double scale = additionalCorrections[icorr].second [sector] ;
10971116 auto & linfo = corr.getSectorRowInfo (sector, row);
10981117 secRowInfo.updateMaxValues (linfo.getMaxValues (), scale);
10991118 secRowInfo.updateMaxValues (linfo.getMinValues (), scale);
@@ -1169,7 +1188,93 @@ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections(
11691188 }
11701189
11711190 } // sector
1172- float duration = watch.RealTime ();
1191+ double duration = watch.RealTime ();
1192+ LOGP (info, " Merge of corrections tooks: {}s" , duration);
1193+ }
1194+
1195+ void TPCFastSpaceChargeCorrectionHelper::mergeCorrections (o2::gpu::TPCFastSpaceChargeCorrection& destinationCorrection,
1196+ const o2::gpu::TPCFastSpaceChargeCorrection& sourceCorrection,
1197+ const std::vector<int >& sectors)
1198+ {
1199+ // / merge of two corrections sector-wise
1200+ TStopwatch watch;
1201+ LOG (info) << " fast space charge correction helper: Merge corrections" ;
1202+
1203+ const auto & geo = destinationCorrection.getGeometry ();
1204+
1205+ for (int sector : sectors) {
1206+ if (sector < 0 || sector >= geo.getNumberOfSectors ()) {
1207+ LOGP (fatal, " Invalid sector number {}. Valid range is [0, {})" , sector, geo.getNumberOfSectors ());
1208+ continue ;
1209+ }
1210+ auto myThread = [&](int iThread) {
1211+ for (int row = iThread; row < geo.getNumberOfRows (); row += mNthreads ) {
1212+
1213+ { // replace the direct correction
1214+ const auto & destSpline = destinationCorrection.getSpline (sector, row);
1215+ float * destSplineParameters = destinationCorrection.getCorrectionData (sector, row);
1216+ const auto & sourceSpline = sourceCorrection.getSpline (sector, row);
1217+ const float * sourceSplineParameters = sourceCorrection.getCorrectionData (sector, row);
1218+
1219+ // ensure the splines are compatible
1220+ if (destSpline.getGridX1 ().getNumberOfKnots () != sourceSpline.getGridX1 ().getNumberOfKnots () ||
1221+ destSpline.getGridX2 ().getNumberOfKnots () != sourceSpline.getGridX2 ().getNumberOfKnots ()) {
1222+ LOGP (error, " Splines for sector {} row {} are not compatible: number of knots in U or V direction do not match" , sector, row);
1223+ continue ;
1224+ }
1225+ // replace the destination correction with the source correction for this sector and row
1226+ memcpy (destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters () * sizeof (float ));
1227+ }
1228+
1229+ { // replace the inverse correction X
1230+ const auto & destSpline = destinationCorrection.getSplineInvX (sector, row);
1231+ float * destSplineParameters = destinationCorrection.getCorrectionDataInvX (sector, row);
1232+ const auto & sourceSpline = sourceCorrection.getSplineInvX (sector, row);
1233+ const float * sourceSplineParameters = sourceCorrection.getCorrectionDataInvX (sector, row);
1234+ // ensure the splines are compatible
1235+ if (destSpline.getGridX1 ().getNumberOfKnots () != sourceSpline.getGridX1 ().getNumberOfKnots () ||
1236+ destSpline.getGridX2 ().getNumberOfKnots () != sourceSpline.getGridX2 ().getNumberOfKnots ()) {
1237+ LOGP (error, " Inverse X splines for sector {} row {} are not compatible: number of knots in U or V direction do not match" , sector, row);
1238+ continue ;
1239+ }
1240+ memcpy (destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters () * sizeof (float ));
1241+ }
1242+
1243+ { // replace the inverse correction YZ
1244+ const auto & destSpline = destinationCorrection.getSplineInvYZ (sector, row);
1245+ float * destSplineParameters = destinationCorrection.getCorrectionDataInvYZ (sector, row);
1246+ const auto & sourceSpline = sourceCorrection.getSplineInvYZ (sector, row);
1247+ const float * sourceSplineParameters = sourceCorrection.getCorrectionDataInvYZ (sector, row);
1248+ // ensure the splines are compatible
1249+ if (destSpline.getGridX1 ().getNumberOfKnots () != sourceSpline.getGridX1 ().getNumberOfKnots () ||
1250+ destSpline.getGridX2 ().getNumberOfKnots () != sourceSpline.getGridX2 ().getNumberOfKnots ()) {
1251+ LOGP (error, " Inverse YZ splines for sector {} row {} are not compatible: number of knots in U or V direction do not match" , sector, row);
1252+ continue ;
1253+ }
1254+ memcpy (destSplineParameters, sourceSplineParameters, destSpline.getNumberOfParameters () * sizeof (float ));
1255+ }
1256+
1257+ // replace the sector row info
1258+ auto & destSecRowInfo = destinationCorrection.getSectorRowInfo (sector, row);
1259+ const auto & sourceSecRowInfo = sourceCorrection.getSectorRowInfo (sector, row);
1260+ destSecRowInfo = sourceSecRowInfo;
1261+ } // row
1262+ }; // thread
1263+
1264+ std::vector<std::thread> threads (mNthreads );
1265+
1266+ // run n threads
1267+ for (int i = 0 ; i < mNthreads ; i++) {
1268+ threads[i] = std::thread (myThread, i);
1269+ }
1270+
1271+ // wait for the threads to finish
1272+ for (auto & th : threads) {
1273+ th.join ();
1274+ }
1275+
1276+ } // sector
1277+ double duration = watch.RealTime ();
11731278 LOGP (info, " Merge of corrections tooks: {}s" , duration);
11741279}
11751280
0 commit comments