From 9566d8820ce64e0b5dd17934ab4b779695ce79d6 Mon Sep 17 00:00:00 2001 From: Dongxu Li Date: Mon, 19 Sep 2022 17:31:05 +0000 Subject: [PATCH] update app. --- app/classification.py | 12 ++++++------ app/main.py | 2 +- app/multimodal_search.py | 19 ++++++++++++------- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/app/classification.py b/app/classification.py index 8be3c940a..e08434d24 100644 --- a/app/classification.py +++ b/app/classification.py @@ -139,10 +139,10 @@ def app(): with torch.no_grad(): image_features = feature_extractor.extract_features( sample, mode="image" - ).image_features[:, 0] + ).image_embeds_proj[:, 0] text_features = feature_extractor.extract_features( sample, mode="text" - ).text_features[:, 0] + ).text_embeds_proj[:, 0] sims = (image_features @ text_features.t())[ 0 ] / feature_extractor.temp @@ -173,10 +173,10 @@ def app(): # with torch.no_grad(): # image_features = feature_extractor.extract_features( # sample, mode="image" - # ).image_features[:, 0] + # ).image_embeds_proj[:, 0] # text_features = feature_extractor.extract_features( # sample, mode="text" - # ).text_features[:, 0] + # ).text_embeds_proj[:, 0] # st.write(image_features.shape) # st.write(text_features.shape) @@ -206,8 +206,8 @@ def app(): with torch.no_grad(): clip_features = model.extract_features(sample) - image_features = clip_features.image_features - text_features = clip_features.text_features + image_features = clip_features.image_embeds_proj + text_features = clip_features.text_embeds_proj image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) diff --git a/app/main.py b/app/main.py index 360b5ce14..108c46f8c 100644 --- a/app/main.py +++ b/app/main.py @@ -17,7 +17,7 @@ app = MultiPage() app.add_page("Image Description Generation", caption.app) - # app.add_page("Multimodal Search", ms.app) + app.add_page("Multimodal Search", ms.app) app.add_page("Visual Question Answering", vqa.app) app.add_page("Image Text Matching", itm.app) app.add_page("Text Localization", tl.app) diff --git a/app/multimodal_search.py b/app/multimodal_search.py index 10419fb76..318986f12 100644 --- a/app/multimodal_search.py +++ b/app/multimodal_search.py @@ -32,12 +32,17 @@ allow_output_mutation=True, ) def load_feat(): - path2feat = torch.load( - os.path.join( - os.path.dirname(__file__), - "resources/path2feat_coco_train2014.pth", - ) - ) + from lavis.common.utils import download_url + + dirname = os.path.join(os.path.dirname(__file__), "assets") + filename = "path2feat_coco_train2014.pth" + filepath = os.path.join(dirname, filename) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth" + + if not os.path.exists(filepath): + download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth") + + path2feat = torch.load(filepath) paths = sorted(path2feat.keys()) all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device) @@ -98,7 +103,7 @@ def app(): with torch.no_grad(): text_feature = feature_extractor.extract_features( sample, mode="text" - ).text_features[0, 0] + ).text_embeds_proj[0, 0] path2feat, paths, all_img_feats = load_feat() all_img_feats.to(device)