-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
40 lines (29 loc) · 983 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import pickle
import click
import mlflow
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("nyc-taxi-experiment")
mlflow.autolog()
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
def load_pickle(filename: str):
with open(filename, "rb") as f_in:
return pickle.load(f_in)
@click.command()
@click.option(
"--data_path",
default="./output",
help="Location where the processed NYC taxi trip data was saved"
)
def run_train(data_path: str):
with mlflow.start_run():
X_train, y_train = load_pickle(os.path.join(data_path, "train.pkl"))
X_val, y_val = load_pickle(os.path.join(data_path, "val.pkl"))
rf = RandomForestRegressor(max_depth=10, random_state=0)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_val)
rmse = mean_squared_error(y_val, y_pred, squared=False)
mlflow.log_metric("rmse", rmse)
if __name__ == '__main__':
run_train()