Skip to content

Commit

Permalink
Some readme improvements (CarperAI#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
thedch authored Oct 19, 2022
1 parent 5b00cd9 commit b7d3c3f
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ repos:
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- id: black
files: ^(trlx|examples|unittests|setup.py)/
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ check_dirs := trlx/

style:
black $(check_dirs)
isort $(check_dirs)
isort $(check_dirs) # see pyproject.toml for isort config
flake8 $(check_dirs) --ignore=$(IGNORE_PEP)

quality:
isort --check-only $(check_dirs)
isort --check-only $(check_dirs) # see pyproject.toml for isort config
flake8 $(check_dirs)
52 changes: 35 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Adding a task for RLHF training depends on the desired training method and pre-e
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install -e ".[dev]"
pre-commit install # see .pre-commit-config.yaml
```

## Example: How to add a task
Expand All @@ -46,35 +47,52 @@ accelerate config
```python
@register_datapipeline
class PPOPipeline(BasePipeline):
def __init__(self, tokenizer, config, prompt_dataset_path = None):
def __init__(self, tokenizer, config, prompt_dataset_path=None):
super().__init__()

ds = load_dataset('imdb', split='test')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])<500, batched=False)

self.tokens = [tokenizer(text,
truncation = True,
padding = 'max_length',
max_length = config.train.input_size,
return_tensors = "pt"
)['input_ids'].long().flatten() for text in ds['review']]
ds = load_dataset("imdb", split="test")
ds = ds.rename_columns({"text": "review", "label": "sentiment"})
ds = ds.filter(lambda x: len(x["review"]) < 500, batched=False)

self.tokens = [
tokenizer(
text,
truncation=True,
padding="max_length",
max_length=config.train.input_size,
return_tensors="pt",
)["input_ids"]
.long()
.flatten()
for text in ds["review"]
]
self.text = [tokenizer.decode(tokens.tolist()) for tokens in self.tokens]

def __getitem__(self, index : int) -> PromptElement:
def __getitem__(self, index: int) -> PromptElement:
return PromptElement(self.text[index], self.tokens[index])

def __len__(self) -> int:
return len(self.text)

def create_loader(self, batch_size : int, shuffle : bool, prep_fn : Callable = None, num_workers : int = 0) -> DataLoader:
#TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
def collate_fn(elems : Iterable[PromptElement]) -> PromptElement:
def create_loader(
self,
batch_size: int,
shuffle: bool,
prep_fn: Callable = None,
num_workers: int = 0,
) -> DataLoader:
# TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
def collate_fn(elems: Iterable[PromptElement]) -> PromptElement:
return PromptBatch(
[elem.text for elem in elems], torch.stack([elem.tokens for elem in elems]) # Assumes token tensors all same size
[elem.text for elem in elems],
torch.stack(
[elem.tokens for elem in elems]
), # Assumes token tensors all same size
)

return DataLoader(self, batch_size, shuffle, collate_fn = collate_fn, num_workers = num_workers)
return DataLoader(
self, batch_size, shuffle, collate_fn=collate_fn, num_workers=num_workers
)
```

### Launch training
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.isort]
multi_line_output = 3
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ install_requires =
[options.extras_require]
dev =
black
isort
flake8
pre-commit
pytest
Expand Down
5 changes: 4 additions & 1 deletion trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from transformers import AutoConfig, AutoTokenizer

from trlx.data import BatchElement, RLElement
from trlx.data.accelerate_base_datatypes import AccelerateRLBatchElement, PromptBatch
from trlx.data.accelerate_base_datatypes import (
AccelerateRLBatchElement,
PromptBatch
)
from trlx.data.configs import TRLConfig
from trlx.model import BaseRLModel, register_model
from trlx.pipeline.accelerate_base_pipeline import AccelerateRolloutStorage
Expand Down
5 changes: 4 additions & 1 deletion trlx/model/accelerate_ilql_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from trlx.model import BaseRLModel, register_model
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
from trlx.pipeline.offline_pipeline import OfflinePipeline, OfflineRolloutStorage
from trlx.pipeline.offline_pipeline import (
OfflinePipeline,
OfflineRolloutStorage
)
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask

WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
Expand Down
9 changes: 6 additions & 3 deletions trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
from abc import abstractmethod
from typing import Dict, Iterable, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torchtyping import TensorType
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
import numpy as np

import wandb
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.configs import TRLConfig
from trlx.model import BaseRLModel, register_model
from trlx.model.accelerate_base_model import AccelerateRLModel
from trlx.model.nn.ppo_models import GPTHeadWithValueModel, GPTHydraHeadWithValueModel
from trlx.model.nn.ppo_models import (
GPTHeadWithValueModel,
GPTHydraHeadWithValueModel
)
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask
from trlx.utils.modeling import clip_by_value, logprobs_from_logits, whiten
Expand Down
6 changes: 3 additions & 3 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand All @@ -15,11 +17,9 @@
GPTJModel,
PretrainedConfig,
PreTrainedModel,
top_k_top_p_filtering,
top_k_top_p_filtering
)
from transformers.modeling_outputs import ModelOutput
from copy import deepcopy
import inspect


# Cell
Expand Down
5 changes: 4 additions & 1 deletion trlx/orchestrator/offline_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from trlx.model import BaseRLModel
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.offline_pipeline import OfflinePipeline, OfflineRolloutStorage
from trlx.pipeline.offline_pipeline import (
OfflinePipeline,
OfflineRolloutStorage
)


@register_orchestrator
Expand Down
7 changes: 5 additions & 2 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from typing import Callable

import torch
import wandb
from tqdm import tqdm
from transformers import pipeline as tfpipeline

import wandb
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.ppo_types import PPORLElement
from trlx.model import BaseRLModel
from trlx.model.nn.ppo_models import (
GPTHeadWithValueModel,
GPTHydraHeadWithValueModel
)
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.ppo_pipeline import PPOPipeline
from trlx.utils import Clock, chunk, flatten, sentiment_score
from trlx.utils.modeling import logprobs_from_logits
from trlx.model.nn.ppo_models import GPTHeadWithValueModel, GPTHydraHeadWithValueModel


@register_orchestrator
Expand Down
2 changes: 1 addition & 1 deletion trlx/pipeline/accelerate_base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
AccelerateRLBatchElement,
AccelerateRLElement,
PromptBatch,
PromptElement,
PromptElement
)
from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline

Expand Down

0 comments on commit b7d3c3f

Please sign in to comment.