Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How about adding a feature to pass the key when performing map on DatasetDict? #7356

Open
jp1924 opened this issue Jan 6, 2025 · 6 comments
Labels
enhancement New feature or request

Comments

@jp1924
Copy link

jp1924 commented Jan 6, 2025

Feature request

Add a feature to pass the key of the DatasetDict when performing map

Motivation

I often preprocess using map on DatasetDict.
Sometimes, I need to preprocess train and valid data differently depending on the task.
So, I thought it would be nice to pass the key (like train, valid) when performing map on DatasetDict.

What do you think?

Your contribution

I can submit a pull request to add the feature to pass the key of the DatasetDict when performing map.

@jp1924 jp1924 added the enhancement New feature or request label Jan 6, 2025
@jp1924
Copy link
Author

jp1924 commented Jan 13, 2025

@lhoestq
If it's okay with you, can I work on this?

@lhoestq
Copy link
Member

lhoestq commented Jan 13, 2025

Hi ! Can you give an example of what it would look like to use this new feature ?

Note that currently you can already do

ds["train"] = ds["train"].map(process_train)
ds["test"] = ds["test"].map(process_test)

@jp1924
Copy link
Author

jp1924 commented Jan 13, 2025

@lhoestq
Thanks for the response!
Let me clarify what I'm looking for with an example:

Currently, we need to write separate processing functions or call .map() separately:

# Current approach
def process_train(example):
    # Training-specific processing
    return example

def process_valid(example):
    # Validation-specific processing
    return example

ds["train"] = ds["train"].map(process_train)
ds["valid"] = ds["valid"].map(process_valid)

What I'm proposing is to have a single processing function that knows which split it's processing:

# Proposed feature
def process(example, split_key):
    if split_key == "train":
        # Training-specific processing
    elif split_key == "valid":
        # Validation-specific processing
    return example

# Using with_key=True to pass the split information
ds = ds.map(process, with_key=True)

This becomes particularly useful when:

  1. The processing logic is heavily shared between splits but needs minor adjustments
  2. You want to maintain the processing logic in one place for better maintainability
  3. The processing function is complex and you want to avoid duplicating code

So I wanted to request this feature to achieve this kind of functionality.
I've created a draft PR implementing this: https://github.com/huggingface/datasets/pull/7240/files

@lhoestq
Copy link
Member

lhoestq commented Jan 13, 2025

I see ! I think it makes sense, and it's more readable than doing something like this:

from functools import partial
ds = DatasetDict({key: ds[key].map(partial(process, split_key=key)) for key in ds})

PS: you named the argument with_key in your example, but it might be even clearer with it's named with_split maybe no ?

@jp1924
Copy link
Author

jp1924 commented Jan 13, 2025

@lhoestq I agree.
It seems better to use with_split.
So can I open a PR with this change?

@lhoestq
Copy link
Member

lhoestq commented Jan 13, 2025

Sure !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants