Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Apr 5, 2024
1 parent 470a71b commit 1538df0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
13 changes: 10 additions & 3 deletions examples/pytorch/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

import fog_rtx
import torch

dataset = fog_rtx.dataset.Dataset(
name="demo_ds",
path="/tmp",
Expand All @@ -14,5 +16,10 @@
pytorch_ds = dataset.pytorch_dataset_builder()

# get samples from the dataset
for data in torch.utils.data.DataLoader(pytorch_ds, batch_size=2, collate_fn=lambda x: x, sampler = torch.utils.data.RandomSampler(pytorch_ds)):
print(data)
for data in torch.utils.data.DataLoader(
pytorch_ds,
batch_size=2,
collate_fn=lambda x: x,
sampler=torch.utils.data.RandomSampler(pytorch_ds),
):
print(data)
9 changes: 2 additions & 7 deletions fog_rtx/database/db_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@

import pandas as pd # type: ignore
import sqlalchemy # type: ignore
from sqlalchemy import MetaData # type: ignore
from sqlalchemy import Table # type: ignore
from sqlalchemy import text # type: ignore
from sqlalchemy import (
Column,
Integer,
MetaData, # type: ignore
create_engine,
inspect,
)
from sqlalchemy import Column, Integer, create_engine, inspect
from sqlalchemy.orm import declarative_base, sessionmaker # type: ignore

from fog_rtx.database.utils import type_py2sql # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions fog_rtx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ def __getitem__(self, idx):
# For simplicity, let's assume we're just returning the episode
return episode

# Assume we use get_metadata_as_pandas_df to retrieve episodes metadata
# Assume we use get_metadata_as_pandas_df to retrieve episodes metadata
metadata_df = self.get_metadata_as_pandas_df()
episodes = self.read_by(metadata_df)

# Initialize the PyTorch dataset with the episodes and features
pytorch_dataset = PyTorchDataset(episodes, self.features)
return pytorch_dataset

return pytorch_dataset

0 comments on commit 1538df0

Please sign in to comment.