-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDocumentAnalyzer.py
103 lines (86 loc) · 3.74 KB
/
DocumentAnalyzer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# import libraries
import os
from azure.ai.documentintelligence import DocumentIntelligenceClient
from azure.ai.documentintelligence.models import (AnalyzeDocumentRequest,
AnalyzeResult)
from azure.core.credentials import AzureKeyCredential
class DocumentAnalyzer:
def _get_words(self, page, line):
result = []
for word in page.words:
if self._in_span(word, line.spans):
result.append(word)
return result
def _in_span(self, word, spans):
for span in spans:
if word.span.offset >= span.offset and (
word.span.offset + word.span.length
) <= (span.offset + span.length):
return True
return False
def _get_env(self):
try:
self.key = os.environ["DOCUMENTAI_KEY"]
self.endpoint = os.environ["DOCUMENTAI_ENDPOINT"]
self.kb_container_name = os.environ["KB_CONTAINER_NAME"]
self.rag_container_name = os.environ["RAG_CONTAINER_NAME"]
self.storage_account = os.environ["STORAGE_ACCOUNT"]
except KeyError:
raise Exception("Env Variables Not Found!")
def __init__(self):
self._get_env()
self.client = DocumentIntelligenceClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.key)
)
def _get_blob_url(self, blob_key, container_name):
blob_key = blob_key.strip()
return f"https://{self.storage_account}.blob.core.windows.net/{container_name}/{blob_key}"
def analyze_blob(self, blob_key: str, key: str):
if blob_key == "":
raise Exception("Blob Key Cannot be ''.")
if key != "rag" and key != "kb":
raise Exception("Choose key 'rag' or 'kb'.")
# choose the appropriate container
container_name = (
self.rag_container_name if key == "rag" else self.kb_container_name
)
poller = self.client.begin_analyze_document(
"prebuilt-layout",
AnalyzeDocumentRequest(
url_source=self._get_blob_url(blob_key, container_name)
),
)
result: AnalyzeResult = poller.result()
if result.styles and any([style.is_handwritten for style in result.styles]):
print("Document contains handwritten content")
else:
print("Document does not contain handwritten content")
# the code below gives the content of the document
# use streaming as the content can be large
content = ""
for page in result.pages:
if page.lines:
for _, line in enumerate(page.lines):
# vector db
content += line.content + " "
return content
def analyze_document(self, file_path: str):
if not file_path or file_path == "":
raise Exception("Invalid file path")
with open(file_path, "rb") as f:
bytes_source = f.read()
poller = self.client.begin_analyze_document(
"prebuilt-layout", AnalyzeDocumentRequest(bytes_source=bytes_source)
)
result: AnalyzeResult = poller.result()
if result.styles and any([style.is_handwritten for style in result.styles]):
print("Document contains handwritten content")
else:
print("Document does not contain handwritten content")
# the code below gives the content of the document
# use streaming as the content can be large
for page in result.pages:
if page.lines:
for _, line in enumerate(page.lines):
# vector db
print(line.content)