Skip to content

Commit

Permalink
Add initial GitHub workflows (CarperAI#43)
Browse files Browse the repository at this point in the history
* Add initial GitHub workflows

* Temporarily remove `F82` undefined-name check

* Format for `black` style check
  • Loading branch information
jon-tow authored Oct 19, 2022
1 parent b7d3c3f commit b2e523f
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 27 deletions.
8 changes: 4 additions & 4 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
---
name: 🐛 Bug Report
description: Report a bug or unexpected behavior to help us improve trlX
description: Report a bug or unexpected behavior to help us improve trlX
labels:
- bug

body:
- type: markdown
attributes:
value: >
#### Before submitting your bug report, please check to see that the
#### Before submitting your bug report, please check to see that the
issue hasn't already been reported and/or fixed in a latest version.
[Search Issues][Issue Search].
If you're asking a question or seeking support, please consider creating a
new [GitHub discussion][Discussions] or heading over to CarperAI's
[Discord server][CarperAI Discord].
Expand Down Expand Up @@ -62,4 +62,4 @@ body:
- type: markdown
attributes:
value: >
Thanks for contributing 🐠!
Thanks for contributing 🐠!
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ body:
attributes:
label: 📚 The doc issue
description: >
Please provide a clear and concise description of what content in https://trlx.readthedocs.io/en/latest/index.html is an issue.
Please provide a clear and concise description of what content in https://trlx.readthedocs.io/en/latest/index.html is an issue.
validations:
required: true

Expand Down
41 changes: 41 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Build

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9.13
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
- name: Lint with flake8
run: |
# Stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7 --show-source --statistics
# Exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Run tests
run: |
pytest -vv --cov=trlx/ tests/
- name: Upload coverage to Codecov
run: |
bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN
13 changes: 13 additions & 0 deletions .github/workflows/code_quality.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Code Quality

on: [pull_request]

jobs:
code-quality:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.9
- uses: pre-commit/[email protected]
5 changes: 1 addition & 4 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from transformers import AutoConfig, AutoTokenizer

from trlx.data import BatchElement, RLElement
from trlx.data.accelerate_base_datatypes import (
AccelerateRLBatchElement,
PromptBatch
)
from trlx.data.accelerate_base_datatypes import AccelerateRLBatchElement, PromptBatch
from trlx.data.configs import TRLConfig
from trlx.model import BaseRLModel, register_model
from trlx.pipeline.accelerate_base_pipeline import AccelerateRolloutStorage
Expand Down
5 changes: 1 addition & 4 deletions trlx/model/accelerate_ilql_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

from trlx.model import BaseRLModel, register_model
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
from trlx.pipeline.offline_pipeline import (
OfflinePipeline,
OfflineRolloutStorage
)
from trlx.pipeline.offline_pipeline import OfflinePipeline, OfflineRolloutStorage
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask

WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
Expand Down
5 changes: 1 addition & 4 deletions trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
from trlx.data.configs import TRLConfig
from trlx.model import BaseRLModel, register_model
from trlx.model.accelerate_base_model import AccelerateRLModel
from trlx.model.nn.ppo_models import (
GPTHeadWithValueModel,
GPTHydraHeadWithValueModel
)
from trlx.model.nn.ppo_models import GPTHeadWithValueModel, GPTHydraHeadWithValueModel
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask
from trlx.utils.modeling import clip_by_value, logprobs_from_logits, whiten
Expand Down
2 changes: 1 addition & 1 deletion trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
GPTJModel,
PretrainedConfig,
PreTrainedModel,
top_k_top_p_filtering
top_k_top_p_filtering,
)
from transformers.modeling_outputs import ModelOutput

Expand Down
5 changes: 1 addition & 4 deletions trlx/orchestrator/offline_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

from trlx.model import BaseRLModel
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.offline_pipeline import (
OfflinePipeline,
OfflineRolloutStorage
)
from trlx.pipeline.offline_pipeline import OfflinePipeline, OfflineRolloutStorage


@register_orchestrator
Expand Down
5 changes: 1 addition & 4 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.ppo_types import PPORLElement
from trlx.model import BaseRLModel
from trlx.model.nn.ppo_models import (
GPTHeadWithValueModel,
GPTHydraHeadWithValueModel
)
from trlx.model.nn.ppo_models import GPTHeadWithValueModel, GPTHydraHeadWithValueModel
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.ppo_pipeline import PPOPipeline
from trlx.utils import Clock, chunk, flatten, sentiment_score
Expand Down
2 changes: 1 addition & 1 deletion trlx/pipeline/accelerate_base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
AccelerateRLBatchElement,
AccelerateRLElement,
PromptBatch,
PromptElement
PromptElement,
)
from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline

Expand Down

0 comments on commit b2e523f

Please sign in to comment.