diff --git a/model/Dataset.py b/model/Dataset.py index c2ee740..e15af49 100644 --- a/model/Dataset.py +++ b/model/Dataset.py @@ -62,7 +62,7 @@ def __getitem__(self, idx): img = None matrix = pd.read_csv(csv_path, skiprows=1, header=None).values.astype(np.float32) - # 计算残基个数 + residue_count = np.count_nonzero(matrix[0, :]) + 1 return img, torch.tensor(matrix), torch.tensor(residue_count, dtype=torch.float32) @@ -101,4 +101,4 @@ def forward(self, x): out = torch.bmm(value, attention.permute(0, 2, 1)) out = out.view(batch_size, channels, height, width) - return out + x # skip connection \ No newline at end of file + return out + x # skip connection