Skip to content

Commit

Permalink
PreTrained Models Module
Browse files Browse the repository at this point in the history
  • Loading branch information
albernar committed Feb 28, 2024
1 parent d905475 commit a53e520
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 171 deletions.
1 change: 0 additions & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ run:

pages:
script:
- pip install sphinx sphinx-rtd-theme
- cd docs
- make clean autogen html
- mkdir ../public/
Expand Down
2 changes: 1 addition & 1 deletion ampligraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

__version__ = '2.0.1'
__all__ = ['datasets', 'latent_features', 'discovery', 'evaluation', 'utils']
__all__ = ['datasets', 'latent_features', 'discovery', 'evaluation', 'utils', 'pretrained_models']

logging.config.fileConfig(
pkg_resources.resource_filename(__name__, "logger.conf"),
Expand Down
156 changes: 5 additions & 151 deletions ampligraph/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from collections import namedtuple
from pathlib import Path

from ampligraph.utils.file_utils import _get_data_home, _fetch_file

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -141,154 +143,6 @@ def _clean_data(X, return_idx=False):
return filtered_X


def _get_data_home(data_home=None):
"""Get the location of the dataset folder to use.
Automatically determine the dataset folder to use.
If ``data_home`` is provided, a check is performed to see if the path exists and creates one if it does not.
If ``data_home`` is `None` the ``AMPLIGRAPH_ENV_NAME`` dataset is used.
If ``AMPLIGRAPH_ENV_NAME`` is not set, the default environment ``~/ampligraph_datasets`` is used.
Parameters
----------
data_home: str
The path to the folder that contains the datasets.
Returns
-------
str
The path to the dataset directory
"""

if data_home is None:
data_home = os.environ.get(
AMPLIGRAPH_ENV_NAME, os.path.join("~", "ampligraph_datasets")
)
data_home = os.path.expanduser(data_home)
if not os.path.exists(data_home):
os.makedirs(data_home)
logger.debug("data_home is set to {}".format(data_home))
return data_home


def _md5(file_path):
md5hash = hashlib.md5()
chunk_size = 4096
with open(file_path, "rb") as f:
content_buffer = f.read(chunk_size)
while content_buffer:
md5hash.update(content_buffer)
content_buffer = f.read(chunk_size)
return md5hash.hexdigest()


def _unzip_dataset(remote, source, destination, check_md5hash=False):
"""Unzip a file from a source location to a destination.
Parameters
----------
source: str
The path to the zipped file
destination: str
The destination directory to unzip the files to.
"""

# TODO - add error checking
with zipfile.ZipFile(source, "r") as zip_ref:
logger.debug("Unzipping {} to {}".format(source, destination))
zip_ref.extractall(destination)
if check_md5hash:
for file_name, remote_checksum in [
[remote.train_name, remote.train_checksum],
[remote.valid_name, remote.valid_checksum],
[remote.test_name, remote.test_checksum],
[remote.test_human_name, remote.test_human_checksum],
[remote.test_human_ids_name, remote.test_human_ids_checksum],
[remote.mapper_name, remote.mapper_checksum],
[remote.valid_negatives_name, remote.valid_negatives_checksum],
[remote.test_negatives_name, remote.test_negatives_checksum],
]:
file_path = os.path.join(
destination, remote.dataset_name, file_name
)
checksum = _md5(file_path)
if checksum != remote_checksum:
os.remove(source)
msg = (
"{} has an md5 checksum of ({}) which is different from the expected ({}), "
"the file may be corrupted.".format(
file_path, checksum, remote_checksum
)
)
logger.error(msg)
raise IOError(msg)
os.remove(source)


def _fetch_remote_data(remote, download_dir, data_home, check_md5hash=False):
"""Download a remote dataset.
Parameters
----------
remote: DatasetMetadata
Named tuple containing remote dataset meta information: dataset name, dataset filename,
url, train filename, validation filename, test filename, train checksum, valid checksum, test checksum.
download_dir: str
The location to download the file to.
data_home: str
The location to save the dataset.
check_md5hash: bool
Whether to check the MD5 hash of the dataset file.
"""

file_path = "{}.zip".format(download_dir)
if not Path(file_path).exists():
urllib.request.urlretrieve(remote.url, file_path)
# TODO - add error checking
_unzip_dataset(remote, file_path, data_home, check_md5hash)


def _fetch_dataset(remote, data_home=None, check_md5hash=False):
"""Get a dataset.
Gets the directory of a dataset. If the dataset is not found it is downloaded automatically.
Parameters
----------
remote: DatasetMetadata
Named tuple containing remote datasets meta information: dataset name, dataset filename,
url, train filename, validation filename, test filename, train checksum, valid checksum, test checksum.
data_home: str
The location to save the dataset to.
check_md5hash: bool
Whether to check the MD5 hash of the dataset file.
Returns
------
str
The location of the dataset.
"""
data_home = _get_data_home(data_home)
dataset_dir = os.path.join(data_home, remote.dataset_name)
if not os.path.exists(dataset_dir):
if remote.url is None:
msg = "No dataset at {} and no url provided.".format(dataset_dir)
logger.error(msg)
raise Exception(msg)

_fetch_remote_data(remote, dataset_dir, data_home, check_md5hash)
return dataset_dir


def _add_reciprocal_relations(triples_df):
"""Add reciprocal relations to the triples
Expand Down Expand Up @@ -430,7 +284,7 @@ def _load_dataset(
dataset_metadata.url.rfind("/")
+ 1: dataset_metadata.url.rfind(".")
]
dataset_path = _fetch_dataset(dataset_metadata, data_home, check_md5hash)
dataset_path = _fetch_file(dataset_metadata, data_home, check_md5hash, file_type='datasets')
train = load_from_csv(
dataset_path,
dataset_metadata.train_name,
Expand Down Expand Up @@ -1231,7 +1085,7 @@ def load_from_rdf(
"""

logger.debug("Loading rdf data from {}.".format(file_name))
data_home = _get_data_home(data_home)
data_home = _get_data_home(data_home, file_type="datasets")
from rdflib import Graph

g = Graph()
Expand Down Expand Up @@ -1290,7 +1144,7 @@ def load_from_ntriples(
"""

logger.debug("Loading rdf ntriples from {}.".format(file_name))
data_home = _get_data_home(data_home)
data_home = _get_data_home(data_home, file_type="datasets")
df = pd.read_csv(
os.path.join(data_home, folder_name, file_name),
sep=" ",
Expand Down
20 changes: 6 additions & 14 deletions ampligraph/latent_features/models/ScoringBasedEmbeddingModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,9 +1345,7 @@ def get_emb_matrix_test(self, part_number=1, number_of_parts=1):
"""
if number_of_parts == 1:
if self.entities_subset.shape[0] != 0:
out = tf.nn.embedding_lookup(
self.encoding_layer.ent_emb, self.entities_subset
)
out = self.encoding_layer(self.entities_subset, type_of="e")
else:
out = self.encoding_layer.ent_emb
return out, 0, out.shape[0] - 1
Expand Down Expand Up @@ -1635,9 +1633,7 @@ def evaluate(
tf.int32, tf.int32, -1, -1, -2
)
if entities_subset is not None:
entities_subset = self.data_indexer.get_indexes(
entities_subset, "e"
)
entities_subset = self.get_indexes(entities_subset, "e")
self.entities_subset = tf.constant(entities_subset, dtype=tf.int32)
self.mapping_dict.insert(
self.entities_subset, tf.range(self.entities_subset.shape[0])
Expand Down Expand Up @@ -2253,7 +2249,7 @@ def get_embeddings(self, entities, embedding_type="e"):
"""

if embedding_type == "e":
lookup_concept = self.data_indexer.get_indexes(entities, "e")
lookup_concept = self.get_indexes(entities, "e")
if self.is_partitioned_training:
emb_out = []
with shelve.open(
Expand All @@ -2262,11 +2258,9 @@ def get_embeddings(self, entities, embedding_type="e"):
for ent_id in lookup_concept:
emb_out.append(ent_emb[str(ent_id)])
else:
return tf.nn.embedding_lookup(
self.encoding_layer.ent_emb, lookup_concept
).numpy()
return self.encoding_layer(lookup_concept, type_of="e").numpy()
elif embedding_type == "r":
lookup_concept = self.data_indexer.get_indexes(entities, "r")
lookup_concept = self.get_indexes(entities, "r")
if self.is_partitioned_training:
emb_out = []
with shelve.open(
Expand All @@ -2275,9 +2269,7 @@ def get_embeddings(self, entities, embedding_type="e"):
for rel_id in lookup_concept:
emb_out.append(rel_emb[str(rel_id)])
else:
return tf.nn.embedding_lookup(
self.encoding_layer.rel_emb, lookup_concept
).numpy()
return self.encoding_layer(lookup_concept, type_of="r").numpy()
else:
msg = "Invalid entity type: {}".format(embedding_type)
raise ValueError(msg)
15 changes: 15 additions & 0 deletions ampligraph/pretrained_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""Support for loading and managing pretrained models."""
from .pretrained_utils import (
load_pretrained_model
)

__all__ = [
"load_pretrained_model"
]
103 changes: 103 additions & 0 deletions ampligraph/pretrained_models/pretrained_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import logging
from collections import namedtuple
from ampligraph.utils.file_utils import _fetch_file
from ampligraph.utils.model_utils import restore_model

AMPLIGRAPH_ENV_NAME = "AMPLIGRAPH_DATA_HOME"

ModelMetadata = namedtuple(
"ModelMetadata",
[
"scoring_type",
"dataset",
"pretrained_model_name",
"url",
"model_checksum"
],
defaults=(None, None, None, None, None),
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def load_pretrained_model(dataset, scoring_type, data_home=None):
"""
Function to load a pretrained model.
This function allows downloading and loading one of the AmpliGraph pre-trained
model on benchmark datasets.
Parameters
----------
dataset: str
Specify the dataset on which the pre-trained model was built. The possible
value is one of `["fb15k-237", "wn18rr", "yago310", "fb15k", "wn18rr"]`.
scoring_type: str
The scoring function used when training the model. The possible value is one of
`["TransE", "DistMult", "ComplEx", "HolE", "RotatE"]`.
Return
------
model: ScoringBasedEmbeddingModel
The pre-trained :class:`~ampligraph.latent_features.ScoringBasedEmbeddingModel`.
Example
-------
>>> from ampligraph.datasets import load_fb15k_237
>>> from ampligraph.pretrained_models import load_pretrained_model
>>> from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score
>>>
>>> dataset = load_fb15k_237()
>>> model = load_pretrained_model(dataset_name="fb15k-237", scoring_type="ComplEx")
>>> ranks = model.evaluate(
>>> dataset['test'],
>>> corrupt_side='s,o',
>>> use_filter={'train': dataset['train'],
>>> 'valid': dataset['valid'],
>>> 'test': dataset['test']}
>>> )
>>> print(f"mr_score: {mr_score(ranks)}")
>>> print(f"mrr_score: {mrr_score(ranks)}")
>>> print(f"hits@1: {hits_at_n_score(ranks, 1)}")
>>> print(f"hits@10: {hits_at_n_score(ranks, 10)}")
"""
assert dataset in ["fb15k-237", "wn18rr", "yago310", "fb15k", "wn18rr"], \
f"The dataset you specified is not one of the available ones! Try with one of" \
f"the following: ['fb15k-237', 'wn18rr', 'yago310', 'fb15k', 'wn18rr']."
assert scoring_type in ["TransE", "DistMult", "ComplEx", "HolE", "RotatE"], \
f"The scoring type you provided is not one of the available ones! Try with one of" \
f"the following: ['TransE', 'DistMult', 'ComplEx', 'HolE', 'RotatE']."

model_name = scoring_type.upper()
dataset_name = dataset.upper()
pretrained_model_name = dataset_name + "_" + model_name
filename = pretrained_model_name + ".zip"
url = "https://ampligraph.s3.eu-west-1.amazonaws.com/pretrained-models-v2.0/" + filename

metadata = ModelMetadata(
scoring_type=scoring_type,
dataset=dataset,
pretrained_model_name=pretrained_model_name,
url=url
)

# with this command we download the .zip file and unzip it, so that, in the
# desired folder, we'll have the model ready to be loaded.
model_path = _fetch_file(metadata, data_home, file_type='models')

return restore_model(model_path)





Loading

0 comments on commit a53e520

Please sign in to comment.