Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPE metal #2049

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/lczero-common
Submodule lczero-common updated 1 files
+21 −0 proto/net.proto
15 changes: 15 additions & 0 deletions src/neural/backends/metal/mps/NetworkGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,27 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat
alpha:(float)alpha
label:(NSString * __nonnull)label;

-(nonnull MPSGraphTensor *) relativePositionEncodingWithTensor:(MPSGraphTensor * __nonnull)tensor
mapTensor:(MPSGraphTensor * __nonnull)rpeMapTensor
weights:(float * __nonnull)rpeWeights
depth:(NSUInteger)depth
heads:(NSUInteger)heads
queries:(NSUInteger)queries
keys:(NSUInteger)keys
type:(NSUInteger)type
label:(NSString * __nonnull)label;

-(nonnull MPSGraphTensor *) getRpeMapTensor;

-(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnull)queries
withKeys:(MPSGraphTensor * __nonnull)keys
withValues:(MPSGraphTensor * __nonnull)values
heads:(NSUInteger)heads
parent:(MPSGraphTensor * __nonnull)parent
smolgen:(lczero::MultiHeadWeights::Smolgen * __nullable)smolgen
rpeQ:(float * __nullable)rpeQ
rpeK:(float * __nullable)rpeK
rpeV:(float * __nullable)rpeV
smolgenActivation:(NSString * __nullable)smolgenActivation
label:(NSString * __nonnull)label;

Expand Down
192 changes: 185 additions & 7 deletions src/neural/backends/metal/mps/NetworkGraph.mm
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ -(nonnull MPSGraphTensor *) addEncoderLayerWithParent:(MPSGraphTensor * __nonnul
heads:heads
parent:parent
smolgen:encoder.mha.has_smolgen ? &encoder.mha.smolgen : nil
rpeQ:encoder.mha.rpe_q.size() > 0 ? encoder.mha.rpe_q.data() : nil
rpeK:encoder.mha.rpe_k.size() > 0 ? encoder.mha.rpe_k.data() : nil
rpeV:encoder.mha.rpe_v.size() > 0 ? encoder.mha.rpe_v.data() : nil
smolgenActivation:smolgenActivation
label:[NSString stringWithFormat:@"%@/mha", label]];

Expand Down Expand Up @@ -746,12 +749,135 @@ -(nonnull MPSGraphTensor *) transposeChannelsWithTensor:(MPSGraphTensor * __nonn
name:[NSString stringWithFormat:@"%@/reshape", label]];
}

-(nonnull MPSGraphTensor *) relativePositionEncodingWithTensor:(MPSGraphTensor * __nonnull)tensor
mapTensor:(MPSGraphTensor * __nonnull)rpeMapTensor
weights:(float * __nonnull)rpeWeights
depth:(NSUInteger)depth
heads:(NSUInteger)heads
queries:(NSUInteger)queries
keys:(NSUInteger)keys
type:(NSUInteger)type
label:(NSString * __nonnull)label
{
// RPE weights factorization.
NSData * rpeWeightsData = [NSData dataWithBytesNoCopy:(void *)rpeWeights
length:depth * heads * 15 * 15 * sizeof(float)
freeWhenDone:NO];

// Leela weights are transposed prior to storage. So needs to be re-transposed.
MPSGraphTensor * rpeTensor = [self variableWithData:rpeWeightsData
shape:@[@(15 * 15), @(depth * heads)]
dataType:MPSDataTypeFloat32
name:[NSString stringWithFormat:@"%@/weights", label]];

rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose", label]];

rpeTensor = [self matrixMultiplicationWithPrimaryTensor:rpeTensor
secondaryTensor:rpeMapTensor
name:[NSString stringWithFormat:@"%@/factorize_matmul", label]];

rpeTensor = [self reshapeTensor:rpeTensor
withShape:@[@(depth), @(heads), @(queries), @(keys)]
name:[NSString stringWithFormat:@"%@/reshape", label]];

// Permutations to implement einsum.
// First permute rpeTensor to get D to dimension 3.
if (type == 0) {
// RPE-Q
// rpe: [D, H, Q, K] -> [H, Q, D, K]
rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_1", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_2", label]];
} else if (type == 1) {
// RPE-K
// rpe: [D, H, Q, K] -> [H, K, D, Q]
rpeTensor = [self transposeTensor:rpeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_1", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_2", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_3", label]];
} else if (type == 2) {
// RPE-V
// rpe: [D, H, Q, K] -> [H, Q, K, D]
rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_1", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_2", label]];
rpeTensor = [self transposeTensor:rpeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_3", label]];
}

// Second transpose Nabc -> abNc to allow abNc × abcd -> abNd, where N is the batch dimension.
// x: [B, H, Q, D] -> [H, Q, B, D] # RPE-Q
// x: [B, H, K, D] -> [H, K, B, D] # RPE-K
// x: [B, H, Q, K] -> [H, Q, B, K] # RPE-V
tensor = [self transposeTensor:tensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/a_transpose_1", label]];
tensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/a_transpose_2", label]];

// Finally matrix multiplication and squeeze.
// x: [H, Q, B, D] x [H, Q, D, K] -> [H, Q, B, K] # RPE-Q
// x: [H, K, B, D] x [H, K, D, Q] -> [H, K, B, Q] # RPE-K
// x: [H, Q, B, K] x [H, Q, K, D] -> [H, Q, B, D] # RPE-V
tensor = [self matrixMultiplicationWithPrimaryTensor:tensor
secondaryTensor:rpeTensor
name:[NSString stringWithFormat:@"%@/rpe/matmul", label]];

// Reverse the last transposition back to Nabd.
tensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/a_transpose_4", label]];
tensor = [self transposeTensor:tensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/a_transpose_5", label]];


if (type == 1) {
// RPE-K needs another transposition back to BHQK.
// x: [B, H, K, Q] -> [B, H, Q, K] # RPE-K
return [self transposeTensor:tensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/rpe/transpose_6", label]];
}

// x: [B, H, Q, K] # RPE-Q or RPE-K
// x: [B, H, Q, D] # RPE-V
return tensor;
}

-(nonnull MPSGraphTensor *) getRpeMapTensor
{
// RPE weights factorizer tensor
static MPSGraphTensor * rpeMapTensor = nil;

@synchronized (self) {
if (rpeMapTensor == nil) {
int rows = 15 * 15;
int cols = 64 * 64;
int row, col;
std::vector<float> rpeMap(rows * cols);
// 15 * 15 in units for distance pairs to 64 * 64 pairs of squares.
// Distance pairs mapped on rows, while square pairs mapped on columns.
for (NSUInteger i = 0; i < 8; i++) {
for (NSUInteger j = 0; j < 8; j++) {
for (NSUInteger k = 0; k < 8; k++) {
for (NSUInteger l = 0; l < 8; l++) {
row = 15 * (i - k + 7) + (j - l + 7);
col = 64 * (i * 8 + j) + k * 8 + l;
rpeMap[row * cols + col] = 1.0f;
}
}
}
}
NSData * rpeMapData = [NSData dataWithBytesNoCopy:(void *)rpeMap.data()
length:rows * cols * sizeof(float)
freeWhenDone:NO];

rpeMapTensor = [self variableWithData:rpeMapData
shape:@[@(rows), @(cols)]
dataType:MPSDataTypeFloat32
name:@"rpe_factor"];
}
}
return rpeMapTensor;
}

-(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnull)queries
withKeys:(MPSGraphTensor * __nonnull)keys
withValues:(MPSGraphTensor * __nonnull)values
heads:(NSUInteger)heads
parent:(MPSGraphTensor * __nonnull)parent
smolgen:(lczero::MultiHeadWeights::Smolgen * __nullable)smolgen
rpeQ:(float * __nullable)rpeQ
rpeK:(float * __nullable)rpeK
rpeV:(float * __nullable)rpeV
smolgenActivation:(NSString * __nullable)smolgenActivation
label:(NSString * __nonnull)label
{
Expand All @@ -769,10 +895,45 @@ -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnu
values = [self transposeTensor:values dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_v", label]];

// Scaled attention matmul.
keys = [self transposeTensor:keys dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_k_2", label]];
MPSGraphTensor * transposedKeys = [self transposeTensor:keys dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_k_2", label]];
MPSGraphTensor * attn = [self matrixMultiplicationWithPrimaryTensor:queries
secondaryTensor:keys
secondaryTensor:transposedKeys
name:[NSString stringWithFormat:@"%@/matmul_qk", label]];

if (rpeQ != nil || rpeK != nil) {
MPSGraphTensor * rpeMapTensor = [self getRpeMapTensor];

// Apply the RPELogits to each of Q and K.
if (rpeQ != nil) {
MPSGraphTensor * rpeQTensor = [self relativePositionEncodingWithTensor:queries
mapTensor:rpeMapTensor
weights:rpeQ
depth:depth
heads:heads
queries:64
keys:64
type:0 // Q-type
label:[NSString stringWithFormat:@"%@/rpeQ", label]];
attn = [self additionWithPrimaryTensor:attn
secondaryTensor:rpeQTensor
name:[NSString stringWithFormat:@"%@/rpeQ_add", label]];
}
if (rpeK != nil) {
MPSGraphTensor * rpeKTensor = [self relativePositionEncodingWithTensor:keys
mapTensor:rpeMapTensor
weights:rpeK
depth:depth
heads:heads
queries:64
keys:64
type:1 // K-type
label:[NSString stringWithFormat:@"%@/rpeK", label]];
attn = [self additionWithPrimaryTensor:attn
secondaryTensor:rpeKTensor
name:[NSString stringWithFormat:@"%@/rpeK_add", label]];
}
}

attn = [self divisionWithPrimaryTensor:attn
secondaryTensor:[self constantWithScalar:sqrt(depth)
shape:@[@1]
Expand Down Expand Up @@ -849,13 +1010,30 @@ -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnu
attn = [self applyActivationWithTensor:attn activation:@"softmax" label:label];

// matmul(scaled_attention_weights, v).
attn = [self matrixMultiplicationWithPrimaryTensor:attn
secondaryTensor:values
name:[NSString stringWithFormat:@"%@/matmul_v", label]];
MPSGraphTensor * output = [self matrixMultiplicationWithPrimaryTensor:attn
secondaryTensor:values
name:[NSString stringWithFormat:@"%@/matmul_v", label]];

if (rpeV != nil) {
MPSGraphTensor * rpeMapTensor = [self getRpeMapTensor];
// output = output + RPEValue(head_depth, name=name+'/rpe_v')(attention_weights)
MPSGraphTensor * rpeVTensor = [self relativePositionEncodingWithTensor:attn
mapTensor:rpeMapTensor
weights:rpeV
depth:depth
heads:heads
queries:64
keys:64
type:2 // V-type
label:[NSString stringWithFormat:@"%@/rpeV", label]];
output = [self additionWithPrimaryTensor:output
secondaryTensor:rpeVTensor
name:[NSString stringWithFormat:@"%@/rpeV_add", label]];
}

attn = [self transposeTensor:attn dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_a", label]];
output = [self transposeTensor:output dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_a", label]];

return [self reshapeTensor:attn withShape:@[@(-1), @64, @(dmodel)] name:[NSString stringWithFormat:@"%@/reshape_a", label]];
return [self reshapeTensor:output withShape:@[@(-1), @64, @(dmodel)] name:[NSString stringWithFormat:@"%@/reshape_a", label]];
}

-(nonnull MPSGraphTensor *) scaledQKMatmulWithQueries:(MPSGraphTensor * __nonnull)queries
Expand Down
8 changes: 5 additions & 3 deletions src/neural/backends/metal/network_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ MetalNetwork::MetalNetwork(const WeightsFile& file, const OptionsDict& options)
"' does not exist in this net.");
}

auto embedding = static_cast<InputEmbedding>(file.format().network_format().input_embedding());
builder_->build(kInputPlanes, weights, embedding, attn_body, attn_policy_, conv_policy_,
wdl_, moves_left_, activations, policy_head, value_head);
auto embedding = static_cast<InputEmbedding>(
file.format().network_format().input_embedding());
builder_->build(kInputPlanes, weights, embedding, attn_body, attn_policy_,
conv_policy_, wdl_, moves_left_, activations, policy_head,
value_head);
}

void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) {
Expand Down
5 changes: 4 additions & 1 deletion src/neural/network_legacy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ BaseWeights::MHA::MHA(const pblczero::Weights::MHA& mha)
dense_w(LayerAdapter(mha.dense_w()).as_vector()),
dense_b(LayerAdapter(mha.dense_b()).as_vector()),
smolgen(Smolgen(mha.smolgen())),
has_smolgen(mha.has_smolgen()) {}
has_smolgen(mha.has_smolgen()),
rpe_q(LayerAdapter(mha.rpe_q()).as_vector()),
rpe_k(LayerAdapter(mha.rpe_k()).as_vector()),
rpe_v(LayerAdapter(mha.rpe_v()).as_vector()) {}

BaseWeights::FFN::FFN(const pblczero::Weights::FFN& ffn)
: dense1_w(LayerAdapter(ffn.dense1_w()).as_vector()),
Expand Down
3 changes: 3 additions & 0 deletions src/neural/network_legacy.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ struct BaseWeights {
Vec dense_b;
Smolgen smolgen;
bool has_smolgen;
Vec rpe_q;
Vec rpe_k;
Vec rpe_v;
};

struct FFN {
Expand Down