From 7f06626a226bd64207c61ead3f23e6809172e21c Mon Sep 17 00:00:00 2001 From: Balanagireddy M Date: Mon, 26 Jun 2023 21:34:59 -0700 Subject: [PATCH] Added method comments --- download.py | 6 ++++++ train_retriever.py | 8 ++++++++ web_demo.py | 3 ++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/download.py b/download.py index 73e593a..f96a2d5 100644 --- a/download.py +++ b/download.py @@ -8,6 +8,7 @@ sess = requests.Session() def parse_args(): + """ Parse arguments """ parser = argparse.ArgumentParser() parser.add_argument('--link', '-l', type=str, required=True, help='Share link of Tsinghua Cloud') parser.add_argument('--password', '-p', type=str, default='', help='Password of the share link') @@ -16,6 +17,7 @@ def parse_args(): return parser.parse_args() def get_share_key(url): + """ Get share key from share link """ prefix = 'https://cloud.tsinghua.edu.cn/d/' if not url.startswith(prefix): raise ValueError('Share link of Tsinghua Cloud should start with {}'.format(prefix)) @@ -26,6 +28,7 @@ def get_share_key(url): def dfs_search_files(share_key: str, path="/"): + """ DFS search all files in the share link """ global sess filelist = [] print('https://cloud.tsinghua.edu.cn/api/v2.1/share-links/{}/dirents/?path={}'.format(share_key, path)) @@ -40,6 +43,7 @@ def dfs_search_files(share_key: str, path="/"): return filelist def download_single_file(url: str, fname: str): + """ Download single file """ global sess resp = sess.get(url, stream=True) total = int(resp.headers.get('content-length', 0)) @@ -58,6 +62,7 @@ def download_single_file(url: str, fname: str): bar.update(size) def download(url, save_dir): + """ Download all files in the share link """ share_key = get_share_key(url) print("Searching for files to be downloaded...") @@ -101,6 +106,7 @@ def download(url, save_dir): return flag def make_data(sample): + """ Make data for training """ src = "" for ix, ref in enumerate(sample['references']): src += "Reference [%d]: %s\\" % (ix+1, ref) diff --git a/train_retriever.py b/train_retriever.py index 1347e8f..c0b8563 100644 --- a/train_retriever.py +++ b/train_retriever.py @@ -8,7 +8,10 @@ from torch.utils.data.distributed import DistributedSampler class QuestionReferenceDensity(torch.nn.Module): + """ Question Reference Density Model """ + def __init__(self) -> None: + """ Initialize the model """ super().__init__() self.question_encoder = AutoModel.from_pretrained("facebook/contriever-msmarco") self.reference_encoder = AutoModel.from_pretrained("facebook/contriever-msmarco") @@ -17,12 +20,14 @@ def __init__(self) -> None: print("Number of parameter: %.2fM" % (total / 1e6)) def mean_pooling(self, token_embeddings, mask): + """ Mean Pooling """ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] return sentence_embeddings def forward(self, question, pos, neg): + """ Forward """ global args q = self.question_encoder(**question) @@ -41,6 +46,7 @@ def forward(self, question, pos, neg): @staticmethod def loss(l_pos, l_neg): + """ Loss """ return torch.nn.functional.cross_entropy(torch.cat([l_pos, l_neg], dim=1), torch.arange(0, len(l_pos), dtype=torch.long, device=args.device)) @staticmethod @@ -53,6 +59,7 @@ def acc(l_pos, l_neg): class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): + """ Linear warmup and then linear decay. """ def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1): self.warmup = warmup self.total = total @@ -104,6 +111,7 @@ def save(name): model.reference_encoder.save_pretrained(os.path.join(log_dir, name, "reference_encoder")) def train(max_epoch = 10, eval_step = 200, save_step = 400, print_step = 50): + """ Train the model """ step = 0 for epoch in range(0, max_epoch): print("EPOCH %d"%epoch) diff --git a/web_demo.py b/web_demo.py index b1e2b80..e1a82e6 100644 --- a/web_demo.py +++ b/web_demo.py @@ -39,7 +39,8 @@ """ -def query(query: str): +def query(query: str): + """ Query the model """ refs = [] answer = "Loading ..."