diff --git a/piicatcher_spacy/detectors/spacy.py b/piicatcher_spacy/detectors/spacy.py index abe205e..4e01664 100644 --- a/piicatcher_spacy/detectors/spacy.py +++ b/piicatcher_spacy/detectors/spacy.py @@ -15,32 +15,35 @@ @register_detector class SpacyDetector(DatumDetector): pii_cls_map = { - 'FAC': Address, # Buildings, airports, highways, bridges, etc. - 'GPE': Address, # Countries, cities, states. - 'LOC': Address, # Non-GPE locations, mountain ranges, bodies of water. - 'PERSON': Person, # People, including fictional. - 'PER': Person, # Bug in french model - 'DATE': BirthDate, # Dates within the period 18 to 100 years ago. + "FAC": Address, # Buildings, airports, highways, bridges, etc. + "GPE": Address, # Countries, cities, states. + "LOC": Address, # Non-GPE locations, mountain ranges, bodies of water. + "PERSON": Person, # People, including fictional. + "PER": Person, # Bug in french model + "DATE": BirthDate, # Dates within the period 18 to 100 years ago. } - name = 'DatumSpacyDetector' + name = "DatumSpacyDetector" def __init__(self, model: str = "en_core_web_md"): super(SpacyDetector, self).__init__() # Fixes a warning message from transformers that is pulled in via spacy import os - os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + os.environ["TOKENIZERS_PARALLELISM"] = "false" self.check_spacy_version() if not self.check_spacy_model(model): - raise ValueError("Unable to find spacy model '{}'. Is your language supported? " - "Check the list of models available here: " - "https://github.com/explosion/spacy-models ".format(self.model)) + raise ValueError( + "Unable to find spacy model '{}'. Is your language supported? " + "Check the list of models available here: " + "https://github.com/explosion/spacy-models ".format(self.model) + ) self.nlp = spacy.load(model) # If the model doesn't support named entity recognition - if 'ner' not in [step[0] for step in self.nlp.pipeline]: + if "ner" not in [step[0] for step in self.nlp.pipeline]: raise ValueError( "The spacy model '{}' doesn't support named entity recognition, " "please choose another model.".format(self.model) @@ -49,16 +52,28 @@ def __init__(self, model: str = "en_core_web_md"): @staticmethod def check_spacy_version() -> bool: """Ensure that the version of spaCy is v3.""" - spacy_version = spacy.__version__ # spacy_info.get('spaCy version', spacy_info.get('spacy_version', None)) + spacy_version = ( + spacy.__version__ + ) # spacy_info.get('spaCy version', spacy_info.get('spacy_version', None)) if spacy_version is None: - raise ImportError('Spacy v3 needs to be installed. Unable to detect spacy version.') + raise ImportError( + "Spacy v3 needs to be installed. Unable to detect spacy version." + ) try: - spacy_major = int(spacy_version.split('.')[0]) + spacy_major = int(spacy_version.split(".")[0]) except Exception: - raise ImportError('Spacy v3 needs to be installed. Spacy version {} is unknown.'.format(spacy_version)) + raise ImportError( + "Spacy v3 needs to be installed. Spacy version {} is unknown.".format( + spacy_version + ) + ) if spacy_major != 3: - raise ImportError('Spacy v3 needs to be installed. Detected version {}.'.format(spacy_version)) + raise ImportError( + "Spacy v3 needs to be installed. Detected version {}.".format( + spacy_version + ) + ) return True @@ -67,10 +82,12 @@ def check_spacy_model(model) -> bool: """Ensure that the spaCy model is installed.""" spacy_info = spacy.info() if isinstance(spacy_info, str): - raise ValueError('Unable to detect spacy models.') - models = list(spacy_info.get('pipelines', spacy_info.get('models', None)).keys()) + raise ValueError("Unable to detect spacy models.") + models = list( + spacy_info.get("pipelines", spacy_info.get("models", None)).keys() + ) if models is None: - raise ValueError('Unable to detect spacy models.') + raise ValueError("Unable to detect spacy models.") if model not in models: LOGGER.info("Downloading spacy model {}".format(model)) @@ -83,6 +100,12 @@ def check_spacy_model(model) -> bool: return model in models def detect(self, column: CatColumn, datum: str) -> Optional[PiiType]: + # if datum can be a list, like in postgres arrays, recurse, otherwise we crash in language.py from spacy. + if isinstance(datum, list): + for d in datum: + result = self.detect(column, d) + if result: + return result doc = self.nlp(datum) for ent in doc.ents: LOGGER.debug("Found %s", ent.label_)