-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seg example #42
base: master
Are you sure you want to change the base?
Seg example #42
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Please see my comments inline.
Please also add a README file like in other examples and a short test (running just on several samples) if possible.
@@ -27,12 +27,16 @@ | |||
from scipy.ndimage.filters import gaussian_filter | |||
from scipy.ndimage.interpolation import map_coordinates | |||
from torch import Tensor | |||
import elasticdeform as ed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a new dependency - add it to requirements.txt
If there are more, please add them too.
@@ -265,23 +274,13 @@ def aug_op_elastic_transform(aug_input: Tensor, alpha: float = 1, sigma: float = | |||
:param channels: which channels to apply the augmentation | |||
:return distorted image | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you change it? the previous implementation didn't work?
I see that you are not using channels - can you support it as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it more straightforward and easy to use.
dx2 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | ||
# convert back to torch tensor | ||
aug_input = [numpy.array(t) for t in aug_input] | ||
aug_input_d = ed.deform_random_grid(aug_input, sigma=7, points=3, axis=[(1, 2), (1,2)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use fixed values - use the function arguments
random crop a (3d) tensor and resize it to a given size | ||
:param crop_size: float <= 1.0 - the fraction to crop from the original tensor for each dim | ||
:param x_off: float <= 1.0 - the x-offset to take | ||
:param y_off: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add doc
|
||
|
||
def aug_op_random_crop_and_resize(aug_input: Tensor, | ||
out_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing type annotation - you have more though - can you please go over the new code and add type annotations?
# :param padding: Optional, padding size | ||
# """ | ||
|
||
# self.input_data = input_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented out code
image = imread(image_fn) | ||
|
||
# ====================================================================== | ||
# TODO - change type to float if input image and to int it mask image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
image = image.astype('float32') | ||
# ===================================================================== | ||
|
||
# img_path = self.input_data + str(inner_image_desc) + '.jpg' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented out code and unused functions
fuse_examples/segmentation/unet.py
Outdated
@@ -0,0 +1,113 @@ | |||
|
|||
""" Full assembly of the parts to form the complete network """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did tou copy this code from?
What is the license? Should we add the link to the original code?
@@ -0,0 +1,584 @@ | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that this is old code that you should delete, right?
… of the unet code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Some more minor comments.
I guess that you still plan to add readme with info, right?
What about a short unittest? can you come up with something relatively short?
|
||
aug_tensor = F.interpolate(aug_tensor, out_size) | ||
|
||
# else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else throw error?
pred_name, | ||
target_name, | ||
pred_name: str = None, | ||
target_name: str = None, | ||
filter_func: Optional[Callable]=None, | ||
class_weights=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type annotation
target_name, | ||
def __init__(self, | ||
pred_name: str = None, | ||
target_name: str = None, | ||
filter_func: Optional[Callable] = None, | ||
class_weights=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type annotation
@@ -105,7 +89,7 @@ | |||
'force_reset_model_dir': True, # If True will reset model dir automatically - otherwise will prompt 'are you sure' message. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not assume anything about ROOT and DATA_ROOT.
Please document what each one is, and ask the user to fill it in
target_name='data.gt.gt_global') | ||
} | ||
|
||
model = model.cuda() | ||
# model = model.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
########################################## | ||
main_out_path = '../siim_data' | ||
dataset_path = '../siim/' | ||
out_size_list = [256, 512] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function argument?
out_size_list = [256, 512] | ||
|
||
|
||
def rle2mask(rles, width, height): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing type annotations and documentation in this file
|
||
|
||
|
||
if __name__ == '__main__': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you require to call this script?
If yes, why isn't it part of the processor (instead of reading and saving back to the disk?_
FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus) | ||
|
||
RUNNING_MODES = ['train', 'infer', 'eval'] # Options: 'train', 'infer', 'eval' | ||
# RUNNING_MODES = ['eval'] # Options: 'train', 'infer', 'eval' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
@@ -0,0 +1,117 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move it to fuse/models?
…from rle encoding
* remove fuse1 data package * remove dataset from manager * convert mnist to fuse2 style * add fuse data package * adjust mnist runner * imaging extension Co-authored-by: moshiko <[email protected]> Co-authored-by: Alex Golts <[email protected]>
…o hello_world_unittest
…e nit to the user inside the notebook
Unittest for hello world notebook
segmentation example using the eval package api