Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchserve #3

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ scripts/_nohup
*.log
*.out
*.pkl
/exp_logs/*
/exp_logs/*
model-store/
logs/
configs/*.json
142 changes: 142 additions & 0 deletions __civil_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
from ts.torch_handler.base_handler import BaseHandler

import os
import io

import zipfile
import joblib
import numpy as np

MODELSZIP = "HARS.zip"


class CivilsHandler(BaseHandler):
def initialize(self, context):
properties = context.system_properties
model_dir = properties.get("model_dir")
gpu_id = properties.get("gpu_id")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.initialized = True
# self.logger = get_logger()
self.context = context

self.map_location, self.device, self.use_gpu = (
("cuda", torch.device("cuda:" + str(gpu_id)), True)
if torch.cuda.is_available()
else ("cpu", torch.device("cpu"), False)
)

# If not already extracted, Extract model source code
print(model_dir)
with zipfile.ZipFile(os.path.join(model_dir, MODELSZIP), "r") as zip_ref:
zip_ref.extractall(model_dir)

from configs.LoadConfig import load_json_config
from tasks import eval_protocols

"""加载模型及环境初始化"""

args = load_json_config(f"{model_dir}/config.json")
self.args = args

self.model = self.build_model(args)
self.model.load(f"{model_dir}/model.pkl")

if args.eval_protocol == "mlp":
args.num_cluster = args.num_cluster[0]
self.task_proto = eval_protocols.Linear_probe(
args.feat_dim, args.num_cluster
)
self.task_proto.load_state_dict(
torch.load(f"{model_dir}/model_{args.eval_protocol}.pkl")["mlp_model"]
)
else:
self.task_proto = joblib.load(f"{model_dir}/{args.eval_protocol}.joblib")

import infer_datautils

# 数据转换与模型信息全局字典
TRANS_DATA_DICT = {
"TS_CoT": infer_datautils.format_ts2cot,
"TS_SEA": infer_datautils.format_ts2sea,
"TS2Vec": infer_datautils.format_ts2vec,
}

self.data_process_func = TRANS_DATA_DICT[self.args.backbone_type]

def build_model(self, args):
from algorithm.ts_sea import TS_SEA
from algorithm.ts_cot import TS_CoT
from algorithm.ts2vec import TS2Vec

"""根据配置加载模型架构"""
if args.backbone_type == "TS_CoT":
return TS_CoT(
input_dims=args.in_dims,
output_dims=args.repr_dims,
device=self.device,
args=args,
)
elif args.backbone_type == "TS_SEA":
return TS_SEA(
input_dims=args.in_dims,
output_dims=args.repr_dims,
device=self.device,
args=args,
)
elif args.backbone_type == "TS2Vec":
config = dict(
batch_size=args.batch_size,
lr=args.lr,
output_dims=args.repr_dims,
max_train_length=args.max_train_length,
)
return TS2Vec(input_dims=args.in_dims, device=self.device, **config)
else:
raise ValueError("Unknown Backbone Type")

def preprocess(self, req):
"""预处理输入数据"""
print("data preprocessing")
# print(req)
# 从请求中读取二进制数据
print(len(req))
binary_data = req[0].get("body")

# 将二进制数据反序列化为 PyTorch 张量
buffer = io.BytesIO(binary_data)
data = torch.load(buffer)

# 检查数据形状
print("接收到的张量形状:", data.shape)

# 返回响应
return self.data_process_func(data)

def inference(self, processed_data):
"""推理阶段"""
# processed_data = self.preprocess(data)

data_repr = self.model.encode_online(processed_data)

if self.args.eval_protocol == "mlp":
data_repr = torch.from_numpy(data_repr)
test_pred = self.task_proto(data_repr)
test_pred = torch.nn.functional.softmax(test_pred, dim=1)
y_score = torch.argmax(test_pred, dim=1)
return test_pred.tolist(), y_score.tolist()
elif self.args.eval_protocol in ["linear", "knn"]:
y_score = self.task_proto.predict_proba(data_repr)
else:
y_score = self.task_proto.score(data_repr)

return y_score.tolist(), y_score.argmax(axis=1).tolist()

def postprocess(self, inference_output):
"""后处理输出"""
y_score, y_label = inference_output
# ret = []
# for _scm,_yl in zip(y_score,y_label):
# ret.append({"predictions": _yl, "scores": _scm})
return [{"predictions": y_label, "scores": y_score}]
68 changes: 68 additions & 0 deletions _test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# from civil_handler222 import CivilsHandler
# from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
# if __name__ == '__main__':

# # =====和torch-model-archiver类似=====
# context = MockContext(
# model_pt_file="exp_logs/civil/RoadBank/TS_CoT/20241101_103432/model.pkl",
# model_dir="model-store",
# model_file="model.py",
# gpu_id="0",
# model_name="model",
# )
# handler = CivilsHandler()
# print(context.manifest)
# print(context.system_properties)
# handler.initialize(context)


import torch
import requests
import io
import random
# 定义接口地址
# url = "http://127.0.0.1:28080/predictions/ts2vec-har"
url = "http://192.168.5.246:19080/predictions/ts2vec-har"

# 创建形状为 (10, 9, 128) 的 PyTorch 数据
dataset_root = "/workspace/CA-TCC/data"
data_path = f"{dataset_root}/HAR/"
train_ = torch.load(data_path + "test.pt")
for idx in range(0,len(train_['samples']),5):

indices = random.sample(range(len(train_["samples"])), 5)
ratio_data = train_["samples"][indices]

# dat_dict = dict()
# dat_dict["samples"] = ratio_data
# torch.save(dat_dict, "subset_data.pt")
# data_ = torch.load("subset_data.pt")["samples"]
file_buffer = io.BytesIO()
torch.save(ratio_data, file_buffer)
file_buffer.seek(0)
file_size = len(file_buffer.getvalue())

# print(data_.shape)

# 将数据保存为二进制流
buffer = io.BytesIO()
torch.save(ratio_data, buffer)
buffer.seek(0) # 将指针移动到流的开头

# 定义请求头(用于二进制流)
headers = {"Content-Type": "application/octet-stream"}

# 发送 POST 请求
try:
response = requests.post(url, headers=headers, data=buffer.getvalue())
# 检查响应状态
if response.status_code == 200:
print("请求成功,响应内容:")
print(response.json())
else:
print(f"请求失败,状态码:{response.status_code}")
print(response.text)
except requests.exceptions.RequestException as e:
print(f"请求发生错误:{e}")

break
Loading