forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_cpp_extensions_stream_and_event.py
106 lines (92 loc) · 3.34 KB
/
test_cpp_extensions_stream_and_event.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Owner(s): ["module: mtia"]
import os
import tempfile
import unittest
import torch
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_LINUX,
skipIfTorchDynamo,
TEST_CUDA,
TEST_MPS,
TEST_PRIVATEUSE1,
TEST_XPU,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
# define TEST_ROCM before changing TEST_CUDA
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
# The test will be skipped if any of the following conditions are met:
@unittest.skipIf(
IS_ARM64
or not IS_LINUX
or TEST_CUDA
or TEST_XPU
or TEST_MPS
or TEST_PRIVATEUSE1
or TEST_ROCM,
"Only on linux platform and mutual exclusive to other backends",
)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppExtensionStreamAndEvent(common.TestCase):
"""Tests Stream and Event with C++ extensions."""
module = None
def setUp(self):
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
def tearDown(self):
super().tearDown()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
@classmethod
def tearDownClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
@classmethod
def setUpClass(cls):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
build_dir = tempfile.mkdtemp()
# Load the fake device guard impl.
src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp"
cls.module = torch.utils.cpp_extension.load(
name="mtia_extension",
sources=[src],
build_directory=build_dir,
extra_include_paths=[
"cpp_extensions",
"path / with spaces in it",
"path with quote'",
],
is_python_module=False,
verbose=True,
)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
def test_stream_event(self):
s = torch.Stream()
self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
e = torch.Event()
self.assertTrue(e.device.type, "mtia")
# Should be nullptr by default
self.assertTrue(e.event_id == 0)
s.record_event(event=e)
print(f"recorded event 1: {e}")
self.assertTrue(e.event_id != 0)
e2 = s.record_event()
print(f"recorded event 2: {e2}")
self.assertTrue(e2.event_id != 0)
self.assertTrue(e2.event_id != e.event_id)
e.synchronize()
e2.synchronize()
time_elapsed = e.elapsed_time(e2)
print(f"time elapsed between e1 and e2: {time_elapsed}")
old_event_id = e.event_id
e.record(stream=s)
print(f"recorded event 1: {e}")
self.assertTrue(e.event_id == old_event_id)
if __name__ == "__main__":
common.run_tests()