Skip to content

Commit

Permalink
fix: Fix csv quoting that broke some datasets (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi authored Jun 7, 2024
1 parent 4a9cd98 commit 2b80aca
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
3 changes: 3 additions & 0 deletions modyn/config/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class DatasetCsvFileWrapperConfig(_DatasetBaseFileWrapperConfig):
"""

separator: str = Field(",", description="The separator used in CSV files.")
quote: str = Field("\\0", description="The quote character used in CSV files.")
quoted_linebreaks: bool = Field(True, description="Whether linebreaks are quoted in CSV files.")

label_index: int = Field(
description=(
"Column index of the label. For columns 'width, 'height, 'age', 'label' you should set label_index to 3."
Expand Down
22 changes: 16 additions & 6 deletions modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ class CsvFileWrapper : public FileWrapper {
separator_ = ',';
}

if (file_wrapper_config_["quote_char"]) {
quote_ = file_wrapper_config_["quote_char"].as<char>();
} else {
quote_ = '\0'; // effectively disables quoting
}

if (file_wrapper_config_["quoted_linebreaks"]) {
allow_quoted_linebreaks_ = file_wrapper_config_["quoted_linebreaks"].as<bool>();
} else {
allow_quoted_linebreaks_ = true;
}

bool ignore_first_line = false;
if (file_wrapper_config_["ignore_first_line"]) {
ignore_first_line = file_wrapper_config_["ignore_first_line"].as<bool>();
Expand All @@ -34,12 +46,8 @@ class CsvFileWrapper : public FileWrapper {
ASSERT(filesystem_wrapper_->exists(path), "The file does not exist.");

validate_file_extension();

label_params_ = rapidcsv::LabelParams(ignore_first_line ? 0 : -1);

stream_ = filesystem_wrapper_->get_stream(path);

doc_ = rapidcsv::Document(*stream_, label_params_, rapidcsv::SeparatorParams(separator_));
setup_document(path);
}

~CsvFileWrapper() override {
Expand All @@ -52,6 +60,7 @@ class CsvFileWrapper : public FileWrapper {
CsvFileWrapper(CsvFileWrapper&&) = default;
CsvFileWrapper& operator=(CsvFileWrapper&&) = default;

void setup_document(const std::string& path);
uint64_t get_number_of_samples() override;
int64_t get_label(uint64_t index) override;
std::vector<int64_t> get_all_labels() override;
Expand All @@ -64,7 +73,8 @@ class CsvFileWrapper : public FileWrapper {
FileWrapperType get_type() override;

private:
char separator_;
char separator_, quote_;
bool allow_quoted_linebreaks_ = true;
uint64_t label_index_;
rapidcsv::Document doc_;
rapidcsv::LabelParams label_params_;
Expand Down
12 changes: 9 additions & 3 deletions modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

using namespace modyn::storage;

void CsvFileWrapper::setup_document(const std::string& path) {
stream_ = filesystem_wrapper_->get_stream(path);
auto sep_params = rapidcsv::SeparatorParams(separator_);
sep_params.mQuoteChar = quote_;
sep_params.mQuotedLinebreaks = allow_quoted_linebreaks_;
doc_ = rapidcsv::Document(*stream_, label_params_, sep_params);
}

void CsvFileWrapper::validate_file_extension() {
if (file_path_.substr(file_path_.find_last_of('.') + 1) != "csv") {
FAIL("The file extension must be .csv");
Expand Down Expand Up @@ -102,9 +110,7 @@ void CsvFileWrapper::set_file_path(const std::string& path) {
stream_->close();
}

stream_ = filesystem_wrapper_->get_stream(path);

doc_ = rapidcsv::Document(*stream_, label_params_, rapidcsv::SeparatorParams(separator_));
setup_document(path);
}

FileWrapperType CsvFileWrapper::get_type() { return FileWrapperType::CSV; }

0 comments on commit 2b80aca

Please sign in to comment.