This commit is contained in:
HamsterMimi
2023-05-04 13:09:03 +08:00
commit 189df25fd3
207 changed files with 242887 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/MeCo.iml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

126
.idea/deployment.xml generated Normal file
View File

@@ -0,0 +1,126 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="ubuntu@172.16.214.100:7712 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (10)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (11)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (8)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.100:7712 password (9)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="ubuntu@172.16.214.99:7712 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>

View File

@@ -0,0 +1,26 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="6">
<item index="0" class="java.lang.String" itemvalue="netifaces" />
<item index="1" class="java.lang.String" itemvalue="scikit-learn" />
<item index="2" class="java.lang.String" itemvalue="torch" />
<item index="3" class="java.lang.String" itemvalue="Auto-PyTorch" />
<item index="4" class="java.lang.String" itemvalue="torchvision" />
<item index="5" class="java.lang.String" itemvalue="tensorflow-gpu" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredIdentifiers">
<list>
<option value="random.random.choices" />
</list>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

4
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.16 (sftp://ubuntu@172.16.214.100:7712/jty/anaconda3/envs/meco/bin/python3.8)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/MeCo.iml" filepath="$PROJECT_DIR$/.idea/MeCo.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 zerocostptnas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

175
Layers/layers.py Normal file
View File

@@ -0,0 +1,175 @@
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
class Linear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__(in_features, out_features, bias)
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
self.register_buffer('score', torch.zeros(self.weight.shape))
if self.bias is not None:
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
def forward(self, input):
W = self.weight_mask * self.weight
if self.bias is not None:
b = self.bias_mask * self.bias
else:
b = self.bias
return F.linear(input, W, b)
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros'):
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, bias, padding_mode)
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
self.register_buffer('score', torch.zeros(self.weight.shape))
if self.bias is not None:
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
def _conv_forward(self, input, weight, bias):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
W = self.weight_mask * self.weight
if self.bias is not None:
b = self.bias_mask * self.bias
else:
b = self.bias
return self._conv_forward(input, W, b)
class BatchNorm1d(nn.BatchNorm1d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(BatchNorm1d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
if self.affine:
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
self.register_buffer('score', torch.zeros(self.weight.shape))
def forward(self, input):
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
if self.affine:
W = self.weight_mask * self.weight
b = self.bias_mask * self.bias
else:
W = self.weight
b = self.bias
return F.batch_norm(
input, self.running_mean, self.running_var, W, b,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
class BatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(BatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
if self.affine:
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
self.register_buffer('score', torch.zeros(self.weight.shape))
def forward(self, input):
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
if self.affine:
W = self.weight_mask * self.weight
b = self.bias_mask * self.bias
else:
W = self.weight
b = self.bias
return F.batch_norm(
input, self.running_mean, self.running_var, W, b,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
class Identity1d(nn.Module):
def __init__(self, num_features):
super(Identity1d, self).__init__()
self.num_features = num_features
self.weight = Parameter(torch.Tensor(num_features))
self.bias = None
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
self.reset_parameters()
self.register_buffer('score', torch.zeros(self.weight.shape))
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, input):
W = self.weight_mask * self.weight
return input * W
class Identity2d(nn.Module):
def __init__(self, num_features):
super(Identity2d, self).__init__()
self.num_features = num_features
self.weight = Parameter(torch.Tensor(num_features, 1, 1))
self.bias = None
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
self.reset_parameters()
self.register_buffer('score', torch.zeros(self.weight.shape))
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, input):
W = self.weight_mask * self.weight
return input * W

143
README.md Normal file
View File

@@ -0,0 +1,143 @@
# Zero-Cost Operation Scoring in Differentiable Architecture Search (Zero-Cost-PT)
Official impementation for AAAI 2023 submission:
"**Zero-Cost Operation Scoring in Differentiable Architecture Search**".
## Installation
```
Python >= 3.6
PyTorch >= 1.7.1
torchvision == 0.8.2
tensorboard == 2.4.1
scipy == 1.5.2
gpustat
```
## Usage/Examples
### Experiments on NAS-Bench-201
Scripts for reproducing our experiments can be found under the ```exp_scripts/``` folder.
#### 1. Prepare NAS-Bench-201 Data
1. Download NAS-Bench-201 checkpoint from [NAS-Bench-201-v1_0-e61699.pth](https://drive.google.com/file/d/1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs/view), and place it under ```./data``` folder.
#### 2. Prepare NAS-Bench-201 and Zero-Cost-NAS API
i. Install NAS-Bench-201 api via `pip`
```
pip install nas-bench-201
```
ii. Install Zero-Cost-NAS API
Clone the code repository from [Zero-Cost-NAS](https://github.com/SamsungLabs/zero-cost-nas). Go to the root directory of the cloned repo and run:
```
pip install .
```
#### 3. Run Zero-Cost-PT on NAS-Bench-201
You can run our Zero-Cost-PT with the following script:
```
bash zerocostpt_nb201_pipeline.sh --seed [SEED]
```
You can specify random seeds with ``` --seed ``` for reproducibility. In our experiments we use random seeds 0, 1, 2, 3.
You could also run with different zero-cost proxies by specifying ```--metrics```, and different edge discretization order with ```--edge_decision```. The number of searching interations (N in our paper) is controlled by parameter ```--pool_size```, while the number of validation iterations (V in our paper) can be specified by ```--validate_rounds```. Please see Section 4.2 in our paper for more information on those parameters.
For example, a typical experiement setting could be:
```--pool_size 10 --edge_decision random --validate_rounds 100 --metrics jacob --seed 0```
### Experiments on NAS-Bench-1shot1
Scripts for reproducing our experiments on NAS-Bench-1shot1 can be found under the ```nasbench1shot1``` folder.
#### 1. Prepare NAS-Bench-101 Data
1. Download NAS-Bench-1shot1 dataset from [nasbench_full.tfrecord](https://storage.googleapis.com/nasbench/nasbench_full.tfrecord), and place it under ```./data``` folder.
#### 2. Prepare NAS-Bench-101 API
Please refer orginal [NAS-Bench-101](https://github.com/google-research/nasbench) for details of API installation
#### 3. Run Zero-Cost-PT on NAS-Bench-1shot1
You can reproduce our Zero-Cost-PT with the following script:
```
cd nasbench1shot1/optimizers/darts/
```
i. Run the following script for search architectures with Zero-Cost-PT from different sub-search-space on NAS-Bench-1shot1
```
python network_proposal.py --seed [SEED] --search_space [SEARCH_SPACE]
```
In NAS-Bench-1shot1, it contains 3 sub-search-space which you can select from [1, 2, 3]
ii. Evaluated final searched model
```
python post_validate.py --seed [SEED] --search_exp_path [PATH_to_LAST_STEP_LOG_FOLDER]
```
### Experiments on NAS-Bench-Macro
Scripts for reproducing our experiments on NAS-Bench-1shot1 can be found under the ```nasbenchmacro``` folder.
#### 1. Prepare NAS-Bench-Macro Data
1. Download NAS-Bench-Macro dataset from [nas-bench-macro_cifar10.jsonnas-bench-macro_cifar10.json](https://github.com/xiusu/NAS-Bench-Macro/tree/master/data/nas-bench-macro_cifar10.json), and place it under ```./data``` folder.
#### 2. Run Zero-Cost-PT on NAS-Bench-Macro
You can reproduce our Zero-Cost-PT with the following script:
```
cd nasbenchmacro/
```
i. Run the following script for search architectures with Zero-Cost-PT on NAS-Bench-Macro
```
python network_proposal.py --seed [SEED]
```
### Experiments on DARTS-like Spaces
Scripts for reproducing our experiments can be found under the ```exp_scripts/``` folder, and Zero-Cost-NAS API is also needed.
#### 1. For DARTS CNN space
Run the following script to search architectures with Zero-Cost-PT and train the searched architecture directly (with the same random seed):
```
bash zerocostpt_darts_pipeline.sh --seed [SEED]
```
Our default parameter settings are:
```--pool_size 10 --edge_decision random --validate_rounds 100 --metrics jacob```
#### 2、For DARTS subspaces S1-S4
On CIFAR-10 use the following script:
```
bash zerocostpt_darts_pipeline.sh --seed [SEED] --space [s1-s4]
```
On CIFAR-100 and SVHN, use the following scripts:
```
bash zerocostpt_darts_pipeline_svhn.sh --seed [SEED] --space [s1-s4]
bash zerocostpt_darts_pipeline_c100.sh --seed [SEED] --space [s1-s4]
```
#### 3、Directly train the searched architectures reported in our paper
For reproducibility we also provide training scripts for evaluation of all the reported architectures in our paper. For an architecture specified by ```[genotype_name]```, run the following scrips to train:
```
bash eval.sh --arch [genotype_name] # for DARTS C10
bash eval-c100.sh --arch [genotype_name] # for DARTS C100
bash eval-svhn.sh --arch [genotype_name] # for DARTS SVHN
```
The model genotypes are provided in ```sota/cnn/genotypes.py```. For instance, genotype `init_pt_s5_C10_0_100_N10` specifies the architecture searched by Zero-Cost-PT (with default settings as explaind above) on DARTS CNN space (S5), using 10 search iterations (N=10), 100 validation iterations (V=100), and random seed 0.
#
### Other Experiments Reported in Appendix
We also provide code to reproduce experiment results reported in appendix, e.g. genotypes for maximum-param and random-sampling baselines, and Zero-Cost-PT for MobileNet-like spaces.
## Reference
Our code (Zero-Cost-PT) is based on [dart-pt](https://github.com/ruocwang/darts-pt) and [Zero-Cost-NAS](https://github.com/SamsungLabs/zero-cost-nas). For experiments on Nasbench1shot1 is based on [nasbench-1shot1](https://github.com/automl/nasbench-1shot1)

47
Scorers/scorer.py Normal file
View File

@@ -0,0 +1,47 @@
import torch
import numpy as np
class Jocab_Scorer:
def __init__(self, gpu):
self.gpu = gpu
print('Jacob score init')
def score(self, model, input, target):
batch_size = input.shape[0]
model.K = torch.zeros(batch_size, batch_size).cuda()
input = input.cuda()
with torch.no_grad():
model(input)
score = self.hooklogdet(model.K.cpu().numpy())
#print(score)
return score
def setup_hooks(self, model, batch_size):
#initalize score
model = model.to(torch.device('cuda', self.gpu))
model.eval()
model.K = torch.zeros(batch_size, batch_size).cuda()
def counting_forward_hook(module, inp, out):
try:
# if not module.visited_backwards:
# return
if isinstance(inp, tuple):
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
model.K = model.K + K + K2
except:
pass
for name, module in model.named_modules():
if 'ReLU' in str(type(module)):
module.register_forward_hook(counting_forward_hook)
#module.register_backward_hook(counting_backward_hook)
def hooklogdet(self, K, labels=None):
s, ld = np.linalg.slogdet(K)
return ld

39
exp_scripts/eval-c100.sh Normal file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
script_name=`basename "$0"`
id=${script_name%.*}
dataset=${dataset:-cifar100}
seed=${seed:-2}
gpu=${gpu:-"auto"}
arch=${arch:-"none"}
batch_size=${batch_size:-96}
learning_rate=${learning_rate:-0.025}
resume_expid=${resume_expid:-'none'}
resume_epoch=${resume_epoch:-0}
while [ $# -gt 0 ]; do
if [[ $1 == *"--"* ]]; then
param="${1/--/}"
declare $param="$2"
# echo $1 $2 // Optional to see the parameter:value result
fi
shift
done
echo 'id:' $id
echo 'seed:' $seed
echo 'dataset:' $dataset
echo 'gpu:' $gpu
echo 'arch:' $arch
echo 'batch_size:' $batch_size
echo 'learning_rate:' $learning_rate
cd ../sota/cnn
python train.py \
--arch $arch \
--dataset $dataset \
--auxiliary --cutout \
--seed $seed --save $id --gpu $gpu \
--batch_size $batch_size --learning_rate $learning_rate \
--resume_expid $resume_expid --resume_epoch $resume_epoch \
--init_channels 16 --layers 8 \

38
exp_scripts/eval.sh Normal file
View File

@@ -0,0 +1,38 @@
#!/bin/bash
script_name=`basename "$0"`
id=${script_name%.*}
dataset=${dataset:-cifar10}
seed=${seed:-2}
gpu=${gpu:-"auto"}
arch=${arch:-"none"}
batch_size=${batch_size:-96}
learning_rate=${learning_rate:-0.025}
resume_expid=${resume_expid:-'none'}
resume_epoch=${resume_epoch:-0}
while [ $# -gt 0 ]; do
if [[ $1 == *"--"* ]]; then
param="${1/--/}"
declare $param="$2"
# echo $1 $2 // Optional to see the parameter:value result
fi
shift
done
echo 'id:' $id
echo 'seed:' $seed
echo 'dataset:' $dataset
echo 'gpu:' $gpu
echo 'arch:' $arch
echo 'batch_size:' $batch_size
echo 'learning_rate:' $learning_rate
cd ../sota/cnn
python3 train.py \
--arch $arch \
--dataset $dataset \
--auxiliary --cutout \
--seed $seed --save $id --gpu $gpu \
--batch_size $batch_size --learning_rate $learning_rate \
--resume_expid $resume_expid --resume_epoch $resume_epoch \

View File

@@ -0,0 +1,51 @@
#!/bin/bash
script_name=`basename "$0"`
id=${script_name%.*}
dataset=${dataset:-cifar10}
seed=${seed:-1}
gpu=${gpu:-"0"}
pool_size=${pool_size:-10}
space=${space:-s5}
metric=${metric:-'jacob'}
edge_decision=${edge_decision:-'random'}
validate_rounds=${validate_rounds:-100}
learning_rate=${learning_rate:-0.025}
while [ $# -gt 0 ]; do
if [[ $1 == *"--"* ]]; then
param="${1/--/}"
declare $param="$2"
# echo $1 $2 // Optional to see the parameter:value result
fi
shift
done
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset 'space:' $space
echo 'proj crit:' $metric
echo 'gpu:' $gpu
cd ../sota/cnn
python3 networks_proposal.py \
--search_space $space --dataset $dataset \
--batch_size 64 \
--seed $seed --save $id --gpu $gpu \
--edge_decision $edge_decision \
--proj_crit_normal $metric --proj_crit_reduce $metric --proj_crit_edge $metric \
--pool_size $pool_size\
cd ../../zerocostnas/
python3 post_validate.py\
--ckpt_path ../experiments/sota/$dataset-search-$id-$space-$seed-$pool_size-$metric\
--save $id --seed $seed --gpu $gpu\
--batch_size 64\
--edge_decision $edge_decision --proj_crit $metric \
--validate_rounds $validate_rounds\
cd ../sota/cnn
python3 train.py \
--seed $seed --gpu $gpu --save $id \
--arch ../../experiments/sota/$space-valid-$id-$seed-$pool_size-$validate_rounds-$metric\
--dataset $dataset \
--auxiliary --cutout \
--batch_size 96 --learning_rate $learning_rate \
--init_channels 36 --layers 20\
--from_dir\

View File

@@ -0,0 +1,50 @@
#!/bin/bash
script_name=`basename "$0"`
id=${script_name%.*}
dataset=${dataset:-cifar100}
seed=${seed:-2}
gpu=${gpu:-"auto"}
pool_size=${pool_size:-100}
space=${space:-s5}
metric=${metric:-'jacob'}
edge_decision=${edge_decision:-'random'}
validate_rounds=${validate_rounds:-100}
learning_rate=${learning_rate:-0.025}
while [ $# -gt 0 ]; do
if [[ $1 == *"--"* ]]; then
param="${1/--/}"
declare $param="$2"
# echo $1 $2 // Optional to see the parameter:value result
fi
shift
done
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset 'space:' $space
echo 'proj crit:' $metric
echo 'gpu:' $gpu
cd ../sota/cnn
python3 networks_proposal.py \
--search_space $space --dataset $dataset --batch_size 64 \
--seed $seed --save $id --gpu $gpu \
--edge_decision $edge_decision \
--proj_crit_normal $metric --proj_crit_reduce $metric --proj_crit_edge $metric \
--pool_size $pool_size\
cd ../zerocostnas/
python3 post_validate.py\
--ckpt_path ../experiments/sota/$dataset-search-$id-$space-$seed-$pool_size-$metric\
--save $id --seed $seed --gpu $gpu\
--edge_decision $edge_decision --proj_crit $metric \
--batch_size 64\
--validate_rounds $validate_rounds\
cd ../sota/cnn
python3 train.py \
--seed $seed --gpu $gpu --save $id\
--arch ../../experiments/sota/$space-valid-$id-$seed-$pool_size-$validate_rounds-$dataset-$metric\
--dataset $dataset \
--auxiliary --cutout \
--batch_size 96 --learning_rate $learning_rate \
--init_channels 16 --layers 20 \
--from_dir\

View File

@@ -0,0 +1,44 @@
#!/bin/bash
script_name=`basename "$0"`
id=${script_name%.*}
dataset=${dataset:-imagenet}
seed=${seed:-2}
gpu=${gpu:-"auto"}
pool_size=${pool_size:-10}
space=${space:-s5}
metric=${metric:-'jacob'}
edge_decision=${edge_decision:-'random'}
validate_rounds=${validate_rounds:-100}
learning_rate=${learning_rate:-0.025}
data=${data:-''}
while [ $# -gt 0 ]; do
if [[ $1 == *"--"* ]]; then
param="${1/--/}"
declare $param="$2"
# echo $1 $2 // Optional to see the parameter:value result
fi
shift
done
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset 'space:' $space
echo 'proj crit:' $metric
echo 'gpu:' $gpu
cd ../sota/cnn
python3 networks_proposal.py \
--search_space $space --dataset $dataset \
--seed $seed --save $id --gpu $gpu \
--edge_decision $edge_decision \
--proj_crit_normal $metric --proj_crit_reduce $metric --proj_crit_edge $metric \
--pool_size $pool_size\
--data $data\
cd ../../zerocostnas/
python3 post_validate.py\
--ckpt_path ../experiments/sota/$dataset-search-$id-$space-$seed-$pool_size\
--save $id --seed $seed --gpu $gpu\
--edge_decision $edge_decision --proj_crit $metric \
--batch_size 64\
--validate_rounds $validate_rounds\
--data $data\

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,96 @@
edge\op,none,skip_connect,nor_conv_1x1,nor_conv_3x3,avg_pool_3x3,
acc,77.36±22.55,83.81±10.64,86.38±9.1, 87.32±9.17,81.02±11.85,
discrete acc darts,83.42,84.1,72,76.35,39.66,
dartspt,85.43,17.02,78.13,59.09,85.34,
zc pt,3455.233646,3449.898772,3449.538363,3441.815563,3461.179476,3
discrete zc,3331.007285,3445.489455,3366.877065,3437.551079,3423.180255,
alpha,0.0758,0.7742,0.05,0.0761,0.024,
best-acc,94.15,94.18,94.44,94.68,93.86,
alpha-60,0.1387,0.4758,0.1296,0.181,0.0748,
tenaspt,38.5,48.0,31.0,6.0,37.5,
synflow,1.9286723850908796e+31,7.990282869734622e+30,1.2421187150331997e+30,9.438907569335487e+26,8.191504786187086e+30,
synflow_disc,4.639162000716631e+21, 1.4975281050055959e+26, 4.2221622054263117e+30, 1.9475517523688712e+36, 1.5075022033622535e+26,
best_nwot,1702.1967536035393,1773.1779654806287, 1793.8140278364453, 1792.8682630835763, 1761.1262357119376,
best_synflow,5.784248799475683e+39, 1.4769546208886953e+44, 6.658953754065702e+49, 5.1987025703231504e+39, 1.9928388494681343e+35,
zc-pt-post,3067.0476, 3055.9404, 3059.8901, 3060.4536, 3073.5583,
zc-disc-post,2942.267, 3068.6416, 3009.5847, 3028.1794, 3031.4248,
,,,,,,
acc,80.03 ±19.38,83.11 ±12.81,85.23 ±11.0,85.99±11.1,81.52±13.06,
discrete acc darts,85.12,83.39,76.72,81.34,84.38,
dartspt,85.52,36.1,84.39,80.95,85.49,
zc pt,3452.145851,3448.696318,3441.809174,3440.652631,3453.739943,3
discrete zc,3429.074707,3435.750274,3407.872847,3434.584512,3421.442414,
alpha,0.0629,0.813,0.039,0.0579,0.0269,
best-acc,94.24,94.16,94.49,94.68,94.09,
alpha-60,0.1236,0.5535,0.11,0.1249,0.088,
tenaspt,7.0,55.0,10.0,15.0,39.0,
synflow,3.116079880492518e+30,2.5018418732419554e+30,1.4274537256246266e+30,3.138287824323275e+29,2.5693894962958226e+30,
synflow_disc,5.615386425664938e+28, 2.340336657109326e+29, 1.9258305801684058e+30, 3.012759514473006e+32, 2.2897138361934977e+29,
best_nwot,1765.3743820515451, 1770.8436009213751, 1791.917305624048, 1793.8140278364453, 1763.877253730585,
best_synflow,1.9424580089849912e+49, 2.764587447411338e+49, 6.658953754065702e+49, 2.0353792445711388e+49, 1.4435653786128956e+49,
zc-pt-post,3067.0476, 3058.9197, 3048.8745, 3051.2664, 3066.668,
zc-disc-post,3020.0203, 3052.1936, 3026.2217, 3022.0935, 3029.2,
,,,,,,
acc,82.9±14.68,82.44 ±14.25,84.05 ±13.19,84.49 ±13.21,81.98 ±14.54,
discrete acc darts,85.96,85.18,54.02,78.41,84.88,
dartspt,85.51,80.29,81.86,77.68,85.32,
zc pt,3446.521007,3447.612434,3435.455206,3436.396744,3449.275466,2
discrete zc,3428.795464,3423.361285,3440.925616,3437.286935,3416.891544,
alpha,0.3339,0.4742,0.061,0.0774,0.0534,
best-acc,94.25,94.43,94.49,94.68,94.19,
alpha-60,0.2403,0.3297,0.1495,0.1748,0.1056,
tenaspt,31.5,10.0,30.0,16.5,36.5,
synflow,1.0312338471669537e+31,4.9191575008661263e+30,1.4241158958667068e+30,1.0282498082879338e+28,5.038622256524752e+30,
synflow_disc,1.6980829611704765e+25, 3.3199508659283994e+27, 3.3825056097270114e+30, 1.2059727722928161e+35, 3.279653417965715e+27,
best_nwot,1764.51075805859,1764.116749555202, 1793.8140278364453, 1792.8239766388833, 1764.1848313456592,
best_synflow,8.376122028137071e+41, 1.0615041036082487e+45, 6.658953754065702e+49, 8.399427750574918e+41, 2.5270360875229e+39,
zc-pt-post,3067.0476, 3068.708, 3056.3506, 3047.9695, 3071.3577,
zc-disc-post,3044.023, 3033.0627, 3032.825, 3052.0688, 3023.2302,
,,,,,,
acc,74.02 ±26.1,85.17 ±7.59,87.3 ±2.48,88.28 ±2.06,81.38 ±8.91,
discrete acc darts,66.18,85.38,78.8,81.59,82.89,
dartspt,85.49,9.86,81.79,59.18,85.48,
zc pt,3453.805194,3435.985406,3444.044047,3445.595326,3447.067855,1
discrete zc,3408.990502,3464.050741,3359.888463,3382.1755,3431.805571,
alpha,0.0267,0.8163,0.0471,0.0904,0.0194,
best-acc,94.16,94.68,94.03,94.04,93.85,
alpha-60,0.0636,0.6513,0.0826,0.1335,0.0691,
tenaspt,34.0,44.0,53.5,23.0,30.0,
synflow,2.0042808467776213e+30,1.9513599734483263e+30,1.5188352495143643e+30,7.704103863066581e+29,1.9536326167605112e+30,
synflow_disc,4.3050000047616484e+29, 7.635399455155384e+29, 1.5949429556375966e+30, 1.4519088590209463e+31, 7.345232988374157e+29,
best_nwot,1766.5481959337162,1769.1683503033412, 1793.8140278364453, 1792.8682630835763, 1765.1445530390838,
best_synflow,5.90523769961745e+49, 6.344766865099622e+49, 6.571181309028854e+49, 6.57509920946309e+49, 6.658953754065702e+49,
zc-pt-post,3067.0476, 3032.6658, 3058.9646, 3059.2861, 3047.1965,
zc-disc-post,2975.976, 3130.7397, 3008.5625, 3009.341, 3086.3398,
,,,,,,
acc,80.14 ±19.52,83.05 ±12.74,85.09 ±11.17,85.7 ±11.18,81.89 ±12.98,
discrete acc darts,86.44,84.75,80.23,80.46,80.13,
dartspt,85.45,51.15,78.84,64.64,85.14,
zc pt,3451.055723,3449.796894,3442.625354,3441.131751,3453.311493,3
discrete zc,3433.98773,3435.573458,3424.470031,3431.143217,3423.153213,
alpha,0.0857,0.7082,0.0716,0.0946,0.0399,
best-acc,94.29,94.18,94.56,94.68,94.23,
alpha-60,0.1183,0.48,0.1305,0.1732,0.0979,
tenaspt,32.0,32.5,36.5,32.0,52.0,
synflow,3.165975343348193e+30,2.4302742111297496e+30,1.4853908452542004e+30,2.868307126123347e+29,2.6891361283692336e+30,
synflow_disc,5.5202846896598e+28, 2.4896852024898197e+29, 2.1810394989246777e+30, 2.9482018739806336e+32, 2.2732178076450144e+29,
best_nwot, 1752.024924623228,1793.8140278364453, 1786.3402409418215, 1785.0294182838636, 1781.9741301640186,
best_synflow,1.8865959738805548e+49, 2.593134717306188e+49, 6.658953754065702e+49, 2.021273089103704e+49, 1.6187260144154453e+49,
zc-pt-post,3067.0476, 3060.9983, 3057.1006, 3054.3428, 3066.2087,
zc-disc-post,3037.8726, 3055.4219, 3027.6638, 3024.3271, 3037.8108,
,,,,,,
acc,77.61 ±22.16,83.43 ±11.34,86.18 ±9.08,86.95 ±9.02,81.74 ±11.79,
discrete acc darts,86.28,82.69,77.13,76.8,81.99,
dartspt,85.54,32.43,81.04,72.75,85.51,
zc pt,3450.967554,3448.211459,3440.79926,3443.240243,3452.989921,2
discrete zc,3434.421701,3437.661196,3418.572637,3397.51709,3424.166157,
alpha,0.1554,0.7029,0.0538,0.0598,0.028,
best-acc,94.05,94.16,94.68,94.56,94.1,
alpha-60,0.1648,0.4853,0.1223,0.1397,0.088,
tenaspt,38.5,16.0,20.0,17.0,27.5,
synflow,1.9460309216168614e+31,8.014786854561015e+30,1.1851807660289746e+30,8.96867143875011e+26,7.75842932776677e+30,
synflow_disc,4.777733726551756e+21, 1.4572459237815469e+26, 3.8590321292364994e+30, 1.8898449210848245e+36, 1.5222938895812008e+26,
best_nwot, 1761.8789642379636,1769.103803678444, 1793.8140278364453, 1792.8239766388833, 1761.9145207476113,
best_synflow,5.776473195639679e+39, 1.4672616553030765e+44, 6.658953754065702e+49, 5.480408193532999e+39, 1.9606567871518125e+35,
zc-pt-post,3067.0476, 3063.1135, 3058.818, 3064.5405, 3070.7593,
zc-disc-post,3061.6133, 3063.294, 3038.05, 3012.938, 3042.5535,
1 edge\op none skip_connect nor_conv_1x1 nor_conv_3x3 avg_pool_3x3
2 acc 77.36±22.55 83.81±10.64 86.38±9.1 87.32±9.17 81.02±11.85
3 discrete acc darts 83.42 84.1 72 76.35 39.66
4 dartspt 85.43 17.02 78.13 59.09 85.34
5 zc pt 3455.233646 3449.898772 3449.538363 3441.815563 3461.179476 3
6 discrete zc 3331.007285 3445.489455 3366.877065 3437.551079 3423.180255
7 alpha 0.0758 0.7742 0.05 0.0761 0.024
8 best-acc 94.15 94.18 94.44 94.68 93.86
9 alpha-60 0.1387 0.4758 0.1296 0.181 0.0748
10 tenaspt 38.5 48.0 31.0 6.0 37.5
11 synflow 1.9286723850908796e+31 7.990282869734622e+30 1.2421187150331997e+30 9.438907569335487e+26 8.191504786187086e+30
12 synflow_disc 4.639162000716631e+21 1.4975281050055959e+26 4.2221622054263117e+30 1.9475517523688712e+36 1.5075022033622535e+26
13 best_nwot 1702.1967536035393 1773.1779654806287 1793.8140278364453 1792.8682630835763 1761.1262357119376
14 best_synflow 5.784248799475683e+39 1.4769546208886953e+44 6.658953754065702e+49 5.1987025703231504e+39 1.9928388494681343e+35
15 zc-pt-post 3067.0476 3055.9404 3059.8901 3060.4536 3073.5583
16 zc-disc-post 2942.267 3068.6416 3009.5847 3028.1794 3031.4248
17
18 acc 80.03 ±19.38 83.11 ±12.81 85.23 ±11.0 85.99±11.1 81.52±13.06
19 discrete acc darts 85.12 83.39 76.72 81.34 84.38
20 dartspt 85.52 36.1 84.39 80.95 85.49
21 zc pt 3452.145851 3448.696318 3441.809174 3440.652631 3453.739943 3
22 discrete zc 3429.074707 3435.750274 3407.872847 3434.584512 3421.442414
23 alpha 0.0629 0.813 0.039 0.0579 0.0269
24 best-acc 94.24 94.16 94.49 94.68 94.09
25 alpha-60 0.1236 0.5535 0.11 0.1249 0.088
26 tenaspt 7.0 55.0 10.0 15.0 39.0
27 synflow 3.116079880492518e+30 2.5018418732419554e+30 1.4274537256246266e+30 3.138287824323275e+29 2.5693894962958226e+30
28 synflow_disc 5.615386425664938e+28 2.340336657109326e+29 1.9258305801684058e+30 3.012759514473006e+32 2.2897138361934977e+29
29 best_nwot 1765.3743820515451 1770.8436009213751 1791.917305624048 1793.8140278364453 1763.877253730585
30 best_synflow 1.9424580089849912e+49 2.764587447411338e+49 6.658953754065702e+49 2.0353792445711388e+49 1.4435653786128956e+49
31 zc-pt-post 3067.0476 3058.9197 3048.8745 3051.2664 3066.668
32 zc-disc-post 3020.0203 3052.1936 3026.2217 3022.0935 3029.2
33
34 acc 82.9±14.68 82.44 ±14.25 84.05 ±13.19 84.49 ±13.21 81.98 ±14.54
35 discrete acc darts 85.96 85.18 54.02 78.41 84.88
36 dartspt 85.51 80.29 81.86 77.68 85.32
37 zc pt 3446.521007 3447.612434 3435.455206 3436.396744 3449.275466 2
38 discrete zc 3428.795464 3423.361285 3440.925616 3437.286935 3416.891544
39 alpha 0.3339 0.4742 0.061 0.0774 0.0534
40 best-acc 94.25 94.43 94.49 94.68 94.19
41 alpha-60 0.2403 0.3297 0.1495 0.1748 0.1056
42 tenaspt 31.5 10.0 30.0 16.5 36.5
43 synflow 1.0312338471669537e+31 4.9191575008661263e+30 1.4241158958667068e+30 1.0282498082879338e+28 5.038622256524752e+30
44 synflow_disc 1.6980829611704765e+25 3.3199508659283994e+27 3.3825056097270114e+30 1.2059727722928161e+35 3.279653417965715e+27
45 best_nwot 1764.51075805859 1764.116749555202 1793.8140278364453 1792.8239766388833 1764.1848313456592
46 best_synflow 8.376122028137071e+41 1.0615041036082487e+45 6.658953754065702e+49 8.399427750574918e+41 2.5270360875229e+39
47 zc-pt-post 3067.0476 3068.708 3056.3506 3047.9695 3071.3577
48 zc-disc-post 3044.023 3033.0627 3032.825 3052.0688 3023.2302
49
50 acc 74.02 ±26.1 85.17 ±7.59 87.3 ±2.48 88.28 ±2.06 81.38 ±8.91
51 discrete acc darts 66.18 85.38 78.8 81.59 82.89
52 dartspt 85.49 9.86 81.79 59.18 85.48
53 zc pt 3453.805194 3435.985406 3444.044047 3445.595326 3447.067855 1
54 discrete zc 3408.990502 3464.050741 3359.888463 3382.1755 3431.805571
55 alpha 0.0267 0.8163 0.0471 0.0904 0.0194
56 best-acc 94.16 94.68 94.03 94.04 93.85
57 alpha-60 0.0636 0.6513 0.0826 0.1335 0.0691
58 tenaspt 34.0 44.0 53.5 23.0 30.0
59 synflow 2.0042808467776213e+30 1.9513599734483263e+30 1.5188352495143643e+30 7.704103863066581e+29 1.9536326167605112e+30
60 synflow_disc 4.3050000047616484e+29 7.635399455155384e+29 1.5949429556375966e+30 1.4519088590209463e+31 7.345232988374157e+29
61 best_nwot 1766.5481959337162 1769.1683503033412 1793.8140278364453 1792.8682630835763 1765.1445530390838
62 best_synflow 5.90523769961745e+49 6.344766865099622e+49 6.571181309028854e+49 6.57509920946309e+49 6.658953754065702e+49
63 zc-pt-post 3067.0476 3032.6658 3058.9646 3059.2861 3047.1965
64 zc-disc-post 2975.976 3130.7397 3008.5625 3009.341 3086.3398
65
66 acc 80.14 ±19.52 83.05 ±12.74 85.09 ±11.17 85.7 ±11.18 81.89 ±12.98
67 discrete acc darts 86.44 84.75 80.23 80.46 80.13
68 dartspt 85.45 51.15 78.84 64.64 85.14
69 zc pt 3451.055723 3449.796894 3442.625354 3441.131751 3453.311493 3
70 discrete zc 3433.98773 3435.573458 3424.470031 3431.143217 3423.153213
71 alpha 0.0857 0.7082 0.0716 0.0946 0.0399
72 best-acc 94.29 94.18 94.56 94.68 94.23
73 alpha-60 0.1183 0.48 0.1305 0.1732 0.0979
74 tenaspt 32.0 32.5 36.5 32.0 52.0
75 synflow 3.165975343348193e+30 2.4302742111297496e+30 1.4853908452542004e+30 2.868307126123347e+29 2.6891361283692336e+30
76 synflow_disc 5.5202846896598e+28 2.4896852024898197e+29 2.1810394989246777e+30 2.9482018739806336e+32 2.2732178076450144e+29
77 best_nwot 1752.024924623228 1793.8140278364453 1786.3402409418215 1785.0294182838636 1781.9741301640186
78 best_synflow 1.8865959738805548e+49 2.593134717306188e+49 6.658953754065702e+49 2.021273089103704e+49 1.6187260144154453e+49
79 zc-pt-post 3067.0476 3060.9983 3057.1006 3054.3428 3066.2087
80 zc-disc-post 3037.8726 3055.4219 3027.6638 3024.3271 3037.8108
81
82 acc 77.61 ±22.16 83.43 ±11.34 86.18 ±9.08 86.95 ±9.02 81.74 ±11.79
83 discrete acc darts 86.28 82.69 77.13 76.8 81.99
84 dartspt 85.54 32.43 81.04 72.75 85.51
85 zc pt 3450.967554 3448.211459 3440.79926 3443.240243 3452.989921 2
86 discrete zc 3434.421701 3437.661196 3418.572637 3397.51709 3424.166157
87 alpha 0.1554 0.7029 0.0538 0.0598 0.028
88 best-acc 94.05 94.16 94.68 94.56 94.1
89 alpha-60 0.1648 0.4853 0.1223 0.1397 0.088
90 tenaspt 38.5 16.0 20.0 17.0 27.5
91 synflow 1.9460309216168614e+31 8.014786854561015e+30 1.1851807660289746e+30 8.96867143875011e+26 7.75842932776677e+30
92 synflow_disc 4.777733726551756e+21 1.4572459237815469e+26 3.8590321292364994e+30 1.8898449210848245e+36 1.5222938895812008e+26
93 best_nwot 1761.8789642379636 1769.103803678444 1793.8140278364453 1792.8239766388833 1761.9145207476113
94 best_synflow 5.776473195639679e+39 1.4672616553030765e+44 6.658953754065702e+49 5.480408193532999e+39 1.9606567871518125e+35
95 zc-pt-post 3067.0476 3063.1135 3058.818 3064.5405 3070.7593
96 zc-disc-post 3061.6133 3063.294 3038.05 3012.938 3042.5535

15614
notebooks_201/parse_log.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

136
pycls/core/benchmark.py Normal file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Benchmarking functions."""
import pycls.core.logging as logging
import pycls.datasets.loader as loader
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
@torch.no_grad()
def compute_time_eval(model):
"""Computes precise model forward test time using dummy data."""
# Use eval mode
model.eval()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
# Compute precise forward pass time
timer = Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
# Forward
timer.tic()
model(inputs)
torch.cuda.synchronize()
timer.toc()
return timer.average_time
def compute_time_train(model, loss_fun):
"""Computes precise model forward + backward time using dummy data."""
# Use train mode
model.train()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
if cfg.TASK in ['col', 'seg']:
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
else:
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
# Cache BatchNorm2D running stats
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
# Compute precise forward backward pass time
fw_timer, bw_timer = Timer(), Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
fw_timer.reset()
bw_timer.reset()
# Forward
fw_timer.tic()
preds = model(inputs)
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
torch.cuda.synchronize()
bw_timer.toc()
# Restore BatchNorm2D running stats
for bn, (mean, var) in zip(bns, bn_stats):
bn.running_mean, bn.running_var = mean, var
return fw_timer.average_time, bw_timer.average_time
def compute_time_loader(data_loader):
"""Computes loader time."""
timer = Timer()
loader.shuffle(data_loader, 0)
data_loader_iterator = iter(data_loader)
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
total_iter = min(total_iter, len(data_loader))
for cur_iter in range(total_iter):
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
timer.tic()
next(data_loader_iterator)
timer.toc()
return timer.average_time
def compute_time_full(model, loss_fun, train_loader, test_loader):
"""Times model and data loader."""
logger.info("Computing model and loader timings...")
# Compute timings
test_fw_time = compute_time_eval(model)
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
train_fw_bw_time = train_fw_time + train_bw_time
train_loader_time = compute_time_loader(train_loader)
# Output iter timing
iter_times = {
"test_fw_time": test_fw_time,
"train_fw_time": train_fw_time,
"train_bw_time": train_bw_time,
"train_fw_bw_time": train_fw_bw_time,
"train_loader_time": train_loader_time,
}
logger.info(logging.dump_log_data(iter_times, "iter_times"))
# Output epoch timing
epoch_times = {
"test_fw_time": test_fw_time * len(test_loader),
"train_fw_time": train_fw_time * len(train_loader),
"train_bw_time": train_bw_time * len(train_loader),
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
"train_loader_time": train_loader_time * len(train_loader),
}
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))

88
pycls/core/builders.py Normal file
View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Model and loss construction functions."""
import torch
from pycls.core.config import cfg
from pycls.models.anynet import AnyNet
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
from pycls.models.resnet import ResNet
from pycls.models.nas.nas import NAS
from pycls.models.nas.nas_search import NAS_Search
from pycls.models.nas_bench.model_builder import NAS_Bench
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
"""CrossEntropyLoss with label smoothing."""
def __init__(self):
super(LabelSmoothedCrossEntropyLoss, self).__init__()
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
self.num_classes = cfg.MODEL.NUM_CLASSES
def forward(self, logits, target):
pred = logits.log_softmax(dim=-1)
with torch.no_grad():
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
return (-target_dist * pred).sum(dim=-1).mean()
# Supported models
_models = {
"anynet": AnyNet,
"effnet": EffNet,
"resnet": ResNet,
"regnet": RegNet,
"nas": NAS,
"nas_search": NAS_Search,
"nas_bench": NAS_Bench,
}
# Supported loss functions
_loss_funs = {
"cross_entropy": torch.nn.CrossEntropyLoss,
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
}
def get_model():
"""Gets the model class specified in the config."""
err_str = "Model type '{}' not supported"
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
return _models[cfg.MODEL.TYPE]
def get_loss_fun():
"""Gets the loss function class specified in the config."""
err_str = "Loss function type '{}' not supported"
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
return _loss_funs[cfg.MODEL.LOSS_FUN]
def build_model():
"""Builds the model."""
return get_model()()
def build_loss_fun():
"""Build the loss function."""
if cfg.TASK == "seg":
return get_loss_fun()(ignore_index=255)
else:
return get_loss_fun()()
def register_model(name, ctor):
"""Registers a model dynamically."""
_models[name] = ctor
def register_loss_fun(name, ctor):
"""Registers a loss function dynamically."""
_loss_funs[name] = ctor

98
pycls/core/checkpoint.py Normal file
View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions that handle saving and loading of checkpoints."""
import os
import pycls.core.distributed as dist
import torch
from pycls.core.config import cfg
# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"
# Checkpoints directory name
_DIR_NAME = "checkpoints"
def get_checkpoint_dir():
"""Retrieves the location for storing checkpoints."""
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
def get_checkpoint(epoch):
"""Retrieves the path to a checkpoint file."""
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
return os.path.join(get_checkpoint_dir(), name)
def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
# Checkpoint file names are in lexicographic order
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)
def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not os.path.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
def save_checkpoint(model, optimizer, epoch):
"""Saves a checkpoint."""
# Save checkpoints only from the master process
if not dist.is_master_proc():
return
# Ensure that the checkpoint dir exists
os.makedirs(get_checkpoint_dir(), exist_ok=True)
# Omit the DDP wrapper in the multi-gpu setting
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state
if isinstance(optimizer, list):
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_w_state": optimizer[0].state_dict(),
"optimizer_a_state": optimizer[1].state_dict(),
"cfg": cfg.dump(),
}
else:
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
torch.save(checkpoint, checkpoint_file)
return checkpoint_file
def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
err_str = "Checkpoint '{}' not found"
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
# Load the checkpoint on CPU to avoid GPU mem spike
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# Account for the DDP wrapper in the multi-gpu setting
ms = model.module if cfg.NUM_GPUS > 1 else model
ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
if isinstance(optimizer, list):
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
else:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"]

500
pycls/core/config.py Normal file
View File

@@ -0,0 +1,500 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Configuration file (powered by YACS)."""
import argparse
import os
import sys
from pycls.core.io import cache_url
from yacs.config import CfgNode as CfgNode
# Global config object
_C = CfgNode()
# Example usage:
# from core.config import cfg
cfg = _C
# ------------------------------------------------------------------------------------ #
# Model options
# ------------------------------------------------------------------------------------ #
_C.MODEL = CfgNode()
# Model type
_C.MODEL.TYPE = ""
# Number of weight layers
_C.MODEL.DEPTH = 0
# Number of input channels
_C.MODEL.INPUT_CHANNELS = 3
# Number of classes
_C.MODEL.NUM_CLASSES = 10
# Loss function (see pycls/core/builders.py for options)
_C.MODEL.LOSS_FUN = "cross_entropy"
# Label smoothing eps
_C.MODEL.LABEL_SMOOTHING_EPS = 0.0
# ASPP channels
_C.MODEL.ASPP_CHANNELS = 256
# ASPP dilation rates
_C.MODEL.ASPP_RATES = [6, 12, 18]
# ------------------------------------------------------------------------------------ #
# ResNet options
# ------------------------------------------------------------------------------------ #
_C.RESNET = CfgNode()
# Transformation function (see pycls/models/resnet.py for options)
_C.RESNET.TRANS_FUN = "basic_transform"
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
_C.RESNET.NUM_GROUPS = 1
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
_C.RESNET.WIDTH_PER_GROUP = 64
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
_C.RESNET.STRIDE_1X1 = True
# ------------------------------------------------------------------------------------ #
# AnyNet options
# ------------------------------------------------------------------------------------ #
_C.ANYNET = CfgNode()
# Stem type
_C.ANYNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.ANYNET.STEM_W = 32
# Block type
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
# Depth for each stage (number of blocks in the stage)
_C.ANYNET.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.ANYNET.WIDTHS = []
# Strides for each stage (applies to the first block of each stage)
_C.ANYNET.STRIDES = []
# Bottleneck multipliers for each stage (applies to bottleneck block)
_C.ANYNET.BOT_MULS = []
# Group widths for each stage (applies to bottleneck block)
_C.ANYNET.GROUP_WS = []
# Whether SE is enabled for res_bottleneck_block
_C.ANYNET.SE_ON = False
# SE ratio
_C.ANYNET.SE_R = 0.25
# ------------------------------------------------------------------------------------ #
# RegNet options
# ------------------------------------------------------------------------------------ #
_C.REGNET = CfgNode()
# Stem type
_C.REGNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.REGNET.STEM_W = 32
# Block type
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
# Stride of each stage
_C.REGNET.STRIDE = 2
# Squeeze-and-Excitation (RegNetY)
_C.REGNET.SE_ON = False
_C.REGNET.SE_R = 0.25
# Depth
_C.REGNET.DEPTH = 10
# Initial width
_C.REGNET.W0 = 32
# Slope
_C.REGNET.WA = 5.0
# Quantization
_C.REGNET.WM = 2.5
# Group width
_C.REGNET.GROUP_W = 16
# Bottleneck multiplier (bm = 1 / b from the paper)
_C.REGNET.BOT_MUL = 1.0
# ------------------------------------------------------------------------------------ #
# EfficientNet options
# ------------------------------------------------------------------------------------ #
_C.EN = CfgNode()
# Stem width
_C.EN.STEM_W = 32
# Depth for each stage (number of blocks in the stage)
_C.EN.DEPTHS = []
# Width for each stage (width of each block in the stage)
_C.EN.WIDTHS = []
# Expansion ratios for MBConv blocks in each stage
_C.EN.EXP_RATIOS = []
# Squeeze-and-Excitation (SE) ratio
_C.EN.SE_R = 0.25
# Strides for each stage (applies to the first block of each stage)
_C.EN.STRIDES = []
# Kernel sizes for each stage
_C.EN.KERNELS = []
# Head width
_C.EN.HEAD_W = 1280
# Drop connect ratio
_C.EN.DC_RATIO = 0.0
# Dropout ratio
_C.EN.DROPOUT_RATIO = 0.0
# ---------------------------------------------------------------------------- #
# NAS options
# ---------------------------------------------------------------------------- #
_C.NAS = CfgNode()
# Cell genotype
_C.NAS.GENOTYPE = 'nas'
# Custom genotype
_C.NAS.CUSTOM_GENOTYPE = []
# Base NAS width
_C.NAS.WIDTH = 16
# Total number of cells
_C.NAS.DEPTH = 20
# Auxiliary heads
_C.NAS.AUX = False
# Weight for auxiliary heads
_C.NAS.AUX_WEIGHT = 0.4
# Drop path probability
_C.NAS.DROP_PROB = 0.0
# Matrix in NAS Bench
_C.NAS.MATRIX = []
# Operations in NAS Bench
_C.NAS.OPS = []
# Number of stacks in NAS Bench
_C.NAS.NUM_STACKS = 3
# Number of modules per stack in NAS Bench
_C.NAS.NUM_MODULES_PER_STACK = 3
# ------------------------------------------------------------------------------------ #
# Batch norm options
# ------------------------------------------------------------------------------------ #
_C.BN = CfgNode()
# BN epsilon
_C.BN.EPS = 1e-5
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
_C.BN.MOM = 0.1
# Precise BN stats
_C.BN.USE_PRECISE_STATS = False
_C.BN.NUM_SAMPLES_PRECISE = 1024
# Initialize the gamma of the final BN of each block to zero
_C.BN.ZERO_INIT_FINAL_GAMMA = False
# Use a different weight decay for BN layers
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
# ------------------------------------------------------------------------------------ #
# Optimizer options
# ------------------------------------------------------------------------------------ #
_C.OPTIM = CfgNode()
# Base learning rate
_C.OPTIM.BASE_LR = 0.1
# Learning rate policy select from {'cos', 'exp', 'steps'}
_C.OPTIM.LR_POLICY = "cos"
# Exponential decay factor
_C.OPTIM.GAMMA = 0.1
# Steps for 'steps' policy (in epochs)
_C.OPTIM.STEPS = []
# Learning rate multiplier for 'steps' policy
_C.OPTIM.LR_MULT = 0.1
# Maximal number of epochs
_C.OPTIM.MAX_EPOCH = 200
# Momentum
_C.OPTIM.MOMENTUM = 0.9
# Momentum dampening
_C.OPTIM.DAMPENING = 0.0
# Nesterov momentum
_C.OPTIM.NESTEROV = True
# L2 regularization
_C.OPTIM.WEIGHT_DECAY = 5e-4
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
_C.OPTIM.WARMUP_FACTOR = 0.1
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0
# Update the learning rate per iter
_C.OPTIM.ITER_LR = False
# Base learning rate for arch
_C.OPTIM.ARCH_BASE_LR = 0.0003
# L2 regularization for arch
_C.OPTIM.ARCH_WEIGHT_DECAY = 0.001
# Optimizer for arch
_C.OPTIM.ARCH_OPTIM = 'adam'
# Epoch to start optimizing arch
_C.OPTIM.ARCH_EPOCH = 0.0
# ------------------------------------------------------------------------------------ #
# Training options
# ------------------------------------------------------------------------------------ #
_C.TRAIN = CfgNode()
# Dataset and split
_C.TRAIN.DATASET = ""
_C.TRAIN.SPLIT = "train"
# Total mini-batch size
_C.TRAIN.BATCH_SIZE = 128
# Image size
_C.TRAIN.IM_SIZE = 224
# Evaluate model on test data every eval period epochs
_C.TRAIN.EVAL_PERIOD = 1
# Save model checkpoint every checkpoint period epochs
_C.TRAIN.CHECKPOINT_PERIOD = 1
# Resume training from the latest checkpoint in the output directory
_C.TRAIN.AUTO_RESUME = True
# Weights to start training from
_C.TRAIN.WEIGHTS = ""
# Percentage of gray images in jig
_C.TRAIN.GRAY_PERCENTAGE = 0.0
# Portion to create trainA/trainB split
_C.TRAIN.PORTION = 1.0
# ------------------------------------------------------------------------------------ #
# Testing options
# ------------------------------------------------------------------------------------ #
_C.TEST = CfgNode()
# Dataset and split
_C.TEST.DATASET = ""
_C.TEST.SPLIT = "val"
# Total mini-batch size
_C.TEST.BATCH_SIZE = 200
# Image size
_C.TEST.IM_SIZE = 256
# Weights to use for testing
_C.TEST.WEIGHTS = ""
# ------------------------------------------------------------------------------------ #
# Common train/test data loader options
# ------------------------------------------------------------------------------------ #
_C.DATA_LOADER = CfgNode()
# Number of data loader workers per process
_C.DATA_LOADER.NUM_WORKERS = 8
# Load data to pinned host memory
_C.DATA_LOADER.PIN_MEMORY = True
# ------------------------------------------------------------------------------------ #
# Memory options
# ------------------------------------------------------------------------------------ #
_C.MEM = CfgNode()
# Perform ReLU inplace
_C.MEM.RELU_INPLACE = True
# ------------------------------------------------------------------------------------ #
# CUDNN options
# ------------------------------------------------------------------------------------ #
_C.CUDNN = CfgNode()
# Perform benchmarking to select the fastest CUDNN algorithms to use
# Note that this may increase the memory usage and will likely not result
# in overall speedups when variable size inputs are used (e.g. COCO training)
_C.CUDNN.BENCHMARK = True
# ------------------------------------------------------------------------------------ #
# Precise timing options
# ------------------------------------------------------------------------------------ #
_C.PREC_TIME = CfgNode()
# Number of iterations to warm up the caches
_C.PREC_TIME.WARMUP_ITER = 3
# Number of iterations to compute avg time
_C.PREC_TIME.NUM_ITER = 30
# ------------------------------------------------------------------------------------ #
# Misc options
# ------------------------------------------------------------------------------------ #
# Number of GPUs to use (applies to both training and testing)
_C.NUM_GPUS = 1
# Task (cls, seg, rot, col, jig)
_C.TASK = "cls"
# Grid in Jigsaw (2, 3); no effect if TASK is not jig
_C.JIGSAW_GRID = 3
# Output directory
_C.OUT_DIR = "/tmp"
# Config destination (in OUT_DIR)
_C.CFG_DEST = "config.yaml"
# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries
_C.RNG_SEED = 1
# Log destination ('stdout' or 'file')
_C.LOG_DEST = "stdout"
# Log period in iters
_C.LOG_PERIOD = 10
# Distributed backend
_C.DIST_BACKEND = "nccl"
# Hostname and port for initializing multi-process groups
_C.HOST = "localhost"
_C.PORT = 10001
# Models weights referred to by URL are downloaded to this local cache
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
# ------------------------------------------------------------------------------------ #
# Deprecated keys
# ------------------------------------------------------------------------------------ #
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
_C.register_deprecated_key("PREC_TIME.ENABLED")
def assert_and_infer_cfg(cache_urls=True):
"""Checks config values invariants."""
err_str = "The first lr step must start at 0"
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
data_splits = ["train", "val", "test"]
err_str = "Data split '{}' not supported"
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
err_str = "Precise BN stats computation not verified for > 1 GPU"
assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str
err_str = "Log destination '{}' not supported"
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
if cache_urls:
cache_cfg_urls()
def cache_cfg_urls():
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with open(cfg_file, "w") as f:
_C.dump(stream=f)
def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
_C.merge_from_file(cfg_file)
def load_cfg_fom_args(description="Config file options."):
"""Load config from command line arguments and set any specified options."""
parser = argparse.ArgumentParser(description=description)
help_s = "Config file location"
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
help_s = "See pycls/core/config.py for all options"
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
_C.merge_from_file(args.cfg_file)
_C.merge_from_list(args.opts)

157
pycls/core/distributed.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Distributed helpers."""
import multiprocessing
import os
import signal
import threading
import traceback
import torch
from pycls.core.config import cfg
def is_master_proc():
"""Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints. In
the multi GPU setting, we assign the master role to the rank 0 process. When
training using a single GPU, there is a single process which is considered master.
"""
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
def init_process_group(proc_rank, world_size):
"""Initializes the default process group."""
# Set the GPU to use
torch.cuda.set_device(proc_rank)
# Initialize the process group
torch.distributed.init_process_group(
backend=cfg.DIST_BACKEND,
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
world_size=world_size,
rank=proc_rank,
)
def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()
def scaled_all_reduce(tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of the
process group (equivalent to cfg.NUM_GPUS).
"""
# There is no need for reduction in the single-proc case
if cfg.NUM_GPUS == 1:
return tensors
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS)
return tensors
class ChildException(Exception):
"""Wraps an exception from a child process."""
def __init__(self, child_trace):
super(ChildException, self).__init__(child_trace)
class ErrorHandler(object):
"""Multiprocessing error handler (based on fairseq's).
Listens for errors in child processes and propagates the tracebacks to the parent.
"""
def __init__(self, error_queue):
# Shared error queue
self.error_queue = error_queue
# Children processes sharing the error queue
self.children_pids = []
# Start a thread listening to errors
self.error_listener = threading.Thread(target=self.listen, daemon=True)
self.error_listener.start()
# Register the signal handler
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
"""Registers a child process."""
self.children_pids.append(pid)
def listen(self):
"""Listens for errors in the error queue."""
# Wait until there is an error in the queue
child_trace = self.error_queue.get()
# Put the error back for the signal handler
self.error_queue.put(child_trace)
# Invoke the signal handler
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, _sig_num, _stack_frame):
"""Signal handler."""
# Kill children processes
for pid in self.children_pids:
os.kill(pid, signal.SIGINT)
# Propagate the error from the child process
raise ChildException(self.error_queue.get())
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
"""Runs a function from a child process."""
try:
# Initialize the process group
init_process_group(proc_rank, world_size)
# Run the function
fun(*fun_args, **fun_kwargs)
except KeyboardInterrupt:
# Killed by the parent process
pass
except Exception:
# Propagate exception to the parent process
error_queue.put(traceback.format_exc())
finally:
# Destroy the process group
destroy_process_group()
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
# There is no need for multi-proc in the single-proc case
fun_kwargs = fun_kwargs if fun_kwargs else {}
if num_proc == 1:
fun(*fun_args, **fun_kwargs)
return
# Handle errors from training subprocesses
error_queue = multiprocessing.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Run each training subprocess
ps = []
for i in range(num_proc):
p_i = multiprocessing.Process(
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
)
ps.append(p_i)
p_i.start()
error_handler.add_child(p_i.pid)
# Wait for each subprocess to finish
for p in ps:
p.join()

77
pycls/core/io.py Normal file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""IO utilities (adapted from Detectron)"""
import logging
import os
import re
import sys
from urllib import request as urlrequest
logger = logging.getLogger(__name__)
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
def cache_url(url_or_file, cache_dir):
"""Download the file specified by the URL to the cache_dir and return the path to
the cached file. If the argument is not a URL, simply return it as is.
"""
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
if not is_url:
return url_or_file
url = url_or_file
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
return cache_file_path
cache_file_dir = os.path.dirname(cache_file_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir)
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path
def _progress_bar(count, total):
"""Report download progress. Credit:
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = "=" * filled_len + "-" * (bar_len - filled_len)
sys.stdout.write(
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
)
sys.stdout.flush()
if count >= total:
sys.stdout.write("\n")
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path. Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
"""
req = urlrequest.Request(url)
response = urlrequest.urlopen(req)
total_size = response.info().get("Content-Length").strip()
total_size = int(total_size)
bytes_so_far = 0
with open(dst_file_path, "wb") as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
if progress_hook:
progress_hook(bytes_so_far, total_size)
f.write(chunk)
return bytes_so_far

138
pycls/core/logging.py Normal file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Logging."""
import builtins
import decimal
import logging
import os
import sys
import pycls.core.distributed as dist
import simplejson
from pycls.core.config import cfg
# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"
# Data output with dump_log_data(data, data_type) will be tagged w/ this
_TAG = "json_stats: "
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
_TYPE = "_type"
def _suppress_print():
"""Suppresses printing from the current process."""
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
pass
builtins.print = ignore
def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if dist.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
else:
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
logging.basicConfig(**logging_config)
else:
_suppress_print()
def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)
def dump_log_data(data, data_type, prec=4):
"""Covert data (a dictionary) into tagged json string for logging."""
data[_TYPE] = data_type
data = float_to_decimal(data, prec)
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
return "{:s}{:s}".format(_TAG, data_json)
def float_to_decimal(data, prec=4):
"""Convert floats to decimals which allows for fixed width json."""
if isinstance(data, dict):
return {k: float_to_decimal(v, prec) for k, v in data.items()}
if isinstance(data, float):
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
else:
return data
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, log_file) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
return files, names
def load_log_data(log_file, data_types_to_skip=()):
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
# Load log_file
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
with open(log_file, "r") as f:
lines = f.readlines()
# Extract and parse lines that start with _TAG and have a type specified
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
lines = [simplejson.loads(l) for l in lines]
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
# Generate data structure accessed by data[data_type][index][metric]
data_types = [l[_TYPE] for l in lines]
data = {t: [] for t in data_types}
for t, line in zip(data_types, lines):
del line[_TYPE]
data[t].append(line)
# Generate data structure accessed by data[data_type][metric][index]
for t in data:
metrics = sorted(data[t][0].keys())
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
return data
def sort_log_data(data):
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
for t in data:
if "epoch" in data[t]:
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
epoch = data[t]["epoch_ind"]
if "iter" in data[t]:
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
for m in data[t]:
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
else:
data[t] = {m: d[0] for m, d in data[t].items()}
return data

435
pycls/core/meters.py Normal file
View File

@@ -0,0 +1,435 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Meters."""
from collections import deque
import numpy as np
import pycls.core.logging as logging
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
def time_string(seconds):
"""Converts time in seconds to a fixed-width string format."""
days, rem = divmod(int(seconds), 24 * 3600)
hrs, rem = divmod(rem, 3600)
mins, secs = divmod(rem, 60)
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
def inter_union(preds, labels, num_classes):
_, preds = torch.max(preds, 1)
preds = preds.type(torch.uint8) + 1
labels = labels.type(torch.uint8) + 1
preds = preds * (labels > 0).type(torch.uint8)
inter = preds * (preds == labels).type(torch.uint8)
area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes)
area_union = area_preds + area_labels - area_inter
return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)]
def topk_errors(preds, labels, ks):
"""Computes the top-k error for each k."""
err_str = "Batch dim of predictions and labels must match"
assert preds.size(0) == labels.size(0), err_str
# Find the top max_k predictions for each sample
_top_max_k_vals, top_max_k_inds = torch.topk(
preds, max(ks), dim=1, largest=True, sorted=True
)
# (batch_size, max_k) -> (max_k, batch_size)
top_max_k_inds = top_max_k_inds.t()
# (batch_size, ) -> (max_k, batch_size)
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
# Compute the number of topk correct predictions for each k
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
def gpu_mem_usage():
"""Computes the GPU memory usage for the current device (MB)."""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / 1024 / 1024
class ScalarMeter(object):
"""Measures a scalar value (adapted from Detectron)."""
def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
def reset(self):
self.deque.clear()
self.total = 0.0
self.count = 0
def add_value(self, value):
self.deque.append(value)
self.count += 1
self.total += value
def get_win_median(self):
return np.median(self.deque)
def get_win_avg(self):
return np.mean(self.deque)
def get_global_avg(self):
return self.total / self.count
class TrainMeter(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
# Current minibatch stats
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": time_string(eta_sec),
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
avg_loss = self.loss_total / self.num_samples
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": time_string(eta_sec),
"top1_err": top1_err,
"top5_err": top5_err,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeter(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
# Min errors (over the full test set)
self.min_top1_err = 100.0
self.min_top5_err = 100.0
# Number of misclassified examples
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.min_top1_err = 100.0
self.min_top5_err = 100.0
self.iter_timer.reset()
self.mb_top1_err.reset()
self.mb_top5_err.reset()
self.num_top1_mis = 0
self.num_top5_mis = 0
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, top1_err, top5_err, mb_size):
self.mb_top1_err.add_value(top1_err)
self.mb_top5_err.add_value(top5_err)
self.num_top1_mis += top1_err * mb_size
self.num_top5_mis += top5_err * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = gpu_mem_usage()
iter_stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"top1_err": self.mb_top1_err.get_win_median(),
"top5_err": self.mb_top5_err.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
def get_epoch_stats(self, cur_epoch):
top1_err = self.num_top1_mis / self.num_samples
top5_err = self.num_top5_mis / self.num_samples
self.min_top1_err = min(self.min_top1_err, top1_err)
self.min_top5_err = min(self.min_top5_err, top5_err)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"top1_err": top1_err,
"top5_err": top5_err,
"min_top1_err": self.min_top1_err,
"min_top5_err": self.min_top5_err,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))
class TrainMeterIoU(object):
"""Measures training stats."""
def __init__(self, epoch_iters):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
self.lr = None
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def reset(self, timer=False):
if timer:
self.iter_timer.reset()
self.loss.reset()
self.loss_total = 0.0
self.lr = None
self.mb_miou.reset()
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, inter, union, loss, lr, mb_size):
# Current minibatch stats
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
self.loss.add_value(loss)
self.lr = lr
# Aggregate stats
self.num_inter += inter * mb_size
self.num_union += union * mb_size
self.loss_total += loss * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"eta": time_string(eta_sec),
"miou": self.mb_miou.get_win_median(),
"loss": self.loss.get_win_median(),
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
mem_usage = gpu_mem_usage()
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
avg_loss = self.loss_total / self.num_samples
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"eta": time_string(eta_sec),
"miou": miou,
"loss": avg_loss,
"lr": self.lr,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeterIoU(object):
"""Measures testing stats."""
def __init__(self, max_iter):
self.max_iter = max_iter
self.iter_timer = Timer()
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
self.max_miou = 0.0
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def reset(self, min_errs=False):
if min_errs:
self.max_miou = 0.0
self.iter_timer.reset()
self.mb_miou.reset()
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
self.num_samples = 0
def iter_tic(self):
self.iter_timer.tic()
def iter_toc(self):
self.iter_timer.toc()
def update_stats(self, inter, union, mb_size):
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
self.num_inter += inter * mb_size
self.num_union += union * mb_size
self.num_samples += mb_size
def get_iter_stats(self, cur_epoch, cur_iter):
mem_usage = gpu_mem_usage()
iter_stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
"time_avg": self.iter_timer.average_time,
"time_diff": self.iter_timer.diff,
"miou": self.mb_miou.get_win_median(),
"mem": int(np.ceil(mem_usage)),
}
return iter_stats
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
return
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
def get_epoch_stats(self, cur_epoch):
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
self.max_miou = max(self.max_miou, miou)
mem_usage = gpu_mem_usage()
stats = {
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
"time_avg": self.iter_timer.average_time,
"miou": miou,
"max_miou": self.max_miou,
"mem": int(np.ceil(mem_usage)),
}
return stats
def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))

129
pycls/core/net.py Normal file
View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for manipulating networks."""
import itertools
import math
import torch
import torch.nn as nn
from pycls.core.config import cfg
def init_weights(m):
"""Performs ResNet-style weight initialization."""
if isinstance(m, nn.Conv2d):
# Note that there is no bias due to BN
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
elif isinstance(m, nn.BatchNorm2d):
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()
@torch.no_grad()
def compute_precise_bn_stats(model, loader):
"""Computes precise BN stats on training data."""
# Compute the number of minibatches to use
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize stats storage
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
moms = [bn.momentum for bn in bns]
# Disable momentum
for bn in bns:
bn.momentum = 1.0
# Accumulate the stats across the data samples
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
# Accumulate the stats for each BN layer
for i, bn in enumerate(bns):
m, v = bn.running_mean, bn.running_var
sqs[i] += (v + m * m) / num_iter
mus[i] += m / num_iter
# Set the stats and restore momentum values
for i, bn in enumerate(bns):
bn.running_var = sqs[i] - mus[i] * mus[i]
bn.running_mean = mus[i]
bn.momentum = moms[i]
def reset_bn_stats(model):
"""Resets running BN stats."""
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.reset_running_stats()
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h = (h + 2 * padding - k) // stride + 1
w = (w + 2 * padding - k) // stride + 1
flops += k * k * w_in * w_out * h * w // groups
params += k * k * w_in * w_out // groups
flops += w_out if bias else 0
params += w_out if bias else 0
acts += w_out * h * w
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity_batchnorm2d(cx, w_in):
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
params += 2 * w_in
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity_maxpool2d(cx, k, stride, padding):
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
h = (h + 2 * padding - k) // stride + 1
w = (w + 2 * padding - k) // stride + 1
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
def complexity(model):
"""Compute model complexity (model can be model instance or model class)."""
size = cfg.TRAIN.IM_SIZE
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
cx = model.complexity(cx)
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
def drop_connect(x, drop_ratio):
"""Drop connect (adapted from DARTS)."""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x
def get_flat_weights(model):
"""Gets all model weights as a single flat vector."""
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
def set_flat_weights(model, flat_weights):
"""Sets all model weights from a single flat vector."""
k = 0
for p in model.parameters():
n = p.data.numel()
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
k += n
assert k == flat_weights.numel()

95
pycls/core/optimizer.py Normal file
View File

@@ -0,0 +1,95 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Optimizer."""
import numpy as np
import torch
from pycls.core.config import cfg
def construct_optimizer(model):
"""Constructs the optimizer.
Note that the momentum update in PyTorch differs from the one in Caffe2.
In particular,
Caffe2:
V := mu * V + lr * g
p := p - V
PyTorch:
V := mu * V + g
p := p - lr * V
where V is the velocity, mu is the momentum factor, lr is the learning rate,
g is the gradient and p are the parameters.
Since V is defined independently of the learning rate in PyTorch,
when the learning rate is changed there is no need to perform the
momentum correction by scaling V (unlike in the Caffe2 case).
"""
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
optim_params = [
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
]
else:
optim_params = model.parameters()
return torch.optim.SGD(
optim_params,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV,
)
def lr_fun_steps(cur_epoch):
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
def lr_fun_exp(cur_epoch):
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
def lr_fun_cos(cur_epoch):
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
def get_lr_fun():
"""Retrieves the specified lr policy function"""
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
if lr_fun not in globals():
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
return globals()[lr_fun]
def get_epoch_lr(cur_epoch):
"""Retrieves the lr for the given epoch according to the policy."""
lr = get_lr_fun()(cur_epoch)
# Linear warmup
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
lr *= warmup_factor
return lr
def set_lr(optimizer, new_lr):
"""Sets the optimizer lr to the specified value."""
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr

132
pycls/core/plotting.py Normal file
View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Plotting functions."""
import colorlover as cl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.offline as offline
import pycls.core.logging as logging
def get_plot_colors(max_colors, color_format="pyplot"):
"""Generate colors for plotting."""
colors = cl.scales["11"]["qual"]["Paired"]
if max_colors > len(colors):
colors = cl.to_rgb(cl.interp(colors, max_colors))
if color_format == "pyplot":
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
return colors
def prepare_plot_data(log_files, names, metric="top1_err"):
"""Load logs and extract data for plotting error curves."""
plot_data = []
for file, name in zip(log_files, names):
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
for phase in ["train", "test"]:
x = data[phase + "_epoch"]["epoch_ind"]
y = data[phase + "_epoch"][metric]
d["x_" + phase], d["y_" + phase] = x, y
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
plot_data.append(d)
assert len(plot_data) > 0, "No data to plot"
return plot_data
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
"""Plot error curves using plotly and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(plot_data), "plotly")
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
data = []
for i, d in enumerate(plot_data):
s = str(i)
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=True,
showlegend=False,
)
)
data.append(
go.Scatter(
x=d["x_test"],
y=d["y_test"],
mode="lines",
name=d["test_label"],
line=line_test,
legendgroup=s,
visible=True,
showlegend=True,
)
)
data.append(
go.Scatter(
x=d["x_train"],
y=d["y_train"],
mode="lines",
name=d["train_label"],
line=line_train,
legendgroup=s,
visible=False,
showlegend=True,
)
)
# Prepare layout w ability to toggle 'all', 'train', 'test'
titlefont = {"size": 18, "color": "#7f7f7f"}
vis = [[True, True, False], [False, False, True], [False, True, False]]
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
layout = go.Layout(
title=metric + " vs. epoch<br>[dash=train, solid=test]",
xaxis={"title": "epoch", "titlefont": titlefont},
yaxis={"title": metric, "titlefont": titlefont},
showlegend=True,
hoverlabel={"namelength": -1},
updatemenus=[
{
"buttons": buttons,
"direction": "down",
"showactive": True,
"x": 1.02,
"xanchor": "left",
"y": 1.08,
"yanchor": "top",
}
],
)
# Create plotly plot
offline.plot({"data": data, "layout": layout}, filename=filename)
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
"""Plot error curves using matplotlib.pyplot and save to file."""
plot_data = prepare_plot_data(log_files, names, metric)
colors = get_plot_colors(len(names))
for ind, d in enumerate(plot_data):
c, lbl = colors[ind], d["test_label"]
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
plt.xlabel("epoch", fontsize=14)
plt.ylabel(metric, fontsize=14)
plt.grid(alpha=0.4)
plt.legend()
if filename:
plt.savefig(filename)
plt.clf()
else:
plt.show()

39
pycls/core/timer.py Normal file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Timer."""
import time
class Timer(object):
"""A simple timer (adapted from Detectron)."""
def __init__(self):
self.total_time = None
self.calls = None
self.start_time = None
self.diff = None
self.average_time = None
self.reset()
def tic(self):
# using time.time as time.clock does not normalize for multithreading
self.start_time = time.time()
def toc(self):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
def reset(self):
self.total_time = 0.0
self.calls = 0
self.start_time = 0.0
self.diff = 0.0
self.average_time = 0.0

419
pycls/core/trainer.py Normal file
View File

@@ -0,0 +1,419 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tools for training and testing a model."""
import os
from thop import profile
import numpy as np
import pycls.core.benchmark as benchmark
import pycls.core.builders as builders
import pycls.core.checkpoint as checkpoint
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.logging as logging
import pycls.core.meters as meters
import pycls.core.net as net
import pycls.core.optimizer as optim
import pycls.datasets.loader as loader
import torch
import torch.nn.functional as F
from pycls.core.config import cfg
logger = logging.get_logger(__name__)
def setup_env():
"""Sets up environment for training or testing."""
if dist.is_master_proc():
# Ensure that the output dir exists
os.makedirs(cfg.OUT_DIR, exist_ok=True)
# Save the config
config.dump_cfg()
# Setup logging
logging.setup_logging()
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg))
logger.info(logging.dump_log_data(cfg, "cfg"))
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
def setup_model():
"""Sets up a model for training or testing and log the results."""
# Build the model
model = builders.build_model()
logger.info("Model:\n{}".format(model))
# Log model complexity
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
h, w = 1025, 2049
else:
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
if cfg.TASK == "jig":
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
else:
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
macs, params = profile(model, inputs=(x, ), verbose=False)
logger.info("Params: {:,}".format(params))
logger.info("Flops: {:,}".format(macs))
# Transfer the model to the current GPU device
err_str = "Cannot use more GPU devices than available"
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
cur_device = torch.cuda.current_device()
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
# Set complexity function to be module's complexity function
# model.complexity = model.module.complexity
return model
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
"""Performs one epoch of training."""
# Update drop path prob for NAS
if cfg.MODEL.TYPE == "nas":
m = model.module if cfg.NUM_GPUS > 1 else model
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
# Shuffle the data
loader.shuffle(train_loader, cur_epoch)
# Update the learning rate per epoch
if not cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch)
optim.set_lr(optimizer, lr)
# Enable training mode
model.train()
train_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(train_loader):
# Update the learning rate per iter
if cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
optim.set_lr(optimizer, lr)
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Perform the forward pass
preds = model(inputs)
# Compute the loss
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
# Perform the backward pass
optimizer.zero_grad()
loss.backward()
# Update the parameters
optimizer.step()
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the stats across the GPUs (no reduction if 1 GPU used)
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
# Copy the stats from GPU to CPU (sync point)
loss = loss.item()
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
"""Performs one epoch of differentiable architecture search."""
m = model.module if cfg.NUM_GPUS > 1 else model
# Shuffle the data
loader.shuffle(train_loader[0], cur_epoch)
loader.shuffle(train_loader[1], cur_epoch)
# Update the learning rate per epoch
if not cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch)
optim.set_lr(optimizer[0], lr)
# Enable training mode
model.train()
train_meter.iter_tic()
trainB_iter = iter(train_loader[1])
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
# Update the learning rate per iter
if cfg.OPTIM.ITER_LR:
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
optim.set_lr(optimizer[0], lr)
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Update architecture
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
try:
inputsB, labelsB = next(trainB_iter)
except StopIteration:
trainB_iter = iter(train_loader[1])
inputsB, labelsB = next(trainB_iter)
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
optimizer[1].zero_grad()
loss = m._loss(inputsB, labelsB)
loss.backward()
optimizer[1].step()
# Perform the forward pass
preds = model(inputs)
# Compute the loss
loss = loss_fun(preds, labels)
# Perform the backward pass
optimizer[0].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
# Update the parameters
optimizer[0].step()
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the stats across the GPUs (no reduction if 1 GPU used)
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
# Copy the stats from GPU to CPU (sync point)
loss = loss.item()
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
train_meter.iter_toc()
# Update and log stats
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
train_meter.log_iter_stats(cur_epoch, cur_iter)
train_meter.iter_tic()
# Log epoch stats
train_meter.log_epoch_stats(cur_epoch)
train_meter.reset()
# Log genotype
genotype = m.genotype()
logger.info("genotype = %s", genotype)
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
@torch.no_grad()
def test_epoch(test_loader, model, test_meter, cur_epoch):
"""Evaluates the model on the test set."""
# Enable eval mode
model.eval()
test_meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(test_loader):
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Compute the predictions
preds = model(inputs)
# Compute the errors
if cfg.TASK == "col":
preds = preds.permute(0, 2, 3, 1)
preds = preds.reshape(-1, preds.size(3))
labels = labels.reshape(-1)
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
else:
mb_size = inputs.size(0) * cfg.NUM_GPUS
if cfg.TASK == "seg":
# top1_err is in fact inter; top5_err is in fact union
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
else:
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
# Combine the errors across the GPUs (no reduction if 1 GPU used)
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
# Copy the errors from GPU to CPU (sync point)
if cfg.TASK == "seg":
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
else:
top1_err, top5_err = top1_err.item(), top5_err.item()
test_meter.iter_toc()
# Update and log stats
test_meter.update_stats(top1_err, top5_err, mb_size)
test_meter.log_iter_stats(cur_epoch, cur_iter)
test_meter.iter_tic()
# Log epoch stats
test_meter.log_epoch_stats(cur_epoch)
test_meter.reset()
def train_model():
"""Trains the model."""
# Setup training/testing environment
setup_env()
# Construct the model, loss_fun, and optimizer
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
if "search" in cfg.MODEL.TYPE:
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
optimizer_w = torch.optim.SGD(
params=params_w,
lr=cfg.OPTIM.BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV
)
if cfg.OPTIM.ARCH_OPTIM == "adam":
optimizer_a = torch.optim.Adam(
params=params_a,
lr=cfg.OPTIM.ARCH_BASE_LR,
betas=(0.5, 0.999),
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
)
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
optimizer_a = torch.optim.SGD(
params=params_a,
lr=cfg.OPTIM.ARCH_BASE_LR,
momentum=cfg.OPTIM.MOMENTUM,
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
dampening=cfg.OPTIM.DAMPENING,
nesterov=cfg.OPTIM.NESTEROV
)
optimizer = [optimizer_w, optimizer_a]
else:
optimizer = optim.construct_optimizer(model)
# Load checkpoint or initial weights
start_epoch = 0
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
last_checkpoint = checkpoint.get_last_checkpoint()
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
start_epoch = checkpoint_epoch + 1
elif cfg.TRAIN.WEIGHTS:
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
# Create data loaders and meters
if cfg.TRAIN.PORTION < 1:
if "search" in cfg.MODEL.TYPE:
train_loader = [loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="l"
),
loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="r"
)]
else:
train_loader = loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=True,
drop_last=True,
portion=cfg.TRAIN.PORTION,
side="l"
)
test_loader = loader._construct_loader(
dataset_name=cfg.TRAIN.DATASET,
split=cfg.TRAIN.SPLIT,
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
shuffle=False,
drop_last=False,
portion=cfg.TRAIN.PORTION,
side="r"
)
else:
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
l = train_loader[0] if isinstance(train_loader, list) else train_loader
train_meter = train_meter_type(len(l))
test_meter = test_meter_type(len(test_loader))
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
l = train_loader[0] if isinstance(train_loader, list) else train_loader
benchmark.compute_time_full(model, loss_fun, l, test_loader)
# Perform the training loop
logger.info("Start epoch: {}".format(start_epoch + 1))
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
net.compute_precise_bn_stats(model, train_loader)
# Save a checkpoint
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
# Evaluate the model
next_epoch = cur_epoch + 1
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
test_epoch(test_loader, model, test_meter, cur_epoch)
def test_model():
"""Evaluates a trained model."""
# Setup training/testing environment
setup_env()
# Construct the model
model = setup_model()
# Load model weights
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
# Create data loaders and meters
test_loader = loader.construct_test_loader()
test_meter = meters.TestMeter(len(test_loader))
# Evaluate the model
test_epoch(test_loader, model, test_meter, 0)
def time_model():
"""Times model and data loader."""
# Setup training/testing environment
setup_env()
# Construct the model and loss_fun
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
# Create data loaders
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
# Compute model and loader timings
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)

108
pycls/models/common.py Normal file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from pycls.core.config import cfg
def Preprocess(x):
if cfg.TASK == 'jig':
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
return x
class Classifier(nn.Module):
def __init__(self, channels, num_classes):
super(Classifier, self).__init__()
if cfg.TASK == 'jig':
self.jig_sq = cfg.JIGSAW_GRID ** 2
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
elif cfg.TASK == 'col':
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
elif cfg.TASK == 'seg':
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
else:
self.pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(channels, num_classes)
def forward(self, x, shape):
if cfg.TASK == 'jig':
x = self.pooling(x)
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
x = self.classifier(x.view(x.size(0), -1))
elif cfg.TASK in ['col', 'seg']:
x = self.classifier(x)
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
else:
x = self.pooling(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, num_classes, rates):
super(ASPP, self).__init__()
assert len(rates) in [1, 3]
self.rates = rates
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.aspp1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
padding=rates[0], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
if len(self.rates) == 3:
self.aspp3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
padding=rates[1], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp4 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
padding=rates[2], bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.aspp5 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.classifier = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, num_classes, 1)
)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
align_corners=True)(x5)
if len(self.rates) == 3:
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x = torch.cat((x1, x2, x3, x4, x5), 1)
else:
x = torch.cat((x1, x2, x5), 1)
x = self.classifier(x)
return x

View File

@@ -0,0 +1,634 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS genotypes (adopted from DARTS)."""
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
# NASNet ops
NASNET_OPS = [
'skip_connect',
'conv_3x1_1x3',
'conv_7x1_1x7',
'dil_conv_3x3',
'avg_pool_3x3',
'max_pool_3x3',
'max_pool_5x5',
'max_pool_7x7',
'conv_1x1',
'conv_3x3',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
]
# ENAS ops
ENAS_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'avg_pool_3x3',
'max_pool_3x3',
]
# AmoebaNet ops
AMOEBA_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'avg_pool_3x3',
'max_pool_3x3',
'dil_sep_conv_3x3',
'conv_7x1_1x7',
]
# NAO ops
NAO_OPS = [
'skip_connect',
'conv_1x1',
'conv_3x3',
'conv_3x1_1x3',
'conv_7x1_1x7',
'max_pool_2x2',
'max_pool_3x3',
'max_pool_5x5',
'avg_pool_2x2',
'avg_pool_3x3',
'avg_pool_5x5',
]
# PNAS ops
PNAS_OPS = [
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'skip_connect',
'avg_pool_3x3',
'max_pool_3x3',
'dil_conv_3x3',
]
# DARTS ops
DARTS_OPS = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6],
)
PNASNet = Genotype(
normal=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 0),
('max_pool_3x3', 0),
('sep_conv_7x7', 1),
('max_pool_3x3', 1),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 4),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 1),
],
reduce_concat=[2, 3, 4, 5, 6],
)
AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6]
)
DARTS_V1 = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 0),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('avg_pool_3x3', 0)
],
reduce_concat=[2, 3, 4, 5]
)
DARTS_V2 = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('skip_connect', 0),
('dil_conv_3x3', 2)
],
normal_concat=[2, 3, 4, 5],
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 1),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('max_pool_3x3', 1)
],
reduce_concat=[2, 3, 4, 5]
)
PDARTS = Genotype(
normal=[
('skip_connect', 0),
('dil_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_3x3', 1),
('dil_conv_3x3', 1),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
PCDARTS_C10 = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('dil_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('avg_pool_3x3', 0),
('dil_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2)
],
reduce_concat=range(2, 6)
)
PCDARTS_IN1K = Genotype(
normal=[
('skip_connect', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('skip_connect', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('dil_conv_5x5', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('max_pool_3x3', 0),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_COL = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 3),
('max_pool_3x3', 0),
('sep_conv_3x3', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_CLS = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('dil_conv_5x5', 2),
('max_pool_3x3', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2),
('dil_conv_5x5', 4),
('dil_conv_5x5', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_ROT = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('dil_conv_5x5', 2),
('sep_conv_5x5', 0),
('dil_conv_5x5', 3),
('sep_conv_3x3', 2),
('sep_conv_3x3', 4),
('sep_conv_3x3', 3)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_COL = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 1)
],
reduce_concat=range(2, 6)
)
UNNAS_IMAGENET22K_JIG = Genotype(
normal=[
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 4)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 0),
('skip_connect', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_5x5', 3),
('sep_conv_5x5', 0),
('sep_conv_5x5', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_SEG = Genotype(
normal=[
('skip_connect', 0),
('sep_conv_5x5', 1),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 2),
('sep_conv_5x5', 0),
('sep_conv_3x3', 4),
('sep_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_ROT = Genotype(
normal=[
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0)
],
normal_concat=range(2, 6),
reduce=[
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_5x5', 1),
('sep_conv_5x5', 3),
('dil_conv_5x5', 2),
('sep_conv_5x5', 2),
('sep_conv_5x5', 0)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_COL = Genotype(
normal=[
('dil_conv_3x3', 1),
('sep_conv_3x3', 0),
('skip_connect', 0),
('sep_conv_5x5', 2),
('dil_conv_3x3', 3),
('skip_connect', 0),
('skip_connect', 0),
('sep_conv_3x3', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('avg_pool_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 4)
],
reduce_concat=range(2, 6)
)
UNNAS_CITYSCAPES_JIG = Genotype(
normal=[
('dil_conv_5x5', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('sep_conv_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_3x3', 2),
('sep_conv_3x3', 0),
('dil_conv_5x5', 1)
],
normal_concat=range(2, 6),
reduce=[
('avg_pool_3x3', 0),
('skip_connect', 1),
('dil_conv_5x5', 1),
('dil_conv_5x5', 2),
('dil_conv_5x5', 2),
('dil_conv_5x5', 0),
('dil_conv_5x5', 3),
('dil_conv_5x5', 2)
],
reduce_concat=range(2, 6)
)
# Supported genotypes
GENOTYPES = {
'nas': NASNet,
'pnas': PNASNet,
'amoeba': AmoebaNet,
'darts_v1': DARTS_V1,
'darts_v2': DARTS_V2,
'pdarts': PDARTS,
'pcdarts_c10': PCDARTS_C10,
'pcdarts_in1k': PCDARTS_IN1K,
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
'custom': None,
}

337
pycls/models/nas/nas.py Normal file
View File

@@ -0,0 +1,337 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS network (adopted from DARTS)."""
from torch.autograd import Variable
import torch
import torch.nn as nn
import pycls.core.logging as logging
from pycls.core.config import cfg
from pycls.models.common import Preprocess
from pycls.models.common import Classifier
from pycls.models.nas.genotypes import GENOTYPES
from pycls.models.nas.genotypes import Genotype
from pycls.models.nas.operations import FactorizedReduce
from pycls.models.nas.operations import OPS
from pycls.models.nas.operations import ReLUConvBN
from pycls.models.nas.operations import Identity
logger = logging.get_logger(__name__)
def drop_path(x, drop_prob):
"""Drop path (ported from DARTS)."""
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
)
x.div_(keep_prob)
x.mul_(mask)
return x
class Cell(nn.Module):
"""NAS cell (ported from DARTS)."""
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x
class NetworkCIFAR(nn.Module):
"""CIFAR network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
class NetworkImageNet(nn.Module):
"""ImageNet network (ported from DARTS)."""
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.stem0 = nn.Sequential(
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
for i in range(layers):
if i in reduction_layers:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
self.classifier = Classifier(C_prev, num_classes)
def forward(self, input):
input = Preprocess(input)
logits_aux = None
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
logits = self.classifier(s1, input.shape[2:])
if self._auxiliary and self.training:
return logits, logits_aux
return logits
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def step(self, input, target, args, shared=None, return_grad=False):
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
if args.grad_clip != 0:
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
if return_grad:
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
return logit_t, Lt, grad
else:
return logit_t, Lt
class NAS(nn.Module):
"""NAS net wrapper (delegates to nets from DARTS)."""
def __init__(self):
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
assert cfg.NAS.GENOTYPE in GENOTYPES, \
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
super(NAS, self).__init__()
logger.info('Constructing NAS: {}'.format(cfg.NAS))
# Use a custom or predefined genotype
if cfg.NAS.GENOTYPE == 'custom':
genotype = Genotype(
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
)
else:
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
# Determine the network constructor for dataset
if 'cifar' in cfg.TRAIN.DATASET:
net_ctor = NetworkCIFAR
else:
net_ctor = NetworkImageNet
# Construct the network
self.net_ = net_ctor(
C=cfg.NAS.WIDTH,
num_classes=cfg.MODEL.NUM_CLASSES,
layers=cfg.NAS.DEPTH,
auxiliary=cfg.NAS.AUX,
genotype=genotype
)
# Drop path probability (set / annealed based on epoch)
self.net_.drop_path_prob = 0.0
def set_drop_path_prob(self, drop_path_prob):
self.net_.drop_path_prob = drop_path_prob
def forward(self, x):
return self.net_.forward(x)

View File

@@ -0,0 +1,219 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""NAS ops (adopted from DARTS)."""
import torch
import torch.nn as nn
from torch.autograd import Variable
OPS = {
'none': lambda C, stride, affine:
Zero(stride),
'noise': lambda C, stride, affine: NoiseOp(stride, 0., 1.),
'avg_pool_2x2': lambda C, stride, affine:
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
'avg_pool_3x3': lambda C, stride, affine:
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'avg_pool_5x5': lambda C, stride, affine:
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
'max_pool_2x2': lambda C, stride, affine:
nn.MaxPool2d(2, stride=stride, padding=0),
'max_pool_3x3': lambda C, stride, affine:
nn.MaxPool2d(3, stride=stride, padding=1),
'max_pool_5x5': lambda C, stride, affine:
nn.MaxPool2d(5, stride=stride, padding=2),
'max_pool_7x7': lambda C, stride, affine:
nn.MaxPool2d(7, stride=stride, padding=3),
'skip_connect': lambda C, stride, affine:
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'conv_1x1': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_3x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3': lambda C, stride, affine:
SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine:
SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine:
SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine:
DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine:
DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_sep_conv_3x3': lambda C, stride, affine:
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
'conv_3x1_1x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_7x1_1x7': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class NoiseOp(nn.Module):
def __init__(self, stride, mean, std):
super(NoiseOp, self).__init__()
self.stride = stride
self.mean = mean
self.std = std
def forward(self, x, block_input=False):
if block_input:
x = x*0
if self.stride != 1:
x_new = x[:,:,::self.stride,::self.stride]
else:
x_new = x
noise = Variable(x_new.data.new(x_new.size()).normal_(self.mean, self.std))
return noise
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False
),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class DilSepConv(nn.Module):
def __init__(
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
):
super(DilSepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x):
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
out = self.bn(out)
return out

169
sota/cnn/genotypes.py Normal file
View File

@@ -0,0 +1,169 @@
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
'none',
'noise',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5'
]
######## S1-S4 Space ########
#### cifar10 s1 - s4
init_pt_s1_C10_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["skip_connect", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["dil_conv_5x5", 2], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
init_pt_s1_C10_2 = Genotype(normal=[["skip_connect", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 3], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
init_pt_s2_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s2_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1]], reduce_concat=range(2, 6))
init_pt_s3_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s3_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s4_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["noise", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s4_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
#### cifar100 s1 - s4
init_pt_s1_C100_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
init_pt_s1_C100_2 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 3], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
init_pt_s2_C100_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s2_C100_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s3_C100_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s3_C100_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s4_C100_0 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s4_C100_2 = Genotype(normal=[["noise", 0], ["sep_conv_3x3", 1], ["noise", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
#### svhn s1 - s4
init_pt_s1_svhn_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 2], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
init_pt_s1_svhn_2 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["max_pool_3x3", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["avg_pool_3x3", 0], ["dil_conv_5x5", 3]], reduce_concat=range(2, 6))
init_pt_s2_svhn_0 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s2_svhn_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s3_svhn_0 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s3_svhn_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s4_svhn_0 = Genotype(normal=[["noise", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["noise", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s4_svhn_2 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["noise", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
######## DARTS Space ########
####init-100-N10
init_pt_s5_C10_0_100_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
####global op gready
global_pt_s5_C10_0_100_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
global_pt_s5_C10_1_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
global_pt_s5_C10_2_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
global_pt_s5_C10_3_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
####2500_sample
sample_2500_0 = Genotype(normal=[["dil_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_5x5", 2], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["dil_conv_5x5", 3]], reduce_concat=range(2, 6))
sample_2500_1 = Genotype(normal=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["avg_pool_3x3", 2], ["dil_conv_5x5", 1], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
sample_2500_2 = Genotype(normal=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
sample_2500_3 = Genotype(normal=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["dil_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["dil_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 0], ["max_pool_3x3", 1], ["avg_pool_3x3", 0], ["max_pool_3x3", 1], ["avg_pool_3x3", 1], ["skip_connect", 3]], reduce_concat=range(2, 6))
####20000_sample
sample_20000_0 = Genotype(normal=[["skip_connect", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["skip_connect", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 2], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["dil_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 0], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
sample_20000_1 = Genotype(normal=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["avg_pool_3x3", 2], ["dil_conv_5x5", 1], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
sample_20000_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["skip_connect", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 3]], reduce_concat=range(2, 6))
sample_20000_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["dil_conv_3x3", 2], ["dil_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
####50000_sample
sample_50000_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["skip_connect", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 0], ["max_pool_3x3", 1]], reduce_concat=range(2, 6))
sample_50000_1 = Genotype(normal=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["max_pool_3x3", 1], ["dil_conv_3x3", 1], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
sample_50000_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["skip_connect", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 3]], reduce_concat=range(2, 6))
sample_50000_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["dil_conv_3x3", 2], ["dil_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
#### random
random_max_0 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
random_max_1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
random_max_2 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
random_max_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
#### ImageNet-1k
init_pt_s5_in_0_100_N10=Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_in_1_100_N10=Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_in_2_100_N10=Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_in_3_100_N10=Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
####N1
init_pt_s5_C10_0_N1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_N1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_N1 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_N1 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
####N5
#####V1
init_pt_s5_C10_0_1_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_1_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_1_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_1_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
#####V10
init_pt_s5_C10_0_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_10_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
#####V100
init_pt_s5_C10_0_100_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_100_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_100_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_100_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
####N10
#####V1
init_pt_s5_C10_0_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
#####V10
init_pt_s5_C10_0_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_1_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_2_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
init_pt_s5_C10_3_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
#fisher
cf10_fisher = Genotype(normal=[["avg_pool_3x3", 0], ["avg_pool_3x3", 1], ["avg_pool_3x3", 0], ["dil_conv_3x3", 1],["avg_pool_3x3", 0], ["skip_connect", 2],["sep_conv_5x5", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["max_pool_3x3", 0], ["max_pool_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
#grasp
cf10_grasp = Genotype(normal=[["avg_pool_3x3", 0], ["avg_pool_3x3", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["skip_connect", 1], ["avg_pool_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["max_pool_3x3", 1], ["sep_conv_3x3", 3]], reduce_concat=range(2, 6))
#jacob_cov
cf10_jacob_cov = Genotype(normal=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["dil_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["max_pool_3x3", 1], ["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["dil_conv_3x3", 1], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
#meco
cf10_meco = Genotype(normal=[["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce= [["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 1], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
#synflow
cf10_synflow = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
#zico
cf10_zico= Genotype(normal= [["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
#snip
cf10_snip = Genotype(normal= [["sep_conv_3x3", 0], ["avg_pool_3x3", 1], ["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["dil_conv_3x3", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
#fisher
cf100_fisher = Genotype(normal= [["sep_conv_3x3", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["max_pool_3x3", 1], ["dil_conv_3x3", 1], ["skip_connect", 3], ["dil_conv_5x5", 0], ["skip_connect", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["max_pool_3x3", 1], ["sep_conv_3x3", 4]] , reduce_concat=range(2, 6))
#grasp
cf100_grasp= Genotype(normal= [["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["avg_pool_3x3", 0], ["sep_conv_3x3", 4]] , normal_concat=range(2, 6), reduce= [["max_pool_3x3", 0], ["sep_conv_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_3x3", 2], ["skip_connect", 0], ["dil_conv_3x3", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]] , reduce_concat=range(2, 6))
#jacob_cov
cf100_jacob_cov = Genotype(normal= [["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["avg_pool_3x3", 0], ["avg_pool_3x3", 3], ["dil_conv_5x5", 1], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce= [["skip_connect", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["skip_connect", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_5x5", 1]] , reduce_concat=range(2, 6))
#meco
cf100_meco = Genotype(normal= [["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["dil_conv_5x5", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce= [["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["dil_conv_3x3", 0], ["sep_conv_3x3", 1]] , reduce_concat=range(2, 6))
#snip
cf100_snip = Genotype(normal= [["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["skip_connect", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["max_pool_3x3", 3]], normal_concat=range(2, 6), reduce= [["dil_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 2], ["skip_connect", 0], ["skip_connect", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2]] , reduce_concat=range(2, 6))
#synflow
cf100_synflow = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]] , normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]] , reduce_concat=range(2, 6))
#zico
cf100_zico = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]] , reduce_concat=range(2, 6))

40
sota/cnn/hdf5.py Normal file
View File

@@ -0,0 +1,40 @@
import h5py
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
class H5Dataset(Dataset):
def __init__(self, h5_path, transform=None):
self.h5_path = h5_path
self.h5_file = None
self.length = len(h5py.File(h5_path, 'r'))
self.transform = transform
def __getitem__(self, index):
#loading in getitem allows us to use multiple processes for data loading
#because hdf5 files aren't pickelable so can't transfer them across processes
# https://discuss.pytorch.org/t/hdf5-a-data-format-for-pytorch/40379
# https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
# TODO possible look at __getstate__ and __setstate__ as a more elegant solution
if self.h5_file is None:
self.h5_file = h5py.File(self.h5_path, 'r', libver="latest", swmr=True)
record = self.h5_file[str(index)]
if self.transform:
x = Image.fromarray(record['data'][()])
x = self.transform(x)
else:
x = torch.from_numpy(record['data'][()])
y = record['target'][()]
y = torch.from_numpy(np.asarray(y))
return (x,y)
def __len__(self):
return self.length

336
sota/cnn/init_projection.py Normal file
View File

@@ -0,0 +1,336 @@
import sys
sys.path.insert(0, '../../')
import numpy as np
import torch
import logging
import torch.utils
from copy import deepcopy
from foresight.pruners import *
torch.set_printoptions(precision=4, sci_mode=False)
def sample_op(model, input, target, args, cell_type, selected_eid=None):
''' operation '''
#### macros
num_edges, num_ops = model.num_edges, model.num_ops
candidate_flags = model.candidate_flags[cell_type]
proj_crit = args.proj_crit[cell_type]
#### select an edge
if selected_eid is None:
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
selected_eid = np.random.choice(remain_eids, size=1)[0]
logging.info('selected edge: %d %s', selected_eid, cell_type)
select_opid = np.random.choice(np.array(range(num_ops)), size=1)[0]
return selected_eid, select_opid
def project_op(model, input, target, args, cell_type, proj_queue=None, selected_eid=None):
''' operation '''
#### macros
num_edges, num_ops = model.num_edges, model.num_ops
candidate_flags = model.candidate_flags[cell_type]
proj_crit = args.proj_crit[cell_type]
#### select an edge
if selected_eid is None:
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
# print(num_edges, num_ops, remain_eids)
if args.edge_decision == "random":
selected_eid = np.random.choice(remain_eids, size=1)[0]
logging.info('selected edge: %d %s', selected_eid, cell_type)
elif args.edge_decision == 'reverse':
selected_eid = remain_eids[-1]
logging.info('selected edge: %d %s', selected_eid, cell_type)
else:
selected_eid = remain_eids[0]
logging.info('selected node: %d %s', selected_eid, cell_type)
#### select the best operation
if proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 0
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
for opid in range(num_ops):
## projection
weights = model.get_projected_weights(cell_type)
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
# ## proj evaluation
# with torch.no_grad():
# valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
# crit = valid_stats
# crit_list.append(crit)
# if crit_extrema is None or compare(crit, crit_extrema):
# crit_extrema = crit
# best_opid = opid
## proj evaluation
if proj_crit == 'jacob':
crit = Jocab_Score(model,cell_type, input, target, weights=weights)
else:
cache_weight = model.proj_weights[cell_type][selected_eid]
cache_flag = model.candidate_flags[cell_type][selected_eid]
for idx in range(num_ops):
if idx == opid:
model.proj_weights[cell_type][selected_eid][opid] = 0
else:
model.proj_weights[cell_type][selected_eid][idx] = 1.0 / num_ops
model.candidate_flags[cell_type][selected_eid] = False
# print(model.get_projected_weights())
if proj_crit == 'comb':
synflow = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['synflow'])
var = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['var'])
# print(synflow, var)
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
measures = {'comb': comb}
else:
measures = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[proj_crit])
# print(measures)
for idx in range(num_ops):
model.proj_weights[cell_type][selected_eid][idx] = 0
model.candidate_flags[cell_type][selected_eid] = cache_flag
crit = measures[proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
#### project
logging.info('best opid: %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
def project_global_op(model, input, target, args, infer, cell_type, selected_eid=None):
''' operation '''
#### macros
num_edges, num_ops = model.num_edges, model.num_ops
candidate_flags = model.candidate_flags[cell_type]
proj_crit = args.proj_crit[cell_type]
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
#### select the best operation
if proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
best_opid = 0
crit_extrema = None
best_eid = None
for eid in remain_eids:
for opid in range(num_ops):
## projection
weights = model.get_projected_weights(cell_type)
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
#weights_dict = {cell_type:weights}
with torch.no_grad():
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
#### project
logging.info('best opid: %d', best_opid)
#logging.info(crit_list)
return best_eid, best_opid
def sample_edge(model, input, target, args, cell_type, selected_eid=None):
''' topology '''
#### macros
candidate_flags = model.candidate_flags_edge[cell_type]
proj_crit = args.proj_crit[cell_type]
#### select an node
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
selected_nid = np.random.choice(remain_nids, size=1)[0]
logging.info('selected node: %d %s', selected_nid, cell_type)
eids = deepcopy(model.nid2eids[selected_nid])
while len(eids) > 2:
elected_eid = np.random.choice(eids, size=1)[0]
eids.remove(elected_eid)
return selected_nid, eids
def project_edge(model, input, target, args, cell_type):
''' topology '''
#### macros
candidate_flags = model.candidate_flags_edge[cell_type]
proj_crit = args.proj_crit[cell_type]
#### select an node
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
if args.edge_decision == "random":
selected_nid = np.random.choice(remain_nids, size=1)[0]
logging.info('selected node: %d %s', selected_nid, cell_type)
elif args.edge_decision == 'reverse':
selected_nid = remain_nids[-1]
logging.info('selected node: %d %s', selected_nid, cell_type)
else:
selected_nid = np.random.choice(remain_nids, size=1)[0]
logging.info('selected node: %d %s', selected_nid, cell_type)
#### select top2 edges
if proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 3
compare = lambda x, y: x < y
eids = deepcopy(model.nid2eids[selected_nid])
crit_list = []
while len(eids) > 2:
eid_todel = None
crit_extrema = None
for eid in eids:
weights = model.get_projected_weights(cell_type)
weights[eid].data.fill_(0)
## proj evaluation
with torch.no_grad():
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
crit_extrema = crit
eid_todel = eid
eids.remove(eid_todel)
#### project
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
#logging.info(crit_list)
return selected_nid, eids
def pt_project(train_queue, model, args):
model.eval()
#### macros
num_projs = model.num_edges + len(model.nid2eids.keys())
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
proj_queue = train_queue
epoch = 0
for step, (input, target) in enumerate(proj_queue):
if epoch < model.num_edges:
logging.info('project op')
if args.edge_decision == 'global_op_greedy':
selected_eid_normal, best_opid_normal = project_global_op(model, input, target, args, cell_type='normal')
elif args.edge_decision == 'sample':
selected_eid_normal, best_opid_normal = sample_op(model, input, target, args, cell_type='normal')
else:
selected_eid_normal, best_opid_normal = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='normal')
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
if args.edge_decision == 'global_op_greedy':
selected_eid_reduce, best_opid_reduce = project_global_op(model, input, target, args, cell_type='reduce')
elif args.edge_decision == 'sample':
selected_eid_reduce, best_opid_reduce = sample_op(model, input, target, args, cell_type='reduce')
else:
selected_eid_reduce, best_opid_reduce = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='reduce')
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
else:
logging.info('project edge')
if args.edge_decision == 'sample':
selected_nid_normal, eids_normal = sample_edge(model, input, target, args, cell_type='normal')
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
selected_nid_reduce, eids_reduce = sample_edge(model, input, target, args, cell_type='reduce')
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
else:
selected_nid_normal, eids_normal = project_edge(model, input, target, args, cell_type='normal')
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
selected_nid_reduce, eids_reduce = project_edge(model, input, target, args, cell_type='reduce')
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
epoch+=1
if epoch == num_projs:
break
return
def Jocab_Score(ori_model, cell_type, input, target, weights=None):
model = deepcopy(ori_model)
model.eval()
if cell_type == 'reduce':
model.proj_weights['reduce'] = weights
model.proj_weights['normal'] = model.get_projected_weights('normal')
else:
model.proj_weights['normal'] = weights
model.proj_weights['reduce'] = model.get_projected_weights('reduce')
batch_size = input.shape[0]
model.K = torch.zeros(batch_size, batch_size).cuda()
def counting_forward_hook(module, inp, out):
try:
if isinstance(inp, tuple):
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
model.K = model.K + K + K2
except:
pass
for name, module in model.named_modules():
if 'ReLU' in str(type(module)):
module.register_forward_hook(counting_forward_hook)
input = input.cuda()
model(input, using_proj=True)
score = hooklogdet(model.K.cpu().numpy())
del model
return score
def hooklogdet(K, labels=None):
s, ld = np.linalg.slogdet(K)
return ld

133
sota/cnn/model.py Normal file
View File

@@ -0,0 +1,133 @@
import torch
import torch.nn as nn
from sota.cnn.operations import *
import sys
sys.path.insert(0, '../../')
from nasbench201.utils import drop_path
class Cell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2*i]]
h2 = states[self._indices[2*i+1]]
op1 = self._ops[2*i]
op2 = self._ops[2*i+1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHead(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHead, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
# image size = 2 x 2
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class Network(nn.Module):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(Network, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
stem_multiplier = 3
C_curr = stem_multiplier*C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
if i == 2*layers//3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, num_classes)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def forward(self, input):
logits_aux = None
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2*self._layers//3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return logits, logits_aux

150
sota/cnn/model_imagenet.py Normal file
View File

@@ -0,0 +1,150 @@
import sys
sys.path.insert(0, '../../')
# from optimizers.darts.operations import *
from sota.cnn.operations import *
#from optimizers.darts.utils import drop_path
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
class Cell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
print(C_prev_prev, C_prev, C)
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
if reduction:
op_names, indices = zip(*genotype.reduce)
concat = genotype.reduce_concat
else:
op_names, indices = zip(*genotype.normal)
concat = genotype.normal_concat
self._compile(C, op_names, indices, concat, reduction)
def _compile(self, C, op_names, indices, concat, reduction):
assert len(op_names) == len(indices)
self._steps = len(op_names) // 2
self._concat = concat
self.multiplier = len(concat)
self._ops = nn.ModuleList()
for name, index in zip(op_names, indices):
stride = 2 if reduction and index < 2 else 1
op = OPS[name](C, stride, True)
self._ops += [op]
self._indices = indices
def forward(self, s0, s1, drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2 * i]]
h2 = states[self._indices[2 * i + 1]]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
h1 = op1(h1)
h2 = op2(h2)
if self.training and drop_prob > 0.:
if not isinstance(op1, Identity):
h1 = drop_path(h1, drop_prob)
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
s = h1 + h2
states += [s]
return torch.cat([states[i] for i in self._concat], dim=1)
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 14x14"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class NetworkImageNet(nn.Module):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
self._auxiliary = auxiliary
self.drop_path_prob = 0.0
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_prev_prev, C_prev, C_curr = C, C, C
self.cells = nn.ModuleList()
reduction_prev = True
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
if i == 2 * layers // 3:
C_to_auxiliary = C_prev
if auxiliary:
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
self.global_pooling = nn.AvgPool2d(7)
self.classifier = nn.Linear(C_prev, num_classes)
def forward(self, input):
logits_aux = None
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
if i == 2 * self._layers // 3:
if self._auxiliary and self.training:
logits_aux = self.auxiliary_head(s1)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return logits, logits_aux

288
sota/cnn/model_search.py Normal file
View File

@@ -0,0 +1,288 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from sota.cnn.operations import *
from sota.cnn.genotypes import Genotype
import sys
sys.path.insert(0, '../../')
from nasbench201.utils import drop_path
class MixedOp(nn.Module):
def __init__(self, C, stride, PRIMITIVES):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
def forward(self, x, weights):
ret = sum(w * op(x, block_input=True) if w == 0 else w * op(x) for w, op in zip(weights, self._ops) if w != 0)
return ret
class Cell(nn.Module):
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
super(Cell, self).__init__()
self.reduction = reduction
self.primitives = self.PRIMITIVES['primitives_reduct' if reduction else 'primitives_normal']
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
self._bns = nn.ModuleList()
edge_index = 0
for i in range(self._steps):
for j in range(2+i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride, self.primitives[edge_index])
self._ops.append(op)
edge_index += 1
def forward(self, s0, s1, weights, drop_prob=0.):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
if drop_prob > 0. and self.training:
s = sum(drop_path(self._ops[offset+j](h, weights[offset+j]), drop_prob) for j, h in enumerate(states))
else:
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, primitives, args,
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0, nettype='cifar'):
super(Network, self).__init__()
#### original code
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
self.drop_path_prob = drop_path_prob
self.nettype = nettype
nn.Module.PRIMITIVES = primitives; self.op_names = primitives
C_curr = stem_multiplier*C
if self.nettype == 'cifar':
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
else:
self.stem0 = nn.Sequential(
nn.Conv2d(3, C_curr // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C_curr // 2),
nn.ReLU(inplace=True),
nn.Conv2d(C_curr // 2, C_curr, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C_curr, C_curr, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C_curr),
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
if self.nettype == 'cifar':
reduction_prev = False
else:
reduction_prev = True
for i in range(layers):
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier*C_curr
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
self._initialize_alphas()
#### optimizer
self._args = args
self.optimizer = torch.optim.SGD(
self.get_weights(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov= args.nesterov)
def reset_optimizer(self, lr, momentum, weight_decay):
del self.optimizer
self.optimizer = torch.optim.SGD(
self.get_weights(),
lr,
momentum=momentum,
weight_decay=weight_decay)
def _loss(self, input, target, return_logits=False):
logits = self(input)
loss = self._criterion(logits, target)
return (loss, logits) if return_logits else loss
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i))
num_ops = len(self.PRIMITIVES['primitives_normal'][0])
self.num_edges = k
self.num_ops = num_ops
self.alphas_normal = self._initialize_alphas_numpy(k, num_ops)
self.alphas_reduce = self._initialize_alphas_numpy(k, num_ops)
self._arch_parameters = [ # must be in this order!
self.alphas_normal,
self.alphas_reduce,
]
def _initialize_alphas_numpy(self, k, num_ops):
''' init from specified arch '''
return Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
def forward(self, input):
weights = self.get_softmax()
weights_normal = weights['normal']
weights_reduce = weights['reduce']
if self.nettype == 'cifar':
s0 = s1 = self.stem(input)
else:
print('imagetnet')
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = weights_reduce
else:
weights = weights_normal
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
def step(self, input, target, args, shared=None):
assert shared is None, 'gradient sharing disabled'
Lt, logit_t = self._loss(input, target, return_logits=True)
Lt.backward()
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
self.optimizer.step()
return logit_t, Lt
#### utils
def set_arch_parameters(self, new_alphas):
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
alpha.data.copy_(new_alpha.data)
def get_softmax(self):
weights_normal = F.softmax(self.alphas_normal, dim=-1)
weights_reduce = F.softmax(self.alphas_reduce, dim=-1)
return {'normal':weights_normal, 'reduce':weights_reduce}
def printing(self, logging, option='all'):
weights = self.get_softmax()
if option in ['all', 'normal']:
weights_normal = weights['normal']
logging.info(weights_normal)
if option in ['all', 'reduce']:
weights_reduce = weights['reduce']
logging.info(weights_reduce)
def arch_parameters(self):
return self._arch_parameters
def get_weights(self):
return self.parameters()
def new(self):
model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.PRIMITIVES, self._args,\
drop_path_prob=self.drop_path_prob).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def clip(self):
for p in self.arch_parameters():
for line in p:
max_index = line.argmax()
line.data.clamp_(0, 1)
if line.sum() == 0.0:
line.data[max_index] = 1.0
line.data.div_(line.sum())
def genotype(self):
def _parse(weights, normal=True):
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct'] ## two are equal for Darts space
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
try:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
except ValueError: # This error happens when the 'none' op is not present in the ops
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if 'none' in PRIMITIVES[j]:
if k != PRIMITIVES[j].index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
else:
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[start+j][k_best], j))
start = end
n += 1
return gene
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy(), True)
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy(), False)
concat = range(2+self._steps-self._multiplier, self._steps+2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype

View File

@@ -0,0 +1,213 @@
import torch
from copy import deepcopy
from sota.cnn.operations import *
from sota.cnn.genotypes import Genotype
import sys
sys.path.insert(0, '../../')
from sota.cnn.model_search import Network
class DartsNetworkProj(Network):
def __init__(self, C, num_classes, layers, criterion, primitives, args,
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0):
super(DartsNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob)
self._initialize_flags()
self._initialize_proj_weights()
self._initialize_topology_dicts()
#### proj flags
def _initialize_topology_dicts(self):
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
self.nid2selected_eids = {
'normal': {0:[],1:[],2:[]},
'reduce': {0:[],1:[],2:[]},
}
def _initialize_flags(self):
self.candidate_flags = {
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
} # must be in this order
self.candidate_flags_edge = {
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
}
def _initialize_proj_weights(self):
''' data structures used for proj '''
if isinstance(self.alphas_normal, list):
alphas_normal = torch.stack(self.alphas_normal, dim=0)
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
else:
alphas_normal = self.alphas_normal
alphas_reduce = self.alphas_reduce
self.proj_weights = { # for hard/soft assignment after project
'normal': torch.zeros_like(alphas_normal),
'reduce': torch.zeros_like(alphas_reduce),
}
#### proj function
def project_op(self, eid, opid, cell_type):
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
self.candidate_flags[cell_type][eid] = False
def project_edge(self, nid, eids, cell_type):
for eid in self.nid2eids[nid]:
if eid not in eids: # not top2
self.proj_weights[cell_type][eid].data.fill_(0)
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
self.candidate_flags_edge[cell_type][nid] = False
#### critical function
def get_projected_weights(self, cell_type):
''' used in forward and genotype '''
weights = self.get_softmax()[cell_type]
## proj op
for eid in range(self.num_edges):
if not self.candidate_flags[cell_type][eid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
## proj edge
for nid in self.nid2eids:
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
for eid in self.nid2eids[nid]:
if eid not in self.nid2selected_eids[cell_type][nid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
return weights
def get_all_projected_weights(self, cell_type):
weights = self.get_softmax()[cell_type]
for eid in range(self.num_edges):
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
for nid in self.nid2eids:
for eid in self.nid2eids[nid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
return weights
def forward(self, input, weights_dict=None, using_proj=False):
if using_proj:
weights_normal = self.get_all_projected_weights('normal')
weights_reduce = self.get_all_projected_weights('reduce')
else:
if weights_dict is None or 'normal' not in weights_dict:
weights_normal = self.get_projected_weights('normal')
else:
weights_normal = weights_dict['normal']
if weights_dict is None or 'reduce' not in weights_dict:
weights_reduce = self.get_projected_weights('reduce')
else:
weights_reduce = weights_dict['reduce']
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = weights_reduce
else:
weights = weights_normal
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
def reset_arch_parameters(self):
self._initialize_flags()
self._initialize_proj_weights()
self._initialize_topology_dicts()
#### utils
def printing(self, logging, option='all'):
weights_normal = self.get_projected_weights('normal')
weights_reduce = self.get_projected_weights('reduce')
if option in ['all', 'normal']:
logging.info('\n%s', weights_normal)
if option in ['all', 'reduce']:
logging.info('\n%s', weights_reduce)
def genotype(self):
def _parse(weights, normal=True):
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
try:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
except ValueError:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if 'none' in PRIMITIVES[j]:
if k != PRIMITIVES[j].index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
else:
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[start+j][k_best], j))
start = end
n += 1
return gene
weights_normal = self.get_projected_weights('normal')
weights_reduce = self.get_projected_weights('reduce')
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
concat = range(2+self._steps-self._multiplier, self._steps+2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype
def get_state_dict(self, epoch, architect, scheduler):
model_state_dict = {
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
'state_dict': self.state_dict(),
'alpha': self.arch_parameters(),
'optimizer': self.optimizer.state_dict(),
'arch_optimizer': architect.optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
#### projection
'nid2eids': self.nid2eids,
'nid2selected_eids': self.nid2selected_eids,
'candidate_flags': self.candidate_flags,
'candidate_flags_edge': self.candidate_flags_edge,
'proj_weights': self.proj_weights,
}
return model_state_dict
def set_state_dict(self, architect, scheduler, checkpoint):
#### common
self.load_state_dict(checkpoint['state_dict'])
self.set_arch_parameters(checkpoint['alpha'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
#### projection
self.nid2eids = checkpoint['nid2eids']
self.nid2selected_eids = checkpoint['nid2selected_eids']
self.candidate_flags = checkpoint['candidate_flags']
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
self.proj_weights = checkpoint['proj_weights']

View File

@@ -0,0 +1,214 @@
import torch
from copy import deepcopy
import torch.nn as nn
from sota.cnn.operations import *
from sota.cnn.genotypes import Genotype
import sys
sys.path.insert(0, '../../')
from sota.cnn.model_search import Network
class ImageNetNetworkProj(Network):
def __init__(self, C, num_classes, layers, criterion, primitives, args,
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0, nettype='imagenet'):
super(ImageNetNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob, nettype=nettype)
self._initialize_flags()
self._initialize_proj_weights()
self._initialize_topology_dicts()
#### proj flags
def _initialize_topology_dicts(self):
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
self.nid2selected_eids = {
'normal': {0:[],1:[],2:[]},
'reduce': {0:[],1:[],2:[]},
}
def _initialize_flags(self):
self.candidate_flags = {
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
} # must be in this order
self.candidate_flags_edge = {
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
}
def _initialize_proj_weights(self):
''' data structures used for proj '''
if isinstance(self.alphas_normal, list):
alphas_normal = torch.stack(self.alphas_normal, dim=0)
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
else:
alphas_normal = self.alphas_normal
alphas_reduce = self.alphas_reduce
self.proj_weights = { # for hard/soft assignment after project
'normal': torch.zeros_like(alphas_normal),
'reduce': torch.zeros_like(alphas_reduce),
}
#### proj function
def project_op(self, eid, opid, cell_type):
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
self.candidate_flags[cell_type][eid] = False
def project_edge(self, nid, eids, cell_type):
for eid in self.nid2eids[nid]:
if eid not in eids: # not top2
self.proj_weights[cell_type][eid].data.fill_(0)
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
self.candidate_flags_edge[cell_type][nid] = False
#### critical function
def get_projected_weights(self, cell_type):
''' used in forward and genotype '''
weights = self.get_softmax()[cell_type]
## proj op
for eid in range(self.num_edges):
if not self.candidate_flags[cell_type][eid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
## proj edge
for nid in self.nid2eids:
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
for eid in self.nid2eids[nid]:
if eid not in self.nid2selected_eids[cell_type][nid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
return weights
def get_all_projected_weights(self, cell_type):
weights = self.get_softmax()[cell_type]
for eid in range(self.num_edges):
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
for nid in self.nid2eids:
for eid in self.nid2eids[nid]:
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
return weights
def forward(self, input, weights_dict=None, using_proj=False):
if using_proj:
weights_normal = self.get_all_projected_weights('normal')
weights_reduce = self.get_all_projected_weights('reduce')
else:
if weights_dict is None or 'normal' not in weights_dict:
weights_normal = self.get_projected_weights('normal')
else:
weights_normal = weights_dict['normal']
if weights_dict is None or 'reduce' not in weights_dict:
weights_reduce = self.get_projected_weights('reduce')
else:
weights_reduce = weights_dict['reduce']
s0 = self.stem0(input)
s1 = self.stem1(s0)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = weights_reduce
else:
weights = weights_normal
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
def reset_arch_parameters(self):
self._initialize_flags()
self._initialize_proj_weights()
self._initialize_topology_dicts()
#### utils
def printing(self, logging, option='all'):
weights_normal = self.get_projected_weights('normal')
weights_reduce = self.get_projected_weights('reduce')
if option in ['all', 'normal']:
logging.info('\n%s', weights_normal)
if option in ['all', 'reduce']:
logging.info('\n%s', weights_reduce)
def genotype(self):
def _parse(weights, normal=True):
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
try:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
except ValueError:
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if 'none' in PRIMITIVES[j]:
if k != PRIMITIVES[j].index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
else:
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[start+j][k_best], j))
start = end
n += 1
return gene
weights_normal = self.get_projected_weights('normal')
weights_reduce = self.get_projected_weights('reduce')
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
concat = range(2+self._steps-self._multiplier, self._steps+2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype
def get_state_dict(self, epoch, architect, scheduler):
model_state_dict = {
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
'state_dict': self.state_dict(),
'alpha': self.arch_parameters(),
'optimizer': self.optimizer.state_dict(),
'arch_optimizer': architect.optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
#### projection
'nid2eids': self.nid2eids,
'nid2selected_eids': self.nid2selected_eids,
'candidate_flags': self.candidate_flags,
'candidate_flags_edge': self.candidate_flags_edge,
'proj_weights': self.proj_weights,
}
return model_state_dict
def set_state_dict(self, architect, scheduler, checkpoint):
#### common
self.load_state_dict(checkpoint['state_dict'])
self.set_arch_parameters(checkpoint['alpha'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
#### projection
self.nid2eids = checkpoint['nid2eids']
self.nid2selected_eids = checkpoint['nid2selected_eids']
self.candidate_flags = checkpoint['candidate_flags']
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
self.proj_weights = checkpoint['proj_weights']

View File

@@ -0,0 +1,236 @@
import os
import sys
sys.path.insert(0, '../../')
import time
import glob
import random
import numpy as np
import torch
import shutil
import nasbench201.utils as ig_utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import json
import copy
from sota.cnn.model_search import Network as DartsNetwork
from sota.cnn.model_search_darts_proj import DartsNetworkProj
from sota.cnn.model_search_imagenet_proj import ImageNetNetworkProj
# from optimizers.darts.architect import Architect as DartsArchitect
from nasbench201.architect_ig import Architect
from sota.cnn.spaces import spaces_dict
from foresight.pruners import *
from torch.utils.tensorboard import SummaryWriter
from sota.cnn.init_projection import pt_project
from hdf5 import H5Dataset
torch.set_printoptions(precision=4, sci_mode=False)
parser = argparse.ArgumentParser("sota")
parser.add_argument('--data', type=str, default='../../data',help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--seed', type=int, default=666, help='random seed')
#model opt related config
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
#system config
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
parser.add_argument('--save_path', type=str, default='../../experiments/sota', help='experiment name')
#search sapce config
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--search_space', type=str, default='s5', help='searching space to choose from')
parser.add_argument('--pool_size', type=int, default=10, help='number of model to proposed')
## projection
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'sample'], help='used for both proj_op and proj_edge')
parser.add_argument('--proj_crit_normal', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
parser.add_argument('--proj_crit_reduce', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
parser.add_argument('--proj_crit_edge', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
parser.add_argument('--proj_mode_edge', type=str, default='reg', choices=['reg'],
help='edge projection evaluation mode, reg: one edge at a time')
args = parser.parse_args()
#### args augment
expid = args.save
args.save = '{}/{}-search-{}-{}-{}-{}-{}'.format(args.save_path,
args.dataset, args.save, args.search_space, args.seed, args.pool_size, args.proj_crit_normal)
if not args.edge_decision == 'random':
args.save += '-' + args.edge_decision
scripts_to_save = glob.glob('*.py') + glob.glob('../../nasbench201/architect*.py') + glob.glob('../../optimizers/darts/architect.py')
if os.path.exists(args.save):
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
print('proceed to override saving directory')
shutil.rmtree(args.save)
else:
exit(0)
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
#### logging
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
log_file = 'log.txt'
log_path = os.path.join(args.save, log_file)
logging.info('======> log filename: %s', log_file)
if os.path.exists(log_path):
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
print('proceed to override log file directory')
else:
exit(0)
fh = logging.FileHandler(log_path, mode='w')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet':
n_classes = 1000
else:
n_classes = 10
def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
torch.cuda.set_device(gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % gpu)
logging.info("args = %s", args)
#### model
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
## darts
if args.dataset == 'imagenet':
model = ImageNetNetworkProj(args.init_channels, n_classes, args.layers, criterion, spaces_dict[args.search_space], args)
else:
model = DartsNetworkProj(args.init_channels, n_classes, args.layers, criterion, spaces_dict[args.search_space], args)
model = model.cuda()
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
#### data
if args.dataset == 'imagenet':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
#for test
#from nasbench201.DownsampledImageNet import ImageNet16
# train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
# n_classes = 10
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
#valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
else:
if args.dataset == 'cifar10':
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'svhn':
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True)
# for x, y in train_queue:
# from torchvision import transforms
# unloader = transforms.ToPILImage()
# image = x.cpu().clone() # clone the tensor
# image = image.squeeze(0) # remove the fake batch dimension
# image = unloader(image)
# image.save('example.jpg')
# print(x.size())
# exit()
#### projection
networks_pool={}
networks_pool['search_space'] = args.search_space
networks_pool['dataset'] = args.dataset
networks_pool['networks'] = []
for i in range(args.pool_size):
network_info={}
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
pt_project(train_queue, model, args)
## logging
num_params = ig_utils.count_parameters_in_Compact(model)
genotype = model.genotype()
json_data = {}
json_data['normal'] = genotype.normal
json_data['normal_concat'] = [x for x in genotype.normal_concat]
json_data['reduce'] = genotype.reduce
json_data['reduce_concat'] = [x for x in genotype.reduce_concat]
json_string = json.dumps(json_data)
logging.info(json_string)
network_info['id'] = str(i)
network_info['genotype'] = json_string
networks_pool['networks'].append(network_info)
model.reset_arch_parameters()
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
json.dump(networks_pool, save_file)
if __name__ == '__main__':
main()

181
sota/cnn/operations.py Normal file
View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
OPS = {
'noise': lambda C, stride, affine: NoiseOp(stride, 0., 1.),
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3_skip': lambda C, stride, affine: SepConvSkip(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5_skip': lambda C, stride, affine: SepConvSkip(C, C, 5, stride, 2, affine=affine),
'dil_conv_3x3_skip': lambda C, stride, affine: DilConvSkip(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5_skip': lambda C, stride, affine: DilConvSkip(C, C, 5, stride, 4, 2, affine=affine),
}
class NoiseOp(nn.Module):
def __init__(self, stride, mean, std):
super(NoiseOp, self).__init__()
self.stride = stride
self.mean = mean
self.std = std
def forward(self, x, block_input=False):
if block_input:
x = x*0
if self.stride != 1:
x_new = x[:,:,::self.stride,::self.stride]
else:
x_new = x
noise = Variable(x_new.data.new(x_new.size()).normal_(self.mean, self.std))
return noise
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x, block_input=False):
if block_input:
x = x*0
return self.op(x)
class DilConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x, block_input=False):
if block_input:
x = x*0
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x, block_input=False):
if block_input:
x = x*0
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x, block_input=False):
if block_input:
x = x*0
return x
class Zero(nn.Module):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
def forward(self, x, block_input=False):
if block_input:
x = x*0
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x, block_input=False):
if block_input:
x = x*0
x = self.relu(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
#### operations with skip
class DilConvSkip(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConvSkip, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x, block_input=False):
if block_input:
x = x*0
return self.op(x) + x
class SepConvSkip(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConvSkip, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
def forward(self, x, block_input=False):
if block_input:
x = x*0
return self.op(x) + x

248
sota/cnn/projection.py Normal file
View File

@@ -0,0 +1,248 @@
import os
import sys
sys.path.insert(0, '../../')
import numpy as np
import torch
import nasbench201.utils as ig_utils
import logging
import torch.utils
from copy import deepcopy
torch.set_printoptions(precision=4, sci_mode=False)
def project_op(model, proj_queue, args, infer, cell_type, selected_eid=None):
''' operation '''
#### macros
num_edges, num_ops = model.num_edges, model.num_ops
candidate_flags = model.candidate_flags[cell_type]
proj_crit = args.proj_crit[cell_type]
#### select an edge
if selected_eid is None:
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
if args.edge_decision == "random":
selected_eid = np.random.choice(remain_eids, size=1)[0]
logging.info('selected edge: %d %s', selected_eid, cell_type)
#### select the best operation
if proj_crit == 'loss':
crit_idx = 1
compare = lambda x, y: x > y
elif proj_crit == 'acc':
crit_idx = 0
compare = lambda x, y: x < y
best_opid = 0
crit_extrema = None
for opid in range(num_ops):
## projection
weights = model.get_projected_weights(cell_type)
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
weights_dict = {cell_type:weights}
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
crit = valid_stats[crit_idx]
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
logging.info('valid_acc %f', valid_stats[0])
logging.info('valid_loss %f', valid_stats[1])
#### project
logging.info('best opid: %d', best_opid)
return selected_eid, best_opid
def project_edge(model, proj_queue, args, infer, cell_type):
''' topology '''
#### macros
candidate_flags = model.candidate_flags_edge[cell_type]
proj_crit = args.proj_crit[cell_type]
#### select an edge
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
if args.edge_decision == "random":
selected_nid = np.random.choice(remain_nids, size=1)[0]
logging.info('selected node: %d %s', selected_nid, cell_type)
#### select top2 edges
if proj_crit == 'loss':
crit_idx = 1
compare = lambda x, y: x > y
elif proj_crit == 'acc':
crit_idx = 0
compare = lambda x, y: x < y
eids = deepcopy(model.nid2eids[selected_nid])
while len(eids) > 2:
eid_todel = None
crit_extrema = None
for eid in eids:
weights = model.get_projected_weights(cell_type)
weights[eid].data.fill_(0)
weights_dict = {cell_type:weights}
## proj evaluation
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
crit = valid_stats[crit_idx]
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
crit_extrema = crit
eid_todel = eid
logging.info('valid_acc %f', valid_stats[0])
logging.info('valid_loss %f', valid_stats[1])
eids.remove(eid_todel)
#### project
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
return selected_nid, eids
def pt_project(train_queue, valid_queue, model, architect, optimizer,
epoch, args, infer, perturb_alpha, epsilon_alpha):
model.train()
model.printing(logging)
train_acc, train_obj = infer(train_queue, model, log=False)
logging.info('train_acc %f', train_acc)
logging.info('train_loss %f', train_obj)
valid_acc, valid_obj = infer(valid_queue, model, log=False)
logging.info('valid_acc %f', valid_acc)
logging.info('valid_loss %f', valid_obj)
objs = ig_utils.AvgrageMeter()
top1 = ig_utils.AvgrageMeter()
top5 = ig_utils.AvgrageMeter()
#### macros
num_projs = model.num_edges + len(model.nid2eids.keys()) - 1 ## -1 because we project at both epoch 0 and -1
tune_epochs = args.proj_intv * num_projs + 1
proj_intv = args.proj_intv
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
proj_queue = valid_queue
#### reset optimizer
model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
model.optimizer, float(tune_epochs), eta_min=args.learning_rate_min)
#### load proj checkpoints
start_epoch = 0
if args.dev_resume_epoch >= 0:
filename = os.path.join(args.dev_resume_checkpoint_dir, 'checkpoint_{}.pth.tar'.format(args.dev_resume_epoch))
if os.path.isfile(filename):
logging.info("=> loading projection checkpoint '{}'".format(filename))
checkpoint = torch.load(filename, map_location='cpu')
start_epoch = checkpoint['epoch']
model.set_state_dict(architect, scheduler, checkpoint)
model.set_arch_parameters(checkpoint['alpha'])
scheduler.load_state_dict(checkpoint['scheduler'])
model.optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
else:
logging.info("=> no checkpoint found at '{}'".format(filename))
exit(0)
#### projecting and tuning
for epoch in range(start_epoch, tune_epochs):
logging.info('epoch %d', epoch)
## project
if epoch % proj_intv == 0 or epoch == tune_epochs - 1:
## saving every projection
save_state_dict = model.get_state_dict(epoch, architect, scheduler)
ig_utils.save_checkpoint(save_state_dict, False, args.dev_save_checkpoint_dir, per_epoch=True)
if epoch < proj_intv * model.num_edges:
logging.info('project op')
selected_eid_normal, best_opid_normal = project_op(model, proj_queue, args, infer, cell_type='normal')
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
selected_eid_reduce, best_opid_reduce = project_op(model, proj_queue, args, infer, cell_type='reduce')
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
model.printing(logging)
else:
logging.info('project edge')
selected_nid_normal, eids_normal = project_edge(model, proj_queue, args, infer, cell_type='normal')
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
selected_nid_reduce, eids_reduce = project_edge(model, proj_queue, args, infer, cell_type='reduce')
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
model.printing(logging)
## tune
for step, (input, target) in enumerate(train_queue):
model.train()
n = input.size(0)
## fetch data
input = input.cuda()
target = target.cuda(non_blocking=True)
input_search, target_search = next(iter(valid_queue))
input_search = input_search.cuda()
target_search = target_search.cuda(non_blocking=True)
## train alpha
optimizer.zero_grad(); architect.optimizer.zero_grad()
architect.step(input, target, input_search, target_search,
return_logits=True)
## sdarts
if perturb_alpha:
# transform arch_parameters to prob (for perturbation)
model.softmax_arch_parameters()
optimizer.zero_grad(); architect.optimizer.zero_grad()
perturb_alpha(model, input, target, epsilon_alpha)
## train weight
optimizer.zero_grad(); architect.optimizer.zero_grad()
logits, loss = model.step(input, target, args)
## sdarts
if perturb_alpha:
## restore alpha to unperturbed arch_parameters
model.restore_arch_parameters()
## logging
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if args.fast:
break
## one epoch end
model.printing(logging)
train_acc, train_obj = infer(train_queue, model, log=False)
logging.info('train_acc %f', train_acc)
logging.info('train_loss %f', train_obj)
valid_acc, valid_obj = infer(valid_queue, model, log=False)
logging.info('valid_acc %f', valid_acc)
logging.info('valid_loss %f', valid_obj)
logging.info('projection finished')
model.printing(logging)
num_params = ig_utils.count_parameters_in_Compact(model)
genotype = model.genotype()
logging.info('param size = %f', num_params)
logging.info('genotype = %s', genotype)
return

103
sota/cnn/spaces.py Normal file
View File

@@ -0,0 +1,103 @@
from collections import OrderedDict
primitives_1 = OrderedDict([('primitives_normal', [['skip_connect',
'dil_conv_3x3'],
['skip_connect',
'dil_conv_5x5'],
['skip_connect',
'dil_conv_5x5'],
['skip_connect',
'sep_conv_3x3'],
['skip_connect',
'dil_conv_3x3'],
['max_pool_3x3',
'skip_connect'],
['skip_connect',
'sep_conv_3x3'],
['skip_connect',
'sep_conv_3x3'],
['skip_connect',
'dil_conv_3x3'],
['skip_connect',
'sep_conv_3x3'],
['max_pool_3x3',
'skip_connect'],
['skip_connect',
'dil_conv_3x3'],
['dil_conv_3x3',
'dil_conv_5x5'],
['dil_conv_3x3',
'dil_conv_5x5']]),
('primitives_reduct', [['max_pool_3x3',
'avg_pool_3x3'],
['max_pool_3x3',
'dil_conv_3x3'],
['max_pool_3x3',
'avg_pool_3x3'],
['max_pool_3x3',
'avg_pool_3x3'],
['skip_connect',
'dil_conv_5x5'],
['max_pool_3x3',
'avg_pool_3x3'],
['max_pool_3x3',
'sep_conv_3x3'],
['skip_connect',
'dil_conv_3x3'],
['skip_connect',
'dil_conv_5x5'],
['max_pool_3x3',
'avg_pool_3x3'],
['max_pool_3x3',
'avg_pool_3x3'],
['skip_connect',
'dil_conv_5x5'],
['skip_connect',
'dil_conv_5x5'],
['skip_connect',
'dil_conv_5x5']])])
primitives_2 = OrderedDict([('primitives_normal', 14 * [['skip_connect',
'sep_conv_3x3']]),
('primitives_reduct', 14 * [['skip_connect',
'sep_conv_3x3']])])
primitives_3 = OrderedDict([('primitives_normal', 14 * [['none',
'skip_connect',
'sep_conv_3x3']]),
('primitives_reduct', 14 * [['none',
'skip_connect',
'sep_conv_3x3']])])
primitives_4 = OrderedDict([('primitives_normal', 14 * [['noise',
'sep_conv_3x3']]),
('primitives_reduct', 14 * [['noise',
'sep_conv_3x3']])])
PRIMITIVES = [
#'none', #0
'max_pool_3x3', # 0
'avg_pool_3x3', # 1
'skip_connect', # 2
'sep_conv_3x3', # 3
'sep_conv_5x5', # 4
'dil_conv_3x3', # 5
'dil_conv_5x5' # 6
]
primitives_5 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES]),
('primitives_reduct', 14 * [PRIMITIVES])])
primitives_6 = OrderedDict([('primitives_normal', 14 * [['sep_conv_5x5']]),
('primitives_reduct', 14 * [['sep_conv_5x5']])])
spaces_dict = {
's1': primitives_1,
's2': primitives_2,
's3': primitives_3,
's4': primitives_4,
's5': primitives_5, # DARTS Space
's6': primitives_6,
}

309
sota/cnn/train.py Normal file
View File

@@ -0,0 +1,309 @@
import os
import random
import sys
sys.path.insert(0, '../../')
import glob
import numpy as np
import torch
import nasbench201.utils as ig_utils
import logging
import argparse
import shutil
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import json
from sota.cnn.model import Network
from torch.utils.tensorboard import SummaryWriter
from collections import namedtuple
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../../data',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--save', type=str, default='exp', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='c100_s4_pgd', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
#### common
parser.add_argument('--resume_epoch', type=int, default=0, help="load ckpt, start training at resume_epoch")
parser.add_argument('--ckpt_interval', type=int, default=50, help="interval (epoch) for saving checkpoints")
parser.add_argument('--resume_expid', type=str, default='', help="full expid to resume from, name == ckpt folder name")
parser.add_argument('--fast', action='store_true', default=False, help="fast mode for debugging")
parser.add_argument('--queue', action='store_true', default=False, help="queueing for gpu")
parser.add_argument('--from_dir', action='store_true', default=True, help="arch load form dir")
args = parser.parse_args()
def load_network_pool(ckpt_path):
with open(os.path.join(ckpt_path, 'best_networks.json'), 'r') as save_file:
networks_pool = json.load(save_file)
return networks_pool['networks']
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
#### args augment
expid = args.save
print(args.from_dir)
if args.from_dir:
id_name = os.path.split(args.arch)[1]
# print('aaaaaaa', args.arch)
args.arch = load_network_pool(args.arch)
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
args.dataset, args.save, id_name, args.seed)
else:
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
args.dataset, args.save, args.arch, args.seed)
if args.cutout:
args.save += '-cutout-' + str(args.cutout_length) + '-' + str(args.cutout_prob)
if args.auxiliary:
args.save += '-auxiliary-' + str(args.auxiliary_weight)
#### logging
if args.resume_epoch > 0: # do not delete dir if resume:
args.save = '../../experiments/sota/{}/{}'.format(args.dataset, args.resume_expid)
assert (os.path.exists(args.save), 'resume but {} does not exist!'.format(args.save))
else:
scripts_to_save = glob.glob('*.py')
if os.path.exists(args.save):
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
print('proceed to override saving directory')
shutil.rmtree(args.save)
else:
exit(0)
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
log_file = 'log_resume_{}.txt'.format(args.resume_epoch) if args.resume_epoch > 0 else 'log.txt'
fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
if args.dataset == 'cifar100':
n_classes = 100
else:
n_classes = 10
def seed_torch(seed=0):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def main():
torch.set_num_threads(3)
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
#### gpu queueing
if args.queue:
ig_utils.queue_gpu()
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
torch.cuda.set_device(gpu)
cudnn.enabled = True
seed_torch(args.seed)
logging.info('gpu device = %d' % gpu)
logging.info("args = %s", args)
if args.from_dir:
genotype_config = json.loads(args.arch)
genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'],
reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat'])
else:
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, n_classes, args.layers, args.auxiliary, genotype)
model = model.cuda()
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay
)
if args.dataset == 'cifar10':
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'cifar100':
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
elif args.dataset == 'svhn':
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(args.epochs),
# eta_min=1e-4
)
#### resume
start_epoch = 0
if args.resume_epoch > 0:
logging.info('loading checkpoint from {}'.format(expid))
filename = os.path.join(args.save, 'checkpoint_{}.pth.tar'.format(args.resume_epoch))
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename, map_location='cpu')
resume_epoch = checkpoint['epoch'] # epoch
model.load_state_dict(checkpoint['state_dict']) # model
scheduler.load_state_dict(checkpoint['scheduler'])
optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
start_epoch = args.resume_epoch
print("=> loaded checkpoint '{}' (epoch {})".format(filename, resume_epoch))
else:
print("=> no checkpoint found at '{}'".format(filename))
#### main training
best_valid_acc = 0
for epoch in range(start_epoch, args.epochs):
lr = scheduler.get_lr()[0]
if args.cutout:
train_transform.transforms[-1].cutout_prob = args.cutout_prob
logging.info('epoch %d lr %e cutout_prob %e', epoch, lr,
train_transform.transforms[-1].cutout_prob)
else:
logging.info('epoch %d lr %e', epoch, lr)
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
logging.info('train_acc %f', train_acc)
writer.add_scalar('Acc/train', train_acc, epoch)
writer.add_scalar('Obj/train', train_obj, epoch)
## scheduler
scheduler.step()
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc %f', valid_acc)
writer.add_scalar('Acc/valid', valid_acc, epoch)
writer.add_scalar('Obj/valid', valid_obj, epoch)
## checkpoint
if (epoch + 1) % args.ckpt_interval == 0:
save_state_dict = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}
ig_utils.save_checkpoint(save_state_dict, False, args.save, per_epoch=True)
best_valid_acc = max(best_valid_acc, valid_acc)
logging.info('best valid_acc %f', best_valid_acc)
writer.close()
def train(train_queue, model, criterion, optimizer):
objs = ig_utils.AvgrageMeter()
top1 = ig_utils.AvgrageMeter()
top5 = ig_utils.AvgrageMeter()
model.train()
for step, (input, target) in enumerate(train_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
optimizer.zero_grad()
logits, logits_aux = model(input)
loss = criterion(logits, target)
if args.auxiliary:
loss_aux = criterion(logits_aux, target)
loss += args.auxiliary_weight * loss_aux
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if args.fast:
logging.info('//// WARNING: FAST MODE')
break
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = ig_utils.AvgrageMeter()
top1 = ig_utils.AvgrageMeter()
top5 = ig_utils.AvgrageMeter()
model.eval()
with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
if args.fast:
logging.info('//// WARNING: FAST MODE')
break
return top1.avg, objs.avg
if __name__ == '__main__':
main()

254
sota/cnn/train_imagenet.py Normal file
View File

@@ -0,0 +1,254 @@
from torch.utils.tensorboard import SummaryWriter
import argparse
import glob
import logging
import sys
sys.path.insert(0, '../../')
import time
import random
import numpy as np
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import nasbench201.utils as utils
from sota.cnn.model_imagenet import NetworkImageNet as Network
import sota.cnn.genotypes as genotypes
from sota.cnn.hdf5 import H5Dataset
parser = argparse.ArgumentParser("imagenet")
parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=0, help='random_ws seed')
parser.add_argument('--arch', type=str, default='c10_s3_pgd', help='which architecture to use')
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay')
parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays')
parser.add_argument('--parallel', action='store_true', default=False, help='darts parallelism')
parser.add_argument('--load', action='store_true', default=False, help='whether load checkpoint for continue training')
args = parser.parse_args()
args.save = '../../experiments/sota/imagenet/eval/{}-{}-{}-{}'.format(
args.save, time.strftime("%Y%m%d-%H%M%S"), args.arch, args.seed)
if args.auxiliary:
args.save += '-auxiliary-' + str(args.auxiliary_weight)
args.save += '-' + str(np.random.randint(10000))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(args.save + '/runs')
CLASSES = 1000
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def seed_torch(seed=0):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
torch.cuda.set_device(args.gpu)
cudnn.enabled = True
seed_torch(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
if args.parallel:
model = nn.DataParallel(model).cuda()
else:
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
criterion_smooth = criterion_smooth.cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay
)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
if args.load:
model, optimizer, start_epoch, best_acc_top1 = utils.load_checkpoint(
model, optimizer, '../../experiments/sota/imagenet/eval/EXP-20200210-143540-c10_s3_pgd-0-auxiliary-0.4-2753')
else:
best_acc_top1 = 0
start_epoch = 0
for epoch in range(start_epoch, args.epochs):
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
logging.info('train_acc %f', train_acc)
writer.add_scalar('Acc/train', train_acc, epoch)
writer.add_scalar('Obj/train', train_obj, epoch)
scheduler.step()
valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc_top1 %f', valid_acc_top1)
logging.info('valid_acc_top5 %f', valid_acc_top5)
writer.add_scalar('Acc/valid_top1', valid_acc_top1, epoch)
writer.add_scalar('Acc/valid_top5', valid_acc_top5, epoch)
is_best = False
if valid_acc_top1 > best_acc_top1:
best_acc_top1 = valid_acc_top1
is_best = True
utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_acc_top1': best_acc_top1,
'optimizer': optimizer.state_dict(),
}, is_best, args.save)
def train(train_queue, model, criterion, optimizer):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.train()
for step, (input, target) in enumerate(train_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
optimizer.zero_grad()
logits, logits_aux = model(input)
loss = criterion(logits, target)
if args.auxiliary:
loss_aux = criterion(logits_aux, target)
loss += args.auxiliary_weight * loss_aux
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
with torch.no_grad():
for step, (input, target) in enumerate(valid_queue):
input = input.cuda()
target = target.cuda(non_blocking=True)
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.data, n)
top1.update(prec1.data, n)
top5.update(prec5.data, n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, top5.avg, objs.avg
if __name__ == '__main__':
main()

67
sota/cnn/visualize.py Normal file
View File

@@ -0,0 +1,67 @@
import sys
import genotypes
from graphviz import Digraph
def plot(genotype, filename, mode=''):
g = Digraph(
format='pdf',
edge_attr=dict(fontsize='40', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center', fontsize='40', height='0.5', width='0.5',
penwidth='2', fontname="times"),
engine='dot')
g.body.extend(['rankdir=LR'])
# g.body.extend(['ratio=0.15'])
# g.view()
g.node("c_{k-2}", fillcolor='darkseagreen2')
g.node("c_{k-1}", fillcolor='darkseagreen2')
assert len(genotype) % 2 == 0
steps = len(genotype) // 2
for i in range(steps):
g.node(str(i), fillcolor='lightblue')
for i in range(steps):
for k in [2 * i, 2 * i + 1]:
op, j = genotype[k]
if j == 0:
u = "c_{k-2}"
elif j == 1:
u = "c_{k-1}"
else:
u = str(j - 2)
v = str(i)
if mode == 'cue' and op != 'skip_connect' and op != 'noise':
g.edge(u, v, label=op, fillcolor='gray', color='red', fontcolor='red')
else:
g.edge(u, v, label=op, fillcolor="gray")
g.node("c_{k}", fillcolor='palegoldenrod')
for i in range(steps):
g.edge(str(i), "c_{k}", fillcolor="gray")
g.render(filename, view=False)
if __name__ == '__main__':
if len(sys.argv) != 2:
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
sys.exit(1)
genotype_name = sys.argv[1]
try:
genotype = eval('genotypes.{}'.format(genotype_name))
# print(genotype)
except AttributeError:
print("{} is not specified in genotypes.py".format(genotype_name))
sys.exit(1)
mode = 'cue'
path = '../../figs/genotypes/cnn_{}/'.format(mode)
# print(genotype.normal)
plot(genotype.normal, path + genotype_name + "_normal", mode=mode)
plot(genotype.reduce, path + genotype_name + "_reduce", mode=mode)

144
sota/cnn/visualize_full.py Normal file
View File

@@ -0,0 +1,144 @@
import sys
import genotypes
import numpy as np
from graphviz import Digraph
supernet_dict = {
0: ('c_{k-2}', '0'),
1: ('c_{k-1}', '0'),
2: ('c_{k-2}', '1'),
3: ('c_{k-1}', '1'),
4: ('0', '1'),
5: ('c_{k-2}', '2'),
6: ('c_{k-1}', '2'),
7: ('0', '2'),
8: ('1', '2'),
9: ('c_{k-2}', '3'),
10: ('c_{k-1}', '3'),
11: ('0', '3'),
12: ('1', '3'),
13: ('2', '3'),
}
steps = 4
def plot_space(primitives, filename):
g = Digraph(
format='pdf',
edge_attr=dict(fontsize='20', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
engine='dot')
g.body.extend(['rankdir=LR'])
g.body.extend(['ratio=50.0'])
g.node("c_{k-2}", fillcolor='darkseagreen2')
g.node("c_{k-1}", fillcolor='darkseagreen2')
steps = 4
for i in range(steps):
g.node(str(i), fillcolor='lightblue')
n = 2
start = 0
nodes_indx = ["c_{k-2}", "c_{k-1}"]
for i in range(steps):
end = start + n
p = primitives[start:end]
v = str(i)
for node, prim in zip(nodes_indx, p):
u = node
for op in prim:
g.edge(u, v, label=op, fillcolor="gray")
start = end
n += 1
nodes_indx.append(v)
g.node("c_{k}", fillcolor='palegoldenrod')
for i in range(steps):
g.edge(str(i), "c_{k}", fillcolor="gray")
g.render(filename, view=False)
def plot(genotype, filename):
g = Digraph(
format='pdf',
edge_attr=dict(fontsize='100', fontname="times"),
node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5', penwidth='2', fontname="times"),
engine='dot')
g.body.extend(['rankdir=LR'])
g.body.extend(['ratio=0.3'])
g.node("c_{k-2}", fillcolor='darkseagreen2')
g.node("c_{k-1}", fillcolor='darkseagreen2')
num_edges = len(genotype)
for i in range(steps):
g.node(str(i), fillcolor='lightblue')
for eid in range(num_edges):
op = genotype[eid]
u, v = supernet_dict[eid]
if op != 'skip_connect':
g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red')
else:
g.edge(u, v, label=op, fillcolor="gray")
g.node("c_{k}", fillcolor='palegoldenrod')
for i in range(steps):
g.edge(str(i), "c_{k}", fillcolor="gray")
g.render(filename, view=False)
# def plot(genotype, filename):
# g = Digraph(
# format='pdf',
# edge_attr=dict(fontsize='100', fontname="times", penwidth='3'),
# node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5',
# penwidth='2', fontname="times"),
# engine='dot')
# g.body.extend(['rankdir=LR'])
# g.node("c_{k-2}", fillcolor='darkseagreen2')
# g.node("c_{k-1}", fillcolor='darkseagreen2')
# num_edges = len(genotype)
# for i in range(steps):
# g.node(str(i), fillcolor='lightblue')
# for eid in range(num_edges):
# op = genotype[eid]
# u, v = supernet_dict[eid]
# if op != 'skip_connect':
# g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red')
# else:
# g.edge(u, v, label=op, fillcolor="gray")
# g.node("c_{k}", fillcolor='palegoldenrod')
# for i in range(steps):
# g.edge(str(i), "c_{k}", fillcolor="gray")
# g.render(filename, view=False)
if __name__ == '__main__':
#### visualize the supernet ####
if len(sys.argv) != 2:
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
sys.exit(1)
genotype_name = sys.argv[1]
assert 'supernet' in genotype_name, 'this script only supports supernet visualization'
try:
genotype = eval('genotypes.{}'.format(genotype_name))
except AttributeError:
print("{} is not specified in genotypes.py".format(genotype_name))
sys.exit(1)
path = '../../figs/genotypes/cnn_supernet_cue/'
plot(genotype.normal, path + genotype_name + "_normal")
plot(genotype.reduce, path + genotype_name + "_reduce")

972
toy_model.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,54 @@
GitPython was originally written by Michael Trier.
GitPython 0.2 was partially (re)written by Sebastian Thiel, based on 0.1.6 and git-dulwich.
Contributors are:
-Michael Trier <mtrier _at_ gmail.com>
-Alan Briolat
-Florian Apolloner <florian _at_ apolloner.eu>
-David Aguilar <davvid _at_ gmail.com>
-Jelmer Vernooij <jelmer _at_ samba.org>
-Steve Frécinaux <code _at_ istique.net>
-Kai Lautaportti <kai _at_ lautaportti.fi>
-Paul Sowden <paul _at_ idontsmoke.co.uk>
-Sebastian Thiel <byronimo _at_ gmail.com>
-Jonathan Chu <jonathan.chu _at_ me.com>
-Vincent Driessen <me _at_ nvie.com>
-Phil Elson <pelson _dot_ pub _at_ gmail.com>
-Bernard `Guyzmo` Pratz <guyzmo+gitpython+pub@m0g.net>
-Timothy B. Hartman <tbhartman _at_ gmail.com>
-Konstantin Popov <konstantin.popov.89 _at_ yandex.ru>
-Peter Jones <pjones _at_ redhat.com>
-Anson Mansfield <anson.mansfield _at_ gmail.com>
-Ken Odegard <ken.odegard _at_ gmail.com>
-Alexis Horgix Chotard
-Piotr Babij <piotr.babij _at_ gmail.com>
-Mikuláš Poul <mikulaspoul _at_ gmail.com>
-Charles Bouchard-Légaré <cblegare.atl _at_ ntis.ca>
-Yaroslav Halchenko <debian _at_ onerussian.com>
-Tim Swast <swast _at_ google.com>
-William Luc Ritchie
-David Host <hostdm _at_ outlook.com>
-A. Jesse Jiryu Davis <jesse _at_ emptysquare.net>
-Steven Whitman <ninloot _at_ gmail.com>
-Stefan Stancu <stefan.stancu _at_ gmail.com>
-César Izurieta <cesar _at_ caih.org>
-Arthur Milchior <arthur _at_ milchior.fr>
-Anil Khatri <anil.soccer.khatri _at_ gmail.com>
-JJ Graham <thetwoj _at_ gmail.com>
-Ben Thayer <ben _at_ benthayer.com>
-Dries Kennes <admin _at_ dries007.net>
-Pratik Anurag <panurag247365 _at_ gmail.com>
-Harmon <harmon.public _at_ gmail.com>
-Liam Beguin <liambeguin _at_ gmail.com>
-Ram Rachum <ram _at_ rachum.com>
-Alba Mendez <me _at_ alba.sh>
-Robert Westman <robert _at_ byteflux.io>
-Hugo van Kemenade
-Hiroki Tokunaga <tokusan441 _at_ gmail.com>
-Julien Mauroy <pro.julien.mauroy _at_ gmail.com>
-Patrick Gerard
-Luke Twist <itsluketwist@gmail.com>
-Joseph Hale <me _at_ jhale.dev>
-Santos Gallegos <stsewd _at_ proton.me>
Portions derived from other open source works and are clearly marked.

View File

@@ -0,0 +1,30 @@
Copyright (C) 2008, 2009 Michael Trier and contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the GitPython project nor the names of
its contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,34 @@
Metadata-Version: 2.1
Name: GitPython
Version: 3.1.31
Summary: GitPython is a Python library used to interact with Git repositories
Home-page: https://github.com/gitpython-developers/GitPython
Author: Sebastian Thiel, Michael Trier
Author-email: byronimo@gmail.com, mtrier@gmail.com
License: BSD
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Console
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Operating System :: POSIX
Classifier: Operating System :: Microsoft :: Windows
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Typing :: Typed
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
License-File: AUTHORS
Requires-Dist: gitdb (<5,>=4.0.1)
Requires-Dist: typing-extensions (>=3.7.4.3) ; python_version < "3.8"
GitPython is a Python library used to interact with Git repositories

View File

@@ -0,0 +1,44 @@
git/__init__.py,sha256=O2tZaGpLYVQiK9lN3NucvyEoZcSFig13tAB6d2TTTL0,2342
git/cmd.py,sha256=i4IyhmCTP-72NPO5aVeWhDT6_jLmA1C2qzhsS7G2UVw,53712
git/compat.py,sha256=3wWLkD9QrZvLiV6NtNxJILwGrLE2nw_SoLqaTEPH364,2256
git/config.py,sha256=PO6qicfkKwRFlKJr9AUuDrWV0rimlmb5S2wIVLlOd7w,34581
git/db.py,sha256=dEs2Bn-iDuHyero9afw8mrXHrLE7_CDExv943iWU9WI,2244
git/diff.py,sha256=DOWd26Dk7FqnKt79zpniv19muBzdYa949TcQPqVbZGg,23434
git/exc.py,sha256=ys5ZYuvzvNN3TfcB5R_bUNRy3OEvURS5pJMdfy0Iws4,6446
git/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
git/remote.py,sha256=H88bonpIjnfozWScpQIFIccE7Soq2hfHO9ldnRCmGUY,45069
git/types.py,sha256=bA4El-NC7YNwQ9jNtkbWgT0QmmAfVs4PVSwBOE_D1Bo,3020
git/util.py,sha256=j5cjyeFibLs4Ed_ErkePf6sx1VWb95OQ4GlJUWgq6PU,39874
git/index/__init__.py,sha256=43ovvVNocVRNiQd4fLqvUMuGGmwhBQ9SsiQ46vkvk1E,89
git/index/base.py,sha256=5GnqwmhLNF9f12hUq4rQyOvqzxPF1Fdc0QOETT5r010,57523
git/index/fun.py,sha256=Y41IGlu8XqnradQXFjTGMISM45m8J256bTKs4xWR4qY,16406
git/index/typ.py,sha256=QnyWeqzU7_xnyiwOki5W633Jp9g5COqEf6B4PeW3hK8,6252
git/index/util.py,sha256=ISsWZjGiflooNr6XtElP4AhWUxQOakouvgeXC2PEepI,3475
git/objects/__init__.py,sha256=NW8HBfdZvBYe9W6IjMWafSj_DVlV2REmmrpWKrHkGVw,692
git/objects/base.py,sha256=N2NTL9hLwgKqY-VQiar8Hvn4a41Y8o_Tmi_SR0mGAS8,7857
git/objects/blob.py,sha256=FIbZTYniJ7nLsdrHuwhagFVs9tYoUIyXodRaHYLaQqs,986
git/objects/commit.py,sha256=ji9ityweewpr12mHh9w3s3ubouYNNCRTBr-LBrjrPbs,27304
git/objects/fun.py,sha256=SV3_G_jnEb_wEa5doF6AapX58StH3OGBxCAKeMyuA0I,8612
git/objects/tag.py,sha256=ZXOLK_lV9E5G2aDl5t0hYDN2hhIhGF23HILHBnZgRX0,3840
git/objects/tree.py,sha256=cSQbt3nn3cIrbVrBasB1wm2r-vzotYWhka1yDjOHf-k,14230
git/objects/util.py,sha256=M8h53ueOV32nXE6XcnKhCHzXznT7pi8JpEEGgCNicXo,22275
git/objects/submodule/__init__.py,sha256=OsMeiex7cG6ev2f35IaJ5csH-eXchSoNKCt4HXUG5Ws,93
git/objects/submodule/base.py,sha256=R4jTjBJyMjFOfDAYwsA6Q3Lt6qeFYERPE4PABACW6GE,61539
git/objects/submodule/root.py,sha256=Ev_RnGzv4hi3UqEFMHuSR-uGR7kYpwOgwZFUG31X-Hc,19568
git/objects/submodule/util.py,sha256=u2zQGFWBmryqET0XWf9BuiY1OOgWB8YCU3Wz0xdp4E4,3380
git/refs/__init__.py,sha256=PMF97jMUcivbCCEJnl2zTs-YtECNFp8rL8GHK8AitXU,203
git/refs/head.py,sha256=rZ4LbFd05Gs9sAuSU5VQRDmJZfrwMwWtBpLlmiUQ-Zg,9756
git/refs/log.py,sha256=Z8X9_ZGZrVTWz9p_-fk1N3m47G-HTRPwozoZBDd70DI,11892
git/refs/reference.py,sha256=DUx7QvYqTBeVxG53ntPfKCp3wuJyDBRIZcPCy1OD22s,5414
git/refs/remote.py,sha256=E63Bh5ig1GYrk6FE46iNtS5P6ZgODyPXot8eJw-mxts,2556
git/refs/symbolic.py,sha256=XwfeYr1Zp-fuHAoGuVAXKk4EYlsuUMVu99OjJWuWDTQ,29967
git/refs/tag.py,sha256=FNoCZ3BdDl2i5kD3si2P9hoXU9rDAZ_YK0Rn84TmKT8,4419
git/repo/__init__.py,sha256=XMpdeowJRtTEd80jAcrKSQfMu2JZGMfPlpuIYHG2ZCk,80
git/repo/base.py,sha256=uD4EL2AWUMSCHCqIk7voXoZ2iChaf5VJ1t1Abr7Zk10,54937
git/repo/fun.py,sha256=VTRODXAb_x8bazkSd8g-Pkk8M2iLVK4kPoKQY9HXjZc,12962
GitPython-3.1.31.dist-info/AUTHORS,sha256=0F09KKrRmwH3zJ4gqo1tJMVlalC9bSunDNKlRvR6q2c,2158
GitPython-3.1.31.dist-info/LICENSE,sha256=_WV__CzvY9JceMq3gI1BTdA6KC5jiTSR_RHDL5i-Z_s,1521
GitPython-3.1.31.dist-info/METADATA,sha256=zFy5SrG7Ur2UItx3seZXELCST9LBEX72wZa7Y7z7FSY,1340
GitPython-3.1.31.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
GitPython-3.1.31.dist-info/top_level.txt,sha256=0hzDuIp8obv624V3GmbqsagBWkk8ohtGU-Bc1PmTT0o,4
GitPython-3.1.31.dist-info/RECORD,,

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.37.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1 @@
gitdb<5,>=4.0.1

View File

@@ -0,0 +1,92 @@
# __init__.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
# flake8: noqa
# @PydevCodeAnalysisIgnore
from git.exc import * # @NoMove @IgnorePep8
import inspect
import os
import sys
import os.path as osp
from typing import Optional
from git.types import PathLike
__version__ = '3.1.31'
# { Initialization
def _init_externals() -> None:
"""Initialize external projects by putting them into the path"""
if __version__ == '3.1.31' and "PYOXIDIZER" not in os.environ:
sys.path.insert(1, osp.join(osp.dirname(__file__), "ext", "gitdb"))
try:
import gitdb
except ImportError as e:
raise ImportError("'gitdb' could not be found in your PYTHONPATH") from e
# END verify import
# } END initialization
#################
_init_externals()
#################
# { Imports
try:
from git.config import GitConfigParser # @NoMove @IgnorePep8
from git.objects import * # @NoMove @IgnorePep8
from git.refs import * # @NoMove @IgnorePep8
from git.diff import * # @NoMove @IgnorePep8
from git.db import * # @NoMove @IgnorePep8
from git.cmd import Git # @NoMove @IgnorePep8
from git.repo import Repo # @NoMove @IgnorePep8
from git.remote import * # @NoMove @IgnorePep8
from git.index import * # @NoMove @IgnorePep8
from git.util import ( # @NoMove @IgnorePep8
LockFile,
BlockingLockFile,
Stats,
Actor,
rmtree,
)
except GitError as exc:
raise ImportError("%s: %s" % (exc.__class__.__name__, exc)) from exc
# } END imports
__all__ = [name for name, obj in locals().items() if not (name.startswith("_") or inspect.ismodule(obj))]
# { Initialize git executable path
GIT_OK = None
def refresh(path: Optional[PathLike] = None) -> None:
"""Convenience method for setting the git executable path."""
global GIT_OK
GIT_OK = False
if not Git.refresh(path=path):
return
if not FetchInfo.refresh():
return
GIT_OK = True
# } END initialize git executable path
#################
try:
refresh()
except Exception as exc:
raise ImportError("Failed to initialize: {0}".format(exc)) from exc
#################

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# config.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
"""utilities to help provide compatibility with python 3"""
# flake8: noqa
import locale
import os
import sys
from gitdb.utils.encoding import (
force_bytes, # @UnusedImport
force_text, # @UnusedImport
)
# typing --------------------------------------------------------------------
from typing import (
Any,
AnyStr,
Dict,
IO,
Optional,
Tuple,
Type,
Union,
overload,
)
# ---------------------------------------------------------------------------
is_win: bool = os.name == "nt"
is_posix = os.name == "posix"
is_darwin = os.name == "darwin"
defenc = sys.getfilesystemencoding()
@overload
def safe_decode(s: None) -> None:
...
@overload
def safe_decode(s: AnyStr) -> str:
...
def safe_decode(s: Union[AnyStr, None]) -> Optional[str]:
"""Safely decodes a binary string to unicode"""
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode(defenc, "surrogateescape")
elif s is None:
return None
else:
raise TypeError("Expected bytes or text, but got %r" % (s,))
@overload
def safe_encode(s: None) -> None:
...
@overload
def safe_encode(s: AnyStr) -> bytes:
...
def safe_encode(s: Optional[AnyStr]) -> Optional[bytes]:
"""Safely encodes a binary string to unicode"""
if isinstance(s, str):
return s.encode(defenc)
elif isinstance(s, bytes):
return s
elif s is None:
return None
else:
raise TypeError("Expected bytes or text, but got %r" % (s,))
@overload
def win_encode(s: None) -> None:
...
@overload
def win_encode(s: AnyStr) -> bytes:
...
def win_encode(s: Optional[AnyStr]) -> Optional[bytes]:
"""Encode unicodes for process arguments on Windows."""
if isinstance(s, str):
return s.encode(locale.getpreferredencoding(False))
elif isinstance(s, bytes):
return s
elif s is not None:
raise TypeError("Expected bytes or text, but got %r" % (s,))
return None

View File

@@ -0,0 +1,897 @@
# config.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
"""Module containing module parser implementation able to properly read and write
configuration files"""
import sys
import abc
from functools import wraps
import inspect
from io import BufferedReader, IOBase
import logging
import os
import re
import fnmatch
from git.compat import (
defenc,
force_text,
is_win,
)
from git.util import LockFile
import os.path as osp
import configparser as cp
# typing-------------------------------------------------------
from typing import (
Any,
Callable,
Generic,
IO,
List,
Dict,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union,
cast,
)
from git.types import Lit_config_levels, ConfigLevels_Tup, PathLike, assert_never, _T
if TYPE_CHECKING:
from git.repo.base import Repo
from io import BytesIO
T_ConfigParser = TypeVar("T_ConfigParser", bound="GitConfigParser")
T_OMD_value = TypeVar("T_OMD_value", str, bytes, int, float, bool)
if sys.version_info[:3] < (3, 7, 2):
# typing.Ordereddict not added until py 3.7.2
from collections import OrderedDict
OrderedDict_OMD = OrderedDict
else:
from typing import OrderedDict
OrderedDict_OMD = OrderedDict[str, List[T_OMD_value]] # type: ignore[assignment, misc]
# -------------------------------------------------------------
__all__ = ("GitConfigParser", "SectionConstraint")
log = logging.getLogger("git.config")
log.addHandler(logging.NullHandler())
# invariants
# represents the configuration level of a configuration file
CONFIG_LEVELS: ConfigLevels_Tup = ("system", "user", "global", "repository")
# Section pattern to detect conditional includes.
# https://git-scm.com/docs/git-config#_conditional_includes
CONDITIONAL_INCLUDE_REGEXP = re.compile(r"(?<=includeIf )\"(gitdir|gitdir/i|onbranch):(.+)\"")
class MetaParserBuilder(abc.ABCMeta): # noqa: B024
"""Utility class wrapping base-class methods into decorators that assure read-only properties"""
def __new__(cls, name: str, bases: Tuple, clsdict: Dict[str, Any]) -> "MetaParserBuilder":
"""
Equip all base-class methods with a needs_values decorator, and all non-const methods
with a set_dirty_and_flush_changes decorator in addition to that."""
kmm = "_mutating_methods_"
if kmm in clsdict:
mutating_methods = clsdict[kmm]
for base in bases:
methods = (t for t in inspect.getmembers(base, inspect.isroutine) if not t[0].startswith("_"))
for name, method in methods:
if name in clsdict:
continue
method_with_values = needs_values(method)
if name in mutating_methods:
method_with_values = set_dirty_and_flush_changes(method_with_values)
# END mutating methods handling
clsdict[name] = method_with_values
# END for each name/method pair
# END for each base
# END if mutating methods configuration is set
new_type = super(MetaParserBuilder, cls).__new__(cls, name, bases, clsdict)
return new_type
def needs_values(func: Callable[..., _T]) -> Callable[..., _T]:
"""Returns method assuring we read values (on demand) before we try to access them"""
@wraps(func)
def assure_data_present(self: "GitConfigParser", *args: Any, **kwargs: Any) -> _T:
self.read()
return func(self, *args, **kwargs)
# END wrapper method
return assure_data_present
def set_dirty_and_flush_changes(non_const_func: Callable[..., _T]) -> Callable[..., _T]:
"""Return method that checks whether given non constant function may be called.
If so, the instance will be set dirty.
Additionally, we flush the changes right to disk"""
def flush_changes(self: "GitConfigParser", *args: Any, **kwargs: Any) -> _T:
rval = non_const_func(self, *args, **kwargs)
self._dirty = True
self.write()
return rval
# END wrapper method
flush_changes.__name__ = non_const_func.__name__
return flush_changes
class SectionConstraint(Generic[T_ConfigParser]):
"""Constrains a ConfigParser to only option commands which are constrained to
always use the section we have been initialized with.
It supports all ConfigParser methods that operate on an option.
:note:
If used as a context manager, will release the wrapped ConfigParser."""
__slots__ = ("_config", "_section_name")
_valid_attrs_ = (
"get_value",
"set_value",
"get",
"set",
"getint",
"getfloat",
"getboolean",
"has_option",
"remove_section",
"remove_option",
"options",
)
def __init__(self, config: T_ConfigParser, section: str) -> None:
self._config = config
self._section_name = section
def __del__(self) -> None:
# Yes, for some reason, we have to call it explicitly for it to work in PY3 !
# Apparently __del__ doesn't get call anymore if refcount becomes 0
# Ridiculous ... .
self._config.release()
def __getattr__(self, attr: str) -> Any:
if attr in self._valid_attrs_:
return lambda *args, **kwargs: self._call_config(attr, *args, **kwargs)
return super(SectionConstraint, self).__getattribute__(attr)
def _call_config(self, method: str, *args: Any, **kwargs: Any) -> Any:
"""Call the configuration at the given method which must take a section name
as first argument"""
return getattr(self._config, method)(self._section_name, *args, **kwargs)
@property
def config(self) -> T_ConfigParser:
"""return: Configparser instance we constrain"""
return self._config
def release(self) -> None:
"""Equivalent to GitConfigParser.release(), which is called on our underlying parser instance"""
return self._config.release()
def __enter__(self) -> "SectionConstraint[T_ConfigParser]":
self._config.__enter__()
return self
def __exit__(self, exception_type: str, exception_value: str, traceback: str) -> None:
self._config.__exit__(exception_type, exception_value, traceback)
class _OMD(OrderedDict_OMD):
"""Ordered multi-dict."""
def __setitem__(self, key: str, value: _T) -> None:
super(_OMD, self).__setitem__(key, [value])
def add(self, key: str, value: Any) -> None:
if key not in self:
super(_OMD, self).__setitem__(key, [value])
return None
super(_OMD, self).__getitem__(key).append(value)
def setall(self, key: str, values: List[_T]) -> None:
super(_OMD, self).__setitem__(key, values)
def __getitem__(self, key: str) -> Any:
return super(_OMD, self).__getitem__(key)[-1]
def getlast(self, key: str) -> Any:
return super(_OMD, self).__getitem__(key)[-1]
def setlast(self, key: str, value: Any) -> None:
if key not in self:
super(_OMD, self).__setitem__(key, [value])
return
prior = super(_OMD, self).__getitem__(key)
prior[-1] = value
def get(self, key: str, default: Union[_T, None] = None) -> Union[_T, None]:
return super(_OMD, self).get(key, [default])[-1]
def getall(self, key: str) -> List[_T]:
return super(_OMD, self).__getitem__(key)
def items(self) -> List[Tuple[str, _T]]: # type: ignore[override]
"""List of (key, last value for key)."""
return [(k, self[k]) for k in self]
def items_all(self) -> List[Tuple[str, List[_T]]]:
"""List of (key, list of values for key)."""
return [(k, self.getall(k)) for k in self]
def get_config_path(config_level: Lit_config_levels) -> str:
# we do not support an absolute path of the gitconfig on windows ,
# use the global config instead
if is_win and config_level == "system":
config_level = "global"
if config_level == "system":
return "/etc/gitconfig"
elif config_level == "user":
config_home = os.environ.get("XDG_CONFIG_HOME") or osp.join(os.environ.get("HOME", "~"), ".config")
return osp.normpath(osp.expanduser(osp.join(config_home, "git", "config")))
elif config_level == "global":
return osp.normpath(osp.expanduser("~/.gitconfig"))
elif config_level == "repository":
raise ValueError("No repo to get repository configuration from. Use Repo._get_config_path")
else:
# Should not reach here. Will raise ValueError if does. Static typing will warn missing elifs
assert_never(
config_level, # type: ignore[unreachable]
ValueError(f"Invalid configuration level: {config_level!r}"),
)
class GitConfigParser(cp.RawConfigParser, metaclass=MetaParserBuilder):
"""Implements specifics required to read git style configuration files.
This variation behaves much like the git.config command such that the configuration
will be read on demand based on the filepath given during initialization.
The changes will automatically be written once the instance goes out of scope, but
can be triggered manually as well.
The configuration file will be locked if you intend to change values preventing other
instances to write concurrently.
:note:
The config is case-sensitive even when queried, hence section and option names
must match perfectly.
If used as a context manager, will release the locked file."""
# { Configuration
# The lock type determines the type of lock to use in new configuration readers.
# They must be compatible to the LockFile interface.
# A suitable alternative would be the BlockingLockFile
t_lock = LockFile
re_comment = re.compile(r"^\s*[#;]")
# } END configuration
optvalueonly_source = r"\s*(?P<option>[^:=\s][^:=]*)"
OPTVALUEONLY = re.compile(optvalueonly_source)
OPTCRE = re.compile(optvalueonly_source + r"\s*(?P<vi>[:=])\s*" + r"(?P<value>.*)$")
del optvalueonly_source
# list of RawConfigParser methods able to change the instance
_mutating_methods_ = ("add_section", "remove_section", "remove_option", "set")
def __init__(
self,
file_or_files: Union[None, PathLike, "BytesIO", Sequence[Union[PathLike, "BytesIO"]]] = None,
read_only: bool = True,
merge_includes: bool = True,
config_level: Union[Lit_config_levels, None] = None,
repo: Union["Repo", None] = None,
) -> None:
"""Initialize a configuration reader to read the given file_or_files and to
possibly allow changes to it by setting read_only False
:param file_or_files:
A single file path or file objects or multiple of these
:param read_only:
If True, the ConfigParser may only read the data , but not change it.
If False, only a single file path or file object may be given. We will write back the changes
when they happen, or when the ConfigParser is released. This will not happen if other
configuration files have been included
:param merge_includes: if True, we will read files mentioned in [include] sections and merge their
contents into ours. This makes it impossible to write back an individual configuration file.
Thus, if you want to modify a single configuration file, turn this off to leave the original
dataset unaltered when reading it.
:param repo: Reference to repository to use if [includeIf] sections are found in configuration files.
"""
cp.RawConfigParser.__init__(self, dict_type=_OMD)
self._dict: Callable[..., _OMD] # type: ignore # mypy/typeshed bug?
self._defaults: _OMD
self._sections: _OMD # type: ignore # mypy/typeshed bug?
# Used in python 3, needs to stay in sync with sections for underlying implementation to work
if not hasattr(self, "_proxies"):
self._proxies = self._dict()
if file_or_files is not None:
self._file_or_files: Union[PathLike, "BytesIO", Sequence[Union[PathLike, "BytesIO"]]] = file_or_files
else:
if config_level is None:
if read_only:
self._file_or_files = [
get_config_path(cast(Lit_config_levels, f)) for f in CONFIG_LEVELS if f != "repository"
]
else:
raise ValueError("No configuration level or configuration files specified")
else:
self._file_or_files = [get_config_path(config_level)]
self._read_only = read_only
self._dirty = False
self._is_initialized = False
self._merge_includes = merge_includes
self._repo = repo
self._lock: Union["LockFile", None] = None
self._acquire_lock()
def _acquire_lock(self) -> None:
if not self._read_only:
if not self._lock:
if isinstance(self._file_or_files, (str, os.PathLike)):
file_or_files = self._file_or_files
elif isinstance(self._file_or_files, (tuple, list, Sequence)):
raise ValueError(
"Write-ConfigParsers can operate on a single file only, multiple files have been passed"
)
else:
file_or_files = self._file_or_files.name
# END get filename from handle/stream
# initialize lock base - we want to write
self._lock = self.t_lock(file_or_files)
# END lock check
self._lock._obtain_lock()
# END read-only check
def __del__(self) -> None:
"""Write pending changes if required and release locks"""
# NOTE: only consistent in PY2
self.release()
def __enter__(self) -> "GitConfigParser":
self._acquire_lock()
return self
def __exit__(self, *args: Any) -> None:
self.release()
def release(self) -> None:
"""Flush changes and release the configuration write lock. This instance must not be used anymore afterwards.
In Python 3, it's required to explicitly release locks and flush changes, as __del__ is not called
deterministically anymore."""
# checking for the lock here makes sure we do not raise during write()
# in case an invalid parser was created who could not get a lock
if self.read_only or (self._lock and not self._lock._has_lock()):
return
try:
try:
self.write()
except IOError:
log.error("Exception during destruction of GitConfigParser", exc_info=True)
except ReferenceError:
# This happens in PY3 ... and usually means that some state cannot be written
# as the sections dict cannot be iterated
# Usually when shutting down the interpreter, don'y know how to fix this
pass
finally:
if self._lock is not None:
self._lock._release_lock()
def optionxform(self, optionstr: str) -> str:
"""Do not transform options in any way when writing"""
return optionstr
def _read(self, fp: Union[BufferedReader, IO[bytes]], fpname: str) -> None:
"""A direct copy of the py2.4 version of the super class's _read method
to assure it uses ordered dicts. Had to change one line to make it work.
Future versions have this fixed, but in fact its quite embarrassing for the
guys not to have done it right in the first place !
Removed big comments to make it more compact.
Made sure it ignores initial whitespace as git uses tabs"""
cursect = None # None, or a dictionary
optname = None
lineno = 0
is_multi_line = False
e = None # None, or an exception
def string_decode(v: str) -> str:
if v[-1] == "\\":
v = v[:-1]
# end cut trailing escapes to prevent decode error
return v.encode(defenc).decode("unicode_escape")
# end
# end
while True:
# we assume to read binary !
line = fp.readline().decode(defenc)
if not line:
break
lineno = lineno + 1
# comment or blank line?
if line.strip() == "" or self.re_comment.match(line):
continue
if line.split(None, 1)[0].lower() == "rem" and line[0] in "rR":
# no leading whitespace
continue
# is it a section header?
mo = self.SECTCRE.match(line.strip())
if not is_multi_line and mo:
sectname: str = mo.group("header").strip()
if sectname in self._sections:
cursect = self._sections[sectname]
elif sectname == cp.DEFAULTSECT:
cursect = self._defaults
else:
cursect = self._dict((("__name__", sectname),))
self._sections[sectname] = cursect
self._proxies[sectname] = None
# So sections can't start with a continuation line
optname = None
# no section header in the file?
elif cursect is None:
raise cp.MissingSectionHeaderError(fpname, lineno, line)
# an option line?
elif not is_multi_line:
mo = self.OPTCRE.match(line)
if mo:
# We might just have handled the last line, which could contain a quotation we want to remove
optname, vi, optval = mo.group("option", "vi", "value")
if vi in ("=", ":") and ";" in optval and not optval.strip().startswith('"'):
pos = optval.find(";")
if pos != -1 and optval[pos - 1].isspace():
optval = optval[:pos]
optval = optval.strip()
if optval == '""':
optval = ""
# end handle empty string
optname = self.optionxform(optname.rstrip())
if len(optval) > 1 and optval[0] == '"' and optval[-1] != '"':
is_multi_line = True
optval = string_decode(optval[1:])
# end handle multi-line
# preserves multiple values for duplicate optnames
cursect.add(optname, optval)
else:
# check if it's an option with no value - it's just ignored by git
if not self.OPTVALUEONLY.match(line):
if not e:
e = cp.ParsingError(fpname)
e.append(lineno, repr(line))
continue
else:
line = line.rstrip()
if line.endswith('"'):
is_multi_line = False
line = line[:-1]
# end handle quotations
optval = cursect.getlast(optname)
cursect.setlast(optname, optval + string_decode(line))
# END parse section or option
# END while reading
# if any parsing errors occurred, raise an exception
if e:
raise e
def _has_includes(self) -> Union[bool, int]:
return self._merge_includes and len(self._included_paths())
def _included_paths(self) -> List[Tuple[str, str]]:
"""Return List all paths that must be included to configuration
as Tuples of (option, value).
"""
paths = []
for section in self.sections():
if section == "include":
paths += self.items(section)
match = CONDITIONAL_INCLUDE_REGEXP.search(section)
if match is None or self._repo is None:
continue
keyword = match.group(1)
value = match.group(2).strip()
if keyword in ["gitdir", "gitdir/i"]:
value = osp.expanduser(value)
if not any(value.startswith(s) for s in ["./", "/"]):
value = "**/" + value
if value.endswith("/"):
value += "**"
# Ensure that glob is always case insensitive if required.
if keyword.endswith("/i"):
value = re.sub(
r"[a-zA-Z]",
lambda m: "[{}{}]".format(m.group().lower(), m.group().upper()),
value,
)
if self._repo.git_dir:
if fnmatch.fnmatchcase(str(self._repo.git_dir), value):
paths += self.items(section)
elif keyword == "onbranch":
try:
branch_name = self._repo.active_branch.name
except TypeError:
# Ignore section if active branch cannot be retrieved.
continue
if fnmatch.fnmatchcase(branch_name, value):
paths += self.items(section)
return paths
def read(self) -> None: # type: ignore[override]
"""Reads the data stored in the files we have been initialized with. It will
ignore files that cannot be read, possibly leaving an empty configuration
:return: Nothing
:raise IOError: if a file cannot be handled"""
if self._is_initialized:
return None
self._is_initialized = True
files_to_read: List[Union[PathLike, IO]] = [""]
if isinstance(self._file_or_files, (str, os.PathLike)):
# for str or Path, as str is a type of Sequence
files_to_read = [self._file_or_files]
elif not isinstance(self._file_or_files, (tuple, list, Sequence)):
# could merge with above isinstance once runtime type known
files_to_read = [self._file_or_files]
else: # for lists or tuples
files_to_read = list(self._file_or_files)
# end assure we have a copy of the paths to handle
seen = set(files_to_read)
num_read_include_files = 0
while files_to_read:
file_path = files_to_read.pop(0)
file_ok = False
if hasattr(file_path, "seek"):
# must be a file objectfile-object
file_path = cast(IO[bytes], file_path) # replace with assert to narrow type, once sure
self._read(file_path, file_path.name)
else:
# assume a path if it is not a file-object
file_path = cast(PathLike, file_path)
try:
with open(file_path, "rb") as fp:
file_ok = True
self._read(fp, fp.name)
except IOError:
continue
# Read includes and append those that we didn't handle yet
# We expect all paths to be normalized and absolute (and will assure that is the case)
if self._has_includes():
for _, include_path in self._included_paths():
if include_path.startswith("~"):
include_path = osp.expanduser(include_path)
if not osp.isabs(include_path):
if not file_ok:
continue
# end ignore relative paths if we don't know the configuration file path
file_path = cast(PathLike, file_path)
assert osp.isabs(file_path), "Need absolute paths to be sure our cycle checks will work"
include_path = osp.join(osp.dirname(file_path), include_path)
# end make include path absolute
include_path = osp.normpath(include_path)
if include_path in seen or not os.access(include_path, os.R_OK):
continue
seen.add(include_path)
# insert included file to the top to be considered first
files_to_read.insert(0, include_path)
num_read_include_files += 1
# each include path in configuration file
# end handle includes
# END for each file object to read
# If there was no file included, we can safely write back (potentially) the configuration file
# without altering it's meaning
if num_read_include_files == 0:
self._merge_includes = False
# end
def _write(self, fp: IO) -> None:
"""Write an .ini-format representation of the configuration state in
git compatible format"""
def write_section(name: str, section_dict: _OMD) -> None:
fp.write(("[%s]\n" % name).encode(defenc))
values: Sequence[str] # runtime only gets str in tests, but should be whatever _OMD stores
v: str
for (key, values) in section_dict.items_all():
if key == "__name__":
continue
for v in values:
fp.write(("\t%s = %s\n" % (key, self._value_to_string(v).replace("\n", "\n\t"))).encode(defenc))
# END if key is not __name__
# END section writing
if self._defaults:
write_section(cp.DEFAULTSECT, self._defaults)
value: _OMD
for name, value in self._sections.items():
write_section(name, value)
def items(self, section_name: str) -> List[Tuple[str, str]]: # type: ignore[override]
""":return: list((option, value), ...) pairs of all items in the given section"""
return [(k, v) for k, v in super(GitConfigParser, self).items(section_name) if k != "__name__"]
def items_all(self, section_name: str) -> List[Tuple[str, List[str]]]:
""":return: list((option, [values...]), ...) pairs of all items in the given section"""
rv = _OMD(self._defaults)
for k, vs in self._sections[section_name].items_all():
if k == "__name__":
continue
if k in rv and rv.getall(k) == vs:
continue
for v in vs:
rv.add(k, v)
return rv.items_all()
@needs_values
def write(self) -> None:
"""Write changes to our file, if there are changes at all
:raise IOError: if this is a read-only writer instance or if we could not obtain
a file lock"""
self._assure_writable("write")
if not self._dirty:
return None
if isinstance(self._file_or_files, (list, tuple)):
raise AssertionError(
"Cannot write back if there is not exactly a single file to write to, have %i files"
% len(self._file_or_files)
)
# end assert multiple files
if self._has_includes():
log.debug(
"Skipping write-back of configuration file as include files were merged in."
+ "Set merge_includes=False to prevent this."
)
return None
# end
fp = self._file_or_files
# we have a physical file on disk, so get a lock
is_file_lock = isinstance(fp, (str, os.PathLike, IOBase)) # can't use Pathlike until 3.5 dropped
if is_file_lock and self._lock is not None: # else raise Error?
self._lock._obtain_lock()
if not hasattr(fp, "seek"):
fp = cast(PathLike, fp)
with open(fp, "wb") as fp_open:
self._write(fp_open)
else:
fp = cast("BytesIO", fp)
fp.seek(0)
# make sure we do not overwrite into an existing file
if hasattr(fp, "truncate"):
fp.truncate()
self._write(fp)
def _assure_writable(self, method_name: str) -> None:
if self.read_only:
raise IOError("Cannot execute non-constant method %s.%s" % (self, method_name))
def add_section(self, section: str) -> None:
"""Assures added options will stay in order"""
return super(GitConfigParser, self).add_section(section)
@property
def read_only(self) -> bool:
""":return: True if this instance may change the configuration file"""
return self._read_only
def get_value(
self,
section: str,
option: str,
default: Union[int, float, str, bool, None] = None,
) -> Union[int, float, str, bool]:
# can default or return type include bool?
"""Get an option's value.
If multiple values are specified for this option in the section, the
last one specified is returned.
:param default:
If not None, the given default value will be returned in case
the option did not exist
:return: a properly typed value, either int, float or string
:raise TypeError: in case the value could not be understood
Otherwise the exceptions known to the ConfigParser will be raised."""
try:
valuestr = self.get(section, option)
except Exception:
if default is not None:
return default
raise
return self._string_to_value(valuestr)
def get_values(
self,
section: str,
option: str,
default: Union[int, float, str, bool, None] = None,
) -> List[Union[int, float, str, bool]]:
"""Get an option's values.
If multiple values are specified for this option in the section, all are
returned.
:param default:
If not None, a list containing the given default value will be
returned in case the option did not exist
:return: a list of properly typed values, either int, float or string
:raise TypeError: in case the value could not be understood
Otherwise the exceptions known to the ConfigParser will be raised."""
try:
self.sections()
lst = self._sections[section].getall(option)
except Exception:
if default is not None:
return [default]
raise
return [self._string_to_value(valuestr) for valuestr in lst]
def _string_to_value(self, valuestr: str) -> Union[int, float, str, bool]:
types = (int, float)
for numtype in types:
try:
val = numtype(valuestr)
# truncated value ?
if val != float(valuestr):
continue
return val
except (ValueError, TypeError):
continue
# END for each numeric type
# try boolean values as git uses them
vl = valuestr.lower()
if vl == "false":
return False
if vl == "true":
return True
if not isinstance(valuestr, str):
raise TypeError(
"Invalid value type: only int, long, float and str are allowed",
valuestr,
)
return valuestr
def _value_to_string(self, value: Union[str, bytes, int, float, bool]) -> str:
if isinstance(value, (int, float, bool)):
return str(value)
return force_text(value)
@needs_values
@set_dirty_and_flush_changes
def set_value(self, section: str, option: str, value: Union[str, bytes, int, float, bool]) -> "GitConfigParser":
"""Sets the given option in section to the given value.
It will create the section if required, and will not throw as opposed to the default
ConfigParser 'set' method.
:param section: Name of the section in which the option resides or should reside
:param option: Name of the options whose value to set
:param value: Value to set the option to. It must be a string or convertible
to a string
:return: this instance"""
if not self.has_section(section):
self.add_section(section)
self.set(section, option, self._value_to_string(value))
return self
@needs_values
@set_dirty_and_flush_changes
def add_value(self, section: str, option: str, value: Union[str, bytes, int, float, bool]) -> "GitConfigParser":
"""Adds a value for the given option in section.
It will create the section if required, and will not throw as opposed to the default
ConfigParser 'set' method. The value becomes the new value of the option as returned
by 'get_value', and appends to the list of values returned by 'get_values`'.
:param section: Name of the section in which the option resides or should reside
:param option: Name of the option
:param value: Value to add to option. It must be a string or convertible
to a string
:return: this instance"""
if not self.has_section(section):
self.add_section(section)
self._sections[section].add(option, self._value_to_string(value))
return self
def rename_section(self, section: str, new_name: str) -> "GitConfigParser":
"""rename the given section to new_name
:raise ValueError: if section doesn't exit
:raise ValueError: if a section with new_name does already exist
:return: this instance
"""
if not self.has_section(section):
raise ValueError("Source section '%s' doesn't exist" % section)
if self.has_section(new_name):
raise ValueError("Destination section '%s' already exists" % new_name)
super(GitConfigParser, self).add_section(new_name)
new_section = self._sections[new_name]
for k, vs in self.items_all(section):
new_section.setall(k, vs)
# end for each value to copy
# This call writes back the changes, which is why we don't have the respective decorator
self.remove_section(section)
return self

View File

@@ -0,0 +1,63 @@
"""Module with our own gitdb implementation - it uses the git command"""
from git.util import bin_to_hex, hex_to_bin
from gitdb.base import OInfo, OStream
from gitdb.db import GitDB # @UnusedImport
from gitdb.db import LooseObjectDB
from gitdb.exc import BadObject
from git.exc import GitCommandError
# typing-------------------------------------------------
from typing import TYPE_CHECKING
from git.types import PathLike
if TYPE_CHECKING:
from git.cmd import Git
# --------------------------------------------------------
__all__ = ("GitCmdObjectDB", "GitDB")
class GitCmdObjectDB(LooseObjectDB):
"""A database representing the default git object store, which includes loose
objects, pack files and an alternates file
It will create objects only in the loose object database.
:note: for now, we use the git command to do all the lookup, just until he
have packs and the other implementations
"""
def __init__(self, root_path: PathLike, git: "Git") -> None:
"""Initialize this instance with the root and a git command"""
super(GitCmdObjectDB, self).__init__(root_path)
self._git = git
def info(self, binsha: bytes) -> OInfo:
hexsha, typename, size = self._git.get_object_header(bin_to_hex(binsha))
return OInfo(hex_to_bin(hexsha), typename, size)
def stream(self, binsha: bytes) -> OStream:
"""For now, all lookup is done by git itself"""
hexsha, typename, size, stream = self._git.stream_object_data(bin_to_hex(binsha))
return OStream(hex_to_bin(hexsha), typename, size, stream)
# { Interface
def partial_to_complete_sha_hex(self, partial_hexsha: str) -> bytes:
""":return: Full binary 20 byte sha from the given partial hexsha
:raise AmbiguousObjectName:
:raise BadObject:
:note: currently we only raise BadObject as git does not communicate
AmbiguousObjects separately"""
try:
hexsha, _typename, _size = self._git.get_object_header(partial_hexsha)
return hex_to_bin(hexsha)
except (GitCommandError, ValueError) as e:
raise BadObject(partial_hexsha) from e
# END handle exceptions
# } END interface

View File

@@ -0,0 +1,662 @@
# diff.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
import re
from git.cmd import handle_process_output
from git.compat import defenc
from git.util import finalize_process, hex_to_bin
from .objects.blob import Blob
from .objects.util import mode_str_to_int
# typing ------------------------------------------------------------------
from typing import (
Any,
Iterator,
List,
Match,
Optional,
Tuple,
Type,
TypeVar,
Union,
TYPE_CHECKING,
cast,
)
from git.types import PathLike, Literal
if TYPE_CHECKING:
from .objects.tree import Tree
from .objects import Commit
from git.repo.base import Repo
from git.objects.base import IndexObject
from subprocess import Popen
from git import Git
Lit_change_type = Literal["A", "D", "C", "M", "R", "T", "U"]
# def is_change_type(inp: str) -> TypeGuard[Lit_change_type]:
# # return True
# return inp in ['A', 'D', 'C', 'M', 'R', 'T', 'U']
# ------------------------------------------------------------------------
__all__ = ("Diffable", "DiffIndex", "Diff", "NULL_TREE")
# Special object to compare against the empty tree in diffs
NULL_TREE = object()
_octal_byte_re = re.compile(b"\\\\([0-9]{3})")
def _octal_repl(matchobj: Match) -> bytes:
value = matchobj.group(1)
value = int(value, 8)
value = bytes(bytearray((value,)))
return value
def decode_path(path: bytes, has_ab_prefix: bool = True) -> Optional[bytes]:
if path == b"/dev/null":
return None
if path.startswith(b'"') and path.endswith(b'"'):
path = path[1:-1].replace(b"\\n", b"\n").replace(b"\\t", b"\t").replace(b'\\"', b'"').replace(b"\\\\", b"\\")
path = _octal_byte_re.sub(_octal_repl, path)
if has_ab_prefix:
assert path.startswith(b"a/") or path.startswith(b"b/")
path = path[2:]
return path
class Diffable(object):
"""Common interface for all object that can be diffed against another object of compatible type.
:note:
Subclasses require a repo member as it is the case for Object instances, for practical
reasons we do not derive from Object."""
__slots__ = ()
# standin indicating you want to diff against the index
class Index(object):
pass
def _process_diff_args(
self, args: List[Union[str, "Diffable", Type["Diffable.Index"], object]]
) -> List[Union[str, "Diffable", Type["Diffable.Index"], object]]:
"""
:return:
possibly altered version of the given args list.
Method is called right before git command execution.
Subclasses can use it to alter the behaviour of the superclass"""
return args
def diff(
self,
other: Union[Type["Index"], "Tree", "Commit", None, str, object] = Index,
paths: Union[PathLike, List[PathLike], Tuple[PathLike, ...], None] = None,
create_patch: bool = False,
**kwargs: Any,
) -> "DiffIndex":
"""Creates diffs between two items being trees, trees and index or an
index and the working tree. It will detect renames automatically.
:param other:
Is the item to compare us with.
If None, we will be compared to the working tree.
If Treeish, it will be compared against the respective tree
If Index ( type ), it will be compared against the index.
If git.NULL_TREE, it will compare against the empty tree.
It defaults to Index to assure the method will not by-default fail
on bare repositories.
:param paths:
is a list of paths or a single path to limit the diff to.
It will only include at least one of the given path or paths.
:param create_patch:
If True, the returned Diff contains a detailed patch that if applied
makes the self to other. Patches are somewhat costly as blobs have to be read
and diffed.
:param kwargs:
Additional arguments passed to git-diff, such as
R=True to swap both sides of the diff.
:return: git.DiffIndex
:note:
On a bare repository, 'other' needs to be provided as Index or as
as Tree/Commit, or a git command error will occur"""
args: List[Union[PathLike, Diffable, Type["Diffable.Index"], object]] = []
args.append("--abbrev=40") # we need full shas
args.append("--full-index") # get full index paths, not only filenames
# remove default '-M' arg (check for renames) if user is overriding it
if not any(x in kwargs for x in ('find_renames', 'no_renames', 'M')):
args.append("-M")
if create_patch:
args.append("-p")
else:
args.append("--raw")
args.append("-z")
# in any way, assure we don't see colored output,
# fixes https://github.com/gitpython-developers/GitPython/issues/172
args.append("--no-color")
if paths is not None and not isinstance(paths, (tuple, list)):
paths = [paths]
if hasattr(self, "Has_Repo"):
self.repo: "Repo" = self.repo
diff_cmd = self.repo.git.diff
if other is self.Index:
args.insert(0, "--cached")
elif other is NULL_TREE:
args.insert(0, "-r") # recursive diff-tree
args.insert(0, "--root")
diff_cmd = self.repo.git.diff_tree
elif other is not None:
args.insert(0, "-r") # recursive diff-tree
args.insert(0, other)
diff_cmd = self.repo.git.diff_tree
args.insert(0, self)
# paths is list here or None
if paths:
args.append("--")
args.extend(paths)
# END paths handling
kwargs["as_process"] = True
proc = diff_cmd(*self._process_diff_args(args), **kwargs)
diff_method = Diff._index_from_patch_format if create_patch else Diff._index_from_raw_format
index = diff_method(self.repo, proc)
proc.wait()
return index
T_Diff = TypeVar("T_Diff", bound="Diff")
class DiffIndex(List[T_Diff]):
"""Implements an Index for diffs, allowing a list of Diffs to be queried by
the diff properties.
The class improves the diff handling convenience"""
# change type invariant identifying possible ways a blob can have changed
# A = Added
# D = Deleted
# R = Renamed
# M = Modified
# T = Changed in the type
change_type = ("A", "C", "D", "R", "M", "T")
def iter_change_type(self, change_type: Lit_change_type) -> Iterator[T_Diff]:
"""
:return:
iterator yielding Diff instances that match the given change_type
:param change_type:
Member of DiffIndex.change_type, namely:
* 'A' for added paths
* 'D' for deleted paths
* 'R' for renamed paths
* 'M' for paths with modified data
* 'T' for changed in the type paths
"""
if change_type not in self.change_type:
raise ValueError("Invalid change type: %s" % change_type)
for diffidx in self:
if diffidx.change_type == change_type:
yield diffidx
elif change_type == "A" and diffidx.new_file:
yield diffidx
elif change_type == "D" and diffidx.deleted_file:
yield diffidx
elif change_type == "C" and diffidx.copied_file:
yield diffidx
elif change_type == "R" and diffidx.renamed:
yield diffidx
elif change_type == "M" and diffidx.a_blob and diffidx.b_blob and diffidx.a_blob != diffidx.b_blob:
yield diffidx
# END for each diff
class Diff(object):
"""A Diff contains diff information between two Trees.
It contains two sides a and b of the diff, members are prefixed with
"a" and "b" respectively to inidcate that.
Diffs keep information about the changed blob objects, the file mode, renames,
deletions and new files.
There are a few cases where None has to be expected as member variable value:
``New File``::
a_mode is None
a_blob is None
a_path is None
``Deleted File``::
b_mode is None
b_blob is None
b_path is None
``Working Tree Blobs``
When comparing to working trees, the working tree blob will have a null hexsha
as a corresponding object does not yet exist. The mode will be null as well.
But the path will be available though.
If it is listed in a diff the working tree version of the file must
be different to the version in the index or tree, and hence has been modified."""
# precompiled regex
re_header = re.compile(
rb"""
^diff[ ]--git
[ ](?P<a_path_fallback>"?[ab]/.+?"?)[ ](?P<b_path_fallback>"?[ab]/.+?"?)\n
(?:^old[ ]mode[ ](?P<old_mode>\d+)\n
^new[ ]mode[ ](?P<new_mode>\d+)(?:\n|$))?
(?:^similarity[ ]index[ ]\d+%\n
^rename[ ]from[ ](?P<rename_from>.*)\n
^rename[ ]to[ ](?P<rename_to>.*)(?:\n|$))?
(?:^new[ ]file[ ]mode[ ](?P<new_file_mode>.+)(?:\n|$))?
(?:^deleted[ ]file[ ]mode[ ](?P<deleted_file_mode>.+)(?:\n|$))?
(?:^similarity[ ]index[ ]\d+%\n
^copy[ ]from[ ].*\n
^copy[ ]to[ ](?P<copied_file_name>.*)(?:\n|$))?
(?:^index[ ](?P<a_blob_id>[0-9A-Fa-f]+)
\.\.(?P<b_blob_id>[0-9A-Fa-f]+)[ ]?(?P<b_mode>.+)?(?:\n|$))?
(?:^---[ ](?P<a_path>[^\t\n\r\f\v]*)[\t\r\f\v]*(?:\n|$))?
(?:^\+\+\+[ ](?P<b_path>[^\t\n\r\f\v]*)[\t\r\f\v]*(?:\n|$))?
""",
re.VERBOSE | re.MULTILINE,
)
# can be used for comparisons
NULL_HEX_SHA = "0" * 40
NULL_BIN_SHA = b"\0" * 20
__slots__ = (
"a_blob",
"b_blob",
"a_mode",
"b_mode",
"a_rawpath",
"b_rawpath",
"new_file",
"deleted_file",
"copied_file",
"raw_rename_from",
"raw_rename_to",
"diff",
"change_type",
"score",
)
def __init__(
self,
repo: "Repo",
a_rawpath: Optional[bytes],
b_rawpath: Optional[bytes],
a_blob_id: Union[str, bytes, None],
b_blob_id: Union[str, bytes, None],
a_mode: Union[bytes, str, None],
b_mode: Union[bytes, str, None],
new_file: bool,
deleted_file: bool,
copied_file: bool,
raw_rename_from: Optional[bytes],
raw_rename_to: Optional[bytes],
diff: Union[str, bytes, None],
change_type: Optional[Lit_change_type],
score: Optional[int],
) -> None:
assert a_rawpath is None or isinstance(a_rawpath, bytes)
assert b_rawpath is None or isinstance(b_rawpath, bytes)
self.a_rawpath = a_rawpath
self.b_rawpath = b_rawpath
self.a_mode = mode_str_to_int(a_mode) if a_mode else None
self.b_mode = mode_str_to_int(b_mode) if b_mode else None
# Determine whether this diff references a submodule, if it does then
# we need to overwrite "repo" to the corresponding submodule's repo instead
if repo and a_rawpath:
for submodule in repo.submodules:
if submodule.path == a_rawpath.decode(defenc, "replace"):
if submodule.module_exists():
repo = submodule.module()
break
self.a_blob: Union["IndexObject", None]
if a_blob_id is None or a_blob_id == self.NULL_HEX_SHA:
self.a_blob = None
else:
self.a_blob = Blob(repo, hex_to_bin(a_blob_id), mode=self.a_mode, path=self.a_path)
self.b_blob: Union["IndexObject", None]
if b_blob_id is None or b_blob_id == self.NULL_HEX_SHA:
self.b_blob = None
else:
self.b_blob = Blob(repo, hex_to_bin(b_blob_id), mode=self.b_mode, path=self.b_path)
self.new_file: bool = new_file
self.deleted_file: bool = deleted_file
self.copied_file: bool = copied_file
# be clear and use None instead of empty strings
assert raw_rename_from is None or isinstance(raw_rename_from, bytes)
assert raw_rename_to is None or isinstance(raw_rename_to, bytes)
self.raw_rename_from = raw_rename_from or None
self.raw_rename_to = raw_rename_to or None
self.diff = diff
self.change_type: Union[Lit_change_type, None] = change_type
self.score = score
def __eq__(self, other: object) -> bool:
for name in self.__slots__:
if getattr(self, name) != getattr(other, name):
return False
# END for each name
return True
def __ne__(self, other: object) -> bool:
return not (self == other)
def __hash__(self) -> int:
return hash(tuple(getattr(self, n) for n in self.__slots__))
def __str__(self) -> str:
h: str = "%s"
if self.a_blob:
h %= self.a_blob.path
elif self.b_blob:
h %= self.b_blob.path
msg: str = ""
line = None # temp line
line_length = 0 # line length
for b, n in zip((self.a_blob, self.b_blob), ("lhs", "rhs")):
if b:
line = "\n%s: %o | %s" % (n, b.mode, b.hexsha)
else:
line = "\n%s: None" % n
# END if blob is not None
line_length = max(len(line), line_length)
msg += line
# END for each blob
# add headline
h += "\n" + "=" * line_length
if self.deleted_file:
msg += "\nfile deleted in rhs"
if self.new_file:
msg += "\nfile added in rhs"
if self.copied_file:
msg += "\nfile %r copied from %r" % (self.b_path, self.a_path)
if self.rename_from:
msg += "\nfile renamed from %r" % self.rename_from
if self.rename_to:
msg += "\nfile renamed to %r" % self.rename_to
if self.diff:
msg += "\n---"
try:
msg += self.diff.decode(defenc) if isinstance(self.diff, bytes) else self.diff
except UnicodeDecodeError:
msg += "OMITTED BINARY DATA"
# end handle encoding
msg += "\n---"
# END diff info
# Python2 silliness: have to assure we convert our likely to be unicode object to a string with the
# right encoding. Otherwise it tries to convert it using ascii, which may fail ungracefully
res = h + msg
# end
return res
@property
def a_path(self) -> Optional[str]:
return self.a_rawpath.decode(defenc, "replace") if self.a_rawpath else None
@property
def b_path(self) -> Optional[str]:
return self.b_rawpath.decode(defenc, "replace") if self.b_rawpath else None
@property
def rename_from(self) -> Optional[str]:
return self.raw_rename_from.decode(defenc, "replace") if self.raw_rename_from else None
@property
def rename_to(self) -> Optional[str]:
return self.raw_rename_to.decode(defenc, "replace") if self.raw_rename_to else None
@property
def renamed(self) -> bool:
""":returns: True if the blob of our diff has been renamed
:note: This property is deprecated, please use ``renamed_file`` instead.
"""
return self.renamed_file
@property
def renamed_file(self) -> bool:
""":returns: True if the blob of our diff has been renamed"""
return self.rename_from != self.rename_to
@classmethod
def _pick_best_path(cls, path_match: bytes, rename_match: bytes, path_fallback_match: bytes) -> Optional[bytes]:
if path_match:
return decode_path(path_match)
if rename_match:
return decode_path(rename_match, has_ab_prefix=False)
if path_fallback_match:
return decode_path(path_fallback_match)
return None
@classmethod
def _index_from_patch_format(cls, repo: "Repo", proc: Union["Popen", "Git.AutoInterrupt"]) -> DiffIndex:
"""Create a new DiffIndex from the given text which must be in patch format
:param repo: is the repository we are operating on - it is required
:param stream: result of 'git diff' as a stream (supporting file protocol)
:return: git.DiffIndex"""
## FIXME: Here SLURPING raw, need to re-phrase header-regexes linewise.
text_list: List[bytes] = []
handle_process_output(proc, text_list.append, None, finalize_process, decode_streams=False)
# for now, we have to bake the stream
text = b"".join(text_list)
index: "DiffIndex" = DiffIndex()
previous_header: Union[Match[bytes], None] = None
header: Union[Match[bytes], None] = None
a_path, b_path = None, None # for mypy
a_mode, b_mode = None, None # for mypy
for _header in cls.re_header.finditer(text):
(
a_path_fallback,
b_path_fallback,
old_mode,
new_mode,
rename_from,
rename_to,
new_file_mode,
deleted_file_mode,
copied_file_name,
a_blob_id,
b_blob_id,
b_mode,
a_path,
b_path,
) = _header.groups()
new_file, deleted_file, copied_file = (
bool(new_file_mode),
bool(deleted_file_mode),
bool(copied_file_name),
)
a_path = cls._pick_best_path(a_path, rename_from, a_path_fallback)
b_path = cls._pick_best_path(b_path, rename_to, b_path_fallback)
# Our only means to find the actual text is to see what has not been matched by our regex,
# and then retro-actively assign it to our index
if previous_header is not None:
index[-1].diff = text[previous_header.end() : _header.start()]
# end assign actual diff
# Make sure the mode is set if the path is set. Otherwise the resulting blob is invalid
# We just use the one mode we should have parsed
a_mode = old_mode or deleted_file_mode or (a_path and (b_mode or new_mode or new_file_mode))
b_mode = b_mode or new_mode or new_file_mode or (b_path and a_mode)
index.append(
Diff(
repo,
a_path,
b_path,
a_blob_id and a_blob_id.decode(defenc),
b_blob_id and b_blob_id.decode(defenc),
a_mode and a_mode.decode(defenc),
b_mode and b_mode.decode(defenc),
new_file,
deleted_file,
copied_file,
rename_from,
rename_to,
None,
None,
None,
)
)
previous_header = _header
header = _header
# end for each header we parse
if index and header:
index[-1].diff = text[header.end() :]
# end assign last diff
return index
@staticmethod
def _handle_diff_line(lines_bytes: bytes, repo: "Repo", index: DiffIndex) -> None:
lines = lines_bytes.decode(defenc)
# Discard everything before the first colon, and the colon itself.
_, _, lines = lines.partition(":")
for line in lines.split("\x00:"):
if not line:
# The line data is empty, skip
continue
meta, _, path = line.partition("\x00")
path = path.rstrip("\x00")
a_blob_id: Optional[str]
b_blob_id: Optional[str]
old_mode, new_mode, a_blob_id, b_blob_id, _change_type = meta.split(None, 4)
# Change type can be R100
# R: status letter
# 100: score (in case of copy and rename)
# assert is_change_type(_change_type[0]), f"Unexpected value for change_type received: {_change_type[0]}"
change_type: Lit_change_type = cast(Lit_change_type, _change_type[0])
score_str = "".join(_change_type[1:])
score = int(score_str) if score_str.isdigit() else None
path = path.strip()
a_path = path.encode(defenc)
b_path = path.encode(defenc)
deleted_file = False
new_file = False
copied_file = False
rename_from = None
rename_to = None
# NOTE: We cannot conclude from the existence of a blob to change type
# as diffs with the working do not have blobs yet
if change_type == "D":
b_blob_id = None # Optional[str]
deleted_file = True
elif change_type == "A":
a_blob_id = None
new_file = True
elif change_type == "C":
copied_file = True
a_path_str, b_path_str = path.split("\x00", 1)
a_path = a_path_str.encode(defenc)
b_path = b_path_str.encode(defenc)
elif change_type == "R":
a_path_str, b_path_str = path.split("\x00", 1)
a_path = a_path_str.encode(defenc)
b_path = b_path_str.encode(defenc)
rename_from, rename_to = a_path, b_path
elif change_type == "T":
# Nothing to do
pass
# END add/remove handling
diff = Diff(
repo,
a_path,
b_path,
a_blob_id,
b_blob_id,
old_mode,
new_mode,
new_file,
deleted_file,
copied_file,
rename_from,
rename_to,
"",
change_type,
score,
)
index.append(diff)
@classmethod
def _index_from_raw_format(cls, repo: "Repo", proc: "Popen") -> "DiffIndex":
"""Create a new DiffIndex from the given stream which must be in raw format.
:return: git.DiffIndex"""
# handles
# :100644 100644 687099101... 37c5e30c8... M .gitignore
index: "DiffIndex" = DiffIndex()
handle_process_output(
proc,
lambda byt: cls._handle_diff_line(byt, repo, index),
None,
finalize_process,
decode_streams=False,
)
return index

View File

@@ -0,0 +1,186 @@
# exc.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
""" Module containing all exceptions thrown throughout the git package, """
from gitdb.exc import BadName # NOQA @UnusedWildImport skipcq: PYL-W0401, PYL-W0614
from gitdb.exc import * # NOQA @UnusedWildImport skipcq: PYL-W0401, PYL-W0614
from git.compat import safe_decode
from git.util import remove_password_if_present
# typing ----------------------------------------------------
from typing import List, Sequence, Tuple, Union, TYPE_CHECKING
from git.types import PathLike
if TYPE_CHECKING:
from git.repo.base import Repo
# ------------------------------------------------------------------
class GitError(Exception):
"""Base class for all package exceptions"""
class InvalidGitRepositoryError(GitError):
"""Thrown if the given repository appears to have an invalid format."""
class WorkTreeRepositoryUnsupported(InvalidGitRepositoryError):
"""Thrown to indicate we can't handle work tree repositories"""
class NoSuchPathError(GitError, OSError):
"""Thrown if a path could not be access by the system."""
class UnsafeProtocolError(GitError):
"""Thrown if unsafe protocols are passed without being explicitly allowed."""
class UnsafeOptionError(GitError):
"""Thrown if unsafe options are passed without being explicitly allowed."""
class CommandError(GitError):
"""Base class for exceptions thrown at every stage of `Popen()` execution.
:param command:
A non-empty list of argv comprising the command-line.
"""
#: A unicode print-format with 2 `%s for `<cmdline>` and the rest,
#: e.g.
#: "'%s' failed%s"
_msg = "Cmd('%s') failed%s"
def __init__(
self,
command: Union[List[str], Tuple[str, ...], str],
status: Union[str, int, None, Exception] = None,
stderr: Union[bytes, str, None] = None,
stdout: Union[bytes, str, None] = None,
) -> None:
if not isinstance(command, (tuple, list)):
command = command.split()
self.command = remove_password_if_present(command)
self.status = status
if status:
if isinstance(status, Exception):
status = "%s('%s')" % (type(status).__name__, safe_decode(str(status)))
else:
try:
status = "exit code(%s)" % int(status)
except (ValueError, TypeError):
s = safe_decode(str(status))
status = "'%s'" % s if isinstance(status, str) else s
self._cmd = safe_decode(self.command[0])
self._cmdline = " ".join(safe_decode(i) for i in self.command)
self._cause = status and " due to: %s" % status or "!"
stdout_decode = safe_decode(stdout)
stderr_decode = safe_decode(stderr)
self.stdout = stdout_decode and "\n stdout: '%s'" % stdout_decode or ""
self.stderr = stderr_decode and "\n stderr: '%s'" % stderr_decode or ""
def __str__(self) -> str:
return (self._msg + "\n cmdline: %s%s%s") % (
self._cmd,
self._cause,
self._cmdline,
self.stdout,
self.stderr,
)
class GitCommandNotFound(CommandError):
"""Thrown if we cannot find the `git` executable in the PATH or at the path given by
the GIT_PYTHON_GIT_EXECUTABLE environment variable"""
def __init__(self, command: Union[List[str], Tuple[str], str], cause: Union[str, Exception]) -> None:
super(GitCommandNotFound, self).__init__(command, cause)
self._msg = "Cmd('%s') not found%s"
class GitCommandError(CommandError):
"""Thrown if execution of the git command fails with non-zero status code."""
def __init__(
self,
command: Union[List[str], Tuple[str, ...], str],
status: Union[str, int, None, Exception] = None,
stderr: Union[bytes, str, None] = None,
stdout: Union[bytes, str, None] = None,
) -> None:
super(GitCommandError, self).__init__(command, status, stderr, stdout)
class CheckoutError(GitError):
"""Thrown if a file could not be checked out from the index as it contained
changes.
The .failed_files attribute contains a list of relative paths that failed
to be checked out as they contained changes that did not exist in the index.
The .failed_reasons attribute contains a string informing about the actual
cause of the issue.
The .valid_files attribute contains a list of relative paths to files that
were checked out successfully and hence match the version stored in the
index"""
def __init__(
self,
message: str,
failed_files: Sequence[PathLike],
valid_files: Sequence[PathLike],
failed_reasons: List[str],
) -> None:
Exception.__init__(self, message)
self.failed_files = failed_files
self.failed_reasons = failed_reasons
self.valid_files = valid_files
def __str__(self) -> str:
return Exception.__str__(self) + ":%s" % self.failed_files
class CacheError(GitError):
"""Base for all errors related to the git index, which is called cache internally"""
class UnmergedEntriesError(CacheError):
"""Thrown if an operation cannot proceed as there are still unmerged
entries in the cache"""
class HookExecutionError(CommandError):
"""Thrown if a hook exits with a non-zero exit code. It provides access to the exit code and the string returned
via standard output"""
def __init__(
self,
command: Union[List[str], Tuple[str, ...], str],
status: Union[str, int, None, Exception],
stderr: Union[bytes, str, None] = None,
stdout: Union[bytes, str, None] = None,
) -> None:
super(HookExecutionError, self).__init__(command, status, stderr, stdout)
self._msg = "Hook('%s') failed%s"
class RepositoryDirtyError(GitError):
"""Thrown whenever an operation on a repository fails as it has uncommitted changes that would be overwritten"""
def __init__(self, repo: "Repo", message: str) -> None:
self.repo = repo
self.message = message
def __str__(self) -> str:
return "Operation cannot be performed on %r: %s" % (self.repo, self.message)

View File

@@ -0,0 +1,4 @@
"""Initialize the index package"""
# flake8: noqa
from .base import *
from .typ import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,444 @@
# Contains standalone functions to accompany the index implementation and make it
# more versatile
# NOTE: Autodoc hates it if this is a docstring
from io import BytesIO
from pathlib import Path
import os
from stat import (
S_IFDIR,
S_IFLNK,
S_ISLNK,
S_ISDIR,
S_IFMT,
S_IFREG,
S_IXUSR,
)
import subprocess
from git.cmd import PROC_CREATIONFLAGS, handle_process_output
from git.compat import (
defenc,
force_text,
force_bytes,
is_posix,
is_win,
safe_decode,
)
from git.exc import UnmergedEntriesError, HookExecutionError
from git.objects.fun import (
tree_to_stream,
traverse_tree_recursive,
traverse_trees_recursive,
)
from git.util import IndexFileSHA1Writer, finalize_process
from gitdb.base import IStream
from gitdb.typ import str_tree_type
import os.path as osp
from .typ import BaseIndexEntry, IndexEntry, CE_NAMEMASK, CE_STAGESHIFT
from .util import pack, unpack
# typing -----------------------------------------------------------------------------
from typing import Dict, IO, List, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast
from git.types import PathLike
if TYPE_CHECKING:
from .base import IndexFile
from git.db import GitCmdObjectDB
from git.objects.tree import TreeCacheTup
# from git.objects.fun import EntryTupOrNone
# ------------------------------------------------------------------------------------
S_IFGITLINK = S_IFLNK | S_IFDIR # a submodule
CE_NAMEMASK_INV = ~CE_NAMEMASK
__all__ = (
"write_cache",
"read_cache",
"write_tree_from_cache",
"entry_key",
"stat_mode_to_index_mode",
"S_IFGITLINK",
"run_commit_hook",
"hook_path",
)
def hook_path(name: str, git_dir: PathLike) -> str:
""":return: path to the given named hook in the given git repository directory"""
return osp.join(git_dir, "hooks", name)
def _has_file_extension(path):
return osp.splitext(path)[1]
def run_commit_hook(name: str, index: "IndexFile", *args: str) -> None:
"""Run the commit hook of the given name. Silently ignores hooks that do not exist.
:param name: name of hook, like 'pre-commit'
:param index: IndexFile instance
:param args: arguments passed to hook file
:raises HookExecutionError:"""
hp = hook_path(name, index.repo.git_dir)
if not os.access(hp, os.X_OK):
return None
env = os.environ.copy()
env["GIT_INDEX_FILE"] = safe_decode(str(index.path))
env["GIT_EDITOR"] = ":"
cmd = [hp]
try:
if is_win and not _has_file_extension(hp):
# Windows only uses extensions to determine how to open files
# (doesn't understand shebangs). Try using bash to run the hook.
relative_hp = Path(hp).relative_to(index.repo.working_dir).as_posix()
cmd = ["bash.exe", relative_hp]
cmd = subprocess.Popen(
cmd + list(args),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=index.repo.working_dir,
close_fds=is_posix,
creationflags=PROC_CREATIONFLAGS,
)
except Exception as ex:
raise HookExecutionError(hp, ex) from ex
else:
stdout_list: List[str] = []
stderr_list: List[str] = []
handle_process_output(cmd, stdout_list.append, stderr_list.append, finalize_process)
stdout = "".join(stdout_list)
stderr = "".join(stderr_list)
if cmd.returncode != 0:
stdout = force_text(stdout, defenc)
stderr = force_text(stderr, defenc)
raise HookExecutionError(hp, cmd.returncode, stderr, stdout)
# end handle return code
def stat_mode_to_index_mode(mode: int) -> int:
"""Convert the given mode from a stat call to the corresponding index mode
and return it"""
if S_ISLNK(mode): # symlinks
return S_IFLNK
if S_ISDIR(mode) or S_IFMT(mode) == S_IFGITLINK: # submodules
return S_IFGITLINK
return S_IFREG | (mode & S_IXUSR and 0o755 or 0o644) # blobs with or without executable bit
def write_cache(
entries: Sequence[Union[BaseIndexEntry, "IndexEntry"]],
stream: IO[bytes],
extension_data: Union[None, bytes] = None,
ShaStreamCls: Type[IndexFileSHA1Writer] = IndexFileSHA1Writer,
) -> None:
"""Write the cache represented by entries to a stream
:param entries: **sorted** list of entries
:param stream: stream to wrap into the AdapterStreamCls - it is used for
final output.
:param ShaStreamCls: Type to use when writing to the stream. It produces a sha
while writing to it, before the data is passed on to the wrapped stream
:param extension_data: any kind of data to write as a trailer, it must begin
a 4 byte identifier, followed by its size ( 4 bytes )"""
# wrap the stream into a compatible writer
stream_sha = ShaStreamCls(stream)
tell = stream_sha.tell
write = stream_sha.write
# header
version = 2
write(b"DIRC")
write(pack(">LL", version, len(entries)))
# body
for entry in entries:
beginoffset = tell()
write(entry.ctime_bytes) # ctime
write(entry.mtime_bytes) # mtime
path_str = str(entry.path)
path: bytes = force_bytes(path_str, encoding=defenc)
plen = len(path) & CE_NAMEMASK # path length
assert plen == len(path), "Path %s too long to fit into index" % entry.path
flags = plen | (entry.flags & CE_NAMEMASK_INV) # clear possible previous values
write(
pack(
">LLLLLL20sH",
entry.dev,
entry.inode,
entry.mode,
entry.uid,
entry.gid,
entry.size,
entry.binsha,
flags,
)
)
write(path)
real_size = (tell() - beginoffset + 8) & ~7
write(b"\0" * ((beginoffset + real_size) - tell()))
# END for each entry
# write previously cached extensions data
if extension_data is not None:
stream_sha.write(extension_data)
# write the sha over the content
stream_sha.write_sha()
def read_header(stream: IO[bytes]) -> Tuple[int, int]:
"""Return tuple(version_long, num_entries) from the given stream"""
type_id = stream.read(4)
if type_id != b"DIRC":
raise AssertionError("Invalid index file header: %r" % type_id)
unpacked = cast(Tuple[int, int], unpack(">LL", stream.read(4 * 2)))
version, num_entries = unpacked
# TODO: handle version 3: extended data, see read-cache.c
assert version in (1, 2)
return version, num_entries
def entry_key(*entry: Union[BaseIndexEntry, PathLike, int]) -> Tuple[PathLike, int]:
""":return: Key suitable to be used for the index.entries dictionary
:param entry: One instance of type BaseIndexEntry or the path and the stage"""
# def is_entry_key_tup(entry_key: Tuple) -> TypeGuard[Tuple[PathLike, int]]:
# return isinstance(entry_key, tuple) and len(entry_key) == 2
if len(entry) == 1:
entry_first = entry[0]
assert isinstance(entry_first, BaseIndexEntry)
return (entry_first.path, entry_first.stage)
else:
# assert is_entry_key_tup(entry)
entry = cast(Tuple[PathLike, int], entry)
return entry
# END handle entry
def read_cache(
stream: IO[bytes],
) -> Tuple[int, Dict[Tuple[PathLike, int], "IndexEntry"], bytes, bytes]:
"""Read a cache file from the given stream
:return: tuple(version, entries_dict, extension_data, content_sha)
* version is the integer version number
* entries dict is a dictionary which maps IndexEntry instances to a path at a stage
* extension_data is '' or 4 bytes of type + 4 bytes of size + size bytes
* content_sha is a 20 byte sha on all cache file contents"""
version, num_entries = read_header(stream)
count = 0
entries: Dict[Tuple[PathLike, int], "IndexEntry"] = {}
read = stream.read
tell = stream.tell
while count < num_entries:
beginoffset = tell()
ctime = unpack(">8s", read(8))[0]
mtime = unpack(">8s", read(8))[0]
(dev, ino, mode, uid, gid, size, sha, flags) = unpack(">LLLLLL20sH", read(20 + 4 * 6 + 2))
path_size = flags & CE_NAMEMASK
path = read(path_size).decode(defenc)
real_size = (tell() - beginoffset + 8) & ~7
read((beginoffset + real_size) - tell())
entry = IndexEntry((mode, sha, flags, path, ctime, mtime, dev, ino, uid, gid, size))
# entry_key would be the method to use, but we safe the effort
entries[(path, entry.stage)] = entry
count += 1
# END for each entry
# the footer contains extension data and a sha on the content so far
# Keep the extension footer,and verify we have a sha in the end
# Extension data format is:
# 4 bytes ID
# 4 bytes length of chunk
# repeated 0 - N times
extension_data = stream.read(~0)
assert (
len(extension_data) > 19
), "Index Footer was not at least a sha on content as it was only %i bytes in size" % len(extension_data)
content_sha = extension_data[-20:]
# truncate the sha in the end as we will dynamically create it anyway
extension_data = extension_data[:-20]
return (version, entries, extension_data, content_sha)
def write_tree_from_cache(
entries: List[IndexEntry], odb: "GitCmdObjectDB", sl: slice, si: int = 0
) -> Tuple[bytes, List["TreeCacheTup"]]:
"""Create a tree from the given sorted list of entries and put the respective
trees into the given object database
:param entries: **sorted** list of IndexEntries
:param odb: object database to store the trees in
:param si: start index at which we should start creating subtrees
:param sl: slice indicating the range we should process on the entries list
:return: tuple(binsha, list(tree_entry, ...)) a tuple of a sha and a list of
tree entries being a tuple of hexsha, mode, name"""
tree_items: List["TreeCacheTup"] = []
ci = sl.start
end = sl.stop
while ci < end:
entry = entries[ci]
if entry.stage != 0:
raise UnmergedEntriesError(entry)
# END abort on unmerged
ci += 1
rbound = entry.path.find("/", si)
if rbound == -1:
# its not a tree
tree_items.append((entry.binsha, entry.mode, entry.path[si:]))
else:
# find common base range
base = entry.path[si:rbound]
xi = ci
while xi < end:
oentry = entries[xi]
orbound = oentry.path.find("/", si)
if orbound == -1 or oentry.path[si:orbound] != base:
break
# END abort on base mismatch
xi += 1
# END find common base
# enter recursion
# ci - 1 as we want to count our current item as well
sha, _tree_entry_list = write_tree_from_cache(entries, odb, slice(ci - 1, xi), rbound + 1)
tree_items.append((sha, S_IFDIR, base))
# skip ahead
ci = xi
# END handle bounds
# END for each entry
# finally create the tree
sio = BytesIO()
tree_to_stream(tree_items, sio.write) # writes to stream as bytes, but doesn't change tree_items
sio.seek(0)
istream = odb.store(IStream(str_tree_type, len(sio.getvalue()), sio))
return (istream.binsha, tree_items)
def _tree_entry_to_baseindexentry(tree_entry: "TreeCacheTup", stage: int) -> BaseIndexEntry:
return BaseIndexEntry((tree_entry[1], tree_entry[0], stage << CE_STAGESHIFT, tree_entry[2]))
def aggressive_tree_merge(odb: "GitCmdObjectDB", tree_shas: Sequence[bytes]) -> List[BaseIndexEntry]:
"""
:return: list of BaseIndexEntries representing the aggressive merge of the given
trees. All valid entries are on stage 0, whereas the conflicting ones are left
on stage 1, 2 or 3, whereas stage 1 corresponds to the common ancestor tree,
2 to our tree and 3 to 'their' tree.
:param tree_shas: 1, 2 or 3 trees as identified by their binary 20 byte shas
If 1 or two, the entries will effectively correspond to the last given tree
If 3 are given, a 3 way merge is performed"""
out: List[BaseIndexEntry] = []
# one and two way is the same for us, as we don't have to handle an existing
# index, instrea
if len(tree_shas) in (1, 2):
for entry in traverse_tree_recursive(odb, tree_shas[-1], ""):
out.append(_tree_entry_to_baseindexentry(entry, 0))
# END for each entry
return out
# END handle single tree
if len(tree_shas) > 3:
raise ValueError("Cannot handle %i trees at once" % len(tree_shas))
# three trees
for base, ours, theirs in traverse_trees_recursive(odb, tree_shas, ""):
if base is not None:
# base version exists
if ours is not None:
# ours exists
if theirs is not None:
# it exists in all branches, if it was changed in both
# its a conflict, otherwise we take the changed version
# This should be the most common branch, so it comes first
if (base[0] != ours[0] and base[0] != theirs[0] and ours[0] != theirs[0]) or (
base[1] != ours[1] and base[1] != theirs[1] and ours[1] != theirs[1]
):
# changed by both
out.append(_tree_entry_to_baseindexentry(base, 1))
out.append(_tree_entry_to_baseindexentry(ours, 2))
out.append(_tree_entry_to_baseindexentry(theirs, 3))
elif base[0] != ours[0] or base[1] != ours[1]:
# only we changed it
out.append(_tree_entry_to_baseindexentry(ours, 0))
else:
# either nobody changed it, or they did. In either
# case, use theirs
out.append(_tree_entry_to_baseindexentry(theirs, 0))
# END handle modification
else:
if ours[0] != base[0] or ours[1] != base[1]:
# they deleted it, we changed it, conflict
out.append(_tree_entry_to_baseindexentry(base, 1))
out.append(_tree_entry_to_baseindexentry(ours, 2))
# else:
# we didn't change it, ignore
# pass
# END handle our change
# END handle theirs
else:
if theirs is None:
# deleted in both, its fine - its out
pass
else:
if theirs[0] != base[0] or theirs[1] != base[1]:
# deleted in ours, changed theirs, conflict
out.append(_tree_entry_to_baseindexentry(base, 1))
out.append(_tree_entry_to_baseindexentry(theirs, 3))
# END theirs changed
# else:
# theirs didn't change
# pass
# END handle theirs
# END handle ours
else:
# all three can't be None
if ours is None:
# added in their branch
assert theirs is not None
out.append(_tree_entry_to_baseindexentry(theirs, 0))
elif theirs is None:
# added in our branch
out.append(_tree_entry_to_baseindexentry(ours, 0))
else:
# both have it, except for the base, see whether it changed
if ours[0] != theirs[0] or ours[1] != theirs[1]:
out.append(_tree_entry_to_baseindexentry(ours, 2))
out.append(_tree_entry_to_baseindexentry(theirs, 3))
else:
# it was added the same in both
out.append(_tree_entry_to_baseindexentry(ours, 0))
# END handle two items
# END handle heads
# END handle base exists
# END for each entries tuple
return out

View File

@@ -0,0 +1,191 @@
"""Module with additional types used by the index"""
from binascii import b2a_hex
from pathlib import Path
from .util import pack, unpack
from git.objects import Blob
# typing ----------------------------------------------------------------------
from typing import NamedTuple, Sequence, TYPE_CHECKING, Tuple, Union, cast, List
from git.types import PathLike
if TYPE_CHECKING:
from git.repo import Repo
StageType = int
# ---------------------------------------------------------------------------------
__all__ = ("BlobFilter", "BaseIndexEntry", "IndexEntry", "StageType")
# { Invariants
CE_NAMEMASK = 0x0FFF
CE_STAGEMASK = 0x3000
CE_EXTENDED = 0x4000
CE_VALID = 0x8000
CE_STAGESHIFT = 12
# } END invariants
class BlobFilter(object):
"""
Predicate to be used by iter_blobs allowing to filter only return blobs which
match the given list of directories or files.
The given paths are given relative to the repository.
"""
__slots__ = "paths"
def __init__(self, paths: Sequence[PathLike]) -> None:
"""
:param paths:
tuple or list of paths which are either pointing to directories or
to files relative to the current repository
"""
self.paths = paths
def __call__(self, stage_blob: Tuple[StageType, Blob]) -> bool:
blob_pathlike: PathLike = stage_blob[1].path
blob_path: Path = blob_pathlike if isinstance(blob_pathlike, Path) else Path(blob_pathlike)
for pathlike in self.paths:
path: Path = pathlike if isinstance(pathlike, Path) else Path(pathlike)
# TODO: Change to use `PosixPath.is_relative_to` once Python 3.8 is no longer supported.
filter_parts: List[str] = path.parts
blob_parts: List[str] = blob_path.parts
if len(filter_parts) > len(blob_parts):
continue
if all(i == j for i, j in zip(filter_parts, blob_parts)):
return True
return False
class BaseIndexEntryHelper(NamedTuple):
"""Typed namedtuple to provide named attribute access for BaseIndexEntry.
Needed to allow overriding __new__ in child class to preserve backwards compat."""
mode: int
binsha: bytes
flags: int
path: PathLike
ctime_bytes: bytes = pack(">LL", 0, 0)
mtime_bytes: bytes = pack(">LL", 0, 0)
dev: int = 0
inode: int = 0
uid: int = 0
gid: int = 0
size: int = 0
class BaseIndexEntry(BaseIndexEntryHelper):
"""Small Brother of an index entry which can be created to describe changes
done to the index in which case plenty of additional information is not required.
As the first 4 data members match exactly to the IndexEntry type, methods
expecting a BaseIndexEntry can also handle full IndexEntries even if they
use numeric indices for performance reasons.
"""
def __new__(
cls,
inp_tuple: Union[
Tuple[int, bytes, int, PathLike],
Tuple[int, bytes, int, PathLike, bytes, bytes, int, int, int, int, int],
],
) -> "BaseIndexEntry":
"""Override __new__ to allow construction from a tuple for backwards compatibility"""
return super().__new__(cls, *inp_tuple)
def __str__(self) -> str:
return "%o %s %i\t%s" % (self.mode, self.hexsha, self.stage, self.path)
def __repr__(self) -> str:
return "(%o, %s, %i, %s)" % (self.mode, self.hexsha, self.stage, self.path)
@property
def hexsha(self) -> str:
"""hex version of our sha"""
return b2a_hex(self.binsha).decode("ascii")
@property
def stage(self) -> int:
"""Stage of the entry, either:
* 0 = default stage
* 1 = stage before a merge or common ancestor entry in case of a 3 way merge
* 2 = stage of entries from the 'left' side of the merge
* 3 = stage of entries from the right side of the merge
:note: For more information, see http://www.kernel.org/pub/software/scm/git/docs/git-read-tree.html
"""
return (self.flags & CE_STAGEMASK) >> CE_STAGESHIFT
@classmethod
def from_blob(cls, blob: Blob, stage: int = 0) -> "BaseIndexEntry":
""":return: Fully equipped BaseIndexEntry at the given stage"""
return cls((blob.mode, blob.binsha, stage << CE_STAGESHIFT, blob.path))
def to_blob(self, repo: "Repo") -> Blob:
""":return: Blob using the information of this index entry"""
return Blob(repo, self.binsha, self.mode, self.path)
class IndexEntry(BaseIndexEntry):
"""Allows convenient access to IndexEntry data without completely unpacking it.
Attributes usully accessed often are cached in the tuple whereas others are
unpacked on demand.
See the properties for a mapping between names and tuple indices."""
@property
def ctime(self) -> Tuple[int, int]:
"""
:return:
Tuple(int_time_seconds_since_epoch, int_nano_seconds) of the
file's creation time"""
return cast(Tuple[int, int], unpack(">LL", self.ctime_bytes))
@property
def mtime(self) -> Tuple[int, int]:
"""See ctime property, but returns modification time"""
return cast(Tuple[int, int], unpack(">LL", self.mtime_bytes))
@classmethod
def from_base(cls, base: "BaseIndexEntry") -> "IndexEntry":
"""
:return:
Minimal entry as created from the given BaseIndexEntry instance.
Missing values will be set to null-like values
:param base: Instance of type BaseIndexEntry"""
time = pack(">LL", 0, 0)
return IndexEntry((base.mode, base.binsha, base.flags, base.path, time, time, 0, 0, 0, 0, 0))
@classmethod
def from_blob(cls, blob: Blob, stage: int = 0) -> "IndexEntry":
""":return: Minimal entry resembling the given blob object"""
time = pack(">LL", 0, 0)
return IndexEntry(
(
blob.mode,
blob.binsha,
stage << CE_STAGESHIFT,
blob.path,
time,
time,
0,
0,
0,
0,
blob.size,
)
)

View File

@@ -0,0 +1,119 @@
"""Module containing index utilities"""
from functools import wraps
import os
import struct
import tempfile
from git.compat import is_win
import os.path as osp
# typing ----------------------------------------------------------------------
from typing import Any, Callable, TYPE_CHECKING
from git.types import PathLike, _T
if TYPE_CHECKING:
from git.index import IndexFile
# ---------------------------------------------------------------------------------
__all__ = ("TemporaryFileSwap", "post_clear_cache", "default_index", "git_working_dir")
# { Aliases
pack = struct.pack
unpack = struct.unpack
# } END aliases
class TemporaryFileSwap(object):
"""Utility class moving a file to a temporary location within the same directory
and moving it back on to where on object deletion."""
__slots__ = ("file_path", "tmp_file_path")
def __init__(self, file_path: PathLike) -> None:
self.file_path = file_path
self.tmp_file_path = str(self.file_path) + tempfile.mktemp("", "", "")
# it may be that the source does not exist
try:
os.rename(self.file_path, self.tmp_file_path)
except OSError:
pass
def __del__(self) -> None:
if osp.isfile(self.tmp_file_path):
if is_win and osp.exists(self.file_path):
os.remove(self.file_path)
os.rename(self.tmp_file_path, self.file_path)
# END temp file exists
# { Decorators
def post_clear_cache(func: Callable[..., _T]) -> Callable[..., _T]:
"""Decorator for functions that alter the index using the git command. This would
invalidate our possibly existing entries dictionary which is why it must be
deleted to allow it to be lazily reread later.
:note:
This decorator will not be required once all functions are implemented
natively which in fact is possible, but probably not feasible performance wise.
"""
@wraps(func)
def post_clear_cache_if_not_raised(self: "IndexFile", *args: Any, **kwargs: Any) -> _T:
rval = func(self, *args, **kwargs)
self._delete_entries_cache()
return rval
# END wrapper method
return post_clear_cache_if_not_raised
def default_index(func: Callable[..., _T]) -> Callable[..., _T]:
"""Decorator assuring the wrapped method may only run if we are the default
repository index. This is as we rely on git commands that operate
on that index only."""
@wraps(func)
def check_default_index(self: "IndexFile", *args: Any, **kwargs: Any) -> _T:
if self._file_path != self._index_path():
raise AssertionError(
"Cannot call %r on indices that do not represent the default git index" % func.__name__
)
return func(self, *args, **kwargs)
# END wrapper method
return check_default_index
def git_working_dir(func: Callable[..., _T]) -> Callable[..., _T]:
"""Decorator which changes the current working dir to the one of the git
repository in order to assure relative paths are handled correctly"""
@wraps(func)
def set_git_working_dir(self: "IndexFile", *args: Any, **kwargs: Any) -> _T:
cur_wd = os.getcwd()
os.chdir(str(self.repo.working_tree_dir))
try:
return func(self, *args, **kwargs)
finally:
os.chdir(cur_wd)
# END handle working dir
# END wrapper
return set_git_working_dir
# } END decorators

View File

@@ -0,0 +1,24 @@
"""
Import all submodules main classes into the package space
"""
# flake8: noqa
import inspect
from .base import *
from .blob import *
from .commit import *
from .submodule import util as smutil
from .submodule.base import *
from .submodule.root import *
from .tag import *
from .tree import *
# Fix import dependency - add IndexObject to the util module, so that it can be
# imported by the submodule.base
smutil.IndexObject = IndexObject # type: ignore[attr-defined]
smutil.Object = Object # type: ignore[attr-defined]
del smutil
# must come after submodule was made available
__all__ = [name for name, obj in locals().items() if not (name.startswith("_") or inspect.ismodule(obj))]

View File

@@ -0,0 +1,224 @@
# base.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
from git.exc import WorkTreeRepositoryUnsupported
from git.util import LazyMixin, join_path_native, stream_copy, bin_to_hex
import gitdb.typ as dbtyp
import os.path as osp
from .util import get_object_type_by_name
# typing ------------------------------------------------------------------
from typing import Any, TYPE_CHECKING, Union
from git.types import PathLike, Commit_ish, Lit_commit_ish
if TYPE_CHECKING:
from git.repo import Repo
from gitdb.base import OStream
from .tree import Tree
from .blob import Blob
from .submodule.base import Submodule
from git.refs.reference import Reference
IndexObjUnion = Union["Tree", "Blob", "Submodule"]
# --------------------------------------------------------------------------
_assertion_msg_format = "Created object %r whose python type %r disagrees with the actual git object type %r"
__all__ = ("Object", "IndexObject")
class Object(LazyMixin):
"""Implements an Object which may be Blobs, Trees, Commits and Tags"""
NULL_HEX_SHA = "0" * 40
NULL_BIN_SHA = b"\0" * 20
TYPES = (
dbtyp.str_blob_type,
dbtyp.str_tree_type,
dbtyp.str_commit_type,
dbtyp.str_tag_type,
)
__slots__ = ("repo", "binsha", "size")
type: Union[Lit_commit_ish, None] = None
def __init__(self, repo: "Repo", binsha: bytes):
"""Initialize an object by identifying it by its binary sha.
All keyword arguments will be set on demand if None.
:param repo: repository this object is located in
:param binsha: 20 byte SHA1"""
super(Object, self).__init__()
self.repo = repo
self.binsha = binsha
assert len(binsha) == 20, "Require 20 byte binary sha, got %r, len = %i" % (
binsha,
len(binsha),
)
@classmethod
def new(cls, repo: "Repo", id: Union[str, "Reference"]) -> Commit_ish:
"""
:return: New Object instance of a type appropriate to the object type behind
id. The id of the newly created object will be a binsha even though
the input id may have been a Reference or Rev-Spec
:param id: reference, rev-spec, or hexsha
:note: This cannot be a __new__ method as it would always call __init__
with the input id which is not necessarily a binsha."""
return repo.rev_parse(str(id))
@classmethod
def new_from_sha(cls, repo: "Repo", sha1: bytes) -> Commit_ish:
"""
:return: new object instance of a type appropriate to represent the given
binary sha1
:param sha1: 20 byte binary sha1"""
if sha1 == cls.NULL_BIN_SHA:
# the NULL binsha is always the root commit
return get_object_type_by_name(b"commit")(repo, sha1)
# END handle special case
oinfo = repo.odb.info(sha1)
inst = get_object_type_by_name(oinfo.type)(repo, oinfo.binsha)
inst.size = oinfo.size
return inst
def _set_cache_(self, attr: str) -> None:
"""Retrieve object information"""
if attr == "size":
oinfo = self.repo.odb.info(self.binsha)
self.size = oinfo.size # type: int
# assert oinfo.type == self.type, _assertion_msg_format % (self.binsha, oinfo.type, self.type)
else:
super(Object, self)._set_cache_(attr)
def __eq__(self, other: Any) -> bool:
""":return: True if the objects have the same SHA1"""
if not hasattr(other, "binsha"):
return False
return self.binsha == other.binsha
def __ne__(self, other: Any) -> bool:
""":return: True if the objects do not have the same SHA1"""
if not hasattr(other, "binsha"):
return True
return self.binsha != other.binsha
def __hash__(self) -> int:
""":return: Hash of our id allowing objects to be used in dicts and sets"""
return hash(self.binsha)
def __str__(self) -> str:
""":return: string of our SHA1 as understood by all git commands"""
return self.hexsha
def __repr__(self) -> str:
""":return: string with pythonic representation of our object"""
return '<git.%s "%s">' % (self.__class__.__name__, self.hexsha)
@property
def hexsha(self) -> str:
""":return: 40 byte hex version of our 20 byte binary sha"""
# b2a_hex produces bytes
return bin_to_hex(self.binsha).decode("ascii")
@property
def data_stream(self) -> "OStream":
""":return: File Object compatible stream to the uncompressed raw data of the object
:note: returned streams must be read in order"""
return self.repo.odb.stream(self.binsha)
def stream_data(self, ostream: "OStream") -> "Object":
"""Writes our data directly to the given output stream
:param ostream: File object compatible stream object.
:return: self"""
istream = self.repo.odb.stream(self.binsha)
stream_copy(istream, ostream)
return self
class IndexObject(Object):
"""Base for all objects that can be part of the index file , namely Tree, Blob and
SubModule objects"""
__slots__ = ("path", "mode")
# for compatibility with iterable lists
_id_attribute_ = "path"
def __init__(
self,
repo: "Repo",
binsha: bytes,
mode: Union[None, int] = None,
path: Union[None, PathLike] = None,
) -> None:
"""Initialize a newly instanced IndexObject
:param repo: is the Repo we are located in
:param binsha: 20 byte sha1
:param mode:
is the stat compatible file mode as int, use the stat module
to evaluate the information
:param path:
is the path to the file in the file system, relative to the git repository root, i.e.
file.ext or folder/other.ext
:note:
Path may not be set of the index object has been created directly as it cannot
be retrieved without knowing the parent tree."""
super(IndexObject, self).__init__(repo, binsha)
if mode is not None:
self.mode = mode
if path is not None:
self.path = path
def __hash__(self) -> int:
"""
:return:
Hash of our path as index items are uniquely identifiable by path, not
by their data !"""
return hash(self.path)
def _set_cache_(self, attr: str) -> None:
if attr in IndexObject.__slots__:
# they cannot be retrieved lateron ( not without searching for them )
raise AttributeError(
"Attribute '%s' unset: path and mode attributes must have been set during %s object creation"
% (attr, type(self).__name__)
)
else:
super(IndexObject, self)._set_cache_(attr)
# END handle slot attribute
@property
def name(self) -> str:
""":return: Name portion of the path, effectively being the basename"""
return osp.basename(self.path)
@property
def abspath(self) -> PathLike:
"""
:return:
Absolute path to this index object in the file system ( as opposed to the
.path field which is a path relative to the git repository ).
The returned path will be native to the system and contains '\' on windows."""
if self.repo.working_tree_dir is not None:
return join_path_native(self.repo.working_tree_dir, self.path)
else:
raise WorkTreeRepositoryUnsupported("Working_tree_dir was None or empty")

View File

@@ -0,0 +1,36 @@
# blob.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
from mimetypes import guess_type
from . import base
from git.types import Literal
__all__ = ("Blob",)
class Blob(base.IndexObject):
"""A Blob encapsulates a git blob object"""
DEFAULT_MIME_TYPE = "text/plain"
type: Literal["blob"] = "blob"
# valid blob modes
executable_mode = 0o100755
file_mode = 0o100644
link_mode = 0o120000
__slots__ = ()
@property
def mime_type(self) -> str:
"""
:return: String describing the mime type of this file (based on the filename)
:note: Defaults to 'text/plain' in case the actual file type is unknown."""
guesses = None
if self.path:
guesses = guess_type(str(self.path))
return guesses and guesses[0] or self.DEFAULT_MIME_TYPE

View File

@@ -0,0 +1,762 @@
# commit.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
import datetime
import re
from subprocess import Popen, PIPE
from gitdb import IStream
from git.util import hex_to_bin, Actor, Stats, finalize_process
from git.diff import Diffable
from git.cmd import Git
from .tree import Tree
from . import base
from .util import (
Serializable,
TraversableIterableObj,
parse_date,
altz_to_utctz_str,
parse_actor_and_date,
from_timestamp,
)
from time import time, daylight, altzone, timezone, localtime
import os
from io import BytesIO
import logging
# typing ------------------------------------------------------------------
from typing import (
Any,
IO,
Iterator,
List,
Sequence,
Tuple,
Union,
TYPE_CHECKING,
cast,
Dict,
)
from git.types import PathLike, Literal
if TYPE_CHECKING:
from git.repo import Repo
from git.refs import SymbolicReference
# ------------------------------------------------------------------------
log = logging.getLogger("git.objects.commit")
log.addHandler(logging.NullHandler())
__all__ = ("Commit",)
class Commit(base.Object, TraversableIterableObj, Diffable, Serializable):
"""Wraps a git Commit object.
This class will act lazily on some of its attributes and will query the
value on demand only if it involves calling the git binary."""
# ENVIRONMENT VARIABLES
# read when creating new commits
env_author_date = "GIT_AUTHOR_DATE"
env_committer_date = "GIT_COMMITTER_DATE"
# CONFIGURATION KEYS
conf_encoding = "i18n.commitencoding"
# INVARIANTS
default_encoding = "UTF-8"
# object configuration
type: Literal["commit"] = "commit"
__slots__ = (
"tree",
"author",
"authored_date",
"author_tz_offset",
"committer",
"committed_date",
"committer_tz_offset",
"message",
"parents",
"encoding",
"gpgsig",
)
_id_attribute_ = "hexsha"
def __init__(
self,
repo: "Repo",
binsha: bytes,
tree: Union[Tree, None] = None,
author: Union[Actor, None] = None,
authored_date: Union[int, None] = None,
author_tz_offset: Union[None, float] = None,
committer: Union[Actor, None] = None,
committed_date: Union[int, None] = None,
committer_tz_offset: Union[None, float] = None,
message: Union[str, bytes, None] = None,
parents: Union[Sequence["Commit"], None] = None,
encoding: Union[str, None] = None,
gpgsig: Union[str, None] = None,
) -> None:
"""Instantiate a new Commit. All keyword arguments taking None as default will
be implicitly set on first query.
:param binsha: 20 byte sha1
:param parents: tuple( Commit, ... )
is a tuple of commit ids or actual Commits
:param tree: Tree object
:param author: Actor
is the author Actor object
:param authored_date: int_seconds_since_epoch
is the authored DateTime - use time.gmtime() to convert it into a
different format
:param author_tz_offset: int_seconds_west_of_utc
is the timezone that the authored_date is in
:param committer: Actor
is the committer string
:param committed_date: int_seconds_since_epoch
is the committed DateTime - use time.gmtime() to convert it into a
different format
:param committer_tz_offset: int_seconds_west_of_utc
is the timezone that the committed_date is in
:param message: string
is the commit message
:param encoding: string
encoding of the message, defaults to UTF-8
:param parents:
List or tuple of Commit objects which are our parent(s) in the commit
dependency graph
:return: git.Commit
:note:
Timezone information is in the same format and in the same sign
as what time.altzone returns. The sign is inverted compared to git's
UTC timezone."""
super(Commit, self).__init__(repo, binsha)
self.binsha = binsha
if tree is not None:
assert isinstance(tree, Tree), "Tree needs to be a Tree instance, was %s" % type(tree)
if tree is not None:
self.tree = tree
if author is not None:
self.author = author
if authored_date is not None:
self.authored_date = authored_date
if author_tz_offset is not None:
self.author_tz_offset = author_tz_offset
if committer is not None:
self.committer = committer
if committed_date is not None:
self.committed_date = committed_date
if committer_tz_offset is not None:
self.committer_tz_offset = committer_tz_offset
if message is not None:
self.message = message
if parents is not None:
self.parents = parents
if encoding is not None:
self.encoding = encoding
if gpgsig is not None:
self.gpgsig = gpgsig
@classmethod
def _get_intermediate_items(cls, commit: "Commit") -> Tuple["Commit", ...]:
return tuple(commit.parents)
@classmethod
def _calculate_sha_(cls, repo: "Repo", commit: "Commit") -> bytes:
"""Calculate the sha of a commit.
:param repo: Repo object the commit should be part of
:param commit: Commit object for which to generate the sha
"""
stream = BytesIO()
commit._serialize(stream)
streamlen = stream.tell()
stream.seek(0)
istream = repo.odb.store(IStream(cls.type, streamlen, stream))
return istream.binsha
def replace(self, **kwargs: Any) -> "Commit":
"""Create new commit object from existing commit object.
Any values provided as keyword arguments will replace the
corresponding attribute in the new object.
"""
attrs = {k: getattr(self, k) for k in self.__slots__}
for attrname in kwargs:
if attrname not in self.__slots__:
raise ValueError("invalid attribute name")
attrs.update(kwargs)
new_commit = self.__class__(self.repo, self.NULL_BIN_SHA, **attrs)
new_commit.binsha = self._calculate_sha_(self.repo, new_commit)
return new_commit
def _set_cache_(self, attr: str) -> None:
if attr in Commit.__slots__:
# read the data in a chunk, its faster - then provide a file wrapper
_binsha, _typename, self.size, stream = self.repo.odb.stream(self.binsha)
self._deserialize(BytesIO(stream.read()))
else:
super(Commit, self)._set_cache_(attr)
# END handle attrs
@property
def authored_datetime(self) -> datetime.datetime:
return from_timestamp(self.authored_date, self.author_tz_offset)
@property
def committed_datetime(self) -> datetime.datetime:
return from_timestamp(self.committed_date, self.committer_tz_offset)
@property
def summary(self) -> Union[str, bytes]:
""":return: First line of the commit message"""
if isinstance(self.message, str):
return self.message.split("\n", 1)[0]
else:
return self.message.split(b"\n", 1)[0]
def count(self, paths: Union[PathLike, Sequence[PathLike]] = "", **kwargs: Any) -> int:
"""Count the number of commits reachable from this commit
:param paths:
is an optional path or a list of paths restricting the return value
to commits actually containing the paths
:param kwargs:
Additional options to be passed to git-rev-list. They must not alter
the output style of the command, or parsing will yield incorrect results
:return: int defining the number of reachable commits"""
# yes, it makes a difference whether empty paths are given or not in our case
# as the empty paths version will ignore merge commits for some reason.
if paths:
return len(self.repo.git.rev_list(self.hexsha, "--", paths, **kwargs).splitlines())
return len(self.repo.git.rev_list(self.hexsha, **kwargs).splitlines())
@property
def name_rev(self) -> str:
"""
:return:
String describing the commits hex sha based on the closest Reference.
Mostly useful for UI purposes"""
return self.repo.git.name_rev(self)
@classmethod
def iter_items(
cls,
repo: "Repo",
rev: Union[str, "Commit", "SymbolicReference"], # type: ignore
paths: Union[PathLike, Sequence[PathLike]] = "",
**kwargs: Any,
) -> Iterator["Commit"]:
"""Find all commits matching the given criteria.
:param repo: is the Repo
:param rev: revision specifier, see git-rev-parse for viable options
:param paths:
is an optional path or list of paths, if set only Commits that include the path
or paths will be considered
:param kwargs:
optional keyword arguments to git rev-list where
``max_count`` is the maximum number of commits to fetch
``skip`` is the number of commits to skip
``since`` all commits since i.e. '1970-01-01'
:return: iterator yielding Commit items"""
if "pretty" in kwargs:
raise ValueError("--pretty cannot be used as parsing expects single sha's only")
# END handle pretty
# use -- in any case, to prevent possibility of ambiguous arguments
# see https://github.com/gitpython-developers/GitPython/issues/264
args_list: List[PathLike] = ["--"]
if paths:
paths_tup: Tuple[PathLike, ...]
if isinstance(paths, (str, os.PathLike)):
paths_tup = (paths,)
else:
paths_tup = tuple(paths)
args_list.extend(paths_tup)
# END if paths
proc = repo.git.rev_list(rev, args_list, as_process=True, **kwargs)
return cls._iter_from_process_or_stream(repo, proc)
def iter_parents(self, paths: Union[PathLike, Sequence[PathLike]] = "", **kwargs: Any) -> Iterator["Commit"]:
"""Iterate _all_ parents of this commit.
:param paths:
Optional path or list of paths limiting the Commits to those that
contain at least one of the paths
:param kwargs: All arguments allowed by git-rev-list
:return: Iterator yielding Commit objects which are parents of self"""
# skip ourselves
skip = kwargs.get("skip", 1)
if skip == 0: # skip ourselves
skip = 1
kwargs["skip"] = skip
return self.iter_items(self.repo, self, paths, **kwargs)
@property
def stats(self) -> Stats:
"""Create a git stat from changes between this commit and its first parent
or from all changes done if this is the very first commit.
:return: git.Stats"""
if not self.parents:
text = self.repo.git.diff_tree(self.hexsha, "--", numstat=True, no_renames=True, root=True)
text2 = ""
for line in text.splitlines()[1:]:
(insertions, deletions, filename) = line.split("\t")
text2 += "%s\t%s\t%s\n" % (insertions, deletions, filename)
text = text2
else:
text = self.repo.git.diff(self.parents[0].hexsha, self.hexsha, "--", numstat=True, no_renames=True)
return Stats._list_from_string(self.repo, text)
@property
def trailers(self) -> Dict:
"""Get the trailers of the message as dictionary
Git messages can contain trailer information that are similar to RFC 822
e-mail headers (see: https://git-scm.com/docs/git-interpret-trailers).
This functions calls ``git interpret-trailers --parse`` onto the message
to extract the trailer information. The key value pairs are stripped of
leading and trailing whitespaces before they get saved into a dictionary.
Valid message with trailer:
.. code-block::
Subject line
some body information
another information
key1: value1
key2 : value 2 with inner spaces
dictionary will look like this:
.. code-block::
{
"key1": "value1",
"key2": "value 2 with inner spaces"
}
:return: Dictionary containing whitespace stripped trailer information
"""
d = {}
cmd = ["git", "interpret-trailers", "--parse"]
proc: Git.AutoInterrupt = self.repo.git.execute(cmd, as_process=True, istream=PIPE) # type: ignore
trailer: str = proc.communicate(str(self.message).encode())[0].decode()
if trailer.endswith("\n"):
trailer = trailer[0:-1]
if trailer != "":
for line in trailer.split("\n"):
key, value = line.split(":", 1)
d[key.strip()] = value.strip()
return d
@classmethod
def _iter_from_process_or_stream(cls, repo: "Repo", proc_or_stream: Union[Popen, IO]) -> Iterator["Commit"]:
"""Parse out commit information into a list of Commit objects
We expect one-line per commit, and parse the actual commit information directly
from our lighting fast object database
:param proc: git-rev-list process instance - one sha per line
:return: iterator returning Commit objects"""
# def is_proc(inp) -> TypeGuard[Popen]:
# return hasattr(proc_or_stream, 'wait') and not hasattr(proc_or_stream, 'readline')
# def is_stream(inp) -> TypeGuard[IO]:
# return hasattr(proc_or_stream, 'readline')
if hasattr(proc_or_stream, "wait"):
proc_or_stream = cast(Popen, proc_or_stream)
if proc_or_stream.stdout is not None:
stream = proc_or_stream.stdout
elif hasattr(proc_or_stream, "readline"):
proc_or_stream = cast(IO, proc_or_stream)
stream = proc_or_stream
readline = stream.readline
while True:
line = readline()
if not line:
break
hexsha = line.strip()
if len(hexsha) > 40:
# split additional information, as returned by bisect for instance
hexsha, _ = line.split(None, 1)
# END handle extra info
assert len(hexsha) == 40, "Invalid line: %s" % hexsha
yield cls(repo, hex_to_bin(hexsha))
# END for each line in stream
# TODO: Review this - it seems process handling got a bit out of control
# due to many developers trying to fix the open file handles issue
if hasattr(proc_or_stream, "wait"):
proc_or_stream = cast(Popen, proc_or_stream)
finalize_process(proc_or_stream)
@classmethod
def create_from_tree(
cls,
repo: "Repo",
tree: Union[Tree, str],
message: str,
parent_commits: Union[None, List["Commit"]] = None,
head: bool = False,
author: Union[None, Actor] = None,
committer: Union[None, Actor] = None,
author_date: Union[None, str, datetime.datetime] = None,
commit_date: Union[None, str, datetime.datetime] = None,
) -> "Commit":
"""Commit the given tree, creating a commit object.
:param repo: Repo object the commit should be part of
:param tree: Tree object or hex or bin sha
the tree of the new commit
:param message: Commit message. It may be an empty string if no message is provided.
It will be converted to a string , in any case.
:param parent_commits:
Optional Commit objects to use as parents for the new commit.
If empty list, the commit will have no parents at all and become
a root commit.
If None , the current head commit will be the parent of the
new commit object
:param head:
If True, the HEAD will be advanced to the new commit automatically.
Else the HEAD will remain pointing on the previous commit. This could
lead to undesired results when diffing files.
:param author: The name of the author, optional. If unset, the repository
configuration is used to obtain this value.
:param committer: The name of the committer, optional. If unset, the
repository configuration is used to obtain this value.
:param author_date: The timestamp for the author field
:param commit_date: The timestamp for the committer field
:return: Commit object representing the new commit
:note:
Additional information about the committer and Author are taken from the
environment or from the git configuration, see git-commit-tree for
more information"""
if parent_commits is None:
try:
parent_commits = [repo.head.commit]
except ValueError:
# empty repositories have no head commit
parent_commits = []
# END handle parent commits
else:
for p in parent_commits:
if not isinstance(p, cls):
raise ValueError(f"Parent commit '{p!r}' must be of type {cls}")
# end check parent commit types
# END if parent commits are unset
# retrieve all additional information, create a commit object, and
# serialize it
# Generally:
# * Environment variables override configuration values
# * Sensible defaults are set according to the git documentation
# COMMITTER AND AUTHOR INFO
cr = repo.config_reader()
env = os.environ
committer = committer or Actor.committer(cr)
author = author or Actor.author(cr)
# PARSE THE DATES
unix_time = int(time())
is_dst = daylight and localtime().tm_isdst > 0
offset = altzone if is_dst else timezone
author_date_str = env.get(cls.env_author_date, "")
if author_date:
author_time, author_offset = parse_date(author_date)
elif author_date_str:
author_time, author_offset = parse_date(author_date_str)
else:
author_time, author_offset = unix_time, offset
# END set author time
committer_date_str = env.get(cls.env_committer_date, "")
if commit_date:
committer_time, committer_offset = parse_date(commit_date)
elif committer_date_str:
committer_time, committer_offset = parse_date(committer_date_str)
else:
committer_time, committer_offset = unix_time, offset
# END set committer time
# assume utf8 encoding
enc_section, enc_option = cls.conf_encoding.split(".")
conf_encoding = cr.get_value(enc_section, enc_option, cls.default_encoding)
if not isinstance(conf_encoding, str):
raise TypeError("conf_encoding could not be coerced to str")
# if the tree is no object, make sure we create one - otherwise
# the created commit object is invalid
if isinstance(tree, str):
tree = repo.tree(tree)
# END tree conversion
# CREATE NEW COMMIT
new_commit = cls(
repo,
cls.NULL_BIN_SHA,
tree,
author,
author_time,
author_offset,
committer,
committer_time,
committer_offset,
message,
parent_commits,
conf_encoding,
)
new_commit.binsha = cls._calculate_sha_(repo, new_commit)
if head:
# need late import here, importing git at the very beginning throws
# as well ...
import git.refs
try:
repo.head.set_commit(new_commit, logmsg=message)
except ValueError:
# head is not yet set to the ref our HEAD points to
# Happens on first commit
master = git.refs.Head.create(
repo,
repo.head.ref,
new_commit,
logmsg="commit (initial): %s" % message,
)
repo.head.set_reference(master, logmsg="commit: Switching to %s" % master)
# END handle empty repositories
# END advance head handling
return new_commit
# { Serializable Implementation
def _serialize(self, stream: BytesIO) -> "Commit":
write = stream.write
write(("tree %s\n" % self.tree).encode("ascii"))
for p in self.parents:
write(("parent %s\n" % p).encode("ascii"))
a = self.author
aname = a.name
c = self.committer
fmt = "%s %s <%s> %s %s\n"
write(
(
fmt
% (
"author",
aname,
a.email,
self.authored_date,
altz_to_utctz_str(self.author_tz_offset),
)
).encode(self.encoding)
)
# encode committer
aname = c.name
write(
(
fmt
% (
"committer",
aname,
c.email,
self.committed_date,
altz_to_utctz_str(self.committer_tz_offset),
)
).encode(self.encoding)
)
if self.encoding != self.default_encoding:
write(("encoding %s\n" % self.encoding).encode("ascii"))
try:
if self.__getattribute__("gpgsig"):
write(b"gpgsig")
for sigline in self.gpgsig.rstrip("\n").split("\n"):
write((" " + sigline + "\n").encode("ascii"))
except AttributeError:
pass
write(b"\n")
# write plain bytes, be sure its encoded according to our encoding
if isinstance(self.message, str):
write(self.message.encode(self.encoding))
else:
write(self.message)
# END handle encoding
return self
def _deserialize(self, stream: BytesIO) -> "Commit":
"""
:param from_rev_list: if true, the stream format is coming from the rev-list command
Otherwise it is assumed to be a plain data stream from our object
"""
readline = stream.readline
self.tree = Tree(self.repo, hex_to_bin(readline().split()[1]), Tree.tree_id << 12, "")
self.parents = []
next_line = None
while True:
parent_line = readline()
if not parent_line.startswith(b"parent"):
next_line = parent_line
break
# END abort reading parents
self.parents.append(type(self)(self.repo, hex_to_bin(parent_line.split()[-1].decode("ascii"))))
# END for each parent line
self.parents = tuple(self.parents)
# we don't know actual author encoding before we have parsed it, so keep the lines around
author_line = next_line
committer_line = readline()
# we might run into one or more mergetag blocks, skip those for now
next_line = readline()
while next_line.startswith(b"mergetag "):
next_line = readline()
while next_line.startswith(b" "):
next_line = readline()
# end skip mergetags
# now we can have the encoding line, or an empty line followed by the optional
# message.
self.encoding = self.default_encoding
self.gpgsig = ""
# read headers
enc = next_line
buf = enc.strip()
while buf:
if buf[0:10] == b"encoding ":
self.encoding = buf[buf.find(b" ") + 1 :].decode(self.encoding, "ignore")
elif buf[0:7] == b"gpgsig ":
sig = buf[buf.find(b" ") + 1 :] + b"\n"
is_next_header = False
while True:
sigbuf = readline()
if not sigbuf:
break
if sigbuf[0:1] != b" ":
buf = sigbuf.strip()
is_next_header = True
break
sig += sigbuf[1:]
# end read all signature
self.gpgsig = sig.rstrip(b"\n").decode(self.encoding, "ignore")
if is_next_header:
continue
buf = readline().strip()
# decode the authors name
try:
(
self.author,
self.authored_date,
self.author_tz_offset,
) = parse_actor_and_date(author_line.decode(self.encoding, "replace"))
except UnicodeDecodeError:
log.error(
"Failed to decode author line '%s' using encoding %s",
author_line,
self.encoding,
exc_info=True,
)
try:
(
self.committer,
self.committed_date,
self.committer_tz_offset,
) = parse_actor_and_date(committer_line.decode(self.encoding, "replace"))
except UnicodeDecodeError:
log.error(
"Failed to decode committer line '%s' using encoding %s",
committer_line,
self.encoding,
exc_info=True,
)
# END handle author's encoding
# a stream from our data simply gives us the plain message
# The end of our message stream is marked with a newline that we strip
self.message = stream.read()
try:
self.message = self.message.decode(self.encoding, "replace")
except UnicodeDecodeError:
log.error(
"Failed to decode message '%s' using encoding %s",
self.message,
self.encoding,
exc_info=True,
)
# END exception handling
return self
# } END serializable implementation
@property
def co_authors(self) -> List[Actor]:
"""
Search the commit message for any co-authors of this commit.
Details on co-authors: https://github.blog/2018-01-29-commit-together-with-co-authors/
:return: List of co-authors for this commit (as Actor objects).
"""
co_authors = []
if self.message:
results = re.findall(
r"^Co-authored-by: (.*) <(.*?)>$",
self.message,
re.MULTILINE,
)
for author in results:
co_authors.append(Actor(*author))
return co_authors

View File

@@ -0,0 +1,254 @@
"""Module with functions which are supposed to be as fast as possible"""
from stat import S_ISDIR
from git.compat import safe_decode, defenc
# typing ----------------------------------------------
from typing import (
Callable,
List,
MutableSequence,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
overload,
)
if TYPE_CHECKING:
from _typeshed import ReadableBuffer
from git import GitCmdObjectDB
EntryTup = Tuple[bytes, int, str] # same as TreeCacheTup in tree.py
EntryTupOrNone = Union[EntryTup, None]
# ---------------------------------------------------
__all__ = (
"tree_to_stream",
"tree_entries_from_data",
"traverse_trees_recursive",
"traverse_tree_recursive",
)
def tree_to_stream(entries: Sequence[EntryTup], write: Callable[["ReadableBuffer"], Union[int, None]]) -> None:
"""Write the give list of entries into a stream using its write method
:param entries: **sorted** list of tuples with (binsha, mode, name)
:param write: write method which takes a data string"""
ord_zero = ord("0")
bit_mask = 7 # 3 bits set
for binsha, mode, name in entries:
mode_str = b""
for i in range(6):
mode_str = bytes([((mode >> (i * 3)) & bit_mask) + ord_zero]) + mode_str
# END for each 8 octal value
# git slices away the first octal if its zero
if mode_str[0] == ord_zero:
mode_str = mode_str[1:]
# END save a byte
# here it comes: if the name is actually unicode, the replacement below
# will not work as the binsha is not part of the ascii unicode encoding -
# hence we must convert to an utf8 string for it to work properly.
# According to my tests, this is exactly what git does, that is it just
# takes the input literally, which appears to be utf8 on linux.
if isinstance(name, str):
name_bytes = name.encode(defenc)
else:
name_bytes = name # type: ignore[unreachable] # check runtime types - is always str?
write(b"".join((mode_str, b" ", name_bytes, b"\0", binsha)))
# END for each item
def tree_entries_from_data(data: bytes) -> List[EntryTup]:
"""Reads the binary representation of a tree and returns tuples of Tree items
:param data: data block with tree data (as bytes)
:return: list(tuple(binsha, mode, tree_relative_path), ...)"""
ord_zero = ord("0")
space_ord = ord(" ")
len_data = len(data)
i = 0
out = []
while i < len_data:
mode = 0
# read mode
# Some git versions truncate the leading 0, some don't
# The type will be extracted from the mode later
while data[i] != space_ord:
# move existing mode integer up one level being 3 bits
# and add the actual ordinal value of the character
mode = (mode << 3) + (data[i] - ord_zero)
i += 1
# END while reading mode
# byte is space now, skip it
i += 1
# parse name, it is NULL separated
ns = i
while data[i] != 0:
i += 1
# END while not reached NULL
# default encoding for strings in git is utf8
# Only use the respective unicode object if the byte stream was encoded
name_bytes = data[ns:i]
name = safe_decode(name_bytes)
# byte is NULL, get next 20
i += 1
sha = data[i : i + 20]
i = i + 20
out.append((sha, mode, name))
# END for each byte in data stream
return out
def _find_by_name(tree_data: MutableSequence[EntryTupOrNone], name: str, is_dir: bool, start_at: int) -> EntryTupOrNone:
"""return data entry matching the given name and tree mode
or None.
Before the item is returned, the respective data item is set
None in the tree_data list to mark it done"""
try:
item = tree_data[start_at]
if item and item[2] == name and S_ISDIR(item[1]) == is_dir:
tree_data[start_at] = None
return item
except IndexError:
pass
# END exception handling
for index, item in enumerate(tree_data):
if item and item[2] == name and S_ISDIR(item[1]) == is_dir:
tree_data[index] = None
return item
# END if item matches
# END for each item
return None
@overload
def _to_full_path(item: None, path_prefix: str) -> None:
...
@overload
def _to_full_path(item: EntryTup, path_prefix: str) -> EntryTup:
...
def _to_full_path(item: EntryTupOrNone, path_prefix: str) -> EntryTupOrNone:
"""Rebuild entry with given path prefix"""
if not item:
return item
return (item[0], item[1], path_prefix + item[2])
def traverse_trees_recursive(
odb: "GitCmdObjectDB", tree_shas: Sequence[Union[bytes, None]], path_prefix: str
) -> List[Tuple[EntryTupOrNone, ...]]:
"""
:return: list of list with entries according to the given binary tree-shas.
The result is encoded in a list
of n tuple|None per blob/commit, (n == len(tree_shas)), where
* [0] == 20 byte sha
* [1] == mode as int
* [2] == path relative to working tree root
The entry tuple is None if the respective blob/commit did not
exist in the given tree.
:param tree_shas: iterable of shas pointing to trees. All trees must
be on the same level. A tree-sha may be None in which case None
:param path_prefix: a prefix to be added to the returned paths on this level,
set it '' for the first iteration
:note: The ordering of the returned items will be partially lost"""
trees_data: List[List[EntryTupOrNone]] = []
nt = len(tree_shas)
for tree_sha in tree_shas:
if tree_sha is None:
data: List[EntryTupOrNone] = []
else:
# make new list for typing as list invariant
data = list(tree_entries_from_data(odb.stream(tree_sha).read()))
# END handle muted trees
trees_data.append(data)
# END for each sha to get data for
out: List[Tuple[EntryTupOrNone, ...]] = []
# find all matching entries and recursively process them together if the match
# is a tree. If the match is a non-tree item, put it into the result.
# Processed items will be set None
for ti, tree_data in enumerate(trees_data):
for ii, item in enumerate(tree_data):
if not item:
continue
# END skip already done items
entries: List[EntryTupOrNone]
entries = [None for _ in range(nt)]
entries[ti] = item
_sha, mode, name = item
is_dir = S_ISDIR(mode) # type mode bits
# find this item in all other tree data items
# wrap around, but stop one before our current index, hence
# ti+nt, not ti+1+nt
for tio in range(ti + 1, ti + nt):
tio = tio % nt
entries[tio] = _find_by_name(trees_data[tio], name, is_dir, ii)
# END for each other item data
# if we are a directory, enter recursion
if is_dir:
out.extend(
traverse_trees_recursive(
odb,
[((ei and ei[0]) or None) for ei in entries],
path_prefix + name + "/",
)
)
else:
out.append(tuple(_to_full_path(e, path_prefix) for e in entries))
# END handle recursion
# finally mark it done
tree_data[ii] = None
# END for each item
# we are done with one tree, set all its data empty
del tree_data[:]
# END for each tree_data chunk
return out
def traverse_tree_recursive(odb: "GitCmdObjectDB", tree_sha: bytes, path_prefix: str) -> List[EntryTup]:
"""
:return: list of entries of the tree pointed to by the binary tree_sha. An entry
has the following format:
* [0] 20 byte sha
* [1] mode as int
* [2] path relative to the repository
:param path_prefix: prefix to prepend to the front of all returned paths"""
entries = []
data = tree_entries_from_data(odb.stream(tree_sha).read())
# unpacking/packing is faster than accessing individual items
for sha, mode, name in data:
if S_ISDIR(mode):
entries.extend(traverse_tree_recursive(odb, sha, path_prefix + name + "/"))
else:
entries.append((sha, mode, path_prefix + name))
# END for each item
return entries

View File

@@ -0,0 +1,2 @@
# NOTE: Cannot import anything here as the top-level _init_ has to handle
# our dependencies

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,426 @@
from .base import Submodule, UpdateProgress
from .util import find_first_remote_branch
from git.exc import InvalidGitRepositoryError
import git
import logging
# typing -------------------------------------------------------------------
from typing import TYPE_CHECKING, Union
from git.types import Commit_ish
if TYPE_CHECKING:
from git.repo import Repo
from git.util import IterableList
# ----------------------------------------------------------------------------
__all__ = ["RootModule", "RootUpdateProgress"]
log = logging.getLogger("git.objects.submodule.root")
log.addHandler(logging.NullHandler())
class RootUpdateProgress(UpdateProgress):
"""Utility class which adds more opcodes to the UpdateProgress"""
REMOVE, PATHCHANGE, BRANCHCHANGE, URLCHANGE = [
1 << x for x in range(UpdateProgress._num_op_codes, UpdateProgress._num_op_codes + 4)
]
_num_op_codes = UpdateProgress._num_op_codes + 4
__slots__ = ()
BEGIN = RootUpdateProgress.BEGIN
END = RootUpdateProgress.END
REMOVE = RootUpdateProgress.REMOVE
BRANCHCHANGE = RootUpdateProgress.BRANCHCHANGE
URLCHANGE = RootUpdateProgress.URLCHANGE
PATHCHANGE = RootUpdateProgress.PATHCHANGE
class RootModule(Submodule):
"""A (virtual) Root of all submodules in the given repository. It can be used
to more easily traverse all submodules of the master repository"""
__slots__ = ()
k_root_name = "__ROOT__"
def __init__(self, repo: "Repo"):
# repo, binsha, mode=None, path=None, name = None, parent_commit=None, url=None, ref=None)
super(RootModule, self).__init__(
repo,
binsha=self.NULL_BIN_SHA,
mode=self.k_default_mode,
path="",
name=self.k_root_name,
parent_commit=repo.head.commit,
url="",
branch_path=git.Head.to_full_path(self.k_head_default),
)
def _clear_cache(self) -> None:
"""May not do anything"""
pass
# { Interface
def update(
self,
previous_commit: Union[Commit_ish, None] = None, # type: ignore[override]
recursive: bool = True,
force_remove: bool = False,
init: bool = True,
to_latest_revision: bool = False,
progress: Union[None, "RootUpdateProgress"] = None,
dry_run: bool = False,
force_reset: bool = False,
keep_going: bool = False,
) -> "RootModule":
"""Update the submodules of this repository to the current HEAD commit.
This method behaves smartly by determining changes of the path of a submodules
repository, next to changes to the to-be-checked-out commit or the branch to be
checked out. This works if the submodules ID does not change.
Additionally it will detect addition and removal of submodules, which will be handled
gracefully.
:param previous_commit: If set to a commit'ish, the commit we should use
as the previous commit the HEAD pointed to before it was set to the commit it points to now.
If None, it defaults to HEAD@{1} otherwise
:param recursive: if True, the children of submodules will be updated as well
using the same technique
:param force_remove: If submodules have been deleted, they will be forcibly removed.
Otherwise the update may fail if a submodule's repository cannot be deleted as
changes have been made to it (see Submodule.update() for more information)
:param init: If we encounter a new module which would need to be initialized, then do it.
:param to_latest_revision: If True, instead of checking out the revision pointed to
by this submodule's sha, the checked out tracking branch will be merged with the
latest remote branch fetched from the repository's origin.
Unless force_reset is specified, a local tracking branch will never be reset into its past, therefore
the remote branch must be in the future for this to have an effect.
:param force_reset: if True, submodules may checkout or reset their branch even if the repository has
pending changes that would be overwritten, or if the local tracking branch is in the future of the
remote tracking branch and would be reset into its past.
:param progress: RootUpdateProgress instance or None if no progress should be sent
:param dry_run: if True, operations will not actually be performed. Progress messages
will change accordingly to indicate the WOULD DO state of the operation.
:param keep_going: if True, we will ignore but log all errors, and keep going recursively.
Unless dry_run is set as well, keep_going could cause subsequent/inherited errors you wouldn't see
otherwise.
In conjunction with dry_run, it can be useful to anticipate all errors when updating submodules
:return: self"""
if self.repo.bare:
raise InvalidGitRepositoryError("Cannot update submodules in bare repositories")
# END handle bare
if progress is None:
progress = RootUpdateProgress()
# END assure progress is set
prefix = ""
if dry_run:
prefix = "DRY-RUN: "
repo = self.repo
try:
# SETUP BASE COMMIT
###################
cur_commit = repo.head.commit
if previous_commit is None:
try:
previous_commit = repo.commit(repo.head.log_entry(-1).oldhexsha)
if previous_commit.binsha == previous_commit.NULL_BIN_SHA:
raise IndexError
# END handle initial commit
except IndexError:
# in new repositories, there is no previous commit
previous_commit = cur_commit
# END exception handling
else:
previous_commit = repo.commit(previous_commit) # obtain commit object
# END handle previous commit
psms: "IterableList[Submodule]" = self.list_items(repo, parent_commit=previous_commit)
sms: "IterableList[Submodule]" = self.list_items(repo)
spsms = set(psms)
ssms = set(sms)
# HANDLE REMOVALS
###################
rrsm = spsms - ssms
len_rrsm = len(rrsm)
for i, rsm in enumerate(rrsm):
op = REMOVE
if i == 0:
op |= BEGIN
# END handle begin
# fake it into thinking its at the current commit to allow deletion
# of previous module. Trigger the cache to be updated before that
progress.update(
op,
i,
len_rrsm,
prefix + "Removing submodule %r at %s" % (rsm.name, rsm.abspath),
)
rsm._parent_commit = repo.head.commit
rsm.remove(
configuration=False,
module=True,
force=force_remove,
dry_run=dry_run,
)
if i == len_rrsm - 1:
op |= END
# END handle end
progress.update(op, i, len_rrsm, prefix + "Done removing submodule %r" % rsm.name)
# END for each removed submodule
# HANDLE PATH RENAMES
#####################
# url changes + branch changes
csms = spsms & ssms
len_csms = len(csms)
for i, csm in enumerate(csms):
psm: "Submodule" = psms[csm.name]
sm: "Submodule" = sms[csm.name]
# PATH CHANGES
##############
if sm.path != psm.path and psm.module_exists():
progress.update(
BEGIN | PATHCHANGE,
i,
len_csms,
prefix + "Moving repository of submodule %r from %s to %s" % (sm.name, psm.abspath, sm.abspath),
)
# move the module to the new path
if not dry_run:
psm.move(sm.path, module=True, configuration=False)
# END handle dry_run
progress.update(
END | PATHCHANGE,
i,
len_csms,
prefix + "Done moving repository of submodule %r" % sm.name,
)
# END handle path changes
if sm.module_exists():
# HANDLE URL CHANGE
###################
if sm.url != psm.url:
# Add the new remote, remove the old one
# This way, if the url just changes, the commits will not
# have to be re-retrieved
nn = "__new_origin__"
smm = sm.module()
rmts = smm.remotes
# don't do anything if we already have the url we search in place
if len([r for r in rmts if r.url == sm.url]) == 0:
progress.update(
BEGIN | URLCHANGE,
i,
len_csms,
prefix + "Changing url of submodule %r from %s to %s" % (sm.name, psm.url, sm.url),
)
if not dry_run:
assert nn not in [r.name for r in rmts]
smr = smm.create_remote(nn, sm.url)
smr.fetch(progress=progress)
# If we have a tracking branch, it should be available
# in the new remote as well.
if len([r for r in smr.refs if r.remote_head == sm.branch_name]) == 0:
raise ValueError(
"Submodule branch named %r was not available in new submodule remote at %r"
% (sm.branch_name, sm.url)
)
# END head is not detached
# now delete the changed one
rmt_for_deletion = None
for remote in rmts:
if remote.url == psm.url:
rmt_for_deletion = remote
break
# END if urls match
# END for each remote
# if we didn't find a matching remote, but have exactly one,
# we can safely use this one
if rmt_for_deletion is None:
if len(rmts) == 1:
rmt_for_deletion = rmts[0]
else:
# if we have not found any remote with the original url
# we may not have a name. This is a special case,
# and its okay to fail here
# Alternatively we could just generate a unique name and leave all
# existing ones in place
raise InvalidGitRepositoryError(
"Couldn't find original remote-repo at url %r" % psm.url
)
# END handle one single remote
# END handle check we found a remote
orig_name = rmt_for_deletion.name
smm.delete_remote(rmt_for_deletion)
# NOTE: Currently we leave tags from the deleted remotes
# as well as separate tracking branches in the possibly totally
# changed repository ( someone could have changed the url to
# another project ). At some point, one might want to clean
# it up, but the danger is high to remove stuff the user
# has added explicitly
# rename the new remote back to what it was
smr.rename(orig_name)
# early on, we verified that the our current tracking branch
# exists in the remote. Now we have to assure that the
# sha we point to is still contained in the new remote
# tracking branch.
smsha = sm.binsha
found = False
rref = smr.refs[self.branch_name]
for c in rref.commit.traverse():
if c.binsha == smsha:
found = True
break
# END traverse all commits in search for sha
# END for each commit
if not found:
# adjust our internal binsha to use the one of the remote
# this way, it will be checked out in the next step
# This will change the submodule relative to us, so
# the user will be able to commit the change easily
log.warning(
"Current sha %s was not contained in the tracking\
branch at the new remote, setting it the the remote's tracking branch",
sm.hexsha,
)
sm.binsha = rref.commit.binsha
# END reset binsha
# NOTE: All checkout is performed by the base implementation of update
# END handle dry_run
progress.update(
END | URLCHANGE,
i,
len_csms,
prefix + "Done adjusting url of submodule %r" % (sm.name),
)
# END skip remote handling if new url already exists in module
# END handle url
# HANDLE PATH CHANGES
#####################
if sm.branch_path != psm.branch_path:
# finally, create a new tracking branch which tracks the
# new remote branch
progress.update(
BEGIN | BRANCHCHANGE,
i,
len_csms,
prefix
+ "Changing branch of submodule %r from %s to %s"
% (sm.name, psm.branch_path, sm.branch_path),
)
if not dry_run:
smm = sm.module()
smmr = smm.remotes
# As the branch might not exist yet, we will have to fetch all remotes to be sure ... .
for remote in smmr:
remote.fetch(progress=progress)
# end for each remote
try:
tbr = git.Head.create(
smm,
sm.branch_name,
logmsg="branch: Created from HEAD",
)
except OSError:
# ... or reuse the existing one
tbr = git.Head(smm, sm.branch_path)
# END assure tracking branch exists
tbr.set_tracking_branch(find_first_remote_branch(smmr, sm.branch_name))
# NOTE: All head-resetting is done in the base implementation of update
# but we will have to checkout the new branch here. As it still points to the currently
# checkout out commit, we don't do any harm.
# As we don't want to update working-tree or index, changing the ref is all there is to do
smm.head.reference = tbr
# END handle dry_run
progress.update(
END | BRANCHCHANGE,
i,
len_csms,
prefix + "Done changing branch of submodule %r" % sm.name,
)
# END handle branch
# END handle
# END for each common submodule
except Exception as err:
if not keep_going:
raise
log.error(str(err))
# end handle keep_going
# FINALLY UPDATE ALL ACTUAL SUBMODULES
######################################
for sm in sms:
# update the submodule using the default method
sm.update(
recursive=False,
init=init,
to_latest_revision=to_latest_revision,
progress=progress,
dry_run=dry_run,
force=force_reset,
keep_going=keep_going,
)
# update recursively depth first - question is which inconsistent
# state will be better in case it fails somewhere. Defective branch
# or defective depth. The RootSubmodule type will never process itself,
# which was done in the previous expression
if recursive:
# the module would exist by now if we are not in dry_run mode
if sm.module_exists():
type(self)(sm.module()).update(
recursive=True,
force_remove=force_remove,
init=init,
to_latest_revision=to_latest_revision,
progress=progress,
dry_run=dry_run,
force_reset=force_reset,
keep_going=keep_going,
)
# END handle dry_run
# END handle recursive
# END for each submodule to update
return self
def module(self) -> "Repo":
""":return: the actual repository containing the submodules"""
return self.repo
# } END interface
# } END classes

View File

@@ -0,0 +1,118 @@
import git
from git.exc import InvalidGitRepositoryError
from git.config import GitConfigParser
from io import BytesIO
import weakref
# typing -----------------------------------------------------------------------
from typing import Any, Sequence, TYPE_CHECKING, Union
from git.types import PathLike
if TYPE_CHECKING:
from .base import Submodule
from weakref import ReferenceType
from git.repo import Repo
from git.refs import Head
from git import Remote
from git.refs import RemoteReference
__all__ = (
"sm_section",
"sm_name",
"mkhead",
"find_first_remote_branch",
"SubmoduleConfigParser",
)
# { Utilities
def sm_section(name: str) -> str:
""":return: section title used in .gitmodules configuration file"""
return f'submodule "{name}"'
def sm_name(section: str) -> str:
""":return: name of the submodule as parsed from the section name"""
section = section.strip()
return section[11:-1]
def mkhead(repo: "Repo", path: PathLike) -> "Head":
""":return: New branch/head instance"""
return git.Head(repo, git.Head.to_full_path(path))
def find_first_remote_branch(remotes: Sequence["Remote"], branch_name: str) -> "RemoteReference":
"""Find the remote branch matching the name of the given branch or raise InvalidGitRepositoryError"""
for remote in remotes:
try:
return remote.refs[branch_name]
except IndexError:
continue
# END exception handling
# END for remote
raise InvalidGitRepositoryError("Didn't find remote branch '%r' in any of the given remotes" % branch_name)
# } END utilities
# { Classes
class SubmoduleConfigParser(GitConfigParser):
"""
Catches calls to _write, and updates the .gitmodules blob in the index
with the new data, if we have written into a stream. Otherwise it will
add the local file to the index to make it correspond with the working tree.
Additionally, the cache must be cleared
Please note that no mutating method will work in bare mode
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._smref: Union["ReferenceType[Submodule]", None] = None
self._index = None
self._auto_write = True
super(SubmoduleConfigParser, self).__init__(*args, **kwargs)
# { Interface
def set_submodule(self, submodule: "Submodule") -> None:
"""Set this instance's submodule. It must be called before
the first write operation begins"""
self._smref = weakref.ref(submodule)
def flush_to_index(self) -> None:
"""Flush changes in our configuration file to the index"""
assert self._smref is not None
# should always have a file here
assert not isinstance(self._file_or_files, BytesIO)
sm = self._smref()
if sm is not None:
index = self._index
if index is None:
index = sm.repo.index
# END handle index
index.add([sm.k_modules_file], write=self._auto_write)
sm._clear_cache()
# END handle weakref
# } END interface
# { Overridden Methods
def write(self) -> None: # type: ignore[override]
rval: None = super(SubmoduleConfigParser, self).write()
self.flush_to_index()
return rval
# END overridden methods
# } END classes

View File

@@ -0,0 +1,107 @@
# objects.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
""" Module containing all object based types. """
from . import base
from .util import get_object_type_by_name, parse_actor_and_date
from ..util import hex_to_bin
from ..compat import defenc
from typing import List, TYPE_CHECKING, Union
from git.types import Literal
if TYPE_CHECKING:
from git.repo import Repo
from git.util import Actor
from .commit import Commit
from .blob import Blob
from .tree import Tree
__all__ = ("TagObject",)
class TagObject(base.Object):
"""Non-Lightweight tag carrying additional information about an object we are pointing to."""
type: Literal["tag"] = "tag"
__slots__ = (
"object",
"tag",
"tagger",
"tagged_date",
"tagger_tz_offset",
"message",
)
def __init__(
self,
repo: "Repo",
binsha: bytes,
object: Union[None, base.Object] = None,
tag: Union[None, str] = None,
tagger: Union[None, "Actor"] = None,
tagged_date: Union[int, None] = None,
tagger_tz_offset: Union[int, None] = None,
message: Union[str, None] = None,
) -> None: # @ReservedAssignment
"""Initialize a tag object with additional data
:param repo: repository this object is located in
:param binsha: 20 byte SHA1
:param object: Object instance of object we are pointing to
:param tag: name of this tag
:param tagger: Actor identifying the tagger
:param tagged_date: int_seconds_since_epoch
is the DateTime of the tag creation - use time.gmtime to convert
it into a different format
:param tagged_tz_offset: int_seconds_west_of_utc is the timezone that the
authored_date is in, in a format similar to time.altzone"""
super(TagObject, self).__init__(repo, binsha)
if object is not None:
self.object: Union["Commit", "Blob", "Tree", "TagObject"] = object
if tag is not None:
self.tag = tag
if tagger is not None:
self.tagger = tagger
if tagged_date is not None:
self.tagged_date = tagged_date
if tagger_tz_offset is not None:
self.tagger_tz_offset = tagger_tz_offset
if message is not None:
self.message = message
def _set_cache_(self, attr: str) -> None:
"""Cache all our attributes at once"""
if attr in TagObject.__slots__:
ostream = self.repo.odb.stream(self.binsha)
lines: List[str] = ostream.read().decode(defenc, "replace").splitlines()
_obj, hexsha = lines[0].split(" ")
_type_token, type_name = lines[1].split(" ")
object_type = get_object_type_by_name(type_name.encode("ascii"))
self.object = object_type(self.repo, hex_to_bin(hexsha))
self.tag = lines[2][4:] # tag <tag name>
if len(lines) > 3:
tagger_info = lines[3] # tagger <actor> <date>
(
self.tagger,
self.tagged_date,
self.tagger_tz_offset,
) = parse_actor_and_date(tagger_info)
# line 4 empty - it could mark the beginning of the next header
# in case there really is no message, it would not exist. Otherwise
# a newline separates header from message
if len(lines) > 5:
self.message = "\n".join(lines[5:])
else:
self.message = ""
# END check our attributes
else:
super(TagObject, self)._set_cache_(attr)

View File

@@ -0,0 +1,424 @@
# tree.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
from git.util import IterableList, join_path
import git.diff as git_diff
from git.util import to_bin_sha
from . import util
from .base import IndexObject, IndexObjUnion
from .blob import Blob
from .submodule.base import Submodule
from .fun import tree_entries_from_data, tree_to_stream
# typing -------------------------------------------------
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Tuple,
Type,
Union,
cast,
TYPE_CHECKING,
)
from git.types import PathLike, Literal
if TYPE_CHECKING:
from git.repo import Repo
from io import BytesIO
TreeCacheTup = Tuple[bytes, int, str]
TraversedTreeTup = Union[Tuple[Union["Tree", None], IndexObjUnion, Tuple["Submodule", "Submodule"]]]
# def is_tree_cache(inp: Tuple[bytes, int, str]) -> TypeGuard[TreeCacheTup]:
# return isinstance(inp[0], bytes) and isinstance(inp[1], int) and isinstance([inp], str)
# --------------------------------------------------------
cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b)
__all__ = ("TreeModifier", "Tree")
def git_cmp(t1: TreeCacheTup, t2: TreeCacheTup) -> int:
a, b = t1[2], t2[2]
# assert isinstance(a, str) and isinstance(b, str)
len_a, len_b = len(a), len(b)
min_len = min(len_a, len_b)
min_cmp = cmp(a[:min_len], b[:min_len])
if min_cmp:
return min_cmp
return len_a - len_b
def merge_sort(a: List[TreeCacheTup], cmp: Callable[[TreeCacheTup, TreeCacheTup], int]) -> None:
if len(a) < 2:
return None
mid = len(a) // 2
lefthalf = a[:mid]
righthalf = a[mid:]
merge_sort(lefthalf, cmp)
merge_sort(righthalf, cmp)
i = 0
j = 0
k = 0
while i < len(lefthalf) and j < len(righthalf):
if cmp(lefthalf[i], righthalf[j]) <= 0:
a[k] = lefthalf[i]
i = i + 1
else:
a[k] = righthalf[j]
j = j + 1
k = k + 1
while i < len(lefthalf):
a[k] = lefthalf[i]
i = i + 1
k = k + 1
while j < len(righthalf):
a[k] = righthalf[j]
j = j + 1
k = k + 1
class TreeModifier(object):
"""A utility class providing methods to alter the underlying cache in a list-like fashion.
Once all adjustments are complete, the _cache, which really is a reference to
the cache of a tree, will be sorted. Assuring it will be in a serializable state"""
__slots__ = "_cache"
def __init__(self, cache: List[TreeCacheTup]) -> None:
self._cache = cache
def _index_by_name(self, name: str) -> int:
""":return: index of an item with name, or -1 if not found"""
for i, t in enumerate(self._cache):
if t[2] == name:
return i
# END found item
# END for each item in cache
return -1
# { Interface
def set_done(self) -> "TreeModifier":
"""Call this method once you are done modifying the tree information.
It may be called several times, but be aware that each call will cause
a sort operation
:return self:"""
merge_sort(self._cache, git_cmp)
return self
# } END interface
# { Mutators
def add(self, sha: bytes, mode: int, name: str, force: bool = False) -> "TreeModifier":
"""Add the given item to the tree. If an item with the given name already
exists, nothing will be done, but a ValueError will be raised if the
sha and mode of the existing item do not match the one you add, unless
force is True
:param sha: The 20 or 40 byte sha of the item to add
:param mode: int representing the stat compatible mode of the item
:param force: If True, an item with your name and information will overwrite
any existing item with the same name, no matter which information it has
:return: self"""
if "/" in name:
raise ValueError("Name must not contain '/' characters")
if (mode >> 12) not in Tree._map_id_to_type:
raise ValueError("Invalid object type according to mode %o" % mode)
sha = to_bin_sha(sha)
index = self._index_by_name(name)
item = (sha, mode, name)
# assert is_tree_cache(item)
if index == -1:
self._cache.append(item)
else:
if force:
self._cache[index] = item
else:
ex_item = self._cache[index]
if ex_item[0] != sha or ex_item[1] != mode:
raise ValueError("Item %r existed with different properties" % name)
# END handle mismatch
# END handle force
# END handle name exists
return self
def add_unchecked(self, binsha: bytes, mode: int, name: str) -> None:
"""Add the given item to the tree, its correctness is assumed, which
puts the caller into responsibility to assure the input is correct.
For more information on the parameters, see ``add``
:param binsha: 20 byte binary sha"""
assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str)
tree_cache = (binsha, mode, name)
self._cache.append(tree_cache)
def __delitem__(self, name: str) -> None:
"""Deletes an item with the given name if it exists"""
index = self._index_by_name(name)
if index > -1:
del self._cache[index]
# } END mutators
class Tree(IndexObject, git_diff.Diffable, util.Traversable, util.Serializable):
"""Tree objects represent an ordered list of Blobs and other Trees.
``Tree as a list``::
Access a specific blob using the
tree['filename'] notation.
You may as well access by index
blob = tree[0]
"""
type: Literal["tree"] = "tree"
__slots__ = "_cache"
# actual integer ids for comparison
commit_id = 0o16 # equals stat.S_IFDIR | stat.S_IFLNK - a directory link
blob_id = 0o10
symlink_id = 0o12
tree_id = 0o04
_map_id_to_type: Dict[int, Type[IndexObjUnion]] = {
commit_id: Submodule,
blob_id: Blob,
symlink_id: Blob
# tree id added once Tree is defined
}
def __init__(
self,
repo: "Repo",
binsha: bytes,
mode: int = tree_id << 12,
path: Union[PathLike, None] = None,
):
super(Tree, self).__init__(repo, binsha, mode, path)
@classmethod
def _get_intermediate_items(
cls,
index_object: IndexObjUnion,
) -> Union[Tuple["Tree", ...], Tuple[()]]:
if index_object.type == "tree":
return tuple(index_object._iter_convert_to_object(index_object._cache))
return ()
def _set_cache_(self, attr: str) -> None:
if attr == "_cache":
# Set the data when we need it
ostream = self.repo.odb.stream(self.binsha)
self._cache: List[TreeCacheTup] = tree_entries_from_data(ostream.read())
else:
super(Tree, self)._set_cache_(attr)
# END handle attribute
def _iter_convert_to_object(self, iterable: Iterable[TreeCacheTup]) -> Iterator[IndexObjUnion]:
"""Iterable yields tuples of (binsha, mode, name), which will be converted
to the respective object representation"""
for binsha, mode, name in iterable:
path = join_path(self.path, name)
try:
yield self._map_id_to_type[mode >> 12](self.repo, binsha, mode, path)
except KeyError as e:
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
# END for each item
def join(self, file: str) -> IndexObjUnion:
"""Find the named object in this tree's contents
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
:raise KeyError: if given file or tree does not exist in tree"""
msg = "Blob or Tree named %r not found"
if "/" in file:
tree = self
item = self
tokens = file.split("/")
for i, token in enumerate(tokens):
item = tree[token]
if item.type == "tree":
tree = item
else:
# safety assertion - blobs are at the end of the path
if i != len(tokens) - 1:
raise KeyError(msg % file)
return item
# END handle item type
# END for each token of split path
if item == self:
raise KeyError(msg % file)
return item
else:
for info in self._cache:
if info[2] == file: # [2] == name
return self._map_id_to_type[info[1] >> 12](
self.repo, info[0], info[1], join_path(self.path, info[2])
)
# END for each obj
raise KeyError(msg % file)
# END handle long paths
def __truediv__(self, file: str) -> IndexObjUnion:
"""For PY3 only"""
return self.join(file)
@property
def trees(self) -> List["Tree"]:
""":return: list(Tree, ...) list of trees directly below this tree"""
return [i for i in self if i.type == "tree"]
@property
def blobs(self) -> List[Blob]:
""":return: list(Blob, ...) list of blobs directly below this tree"""
return [i for i in self if i.type == "blob"]
@property
def cache(self) -> TreeModifier:
"""
:return: An object allowing to modify the internal cache. This can be used
to change the tree's contents. When done, make sure you call ``set_done``
on the tree modifier, or serialization behaviour will be incorrect.
See the ``TreeModifier`` for more information on how to alter the cache"""
return TreeModifier(self._cache)
def traverse(
self, # type: ignore[override]
predicate: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: True,
prune: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: False,
depth: int = -1,
branch_first: bool = True,
visit_once: bool = False,
ignore_self: int = 1,
as_edge: bool = False,
) -> Union[Iterator[IndexObjUnion], Iterator[TraversedTreeTup]]:
"""For documentation, see util.Traversable._traverse()
Trees are set to visit_once = False to gain more performance in the traversal"""
# """
# # To typecheck instead of using cast.
# import itertools
# def is_tree_traversed(inp: Tuple) -> TypeGuard[Tuple[Iterator[Union['Tree', 'Blob', 'Submodule']]]]:
# return all(isinstance(x, (Blob, Tree, Submodule)) for x in inp[1])
# ret = super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self)
# ret_tup = itertools.tee(ret, 2)
# assert is_tree_traversed(ret_tup), f"Type is {[type(x) for x in list(ret_tup[0])]}"
# return ret_tup[0]"""
return cast(
Union[Iterator[IndexObjUnion], Iterator[TraversedTreeTup]],
super(Tree, self)._traverse(
predicate,
prune,
depth, # type: ignore
branch_first,
visit_once,
ignore_self,
),
)
def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[IndexObjUnion]:
"""
:return: IterableList with the results of the traversal as produced by
traverse()
Tree -> IterableList[Union['Submodule', 'Tree', 'Blob']]
"""
return super(Tree, self)._list_traverse(*args, **kwargs)
# List protocol
def __getslice__(self, i: int, j: int) -> List[IndexObjUnion]:
return list(self._iter_convert_to_object(self._cache[i:j]))
def __iter__(self) -> Iterator[IndexObjUnion]:
return self._iter_convert_to_object(self._cache)
def __len__(self) -> int:
return len(self._cache)
def __getitem__(self, item: Union[str, int, slice]) -> IndexObjUnion:
if isinstance(item, int):
info = self._cache[item]
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
if isinstance(item, str):
# compatibility
return self.join(item)
# END index is basestring
raise TypeError("Invalid index type: %r" % item)
def __contains__(self, item: Union[IndexObjUnion, PathLike]) -> bool:
if isinstance(item, IndexObject):
for info in self._cache:
if item.binsha == info[0]:
return True
# END compare sha
# END for each entry
# END handle item is index object
# compatibility
# treat item as repo-relative path
else:
path = self.path
for info in self._cache:
if item == join_path(path, info[2]):
return True
# END for each item
return False
def __reversed__(self) -> Iterator[IndexObjUnion]:
return reversed(self._iter_convert_to_object(self._cache)) # type: ignore
def _serialize(self, stream: "BytesIO") -> "Tree":
"""Serialize this tree into the stream. Please note that we will assume
our tree data to be in a sorted state. If this is not the case, serialization
will not generate a correct tree representation as these are assumed to be sorted
by algorithms"""
tree_to_stream(self._cache, stream.write)
return self
def _deserialize(self, stream: "BytesIO") -> "Tree":
self._cache = tree_entries_from_data(stream.read())
return self
# END tree
# finalize map definition
Tree._map_id_to_type[Tree.tree_id] = Tree
#

View File

@@ -0,0 +1,637 @@
# util.py
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
#
# This module is part of GitPython and is released under
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
"""Module for general utility functions"""
# flake8: noqa F401
from abc import ABC, abstractmethod
import warnings
from git.util import IterableList, IterableObj, Actor
import re
from collections import deque
from string import digits
import time
import calendar
from datetime import datetime, timedelta, tzinfo
# typing ------------------------------------------------------------
from typing import (
Any,
Callable,
Deque,
Iterator,
Generic,
NamedTuple,
overload,
Sequence, # NOQA: F401
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from git.types import Has_id_attribute, Literal, _T # NOQA: F401
if TYPE_CHECKING:
from io import BytesIO, StringIO
from .commit import Commit
from .blob import Blob
from .tag import TagObject
from .tree import Tree, TraversedTreeTup
from subprocess import Popen
from .submodule.base import Submodule
from git.types import Protocol, runtime_checkable
else:
# Protocol = Generic[_T] # Needed for typing bug #572?
Protocol = ABC
def runtime_checkable(f):
return f
class TraverseNT(NamedTuple):
depth: int
item: Union["Traversable", "Blob"]
src: Union["Traversable", None]
T_TIobj = TypeVar("T_TIobj", bound="TraversableIterableObj") # for TraversableIterableObj.traverse()
TraversedTup = Union[
Tuple[Union["Traversable", None], "Traversable"], # for commit, submodule
"TraversedTreeTup",
] # for tree.traverse()
# --------------------------------------------------------------------
__all__ = (
"get_object_type_by_name",
"parse_date",
"parse_actor_and_date",
"ProcessStreamAdapter",
"Traversable",
"altz_to_utctz_str",
"utctz_to_altz",
"verify_utctz",
"Actor",
"tzoffset",
"utc",
)
ZERO = timedelta(0)
# { Functions
def mode_str_to_int(modestr: Union[bytes, str]) -> int:
"""
:param modestr: string like 755 or 644 or 100644 - only the last 6 chars will be used
:return:
String identifying a mode compatible to the mode methods ids of the
stat module regarding the rwx permissions for user, group and other,
special flags and file system flags, i.e. whether it is a symlink
for example."""
mode = 0
for iteration, char in enumerate(reversed(modestr[-6:])):
char = cast(Union[str, int], char)
mode += int(char) << iteration * 3
# END for each char
return mode
def get_object_type_by_name(
object_type_name: bytes,
) -> Union[Type["Commit"], Type["TagObject"], Type["Tree"], Type["Blob"]]:
"""
:return: type suitable to handle the given object type name.
Use the type to create new instances.
:param object_type_name: Member of TYPES
:raise ValueError: In case object_type_name is unknown"""
if object_type_name == b"commit":
from . import commit
return commit.Commit
elif object_type_name == b"tag":
from . import tag
return tag.TagObject
elif object_type_name == b"blob":
from . import blob
return blob.Blob
elif object_type_name == b"tree":
from . import tree
return tree.Tree
else:
raise ValueError("Cannot handle unknown object type: %s" % object_type_name.decode())
def utctz_to_altz(utctz: str) -> int:
"""Convert a git timezone offset into a timezone offset west of
UTC in seconds (compatible with time.altzone).
:param utctz: git utc timezone string, i.e. +0200
"""
int_utctz = int(utctz)
seconds = ((abs(int_utctz) // 100) * 3600 + (abs(int_utctz) % 100) * 60)
return seconds if int_utctz < 0 else -seconds
def altz_to_utctz_str(altz: int) -> str:
"""Convert a timezone offset west of UTC in seconds into a git timezone offset string
:param altz: timezone offset in seconds west of UTC
"""
hours = abs(altz) // 3600
minutes = (abs(altz) % 3600) // 60
sign = "-" if altz >= 60 else "+"
return "{}{:02}{:02}".format(sign, hours, minutes)
def verify_utctz(offset: str) -> str:
""":raise ValueError: if offset is incorrect
:return: offset"""
fmt_exc = ValueError("Invalid timezone offset format: %s" % offset)
if len(offset) != 5:
raise fmt_exc
if offset[0] not in "+-":
raise fmt_exc
if offset[1] not in digits or offset[2] not in digits or offset[3] not in digits or offset[4] not in digits:
raise fmt_exc
# END for each char
return offset
class tzoffset(tzinfo):
def __init__(self, secs_west_of_utc: float, name: Union[None, str] = None) -> None:
self._offset = timedelta(seconds=-secs_west_of_utc)
self._name = name or "fixed"
def __reduce__(self) -> Tuple[Type["tzoffset"], Tuple[float, str]]:
return tzoffset, (-self._offset.total_seconds(), self._name)
def utcoffset(self, dt: Union[datetime, None]) -> timedelta:
return self._offset
def tzname(self, dt: Union[datetime, None]) -> str:
return self._name
def dst(self, dt: Union[datetime, None]) -> timedelta:
return ZERO
utc = tzoffset(0, "UTC")
def from_timestamp(timestamp: float, tz_offset: float) -> datetime:
"""Converts a timestamp + tz_offset into an aware datetime instance."""
utc_dt = datetime.fromtimestamp(timestamp, utc)
try:
local_dt = utc_dt.astimezone(tzoffset(tz_offset))
return local_dt
except ValueError:
return utc_dt
def parse_date(string_date: Union[str, datetime]) -> Tuple[int, int]:
"""
Parse the given date as one of the following
* aware datetime instance
* Git internal format: timestamp offset
* RFC 2822: Thu, 07 Apr 2005 22:13:13 +0200.
* ISO 8601 2005-04-07T22:13:13
The T can be a space as well
:return: Tuple(int(timestamp_UTC), int(offset)), both in seconds since epoch
:raise ValueError: If the format could not be understood
:note: Date can also be YYYY.MM.DD, MM/DD/YYYY and DD.MM.YYYY.
"""
if isinstance(string_date, datetime):
if string_date.tzinfo:
utcoffset = cast(timedelta, string_date.utcoffset()) # typeguard, if tzinfoand is not None
offset = -int(utcoffset.total_seconds())
return int(string_date.astimezone(utc).timestamp()), offset
else:
raise ValueError(f"string_date datetime object without tzinfo, {string_date}")
# git time
try:
if string_date.count(" ") == 1 and string_date.rfind(":") == -1:
timestamp, offset_str = string_date.split()
if timestamp.startswith("@"):
timestamp = timestamp[1:]
timestamp_int = int(timestamp)
return timestamp_int, utctz_to_altz(verify_utctz(offset_str))
else:
offset_str = "+0000" # local time by default
if string_date[-5] in "-+":
offset_str = verify_utctz(string_date[-5:])
string_date = string_date[:-6] # skip space as well
# END split timezone info
offset = utctz_to_altz(offset_str)
# now figure out the date and time portion - split time
date_formats = []
splitter = -1
if "," in string_date:
date_formats.append("%a, %d %b %Y")
splitter = string_date.rfind(" ")
else:
# iso plus additional
date_formats.append("%Y-%m-%d")
date_formats.append("%Y.%m.%d")
date_formats.append("%m/%d/%Y")
date_formats.append("%d.%m.%Y")
splitter = string_date.rfind("T")
if splitter == -1:
splitter = string_date.rfind(" ")
# END handle 'T' and ' '
# END handle rfc or iso
assert splitter > -1
# split date and time
time_part = string_date[splitter + 1 :] # skip space
date_part = string_date[:splitter]
# parse time
tstruct = time.strptime(time_part, "%H:%M:%S")
for fmt in date_formats:
try:
dtstruct = time.strptime(date_part, fmt)
utctime = calendar.timegm(
(
dtstruct.tm_year,
dtstruct.tm_mon,
dtstruct.tm_mday,
tstruct.tm_hour,
tstruct.tm_min,
tstruct.tm_sec,
dtstruct.tm_wday,
dtstruct.tm_yday,
tstruct.tm_isdst,
)
)
return int(utctime), offset
except ValueError:
continue
# END exception handling
# END for each fmt
# still here ? fail
raise ValueError("no format matched")
# END handle format
except Exception as e:
raise ValueError(f"Unsupported date format or type: {string_date}, type={type(string_date)}") from e
# END handle exceptions
# precompiled regex
_re_actor_epoch = re.compile(r"^.+? (.*) (\d+) ([+-]\d+).*$")
_re_only_actor = re.compile(r"^.+? (.*)$")
def parse_actor_and_date(line: str) -> Tuple[Actor, int, int]:
"""Parse out the actor (author or committer) info from a line like::
author Tom Preston-Werner <tom@mojombo.com> 1191999972 -0700
:return: [Actor, int_seconds_since_epoch, int_timezone_offset]"""
actor, epoch, offset = "", "0", "0"
m = _re_actor_epoch.search(line)
if m:
actor, epoch, offset = m.groups()
else:
m = _re_only_actor.search(line)
actor = m.group(1) if m else line or ""
return (Actor._from_string(actor), int(epoch), utctz_to_altz(offset))
# } END functions
# { Classes
class ProcessStreamAdapter(object):
"""Class wireing all calls to the contained Process instance.
Use this type to hide the underlying process to provide access only to a specified
stream. The process is usually wrapped into an AutoInterrupt class to kill
it if the instance goes out of scope."""
__slots__ = ("_proc", "_stream")
def __init__(self, process: "Popen", stream_name: str) -> None:
self._proc = process
self._stream: StringIO = getattr(process, stream_name) # guessed type
def __getattr__(self, attr: str) -> Any:
return getattr(self._stream, attr)
@runtime_checkable
class Traversable(Protocol):
"""Simple interface to perform depth-first or breadth-first traversals
into one direction.
Subclasses only need to implement one function.
Instances of the Subclass must be hashable
Defined subclasses = [Commit, Tree, SubModule]
"""
__slots__ = ()
@classmethod
@abstractmethod
def _get_intermediate_items(cls, item: Any) -> Sequence["Traversable"]:
"""
Returns:
Tuple of items connected to the given item.
Must be implemented in subclass
class Commit:: (cls, Commit) -> Tuple[Commit, ...]
class Submodule:: (cls, Submodule) -> Iterablelist[Submodule]
class Tree:: (cls, Tree) -> Tuple[Tree, ...]
"""
raise NotImplementedError("To be implemented in subclass")
@abstractmethod
def list_traverse(self, *args: Any, **kwargs: Any) -> Any:
""" """
warnings.warn(
"list_traverse() method should only be called from subclasses."
"Calling from Traversable abstract class will raise NotImplementedError in 3.1.20"
"Builtin sublclasses are 'Submodule', 'Tree' and 'Commit",
DeprecationWarning,
stacklevel=2,
)
return self._list_traverse(*args, **kwargs)
def _list_traverse(
self, as_edge: bool = False, *args: Any, **kwargs: Any
) -> IterableList[Union["Commit", "Submodule", "Tree", "Blob"]]:
"""
:return: IterableList with the results of the traversal as produced by
traverse()
Commit -> IterableList['Commit']
Submodule -> IterableList['Submodule']
Tree -> IterableList[Union['Submodule', 'Tree', 'Blob']]
"""
# Commit and Submodule have id.__attribute__ as IterableObj
# Tree has id.__attribute__ inherited from IndexObject
if isinstance(self, Has_id_attribute):
id = self._id_attribute_
else:
id = "" # shouldn't reach here, unless Traversable subclass created with no _id_attribute_
# could add _id_attribute_ to Traversable, or make all Traversable also Iterable?
if not as_edge:
out: IterableList[Union["Commit", "Submodule", "Tree", "Blob"]] = IterableList(id)
out.extend(self.traverse(as_edge=as_edge, *args, **kwargs))
return out
# overloads in subclasses (mypy doesn't allow typing self: subclass)
# Union[IterableList['Commit'], IterableList['Submodule'], IterableList[Union['Submodule', 'Tree', 'Blob']]]
else:
# Raise deprecationwarning, doesn't make sense to use this
out_list: IterableList = IterableList(self.traverse(*args, **kwargs))
return out_list
@abstractmethod
def traverse(self, *args: Any, **kwargs: Any) -> Any:
""" """
warnings.warn(
"traverse() method should only be called from subclasses."
"Calling from Traversable abstract class will raise NotImplementedError in 3.1.20"
"Builtin sublclasses are 'Submodule', 'Tree' and 'Commit",
DeprecationWarning,
stacklevel=2,
)
return self._traverse(*args, **kwargs)
def _traverse(
self,
predicate: Callable[[Union["Traversable", "Blob", TraversedTup], int], bool] = lambda i, d: True,
prune: Callable[[Union["Traversable", "Blob", TraversedTup], int], bool] = lambda i, d: False,
depth: int = -1,
branch_first: bool = True,
visit_once: bool = True,
ignore_self: int = 1,
as_edge: bool = False,
) -> Union[Iterator[Union["Traversable", "Blob"]], Iterator[TraversedTup]]:
""":return: iterator yielding of items found when traversing self
:param predicate: f(i,d) returns False if item i at depth d should not be included in the result
:param prune:
f(i,d) return True if the search should stop at item i at depth d.
Item i will not be returned.
:param depth:
define at which level the iteration should not go deeper
if -1, there is no limit
if 0, you would effectively only get self, the root of the iteration
i.e. if 1, you would only get the first level of predecessors/successors
:param branch_first:
if True, items will be returned branch first, otherwise depth first
:param visit_once:
if True, items will only be returned once, although they might be encountered
several times. Loops are prevented that way.
:param ignore_self:
if True, self will be ignored and automatically pruned from
the result. Otherwise it will be the first item to be returned.
If as_edge is True, the source of the first edge is None
:param as_edge:
if True, return a pair of items, first being the source, second the
destination, i.e. tuple(src, dest) with the edge spanning from
source to destination"""
"""
Commit -> Iterator[Union[Commit, Tuple[Commit, Commit]]
Submodule -> Iterator[Submodule, Tuple[Submodule, Submodule]]
Tree -> Iterator[Union[Blob, Tree, Submodule,
Tuple[Union[Submodule, Tree], Union[Blob, Tree, Submodule]]]
ignore_self=True is_edge=True -> Iterator[item]
ignore_self=True is_edge=False --> Iterator[item]
ignore_self=False is_edge=True -> Iterator[item] | Iterator[Tuple[src, item]]
ignore_self=False is_edge=False -> Iterator[Tuple[src, item]]"""
visited = set()
stack: Deque[TraverseNT] = deque()
stack.append(TraverseNT(0, self, None)) # self is always depth level 0
def addToStack(
stack: Deque[TraverseNT],
src_item: "Traversable",
branch_first: bool,
depth: int,
) -> None:
lst = self._get_intermediate_items(item)
if not lst: # empty list
return None
if branch_first:
stack.extendleft(TraverseNT(depth, i, src_item) for i in lst)
else:
reviter = (TraverseNT(depth, lst[i], src_item) for i in range(len(lst) - 1, -1, -1))
stack.extend(reviter)
# END addToStack local method
while stack:
d, item, src = stack.pop() # depth of item, item, item_source
if visit_once and item in visited:
continue
if visit_once:
visited.add(item)
rval: Union[TraversedTup, "Traversable", "Blob"]
if as_edge: # if as_edge return (src, item) unless rrc is None (e.g. for first item)
rval = (src, item)
else:
rval = item
if prune(rval, d):
continue
skipStartItem = ignore_self and (item is self)
if not skipStartItem and predicate(rval, d):
yield rval
# only continue to next level if this is appropriate !
nd = d + 1
if depth > -1 and nd > depth:
continue
addToStack(stack, item, branch_first, nd)
# END for each item on work stack
@runtime_checkable
class Serializable(Protocol):
"""Defines methods to serialize and deserialize objects from and into a data stream"""
__slots__ = ()
# @abstractmethod
def _serialize(self, stream: "BytesIO") -> "Serializable":
"""Serialize the data of this object into the given data stream
:note: a serialized object would ``_deserialize`` into the same object
:param stream: a file-like object
:return: self"""
raise NotImplementedError("To be implemented in subclass")
# @abstractmethod
def _deserialize(self, stream: "BytesIO") -> "Serializable":
"""Deserialize all information regarding this object from the stream
:param stream: a file-like object
:return: self"""
raise NotImplementedError("To be implemented in subclass")
class TraversableIterableObj(IterableObj, Traversable):
__slots__ = ()
TIobj_tuple = Tuple[Union[T_TIobj, None], T_TIobj]
def list_traverse(self: T_TIobj, *args: Any, **kwargs: Any) -> IterableList[T_TIobj]:
return super(TraversableIterableObj, self)._list_traverse(*args, **kwargs)
@overload # type: ignore
def traverse(self: T_TIobj) -> Iterator[T_TIobj]:
...
@overload
def traverse(
self: T_TIobj,
predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
prune: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
depth: int,
branch_first: bool,
visit_once: bool,
ignore_self: Literal[True],
as_edge: Literal[False],
) -> Iterator[T_TIobj]:
...
@overload
def traverse(
self: T_TIobj,
predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
prune: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
depth: int,
branch_first: bool,
visit_once: bool,
ignore_self: Literal[False],
as_edge: Literal[True],
) -> Iterator[Tuple[Union[T_TIobj, None], T_TIobj]]:
...
@overload
def traverse(
self: T_TIobj,
predicate: Callable[[Union[T_TIobj, TIobj_tuple], int], bool],
prune: Callable[[Union[T_TIobj, TIobj_tuple], int], bool],
depth: int,
branch_first: bool,
visit_once: bool,
ignore_self: Literal[True],
as_edge: Literal[True],
) -> Iterator[Tuple[T_TIobj, T_TIobj]]:
...
def traverse(
self: T_TIobj,
predicate: Callable[[Union[T_TIobj, TIobj_tuple], int], bool] = lambda i, d: True,
prune: Callable[[Union[T_TIobj, TIobj_tuple], int], bool] = lambda i, d: False,
depth: int = -1,
branch_first: bool = True,
visit_once: bool = True,
ignore_self: int = 1,
as_edge: bool = False,
) -> Union[Iterator[T_TIobj], Iterator[Tuple[T_TIobj, T_TIobj]], Iterator[TIobj_tuple]]:
"""For documentation, see util.Traversable._traverse()"""
"""
# To typecheck instead of using cast.
import itertools
from git.types import TypeGuard
def is_commit_traversed(inp: Tuple) -> TypeGuard[Tuple[Iterator[Tuple['Commit', 'Commit']]]]:
for x in inp[1]:
if not isinstance(x, tuple) and len(x) != 2:
if all(isinstance(inner, Commit) for inner in x):
continue
return True
ret = super(Commit, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self, as_edge)
ret_tup = itertools.tee(ret, 2)
assert is_commit_traversed(ret_tup), f"{[type(x) for x in list(ret_tup[0])]}"
return ret_tup[0]
"""
return cast(
Union[Iterator[T_TIobj], Iterator[Tuple[Union[None, T_TIobj], T_TIobj]]],
super(TraversableIterableObj, self)._traverse(
predicate, prune, depth, branch_first, visit_once, ignore_self, as_edge # type: ignore
),
)

View File

@@ -0,0 +1,9 @@
# flake8: noqa
# import all modules in order, fix the names they require
from .symbolic import *
from .reference import *
from .head import *
from .tag import *
from .remote import *
from .log import *

View File

@@ -0,0 +1,277 @@
from git.config import GitConfigParser, SectionConstraint
from git.util import join_path
from git.exc import GitCommandError
from .symbolic import SymbolicReference
from .reference import Reference
# typinng ---------------------------------------------------
from typing import Any, Sequence, Union, TYPE_CHECKING
from git.types import PathLike, Commit_ish
if TYPE_CHECKING:
from git.repo import Repo
from git.objects import Commit
from git.refs import RemoteReference
# -------------------------------------------------------------------
__all__ = ["HEAD", "Head"]
def strip_quotes(string: str) -> str:
if string.startswith('"') and string.endswith('"'):
return string[1:-1]
return string
class HEAD(SymbolicReference):
"""Special case of a Symbolic Reference as it represents the repository's
HEAD reference."""
_HEAD_NAME = "HEAD"
_ORIG_HEAD_NAME = "ORIG_HEAD"
__slots__ = ()
def __init__(self, repo: "Repo", path: PathLike = _HEAD_NAME):
if path != self._HEAD_NAME:
raise ValueError("HEAD instance must point to %r, got %r" % (self._HEAD_NAME, path))
super(HEAD, self).__init__(repo, path)
self.commit: "Commit"
def orig_head(self) -> SymbolicReference:
"""
:return: SymbolicReference pointing at the ORIG_HEAD, which is maintained
to contain the previous value of HEAD"""
return SymbolicReference(self.repo, self._ORIG_HEAD_NAME)
def reset(
self,
commit: Union[Commit_ish, SymbolicReference, str] = "HEAD",
index: bool = True,
working_tree: bool = False,
paths: Union[PathLike, Sequence[PathLike], None] = None,
**kwargs: Any,
) -> "HEAD":
"""Reset our HEAD to the given commit optionally synchronizing
the index and working tree. The reference we refer to will be set to
commit as well.
:param commit:
Commit object, Reference Object or string identifying a revision we
should reset HEAD to.
:param index:
If True, the index will be set to match the given commit. Otherwise
it will not be touched.
:param working_tree:
If True, the working tree will be forcefully adjusted to match the given
commit, possibly overwriting uncommitted changes without warning.
If working_tree is True, index must be true as well
:param paths:
Single path or list of paths relative to the git root directory
that are to be reset. This allows to partially reset individual files.
:param kwargs:
Additional arguments passed to git-reset.
:return: self"""
mode: Union[str, None]
mode = "--soft"
if index:
mode = "--mixed"
# it appears, some git-versions declare mixed and paths deprecated
# see http://github.com/Byron/GitPython/issues#issue/2
if paths:
mode = None
# END special case
# END handle index
if working_tree:
mode = "--hard"
if not index:
raise ValueError("Cannot reset the working tree if the index is not reset as well")
# END working tree handling
try:
self.repo.git.reset(mode, commit, "--", paths, **kwargs)
except GitCommandError as e:
# git nowadays may use 1 as status to indicate there are still unstaged
# modifications after the reset
if e.status != 1:
raise
# END handle exception
return self
class Head(Reference):
"""A Head is a named reference to a Commit. Every Head instance contains a name
and a Commit object.
Examples::
>>> repo = Repo("/path/to/repo")
>>> head = repo.heads[0]
>>> head.name
'master'
>>> head.commit
<git.Commit "1c09f116cbc2cb4100fb6935bb162daa4723f455">
>>> head.commit.hexsha
'1c09f116cbc2cb4100fb6935bb162daa4723f455'"""
_common_path_default = "refs/heads"
k_config_remote = "remote"
k_config_remote_ref = "merge" # branch to merge from remote
@classmethod
def delete(cls, repo: "Repo", *heads: "Union[Head, str]", force: bool = False, **kwargs: Any) -> None:
"""Delete the given heads
:param force:
If True, the heads will be deleted even if they are not yet merged into
the main development stream.
Default False"""
flag = "-d"
if force:
flag = "-D"
repo.git.branch(flag, *heads)
def set_tracking_branch(self, remote_reference: Union["RemoteReference", None]) -> "Head":
"""
Configure this branch to track the given remote reference. This will alter
this branch's configuration accordingly.
:param remote_reference: The remote reference to track or None to untrack
any references
:return: self"""
from .remote import RemoteReference
if remote_reference is not None and not isinstance(remote_reference, RemoteReference):
raise ValueError("Incorrect parameter type: %r" % remote_reference)
# END handle type
with self.config_writer() as writer:
if remote_reference is None:
writer.remove_option(self.k_config_remote)
writer.remove_option(self.k_config_remote_ref)
if len(writer.options()) == 0:
writer.remove_section()
else:
writer.set_value(self.k_config_remote, remote_reference.remote_name)
writer.set_value(
self.k_config_remote_ref,
Head.to_full_path(remote_reference.remote_head),
)
return self
def tracking_branch(self) -> Union["RemoteReference", None]:
"""
:return: The remote_reference we are tracking, or None if we are
not a tracking branch"""
from .remote import RemoteReference
reader = self.config_reader()
if reader.has_option(self.k_config_remote) and reader.has_option(self.k_config_remote_ref):
ref = Head(
self.repo,
Head.to_full_path(strip_quotes(reader.get_value(self.k_config_remote_ref))),
)
remote_refpath = RemoteReference.to_full_path(join_path(reader.get_value(self.k_config_remote), ref.name))
return RemoteReference(self.repo, remote_refpath)
# END handle have tracking branch
# we are not a tracking branch
return None
def rename(self, new_path: PathLike, force: bool = False) -> "Head":
"""Rename self to a new path
:param new_path:
Either a simple name or a path, i.e. new_name or features/new_name.
The prefix refs/heads is implied
:param force:
If True, the rename will succeed even if a head with the target name
already exists.
:return: self
:note: respects the ref log as git commands are used"""
flag = "-m"
if force:
flag = "-M"
self.repo.git.branch(flag, self, new_path)
self.path = "%s/%s" % (self._common_path_default, new_path)
return self
def checkout(self, force: bool = False, **kwargs: Any) -> Union["HEAD", "Head"]:
"""Checkout this head by setting the HEAD to this reference, by updating the index
to reflect the tree we point to and by updating the working tree to reflect
the latest index.
The command will fail if changed working tree files would be overwritten.
:param force:
If True, changes to the index and the working tree will be discarded.
If False, GitCommandError will be raised in that situation.
:param kwargs:
Additional keyword arguments to be passed to git checkout, i.e.
b='new_branch' to create a new branch at the given spot.
:return:
The active branch after the checkout operation, usually self unless
a new branch has been created.
If there is no active branch, as the HEAD is now detached, the HEAD
reference will be returned instead.
:note:
By default it is only allowed to checkout heads - everything else
will leave the HEAD detached which is allowed and possible, but remains
a special state that some tools might not be able to handle."""
kwargs["f"] = force
if kwargs["f"] is False:
kwargs.pop("f")
self.repo.git.checkout(self, **kwargs)
if self.repo.head.is_detached:
return self.repo.head
else:
return self.repo.active_branch
# { Configuration
def _config_parser(self, read_only: bool) -> SectionConstraint[GitConfigParser]:
if read_only:
parser = self.repo.config_reader()
else:
parser = self.repo.config_writer()
# END handle parser instance
return SectionConstraint(parser, 'branch "%s"' % self.name)
def config_reader(self) -> SectionConstraint[GitConfigParser]:
"""
:return: A configuration parser instance constrained to only read
this instance's values"""
return self._config_parser(read_only=True)
def config_writer(self) -> SectionConstraint[GitConfigParser]:
"""
:return: A configuration writer instance with read-and write access
to options of this head"""
return self._config_parser(read_only=False)
# } END configuration

Some files were not shown because too many files have changed in this diff Show More