-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
26 lines (20 loc) · 784 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import hydra
from omegaconf import DictConfig
from src.models import BaseModel, RagModel
import warnings
@hydra.main(version_base=None, config_path="configs/base", config_name="parameters")
def main(cfg: DictConfig) -> None:
base_model = BaseModel(model_name=cfg.get("model_name"))
rag_model = RagModel(
model_name=cfg.get("model_name"),
chunk_size=cfg.get("rag_chunk_size"),
chunk_overlap=cfg.get("rag_chunk_overlap")
)
question = "What songs is Taylor Swift known for?"
print(f"Question: {question}")
print(f"Base answer: {base_model.chat(question)}")
print(f"Answer with RAG: {rag_model.chat(question)}")
if __name__ == "__main__":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
main()