Skip to content

Commit

Permalink
Merge pull request #12 from tmplink/v1.6
Browse files Browse the repository at this point in the history
Improve performance
  • Loading branch information
tmplink authored Nov 30, 2024
2 parents bce6666 + e369e5f commit 920684e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 36 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -e

# 默认配置值
IMAGE_NAME="vxlink/nsfw_detector"
VERSION="v1.5"
VERSION="v1.6"
PUSH="false"
CACHE_DIR="${HOME}/.docker/nsfw_detector_cache"
CACHE_FROM=""
Expand Down
3 changes: 2 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
datefmt='%Y-%m-%d %H:%M:%S',
encoding='utf-8'
)

# 配置 rarfile
Expand Down
2 changes: 1 addition & 1 deletion dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
python3-pip \
curl \
unrar \
p7zip-full \
p7zip-full p7zip-rar \
python3-opencv \
libgl1-mesa-glx \
libglib2.0-0 \
Expand Down
30 changes: 24 additions & 6 deletions processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def process_archive(filepath, filename, depth=0, max_depth=100):
"""
temp_dir = None
try:
# 确保 filename 正确编码
encoded_filename = filename # 保存原始文件名
if isinstance(filename, bytes):
with ArchiveHandler(filepath) as temp_handler:
encoded_filename = temp_handler.__encode_filename(filename)

# 检查递归深度
if depth > max_depth:
logger.warning(f"达到最大递归深度 {max_depth}")
Expand All @@ -336,7 +342,7 @@ def process_archive(filepath, filename, depth=0, max_depth=100):

# 创建临时目录
temp_dir = tempfile.mkdtemp()
logger.info(f"处理压缩文件: {filename}, 深度: {depth}, 临时文件路径: {filepath}")
logger.info(f"处理压缩文件: {encoded_filename}, 深度: {depth}, 临时文件路径: {filepath}")

# 检查文件大小
file_size = os.path.getsize(filepath)
Expand All @@ -349,11 +355,16 @@ def process_archive(filepath, filename, depth=0, max_depth=100):
with ArchiveHandler(filepath) as handler:
# 获取文件列表
files = handler.list_files()

# 分离可直接处理的文件和嵌套压缩包
processable_files = []
nested_archives = []

for f in files:
# 确保文件名已正确编码
if isinstance(f, bytes):
f = handler.__encode_filename(f)

ext = os.path.splitext(f)[1].lower()
if ext in ARCHIVE_EXTENSIONS:
nested_archives.append(f)
Expand All @@ -374,6 +385,10 @@ def process_archive(filepath, filename, depth=0, max_depth=100):

for inner_filename in sorted_files:
try:
# 确保内部文件名已正确编码
if isinstance(inner_filename, bytes):
inner_filename = handler.__encode_filename(inner_filename)

content = handler.extract_file(inner_filename)
ext = os.path.splitext(inner_filename)[1].lower()

Expand Down Expand Up @@ -424,16 +439,20 @@ def process_archive(filepath, filename, depth=0, max_depth=100):
continue

if matched_content:
logger.info(f"在压缩包 {filename} 中发现匹配内容: {matched_content['matched_file']}")
logger.info(f"在压缩包 {encoded_filename} 中发现匹配内容: {matched_content['matched_file']}")
return {
'status': 'success',
'filename': filename,
'filename': encoded_filename,
'result': matched_content['result']
}

# 处理嵌套的压缩包
for nested_archive in nested_archives:
try:
# 确保嵌套压缩包文件名已正确编码
if isinstance(nested_archive, bytes):
nested_archive = handler.__encode_filename(nested_archive)

temp_nested = tempfile.NamedTemporaryFile(delete=False)
content = handler.extract_file(nested_archive)

Expand Down Expand Up @@ -465,10 +484,10 @@ def process_archive(filepath, filename, depth=0, max_depth=100):

# 如果所有文件都处理完还没有返回,返回最后一个结果
if last_result:
logger.info(f"处理压缩包 {filename} 完成,最后处理的文件: {last_result['matched_file']}")
logger.info(f"处理压缩包 {encoded_filename} 完成,最后处理的文件: {last_result['matched_file']}")
return {
'status': 'success',
'filename': filename,
'filename': encoded_filename,
'result': last_result['result']
}

Expand All @@ -483,7 +502,6 @@ def process_archive(filepath, filename, depth=0, max_depth=100):
'status': 'error',
'message': str(e)
}, 500

finally:
if temp_dir and os.path.exists(temp_dir):
try:
Expand Down
106 changes: 79 additions & 27 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,63 @@ def _generate_temp_filename(self, original_filename):
ext = Path(original_filename).suffix
return f"{str(uuid.uuid4())}{ext}"

def __encode_filename(self, filename):
"""文件名编码处理"""
if isinstance(filename, str):
return filename

try:
decoded = filename.decode('utf-8')
return decoded
except UnicodeDecodeError as e:
return filename.decode('utf-8', errors='replace')

def _extract_rar_files(self, files_to_extract):
"""只解压需要处理的RAR文件到临时目录"""
if not self.temp_dir:
self.temp_dir = tempfile.mkdtemp()

try:
for filename in files_to_extract:
# 使用unrar命令行工具解压特定文件
extract_cmd = ['unrar', 'e', '-y', self.filepath, filename, self.temp_dir]

result = subprocess.run(
extract_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding='utf-8'
)

if result.returncode != 0:
logger.warning(f"解压RAR文件 {filename} 失败: {result.stderr}")
continue

# 获取解压后的文件路径
original_path = os.path.join(self.temp_dir, os.path.basename(filename))
if os.path.exists(original_path):
new_filename = self._generate_temp_filename(filename)
new_path = os.path.join(self.temp_dir, new_filename)
try:
os.link(original_path, new_path)
except OSError:
shutil.copy2(original_path, new_path)
self._extracted_files[filename] = new_path
os.unlink(original_path)

except Exception as e:
logger.error(f"RAR文件解压失败: {str(e)}")
raise

def _extract_7z_files(self, files_to_extract):
"""只解压需要处理的文件到临时目录"""
"""只解压需要处理的7z文件到临时目录"""
if not self.temp_dir:
self.temp_dir = tempfile.mkdtemp()

try:
for filename in files_to_extract:
# 为每个文件准备解压命令
extract_cmd = ['7z', 'e', self.filepath, '-o' + self.temp_dir, filename, '-y']

# 执行解压
result = subprocess.run(
extract_cmd,
stdout=subprocess.PIPE,
Expand All @@ -84,25 +130,19 @@ def _extract_7z_files(self, files_to_extract):
logger.warning(f"解压文件 {filename} 失败: {result.stderr}")
continue

# 获取解压后的文件路径
original_path = os.path.join(self.temp_dir, os.path.basename(filename))
if os.path.exists(original_path):
new_filename = self._generate_temp_filename(filename)
new_path = os.path.join(self.temp_dir, new_filename)
try:
# 尝试创建硬链接
os.link(original_path, new_path)
except OSError:
# 如果硬链接失败,则复制文件
shutil.copy2(original_path, new_path)
self._extracted_files[filename] = new_path
# 删除原始文件
os.unlink(original_path)

except Exception as e:
logger.error(f"7z文件解压失败: {str(e)}")
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
raise

def __enter__(self):
Expand All @@ -112,6 +152,7 @@ def __enter__(self):
if self.archive.testzip() is not None:
raise zipfile.BadZipFile("ZIP文件损坏")
elif self.type == 'rar':
# 只打开文件以获取文件列表,不进行解压
self.archive = rarfile.RarFile(self.filepath)
if self.archive.needs_password():
raise Exception("RAR文件有密码保护")
Expand All @@ -134,10 +175,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def list_files(self):
try:
files = []
if self.type == 'zip':
files = [f for f in self.archive.namelist() if not f.endswith('/')]
files = [self.__encode_filename(f) for f in self.archive.namelist()
if not f.endswith('/')]
elif self.type == 'rar':
files = [f.filename for f in self.archive.infolist() if not f.is_dir()]
files = [self.__encode_filename(f.filename) for f in self.archive.infolist()
if not f.is_dir()]
processable_files = [f for f in files if can_process_file(f)]
if processable_files:
self._extract_rar_files(processable_files)
elif self.type == '7z':
result = subprocess.run(
['7z', 'l', '-slt', self.filepath],
Expand Down Expand Up @@ -165,7 +212,6 @@ def list_files(self):
current_file = None
is_directory = False

# 对于7z文件,只解压需要处理的文件
processable_files = [f for f in files if can_process_file(f)]
if processable_files:
self._extract_7z_files(processable_files)
Expand All @@ -176,8 +222,6 @@ def list_files(self):
files = [base_name[:-3]]
else:
files = ['content']
else:
files = []

processable = [f for f in files if can_process_file(f)]
logger.info(f"找到 {len(processable)} 个可处理文件")
Expand All @@ -192,11 +236,12 @@ def get_file_info(self, filename):
if self.type == 'zip':
return self.archive.getinfo(filename).file_size
elif self.type == 'rar':
if filename in self._extracted_files:
return os.path.getsize(self._extracted_files[filename])
return self.archive.getinfo(filename).file_size
elif self.type == '7z':
if filename in self._extracted_files:
return os.path.getsize(self._extracted_files[filename])
# 如果文件未解压,运行7z l命令获取文件大小
result = subprocess.run(
['7z', 'l', '-slt', self.filepath, filename],
stdout=subprocess.PIPE,
Expand All @@ -221,24 +266,31 @@ def get_file_info(self, filename):

def extract_file(self, filename):
try:
base_name = os.path.basename(filename)
logger.info(f"正在检测文件: {base_name}")
encoded_filename = self.__encode_filename(filename)
logger.info(f"正在检测文件: {encoded_filename}")

if self.type == 'zip':
return self.archive.read(filename)
return self.archive.read(filename) # 使用原始 filename
elif self.type == 'rar':
return self.archive.read(filename)
if encoded_filename in self._extracted_files:
with open(self._extracted_files[encoded_filename], 'rb') as f:
return f.read()
if can_process_file(encoded_filename):
self._extract_rar_files([filename]) # 使用原始 filename
if encoded_filename in self._extracted_files:
with open(self._extracted_files[encoded_filename], 'rb') as f:
return f.read()
return self.archive.read(filename) # 使用原始 filename
elif self.type == '7z':
if filename in self._extracted_files:
with open(self._extracted_files[filename], 'rb') as f:
if encoded_filename in self._extracted_files:
with open(self._extracted_files[encoded_filename], 'rb') as f:
return f.read()
# 如果文件还未解压,则进行解压
if can_process_file(filename):
self._extract_7z_files([filename])
if filename in self._extracted_files:
with open(self._extracted_files[filename], 'rb') as f:
if can_process_file(encoded_filename):
self._extract_7z_files([filename]) # 使用原始 filename
if encoded_filename in self._extracted_files:
with open(self._extracted_files[encoded_filename], 'rb') as f:
return f.read()
raise Exception(f"文件 {filename} 未找到在提取列表中")
raise Exception(f"文件 {encoded_filename} 未找到在提取列表中")
elif self.type == 'gz':
return self.archive.read()
raise Exception("不支持的压缩格式")
Expand Down

0 comments on commit 920684e

Please sign in to comment.