Skip to content

Commit

Permalink
re-doing contrib/semseg
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Nov 26, 2024
1 parent 63c5684 commit 4a68769
Show file tree
Hide file tree
Showing 23 changed files with 1,140 additions and 1,503 deletions.
1 change: 1 addition & 0 deletions contrib/semseg/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

6 changes: 3 additions & 3 deletions contrib/semseg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class Train:
"""Linear layer learning rate."""
weight_decay: float = 1e-3
"""Weight decay for AdamW."""
n_epochs: int = 10
n_epochs: int = 100
"""Number of training epochs for linear layer."""
batch_size: int = 32
batch_size: int = 1024
"""Training batch size for linear layer."""
n_workers: int = 8
n_workers: int = 32
"""Number of dataloader workers."""
train_acts: saev.config.DataLoad = dataclasses.field(
default_factory=saev.config.DataLoad
Expand Down
22 changes: 10 additions & 12 deletions contrib/semseg/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,21 @@ def __():
if pkg_root not in sys.path:
sys.path.append(pkg_root)

import os
import csv
import einops

import marimo as mo
import sklearn.decomposition
import os

import numpy as np
import datasets
import altair as alt
import beartype
from PIL import Image, ImageDraw
from jaxtyping import jaxtyped, Int

import torch
import datasets
import einops
import marimo as mo
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import altair as alt
import sklearn.decomposition
import torch
from jaxtyping import Int, jaxtyped
from PIL import Image, ImageDraw

import saev.activations
import saev.config
Expand Down
78 changes: 45 additions & 33 deletions contrib/semseg/dashboard2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ def __():

import random

import marimo as mo
import beartype

import contrib.semantic_seg.training
import marimo as mo
import numpy as np
import torch
from torchvision.transforms import v2
from jaxtyping import Int, UInt8, jaxtyped
from PIL import Image
from jaxtyping import jaxtyped, UInt8, Int
from torchvision.transforms import v2

import contrib.semantic_seg.training
import saev.config

return (
Image,
Int,
Expand All @@ -45,7 +45,9 @@ def __():

@app.cell
def __(contrib):
ckpt_fpath = "/home/stevens.994/projects/saev-live/checkpoints/faithfulness/model.pt"
ckpt_fpath = (
"/home/stevens.994/projects/saev-live/checkpoints/faithfulness/model.pt"
)
model = contrib.semantic_seg.training.load(ckpt_fpath)
model.eval()
return ckpt_fpath, model
Expand All @@ -68,15 +70,12 @@ def __(contrib, saev):
@app.cell
def __(v2):
def make_img_transform():
return v2.Compose(
[
v2.Resize(size=224, interpolation=v2.InterpolationMode.NEAREST),
v2.CenterCrop(size=(224, 224)),
# v2.ToImage(),
# v2.ToDtype(torch.uint8),
]
)

return v2.Compose([
v2.Resize(size=224, interpolation=v2.InterpolationMode.NEAREST),
v2.CenterCrop(size=(224, 224)),
# v2.ToImage(),
# v2.ToDtype(torch.uint8),
])

img_transform = make_img_transform()
return img_transform, make_img_transform
Expand All @@ -97,14 +96,14 @@ def make_colors(seed: int = 42) -> UInt8[np.ndarray, "n 3"]:
colors = np.array(colors, dtype=np.uint8)
return colors


@jaxtyped(typechecker=beartype.beartype)
def color_map(map: UInt8[np.ndarray, "width height"]) -> Image.Image:
colored = np.zeros((224, 224, 3), dtype=np.uint8)
for i, color in enumerate(make_colors()):
colored[map == i, :] = color

return Image.fromarray(colored)

return color_map, make_colors


Expand Down Expand Up @@ -177,11 +176,17 @@ def mean_iou(

if ignore_class is not None:
pred_one_hot = torch.cat(
(pred_one_hot[..., :ignore_class], pred_one_hot[..., ignore_class + 1 :]),
(
pred_one_hot[..., :ignore_class],
pred_one_hot[..., ignore_class + 1 :],
),
axis=-1,
)
true_one_hot = torch.cat(
(true_one_hot[..., :ignore_class], true_one_hot[..., ignore_class + 1 :]),
(
true_one_hot[..., :ignore_class],
true_one_hot[..., ignore_class + 1 :],
),
axis=-1,
)

Expand All @@ -193,7 +198,6 @@ def mean_iou(
# Handle division by zero
return ((intersection + eps) / (union + eps)).mean().item()


mean_iou(y_pred, y_true, 151)
return (mean_iou,)

Expand Down Expand Up @@ -281,7 +285,6 @@ def intersect_and_union(

return area_intersect, area_union, area_pred_label, area_label


@jaxtyped(typechecker=beartype.beartype)
def total_intersect_and_union(
results,
Expand Down Expand Up @@ -321,8 +324,15 @@ def total_intersect_and_union(
total_area_pred_label = np.zeros((num_labels,), dtype=np.float64)
total_area_label = np.zeros((num_labels,), dtype=np.float64)
for result, gt_seg_map in zip(results, gt_seg_maps):
area_intersect, area_union, area_pred_label, area_label = intersect_and_union(
result, gt_seg_map, num_labels, ignore_index, label_map, reduce_labels
area_intersect, area_union, area_pred_label, area_label = (
intersect_and_union(
result,
gt_seg_map,
num_labels,
ignore_index,
label_map,
reduce_labels,
)
)
total_area_intersect += area_intersect
total_area_union += area_union
Expand All @@ -335,7 +345,6 @@ def total_intersect_and_union(
total_area_label,
)


@jaxtyped(typechecker=beartype.beartype)
def mean_iou(
results,
Expand Down Expand Up @@ -376,10 +385,13 @@ def mean_iou(
- *per_category_iou* (`ndarray` of shape `(num_labels,)`):
Per category IoU.
"""
total_area_intersect, total_area_union, total_area_pred_label, total_area_label = (
total_intersect_and_union(
results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels
)
(
total_area_intersect,
total_area_union,
total_area_pred_label,
total_area_label,
) = total_intersect_and_union(
results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels
)

# compute metrics
Expand All @@ -396,20 +408,20 @@ def mean_iou(
metrics["per_category_accuracy"] = acc

if nan_to_num is not None:
metrics = dict(
{
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in metrics.items()
}
)
metrics = dict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in metrics.items()
})

return metrics

return intersect_and_union, mean_iou, total_intersect_and_union


@app.cell
def __():
import tensordict

return (tensordict,)


Expand Down
39 changes: 39 additions & 0 deletions contrib/semseg/training.py
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
import beartype
import torch

import saev.config

from . import config

n_classes = 151


@beartype.beartype
def main(cfg: config.Train):
train_dataset = Dataset(cfg.train_acts, cfg.train_imgs)
val_dataset = Dataset(cfg.val_acts, cfg.val_imgs)

model = torch.nn.Linear(train_dataset.d_vit, n_classes)
optim = torch.optim.AdamW(
model.parameters, lr=cfg.learning_rate, weight_decay=cfg.weight_decay
)

for epoch in range(cfg.n_epochs):
model.train()
for batch in train_dataloader:
breakpoint()

model.eval()
for batch in val_dataloader:
breakpoint()


@beartype.beartype
class Dataset(torch.utils.data.Dataset):
def __init__(
self, acts_cfg: saev.config.DataLoad, imgs_cfg: saev.config.Ade20kDataset
):
breakpoint()

@property
def d_vit(self) -> int:
breakpoint()
64 changes: 64 additions & 0 deletions docs/contrib/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
<meta name="generator" content="pdoc3 0.11.1">
<title>contrib API documentation</title>
<meta name="description" content="">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => {
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
hljs.highlightAll();
})</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Namespace <code>contrib</code></h1>
</header>
<section id="section-intro">
</section>
<section>
<h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dl>
<dt><code class="name"><a title="contrib.semseg" href="semseg/index.html">contrib.semseg</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
</section>
<section>
</section>
<section>
</section>
<section>
</section>
</article>
<nav id="sidebar">
<div class="toc">
<ul></ul>
</div>
<ul id="index">
<li><h3><a href="#header-submodules">Sub-modules</a></h3>
<ul>
<li><code><a title="contrib.semseg" href="semseg/index.html">contrib.semseg</a></code></li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
</footer>
</body>
</html>
Loading

0 comments on commit 4a68769

Please sign in to comment.