-
Notifications
You must be signed in to change notification settings - Fork 648
/
Copy pathpyproject.toml
200 lines (181 loc) · 4.83 KB
/
pyproject.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
[project]
name = "pytorch-forecasting"
readme = "README.md" # Markdown files are supported
version = "1.2.0" # is being replaced automatically
authors = [
{name = "Jan Beitner"},
]
requires-python = ">=3.9,<3.13"
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: MIT License",
]
description = "Forecasting timeseries with PyTorch - dataloaders, normalizers, metrics and models"
dependencies = [
"numpy<=3.0.0",
"torch >=2.0.0,!=2.0.1,<3.0.0",
"lightning >=2.0.0,<3.0.0",
"scipy >=1.8,<2.0",
"pandas >=1.3.0,<3.0.0",
"scikit-learn >=1.2,<2.0",
]
[project.optional-dependencies]
# there are the following dependency sets:
# - all_extras - all soft dependencies
# - granular dependency sets:
# - tuning - dependencies for tuning hyperparameters via optuna
# - mqf2 - dependencies for multivariate quantile loss
# - graph - dependencies for graph based forecasting
# - dev - the developer dependency set, for contributors to pytorch-forecasting
# - CI related: e.g., dev, github-actions. Not for users of sktime.
#
# soft dependencies are not required for the core functionality of sktime
# but are required by popular estimators, e.g., prophet, tbats, etc.
# all soft dependencies
#
# users can install via "pip install pytorch-forecasting[all_extras]"
#
all_extras = [
"cpflows",
"matplotlib",
"optuna >=3.1.0,<5.0.0",
"optuna-integration",
"pytorch_optimizer >=2.5.1,<4.0.0",
"statsmodels",
]
tuning = [
"optuna >=3.1.0,<5.0.0",
"optuna-integration",
"statsmodels",
]
mqf2 = ["cpflows"]
# the graph set is not currently used within pytorch-forecasting
# but is kept for future development, as it has already been released
graph = ["networkx"]
# dev - the developer dependency set, for contributors to pytorch-forecasting
dev = [
"pydocstyle >=6.1.1,<7.0.0",
# checks and make tools
"pre-commit >=3.2.0,<5.0.0",
"invoke",
"mypy",
"pylint",
"ruff",
# pytest
"pytest",
"pytest-xdist",
"pytest-cov",
"pytest-sugar",
"coverage",
"pyarrow",
# jupyter notebook
"ipykernel",
"nbconvert",
"black[jupyter]",
# documentatation
"sphinx",
"pydata-sphinx-theme",
"nbsphinx",
"recommonmark",
"ipywidgets>=8.0.1,<9.0.0",
"pytest-dotenv>=0.5.2,<1.0.0",
"tensorboard>=2.12.1,<3.0.0",
"pandoc>=2.3,<3.0.0",
]
# docs - dependencies for building the documentation
docs = [
"sphinx>3.2",
"pydata-sphinx-theme",
"nbsphinx",
"pandoc",
"nbconvert",
"recommonmark",
"docutils",
]
github-actions = ["pytest-github-actions-annotate-failures"]
[tool.setuptools.packages.find]
exclude = ["build_tools"]
[build-system]
build-backend = "setuptools.build_meta"
requires = [
"setuptools>=70.0.0",
]
[tool.ruff]
line-length = 88
exclude = [
"docs/build/",
"node_modules/",
".eggs/",
"versioneer.py",
"venv/",
".venv/",
".git/",
".history/",
"docs/source/tutorials/",
]
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "W", "C4", "S"]
extend-select = [
"I", # isort
"C4", # https://pypi.org/project/flake8-comprehensions
]
extend-ignore = [
"E203", # space before : (needed for how black formats slicing)
"E402", # module level import not at top of file
"E731", # do not assign a lambda expression, use a def
"E741", # ignore not easy to read variables like i l I etc.
"C406", # Unnecessary list literal - rewrite as a dict literal.
"C408", # Unnecessary dict call - rewrite as a literal.
"C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal.
"F401", # unused imports
"S101", # use of assert
]
[tool.ruff.lint.isort]
known-first-party = ["pytorch_forecasting"]
combine-as-imports = true
force-sort-within-sections = true
[tool.ruff.lint.per-file-ignores]
"pytorch_forecasting/data/timeseries.py" = [
"E501", # Line too long being fixed in #1746 To be removed after merging
]
[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| docs/build/
| node_modules/
| venve/
| .venv/
)
'''
[tool.nbqa.mutate]
ruff = 1
black = 1
[tool.nbqa.exclude]
ruff = "docs/source/tutorials/" # ToDo: Remove this when fixing notebooks