Skip to content

Commit

Permalink
Version 0.2.0 (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz authored Dec 24, 2020
1 parent c6650d6 commit 68de6fc
Show file tree
Hide file tree
Showing 41 changed files with 796 additions and 94 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,19 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip

- name: Install PyTorch on Linux and Windows
if: >
matrix.operating-system == 'ubuntu-latest' ||
matrix.operating-system == 'windows-latest'
run: >
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch on MacOS
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.7.1 torchvision==0.8.2
- name: Install dependencies
run: pip install -f https://download.pytorch.org/whl/torch_stable.html .[tests]
run: pip install .[tests]
- name: Install linters
run: pip install pydocstyle flake8 flake8-docstrings mypy
- name: Run PyTest
Expand Down
23 changes: 23 additions & 0 deletions .github/workflows/publish_docker_image.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Publish Docker image
on:
release:
types: [published]

jobs:
push_to_registry:
name: Push Docker image to GitHub Packages
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v2
- name: Login to Github Container Registry
run: echo ${{ secrets.CR_PAT }} | docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Build image
run: |
docker build
-f docker/Dockerfile
--tag ghcr.io/albumentations-team/autoalbument:${{ github.event.release.tag_name }}
--tag ghcr.io/albumentations-team/autoalbument:latest
.
- name: Push image
run: docker push ghcr.io/albumentations-team/autoalbument
2 changes: 1 addition & 1 deletion autoalbument/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.2.0"
4 changes: 2 additions & 2 deletions autoalbument/cli/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from autoalbument.config.faster_autoaugment import FasterAutoAugmentSearchConfig
from autoalbument.config.validation import validate_cfg
from autoalbument.faster_autoaugment.search import get_faa_seacher
from autoalbument.faster_autoaugment.search import get_faa_searcher
from autoalbument.utils.hydra import get_config_dir

OmegaConf.register_resolver("config_dir", get_config_dir)
Expand Down Expand Up @@ -37,5 +37,5 @@ def main(cfg):
print(get_prettified_cfg(cfg))
cwd = os.getcwd()
print(f"Working directory: {cwd}")
faa_searcher = get_faa_seacher(cfg)
faa_searcher = get_faa_searcher(cfg)
faa_searcher.search()
4 changes: 4 additions & 0 deletions autoalbument/cli/templates/classification/search.yaml.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ policy_model:
# Settings for Classification Model that is used for two purposes:
# 1. As a model that performs classification of input images.
# 2. As a Discriminator for Policy Model.

classification_model:
# By default, AutoAlbument uses an instance of `autoalbument.faster_autoaugment.models.ClassificationModel`
# as a classification model. This model takes three parameters: `num_classes`, `architecture` and `pretrained`.
_target_: autoalbument.faster_autoaugment.models.ClassificationModel

# Number of classes in the dataset. The dataset implementation should return an integer in the range
# [0, num_classes - 1] as a class label of an image.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ policy_model:
# 1. As a model that performs classification of input images.
# 2. As a Discriminator for Policy Model.
classification_model:
# By default, AutoAlbument uses an instance of `autoalbument.faster_autoaugment.models.ClassificationModel`
# as a classification model. This model takes three parameters: `num_classes`, `architecture` and `pretrained`.
_target_: autoalbument.faster_autoaugment.models.ClassificationModel

# Number of classes in the dataset. The dataset implementation should return an integer in the range
# [0, num_classes - 1] as a class label of an image.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ policy_model:
# 1. As a model that performs semantic segmentation of input images.
# 2. As a Discriminator for Policy Model.
semantic_segmentation_model:
# By default, AutoAlbument uses an instance of `autoalbument.faster_autoaugment.models.SemanticSegmentationModel` as
# a semantic segmentation model.
# This model takes four parameters: `num_classes`, `architecture`, `encoder_architecture` and `pretrained`.
_target_: autoalbument.faster_autoaugment.models.SemanticSegmentationModel

# The number of classes in the dataset. The dataset implementation should return a mask as a NumPy array with
# the shape [height, width, num_classes]. In a case of binary segmentation you can set `num_classes` to 1.
Expand All @@ -55,6 +59,7 @@ semantic_segmentation_model:
# refer to https://github.com/qubvel/segmentation_models.pytorch#encoders-
pretrained: True


data:
# Class for the PyTorch Dataset and arguments to it. AutoAlbument will create an object of this class using
# the `instantiate` method from Hydra - https://hydra.cc/docs/next/patterns/instantiate_objects/overview/.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ policy_model:
# 1. As a model that performs semantic segmentation of input images.
# 2. As a Discriminator for Policy Model.
semantic_segmentation_model:
# By default, AutoAlbument uses an instance of `autoalbument.faster_autoaugment.models.SemanticSegmentationModel` as
# a semantic segmentation model.
# This model takes four parameters: `num_classes`, `architecture`, `encoder_architecture` and `pretrained`.
_target_: autoalbument.faster_autoaugment.models.SemanticSegmentationModel

# The number of classes in the dataset. The dataset implementation should return a mask as a NumPy array with
# the shape [height, width, num_classes]. In a case of binary segmentation you can set `num_classes` to 1.
Expand Down
6 changes: 4 additions & 2 deletions autoalbument/config/faster_autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ class PolicyModelConfig:

@dataclass
class ClassificationModelConfig:
num_classes: int = MISSING
_target_: str = "autoalbument.faster_autoaugment.models.ClassificationModel"
num_classes: Optional[int] = None
architecture: str = "resnet18"
pretrained: bool = False


@dataclass
class SemanticSegmentationModelConfig:
num_classes: int = MISSING
_target_: str = "autoalbument.faster_autoaugment.models.SemanticSegmentationModel"
num_classes: Optional[int] = None
architecture: str = "Unet"
encoder_architecture: str = "resnet18"
pretrained: bool = False
Expand Down
59 changes: 59 additions & 0 deletions autoalbument/faster_autoaugment/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Tuple

import segmentation_models_pytorch as smp
import timm
from torch import nn, Tensor
from torch.nn import Flatten


class BaseDiscriminator(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()

def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
raise NotImplementedError


class ClassificationModel(BaseDiscriminator):
def __init__(self, architecture, pretrained, num_classes):
super().__init__()
self.base_model = timm.create_model(architecture, pretrained=pretrained)
self.base_model.reset_classifier(num_classes)
self.classifier = self.base_model.get_classifier()
num_features = self.classifier.in_features
self.discriminator = nn.Sequential(
nn.Linear(num_features, num_features), nn.ReLU(), nn.Linear(num_features, 1)
)

def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
x = self.base_model.forward_features(input)
x = self.base_model.global_pool(x).flatten(1)
return self.classifier(x), self.discriminator(x).view(-1)


class SemanticSegmentationModel(BaseDiscriminator):
def __init__(self, architecture, encoder_architecture, num_classes, pretrained):
super().__init__()
model = getattr(smp, architecture)

self.base_model = model(
encoder_architecture, encoder_weights=self._get_encoder_weights(pretrained), classes=num_classes
)
num_features = self.base_model.encoder.out_channels[-1]
self.base_model.classification_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Flatten(),
nn.Linear(num_features, num_features),
nn.ReLU(),
nn.Linear(num_features, 1),
)

@staticmethod
def _get_encoder_weights(pretrained):
if isinstance(pretrained, bool):
return "imagenet" if pretrained else None
return pretrained

def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
mask, discriminator_output = self.base_model(input)
return mask, discriminator_output.view(-1)
80 changes: 15 additions & 65 deletions autoalbument/faster_autoaugment/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,69 +10,22 @@
from typing import Tuple

import albumentations as A
import timm
import torch
from albumentations.pytorch import ToTensorV2
from hydra.utils import instantiate
from torch import Tensor, nn
from torch.nn import Flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from autoalbument.faster_autoaugment.metrics import get_average_parameter_change
from autoalbument.faster_autoaugment.models import ClassificationModel, SemanticSegmentationModel
from autoalbument.faster_autoaugment.utils import MAX_VALUES_BY_INPUT_DTYPE, get_dataset_cls, MetricTracker, set_seed
from autoalbument.utils.files import symlink
from autoalbument.faster_autoaugment.policy import Policy
import segmentation_models_pytorch as smp

log = logging.getLogger(__name__)


class ClassificationDiscriminator(nn.Module):
def __init__(self, architecture, pretrained, num_classes):
super().__init__()
self.base_model = timm.create_model(architecture, pretrained=pretrained)
self.base_model.reset_classifier(num_classes)
self.classifier = self.base_model.get_classifier()
num_features = self.classifier.in_features
self.discriminator = nn.Sequential(
nn.Linear(num_features, num_features), nn.ReLU(), nn.Linear(num_features, 1)
)

def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
x = self.base_model.forward_features(input)
x = self.base_model.global_pool(x).flatten(1)
return self.classifier(x), self.discriminator(x).view(-1)


class SegmentationDiscriminator(nn.Module):
def __init__(self, architecture, encoder_architecture, num_classes, pretrained):
super().__init__()
model = getattr(smp, architecture)

self.base_model = model(
encoder_architecture, encoder_weights=self._get_encoder_weights(pretrained), classes=num_classes
)
num_features = self.base_model.encoder.out_channels[-1]
self.base_model.classification_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Flatten(),
nn.Linear(num_features, num_features),
nn.ReLU(),
nn.Linear(num_features, 1),
)

@staticmethod
def _get_encoder_weights(pretrained):
if isinstance(pretrained, bool):
return "imagenet" if pretrained else None
return pretrained

def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
mask, discriminator_output = self.base_model(input)
return mask, discriminator_output.view(-1)


class FasterAutoAugmentBase:
def __init__(self, cfg):
self.cfg = cfg
Expand Down Expand Up @@ -187,9 +140,15 @@ def create_optimizers(self):
"policy": policy_optimizer,
}

def create_main_model(self):
def get_main_model_cfg(self):
raise NotImplementedError

def create_main_model(self):
model_cfg = self.get_main_model_cfg()
main_model = instantiate(model_cfg)
main_model = main_model.to(self.cfg.device)
return main_model

def create_policy_model(self):
policy_model_cfg = self.cfg.policy_model
normalization_cfg = self.cfg.data.normalization
Expand Down Expand Up @@ -239,6 +198,7 @@ def wgan_loss(
loss = self.cfg.policy_model.task_factor * self.loss(output, n_target)
loss.backward(retain_graph=True)
d_n_loss = n_output.mean()

d_n_loss.backward(-ones)

with torch.no_grad():
Expand All @@ -247,6 +207,7 @@ def wgan_loss(

_, a_output = self.models["main"](augmented)
d_a_loss = a_output.mean()

d_a_loss.backward(ones)
gp = self.cfg.policy_model.gp_factor * self.gradient_penalty(n_input, augmented)
gp.backward()
Expand Down Expand Up @@ -392,12 +353,8 @@ def search(self):


class FAAClassification(FasterAutoAugmentBase):
def create_main_model(self):
model_cfg = self.cfg.classification_model
main_model = ClassificationDiscriminator(
model_cfg.architecture, num_classes=model_cfg.num_classes, pretrained=model_cfg.pretrained
).to(self.cfg.device)
return main_model
def get_main_model_cfg(self):
return self.cfg.classification_model

def create_loss(self):
return nn.CrossEntropyLoss().to(self.cfg.device)
Expand All @@ -408,15 +365,8 @@ def policy_forward_for_policy_train(self, a_input, a_target):


class FAASemanticSegmentation(FasterAutoAugmentBase):
def create_main_model(self):
model_cfg = self.cfg.semantic_segmentation_model
main_model = SegmentationDiscriminator(
model_cfg.architecture,
encoder_architecture=model_cfg.encoder_architecture,
num_classes=model_cfg.num_classes,
pretrained=model_cfg.pretrained,
).to(self.cfg.device)
return main_model
def get_main_model_cfg(self):
return self.cfg.semantic_segmentation_model

def create_loss(self):
return nn.BCEWithLogitsLoss().to(self.cfg.device)
Expand All @@ -426,7 +376,7 @@ def policy_forward_for_policy_train(self, a_input, a_target):
return output["image_batch"], output["mask_batch"]


def get_faa_seacher(cfg):
def get_faa_searcher(cfg):
task = cfg.task
if task == "semantic_segmentation":
return FAASemanticSegmentation(cfg)
Expand Down
18 changes: 18 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime

RUN apt-get update && apt-get install -y \
libgl1-mesa-glx \
libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*

RUN useradd --create-home --shell /bin/bash --no-log-init autoalbument
USER autoalbument
ENV PATH="/home/autoalbument/.local/bin:${PATH}"
WORKDIR /opt/autoalbument
COPY . .
RUN pip install --no-cache-dir .
COPY docker/entrypoint.sh entrypoint.sh

WORKDIR /autoalbument

ENTRYPOINT ["/opt/autoalbument/entrypoint.sh"]
3 changes: 3 additions & 0 deletions docker/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

autoalbument-search --config-dir /config "$@"
4 changes: 2 additions & 2 deletions examples/cifar10/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


class SearchDataset(torchvision.datasets.CIFAR10):
def __init__(self, transform=None):
super().__init__(root="~/data/cifar10", train=True, download=True, transform=transform)
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)

def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
Expand Down
Loading

0 comments on commit 68de6fc

Please sign in to comment.