Skip to content

Commit

Permalink
Spark 3.5: Fix broadcasting specs in RewriteTablePath (#11982)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang authored Jan 24, 2025
1 parent 67c52b5 commit d693f83
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.iceberg.RewriteTablePathUtil.PositionDeleteReaderWriter;
import org.apache.iceberg.RewriteTablePathUtil.RewriteResult;
import org.apache.iceberg.Schema;
import org.apache.iceberg.SerializableTable;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.StaticTableOperations;
import org.apache.iceberg.StructLike;
Expand All @@ -63,10 +62,12 @@
import org.apache.iceberg.io.OutputFile;
import org.apache.iceberg.orc.ORC;
import org.apache.iceberg.parquet.Parquet;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.spark.JobGroupInfo;
import org.apache.iceberg.spark.source.SerializableTableWithSize;
import org.apache.iceberg.util.Pair;
import org.apache.spark.api.java.function.ForeachFunction;
import org.apache.spark.api.java.function.MapFunction;
Expand Down Expand Up @@ -96,6 +97,7 @@ public class RewriteTablePathSparkAction extends BaseSparkAction<RewriteTablePat
private String stagingDir;

private final Table table;
private Broadcast<Table> tableBroadcast = null;

RewriteTablePathSparkAction(SparkSession spark, Table table) {
super(spark);
Expand Down Expand Up @@ -457,18 +459,13 @@ private RewriteContentFileResult rewriteManifests(
Dataset<ManifestFile> manifestDS =
spark().createDataset(Lists.newArrayList(toRewrite), manifestFileEncoder);

Broadcast<Table> serializableTable = sparkContext().broadcast(SerializableTable.copyOf(table));
Broadcast<Map<Integer, PartitionSpec>> specsById =
sparkContext().broadcast(tableMetadata.specsById());

return manifestDS
.repartition(toRewrite.size())
.map(
toManifests(
serializableTable,
tableBroadcast(),
stagingDir,
tableMetadata.formatVersion(),
specsById,
sourcePrefix,
targetPrefix),
Encoders.bean(RewriteContentFileResult.class))
Expand All @@ -478,10 +475,9 @@ private RewriteContentFileResult rewriteManifests(
}

private static MapFunction<ManifestFile, RewriteContentFileResult> toManifests(
Broadcast<Table> tableBroadcast,
Broadcast<Table> table,
String stagingLocation,
int format,
Broadcast<Map<Integer, PartitionSpec>> specsById,
String sourcePrefix,
String targetPrefix) {

Expand All @@ -491,24 +487,12 @@ private static MapFunction<ManifestFile, RewriteContentFileResult> toManifests(
case DATA:
result.appendDataFile(
writeDataManifest(
manifestFile,
tableBroadcast,
stagingLocation,
format,
specsById,
sourcePrefix,
targetPrefix));
manifestFile, table, stagingLocation, format, sourcePrefix, targetPrefix));
break;
case DELETES:
result.appendDeleteFile(
writeDeleteManifest(
manifestFile,
tableBroadcast,
stagingLocation,
format,
specsById,
sourcePrefix,
targetPrefix));
manifestFile, table, stagingLocation, format, sourcePrefix, targetPrefix));
break;
default:
throw new UnsupportedOperationException(
Expand All @@ -520,17 +504,16 @@ private static MapFunction<ManifestFile, RewriteContentFileResult> toManifests(

private static RewriteResult<DataFile> writeDataManifest(
ManifestFile manifestFile,
Broadcast<Table> tableBroadcast,
Broadcast<Table> table,
String stagingLocation,
int format,
Broadcast<Map<Integer, PartitionSpec>> specsByIdBroadcast,
String sourcePrefix,
String targetPrefix) {
try {
String stagingPath = RewriteTablePathUtil.stagingPath(manifestFile.path(), stagingLocation);
FileIO io = tableBroadcast.getValue().io();
FileIO io = table.getValue().io();
OutputFile outputFile = io.newOutputFile(stagingPath);
Map<Integer, PartitionSpec> specsById = specsByIdBroadcast.getValue();
Map<Integer, PartitionSpec> specsById = table.getValue().specs();
return RewriteTablePathUtil.rewriteDataManifest(
manifestFile, outputFile, io, format, specsById, sourcePrefix, targetPrefix);
} catch (IOException e) {
Expand All @@ -540,17 +523,16 @@ private static RewriteResult<DataFile> writeDataManifest(

private static RewriteResult<DeleteFile> writeDeleteManifest(
ManifestFile manifestFile,
Broadcast<Table> tableBroadcast,
Broadcast<Table> table,
String stagingLocation,
int format,
Broadcast<Map<Integer, PartitionSpec>> specsByIdBroadcast,
String sourcePrefix,
String targetPrefix) {
try {
String stagingPath = RewriteTablePathUtil.stagingPath(manifestFile.path(), stagingLocation);
FileIO io = tableBroadcast.getValue().io();
FileIO io = table.getValue().io();
OutputFile outputFile = io.newOutputFile(stagingPath);
Map<Integer, PartitionSpec> specsById = specsByIdBroadcast.getValue();
Map<Integer, PartitionSpec> specsById = table.getValue().specs();
return RewriteTablePathUtil.rewriteDeleteManifest(
manifestFile,
outputFile,
Expand All @@ -574,21 +556,12 @@ private void rewritePositionDeletes(TableMetadata metadata, Set<DeleteFile> toRe
Dataset<DeleteFile> deleteFileDs =
spark().createDataset(Lists.newArrayList(toRewrite), deleteFileEncoder);

Broadcast<Table> serializableTable = sparkContext().broadcast(SerializableTable.copyOf(table));
Broadcast<Map<Integer, PartitionSpec>> specsById =
sparkContext().broadcast(metadata.specsById());

PositionDeleteReaderWriter posDeleteReaderWriter = new SparkPositionDeleteReaderWriter();
deleteFileDs
.repartition(toRewrite.size())
.foreach(
rewritePositionDelete(
serializableTable,
specsById,
sourcePrefix,
targetPrefix,
stagingDir,
posDeleteReaderWriter));
tableBroadcast(), sourcePrefix, targetPrefix, stagingDir, posDeleteReaderWriter));
}

private static class SparkPositionDeleteReaderWriter implements PositionDeleteReaderWriter {
Expand All @@ -611,17 +584,16 @@ public PositionDeleteWriter<Record> writer(
}

private ForeachFunction<DeleteFile> rewritePositionDelete(
Broadcast<Table> tableBroadcast,
Broadcast<Map<Integer, PartitionSpec>> specsById,
Broadcast<Table> tableArg,
String sourcePrefixArg,
String targetPrefixArg,
String stagingLocationArg,
PositionDeleteReaderWriter posDeleteReaderWriter) {
return deleteFile -> {
FileIO io = tableBroadcast.getValue().io();
FileIO io = tableArg.getValue().io();
String newPath = RewriteTablePathUtil.stagingPath(deleteFile.location(), stagingLocationArg);
OutputFile outputFile = io.newOutputFile(newPath);
PartitionSpec spec = specsById.getValue().get(deleteFile.specId());
PartitionSpec spec = tableArg.getValue().specs().get(deleteFile.specId());
RewriteTablePathUtil.rewritePositionDeleteFile(
deleteFile,
outputFile,
Expand Down Expand Up @@ -730,4 +702,13 @@ private String getMetadataLocation(Table tbl) {
!metadataDir.isEmpty(), "Failed to get the metadata file root directory");
return metadataDir;
}

@VisibleForTesting
Broadcast<Table> tableBroadcast() {
if (tableBroadcast == null) {
this.tableBroadcast = sparkContext().broadcast(SerializableTableWithSize.copyOf(table));
}

return tableBroadcast;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@
import org.apache.iceberg.spark.source.ThreeColumnRecord;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.Pair;
import org.apache.spark.SparkEnv;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockInfoManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BroadcastBlockId;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -910,6 +916,17 @@ public void testDeleteFrom() throws Exception {
assertEquals("Rows must match", originalData, copiedData);
}

@Test
public void testKryoDeserializeBroadcastValues() {
sparkContext.getConf().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
RewriteTablePathSparkAction action =
(RewriteTablePathSparkAction) actions().rewriteTablePath(table);
Broadcast<Table> tableBroadcast = action.tableBroadcast();
// force deserializing broadcast values
removeBroadcastValuesFromLocalBlockManager(tableBroadcast.id());
assertThat(tableBroadcast.getValue().uuid()).isEqualTo(table.uuid());
}

protected void checkFileNum(
int versionFileCount,
int manifestListCount,
Expand Down Expand Up @@ -1049,4 +1066,15 @@ private PositionDelete<GenericRecord> positionDelete(
posDelete.set(path, position, nested);
return posDelete;
}

private void removeBroadcastValuesFromLocalBlockManager(long id) {
BlockId blockId = new BroadcastBlockId(id, "");
SparkEnv env = SparkEnv.get();
env.broadcastManager().cachedValues().clear();
BlockManager blockManager = env.blockManager();
BlockInfoManager blockInfoManager = blockManager.blockInfoManager();
blockInfoManager.lockForWriting(blockId, true);
blockInfoManager.removeBlock(blockId);
blockManager.memoryStore().remove(blockId);
}
}

0 comments on commit d693f83

Please sign in to comment.