@@ -87,31 +87,71 @@ class BatchDataProcessor(connectionConfigsByName: Map[String, Map[String, String
8787 val recordStepName = s " ${task._2.name}_ ${s.name}"
8888 val stepRecords = trackRecordsPerStep(recordStepName)
8989 val startIndex = stepRecords.currentNumRecords
90- val endIndex = stepRecords.currentNumRecords + stepRecords.numRecordsPerBatch
90+
91+ // Calculate precise number of records for this batch to ensure exact total
92+ val adjustedTotalRecords = stepRecords.numTotalRecords / stepRecords.averagePerCol
93+ val remainingAdjustedRecords = adjustedTotalRecords - (stepRecords.currentNumRecords / stepRecords.averagePerCol)
94+
95+ val recordsToGenerate = if (remainingAdjustedRecords <= 0 ) {
96+ 0L
97+ } else if (stepRecords.remainder > 0 && batch <= stepRecords.remainder) {
98+ // First 'remainder' batches get base + 1 records
99+ Math .min(stepRecords.baseRecordsPerBatch + 1 , remainingAdjustedRecords)
100+ } else {
101+ // Remaining batches get base records
102+ Math .min(stepRecords.baseRecordsPerBatch, remainingAdjustedRecords)
103+ }
104+
105+ // Convert back to actual records (multiply by averagePerCol)
106+ val actualRecordsToGenerate = recordsToGenerate * stepRecords.averagePerCol
107+ val endIndex = startIndex + actualRecordsToGenerate
108+
109+ LOGGER .debug(s " Batch $batch: startIndex= $startIndex, endIndex= $endIndex, recordsToGenerate= $recordsToGenerate, " +
110+ s " actualRecordsToGenerate= $actualRecordsToGenerate, remainingAdjustedRecords= $remainingAdjustedRecords" )
91111
92112 val genDf = dataGeneratorFactory.generateDataForStep(s, task._1.dataSourceName, startIndex, endIndex)
93113 val initialDf = getUniqueGeneratedRecords(uniqueFieldUtil, dataSourceStepName, genDf, s)
94114 if (! initialDf.storageLevel.useMemory) initialDf.cache()
95115 genDf.unpersist()
96116
97- val initialRecordCount = if (flagsConfig.enableCount) initialDf.count() else stepRecords.numRecordsPerBatch
98- val targetNumRecords = stepRecords.numRecordsPerBatch * s.count.perField.map(_.averageCountPerField).getOrElse( 1L )
117+ val initialRecordCount = if (flagsConfig.enableCount) initialDf.count() else actualRecordsToGenerate
118+ val targetNumRecords = actualRecordsToGenerate
99119
100120 LOGGER .debug(s " Step record count for batch, batch= $batch, step-name= ${s.name}, " +
101- s " target-num-records= $targetNumRecords, actual-num-records= $initialRecordCount" )
121+ s " target-num-records= $targetNumRecords, actual-num-records= $initialRecordCount, records-to-generate= $recordsToGenerate " )
102122
103123 // if record count doesn't match expected record count, generate more data
104124 def generateAdditionalRecords (currentDf : DataFrame , currentRecordCount : Long ): (DataFrame , Long ) = {
125+ LOGGER .debug(s " Generating additional records for batch, batch= $batch, step-name= ${s.name}, " +
126+ s " current-record-count= $currentRecordCount, target-num-records= $targetNumRecords" )
127+
128+ if (currentRecordCount >= targetNumRecords) {
129+ LOGGER .debug(s " No additional records needed, current count meets or exceeds target " )
130+ return (currentDf, currentRecordCount)
131+ }
132+
105133 val additionalGenDf = dataGeneratorFactory
106134 .generateDataForStep(s, task._1.dataSourceName, stepRecords.currentNumRecords + currentRecordCount, endIndex)
107135 val additionalDf = getUniqueGeneratedRecords(uniqueFieldUtil, dataSourceStepName, additionalGenDf, s)
108136 if (! additionalDf.storageLevel.useMemory) additionalDf.cache()
109137 additionalGenDf.unpersist()
110- val newDf = currentDf.unionByName(additionalDf, true )
111- val newRecordCount = newDf.count()
138+ val additionalRecordCount = if (flagsConfig.enableCount) additionalDf.count() else 0
139+ LOGGER .debug(s " Additional records generated, additional-record-count= $additionalRecordCount" )
140+
141+ // Only union if we actually generated additional records
142+ val (newDf, newRecordCount) = if (additionalRecordCount > 0 ) {
143+ val unionDf = currentDf.union(additionalDf)
144+ val finalCount = unionDf.count()
145+ additionalDf.unpersist()
146+ (unionDf, finalCount)
147+ } else {
148+ // No additional records were generated, return current DataFrame as-is
149+ additionalDf.unpersist()
150+ (currentDf, currentRecordCount)
151+ }
152+
112153 LOGGER .debug(s " Generated more records for step, batch= $batch, step-name= ${s.name}, " +
113- s " new-num-records= ${additionalDf.count()}, actual-num-records= $newRecordCount" )
114- additionalDf.unpersist()
154+ s " new-num-records= $additionalRecordCount, actual-num-records= $newRecordCount, current-df-count= ${currentDf.count()}" )
115155 (newDf, newRecordCount)
116156 }
117157
@@ -142,7 +182,7 @@ class BatchDataProcessor(connectionConfigsByName: Map[String, Map[String, String
142182 s " target-num-records= $targetNumRecords, actual-num-records= $finalRecordCount" )
143183 }
144184
145- trackRecordsPerStep = trackRecordsPerStep ++ Map (recordStepName -> stepRecords.copy(currentNumRecords = finalRecordCount + stepRecords.currentNumRecords))
185+ trackRecordsPerStep = trackRecordsPerStep ++ Map (recordStepName -> stepRecords.copy(currentNumRecords = stepRecords.currentNumRecords + finalRecordCount ))
146186 (dataSourceStepName, finalDf)
147187 } else {
148188 LOGGER .debug(s " Step has both data generation and reference mode disabled, data-source= ${task._1.dataSourceName}, step-name= ${s.name}" )
0 commit comments