diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index f0f87e889..5de3d82a3 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -12,12 +12,14 @@ import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; + import java.io.IOException; import java.lang.management.ManagementFactory; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -38,6 +40,7 @@ import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; + import net.snowflake.client.jdbc.internal.google.common.util.concurrent.ThreadFactoryBuilder; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; @@ -83,13 +86,16 @@ List>> getData() { private final SnowflakeStreamingIngestClientInternal owningClient; // Thread to schedule the flush job - @VisibleForTesting ScheduledExecutorService flushWorker; + @VisibleForTesting + ScheduledExecutorService flushWorker; // Thread to register the blob - @VisibleForTesting ExecutorService registerWorker; + @VisibleForTesting + ExecutorService registerWorker; // Threads to build and upload the blob - @VisibleForTesting ExecutorService buildUploadWorkers; + @VisibleForTesting + ExecutorService buildUploadWorkers; // Reference to the channel cache private final ChannelCache channelCache; @@ -105,9 +111,11 @@ List>> getData() { * blob is not 1. When max chunk in blob is 1, flush service ignores these variables and uses * table level last flush time and need flush flag. See {@link ChannelCache.FlushInfo}. */ - @VisibleForTesting volatile long lastFlushTime; + @VisibleForTesting + volatile long lastFlushTime; - @VisibleForTesting volatile boolean isNeedFlush; + @VisibleForTesting + volatile boolean isNeedFlush; // Indicates whether it's running as part of the test private final boolean isTestMode; @@ -122,10 +130,10 @@ List>> getData() { /** * Default constructor * - * @param client the owning client - * @param cache the channel cache + * @param client the owning client + * @param cache the channel cache * @param storageManager the storage manager - * @param isTestMode whether the service is running in test mode + * @param isTestMode whether the service is running in test mode */ FlushService( SnowflakeStreamingIngestClientInternal client, @@ -164,8 +172,8 @@ private CompletableFuture statsFuture() { } /** - * @param isForce if true will flush regardless of other conditions - * @param tablesToFlush list of tables to flush + * @param isForce if true will flush regardless of other conditions + * @param tablesToFlush list of tables to flush * @param flushStartTime the time when the flush started * @return */ @@ -187,7 +195,9 @@ private CompletableFuture distributeFlush( this.flushWorker); } - /** If tracing is enabled, print always else, check if it needs flush or is forceful. */ + /** + * If tracing is enabled, print always else, check if it needs flush or is forceful. + */ private void logFlushTask(boolean isForce, Set tablesToFlush, long flushStartTime) { boolean isNeedFlush = this.owningClient.getParameterProvider().getMaxChunksInBlob() == 1 @@ -262,7 +272,7 @@ private CompletableFuture registerFuture() { * * @param isForce * @return Completable future that will return when the blobs are registered successfully, or null - * if none of the conditions is met above + * if none of the conditions is met above */ CompletableFuture flush(boolean isForce) { final long flushStartTime = System.currentTimeMillis(); @@ -277,14 +287,14 @@ CompletableFuture flush(boolean isForce) { key -> isForce || flushStartTime - this.channelCache.getLastFlushTime(key) - >= flushingInterval + >= flushingInterval || this.channelCache.getNeedFlush(key)) .collect(Collectors.toSet()); } else { if (isForce || (!DISABLE_BACKGROUND_FLUSH - && !isTestMode() - && (this.isNeedFlush || flushStartTime - this.lastFlushTime >= flushingInterval))) { + && !isTestMode() + && (this.isNeedFlush || flushStartTime - this.lastFlushTime >= flushingInterval))) { tablesToFlush = this.channelCache.keySet(); } else { tablesToFlush = null; @@ -293,9 +303,9 @@ CompletableFuture flush(boolean isForce) { if (isForce || (!DISABLE_BACKGROUND_FLUSH - && !isTestMode() - && tablesToFlush != null - && !tablesToFlush.isEmpty())) { + && !isTestMode() + && tablesToFlush != null + && !tablesToFlush.isEmpty())) { return this.statsFuture() .thenCompose((v) -> this.distributeFlush(isForce, tablesToFlush, flushStartTime)) .thenCompose((v) -> this.registerFuture()); @@ -303,7 +313,9 @@ CompletableFuture flush(boolean isForce) { return this.statsFuture(); } - /** Create the workers for each specific job */ + /** + * Create the workers for each specific job + */ private void createWorkers() { // Create thread for checking and scheduling flush job ThreadFactory flushThreadFactory = @@ -377,6 +389,14 @@ private void createWorkers() { Runtime.getRuntime().availableProcessors()); } + Map>> getChannelsToFlush(Set tablesToFlush) { + return this.channelCache.entrySet().stream() + .filter(e -> tablesToFlush.contains(e.getKey())) + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> e.getValue().values())); + } + /** * Distribute the flush tasks by iterating through all the channels in the channel cache and kick * off a build blob work when certain size has reached or we have reached the end @@ -384,181 +404,158 @@ private void createWorkers() { * @param tablesToFlush list of tables to flush */ void distributeFlushTasks(Set tablesToFlush) { - Iterator< - Map.Entry< - String, ConcurrentHashMap>>> - itr = - this.channelCache.entrySet().stream() - .filter(e -> tablesToFlush.contains(e.getKey())) - .iterator(); - List, CompletableFuture>> blobs = new ArrayList<>(); - List> leftoverChannelsDataPerTable = new ArrayList<>(); + Map>> channelsToFlushByFqTableName = + getChannelsToFlush(tablesToFlush); + Iterator>>> itr = + channelsToFlushByFqTableName.entrySet().iterator(); // The API states that the number of available processors reported can change and therefore, we // should poll it occasionally. numProcessors = Runtime.getRuntime().availableProcessors(); - while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) { - List>> blobData = new ArrayList<>(); - float totalBufferSizeInBytes = 0F; - - // Distribute work at table level, split the blob if reaching the blob size limit or the - // channel has different encryption key ids - while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) { - List> channelsDataPerTable = Collections.synchronizedList(new ArrayList<>()); - if (!leftoverChannelsDataPerTable.isEmpty()) { - channelsDataPerTable.addAll(leftoverChannelsDataPerTable); - leftoverChannelsDataPerTable.clear(); - } else if (blobData.size() - >= this.owningClient.getParameterProvider().getMaxChunksInBlob()) { - // Create a new blob if the current one already contains max allowed number of chunks - logger.logInfo( - "Max allowed number of chunks in the current blob reached. chunkCount={}" - + " maxChunkCount={}", - blobData.size(), - this.owningClient.getParameterProvider().getMaxChunksInBlob()); - break; - } else { - ConcurrentHashMap> table = - itr.next().getValue(); - // Use parallel stream since getData could be the performance bottleneck when we have a - // high number of channels - table.values().parallelStream() - .forEach( - channel -> { - if (channel.isValid()) { - ChannelData data = channel.getData(); - if (data != null) { - channelsDataPerTable.add(data); - } - } - }); - } - - if (!channelsDataPerTable.isEmpty()) { - int idx = 0; - float totalBufferSizePerTableInBytes = 0F; - while (idx < channelsDataPerTable.size()) { - ChannelData channelData = channelsDataPerTable.get(idx); - // Stop processing the rest of channels when needed - if (idx > 0 - && shouldStopProcessing( - totalBufferSizeInBytes, - totalBufferSizePerTableInBytes, - channelData, - channelsDataPerTable.get(idx - 1))) { - leftoverChannelsDataPerTable.addAll( - channelsDataPerTable.subList(idx, channelsDataPerTable.size())); - logger.logInfo( - "Creation of another blob is needed because of blob/chunk size limit or" - + " different encryption ids or different schema, client={}, table={}," - + " blobSize={}, chunkSize={}, nextChannelSize={}, encryptionId1={}," - + " encryptionId2={}, schema1={}, schema2={}", - this.owningClient.getName(), - channelData.getChannelContext().getTableName(), - totalBufferSizeInBytes, - totalBufferSizePerTableInBytes, - channelData.getBufferSize(), - channelData.getChannelContext().getEncryptionKeyId(), - channelsDataPerTable.get(idx - 1).getChannelContext().getEncryptionKeyId(), - channelData.getColumnEps().keySet(), - channelsDataPerTable.get(idx - 1).getColumnEps().keySet()); - break; - } - totalBufferSizeInBytes += channelData.getBufferSize(); - totalBufferSizePerTableInBytes += channelData.getBufferSize(); - idx++; - } - // Add processed channels to the current blob, stop if we need to create a new blob - blobData.add(channelsDataPerTable.subList(0, idx)); - if (idx != channelsDataPerTable.size()) { - break; - } - } - } + while (itr.hasNext()) { + Map.Entry>> tableDataToFlush = itr.next(); + String tableName = tableDataToFlush.getKey(); + Collection> tableChannels = tableDataToFlush.getValue(); + + List> channelsDataPerTable = Collections.synchronizedList(new ArrayList<>()); + tableChannels.parallelStream() + .forEach( + channel -> { + if (channel.isValid()) { + ChannelData data = channel.getData(); + if (data != null) { + channelsDataPerTable.add(data); + } + } + }); - if (blobData.isEmpty()) { - continue; - } + List>> chunkData = flushChunkData(tableName, channelsDataPerTable); - // Kick off a build job + uploadChunkData(tableName, chunkData); - // Get the fully qualified table name from the first channel in the blob. - // This only matters when the client is in Iceberg mode. In Iceberg mode, - // all channels in the blob belong to the same table. - String fullyQualifiedTableName = - blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName(); + } + } - final BlobPath blobPath = this.storageManager.generateBlobPath(fullyQualifiedTableName); + private List>> flushChunkData(String fullyQualifiedTableName, List> channelsDataPerTable) { + if (channelsDataPerTable.isEmpty()) { + return null; + } - long flushStartMs = System.currentTimeMillis(); - if (this.owningClient.flushLatency != null) { - latencyTimerContextMap.putIfAbsent( - blobPath.fileRegistrationPath, this.owningClient.flushLatency.time()); + List> blob = new ArrayList<>(); + List>> blobData = new ArrayList<>(); + + float totalBufferSizeInBytes = 0F; + + // Distribute work at table level, split the blob if reaching the blob size limit or the + // channel has different encryption key ids + for (ChannelData channelData : channelsDataPerTable) { + int idx = 0; + float totalBufferSizePerTableInBytes = 0F; + if (shouldStopProcessing( + totalBufferSizeInBytes, + totalBufferSizePerTableInBytes, + channelData, + channelsDataPerTable.get(idx - 1))) { + logger.logInfo( + "Creation of another blob is needed because of blob/chunk size limit or" + + " different encryption ids or different schema, client={}, table={}," + + " blobSize={}, chunkSize={}, nextChannelSize={}, encryptionId1={}," + + " encryptionId2={}, schema1={}, schema2={}", + this.owningClient.getName(), + channelData.getChannelContext().getTableName(), + totalBufferSizeInBytes, + totalBufferSizePerTableInBytes, + channelData.getBufferSize(), + channelData.getChannelContext().getEncryptionKeyId(), + channelsDataPerTable.get(idx - 1).getChannelContext().getEncryptionKeyId(), + channelData.getColumnEps().keySet(), + channelsDataPerTable.get(idx - 1).getColumnEps().keySet()); + blobData.add(blob); + blob = new ArrayList<>(); + totalBufferSizeInBytes = 0; } + totalBufferSizeInBytes += channelData.getBufferSize(); + totalBufferSizePerTableInBytes += channelData.getBufferSize(); + blob.add(channelData); + } - // Copy encryptionKeysPerTable from owning client - Map encryptionKeysPerTable = - new ConcurrentHashMap<>(); - this.owningClient - .getEncryptionKeysPerTable() - .forEach((k, v) -> encryptionKeysPerTable.put(k, new EncryptionKey(v))); - - Supplier supplier = - () -> { - try { - BlobMetadata blobMetadata = - buildAndUpload( - blobPath, blobData, fullyQualifiedTableName, encryptionKeysPerTable); - blobMetadata.getBlobStats().setFlushStartMs(flushStartMs); - return blobMetadata; - } catch (Throwable e) { - Throwable ex = e.getCause() == null ? e : e.getCause(); - String errorMessage = - String.format( - "Building blob failed, client=%s, blob=%s, exception=%s," - + " detail=%s, trace=%s, all channels in the blob will be" - + " invalidated", - this.owningClient.getName(), - blobPath.fileRegistrationPath, - ex, - ex.getMessage(), - getStackTrace(ex)); - logger.logError(errorMessage); - if (this.owningClient.getTelemetryService() != null) { - this.owningClient - .getTelemetryService() - .reportClientFailure(this.getClass().getSimpleName(), errorMessage); - } - - if (e instanceof IOException) { - invalidateAllChannelsInBlob(blobData, errorMessage); - return null; - } else if (e instanceof NoSuchAlgorithmException) { - throw new SFException(e, ErrorCode.MD5_HASHING_NOT_AVAILABLE); - } else if (e instanceof InvalidAlgorithmParameterException - | e instanceof NoSuchPaddingException - | e instanceof IllegalBlockSizeException - | e instanceof BadPaddingException - | e instanceof InvalidKeyException) { - throw new SFException(e, ErrorCode.ENCRYPTION_FAILURE); - } else { - throw new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()); - } - } - }; + blobData.add(blob); + return blobData; + } - blobs.add( - new Pair<>( - new BlobData<>(blobPath.fileRegistrationPath, blobData), - CompletableFuture.supplyAsync(supplier, this.buildUploadWorkers))); + void uploadChunkData(String fullyQualifiedTableName, List>> chunkData) { + List, CompletableFuture>> blobs = new ArrayList<>(); + // Kick off a build job + final BlobPath blobPath = this.storageManager.generateBlobPath(fullyQualifiedTableName); - logger.logInfo( - "buildAndUpload task added for client={}, blob={}, buildUploadWorkers stats={}", - this.owningClient.getName(), - blobPath, - this.buildUploadWorkers.toString()); + long flushStartMs = System.currentTimeMillis(); + if (this.owningClient.flushLatency != null) { + latencyTimerContextMap.putIfAbsent( + blobPath.fileRegistrationPath, this.owningClient.flushLatency.time()); } + // Copy encryptionKeysPerTable from owning client + Map encryptionKeysPerTable = + new ConcurrentHashMap<>(); + this.owningClient + .getEncryptionKeysPerTable() + .forEach((k, v) -> encryptionKeysPerTable.put(k, new EncryptionKey(v))); + + Supplier supplier = + () -> { + try { + BlobMetadata blobMetadata = + buildAndUpload( + blobPath, chunkData, fullyQualifiedTableName, encryptionKeysPerTable); + blobMetadata.getBlobStats().setFlushStartMs(flushStartMs); + return blobMetadata; + } catch (Throwable e) { + Throwable ex = e.getCause() == null ? e : e.getCause(); + String errorMessage = + String.format( + "Building blob failed, client=%s, blob=%s, exception=%s," + + " detail=%s, trace=%s, all channels in the blob will be" + + " invalidated", + this.owningClient.getName(), + blobPath.fileRegistrationPath, + ex, + ex.getMessage(), + getStackTrace(ex)); + logger.logError(errorMessage); + if (this.owningClient.getTelemetryService() != null) { + this.owningClient + .getTelemetryService() + .reportClientFailure(this.getClass().getSimpleName(), errorMessage); + } + + if (e instanceof IOException) { + invalidateAllChannelsInBlob(chunkData, errorMessage); + return null; + } else if (e instanceof NoSuchAlgorithmException) { + throw new SFException(e, ErrorCode.MD5_HASHING_NOT_AVAILABLE); + } else if (e instanceof InvalidAlgorithmParameterException + | e instanceof NoSuchPaddingException + | e instanceof IllegalBlockSizeException + | e instanceof BadPaddingException + | e instanceof InvalidKeyException) { + throw new SFException(e, ErrorCode.ENCRYPTION_FAILURE); + } else { + throw new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()); + } + } + }; + + blobs.add( + new Pair<>( + new BlobData<>(blobPath.fileRegistrationPath, chunkData), + CompletableFuture.supplyAsync(supplier, this.buildUploadWorkers))); + + logger.logInfo( + "buildAndUpload task added for client={}, blob={}, buildUploadWorkers stats={}", + this.owningClient.getName(), + blobPath, + this.buildUploadWorkers.toString()); + // Add the flush task futures to the register service this.registerService.addBlobs(blobs); } @@ -580,21 +577,21 @@ private boolean shouldStopProcessing( ChannelData prev) { return totalBufferSizeInBytes + current.getBufferSize() > MAX_BLOB_SIZE_IN_BYTES || totalBufferSizePerTableInBytes + current.getBufferSize() - > this.owningClient.getParameterProvider().getMaxChunkSizeInBytes() + > this.owningClient.getParameterProvider().getMaxChunkSizeInBytes() || !Objects.equals( - current.getChannelContext().getEncryptionKeyId(), - prev.getChannelContext().getEncryptionKeyId()) + current.getChannelContext().getEncryptionKeyId(), + prev.getChannelContext().getEncryptionKeyId()) || !current.getColumnEps().keySet().equals(prev.getColumnEps().keySet()); } /** * Builds and uploads blob to cloud storage. * - * @param blobPath Path of the destination blob in cloud storage - * @param blobData All the data for one blob. Assumes that all ChannelData in the inner List - * belongs to the same table. Will error if this is not the case + * @param blobPath Path of the destination blob in cloud storage + * @param blobData All the data for one blob. Assumes that all ChannelData in the inner List + * belongs to the same table. Will error if this is not the case * @param fullyQualifiedTableName the table name of the first channel in the blob, only matters in - * Iceberg mode + * Iceberg mode * @return BlobMetadata for FlushService.upload */ BlobMetadata buildAndUpload( @@ -603,8 +600,8 @@ BlobMetadata buildAndUpload( String fullyQualifiedTableName, Map encryptionKeysPerTable) throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, - NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException, - InvalidKeyException { + NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException, + InvalidKeyException { Timer.Context buildContext = Utils.createTimerContext(this.owningClient.buildLatency); // Construct the blob along with the metadata of the blob @@ -629,10 +626,10 @@ BlobMetadata buildAndUpload( /** * Upload a blob to Streaming Ingest dedicated stage * - * @param storage the storage to upload the blob - * @param blobPath full path of the blob - * @param blob blob data - * @param metadata a list of chunk metadata + * @param storage the storage to upload the blob + * @param blobPath full path of the blob + * @param blob blob data + * @param metadata a list of chunk metadata * @param blobStats an object to track latencies and other stats of the blob * @return BlobMetadata object used to create the register blob request */ @@ -695,9 +692,9 @@ void shutdown() throws InterruptedException { boolean isTerminated = this.flushWorker.awaitTermination(THREAD_SHUTDOWN_TIMEOUT_IN_SEC, TimeUnit.SECONDS) && this.registerWorker.awaitTermination( - THREAD_SHUTDOWN_TIMEOUT_IN_SEC, TimeUnit.SECONDS) + THREAD_SHUTDOWN_TIMEOUT_IN_SEC, TimeUnit.SECONDS) && this.buildUploadWorkers.awaitTermination( - THREAD_SHUTDOWN_TIMEOUT_IN_SEC, TimeUnit.SECONDS); + THREAD_SHUTDOWN_TIMEOUT_IN_SEC, TimeUnit.SECONDS); if (!isTerminated) { logger.logWarn("Tasks can't be terminated within the timeout, force shutdown now."); @@ -759,7 +756,9 @@ boolean throttleDueToQueuedFlushTasks() { return throttleOnQueuedTasks; } - /** Get whether we're running under test mode */ + /** + * Get whether we're running under test mode + */ boolean isTestMode() { return this.isTestMode; }