-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpipeline.py
52 lines (42 loc) · 1.63 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from pathlib import Path
from haystack.nodes import PromptNode, PromptTemplate
def get_prompt(input: str, prompt_template: Path) -> str:
prompt = prompt_template.read_text()
return prompt.format(input=input)
class GithubEventPrompter:
"""
This class is responsible for querying the OpenAI API for a prompt
based on a Github event.
"""
def __init__(self, openai_key: str, prompt_template: Path):
self._prompt_node = PromptNode(
model_name_or_path="text-davinci-003", api_key=openai_key, max_length=1000
)
self._github_template = PromptTemplate(
name="github-events", prompt_text=prompt_template.read_text()
)
def query(self, event: str, input: str, username: str) -> str:
result = self._prompt_node.prompt(
prompt_template=self._github_template,
input=input,
event=event,
username=username,
)
return result
class SummaryPrompter:
"""
This class is responsible for querying the OpenAI API for a prompt
based on a Github event.
"""
def __init__(self, openai_key: str, prompt_template: Path):
self._prompt_node = PromptNode(
model_name_or_path="text-davinci-003", api_key=openai_key, max_length=1000
)
self._summary_template = PromptTemplate(
name="github-summary", prompt_text=prompt_template.read_text()
)
def query(self, events: str, username: str) -> str:
result = self._prompt_node.prompt(
prompt_template=self._summary_template, events=events, username=username
)
return result