From 9086f84aaacafa3e166f496883d7290a1615efda Mon Sep 17 00:00:00 2001 From: koguhnhyeok Date: Thu, 9 Nov 2023 22:22:49 +0900 Subject: [PATCH] Fixed a warning issue when accessing the internal storage of the torch_geometric dataset and method iteration error. --- graphgym/loader_pyg.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/graphgym/loader_pyg.py b/graphgym/loader_pyg.py index 7b7e6735..240f31b0 100644 --- a/graphgym/loader_pyg.py +++ b/graphgym/loader_pyg.py @@ -75,7 +75,7 @@ def load_pyg(name, dataset_dir): def set_dataset_attr(dataset, name, value, size): dataset._data_list = None - dataset.data[name] = value + dataset._data[name] = value if dataset.slices is not None: dataset.slices[name] = torch.tensor([0, size], dtype=torch.long) @@ -102,9 +102,9 @@ def load_ogb(name, dataset_dir): splits = dataset.get_idx_split() split_names = ['train_mask', 'val_mask', 'test_mask'] for i, key in enumerate(splits.keys()): - mask = index_to_mask(splits[key], size=dataset.data.y.shape[0]) + mask = index_to_mask(splits[key], size=dataset._data.y.shape[0]) set_dataset_attr(dataset, split_names[i], mask, len(mask)) - edge_index = to_undirected(dataset.data.edge_index) + edge_index = to_undirected(dataset._data.edge_index) set_dataset_attr(dataset, 'edge_index', edge_index, edge_index.shape[1]) @@ -127,7 +127,7 @@ def load_ogb(name, dataset_dir): dataset.transform = neg_sampling_transform else: id_neg = negative_sampling(edge_index=id, - num_nodes=dataset.data.num_nodes, + num_nodes=dataset._data.num_nodes, num_neg_samples=id.shape[1]) id_all = torch.cat([id, id_neg], dim=-1) label = create_link_label(id, id_neg) @@ -190,24 +190,24 @@ def set_dataset_info(dataset): # get dim_in and dim_out try: - cfg.share.dim_in = dataset.data.x.shape[1] + cfg.share.dim_in = dataset._data.x.shape[1] except Exception: cfg.share.dim_in = 1 try: if cfg.dataset.task_type == 'classification': - cfg.share.dim_out = torch.unique(dataset.data.y).shape[0] + cfg.share.dim_out = torch.unique(dataset._data.y).shape[0] else: - cfg.share.dim_out = dataset.data.y.shape[1] + cfg.share.dim_out = dataset._data.y.shape[1] except Exception: cfg.share.dim_out = 1 # count number of dataset splits cfg.share.num_splits = 1 - for key in dataset.data.keys: + for key in dataset._data.keys(): if 'val' in key: cfg.share.num_splits += 1 break - for key in dataset.data.keys: + for key in dataset._data.keys(): if 'test' in key: cfg.share.num_splits += 1 break @@ -297,14 +297,14 @@ def create_loader(): dataset = create_dataset() # train loader if cfg.dataset.task == 'graph': - id = dataset.data['train_graph_index'] + id = dataset._data['train_graph_index'] loaders = [ get_loader(dataset[id], cfg.train.sampler, cfg.train.batch_size, shuffle=True) ] - delattr(dataset.data, 'train_graph_index') + delattr(dataset._data, 'train_graph_index') else: loaders = [ get_loader(dataset, @@ -317,13 +317,13 @@ def create_loader(): for i in range(cfg.share.num_splits - 1): if cfg.dataset.task == 'graph': split_names = ['val_graph_index', 'test_graph_index'] - id = dataset.data[split_names[i]] + id = dataset._data[split_names[i]] loaders.append( get_loader(dataset[id], cfg.val.sampler, cfg.train.batch_size, shuffle=False)) - delattr(dataset.data, split_names[i]) + delattr(dataset._data, split_names[i]) else: loaders.append( get_loader(dataset,