Skip to content

Commit

Permalink
Fix colliding cuda context
Browse files Browse the repository at this point in the history
- Extract face_recognition and dlib into its own context
- Prevents initilizing the cuda context multiple times from dlib and pytorch
  • Loading branch information
derneuere committed Oct 31, 2023
1 parent 74f228d commit 695b820
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 13 deletions.
25 changes: 25 additions & 0 deletions api/face_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import requests
import numpy as np


def get_face_encodings(image_path, known_face_locations):
json = {
"source": image_path,
"face_locations": known_face_locations,
}
face_encoding = requests.post(
"http://localhost:8005/face-encodings", json=json
).json()

face_encodings_list = face_encoding["encodings"]
face_encodings = [np.array(enc) for enc in face_encodings_list]

return face_encodings


def get_face_locations(image_path, model="hog"):
json = {"source": image_path, "model": model}
face_locations = requests.post(
"http://localhost:8005/face-locations", json=json
).json()
return face_locations["face_locations"]
23 changes: 10 additions & 13 deletions api/models/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from django.db.models import Q
from django.db.utils import IntegrityError
from api.im2txt.sample import Im2txt
from api.face_recognition import get_face_encodings, get_face_locations

import api.date_time_extractor as date_time_extractor
import api.models
Expand Down Expand Up @@ -702,20 +703,18 @@ def _extract_faces(self, second_try=False):
).count()
image_path = self.image_hash + "_" + str(idx_face) + ".jpg"

import face_recognition

face_encodings = face_recognition.face_encodings(
image, known_face_locations=[(top, right, bottom, left)]
)

face_encoding = get_face_encodings(
self.thumbnail_big.path,
known_face_locations=[(top, right, bottom, left)],
)[0]
face = api.models.face.Face(
image_path=image_path,
photo=self,
location_top=top,
location_right=right,
location_bottom=bottom,
location_left=left,
encoding=face_encodings[0].tobytes().hex(),
encoding=face_encoding.tobytes().hex(),
person=person,
cluster=unknown_cluster,
)
Expand All @@ -729,16 +728,14 @@ def _extract_faces(self, second_try=False):
logger.debug(f"Created face {face} from {self.main_file.path}")
return

import face_recognition

try:
image = np.array(PIL.Image.open(self.thumbnail_big.path))

face_locations = []
# Create
try:
face_locations = face_recognition.face_locations(
image, model=self.owner.face_recognition_model.lower()
face_locations = get_face_locations(
self.thumbnail_big.path, model=self.owner.face_recognition_model.lower()
)
except Exception as e:
logger.info(
Expand All @@ -747,8 +744,8 @@ def _extract_faces(self, second_try=False):
logger.info(e)

if len(face_locations) > 0:
face_encodings = face_recognition.face_encodings(
image, known_face_locations=face_locations
face_encodings = get_face_encodings(
self.thumbnail_big.path, known_face_locations=face_locations
)
for idx_face, face in enumerate(zip(face_encodings, face_locations)):
face_encoding = face[0]
Expand Down
Empty file.
55 changes: 55 additions & 0 deletions service/face_recognition/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import gevent
from flask import Flask, request
from gevent.pywsgi import WSGIServer
import face_recognition
import numpy as np
import PIL

app = Flask(__name__)


def log(message):
print("face_recognition: {}".format(message))


@app.route("/face-encodings", methods=["POST"])
def create_face_encodings():
try:
data = request.get_json()
source = data["source"]
face_locations = data["face_locations"]
except Exception:
return "", 400

image = np.array(PIL.Image.open(source))
face_encodings = face_recognition.face_encodings(
image,
known_face_locations=face_locations,
)
# Convert NumPy arrays to Python lists
face_encodings_list = [enc.tolist() for enc in face_encodings]
# Log number of face encodings
log(f"created face_encodings={len(face_encodings_list)}")
return {"encodings": face_encodings_list}, 201


@app.route("/face-locations", methods=["POST"])
def create_face_locations():
try:
data = request.get_json()
source = data["source"]
model = data["model"]
except Exception:
return "", 400

image = np.array(PIL.Image.open(source))
face_locations = face_recognition.face_locations(image, model=model)
log(f"created face_location={face_locations}")
return {"face_locations": face_locations}, 201


if __name__ == "__main__":
log("service starting")
server = WSGIServer(("0.0.0.0", 8005), app)
server_thread = gevent.spawn(server.serve_forever)
gevent.joinall([server_thread])

0 comments on commit 695b820

Please sign in to comment.