Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert directory fbcode/fbpcs to use the Ruff Formatter #2421

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions fbpcs/bolt/read_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from fbpcs.utils.config_yaml.config_yaml_dict import ConfigYamlDict


def parse_bolt_config(config: Dict[str, Any], logger: logging.Logger) -> Tuple[
def parse_bolt_config(
config: Dict[str, Any], logger: logging.Logger
) -> Tuple[
BoltRunner[BoltPCSCreateInstanceArgs, BoltPCSCreateInstanceArgs],
List[BoltJob[BoltPCSCreateInstanceArgs, BoltPCSCreateInstanceArgs]],
]:

# create runner
runner_config = config["runner"]
runner = create_bolt_runner(runner_config=runner_config, logger=logger)
Expand Down Expand Up @@ -79,7 +80,7 @@ def create_bolt_runner(


def create_job_list(
job_config_list: Dict[str, Any]
job_config_list: Dict[str, Any],
) -> List[BoltJob[BoltPCSCreateInstanceArgs, BoltPCSCreateInstanceArgs]]:
bolt_job_list = []
for job_name, job_config in job_config_list.items():
Expand Down
13 changes: 6 additions & 7 deletions fbpcs/common/service/graphapi_trace_logging_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def _flush_msg_queue(self, msg_queue, flush_size: int = FLUSH_CHUNK_SIZE) -> Non
continue

aggregate_msg["component"] += f"{AGGREGATE_DELIMITER}{msg['component']}"
aggregate_msg[
"checkpoint_name"
] += f"{AGGREGATE_DELIMITER}{msg['checkpoint_name']}"
aggregate_msg[
"checkpoint_data"
] += f"{AGGREGATE_DELIMITER}{msg['checkpoint_data']}"
aggregate_msg["checkpoint_name"] += (
f"{AGGREGATE_DELIMITER}{msg['checkpoint_name']}"
)
aggregate_msg["checkpoint_data"] += (
f"{AGGREGATE_DELIMITER}{msg['checkpoint_data']}"
)

if aggregate_msg:
self._post_request(params=aggregate_msg)
Expand All @@ -106,7 +106,6 @@ def _write_checkpoint_impl(
status: CheckpointStatus,
checkpoint_data: Optional[Dict[str, str]] = None,
) -> None:

checkpoint_data = checkpoint_data or {}
component = checkpoint_data.pop("component", DEFAULT_COMPONENT_NAME)
scrubbed_checkpoint_data = {}
Expand Down
1 change: 0 additions & 1 deletion fbpcs/common/service/test/test_trace_logging_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class DummyRegistry(RegistryFactory[str]):

_REGISTRY: Dict[str, str] = {}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def _test_immutable_helper(
setattr(intance_base_obj, test_field, change_vals)

def _init_event_hook(self) -> None:

##########################
# update hooks: initialize name when id is initialized, they are both immutable
# initialize org when user is initialized, they are both immutable
Expand Down Expand Up @@ -356,7 +355,6 @@ def _init_event_hook(self) -> None:
)

def _update_event_hook(self) -> None:

##########################
# update hooks: update output_path and storage when input_path is changed
##########################
Expand Down Expand Up @@ -423,7 +421,6 @@ def _update_event_hook(self) -> None:
self.assertEqual(self.obj_1.highest_pressure, 70)

def _delete_event_hook(self) -> None:

##########################
# frozen hooks: frozen location when region is deleted
##########################
Expand Down
1 change: 0 additions & 1 deletion fbpcs/common/tests/test_stage_state_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def test_stop_containers(self, mock_onedocker_svc) -> None:
"Subtest with container_stoppable: {container_stoppable}",
container_stoppable=container_stoppable,
):

mock_onedocker_svc.reset_mock()
if container_stoppable:
mock_onedocker_svc.stop_containers = MagicMock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def build_args(
metric_path: Optional[str] = None,
run_id: Optional[str] = None,
) -> str:

cmd_ls = []

if server_endpoint:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


def get_access_token(hostname: str, client_id: str, client_secret: str) -> str:

url = f"https://{hostname}/clients/token"

payload = f"client_id={client_id}&client_secret={client_secret}&grant_type=client_credentials"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def lambda_handler(
output = []
##### NOTE: this script assume the schema is correct, no missing items
for record in event["records"]:

row = {}
recordId = record["recordId"]
row["recordId"] = recordId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


class AwsDeploymentHelper:

# policy_arn is fixed string. So defining it as a macro.
POLICY_ARN = "arn:aws:iam::{}:policy/{}"
IAM_POLICIES_DIRECTORY = "iam_policies"
Expand Down Expand Up @@ -282,7 +281,6 @@ def list_access_keys(self, user_name: str) -> List[str]:
def read_json_file(
self, file_name: str, policy_params: PolicyParams, read_mode: str = "r"
) -> Dict[str, Any]:

# this can be replaced with a json file which is written in deploy.sh
interpolation_data = {
"REGION": self.region,
Expand Down Expand Up @@ -326,7 +324,6 @@ def read_json_file(
return json_data

def create_user_workflow(self, user_name: str) -> None:

self.log.info(
f"""Cli to create user is triggered. Following actions will be performed
1. User {user_name} will be created
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def test_list_users(self) -> None:
self.aws_deployment_helper.iam.list_users.assert_called_once()

def test_create_access_key(self) -> None:

self.aws_deployment_helper.iam.create_access_key.return_value = {
"AccessKey": {"AccessKeyId": 1, "SecretAccessKey": 2}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

# Define Lambda function
def lambda_handler(event, context):

### should be one single upload (it should be larger than 1 by default)
if len(event["Records"]) >= 2:
logger.info("multiple csv uploaded. please upload only one csv at a time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(
logger_name: Optional[str] = None,
s3_bucket_name: Optional[str] = None,
) -> None:

aws_access_key_id = aws_access_key_id or os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = aws_secret_access_key or os.environ.get(
"AWS_SECRET_ACCESS_KEY"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def test_ensure_folder_exists(self) -> None:
)

def test_get_kinesis_firehose_streams(self) -> None:

kinesis_firehose_stream_name = "test_stream"
mock_return = {"stream_name": kinesis_firehose_stream_name}
self.aws_container_logs.kinesis_client.describe_delivery_stream.return_value = (
Expand Down Expand Up @@ -377,9 +376,7 @@ def test_get_latest_cloudwatch_log(self) -> None:
mock_streams = []
mock_response = {"logStreams": mock_streams}
self.aws_container_logs.cloudwatch_client.describe_log_streams.reset_mock()
self.aws_container_logs.cloudwatch_client.describe_log_streams.return_value = (
mock_response
)
self.aws_container_logs.cloudwatch_client.describe_log_streams.return_value = mock_response
expected = ""
self.assertEqual(
expected,
Expand All @@ -391,9 +388,7 @@ def test_get_latest_cloudwatch_log(self) -> None:
with self.subTest("EmptyResponse"):
mock_response = {}
self.aws_container_logs.cloudwatch_client.describe_log_streams.reset_mock()
self.aws_container_logs.cloudwatch_client.describe_log_streams.return_value = (
mock_response
)
self.aws_container_logs.cloudwatch_client.describe_log_streams.return_value = mock_response
expected = ""
self.assertEqual(
expected,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def test_upload_logs_to_s3_from_cloudwatch(self) -> None:
self.aws_container_logs.cloudwatch_client.describe_log_groups.reset_mock(
side_effect=True
)
getattr(self.aws_container_logs.s3_client, s3_endpoint).side_effect = (
ClientError(
error_response={"Error": {"Code": error_code}},
operation_name=s3_endpoint,
)
getattr(
self.aws_container_logs.s3_client, s3_endpoint
).side_effect = ClientError(
error_response={"Error": {"Code": error_code}},
operation_name=s3_endpoint,
)
with self.assertRaisesRegex(Exception, exc_regex):
self.aws_container_logs.upload_logs_to_s3_from_cloudwatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def setUp(self) -> None:
self.utils = Utils()

def test_create_file(self) -> None:

fake_file_path = "fake/file/path"
content_list = ["This is test string"]
with patch(
Expand Down
1 change: 0 additions & 1 deletion fbpcs/infra/logging_service/log_analyzer/log_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def _parse_one_line(
log_line: str,
parsing_state: Optional[ParsingState],
) -> Optional[ParsingState]:

if line_num == 1:
context = self._parse_line_context(log_line)
self.run_study.first_log = log_line
Expand Down
1 change: 0 additions & 1 deletion fbpcs/infra/restore_run_state/restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def run(self, argv: Optional[List[str]] = None) -> None:
self.logger.info(f"Downloaded run state to {dest_folder}")

def _copy_files(self, run_data_path: str, dest_folder: str) -> None:

# DataPath is like s3://fb-pc-data-nov07test1-vwxz/query-results/fbpcs_instances_638479584559395_1/
splits = self._split_path(run_data_path)
if splits is None:
Expand Down
1 change: 0 additions & 1 deletion fbpcs/pc_pre_validation/pc_pre_validation_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
[--tee-local-file-path=<local-file-path>]
"""


from typing import cast, List, Optional as OptionalType

from docopt import docopt
Expand Down
18 changes: 8 additions & 10 deletions fbpcs/pl_coordinator/pl_study_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ async def run_study_async(
bolt_hooks: Optional[Dict[BoltHookKey, List[BoltHook[BoltHookArgs]]]] = None,
stage_timeout_override: Optional[int] = None,
) -> BoltSummary:

# Create a GraphApiTraceLoggingService specific for this study_id
client: BoltGraphAPIClient[BoltPLGraphAPICreateInstanceArgs] = BoltGraphAPIClient(
config=config,
Expand Down Expand Up @@ -367,7 +366,6 @@ async def _validate_access_to_instance(
instance_id: str,
run_id: str,
) -> None:

tries = 0
while tries < CREATE_INSTANCE_TRIES:
tries += 1
Expand Down Expand Up @@ -834,14 +832,14 @@ async def _create_new_instances(
for objective_id in cell_obj_instances[cell_id]:
# Create new instance for cell_obj pairs which has no valid instance.
if "instance_id" not in cell_obj_instances[cell_id][objective_id]:
cell_obj_instances[cell_id][objective_id]["instance_id"] = (
await _create_instance_retry(
client, study_id, cell_id, objective_id, run_id, logger
)
)
cell_obj_instances[cell_id][objective_id][
STATUS
] = PrivateComputationInstanceStatus.CREATED.value
"instance_id"
] = await _create_instance_retry(
client, study_id, cell_id, objective_id, run_id, logger
)
cell_obj_instances[cell_id][objective_id][STATUS] = (
PrivateComputationInstanceStatus.CREATED.value
)

instance_id = cell_obj_instances[cell_id][objective_id]["instance_id"]
is_pl_timestamp_validation_enabled = await client.has_feature(
Expand Down Expand Up @@ -902,7 +900,7 @@ async def _create_instance_retry(

@bolt_checkpoint(dump_return_val=True, component=LOG_COMPONENT)
def _instance_to_input_path(
cell_obj_instance: Dict[str, Dict[str, Dict[str, Any]]]
cell_obj_instance: Dict[str, Dict[str, Dict[str, Any]]],
) -> Dict[str, Dict[str, str]]:
instance_input_path = {}
for cell_id in cell_obj_instance:
Expand Down
1 change: 0 additions & 1 deletion fbpcs/private_computation/entity/infra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# this import statument can avoid circular import
if TYPE_CHECKING:

from fbpcs.private_computation.stage_flows.private_computation_base_stage_flow import (
PrivateComputationBaseStageFlow,
)
Expand Down
6 changes: 3 additions & 3 deletions fbpcs/private_computation/entity/pc_infra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def build_full_config(cls, yml_config: Dict[str, Any]) -> Dict[str, Any]:
# can handle more than 1 override
for dep_key, dep_value in overrides.items():
if dep_key in yml_config["private_computation"]["dependency"]:
yml_config["private_computation"]["dependency"][
dep_key
] = dep_value
yml_config["private_computation"]["dependency"][dep_key] = (
dep_value
)

elif dep_key in yml_config["mpc"]["dependency"]:
yml_config["mpc"]["dependency"][dep_key] = dep_value
Expand Down
1 change: 0 additions & 1 deletion fbpcs/private_computation/entity/pc_infra_config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class PrivateComputationInfraConfigData:


class PrivateComputationInfraConfigInfo(Enum):

CONTAINER_SERVICE = PrivateComputationInfraConfigData(
"fbpcp.service.container_aws.AWSContainerService",
{"region", "cluster", "subnets"},
Expand Down
1 change: 0 additions & 1 deletion fbpcs/private_computation/entity/pcs_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class PCSFeature(Enum):

PCS_DUMMY = "pcs_dummy_feature"
PRIVATE_LIFT_PCF2_RELEASE = "private_lift_pcf2_release"
PC_COORDINATED_RETRY = "private_computation_coordinated_retry"
Expand Down
3 changes: 0 additions & 3 deletions fbpcs/private_computation/pc_attribution_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ async def run_attribution_async(
bolt_hooks: Optional[Dict[BoltHookKey, List[BoltHook[BoltHookArgs]]]] = None,
stage_timeout_override: Optional[int] = None,
) -> BoltSummary:

## Step 1: Validation. Function arguments and for private attribution run.
# obtain the values in the dataset info vector.
client: BoltGraphAPIClient[BoltPAGraphAPICreateInstanceArgs] = BoltGraphAPIClient(
Expand Down Expand Up @@ -251,7 +250,6 @@ async def _run_attribution_async_helper(
bolt_hooks: Optional[Dict[BoltHookKey, List[BoltHook[BoltHookArgs]]]],
stage_timeout_override: Optional[int],
) -> BoltSummary:

try:
datasets_info = _get_attribution_dataset_info(client, dataset_id, logger)
except GraphAPIGenericException as err:
Expand Down Expand Up @@ -605,7 +603,6 @@ def get_runnable_timestamps(
graphapi_version: Optional[str] = None,
graphapi_domain: Optional[str] = None,
) -> Iterable[str]:

client: BoltGraphAPIClient[BoltPAGraphAPICreateInstanceArgs] = BoltGraphAPIClient(
config=config,
logger=logger,
Expand Down
5 changes: 0 additions & 5 deletions fbpcs/private_computation/service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def transform_file_path(file_path: str, aws_region: Optional[str] = None) -> str
rf"https://[sS]3\.{region_regex_pattern}+\.amazonaws\.com/{bucket_name_regex_pattern}+/{key_pattern}+",
file_path,
):

# Extract Bucket, Key, and Region
key_name_search = re.search(
rf"https://[sS]3\.{region_regex_pattern}+\.amazonaws\.com/{bucket_name_regex_pattern}+/",
Expand Down Expand Up @@ -131,9 +130,7 @@ def transform_file_path(file_path: str, aws_region: Optional[str] = None) -> str

# Check if it matches the s3 style access format, s3://bucket-name/key-name
if re.search(rf"[sS]3://{bucket_name_regex_pattern}+/{key_pattern}+", file_path):

if aws_region is not None:

# Extract Bucket, Key
bucket_name_search = re.search(r"[sS]3://", file_path)
key_name_search = re.search(
Expand All @@ -144,7 +141,6 @@ def transform_file_path(file_path: str, aws_region: Optional[str] = None) -> str

# Check for not None rather than extracting on search, to keep pyre happy
if key_name_search and bucket_name_search:

bucket = file_path[
bucket_name_search.span()[1] : key_name_search.span()[1] - 1
]
Expand Down Expand Up @@ -289,7 +285,6 @@ def generate_env_vars_dicts_list(
server_hostnames: Optional[List[str]] = None,
server_private_key_ref_provider: Optional[PrivateKeyReferenceProvider] = None,
) -> List[Dict[str, str]]:

_validate_env_vars_length(
num_containers=num_containers,
server_ip_addresses=server_ip_addresses,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def main() -> None:
json_output = json.dumps(instance_dict)
elif path == LIFT_PC_PATH:
instance_dict = json.loads(json_output)
instance_dict["infra_config"][
"invalid_parameter_to_exclude"
] = "This instance value should be excluded."
instance_dict["infra_config"]["invalid_parameter_to_exclude"] = (
"This instance value should be excluded."
)
instance_dict["infra_config"]["instances"][0][
"invalid_parameter_to_exclude"
] = "This instance value should be excluded."
Expand Down
Loading