upload
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
8
.idea/MeCo.iml
generated
Normal file
8
.idea/MeCo.iml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
126
.idea/deployment.xml
generated
Normal file
126
.idea/deployment.xml
generated
Normal file
@@ -0,0 +1,126 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (10)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (11)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (4)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (5)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (6)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (7)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (8)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.100:7712 password (9)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.99:7712 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.99:7712 password (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.99:7712 password (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.99:7712 password (4)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.99:7712 password (5)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="ubuntu@172.16.214.99:7712 password (6)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
</component>
|
||||
</project>
|
||||
26
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
26
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,26 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="6">
|
||||
<item index="0" class="java.lang.String" itemvalue="netifaces" />
|
||||
<item index="1" class="java.lang.String" itemvalue="scikit-learn" />
|
||||
<item index="2" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="3" class="java.lang.String" itemvalue="Auto-PyTorch" />
|
||||
<item index="4" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="5" class="java.lang.String" itemvalue="tensorflow-gpu" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredIdentifiers">
|
||||
<list>
|
||||
<option value="random.random.choices" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
4
.idea/misc.xml
generated
Normal file
4
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.16 (sftp://ubuntu@172.16.214.100:7712/jty/anaconda3/envs/meco/bin/python3.8)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/MeCo.iml" filepath="$PROJECT_DIR$/.idea/MeCo.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 zerocostptnas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
175
Layers/layers.py
Normal file
175
Layers/layers.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super(Linear, self).__init__(in_features, out_features, bias)
|
||||
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
||||
self.register_buffer('score', torch.zeros(self.weight.shape))
|
||||
if self.bias is not None:
|
||||
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
||||
|
||||
def forward(self, input):
|
||||
W = self.weight_mask * self.weight
|
||||
if self.bias is not None:
|
||||
b = self.bias_mask * self.bias
|
||||
else:
|
||||
b = self.bias
|
||||
return F.linear(input, W, b)
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1,
|
||||
bias=True, padding_mode='zeros'):
|
||||
super(Conv2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding,
|
||||
dilation, groups, bias, padding_mode)
|
||||
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
||||
self.register_buffer('score', torch.zeros(self.weight.shape))
|
||||
if self.bias is not None:
|
||||
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
||||
|
||||
def _conv_forward(self, input, weight, bias):
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, bias, self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
def forward(self, input):
|
||||
W = self.weight_mask * self.weight
|
||||
if self.bias is not None:
|
||||
b = self.bias_mask * self.bias
|
||||
else:
|
||||
b = self.bias
|
||||
return self._conv_forward(input, W, b)
|
||||
|
||||
|
||||
class BatchNorm1d(nn.BatchNorm1d):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||
track_running_stats=True):
|
||||
super(BatchNorm1d, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
if self.affine:
|
||||
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
||||
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
||||
self.register_buffer('score', torch.zeros(self.weight.shape))
|
||||
def forward(self, input):
|
||||
self._check_input_dim(input)
|
||||
|
||||
# exponential_average_factor is set to self.momentum
|
||||
# (when it is available) only so that if gets updated
|
||||
# in ONNX graph when this node is exported to ONNX.
|
||||
if self.momentum is None:
|
||||
exponential_average_factor = 0.0
|
||||
else:
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||
if self.num_batches_tracked is not None:
|
||||
self.num_batches_tracked = self.num_batches_tracked + 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
if self.affine:
|
||||
W = self.weight_mask * self.weight
|
||||
b = self.bias_mask * self.bias
|
||||
else:
|
||||
W = self.weight
|
||||
b = self.bias
|
||||
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, W, b,
|
||||
self.training or not self.track_running_stats,
|
||||
exponential_average_factor, self.eps)
|
||||
|
||||
|
||||
class BatchNorm2d(nn.BatchNorm2d):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||
track_running_stats=True):
|
||||
super(BatchNorm2d, self).__init__(
|
||||
num_features, eps, momentum, affine, track_running_stats)
|
||||
if self.affine:
|
||||
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
||||
self.register_buffer('bias_mask', torch.ones(self.bias.shape))
|
||||
self.register_buffer('score', torch.zeros(self.weight.shape))
|
||||
def forward(self, input):
|
||||
self._check_input_dim(input)
|
||||
|
||||
# exponential_average_factor is set to self.momentum
|
||||
# (when it is available) only so that if gets updated
|
||||
# in ONNX graph when this node is exported to ONNX.
|
||||
if self.momentum is None:
|
||||
exponential_average_factor = 0.0
|
||||
else:
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||
if self.num_batches_tracked is not None:
|
||||
self.num_batches_tracked = self.num_batches_tracked + 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
if self.affine:
|
||||
W = self.weight_mask * self.weight
|
||||
b = self.bias_mask * self.bias
|
||||
else:
|
||||
W = self.weight
|
||||
b = self.bias
|
||||
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, W, b,
|
||||
self.training or not self.track_running_stats,
|
||||
exponential_average_factor, self.eps)
|
||||
|
||||
|
||||
class Identity1d(nn.Module):
|
||||
def __init__(self, num_features):
|
||||
super(Identity1d, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.weight = Parameter(torch.Tensor(num_features))
|
||||
self.bias = None
|
||||
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
||||
self.reset_parameters()
|
||||
self.register_buffer('score', torch.zeros(self.weight.shape))
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, input):
|
||||
W = self.weight_mask * self.weight
|
||||
return input * W
|
||||
|
||||
|
||||
class Identity2d(nn.Module):
|
||||
def __init__(self, num_features):
|
||||
super(Identity2d, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.weight = Parameter(torch.Tensor(num_features, 1, 1))
|
||||
self.bias = None
|
||||
self.register_buffer('weight_mask', torch.ones(self.weight.shape))
|
||||
self.reset_parameters()
|
||||
self.register_buffer('score', torch.zeros(self.weight.shape))
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, input):
|
||||
W = self.weight_mask * self.weight
|
||||
return input * W
|
||||
|
||||
|
||||
|
||||
|
||||
143
README.md
Normal file
143
README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Zero-Cost Operation Scoring in Differentiable Architecture Search (Zero-Cost-PT)
|
||||
Official impementation for AAAI 2023 submission:
|
||||
"**Zero-Cost Operation Scoring in Differentiable Architecture Search**".
|
||||
|
||||
|
||||
## Installation
|
||||
```
|
||||
Python >= 3.6
|
||||
PyTorch >= 1.7.1
|
||||
torchvision == 0.8.2
|
||||
tensorboard == 2.4.1
|
||||
scipy == 1.5.2
|
||||
gpustat
|
||||
```
|
||||
|
||||
## Usage/Examples
|
||||
|
||||
### Experiments on NAS-Bench-201
|
||||
Scripts for reproducing our experiments can be found under the ```exp_scripts/``` folder.
|
||||
|
||||
#### 1. Prepare NAS-Bench-201 Data
|
||||
1. Download NAS-Bench-201 checkpoint from [NAS-Bench-201-v1_0-e61699.pth](https://drive.google.com/file/d/1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs/view), and place it under ```./data``` folder.
|
||||
|
||||
#### 2. Prepare NAS-Bench-201 and Zero-Cost-NAS API
|
||||
i. Install NAS-Bench-201 api via `pip`
|
||||
```
|
||||
pip install nas-bench-201
|
||||
```
|
||||
ii. Install Zero-Cost-NAS API
|
||||
|
||||
Clone the code repository from [Zero-Cost-NAS](https://github.com/SamsungLabs/zero-cost-nas). Go to the root directory of the cloned repo and run:
|
||||
```
|
||||
pip install .
|
||||
```
|
||||
|
||||
#### 3. Run Zero-Cost-PT on NAS-Bench-201
|
||||
|
||||
You can run our Zero-Cost-PT with the following script:
|
||||
```
|
||||
bash zerocostpt_nb201_pipeline.sh --seed [SEED]
|
||||
```
|
||||
You can specify random seeds with ``` --seed ``` for reproducibility. In our experiments we use random seeds 0, 1, 2, 3.
|
||||
|
||||
You could also run with different zero-cost proxies by specifying ```--metrics```, and different edge discretization order with ```--edge_decision```. The number of searching interations (N in our paper) is controlled by parameter ```--pool_size```, while the number of validation iterations (V in our paper) can be specified by ```--validate_rounds```. Please see Section 4.2 in our paper for more information on those parameters.
|
||||
|
||||
For example, a typical experiement setting could be:
|
||||
|
||||
```--pool_size 10 --edge_decision random --validate_rounds 100 --metrics jacob --seed 0```
|
||||
|
||||
### Experiments on NAS-Bench-1shot1
|
||||
Scripts for reproducing our experiments on NAS-Bench-1shot1 can be found under the ```nasbench1shot1``` folder.
|
||||
|
||||
#### 1. Prepare NAS-Bench-101 Data
|
||||
1. Download NAS-Bench-1shot1 dataset from [nasbench_full.tfrecord](https://storage.googleapis.com/nasbench/nasbench_full.tfrecord), and place it under ```./data``` folder.
|
||||
|
||||
#### 2. Prepare NAS-Bench-101 API
|
||||
Please refer orginal [NAS-Bench-101](https://github.com/google-research/nasbench) for details of API installation
|
||||
|
||||
#### 3. Run Zero-Cost-PT on NAS-Bench-1shot1
|
||||
|
||||
You can reproduce our Zero-Cost-PT with the following script:
|
||||
```
|
||||
cd nasbench1shot1/optimizers/darts/
|
||||
```
|
||||
i. Run the following script for search architectures with Zero-Cost-PT from different sub-search-space on NAS-Bench-1shot1
|
||||
```
|
||||
python network_proposal.py --seed [SEED] --search_space [SEARCH_SPACE]
|
||||
```
|
||||
In NAS-Bench-1shot1, it contains 3 sub-search-space which you can select from [1, 2, 3]
|
||||
|
||||
ii. Evaluated final searched model
|
||||
```
|
||||
python post_validate.py --seed [SEED] --search_exp_path [PATH_to_LAST_STEP_LOG_FOLDER]
|
||||
```
|
||||
|
||||
### Experiments on NAS-Bench-Macro
|
||||
Scripts for reproducing our experiments on NAS-Bench-1shot1 can be found under the ```nasbenchmacro``` folder.
|
||||
|
||||
#### 1. Prepare NAS-Bench-Macro Data
|
||||
1. Download NAS-Bench-Macro dataset from [nas-bench-macro_cifar10.jsonnas-bench-macro_cifar10.json](https://github.com/xiusu/NAS-Bench-Macro/tree/master/data/nas-bench-macro_cifar10.json), and place it under ```./data``` folder.
|
||||
|
||||
#### 2. Run Zero-Cost-PT on NAS-Bench-Macro
|
||||
|
||||
You can reproduce our Zero-Cost-PT with the following script:
|
||||
```
|
||||
cd nasbenchmacro/
|
||||
```
|
||||
i. Run the following script for search architectures with Zero-Cost-PT on NAS-Bench-Macro
|
||||
```
|
||||
python network_proposal.py --seed [SEED]
|
||||
```
|
||||
|
||||
### Experiments on DARTS-like Spaces
|
||||
Scripts for reproducing our experiments can be found under the ```exp_scripts/``` folder, and Zero-Cost-NAS API is also needed.
|
||||
|
||||
#### 1. For DARTS CNN space
|
||||
|
||||
Run the following script to search architectures with Zero-Cost-PT and train the searched architecture directly (with the same random seed):
|
||||
```
|
||||
bash zerocostpt_darts_pipeline.sh --seed [SEED]
|
||||
```
|
||||
Our default parameter settings are:
|
||||
|
||||
```--pool_size 10 --edge_decision random --validate_rounds 100 --metrics jacob```
|
||||
|
||||
|
||||
#### 2、For DARTS subspaces S1-S4
|
||||
|
||||
On CIFAR-10 use the following script:
|
||||
|
||||
```
|
||||
bash zerocostpt_darts_pipeline.sh --seed [SEED] --space [s1-s4]
|
||||
```
|
||||
|
||||
On CIFAR-100 and SVHN, use the following scripts:
|
||||
|
||||
```
|
||||
bash zerocostpt_darts_pipeline_svhn.sh --seed [SEED] --space [s1-s4]
|
||||
bash zerocostpt_darts_pipeline_c100.sh --seed [SEED] --space [s1-s4]
|
||||
```
|
||||
|
||||
#### 3、Directly train the searched architectures reported in our paper
|
||||
|
||||
For reproducibility we also provide training scripts for evaluation of all the reported architectures in our paper. For an architecture specified by ```[genotype_name]```, run the following scrips to train:
|
||||
|
||||
```
|
||||
bash eval.sh --arch [genotype_name] # for DARTS C10
|
||||
bash eval-c100.sh --arch [genotype_name] # for DARTS C100
|
||||
bash eval-svhn.sh --arch [genotype_name] # for DARTS SVHN
|
||||
```
|
||||
|
||||
The model genotypes are provided in ```sota/cnn/genotypes.py```. For instance, genotype `init_pt_s5_C10_0_100_N10` specifies the architecture searched by Zero-Cost-PT (with default settings as explaind above) on DARTS CNN space (S5), using 10 search iterations (N=10), 100 validation iterations (V=100), and random seed 0.
|
||||
|
||||
#
|
||||
|
||||
|
||||
### Other Experiments Reported in Appendix
|
||||
We also provide code to reproduce experiment results reported in appendix, e.g. genotypes for maximum-param and random-sampling baselines, and Zero-Cost-PT for MobileNet-like spaces.
|
||||
|
||||
|
||||
|
||||
## Reference
|
||||
Our code (Zero-Cost-PT) is based on [dart-pt](https://github.com/ruocwang/darts-pt) and [Zero-Cost-NAS](https://github.com/SamsungLabs/zero-cost-nas). For experiments on Nasbench1shot1 is based on [nasbench-1shot1](https://github.com/automl/nasbench-1shot1)
|
||||
47
Scorers/scorer.py
Normal file
47
Scorers/scorer.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class Jocab_Scorer:
|
||||
def __init__(self, gpu):
|
||||
self.gpu = gpu
|
||||
print('Jacob score init')
|
||||
|
||||
def score(self, model, input, target):
|
||||
batch_size = input.shape[0]
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
|
||||
input = input.cuda()
|
||||
with torch.no_grad():
|
||||
model(input)
|
||||
score = self.hooklogdet(model.K.cpu().numpy())
|
||||
|
||||
#print(score)
|
||||
return score
|
||||
|
||||
def setup_hooks(self, model, batch_size):
|
||||
#initalize score
|
||||
model = model.to(torch.device('cuda', self.gpu))
|
||||
model.eval()
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
# if not module.visited_backwards:
|
||||
# return
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
model.K = model.K + K + K2
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
#module.register_backward_hook(counting_backward_hook)
|
||||
|
||||
def hooklogdet(self, K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
||||
39
exp_scripts/eval-c100.sh
Normal file
39
exp_scripts/eval-c100.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
script_name=`basename "$0"`
|
||||
id=${script_name%.*}
|
||||
dataset=${dataset:-cifar100}
|
||||
seed=${seed:-2}
|
||||
gpu=${gpu:-"auto"}
|
||||
arch=${arch:-"none"}
|
||||
batch_size=${batch_size:-96}
|
||||
learning_rate=${learning_rate:-0.025}
|
||||
resume_expid=${resume_expid:-'none'}
|
||||
resume_epoch=${resume_epoch:-0}
|
||||
|
||||
while [ $# -gt 0 ]; do
|
||||
if [[ $1 == *"--"* ]]; then
|
||||
param="${1/--/}"
|
||||
declare $param="$2"
|
||||
# echo $1 $2 // Optional to see the parameter:value result
|
||||
fi
|
||||
shift
|
||||
done
|
||||
|
||||
echo 'id:' $id
|
||||
echo 'seed:' $seed
|
||||
echo 'dataset:' $dataset
|
||||
echo 'gpu:' $gpu
|
||||
echo 'arch:' $arch
|
||||
echo 'batch_size:' $batch_size
|
||||
echo 'learning_rate:' $learning_rate
|
||||
|
||||
|
||||
cd ../sota/cnn
|
||||
python train.py \
|
||||
--arch $arch \
|
||||
--dataset $dataset \
|
||||
--auxiliary --cutout \
|
||||
--seed $seed --save $id --gpu $gpu \
|
||||
--batch_size $batch_size --learning_rate $learning_rate \
|
||||
--resume_expid $resume_expid --resume_epoch $resume_epoch \
|
||||
--init_channels 16 --layers 8 \
|
||||
38
exp_scripts/eval.sh
Normal file
38
exp_scripts/eval.sh
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
script_name=`basename "$0"`
|
||||
id=${script_name%.*}
|
||||
dataset=${dataset:-cifar10}
|
||||
seed=${seed:-2}
|
||||
gpu=${gpu:-"auto"}
|
||||
arch=${arch:-"none"}
|
||||
batch_size=${batch_size:-96}
|
||||
learning_rate=${learning_rate:-0.025}
|
||||
resume_expid=${resume_expid:-'none'}
|
||||
resume_epoch=${resume_epoch:-0}
|
||||
|
||||
while [ $# -gt 0 ]; do
|
||||
if [[ $1 == *"--"* ]]; then
|
||||
param="${1/--/}"
|
||||
declare $param="$2"
|
||||
# echo $1 $2 // Optional to see the parameter:value result
|
||||
fi
|
||||
shift
|
||||
done
|
||||
|
||||
echo 'id:' $id
|
||||
echo 'seed:' $seed
|
||||
echo 'dataset:' $dataset
|
||||
echo 'gpu:' $gpu
|
||||
echo 'arch:' $arch
|
||||
echo 'batch_size:' $batch_size
|
||||
echo 'learning_rate:' $learning_rate
|
||||
|
||||
|
||||
cd ../sota/cnn
|
||||
python3 train.py \
|
||||
--arch $arch \
|
||||
--dataset $dataset \
|
||||
--auxiliary --cutout \
|
||||
--seed $seed --save $id --gpu $gpu \
|
||||
--batch_size $batch_size --learning_rate $learning_rate \
|
||||
--resume_expid $resume_expid --resume_epoch $resume_epoch \
|
||||
51
exp_scripts/zerocostpt_darts_pipeline.sh
Normal file
51
exp_scripts/zerocostpt_darts_pipeline.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
script_name=`basename "$0"`
|
||||
id=${script_name%.*}
|
||||
dataset=${dataset:-cifar10}
|
||||
seed=${seed:-1}
|
||||
gpu=${gpu:-"0"}
|
||||
pool_size=${pool_size:-10}
|
||||
space=${space:-s5}
|
||||
metric=${metric:-'jacob'}
|
||||
edge_decision=${edge_decision:-'random'}
|
||||
validate_rounds=${validate_rounds:-100}
|
||||
learning_rate=${learning_rate:-0.025}
|
||||
while [ $# -gt 0 ]; do
|
||||
if [[ $1 == *"--"* ]]; then
|
||||
param="${1/--/}"
|
||||
declare $param="$2"
|
||||
# echo $1 $2 // Optional to see the parameter:value result
|
||||
fi
|
||||
shift
|
||||
done
|
||||
|
||||
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset 'space:' $space
|
||||
echo 'proj crit:' $metric
|
||||
echo 'gpu:' $gpu
|
||||
|
||||
cd ../sota/cnn
|
||||
python3 networks_proposal.py \
|
||||
--search_space $space --dataset $dataset \
|
||||
--batch_size 64 \
|
||||
--seed $seed --save $id --gpu $gpu \
|
||||
--edge_decision $edge_decision \
|
||||
--proj_crit_normal $metric --proj_crit_reduce $metric --proj_crit_edge $metric \
|
||||
--pool_size $pool_size\
|
||||
|
||||
cd ../../zerocostnas/
|
||||
python3 post_validate.py\
|
||||
--ckpt_path ../experiments/sota/$dataset-search-$id-$space-$seed-$pool_size-$metric\
|
||||
--save $id --seed $seed --gpu $gpu\
|
||||
--batch_size 64\
|
||||
--edge_decision $edge_decision --proj_crit $metric \
|
||||
--validate_rounds $validate_rounds\
|
||||
|
||||
cd ../sota/cnn
|
||||
python3 train.py \
|
||||
--seed $seed --gpu $gpu --save $id \
|
||||
--arch ../../experiments/sota/$space-valid-$id-$seed-$pool_size-$validate_rounds-$metric\
|
||||
--dataset $dataset \
|
||||
--auxiliary --cutout \
|
||||
--batch_size 96 --learning_rate $learning_rate \
|
||||
--init_channels 36 --layers 20\
|
||||
--from_dir\
|
||||
50
exp_scripts/zerocostpt_darts_pipeline_c100.sh
Normal file
50
exp_scripts/zerocostpt_darts_pipeline_c100.sh
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
script_name=`basename "$0"`
|
||||
id=${script_name%.*}
|
||||
dataset=${dataset:-cifar100}
|
||||
seed=${seed:-2}
|
||||
gpu=${gpu:-"auto"}
|
||||
pool_size=${pool_size:-100}
|
||||
space=${space:-s5}
|
||||
metric=${metric:-'jacob'}
|
||||
edge_decision=${edge_decision:-'random'}
|
||||
validate_rounds=${validate_rounds:-100}
|
||||
learning_rate=${learning_rate:-0.025}
|
||||
while [ $# -gt 0 ]; do
|
||||
if [[ $1 == *"--"* ]]; then
|
||||
param="${1/--/}"
|
||||
declare $param="$2"
|
||||
# echo $1 $2 // Optional to see the parameter:value result
|
||||
fi
|
||||
shift
|
||||
done
|
||||
|
||||
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset 'space:' $space
|
||||
echo 'proj crit:' $metric
|
||||
echo 'gpu:' $gpu
|
||||
|
||||
cd ../sota/cnn
|
||||
python3 networks_proposal.py \
|
||||
--search_space $space --dataset $dataset --batch_size 64 \
|
||||
--seed $seed --save $id --gpu $gpu \
|
||||
--edge_decision $edge_decision \
|
||||
--proj_crit_normal $metric --proj_crit_reduce $metric --proj_crit_edge $metric \
|
||||
--pool_size $pool_size\
|
||||
|
||||
cd ../zerocostnas/
|
||||
python3 post_validate.py\
|
||||
--ckpt_path ../experiments/sota/$dataset-search-$id-$space-$seed-$pool_size-$metric\
|
||||
--save $id --seed $seed --gpu $gpu\
|
||||
--edge_decision $edge_decision --proj_crit $metric \
|
||||
--batch_size 64\
|
||||
--validate_rounds $validate_rounds\
|
||||
|
||||
cd ../sota/cnn
|
||||
python3 train.py \
|
||||
--seed $seed --gpu $gpu --save $id\
|
||||
--arch ../../experiments/sota/$space-valid-$id-$seed-$pool_size-$validate_rounds-$dataset-$metric\
|
||||
--dataset $dataset \
|
||||
--auxiliary --cutout \
|
||||
--batch_size 96 --learning_rate $learning_rate \
|
||||
--init_channels 16 --layers 20 \
|
||||
--from_dir\
|
||||
44
exp_scripts/zerocostpt_darts_pipeline_imagenet.sh
Normal file
44
exp_scripts/zerocostpt_darts_pipeline_imagenet.sh
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
script_name=`basename "$0"`
|
||||
id=${script_name%.*}
|
||||
dataset=${dataset:-imagenet}
|
||||
seed=${seed:-2}
|
||||
gpu=${gpu:-"auto"}
|
||||
pool_size=${pool_size:-10}
|
||||
space=${space:-s5}
|
||||
metric=${metric:-'jacob'}
|
||||
edge_decision=${edge_decision:-'random'}
|
||||
validate_rounds=${validate_rounds:-100}
|
||||
learning_rate=${learning_rate:-0.025}
|
||||
data=${data:-''}
|
||||
while [ $# -gt 0 ]; do
|
||||
if [[ $1 == *"--"* ]]; then
|
||||
param="${1/--/}"
|
||||
declare $param="$2"
|
||||
# echo $1 $2 // Optional to see the parameter:value result
|
||||
fi
|
||||
shift
|
||||
done
|
||||
|
||||
echo 'id:' $id 'seed:' $seed 'dataset:' $dataset 'space:' $space
|
||||
echo 'proj crit:' $metric
|
||||
echo 'gpu:' $gpu
|
||||
|
||||
cd ../sota/cnn
|
||||
python3 networks_proposal.py \
|
||||
--search_space $space --dataset $dataset \
|
||||
--seed $seed --save $id --gpu $gpu \
|
||||
--edge_decision $edge_decision \
|
||||
--proj_crit_normal $metric --proj_crit_reduce $metric --proj_crit_edge $metric \
|
||||
--pool_size $pool_size\
|
||||
--data $data\
|
||||
|
||||
cd ../../zerocostnas/
|
||||
python3 post_validate.py\
|
||||
--ckpt_path ../experiments/sota/$dataset-search-$id-$space-$seed-$pool_size\
|
||||
--save $id --seed $seed --gpu $gpu\
|
||||
--edge_decision $edge_decision --proj_crit $metric \
|
||||
--batch_size 64\
|
||||
--validate_rounds $validate_rounds\
|
||||
--data $data\
|
||||
|
||||
1599
notebooks_201/.ipynb_checkpoints/op_strength-checkpoint.ipynb
Normal file
1599
notebooks_201/.ipynb_checkpoints/op_strength-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
15614
notebooks_201/.ipynb_checkpoints/parse_log-checkpoint.ipynb
Normal file
15614
notebooks_201/.ipynb_checkpoints/parse_log-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
122733
notebooks_201/N1000T1_zero_cost.ipynb
Normal file
122733
notebooks_201/N1000T1_zero_cost.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
BIN
notebooks_201/metric_correlation.pdf
Normal file
BIN
notebooks_201/metric_correlation.pdf
Normal file
Binary file not shown.
BIN
notebooks_201/metric_correlation_vert.pdf
Normal file
BIN
notebooks_201/metric_correlation_vert.pdf
Normal file
Binary file not shown.
40707
notebooks_201/nb201-avg-accuracy.ipynb
Normal file
40707
notebooks_201/nb201-avg-accuracy.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
BIN
notebooks_201/nb2_test_acc_cf10.p
Normal file
BIN
notebooks_201/nb2_test_acc_cf10.p
Normal file
Binary file not shown.
BIN
notebooks_201/op_correl_time.pdf
Normal file
BIN
notebooks_201/op_correl_time.pdf
Normal file
Binary file not shown.
BIN
notebooks_201/op_correl_time_nwot.pdf
Normal file
BIN
notebooks_201/op_correl_time_nwot.pdf
Normal file
Binary file not shown.
BIN
notebooks_201/op_correl_time_synflow.pdf
Normal file
BIN
notebooks_201/op_correl_time_synflow.pdf
Normal file
Binary file not shown.
1612
notebooks_201/op_strength.ipynb
Normal file
1612
notebooks_201/op_strength.ipynb
Normal file
File diff suppressed because one or more lines are too long
96
notebooks_201/op_strength_src.csv
Normal file
96
notebooks_201/op_strength_src.csv
Normal file
@@ -0,0 +1,96 @@
|
||||
edge\op,none,skip_connect,nor_conv_1x1,nor_conv_3x3,avg_pool_3x3,
|
||||
acc,77.36±22.55,83.81±10.64,86.38±9.1, 87.32±9.17,81.02±11.85,
|
||||
discrete acc darts,83.42,84.1,72,76.35,39.66,
|
||||
dartspt,85.43,17.02,78.13,59.09,85.34,
|
||||
zc pt,3455.233646,3449.898772,3449.538363,3441.815563,3461.179476,3
|
||||
discrete zc,3331.007285,3445.489455,3366.877065,3437.551079,3423.180255,
|
||||
alpha,0.0758,0.7742,0.05,0.0761,0.024,
|
||||
best-acc,94.15,94.18,94.44,94.68,93.86,
|
||||
alpha-60,0.1387,0.4758,0.1296,0.181,0.0748,
|
||||
tenaspt,38.5,48.0,31.0,6.0,37.5,
|
||||
synflow,1.9286723850908796e+31,7.990282869734622e+30,1.2421187150331997e+30,9.438907569335487e+26,8.191504786187086e+30,
|
||||
synflow_disc,4.639162000716631e+21, 1.4975281050055959e+26, 4.2221622054263117e+30, 1.9475517523688712e+36, 1.5075022033622535e+26,
|
||||
best_nwot,1702.1967536035393,1773.1779654806287, 1793.8140278364453, 1792.8682630835763, 1761.1262357119376,
|
||||
best_synflow,5.784248799475683e+39, 1.4769546208886953e+44, 6.658953754065702e+49, 5.1987025703231504e+39, 1.9928388494681343e+35,
|
||||
zc-pt-post,3067.0476, 3055.9404, 3059.8901, 3060.4536, 3073.5583,
|
||||
zc-disc-post,2942.267, 3068.6416, 3009.5847, 3028.1794, 3031.4248,
|
||||
,,,,,,
|
||||
acc,80.03 ±19.38,83.11 ±12.81,85.23 ±11.0,85.99±11.1,81.52±13.06,
|
||||
discrete acc darts,85.12,83.39,76.72,81.34,84.38,
|
||||
dartspt,85.52,36.1,84.39,80.95,85.49,
|
||||
zc pt,3452.145851,3448.696318,3441.809174,3440.652631,3453.739943,3
|
||||
discrete zc,3429.074707,3435.750274,3407.872847,3434.584512,3421.442414,
|
||||
alpha,0.0629,0.813,0.039,0.0579,0.0269,
|
||||
best-acc,94.24,94.16,94.49,94.68,94.09,
|
||||
alpha-60,0.1236,0.5535,0.11,0.1249,0.088,
|
||||
tenaspt,7.0,55.0,10.0,15.0,39.0,
|
||||
synflow,3.116079880492518e+30,2.5018418732419554e+30,1.4274537256246266e+30,3.138287824323275e+29,2.5693894962958226e+30,
|
||||
synflow_disc,5.615386425664938e+28, 2.340336657109326e+29, 1.9258305801684058e+30, 3.012759514473006e+32, 2.2897138361934977e+29,
|
||||
best_nwot,1765.3743820515451, 1770.8436009213751, 1791.917305624048, 1793.8140278364453, 1763.877253730585,
|
||||
best_synflow,1.9424580089849912e+49, 2.764587447411338e+49, 6.658953754065702e+49, 2.0353792445711388e+49, 1.4435653786128956e+49,
|
||||
zc-pt-post,3067.0476, 3058.9197, 3048.8745, 3051.2664, 3066.668,
|
||||
zc-disc-post,3020.0203, 3052.1936, 3026.2217, 3022.0935, 3029.2,
|
||||
,,,,,,
|
||||
acc,82.9±14.68,82.44 ±14.25,84.05 ±13.19,84.49 ±13.21,81.98 ±14.54,
|
||||
discrete acc darts,85.96,85.18,54.02,78.41,84.88,
|
||||
dartspt,85.51,80.29,81.86,77.68,85.32,
|
||||
zc pt,3446.521007,3447.612434,3435.455206,3436.396744,3449.275466,2
|
||||
discrete zc,3428.795464,3423.361285,3440.925616,3437.286935,3416.891544,
|
||||
alpha,0.3339,0.4742,0.061,0.0774,0.0534,
|
||||
best-acc,94.25,94.43,94.49,94.68,94.19,
|
||||
alpha-60,0.2403,0.3297,0.1495,0.1748,0.1056,
|
||||
tenaspt,31.5,10.0,30.0,16.5,36.5,
|
||||
synflow,1.0312338471669537e+31,4.9191575008661263e+30,1.4241158958667068e+30,1.0282498082879338e+28,5.038622256524752e+30,
|
||||
synflow_disc,1.6980829611704765e+25, 3.3199508659283994e+27, 3.3825056097270114e+30, 1.2059727722928161e+35, 3.279653417965715e+27,
|
||||
best_nwot,1764.51075805859,1764.116749555202, 1793.8140278364453, 1792.8239766388833, 1764.1848313456592,
|
||||
best_synflow,8.376122028137071e+41, 1.0615041036082487e+45, 6.658953754065702e+49, 8.399427750574918e+41, 2.5270360875229e+39,
|
||||
zc-pt-post,3067.0476, 3068.708, 3056.3506, 3047.9695, 3071.3577,
|
||||
zc-disc-post,3044.023, 3033.0627, 3032.825, 3052.0688, 3023.2302,
|
||||
,,,,,,
|
||||
acc,74.02 ±26.1,85.17 ±7.59,87.3 ±2.48,88.28 ±2.06,81.38 ±8.91,
|
||||
discrete acc darts,66.18,85.38,78.8,81.59,82.89,
|
||||
dartspt,85.49,9.86,81.79,59.18,85.48,
|
||||
zc pt,3453.805194,3435.985406,3444.044047,3445.595326,3447.067855,1
|
||||
discrete zc,3408.990502,3464.050741,3359.888463,3382.1755,3431.805571,
|
||||
alpha,0.0267,0.8163,0.0471,0.0904,0.0194,
|
||||
best-acc,94.16,94.68,94.03,94.04,93.85,
|
||||
alpha-60,0.0636,0.6513,0.0826,0.1335,0.0691,
|
||||
tenaspt,34.0,44.0,53.5,23.0,30.0,
|
||||
synflow,2.0042808467776213e+30,1.9513599734483263e+30,1.5188352495143643e+30,7.704103863066581e+29,1.9536326167605112e+30,
|
||||
synflow_disc,4.3050000047616484e+29, 7.635399455155384e+29, 1.5949429556375966e+30, 1.4519088590209463e+31, 7.345232988374157e+29,
|
||||
best_nwot,1766.5481959337162,1769.1683503033412, 1793.8140278364453, 1792.8682630835763, 1765.1445530390838,
|
||||
best_synflow,5.90523769961745e+49, 6.344766865099622e+49, 6.571181309028854e+49, 6.57509920946309e+49, 6.658953754065702e+49,
|
||||
zc-pt-post,3067.0476, 3032.6658, 3058.9646, 3059.2861, 3047.1965,
|
||||
zc-disc-post,2975.976, 3130.7397, 3008.5625, 3009.341, 3086.3398,
|
||||
,,,,,,
|
||||
acc,80.14 ±19.52,83.05 ±12.74,85.09 ±11.17,85.7 ±11.18,81.89 ±12.98,
|
||||
discrete acc darts,86.44,84.75,80.23,80.46,80.13,
|
||||
dartspt,85.45,51.15,78.84,64.64,85.14,
|
||||
zc pt,3451.055723,3449.796894,3442.625354,3441.131751,3453.311493,3
|
||||
discrete zc,3433.98773,3435.573458,3424.470031,3431.143217,3423.153213,
|
||||
alpha,0.0857,0.7082,0.0716,0.0946,0.0399,
|
||||
best-acc,94.29,94.18,94.56,94.68,94.23,
|
||||
alpha-60,0.1183,0.48,0.1305,0.1732,0.0979,
|
||||
tenaspt,32.0,32.5,36.5,32.0,52.0,
|
||||
synflow,3.165975343348193e+30,2.4302742111297496e+30,1.4853908452542004e+30,2.868307126123347e+29,2.6891361283692336e+30,
|
||||
synflow_disc,5.5202846896598e+28, 2.4896852024898197e+29, 2.1810394989246777e+30, 2.9482018739806336e+32, 2.2732178076450144e+29,
|
||||
best_nwot, 1752.024924623228,1793.8140278364453, 1786.3402409418215, 1785.0294182838636, 1781.9741301640186,
|
||||
best_synflow,1.8865959738805548e+49, 2.593134717306188e+49, 6.658953754065702e+49, 2.021273089103704e+49, 1.6187260144154453e+49,
|
||||
zc-pt-post,3067.0476, 3060.9983, 3057.1006, 3054.3428, 3066.2087,
|
||||
zc-disc-post,3037.8726, 3055.4219, 3027.6638, 3024.3271, 3037.8108,
|
||||
,,,,,,
|
||||
acc,77.61 ±22.16,83.43 ±11.34,86.18 ±9.08,86.95 ±9.02,81.74 ±11.79,
|
||||
discrete acc darts,86.28,82.69,77.13,76.8,81.99,
|
||||
dartspt,85.54,32.43,81.04,72.75,85.51,
|
||||
zc pt,3450.967554,3448.211459,3440.79926,3443.240243,3452.989921,2
|
||||
discrete zc,3434.421701,3437.661196,3418.572637,3397.51709,3424.166157,
|
||||
alpha,0.1554,0.7029,0.0538,0.0598,0.028,
|
||||
best-acc,94.05,94.16,94.68,94.56,94.1,
|
||||
alpha-60,0.1648,0.4853,0.1223,0.1397,0.088,
|
||||
tenaspt,38.5,16.0,20.0,17.0,27.5,
|
||||
synflow,1.9460309216168614e+31,8.014786854561015e+30,1.1851807660289746e+30,8.96867143875011e+26,7.75842932776677e+30,
|
||||
synflow_disc,4.777733726551756e+21, 1.4572459237815469e+26, 3.8590321292364994e+30, 1.8898449210848245e+36, 1.5222938895812008e+26,
|
||||
best_nwot, 1761.8789642379636,1769.103803678444, 1793.8140278364453, 1792.8239766388833, 1761.9145207476113,
|
||||
best_synflow,5.776473195639679e+39, 1.4672616553030765e+44, 6.658953754065702e+49, 5.480408193532999e+39, 1.9606567871518125e+35,
|
||||
zc-pt-post,3067.0476, 3063.1135, 3058.818, 3064.5405, 3070.7593,
|
||||
zc-disc-post,3061.6133, 3063.294, 3038.05, 3012.938, 3042.5535,
|
||||
|
15614
notebooks_201/parse_log.ipynb
Normal file
15614
notebooks_201/parse_log.ipynb
Normal file
File diff suppressed because one or more lines are too long
3629
notebooks_201/parse_log_iterative.ipynb
Normal file
3629
notebooks_201/parse_log_iterative.ipynb
Normal file
File diff suppressed because one or more lines are too long
136
pycls/core/benchmark.py
Normal file
136
pycls/core/benchmark.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Benchmarking functions."""
|
||||
|
||||
import pycls.core.logging as logging
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_time_eval(model):
|
||||
"""Computes precise model forward test time using dummy data."""
|
||||
# Use eval mode
|
||||
model.eval()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
# Compute precise forward pass time
|
||||
timer = Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
# Forward
|
||||
timer.tic()
|
||||
model(inputs)
|
||||
torch.cuda.synchronize()
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_train(model, loss_fun):
|
||||
"""Computes precise model forward + backward time using dummy data."""
|
||||
# Use train mode
|
||||
model.train()
|
||||
# Generate a dummy mini-batch and copy data to GPU
|
||||
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
|
||||
if cfg.TASK == "jig":
|
||||
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
else:
|
||||
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
||||
if cfg.TASK in ['col', 'seg']:
|
||||
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
else:
|
||||
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
|
||||
# Cache BatchNorm2D running stats
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
|
||||
# Compute precise forward backward pass time
|
||||
fw_timer, bw_timer = Timer(), Timer()
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
for cur_iter in range(total_iter):
|
||||
# Reset the timers after the warmup phase
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
fw_timer.reset()
|
||||
bw_timer.reset()
|
||||
# Forward
|
||||
fw_timer.tic()
|
||||
preds = model(inputs)
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
torch.cuda.synchronize()
|
||||
fw_timer.toc()
|
||||
# Backward
|
||||
bw_timer.tic()
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
bw_timer.toc()
|
||||
# Restore BatchNorm2D running stats
|
||||
for bn, (mean, var) in zip(bns, bn_stats):
|
||||
bn.running_mean, bn.running_var = mean, var
|
||||
return fw_timer.average_time, bw_timer.average_time
|
||||
|
||||
|
||||
def compute_time_loader(data_loader):
|
||||
"""Computes loader time."""
|
||||
timer = Timer()
|
||||
loader.shuffle(data_loader, 0)
|
||||
data_loader_iterator = iter(data_loader)
|
||||
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
||||
total_iter = min(total_iter, len(data_loader))
|
||||
for cur_iter in range(total_iter):
|
||||
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
||||
timer.reset()
|
||||
timer.tic()
|
||||
next(data_loader_iterator)
|
||||
timer.toc()
|
||||
return timer.average_time
|
||||
|
||||
|
||||
def compute_time_full(model, loss_fun, train_loader, test_loader):
|
||||
"""Times model and data loader."""
|
||||
logger.info("Computing model and loader timings...")
|
||||
# Compute timings
|
||||
test_fw_time = compute_time_eval(model)
|
||||
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
|
||||
train_fw_bw_time = train_fw_time + train_bw_time
|
||||
train_loader_time = compute_time_loader(train_loader)
|
||||
# Output iter timing
|
||||
iter_times = {
|
||||
"test_fw_time": test_fw_time,
|
||||
"train_fw_time": train_fw_time,
|
||||
"train_bw_time": train_bw_time,
|
||||
"train_fw_bw_time": train_fw_bw_time,
|
||||
"train_loader_time": train_loader_time,
|
||||
}
|
||||
logger.info(logging.dump_log_data(iter_times, "iter_times"))
|
||||
# Output epoch timing
|
||||
epoch_times = {
|
||||
"test_fw_time": test_fw_time * len(test_loader),
|
||||
"train_fw_time": train_fw_time * len(train_loader),
|
||||
"train_bw_time": train_bw_time * len(train_loader),
|
||||
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
|
||||
"train_loader_time": train_loader_time * len(train_loader),
|
||||
}
|
||||
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
|
||||
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
|
||||
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
|
||||
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
|
||||
88
pycls/core/builders.py
Normal file
88
pycls/core/builders.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Model and loss construction functions."""
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
from pycls.models.effnet import EffNet
|
||||
from pycls.models.regnet import RegNet
|
||||
from pycls.models.resnet import ResNet
|
||||
from pycls.models.nas.nas import NAS
|
||||
from pycls.models.nas.nas_search import NAS_Search
|
||||
from pycls.models.nas_bench.model_builder import NAS_Bench
|
||||
|
||||
|
||||
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
|
||||
"""CrossEntropyLoss with label smoothing."""
|
||||
def __init__(self):
|
||||
super(LabelSmoothedCrossEntropyLoss, self).__init__()
|
||||
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
|
||||
self.num_classes = cfg.MODEL.NUM_CLASSES
|
||||
|
||||
def forward(self, logits, target):
|
||||
pred = logits.log_softmax(dim=-1)
|
||||
with torch.no_grad():
|
||||
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
|
||||
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
|
||||
return (-target_dist * pred).sum(dim=-1).mean()
|
||||
|
||||
|
||||
# Supported models
|
||||
_models = {
|
||||
"anynet": AnyNet,
|
||||
"effnet": EffNet,
|
||||
"resnet": ResNet,
|
||||
"regnet": RegNet,
|
||||
"nas": NAS,
|
||||
"nas_search": NAS_Search,
|
||||
"nas_bench": NAS_Bench,
|
||||
}
|
||||
|
||||
# Supported loss functions
|
||||
_loss_funs = {
|
||||
"cross_entropy": torch.nn.CrossEntropyLoss,
|
||||
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
|
||||
}
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Gets the model class specified in the config."""
|
||||
err_str = "Model type '{}' not supported"
|
||||
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
|
||||
return _models[cfg.MODEL.TYPE]
|
||||
|
||||
|
||||
def get_loss_fun():
|
||||
"""Gets the loss function class specified in the config."""
|
||||
err_str = "Loss function type '{}' not supported"
|
||||
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
|
||||
return _loss_funs[cfg.MODEL.LOSS_FUN]
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Builds the model."""
|
||||
return get_model()()
|
||||
|
||||
|
||||
def build_loss_fun():
|
||||
"""Build the loss function."""
|
||||
if cfg.TASK == "seg":
|
||||
return get_loss_fun()(ignore_index=255)
|
||||
else:
|
||||
return get_loss_fun()()
|
||||
|
||||
|
||||
def register_model(name, ctor):
|
||||
"""Registers a model dynamically."""
|
||||
_models[name] = ctor
|
||||
|
||||
|
||||
def register_loss_fun(name, ctor):
|
||||
"""Registers a loss function dynamically."""
|
||||
_loss_funs[name] = ctor
|
||||
98
pycls/core/checkpoint.py
Normal file
98
pycls/core/checkpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions that handle saving and loading of checkpoints."""
|
||||
|
||||
import os
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Common prefix for checkpoint file names
|
||||
_NAME_PREFIX = "model_epoch_"
|
||||
# Checkpoints directory name
|
||||
_DIR_NAME = "checkpoints"
|
||||
|
||||
|
||||
def get_checkpoint_dir():
|
||||
"""Retrieves the location for storing checkpoints."""
|
||||
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
|
||||
|
||||
|
||||
def get_checkpoint(epoch):
|
||||
"""Retrieves the path to a checkpoint file."""
|
||||
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
|
||||
return os.path.join(get_checkpoint_dir(), name)
|
||||
|
||||
|
||||
def get_last_checkpoint():
|
||||
"""Retrieves the most recent checkpoint (highest epoch number)."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
# Checkpoint file names are in lexicographic order
|
||||
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
|
||||
last_checkpoint_name = sorted(checkpoints)[-1]
|
||||
return os.path.join(checkpoint_dir, last_checkpoint_name)
|
||||
|
||||
|
||||
def has_checkpoint():
|
||||
"""Determines if there are checkpoints available."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return False
|
||||
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch):
|
||||
"""Saves a checkpoint."""
|
||||
# Save checkpoints only from the master process
|
||||
if not dist.is_master_proc():
|
||||
return
|
||||
# Ensure that the checkpoint dir exists
|
||||
os.makedirs(get_checkpoint_dir(), exist_ok=True)
|
||||
# Omit the DDP wrapper in the multi-gpu setting
|
||||
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
|
||||
# Record the state
|
||||
if isinstance(optimizer, list):
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_w_state": optimizer[0].state_dict(),
|
||||
"optimizer_a_state": optimizer[1].state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
else:
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
# Write the checkpoint
|
||||
checkpoint_file = get_checkpoint(epoch + 1)
|
||||
torch.save(checkpoint, checkpoint_file)
|
||||
return checkpoint_file
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer=None):
|
||||
"""Loads the checkpoint from the given file."""
|
||||
err_str = "Checkpoint '{}' not found"
|
||||
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
# Account for the DDP wrapper in the multi-gpu setting
|
||||
ms = model.module if cfg.NUM_GPUS > 1 else model
|
||||
ms.load_state_dict(checkpoint["model_state"])
|
||||
# Load the optimizer state (commonly not done when fine-tuning)
|
||||
if optimizer:
|
||||
if isinstance(optimizer, list):
|
||||
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
|
||||
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
|
||||
else:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
return checkpoint["epoch"]
|
||||
500
pycls/core/config.py
Normal file
500
pycls/core/config.py
Normal file
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Configuration file (powered by YACS)."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pycls.core.io import cache_url
|
||||
from yacs.config import CfgNode as CfgNode
|
||||
|
||||
|
||||
# Global config object
|
||||
_C = CfgNode()
|
||||
|
||||
# Example usage:
|
||||
# from core.config import cfg
|
||||
cfg = _C
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Model options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MODEL = CfgNode()
|
||||
|
||||
# Model type
|
||||
_C.MODEL.TYPE = ""
|
||||
|
||||
# Number of weight layers
|
||||
_C.MODEL.DEPTH = 0
|
||||
|
||||
# Number of input channels
|
||||
_C.MODEL.INPUT_CHANNELS = 3
|
||||
|
||||
# Number of classes
|
||||
_C.MODEL.NUM_CLASSES = 10
|
||||
|
||||
# Loss function (see pycls/core/builders.py for options)
|
||||
_C.MODEL.LOSS_FUN = "cross_entropy"
|
||||
|
||||
# Label smoothing eps
|
||||
_C.MODEL.LABEL_SMOOTHING_EPS = 0.0
|
||||
|
||||
# ASPP channels
|
||||
_C.MODEL.ASPP_CHANNELS = 256
|
||||
|
||||
# ASPP dilation rates
|
||||
_C.MODEL.ASPP_RATES = [6, 12, 18]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# ResNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.RESNET = CfgNode()
|
||||
|
||||
# Transformation function (see pycls/models/resnet.py for options)
|
||||
_C.RESNET.TRANS_FUN = "basic_transform"
|
||||
|
||||
# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
|
||||
_C.RESNET.NUM_GROUPS = 1
|
||||
|
||||
# Width of each group (64 -> ResNet; 4 -> ResNeXt)
|
||||
_C.RESNET.WIDTH_PER_GROUP = 64
|
||||
|
||||
# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
|
||||
_C.RESNET.STRIDE_1X1 = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# AnyNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.ANYNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.ANYNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.ANYNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.ANYNET.DEPTHS = []
|
||||
|
||||
# Width for each stage (width of each block in the stage)
|
||||
_C.ANYNET.WIDTHS = []
|
||||
|
||||
# Strides for each stage (applies to the first block of each stage)
|
||||
_C.ANYNET.STRIDES = []
|
||||
|
||||
# Bottleneck multipliers for each stage (applies to bottleneck block)
|
||||
_C.ANYNET.BOT_MULS = []
|
||||
|
||||
# Group widths for each stage (applies to bottleneck block)
|
||||
_C.ANYNET.GROUP_WS = []
|
||||
|
||||
# Whether SE is enabled for res_bottleneck_block
|
||||
_C.ANYNET.SE_ON = False
|
||||
|
||||
# SE ratio
|
||||
_C.ANYNET.SE_R = 0.25
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# RegNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.REGNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.REGNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.REGNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Stride of each stage
|
||||
_C.REGNET.STRIDE = 2
|
||||
|
||||
# Squeeze-and-Excitation (RegNetY)
|
||||
_C.REGNET.SE_ON = False
|
||||
_C.REGNET.SE_R = 0.25
|
||||
|
||||
# Depth
|
||||
_C.REGNET.DEPTH = 10
|
||||
|
||||
# Initial width
|
||||
_C.REGNET.W0 = 32
|
||||
|
||||
# Slope
|
||||
_C.REGNET.WA = 5.0
|
||||
|
||||
# Quantization
|
||||
_C.REGNET.WM = 2.5
|
||||
|
||||
# Group width
|
||||
_C.REGNET.GROUP_W = 16
|
||||
|
||||
# Bottleneck multiplier (bm = 1 / b from the paper)
|
||||
_C.REGNET.BOT_MUL = 1.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# EfficientNet options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.EN = CfgNode()
|
||||
|
||||
# Stem width
|
||||
_C.EN.STEM_W = 32
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.EN.DEPTHS = []
|
||||
|
||||
# Width for each stage (width of each block in the stage)
|
||||
_C.EN.WIDTHS = []
|
||||
|
||||
# Expansion ratios for MBConv blocks in each stage
|
||||
_C.EN.EXP_RATIOS = []
|
||||
|
||||
# Squeeze-and-Excitation (SE) ratio
|
||||
_C.EN.SE_R = 0.25
|
||||
|
||||
# Strides for each stage (applies to the first block of each stage)
|
||||
_C.EN.STRIDES = []
|
||||
|
||||
# Kernel sizes for each stage
|
||||
_C.EN.KERNELS = []
|
||||
|
||||
# Head width
|
||||
_C.EN.HEAD_W = 1280
|
||||
|
||||
# Drop connect ratio
|
||||
_C.EN.DC_RATIO = 0.0
|
||||
|
||||
# Dropout ratio
|
||||
_C.EN.DROPOUT_RATIO = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# NAS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.NAS = CfgNode()
|
||||
|
||||
# Cell genotype
|
||||
_C.NAS.GENOTYPE = 'nas'
|
||||
|
||||
# Custom genotype
|
||||
_C.NAS.CUSTOM_GENOTYPE = []
|
||||
|
||||
# Base NAS width
|
||||
_C.NAS.WIDTH = 16
|
||||
|
||||
# Total number of cells
|
||||
_C.NAS.DEPTH = 20
|
||||
|
||||
# Auxiliary heads
|
||||
_C.NAS.AUX = False
|
||||
|
||||
# Weight for auxiliary heads
|
||||
_C.NAS.AUX_WEIGHT = 0.4
|
||||
|
||||
# Drop path probability
|
||||
_C.NAS.DROP_PROB = 0.0
|
||||
|
||||
# Matrix in NAS Bench
|
||||
_C.NAS.MATRIX = []
|
||||
|
||||
# Operations in NAS Bench
|
||||
_C.NAS.OPS = []
|
||||
|
||||
# Number of stacks in NAS Bench
|
||||
_C.NAS.NUM_STACKS = 3
|
||||
|
||||
# Number of modules per stack in NAS Bench
|
||||
_C.NAS.NUM_MODULES_PER_STACK = 3
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Batch norm options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.BN = CfgNode()
|
||||
|
||||
# BN epsilon
|
||||
_C.BN.EPS = 1e-5
|
||||
|
||||
# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
|
||||
_C.BN.MOM = 0.1
|
||||
|
||||
# Precise BN stats
|
||||
_C.BN.USE_PRECISE_STATS = False
|
||||
_C.BN.NUM_SAMPLES_PRECISE = 1024
|
||||
|
||||
# Initialize the gamma of the final BN of each block to zero
|
||||
_C.BN.ZERO_INIT_FINAL_GAMMA = False
|
||||
|
||||
# Use a different weight decay for BN layers
|
||||
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
|
||||
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Optimizer options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.OPTIM = CfgNode()
|
||||
|
||||
# Base learning rate
|
||||
_C.OPTIM.BASE_LR = 0.1
|
||||
|
||||
# Learning rate policy select from {'cos', 'exp', 'steps'}
|
||||
_C.OPTIM.LR_POLICY = "cos"
|
||||
|
||||
# Exponential decay factor
|
||||
_C.OPTIM.GAMMA = 0.1
|
||||
|
||||
# Steps for 'steps' policy (in epochs)
|
||||
_C.OPTIM.STEPS = []
|
||||
|
||||
# Learning rate multiplier for 'steps' policy
|
||||
_C.OPTIM.LR_MULT = 0.1
|
||||
|
||||
# Maximal number of epochs
|
||||
_C.OPTIM.MAX_EPOCH = 200
|
||||
|
||||
# Momentum
|
||||
_C.OPTIM.MOMENTUM = 0.9
|
||||
|
||||
# Momentum dampening
|
||||
_C.OPTIM.DAMPENING = 0.0
|
||||
|
||||
# Nesterov momentum
|
||||
_C.OPTIM.NESTEROV = True
|
||||
|
||||
# L2 regularization
|
||||
_C.OPTIM.WEIGHT_DECAY = 5e-4
|
||||
|
||||
# Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
|
||||
_C.OPTIM.WARMUP_FACTOR = 0.1
|
||||
|
||||
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
|
||||
_C.OPTIM.WARMUP_EPOCHS = 0
|
||||
|
||||
# Update the learning rate per iter
|
||||
_C.OPTIM.ITER_LR = False
|
||||
|
||||
# Base learning rate for arch
|
||||
_C.OPTIM.ARCH_BASE_LR = 0.0003
|
||||
|
||||
# L2 regularization for arch
|
||||
_C.OPTIM.ARCH_WEIGHT_DECAY = 0.001
|
||||
|
||||
# Optimizer for arch
|
||||
_C.OPTIM.ARCH_OPTIM = 'adam'
|
||||
|
||||
# Epoch to start optimizing arch
|
||||
_C.OPTIM.ARCH_EPOCH = 0.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Training options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TRAIN = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TRAIN.DATASET = ""
|
||||
_C.TRAIN.SPLIT = "train"
|
||||
|
||||
# Total mini-batch size
|
||||
_C.TRAIN.BATCH_SIZE = 128
|
||||
|
||||
# Image size
|
||||
_C.TRAIN.IM_SIZE = 224
|
||||
|
||||
# Evaluate model on test data every eval period epochs
|
||||
_C.TRAIN.EVAL_PERIOD = 1
|
||||
|
||||
# Save model checkpoint every checkpoint period epochs
|
||||
_C.TRAIN.CHECKPOINT_PERIOD = 1
|
||||
|
||||
# Resume training from the latest checkpoint in the output directory
|
||||
_C.TRAIN.AUTO_RESUME = True
|
||||
|
||||
# Weights to start training from
|
||||
_C.TRAIN.WEIGHTS = ""
|
||||
|
||||
# Percentage of gray images in jig
|
||||
_C.TRAIN.GRAY_PERCENTAGE = 0.0
|
||||
|
||||
# Portion to create trainA/trainB split
|
||||
_C.TRAIN.PORTION = 1.0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Testing options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TEST = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TEST.DATASET = ""
|
||||
_C.TEST.SPLIT = "val"
|
||||
|
||||
# Total mini-batch size
|
||||
_C.TEST.BATCH_SIZE = 200
|
||||
|
||||
# Image size
|
||||
_C.TEST.IM_SIZE = 256
|
||||
|
||||
# Weights to use for testing
|
||||
_C.TEST.WEIGHTS = ""
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Common train/test data loader options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.DATA_LOADER = CfgNode()
|
||||
|
||||
# Number of data loader workers per process
|
||||
_C.DATA_LOADER.NUM_WORKERS = 8
|
||||
|
||||
# Load data to pinned host memory
|
||||
_C.DATA_LOADER.PIN_MEMORY = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Memory options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MEM = CfgNode()
|
||||
|
||||
# Perform ReLU inplace
|
||||
_C.MEM.RELU_INPLACE = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# CUDNN options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.CUDNN = CfgNode()
|
||||
|
||||
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
||||
# Note that this may increase the memory usage and will likely not result
|
||||
# in overall speedups when variable size inputs are used (e.g. COCO training)
|
||||
_C.CUDNN.BENCHMARK = True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Precise timing options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.PREC_TIME = CfgNode()
|
||||
|
||||
# Number of iterations to warm up the caches
|
||||
_C.PREC_TIME.WARMUP_ITER = 3
|
||||
|
||||
# Number of iterations to compute avg time
|
||||
_C.PREC_TIME.NUM_ITER = 30
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Misc options
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
# Number of GPUs to use (applies to both training and testing)
|
||||
_C.NUM_GPUS = 1
|
||||
|
||||
# Task (cls, seg, rot, col, jig)
|
||||
_C.TASK = "cls"
|
||||
|
||||
# Grid in Jigsaw (2, 3); no effect if TASK is not jig
|
||||
_C.JIGSAW_GRID = 3
|
||||
|
||||
# Output directory
|
||||
_C.OUT_DIR = "/tmp"
|
||||
|
||||
# Config destination (in OUT_DIR)
|
||||
_C.CFG_DEST = "config.yaml"
|
||||
|
||||
# Note that non-determinism may still be present due to non-deterministic
|
||||
# operator implementations in GPU operator libraries
|
||||
_C.RNG_SEED = 1
|
||||
|
||||
# Log destination ('stdout' or 'file')
|
||||
_C.LOG_DEST = "stdout"
|
||||
|
||||
# Log period in iters
|
||||
_C.LOG_PERIOD = 10
|
||||
|
||||
# Distributed backend
|
||||
_C.DIST_BACKEND = "nccl"
|
||||
|
||||
# Hostname and port for initializing multi-process groups
|
||||
_C.HOST = "localhost"
|
||||
_C.PORT = 10001
|
||||
|
||||
# Models weights referred to by URL are downloaded to this local cache
|
||||
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Deprecated keys
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
|
||||
_C.register_deprecated_key("PREC_TIME.ENABLED")
|
||||
|
||||
|
||||
def assert_and_infer_cfg(cache_urls=True):
|
||||
"""Checks config values invariants."""
|
||||
err_str = "The first lr step must start at 0"
|
||||
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
|
||||
data_splits = ["train", "val", "test"]
|
||||
err_str = "Data split '{}' not supported"
|
||||
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
|
||||
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
|
||||
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
|
||||
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
err_str = "Precise BN stats computation not verified for > 1 GPU"
|
||||
assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str
|
||||
err_str = "Log destination '{}' not supported"
|
||||
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
|
||||
if cache_urls:
|
||||
cache_cfg_urls()
|
||||
|
||||
|
||||
def cache_cfg_urls():
|
||||
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
|
||||
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
|
||||
|
||||
def dump_cfg():
|
||||
"""Dumps the config to the output directory."""
|
||||
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
|
||||
with open(cfg_file, "w") as f:
|
||||
_C.dump(stream=f)
|
||||
|
||||
|
||||
def load_cfg(out_dir, cfg_dest="config.yaml"):
|
||||
"""Loads config from specified output directory."""
|
||||
cfg_file = os.path.join(out_dir, cfg_dest)
|
||||
_C.merge_from_file(cfg_file)
|
||||
|
||||
|
||||
def load_cfg_fom_args(description="Config file options."):
|
||||
"""Load config from command line arguments and set any specified options."""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
help_s = "Config file location"
|
||||
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
|
||||
help_s = "See pycls/core/config.py for all options"
|
||||
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
args = parser.parse_args()
|
||||
_C.merge_from_file(args.cfg_file)
|
||||
_C.merge_from_list(args.opts)
|
||||
157
pycls/core/distributed.py
Normal file
157
pycls/core/distributed.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Distributed helpers."""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def is_master_proc():
|
||||
"""Determines if the current process is the master process.
|
||||
|
||||
Master process is responsible for logging, writing and loading checkpoints. In
|
||||
the multi GPU setting, we assign the master role to the rank 0 process. When
|
||||
training using a single GPU, there is a single process which is considered master.
|
||||
"""
|
||||
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
|
||||
|
||||
|
||||
def init_process_group(proc_rank, world_size):
|
||||
"""Initializes the default process group."""
|
||||
# Set the GPU to use
|
||||
torch.cuda.set_device(proc_rank)
|
||||
# Initialize the process group
|
||||
torch.distributed.init_process_group(
|
||||
backend=cfg.DIST_BACKEND,
|
||||
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
|
||||
world_size=world_size,
|
||||
rank=proc_rank,
|
||||
)
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
"""Destroys the default process group."""
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def scaled_all_reduce(tensors):
|
||||
"""Performs the scaled all_reduce operation on the provided tensors.
|
||||
|
||||
The input tensors are modified in-place. Currently supports only the sum
|
||||
reduction operator. The reduced values are scaled by the inverse size of the
|
||||
process group (equivalent to cfg.NUM_GPUS).
|
||||
"""
|
||||
# There is no need for reduction in the single-proc case
|
||||
if cfg.NUM_GPUS == 1:
|
||||
return tensors
|
||||
# Queue the reductions
|
||||
reductions = []
|
||||
for tensor in tensors:
|
||||
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
||||
reductions.append(reduction)
|
||||
# Wait for reductions to finish
|
||||
for reduction in reductions:
|
||||
reduction.wait()
|
||||
# Scale the results
|
||||
for tensor in tensors:
|
||||
tensor.mul_(1.0 / cfg.NUM_GPUS)
|
||||
return tensors
|
||||
|
||||
|
||||
class ChildException(Exception):
|
||||
"""Wraps an exception from a child process."""
|
||||
|
||||
def __init__(self, child_trace):
|
||||
super(ChildException, self).__init__(child_trace)
|
||||
|
||||
|
||||
class ErrorHandler(object):
|
||||
"""Multiprocessing error handler (based on fairseq's).
|
||||
|
||||
Listens for errors in child processes and propagates the tracebacks to the parent.
|
||||
"""
|
||||
|
||||
def __init__(self, error_queue):
|
||||
# Shared error queue
|
||||
self.error_queue = error_queue
|
||||
# Children processes sharing the error queue
|
||||
self.children_pids = []
|
||||
# Start a thread listening to errors
|
||||
self.error_listener = threading.Thread(target=self.listen, daemon=True)
|
||||
self.error_listener.start()
|
||||
# Register the signal handler
|
||||
signal.signal(signal.SIGUSR1, self.signal_handler)
|
||||
|
||||
def add_child(self, pid):
|
||||
"""Registers a child process."""
|
||||
self.children_pids.append(pid)
|
||||
|
||||
def listen(self):
|
||||
"""Listens for errors in the error queue."""
|
||||
# Wait until there is an error in the queue
|
||||
child_trace = self.error_queue.get()
|
||||
# Put the error back for the signal handler
|
||||
self.error_queue.put(child_trace)
|
||||
# Invoke the signal handler
|
||||
os.kill(os.getpid(), signal.SIGUSR1)
|
||||
|
||||
def signal_handler(self, _sig_num, _stack_frame):
|
||||
"""Signal handler."""
|
||||
# Kill children processes
|
||||
for pid in self.children_pids:
|
||||
os.kill(pid, signal.SIGINT)
|
||||
# Propagate the error from the child process
|
||||
raise ChildException(self.error_queue.get())
|
||||
|
||||
|
||||
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
|
||||
"""Runs a function from a child process."""
|
||||
try:
|
||||
# Initialize the process group
|
||||
init_process_group(proc_rank, world_size)
|
||||
# Run the function
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
except KeyboardInterrupt:
|
||||
# Killed by the parent process
|
||||
pass
|
||||
except Exception:
|
||||
# Propagate exception to the parent process
|
||||
error_queue.put(traceback.format_exc())
|
||||
finally:
|
||||
# Destroy the process group
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
|
||||
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
|
||||
# There is no need for multi-proc in the single-proc case
|
||||
fun_kwargs = fun_kwargs if fun_kwargs else {}
|
||||
if num_proc == 1:
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
return
|
||||
# Handle errors from training subprocesses
|
||||
error_queue = multiprocessing.SimpleQueue()
|
||||
error_handler = ErrorHandler(error_queue)
|
||||
# Run each training subprocess
|
||||
ps = []
|
||||
for i in range(num_proc):
|
||||
p_i = multiprocessing.Process(
|
||||
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
|
||||
)
|
||||
ps.append(p_i)
|
||||
p_i.start()
|
||||
error_handler.add_child(p_i.pid)
|
||||
# Wait for each subprocess to finish
|
||||
for p in ps:
|
||||
p.join()
|
||||
77
pycls/core/io.py
Normal file
77
pycls/core/io.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""IO utilities (adapted from Detectron)"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from urllib import request as urlrequest
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
|
||||
|
||||
|
||||
def cache_url(url_or_file, cache_dir):
|
||||
"""Download the file specified by the URL to the cache_dir and return the path to
|
||||
the cached file. If the argument is not a URL, simply return it as is.
|
||||
"""
|
||||
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
|
||||
if not is_url:
|
||||
return url_or_file
|
||||
url = url_or_file
|
||||
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
|
||||
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
|
||||
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
|
||||
if os.path.exists(cache_file_path):
|
||||
return cache_file_path
|
||||
cache_file_dir = os.path.dirname(cache_file_path)
|
||||
if not os.path.exists(cache_file_dir):
|
||||
os.makedirs(cache_file_dir)
|
||||
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
|
||||
download_url(url, cache_file_path)
|
||||
return cache_file_path
|
||||
|
||||
|
||||
def _progress_bar(count, total):
|
||||
"""Report download progress. Credit:
|
||||
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
|
||||
"""
|
||||
bar_len = 60
|
||||
filled_len = int(round(bar_len * count / float(total)))
|
||||
percents = round(100.0 * count / float(total), 1)
|
||||
bar = "=" * filled_len + "-" * (bar_len - filled_len)
|
||||
sys.stdout.write(
|
||||
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
if count >= total:
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
|
||||
"""Download url and write it to dst_file_path. Credit:
|
||||
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
|
||||
"""
|
||||
req = urlrequest.Request(url)
|
||||
response = urlrequest.urlopen(req)
|
||||
total_size = response.info().get("Content-Length").strip()
|
||||
total_size = int(total_size)
|
||||
bytes_so_far = 0
|
||||
with open(dst_file_path, "wb") as f:
|
||||
while 1:
|
||||
chunk = response.read(chunk_size)
|
||||
bytes_so_far += len(chunk)
|
||||
if not chunk:
|
||||
break
|
||||
if progress_hook:
|
||||
progress_hook(bytes_so_far, total_size)
|
||||
f.write(chunk)
|
||||
return bytes_so_far
|
||||
138
pycls/core/logging.py
Normal file
138
pycls/core/logging.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Logging."""
|
||||
|
||||
import builtins
|
||||
import decimal
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import simplejson
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Show filename and line number in logs
|
||||
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
|
||||
|
||||
# Log file name (for cfg.LOG_DEST = 'file')
|
||||
_LOG_FILE = "stdout.log"
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will be tagged w/ this
|
||||
_TAG = "json_stats: "
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
|
||||
_TYPE = "_type"
|
||||
|
||||
|
||||
def _suppress_print():
|
||||
"""Suppresses printing from the current process."""
|
||||
|
||||
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
|
||||
pass
|
||||
|
||||
builtins.print = ignore
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Sets up the logging."""
|
||||
# Enable logging only for the master process
|
||||
if dist.is_master_proc():
|
||||
# Clear the root logger to prevent any existing logging config
|
||||
# (e.g. set by another module) from messing with our setup
|
||||
logging.root.handlers = []
|
||||
# Construct logging configuration
|
||||
logging_config = {"level": logging.INFO, "format": _FORMAT}
|
||||
# Log either to stdout or to a file
|
||||
if cfg.LOG_DEST == "stdout":
|
||||
logging_config["stream"] = sys.stdout
|
||||
else:
|
||||
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
|
||||
# Configure logging
|
||||
logging.basicConfig(**logging_config)
|
||||
else:
|
||||
_suppress_print()
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
"""Retrieves the logger."""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def dump_log_data(data, data_type, prec=4):
|
||||
"""Covert data (a dictionary) into tagged json string for logging."""
|
||||
data[_TYPE] = data_type
|
||||
data = float_to_decimal(data, prec)
|
||||
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
|
||||
return "{:s}{:s}".format(_TAG, data_json)
|
||||
|
||||
|
||||
def float_to_decimal(data, prec=4):
|
||||
"""Convert floats to decimals which allows for fixed width json."""
|
||||
if isinstance(data, dict):
|
||||
return {k: float_to_decimal(v, prec) for k, v in data.items()}
|
||||
if isinstance(data, float):
|
||||
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
|
||||
"""Get all log files in directory containing subdirs of trained models."""
|
||||
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
|
||||
files = [os.path.join(log_dir, n, log_file) for n in names]
|
||||
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
|
||||
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
|
||||
return files, names
|
||||
|
||||
|
||||
def load_log_data(log_file, data_types_to_skip=()):
|
||||
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
|
||||
# Load log_file
|
||||
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
|
||||
with open(log_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
# Extract and parse lines that start with _TAG and have a type specified
|
||||
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
|
||||
lines = [simplejson.loads(l) for l in lines]
|
||||
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
|
||||
# Generate data structure accessed by data[data_type][index][metric]
|
||||
data_types = [l[_TYPE] for l in lines]
|
||||
data = {t: [] for t in data_types}
|
||||
for t, line in zip(data_types, lines):
|
||||
del line[_TYPE]
|
||||
data[t].append(line)
|
||||
# Generate data structure accessed by data[data_type][metric][index]
|
||||
for t in data:
|
||||
metrics = sorted(data[t][0].keys())
|
||||
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
|
||||
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
|
||||
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
|
||||
return data
|
||||
|
||||
|
||||
def sort_log_data(data):
|
||||
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
|
||||
for t in data:
|
||||
if "epoch" in data[t]:
|
||||
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
|
||||
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
|
||||
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
|
||||
epoch = data[t]["epoch_ind"]
|
||||
if "iter" in data[t]:
|
||||
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
|
||||
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
|
||||
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
|
||||
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
|
||||
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
|
||||
for m in data[t]:
|
||||
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
|
||||
else:
|
||||
data[t] = {m: d[0] for m, d in data[t].items()}
|
||||
return data
|
||||
435
pycls/core/meters.py
Normal file
435
pycls/core/meters.py
Normal file
@@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Meters."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.logging as logging
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.core.timer import Timer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def time_string(seconds):
|
||||
"""Converts time in seconds to a fixed-width string format."""
|
||||
days, rem = divmod(int(seconds), 24 * 3600)
|
||||
hrs, rem = divmod(rem, 3600)
|
||||
mins, secs = divmod(rem, 60)
|
||||
return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
|
||||
|
||||
|
||||
def inter_union(preds, labels, num_classes):
|
||||
_, preds = torch.max(preds, 1)
|
||||
preds = preds.type(torch.uint8) + 1
|
||||
labels = labels.type(torch.uint8) + 1
|
||||
preds = preds * (labels > 0).type(torch.uint8)
|
||||
|
||||
inter = preds * (preds == labels).type(torch.uint8)
|
||||
area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes)
|
||||
area_union = area_preds + area_labels - area_inter
|
||||
|
||||
return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)]
|
||||
|
||||
|
||||
def topk_errors(preds, labels, ks):
|
||||
"""Computes the top-k error for each k."""
|
||||
err_str = "Batch dim of predictions and labels must match"
|
||||
assert preds.size(0) == labels.size(0), err_str
|
||||
# Find the top max_k predictions for each sample
|
||||
_top_max_k_vals, top_max_k_inds = torch.topk(
|
||||
preds, max(ks), dim=1, largest=True, sorted=True
|
||||
)
|
||||
# (batch_size, max_k) -> (max_k, batch_size)
|
||||
top_max_k_inds = top_max_k_inds.t()
|
||||
# (batch_size, ) -> (max_k, batch_size)
|
||||
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
|
||||
# (i, j) = 1 if top i-th prediction for the j-th sample is correct
|
||||
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
|
||||
# Compute the number of topk correct predictions for each k
|
||||
topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks]
|
||||
return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
|
||||
|
||||
|
||||
def gpu_mem_usage():
|
||||
"""Computes the GPU memory usage for the current device (MB)."""
|
||||
mem_usage_bytes = torch.cuda.max_memory_allocated()
|
||||
return mem_usage_bytes / 1024 / 1024
|
||||
|
||||
|
||||
class ScalarMeter(object):
|
||||
"""Measures a scalar value (adapted from Detectron)."""
|
||||
|
||||
def __init__(self, window_size):
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.deque.clear()
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def add_value(self, value):
|
||||
self.deque.append(value)
|
||||
self.count += 1
|
||||
self.total += value
|
||||
|
||||
def get_win_median(self):
|
||||
return np.median(self.deque)
|
||||
|
||||
def get_win_avg(self):
|
||||
return np.mean(self.deque)
|
||||
|
||||
def get_global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
|
||||
class TrainMeter(object):
|
||||
"""Measures training stats."""
|
||||
|
||||
def __init__(self, epoch_iters):
|
||||
self.epoch_iters = epoch_iters
|
||||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
|
||||
self.iter_timer = Timer()
|
||||
self.loss = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
# Current minibatch errors (smoothed over a window)
|
||||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
# Number of misclassified examples
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, timer=False):
|
||||
if timer:
|
||||
self.iter_timer.reset()
|
||||
self.loss.reset()
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
self.mb_top1_err.reset()
|
||||
self.mb_top5_err.reset()
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
|
||||
# Current minibatch stats
|
||||
self.mb_top1_err.add_value(top1_err)
|
||||
self.mb_top5_err.add_value(top5_err)
|
||||
self.loss.add_value(loss)
|
||||
self.lr = lr
|
||||
# Aggregate stats
|
||||
self.num_top1_mis += top1_err * mb_size
|
||||
self.num_top5_mis += top5_err * mb_size
|
||||
self.loss_total += loss * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"eta": time_string(eta_sec),
|
||||
"top1_err": self.mb_top1_err.get_win_median(),
|
||||
"top5_err": self.mb_top5_err.get_win_median(),
|
||||
"loss": self.loss.get_win_median(),
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "train_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
top1_err = self.num_top1_mis / self.num_samples
|
||||
top5_err = self.num_top5_mis / self.num_samples
|
||||
avg_loss = self.loss_total / self.num_samples
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"eta": time_string(eta_sec),
|
||||
"top1_err": top1_err,
|
||||
"top5_err": top5_err,
|
||||
"loss": avg_loss,
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "train_epoch"))
|
||||
|
||||
|
||||
class TestMeter(object):
|
||||
"""Measures testing stats."""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
self.max_iter = max_iter
|
||||
self.iter_timer = Timer()
|
||||
# Current minibatch errors (smoothed over a window)
|
||||
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
|
||||
# Min errors (over the full test set)
|
||||
self.min_top1_err = 100.0
|
||||
self.min_top5_err = 100.0
|
||||
# Number of misclassified examples
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, min_errs=False):
|
||||
if min_errs:
|
||||
self.min_top1_err = 100.0
|
||||
self.min_top5_err = 100.0
|
||||
self.iter_timer.reset()
|
||||
self.mb_top1_err.reset()
|
||||
self.mb_top5_err.reset()
|
||||
self.num_top1_mis = 0
|
||||
self.num_top5_mis = 0
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, top1_err, top5_err, mb_size):
|
||||
self.mb_top1_err.add_value(top1_err)
|
||||
self.mb_top5_err.add_value(top5_err)
|
||||
self.num_top1_mis += top1_err * mb_size
|
||||
self.num_top5_mis += top5_err * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
mem_usage = gpu_mem_usage()
|
||||
iter_stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"top1_err": self.mb_top1_err.get_win_median(),
|
||||
"top5_err": self.mb_top5_err.get_win_median(),
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return iter_stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "test_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
top1_err = self.num_top1_mis / self.num_samples
|
||||
top5_err = self.num_top5_mis / self.num_samples
|
||||
self.min_top1_err = min(self.min_top1_err, top1_err)
|
||||
self.min_top5_err = min(self.min_top5_err, top5_err)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"top1_err": top1_err,
|
||||
"top5_err": top5_err,
|
||||
"min_top1_err": self.min_top1_err,
|
||||
"min_top5_err": self.min_top5_err,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "test_epoch"))
|
||||
|
||||
|
||||
class TrainMeterIoU(object):
|
||||
"""Measures training stats."""
|
||||
|
||||
def __init__(self, epoch_iters):
|
||||
self.epoch_iters = epoch_iters
|
||||
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
|
||||
self.iter_timer = Timer()
|
||||
self.loss = ScalarMeter(cfg.LOG_PERIOD)
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
|
||||
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
|
||||
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, timer=False):
|
||||
if timer:
|
||||
self.iter_timer.reset()
|
||||
self.loss.reset()
|
||||
self.loss_total = 0.0
|
||||
self.lr = None
|
||||
self.mb_miou.reset()
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, inter, union, loss, lr, mb_size):
|
||||
# Current minibatch stats
|
||||
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
|
||||
self.loss.add_value(loss)
|
||||
self.lr = lr
|
||||
# Aggregate stats
|
||||
self.num_inter += inter * mb_size
|
||||
self.num_union += union * mb_size
|
||||
self.loss_total += loss * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"eta": time_string(eta_sec),
|
||||
"miou": self.mb_miou.get_win_median(),
|
||||
"loss": self.loss.get_win_median(),
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "train_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
|
||||
eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
|
||||
mem_usage = gpu_mem_usage()
|
||||
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
|
||||
avg_loss = self.loss_total / self.num_samples
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"eta": time_string(eta_sec),
|
||||
"miou": miou,
|
||||
"loss": avg_loss,
|
||||
"lr": self.lr,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "train_epoch"))
|
||||
|
||||
|
||||
class TestMeterIoU(object):
|
||||
"""Measures testing stats."""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
self.max_iter = max_iter
|
||||
self.iter_timer = Timer()
|
||||
|
||||
self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)
|
||||
|
||||
self.max_miou = 0.0
|
||||
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def reset(self, min_errs=False):
|
||||
if min_errs:
|
||||
self.max_miou = 0.0
|
||||
self.iter_timer.reset()
|
||||
self.mb_miou.reset()
|
||||
self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
|
||||
self.num_samples = 0
|
||||
|
||||
def iter_tic(self):
|
||||
self.iter_timer.tic()
|
||||
|
||||
def iter_toc(self):
|
||||
self.iter_timer.toc()
|
||||
|
||||
def update_stats(self, inter, union, mb_size):
|
||||
self.mb_miou.add_value((inter / (union + 1e-10)).mean())
|
||||
self.num_inter += inter * mb_size
|
||||
self.num_union += union * mb_size
|
||||
self.num_samples += mb_size
|
||||
|
||||
def get_iter_stats(self, cur_epoch, cur_iter):
|
||||
mem_usage = gpu_mem_usage()
|
||||
iter_stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"iter": "{}/{}".format(cur_iter + 1, self.max_iter),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"time_diff": self.iter_timer.diff,
|
||||
"miou": self.mb_miou.get_win_median(),
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return iter_stats
|
||||
|
||||
def log_iter_stats(self, cur_epoch, cur_iter):
|
||||
if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
|
||||
return
|
||||
stats = self.get_iter_stats(cur_epoch, cur_iter)
|
||||
logger.info(logging.dump_log_data(stats, "test_iter"))
|
||||
|
||||
def get_epoch_stats(self, cur_epoch):
|
||||
miou = (self.num_inter / (self.num_union + 1e-10)).mean()
|
||||
self.max_miou = max(self.max_miou, miou)
|
||||
mem_usage = gpu_mem_usage()
|
||||
stats = {
|
||||
"epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
|
||||
"time_avg": self.iter_timer.average_time,
|
||||
"miou": miou,
|
||||
"max_miou": self.max_miou,
|
||||
"mem": int(np.ceil(mem_usage)),
|
||||
}
|
||||
return stats
|
||||
|
||||
def log_epoch_stats(self, cur_epoch):
|
||||
stats = self.get_epoch_stats(cur_epoch)
|
||||
logger.info(logging.dump_log_data(stats, "test_epoch"))
|
||||
129
pycls/core/net.py
Normal file
129
pycls/core/net.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions for manipulating networks."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""Performs ResNet-style weight initialization."""
|
||||
if isinstance(m, nn.Conv2d):
|
||||
# Note that there is no bias due to BN
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
|
||||
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
|
||||
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(mean=0.0, std=0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_precise_bn_stats(model, loader):
|
||||
"""Computes precise BN stats on training data."""
|
||||
# Compute the number of minibatches to use
|
||||
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
|
||||
# Retrieve the BN layers
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
# Initialize stats storage
|
||||
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
|
||||
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
|
||||
# Remember momentum values
|
||||
moms = [bn.momentum for bn in bns]
|
||||
# Disable momentum
|
||||
for bn in bns:
|
||||
bn.momentum = 1.0
|
||||
# Accumulate the stats across the data samples
|
||||
for inputs, _labels in itertools.islice(loader, num_iter):
|
||||
model(inputs.cuda())
|
||||
# Accumulate the stats for each BN layer
|
||||
for i, bn in enumerate(bns):
|
||||
m, v = bn.running_mean, bn.running_var
|
||||
sqs[i] += (v + m * m) / num_iter
|
||||
mus[i] += m / num_iter
|
||||
# Set the stats and restore momentum values
|
||||
for i, bn in enumerate(bns):
|
||||
bn.running_var = sqs[i] - mus[i] * mus[i]
|
||||
bn.running_mean = mus[i]
|
||||
bn.momentum = moms[i]
|
||||
|
||||
|
||||
def reset_bn_stats(model):
|
||||
"""Resets running BN stats."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, torch.nn.BatchNorm2d):
|
||||
m.reset_running_stats()
|
||||
|
||||
|
||||
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
|
||||
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
flops += k * k * w_in * w_out * h * w // groups
|
||||
params += k * k * w_in * w_out // groups
|
||||
flops += w_out if bias else 0
|
||||
params += w_out if bias else 0
|
||||
acts += w_out * h * w
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_batchnorm2d(cx, w_in):
|
||||
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
params += 2 * w_in
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_maxpool2d(cx, k, stride, padding):
|
||||
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity(model):
|
||||
"""Compute model complexity (model can be model instance or model class)."""
|
||||
size = cfg.TRAIN.IM_SIZE
|
||||
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
|
||||
cx = model.complexity(cx)
|
||||
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
|
||||
|
||||
|
||||
def drop_connect(x, drop_ratio):
|
||||
"""Drop connect (adapted from DARTS)."""
|
||||
keep_ratio = 1.0 - drop_ratio
|
||||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||
mask.bernoulli_(keep_ratio)
|
||||
x.div_(keep_ratio)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def get_flat_weights(model):
|
||||
"""Gets all model weights as a single flat vector."""
|
||||
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
|
||||
|
||||
|
||||
def set_flat_weights(model, flat_weights):
|
||||
"""Sets all model weights from a single flat vector."""
|
||||
k = 0
|
||||
for p in model.parameters():
|
||||
n = p.data.numel()
|
||||
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
|
||||
k += n
|
||||
assert k == flat_weights.numel()
|
||||
95
pycls/core/optimizer.py
Normal file
95
pycls/core/optimizer.py
Normal file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Optimizer."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def construct_optimizer(model):
|
||||
"""Constructs the optimizer.
|
||||
|
||||
Note that the momentum update in PyTorch differs from the one in Caffe2.
|
||||
In particular,
|
||||
|
||||
Caffe2:
|
||||
V := mu * V + lr * g
|
||||
p := p - V
|
||||
|
||||
PyTorch:
|
||||
V := mu * V + g
|
||||
p := p - lr * V
|
||||
|
||||
where V is the velocity, mu is the momentum factor, lr is the learning rate,
|
||||
g is the gradient and p are the parameters.
|
||||
|
||||
Since V is defined independently of the learning rate in PyTorch,
|
||||
when the learning rate is changed there is no need to perform the
|
||||
momentum correction by scaling V (unlike in the Caffe2 case).
|
||||
"""
|
||||
if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
|
||||
# Apply different weight decay to Batchnorm and non-batchnorm parameters.
|
||||
p_bn = [p for n, p in model.named_parameters() if "bn" in n]
|
||||
p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
|
||||
optim_params = [
|
||||
{"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
|
||||
{"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
|
||||
]
|
||||
else:
|
||||
optim_params = model.parameters()
|
||||
return torch.optim.SGD(
|
||||
optim_params,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV,
|
||||
)
|
||||
|
||||
|
||||
def lr_fun_steps(cur_epoch):
|
||||
"""Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
|
||||
ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
|
||||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
|
||||
|
||||
|
||||
def lr_fun_exp(cur_epoch):
|
||||
"""Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
|
||||
return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
|
||||
|
||||
|
||||
def lr_fun_cos(cur_epoch):
|
||||
"""Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
|
||||
base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
|
||||
return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
|
||||
|
||||
|
||||
def get_lr_fun():
|
||||
"""Retrieves the specified lr policy function"""
|
||||
lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
|
||||
if lr_fun not in globals():
|
||||
raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
|
||||
return globals()[lr_fun]
|
||||
|
||||
|
||||
def get_epoch_lr(cur_epoch):
|
||||
"""Retrieves the lr for the given epoch according to the policy."""
|
||||
lr = get_lr_fun()(cur_epoch)
|
||||
# Linear warmup
|
||||
if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
|
||||
alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
|
||||
warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
|
||||
lr *= warmup_factor
|
||||
return lr
|
||||
|
||||
|
||||
def set_lr(optimizer, new_lr):
|
||||
"""Sets the optimizer lr to the specified value."""
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_lr
|
||||
132
pycls/core/plotting.py
Normal file
132
pycls/core/plotting.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Plotting functions."""
|
||||
|
||||
import colorlover as cl
|
||||
import matplotlib.pyplot as plt
|
||||
import plotly.graph_objs as go
|
||||
import plotly.offline as offline
|
||||
import pycls.core.logging as logging
|
||||
|
||||
|
||||
def get_plot_colors(max_colors, color_format="pyplot"):
|
||||
"""Generate colors for plotting."""
|
||||
colors = cl.scales["11"]["qual"]["Paired"]
|
||||
if max_colors > len(colors):
|
||||
colors = cl.to_rgb(cl.interp(colors, max_colors))
|
||||
if color_format == "pyplot":
|
||||
return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
|
||||
return colors
|
||||
|
||||
|
||||
def prepare_plot_data(log_files, names, metric="top1_err"):
|
||||
"""Load logs and extract data for plotting error curves."""
|
||||
plot_data = []
|
||||
for file, name in zip(log_files, names):
|
||||
d, data = {}, logging.sort_log_data(logging.load_log_data(file))
|
||||
for phase in ["train", "test"]:
|
||||
x = data[phase + "_epoch"]["epoch_ind"]
|
||||
y = data[phase + "_epoch"][metric]
|
||||
d["x_" + phase], d["y_" + phase] = x, y
|
||||
d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
|
||||
plot_data.append(d)
|
||||
assert len(plot_data) > 0, "No data to plot"
|
||||
return plot_data
|
||||
|
||||
|
||||
def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
|
||||
"""Plot error curves using plotly and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(plot_data), "plotly")
|
||||
# Prepare data for plots (3 sets, train duplicated w and w/o legend)
|
||||
data = []
|
||||
for i, d in enumerate(plot_data):
|
||||
s = str(i)
|
||||
line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
|
||||
line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_test"],
|
||||
y=d["y_test"],
|
||||
mode="lines",
|
||||
name=d["test_label"],
|
||||
line=line_test,
|
||||
legendgroup=s,
|
||||
visible=True,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
data.append(
|
||||
go.Scatter(
|
||||
x=d["x_train"],
|
||||
y=d["y_train"],
|
||||
mode="lines",
|
||||
name=d["train_label"],
|
||||
line=line_train,
|
||||
legendgroup=s,
|
||||
visible=False,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
# Prepare layout w ability to toggle 'all', 'train', 'test'
|
||||
titlefont = {"size": 18, "color": "#7f7f7f"}
|
||||
vis = [[True, True, False], [False, False, True], [False, True, False]]
|
||||
buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
|
||||
buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
|
||||
layout = go.Layout(
|
||||
title=metric + " vs. epoch<br>[dash=train, solid=test]",
|
||||
xaxis={"title": "epoch", "titlefont": titlefont},
|
||||
yaxis={"title": metric, "titlefont": titlefont},
|
||||
showlegend=True,
|
||||
hoverlabel={"namelength": -1},
|
||||
updatemenus=[
|
||||
{
|
||||
"buttons": buttons,
|
||||
"direction": "down",
|
||||
"showactive": True,
|
||||
"x": 1.02,
|
||||
"xanchor": "left",
|
||||
"y": 1.08,
|
||||
"yanchor": "top",
|
||||
}
|
||||
],
|
||||
)
|
||||
# Create plotly plot
|
||||
offline.plot({"data": data, "layout": layout}, filename=filename)
|
||||
|
||||
|
||||
def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
|
||||
"""Plot error curves using matplotlib.pyplot and save to file."""
|
||||
plot_data = prepare_plot_data(log_files, names, metric)
|
||||
colors = get_plot_colors(len(names))
|
||||
for ind, d in enumerate(plot_data):
|
||||
c, lbl = colors[ind], d["test_label"]
|
||||
plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
|
||||
plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
|
||||
plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
|
||||
plt.xlabel("epoch", fontsize=14)
|
||||
plt.ylabel(metric, fontsize=14)
|
||||
plt.grid(alpha=0.4)
|
||||
plt.legend()
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
plt.clf()
|
||||
else:
|
||||
plt.show()
|
||||
39
pycls/core/timer.py
Normal file
39
pycls/core/timer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Timer."""
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Timer(object):
|
||||
"""A simple timer (adapted from Detectron)."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_time = None
|
||||
self.calls = None
|
||||
self.start_time = None
|
||||
self.diff = None
|
||||
self.average_time = None
|
||||
self.reset()
|
||||
|
||||
def tic(self):
|
||||
# using time.time as time.clock does not normalize for multithreading
|
||||
self.start_time = time.time()
|
||||
|
||||
def toc(self):
|
||||
self.diff = time.time() - self.start_time
|
||||
self.total_time += self.diff
|
||||
self.calls += 1
|
||||
self.average_time = self.total_time / self.calls
|
||||
|
||||
def reset(self):
|
||||
self.total_time = 0.0
|
||||
self.calls = 0
|
||||
self.start_time = 0.0
|
||||
self.diff = 0.0
|
||||
self.average_time = 0.0
|
||||
419
pycls/core/trainer.py
Normal file
419
pycls/core/trainer.py
Normal file
@@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Tools for training and testing a model."""
|
||||
|
||||
import os
|
||||
from thop import profile
|
||||
|
||||
import numpy as np
|
||||
import pycls.core.benchmark as benchmark
|
||||
import pycls.core.builders as builders
|
||||
import pycls.core.checkpoint as checkpoint
|
||||
import pycls.core.config as config
|
||||
import pycls.core.distributed as dist
|
||||
import pycls.core.logging as logging
|
||||
import pycls.core.meters as meters
|
||||
import pycls.core.net as net
|
||||
import pycls.core.optimizer as optim
|
||||
import pycls.datasets.loader as loader
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def setup_env():
|
||||
"""Sets up environment for training or testing."""
|
||||
if dist.is_master_proc():
|
||||
# Ensure that the output dir exists
|
||||
os.makedirs(cfg.OUT_DIR, exist_ok=True)
|
||||
# Save the config
|
||||
config.dump_cfg()
|
||||
# Setup logging
|
||||
logging.setup_logging()
|
||||
# Log the config as both human readable and as a json
|
||||
logger.info("Config:\n{}".format(cfg))
|
||||
logger.info(logging.dump_log_data(cfg, "cfg"))
|
||||
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
|
||||
np.random.seed(cfg.RNG_SEED)
|
||||
torch.manual_seed(cfg.RNG_SEED)
|
||||
# Configure the CUDNN backend
|
||||
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""Sets up a model for training or testing and log the results."""
|
||||
# Build the model
|
||||
model = builders.build_model()
|
||||
logger.info("Model:\n{}".format(model))
|
||||
# Log model complexity
|
||||
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
|
||||
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
|
||||
h, w = 1025, 2049
|
||||
else:
|
||||
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
|
||||
if cfg.TASK == "jig":
|
||||
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
else:
|
||||
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
|
||||
macs, params = profile(model, inputs=(x, ), verbose=False)
|
||||
logger.info("Params: {:,}".format(params))
|
||||
logger.info("Flops: {:,}".format(macs))
|
||||
# Transfer the model to the current GPU device
|
||||
err_str = "Cannot use more GPU devices than available"
|
||||
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
|
||||
cur_device = torch.cuda.current_device()
|
||||
model = model.cuda(device=cur_device)
|
||||
# Use multi-process data parallel model in the multi-gpu setting
|
||||
if cfg.NUM_GPUS > 1:
|
||||
# Make model replica operate on the current device
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=model, device_ids=[cur_device], output_device=cur_device
|
||||
)
|
||||
# Set complexity function to be module's complexity function
|
||||
# model.complexity = model.module.complexity
|
||||
return model
|
||||
|
||||
|
||||
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of training."""
|
||||
# Update drop path prob for NAS
|
||||
if cfg.MODEL.TYPE == "nas":
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader, cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
|
||||
optim.set_lr(optimizer, lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
if isinstance(preds, tuple):
|
||||
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
||||
preds = preds[0]
|
||||
else:
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Update the parameters
|
||||
optimizer.step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
|
||||
|
||||
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
||||
"""Performs one epoch of differentiable architecture search."""
|
||||
m = model.module if cfg.NUM_GPUS > 1 else model
|
||||
# Shuffle the data
|
||||
loader.shuffle(train_loader[0], cur_epoch)
|
||||
loader.shuffle(train_loader[1], cur_epoch)
|
||||
# Update the learning rate per epoch
|
||||
if not cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch)
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Enable training mode
|
||||
model.train()
|
||||
train_meter.iter_tic()
|
||||
trainB_iter = iter(train_loader[1])
|
||||
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
|
||||
# Update the learning rate per iter
|
||||
if cfg.OPTIM.ITER_LR:
|
||||
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
|
||||
optim.set_lr(optimizer[0], lr)
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Update architecture
|
||||
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
|
||||
try:
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
except StopIteration:
|
||||
trainB_iter = iter(train_loader[1])
|
||||
inputsB, labelsB = next(trainB_iter)
|
||||
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
|
||||
optimizer[1].zero_grad()
|
||||
loss = m._loss(inputsB, labelsB)
|
||||
loss.backward()
|
||||
optimizer[1].step()
|
||||
# Perform the forward pass
|
||||
preds = model(inputs)
|
||||
# Compute the loss
|
||||
loss = loss_fun(preds, labels)
|
||||
# Perform the backward pass
|
||||
optimizer[0].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
|
||||
# Update the parameters
|
||||
optimizer[0].step()
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
||||
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
||||
# Copy the stats from GPU to CPU (sync point)
|
||||
loss = loss.item()
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
train_meter.iter_toc()
|
||||
# Update and log stats
|
||||
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
||||
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
train_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
train_meter.log_epoch_stats(cur_epoch)
|
||||
train_meter.reset()
|
||||
# Log genotype
|
||||
genotype = m.genotype()
|
||||
logger.info("genotype = %s", genotype)
|
||||
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
|
||||
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_epoch(test_loader, model, test_meter, cur_epoch):
|
||||
"""Evaluates the model on the test set."""
|
||||
# Enable eval mode
|
||||
model.eval()
|
||||
test_meter.iter_tic()
|
||||
for cur_iter, (inputs, labels) in enumerate(test_loader):
|
||||
# Transfer the data to the current GPU device
|
||||
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
||||
# Compute the predictions
|
||||
preds = model(inputs)
|
||||
# Compute the errors
|
||||
if cfg.TASK == "col":
|
||||
preds = preds.permute(0, 2, 3, 1)
|
||||
preds = preds.reshape(-1, preds.size(3))
|
||||
labels = labels.reshape(-1)
|
||||
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
||||
else:
|
||||
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
||||
if cfg.TASK == "seg":
|
||||
# top1_err is in fact inter; top5_err is in fact union
|
||||
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
||||
else:
|
||||
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
||||
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
||||
# Combine the errors across the GPUs (no reduction if 1 GPU used)
|
||||
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
|
||||
# Copy the errors from GPU to CPU (sync point)
|
||||
if cfg.TASK == "seg":
|
||||
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
||||
else:
|
||||
top1_err, top5_err = top1_err.item(), top5_err.item()
|
||||
test_meter.iter_toc()
|
||||
# Update and log stats
|
||||
test_meter.update_stats(top1_err, top5_err, mb_size)
|
||||
test_meter.log_iter_stats(cur_epoch, cur_iter)
|
||||
test_meter.iter_tic()
|
||||
# Log epoch stats
|
||||
test_meter.log_epoch_stats(cur_epoch)
|
||||
test_meter.reset()
|
||||
|
||||
|
||||
def train_model():
|
||||
"""Trains the model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model, loss_fun, and optimizer
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
|
||||
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
|
||||
optimizer_w = torch.optim.SGD(
|
||||
params=params_w,
|
||||
lr=cfg.OPTIM.BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
if cfg.OPTIM.ARCH_OPTIM == "adam":
|
||||
optimizer_a = torch.optim.Adam(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
|
||||
)
|
||||
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
|
||||
optimizer_a = torch.optim.SGD(
|
||||
params=params_a,
|
||||
lr=cfg.OPTIM.ARCH_BASE_LR,
|
||||
momentum=cfg.OPTIM.MOMENTUM,
|
||||
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
|
||||
dampening=cfg.OPTIM.DAMPENING,
|
||||
nesterov=cfg.OPTIM.NESTEROV
|
||||
)
|
||||
optimizer = [optimizer_w, optimizer_a]
|
||||
else:
|
||||
optimizer = optim.construct_optimizer(model)
|
||||
# Load checkpoint or initial weights
|
||||
start_epoch = 0
|
||||
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
|
||||
last_checkpoint = checkpoint.get_last_checkpoint()
|
||||
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
|
||||
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
|
||||
start_epoch = checkpoint_epoch + 1
|
||||
elif cfg.TRAIN.WEIGHTS:
|
||||
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
|
||||
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
if cfg.TRAIN.PORTION < 1:
|
||||
if "search" in cfg.MODEL.TYPE:
|
||||
train_loader = [loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
),
|
||||
loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)]
|
||||
else:
|
||||
train_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="l"
|
||||
)
|
||||
test_loader = loader._construct_loader(
|
||||
dataset_name=cfg.TRAIN.DATASET,
|
||||
split=cfg.TRAIN.SPLIT,
|
||||
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
portion=cfg.TRAIN.PORTION,
|
||||
side="r"
|
||||
)
|
||||
else:
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
|
||||
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
train_meter = train_meter_type(len(l))
|
||||
test_meter = test_meter_type(len(test_loader))
|
||||
# Compute model and loader timings
|
||||
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
|
||||
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
||||
benchmark.compute_time_full(model, loss_fun, l, test_loader)
|
||||
# Perform the training loop
|
||||
logger.info("Start epoch: {}".format(start_epoch + 1))
|
||||
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
|
||||
# Train for one epoch
|
||||
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
|
||||
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
|
||||
# Compute precise BN stats
|
||||
if cfg.BN.USE_PRECISE_STATS:
|
||||
net.compute_precise_bn_stats(model, train_loader)
|
||||
# Save a checkpoint
|
||||
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
|
||||
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
|
||||
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
|
||||
# Evaluate the model
|
||||
next_epoch = cur_epoch + 1
|
||||
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
|
||||
test_epoch(test_loader, model, test_meter, cur_epoch)
|
||||
|
||||
|
||||
def test_model():
|
||||
"""Evaluates a trained model."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model
|
||||
model = setup_model()
|
||||
# Load model weights
|
||||
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
|
||||
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
|
||||
# Create data loaders and meters
|
||||
test_loader = loader.construct_test_loader()
|
||||
test_meter = meters.TestMeter(len(test_loader))
|
||||
# Evaluate the model
|
||||
test_epoch(test_loader, model, test_meter, 0)
|
||||
|
||||
|
||||
def time_model():
|
||||
"""Times model and data loader."""
|
||||
# Setup training/testing environment
|
||||
setup_env()
|
||||
# Construct the model and loss_fun
|
||||
model = setup_model()
|
||||
loss_fun = builders.build_loss_fun().cuda()
|
||||
# Create data loaders
|
||||
train_loader = loader.construct_train_loader()
|
||||
test_loader = loader.construct_test_loader()
|
||||
# Compute model and loader timings
|
||||
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
|
||||
108
pycls/models/common.py
Normal file
108
pycls/models/common.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def Preprocess(x):
|
||||
if cfg.TASK == 'jig':
|
||||
assert len(x.shape) == 5, 'Wrong tensor dimension for jigsaw'
|
||||
assert x.shape[1] == cfg.JIGSAW_GRID ** 2, 'Wrong grid for jigsaw'
|
||||
x = x.view([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]])
|
||||
return x
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, channels, num_classes):
|
||||
super(Classifier, self).__init__()
|
||||
if cfg.TASK == 'jig':
|
||||
self.jig_sq = cfg.JIGSAW_GRID ** 2
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels * self.jig_sq, num_classes)
|
||||
elif cfg.TASK == 'col':
|
||||
self.classifier = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)
|
||||
elif cfg.TASK == 'seg':
|
||||
self.classifier = ASPP(channels, cfg.MODEL.ASPP_CHANNELS, num_classes, cfg.MODEL.ASPP_RATES)
|
||||
else:
|
||||
self.pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(channels, num_classes)
|
||||
|
||||
def forward(self, x, shape):
|
||||
if cfg.TASK == 'jig':
|
||||
x = self.pooling(x)
|
||||
x = x.view([x.shape[0] // self.jig_sq, x.shape[1] * self.jig_sq, x.shape[2], x.shape[3]])
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
elif cfg.TASK in ['col', 'seg']:
|
||||
x = self.classifier(x)
|
||||
x = nn.Upsample(shape, mode='bilinear', align_corners=True)(x)
|
||||
else:
|
||||
x = self.pooling(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_classes, rates):
|
||||
super(ASPP, self).__init__()
|
||||
assert len(rates) in [1, 3]
|
||||
self.rates = rates
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.aspp1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[0],
|
||||
padding=rates[0], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
if len(self.rates) == 3:
|
||||
self.aspp3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[1],
|
||||
padding=rates[1], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp4 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, dilation=rates[2],
|
||||
padding=rates[2], bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.aspp5 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, num_classes, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x2 = self.aspp2(x)
|
||||
x5 = self.global_pooling(x)
|
||||
x5 = self.aspp5(x5)
|
||||
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
|
||||
align_corners=True)(x5)
|
||||
if len(self.rates) == 3:
|
||||
x3 = self.aspp3(x)
|
||||
x4 = self.aspp4(x)
|
||||
x = torch.cat((x1, x2, x3, x4, x5), 1)
|
||||
else:
|
||||
x = torch.cat((x1, x2, x5), 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
634
pycls/models/nas/genotypes.py
Normal file
634
pycls/models/nas/genotypes.py
Normal file
@@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS genotypes (adopted from DARTS)."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
|
||||
# NASNet ops
|
||||
NASNET_OPS = [
|
||||
'skip_connect',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'dil_conv_3x3',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'max_pool_7x7',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
]
|
||||
|
||||
# ENAS ops
|
||||
ENAS_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
]
|
||||
|
||||
# AmoebaNet ops
|
||||
AMOEBA_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_sep_conv_3x3',
|
||||
'conv_7x1_1x7',
|
||||
]
|
||||
|
||||
# NAO ops
|
||||
NAO_OPS = [
|
||||
'skip_connect',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'max_pool_2x2',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'avg_pool_2x2',
|
||||
'avg_pool_3x3',
|
||||
'avg_pool_5x5',
|
||||
]
|
||||
|
||||
# PNAS ops
|
||||
PNAS_OPS = [
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'skip_connect',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_conv_3x3',
|
||||
]
|
||||
|
||||
# DARTS ops
|
||||
DARTS_OPS = [
|
||||
'none',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
]
|
||||
|
||||
|
||||
NASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 3),
|
||||
('avg_pool_3x3', 2),
|
||||
('sep_conv_3x3', 2),
|
||||
('max_pool_3x3', 1),
|
||||
],
|
||||
reduce_concat=[4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5, 6],
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 1),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 4),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
|
||||
AmoebaNet = Genotype(
|
||||
normal=[
|
||||
('avg_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 1),
|
||||
('skip_connect', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
],
|
||||
normal_concat=[4, 5, 6],
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_7x7', 2),
|
||||
('sep_conv_7x7', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('conv_7x1_1x7', 0),
|
||||
('sep_conv_3x3', 5),
|
||||
],
|
||||
reduce_concat=[3, 4, 6]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('avg_pool_3x3', 0)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 2),
|
||||
('skip_connect', 2),
|
||||
('max_pool_3x3', 1)
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5]
|
||||
)
|
||||
|
||||
PDARTS = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_3x3', 1),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_C10 = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('dil_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
PCDARTS_IN1K = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('dil_conv_5x5', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_COL = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_CLS = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('max_pool_3x3', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('max_pool_3x3', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 4),
|
||||
('dil_conv_5x5', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_3x3', 3)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_COL = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 1)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_IMAGENET22K_JIG = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 4)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_5x5', 0),
|
||||
('skip_connect', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 3),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_5x5', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_SEG = Genotype(
|
||||
normal=[
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('sep_conv_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 4),
|
||||
('sep_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_ROT = Genotype(
|
||||
normal=[
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 3),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('max_pool_3x3', 0),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 1),
|
||||
('sep_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2),
|
||||
('sep_conv_5x5', 2),
|
||||
('sep_conv_5x5', 0)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_COL = Genotype(
|
||||
normal=[
|
||||
('dil_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_5x5', 2),
|
||||
('dil_conv_3x3', 3),
|
||||
('skip_connect', 0),
|
||||
('skip_connect', 0),
|
||||
('sep_conv_3x3', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('avg_pool_3x3', 0),
|
||||
('avg_pool_3x3', 1),
|
||||
('skip_connect', 4)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
UNNAS_CITYSCAPES_JIG = Genotype(
|
||||
normal=[
|
||||
('dil_conv_5x5', 1),
|
||||
('sep_conv_5x5', 0),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 1),
|
||||
('sep_conv_3x3', 0),
|
||||
('sep_conv_3x3', 2),
|
||||
('sep_conv_3x3', 0),
|
||||
('dil_conv_5x5', 1)
|
||||
],
|
||||
normal_concat=range(2, 6),
|
||||
reduce=[
|
||||
('avg_pool_3x3', 0),
|
||||
('skip_connect', 1),
|
||||
('dil_conv_5x5', 1),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 2),
|
||||
('dil_conv_5x5', 0),
|
||||
('dil_conv_5x5', 3),
|
||||
('dil_conv_5x5', 2)
|
||||
],
|
||||
reduce_concat=range(2, 6)
|
||||
)
|
||||
|
||||
|
||||
# Supported genotypes
|
||||
GENOTYPES = {
|
||||
'nas': NASNet,
|
||||
'pnas': PNASNet,
|
||||
'amoeba': AmoebaNet,
|
||||
'darts_v1': DARTS_V1,
|
||||
'darts_v2': DARTS_V2,
|
||||
'pdarts': PDARTS,
|
||||
'pcdarts_c10': PCDARTS_C10,
|
||||
'pcdarts_in1k': PCDARTS_IN1K,
|
||||
'unnas_imagenet_cls': UNNAS_IMAGENET_CLS,
|
||||
'unnas_imagenet_rot': UNNAS_IMAGENET_ROT,
|
||||
'unnas_imagenet_col': UNNAS_IMAGENET_COL,
|
||||
'unnas_imagenet_jig': UNNAS_IMAGENET_JIG,
|
||||
'unnas_imagenet22k_cls': UNNAS_IMAGENET22K_CLS,
|
||||
'unnas_imagenet22k_rot': UNNAS_IMAGENET22K_ROT,
|
||||
'unnas_imagenet22k_col': UNNAS_IMAGENET22K_COL,
|
||||
'unnas_imagenet22k_jig': UNNAS_IMAGENET22K_JIG,
|
||||
'unnas_cityscapes_seg': UNNAS_CITYSCAPES_SEG,
|
||||
'unnas_cityscapes_rot': UNNAS_CITYSCAPES_ROT,
|
||||
'unnas_cityscapes_col': UNNAS_CITYSCAPES_COL,
|
||||
'unnas_cityscapes_jig': UNNAS_CITYSCAPES_JIG,
|
||||
'custom': None,
|
||||
}
|
||||
337
pycls/models/nas/nas.py
Normal file
337
pycls/models/nas/nas.py
Normal file
@@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""NAS network (adopted from DARTS)."""
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import pycls.core.logging as logging
|
||||
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.common import Preprocess
|
||||
from pycls.models.common import Classifier
|
||||
from pycls.models.nas.genotypes import GENOTYPES
|
||||
from pycls.models.nas.genotypes import Genotype
|
||||
from pycls.models.nas.operations import FactorizedReduce
|
||||
from pycls.models.nas.operations import OPS
|
||||
from pycls.models.nas.operations import ReLUConvBN
|
||||
from pycls.models.nas.operations import Identity
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
"""Drop path (ported from DARTS)."""
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1.-drop_prob
|
||||
mask = Variable(
|
||||
torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
|
||||
)
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""NAS cell (ported from DARTS)."""
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
logger.info('{}, {}, {}'.format(C_prev_prev, C_prev, C))
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
"""CIFAR network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
def _loss(self, input, target, return_logits=False):
|
||||
logits = self(input)
|
||||
loss = self._criterion(logits, target)
|
||||
|
||||
return (loss, logits) if return_logits else loss
|
||||
|
||||
def step(self, input, target, args, shared=None, return_grad=False):
|
||||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||
Lt.backward()
|
||||
if args.grad_clip != 0:
|
||||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
if return_grad:
|
||||
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||||
return logit_t, Lt, grad
|
||||
else:
|
||||
return logit_t, Lt
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
"""ImageNet network (ported from DARTS)."""
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(cfg.MODEL.INPUT_CHANNELS, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
reduction_layers = [layers//3] if cfg.TASK == 'seg' else [layers//3, 2*layers//3]
|
||||
for i in range(layers):
|
||||
if i in reduction_layers:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
self.classifier = Classifier(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
input = Preprocess(input)
|
||||
logits_aux = None
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2 * self._layers // 3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
logits = self.classifier(s1, input.shape[2:])
|
||||
if self._auxiliary and self.training:
|
||||
return logits, logits_aux
|
||||
return logits
|
||||
|
||||
def _loss(self, input, target, return_logits=False):
|
||||
logits = self(input)
|
||||
loss = self._criterion(logits, target)
|
||||
|
||||
return (loss, logits) if return_logits else loss
|
||||
|
||||
def step(self, input, target, args, shared=None, return_grad=False):
|
||||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||
Lt.backward()
|
||||
if args.grad_clip != 0:
|
||||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
if return_grad:
|
||||
grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()])
|
||||
return logit_t, Lt, grad
|
||||
else:
|
||||
return logit_t, Lt
|
||||
|
||||
|
||||
class NAS(nn.Module):
|
||||
"""NAS net wrapper (delegates to nets from DARTS)."""
|
||||
|
||||
def __init__(self):
|
||||
assert cfg.TRAIN.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Training on {} is not supported'.format(cfg.TRAIN.DATASET)
|
||||
assert cfg.TEST.DATASET in ['cifar10', 'imagenet', 'cityscapes'], \
|
||||
'Testing on {} is not supported'.format(cfg.TEST.DATASET)
|
||||
assert cfg.NAS.GENOTYPE in GENOTYPES, \
|
||||
'Genotype {} not supported'.format(cfg.NAS.GENOTYPE)
|
||||
super(NAS, self).__init__()
|
||||
logger.info('Constructing NAS: {}'.format(cfg.NAS))
|
||||
# Use a custom or predefined genotype
|
||||
if cfg.NAS.GENOTYPE == 'custom':
|
||||
genotype = Genotype(
|
||||
normal=cfg.NAS.CUSTOM_GENOTYPE[0],
|
||||
normal_concat=cfg.NAS.CUSTOM_GENOTYPE[1],
|
||||
reduce=cfg.NAS.CUSTOM_GENOTYPE[2],
|
||||
reduce_concat=cfg.NAS.CUSTOM_GENOTYPE[3],
|
||||
)
|
||||
else:
|
||||
genotype = GENOTYPES[cfg.NAS.GENOTYPE]
|
||||
# Determine the network constructor for dataset
|
||||
if 'cifar' in cfg.TRAIN.DATASET:
|
||||
net_ctor = NetworkCIFAR
|
||||
else:
|
||||
net_ctor = NetworkImageNet
|
||||
# Construct the network
|
||||
self.net_ = net_ctor(
|
||||
C=cfg.NAS.WIDTH,
|
||||
num_classes=cfg.MODEL.NUM_CLASSES,
|
||||
layers=cfg.NAS.DEPTH,
|
||||
auxiliary=cfg.NAS.AUX,
|
||||
genotype=genotype
|
||||
)
|
||||
# Drop path probability (set / annealed based on epoch)
|
||||
self.net_.drop_path_prob = 0.0
|
||||
|
||||
def set_drop_path_prob(self, drop_path_prob):
|
||||
self.net_.drop_path_prob = drop_path_prob
|
||||
|
||||
def forward(self, x):
|
||||
return self.net_.forward(x)
|
||||
219
pycls/models/nas/operations.py
Normal file
219
pycls/models/nas/operations.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""NAS ops (adopted from DARTS)."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
OPS = {
|
||||
'none': lambda C, stride, affine:
|
||||
Zero(stride),
|
||||
'noise': lambda C, stride, affine: NoiseOp(stride, 0., 1.),
|
||||
'avg_pool_2x2': lambda C, stride, affine:
|
||||
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
|
||||
'avg_pool_3x3': lambda C, stride, affine:
|
||||
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'avg_pool_5x5': lambda C, stride, affine:
|
||||
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
|
||||
'max_pool_2x2': lambda C, stride, affine:
|
||||
nn.MaxPool2d(2, stride=stride, padding=0),
|
||||
'max_pool_3x3': lambda C, stride, affine:
|
||||
nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'max_pool_5x5': lambda C, stride, affine:
|
||||
nn.MaxPool2d(5, stride=stride, padding=2),
|
||||
'max_pool_7x7': lambda C, stride, affine:
|
||||
nn.MaxPool2d(7, stride=stride, padding=3),
|
||||
'skip_connect': lambda C, stride, affine:
|
||||
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'conv_1x1': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_3x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3': lambda C, stride, affine:
|
||||
SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine:
|
||||
SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine:
|
||||
SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine:
|
||||
DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine:
|
||||
DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'dil_sep_conv_3x3': lambda C, stride, affine:
|
||||
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'conv_3x1_1x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C, C, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_7x1_1x7': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
}
|
||||
|
||||
class NoiseOp(nn.Module):
|
||||
def __init__(self, stride, mean, std):
|
||||
super(NoiseOp, self).__init__()
|
||||
self.stride = stride
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
if self.stride != 1:
|
||||
x_new = x[:,:,::self.stride,::self.stride]
|
||||
else:
|
||||
x_new = x
|
||||
noise = Variable(x_new.data.new(x_new.size()).normal_(self.mean, self.std))
|
||||
|
||||
return noise
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_out, kernel_size, stride=stride,
|
||||
padding=padding, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilSepConv(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True
|
||||
):
|
||||
super(DilSepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
169
sota/cnn/genotypes.py
Normal file
169
sota/cnn/genotypes.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
PRIMITIVES = [
|
||||
'none',
|
||||
'noise',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5'
|
||||
]
|
||||
|
||||
|
||||
######## S1-S4 Space ########
|
||||
#### cifar10 s1 - s4
|
||||
|
||||
init_pt_s1_C10_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["skip_connect", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["dil_conv_5x5", 2], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s1_C10_2 = Genotype(normal=[["skip_connect", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 3], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["noise", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
#### cifar100 s1 - s4
|
||||
init_pt_s1_C100_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
|
||||
init_pt_s1_C100_2 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 3], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C100_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C100_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C100_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C100_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C100_0 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C100_2 = Genotype(normal=[["noise", 0], ["sep_conv_3x3", 1], ["noise", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#### svhn s1 - s4
|
||||
init_pt_s1_svhn_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 2], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s1_svhn_2 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["max_pool_3x3", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["avg_pool_3x3", 0], ["dil_conv_5x5", 3]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_svhn_0 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_svhn_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_svhn_0 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_svhn_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_svhn_0 = Genotype(normal=[["noise", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["noise", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_svhn_2 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["noise", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
|
||||
######## DARTS Space ########
|
||||
|
||||
####init-100-N10
|
||||
init_pt_s5_C10_0_100_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####global op gready
|
||||
global_pt_s5_C10_0_100_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
global_pt_s5_C10_1_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
global_pt_s5_C10_2_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
global_pt_s5_C10_3_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####2500_sample
|
||||
sample_2500_0 = Genotype(normal=[["dil_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_5x5", 2], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["dil_conv_5x5", 3]], reduce_concat=range(2, 6))
|
||||
sample_2500_1 = Genotype(normal=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["avg_pool_3x3", 2], ["dil_conv_5x5", 1], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
sample_2500_2 = Genotype(normal=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
|
||||
sample_2500_3 = Genotype(normal=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["dil_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["dil_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 0], ["max_pool_3x3", 1], ["avg_pool_3x3", 0], ["max_pool_3x3", 1], ["avg_pool_3x3", 1], ["skip_connect", 3]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
####20000_sample
|
||||
sample_20000_0 = Genotype(normal=[["skip_connect", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["skip_connect", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 2], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["dil_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 0], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
sample_20000_1 = Genotype(normal=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["avg_pool_3x3", 2], ["dil_conv_5x5", 1], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
sample_20000_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["skip_connect", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 3]], reduce_concat=range(2, 6))
|
||||
sample_20000_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["dil_conv_3x3", 2], ["dil_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
|
||||
####50000_sample
|
||||
sample_50000_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["skip_connect", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 0], ["max_pool_3x3", 1]], reduce_concat=range(2, 6))
|
||||
sample_50000_1 = Genotype(normal=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["max_pool_3x3", 1], ["dil_conv_3x3", 1], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
|
||||
sample_50000_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["skip_connect", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 3]], reduce_concat=range(2, 6))
|
||||
sample_50000_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["dil_conv_3x3", 2], ["dil_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
|
||||
#### random
|
||||
random_max_0 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
random_max_1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
random_max_2 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
random_max_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#### ImageNet-1k
|
||||
init_pt_s5_in_0_100_N10=Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_in_1_100_N10=Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_in_2_100_N10=Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_in_3_100_N10=Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
####N1
|
||||
init_pt_s5_C10_0_N1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_N1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_N1 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_N1 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####N5
|
||||
|
||||
#####V1
|
||||
init_pt_s5_C10_0_1_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_1_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_1_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_1_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#####V10
|
||||
init_pt_s5_C10_0_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_10_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#####V100
|
||||
init_pt_s5_C10_0_100_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_100_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_100_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_100_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####N10
|
||||
|
||||
#####V1
|
||||
init_pt_s5_C10_0_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#####V10
|
||||
init_pt_s5_C10_0_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#fisher
|
||||
cf10_fisher = Genotype(normal=[["avg_pool_3x3", 0], ["avg_pool_3x3", 1], ["avg_pool_3x3", 0], ["dil_conv_3x3", 1],["avg_pool_3x3", 0], ["skip_connect", 2],["sep_conv_5x5", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["max_pool_3x3", 0], ["max_pool_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
#grasp
|
||||
cf10_grasp = Genotype(normal=[["avg_pool_3x3", 0], ["avg_pool_3x3", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["skip_connect", 1], ["avg_pool_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["max_pool_3x3", 1], ["sep_conv_3x3", 3]], reduce_concat=range(2, 6))
|
||||
#jacob_cov
|
||||
cf10_jacob_cov = Genotype(normal=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["dil_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["max_pool_3x3", 1], ["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["dil_conv_3x3", 1], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
#meco
|
||||
cf10_meco = Genotype(normal=[["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce= [["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 1], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
#synflow
|
||||
cf10_synflow = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
#zico
|
||||
cf10_zico= Genotype(normal= [["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
#snip
|
||||
cf10_snip = Genotype(normal= [["sep_conv_3x3", 0], ["avg_pool_3x3", 1], ["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["dil_conv_3x3", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
#fisher
|
||||
cf100_fisher = Genotype(normal= [["sep_conv_3x3", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["max_pool_3x3", 1], ["dil_conv_3x3", 1], ["skip_connect", 3], ["dil_conv_5x5", 0], ["skip_connect", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["max_pool_3x3", 1], ["sep_conv_3x3", 4]] , reduce_concat=range(2, 6))
|
||||
#grasp
|
||||
cf100_grasp= Genotype(normal= [["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["avg_pool_3x3", 0], ["sep_conv_3x3", 4]] , normal_concat=range(2, 6), reduce= [["max_pool_3x3", 0], ["sep_conv_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_3x3", 2], ["skip_connect", 0], ["dil_conv_3x3", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]] , reduce_concat=range(2, 6))
|
||||
#jacob_cov
|
||||
cf100_jacob_cov = Genotype(normal= [["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["avg_pool_3x3", 0], ["avg_pool_3x3", 3], ["dil_conv_5x5", 1], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce= [["skip_connect", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["skip_connect", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_5x5", 1]] , reduce_concat=range(2, 6))
|
||||
#meco
|
||||
cf100_meco = Genotype(normal= [["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["dil_conv_5x5", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce= [["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["dil_conv_3x3", 0], ["sep_conv_3x3", 1]] , reduce_concat=range(2, 6))
|
||||
#snip
|
||||
cf100_snip = Genotype(normal= [["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["skip_connect", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["max_pool_3x3", 3]], normal_concat=range(2, 6), reduce= [["dil_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 2], ["skip_connect", 0], ["skip_connect", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2]] , reduce_concat=range(2, 6))
|
||||
#synflow
|
||||
cf100_synflow = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]] , normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]] , reduce_concat=range(2, 6))
|
||||
#zico
|
||||
cf100_zico = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]] , reduce_concat=range(2, 6))
|
||||
|
||||
40
sota/cnn/hdf5.py
Normal file
40
sota/cnn/hdf5.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import h5py
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class H5Dataset(Dataset):
|
||||
def __init__(self, h5_path, transform=None):
|
||||
self.h5_path = h5_path
|
||||
self.h5_file = None
|
||||
self.length = len(h5py.File(h5_path, 'r'))
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
#loading in getitem allows us to use multiple processes for data loading
|
||||
#because hdf5 files aren't pickelable so can't transfer them across processes
|
||||
# https://discuss.pytorch.org/t/hdf5-a-data-format-for-pytorch/40379
|
||||
# https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
|
||||
# TODO possible look at __getstate__ and __setstate__ as a more elegant solution
|
||||
if self.h5_file is None:
|
||||
self.h5_file = h5py.File(self.h5_path, 'r', libver="latest", swmr=True)
|
||||
|
||||
record = self.h5_file[str(index)]
|
||||
|
||||
if self.transform:
|
||||
x = Image.fromarray(record['data'][()])
|
||||
x = self.transform(x)
|
||||
else:
|
||||
x = torch.from_numpy(record['data'][()])
|
||||
|
||||
y = record['target'][()]
|
||||
y = torch.from_numpy(np.asarray(y))
|
||||
|
||||
return (x,y)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
336
sota/cnn/init_projection.py
Normal file
336
sota/cnn/init_projection.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import torch.utils
|
||||
from copy import deepcopy
|
||||
from foresight.pruners import *
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
def sample_op(model, input, target, args, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
|
||||
select_opid = np.random.choice(np.array(range(num_ops)), size=1)[0]
|
||||
return selected_eid, select_opid
|
||||
|
||||
def project_op(model, input, target, args, cell_type, proj_queue=None, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
# print(num_edges, num_ops, remain_eids)
|
||||
if args.edge_decision == "random":
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
elif args.edge_decision == 'reverse':
|
||||
selected_eid = remain_eids[-1]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
else:
|
||||
selected_eid = remain_eids[0]
|
||||
logging.info('selected node: %d %s', selected_eid, cell_type)
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
n_classes = 120
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
# ## proj evaluation
|
||||
# with torch.no_grad():
|
||||
# valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
# crit = valid_stats
|
||||
# crit_list.append(crit)
|
||||
# if crit_extrema is None or compare(crit, crit_extrema):
|
||||
# crit_extrema = crit
|
||||
# best_opid = opid
|
||||
|
||||
## proj evaluation
|
||||
if proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model,cell_type, input, target, weights=weights)
|
||||
else:
|
||||
cache_weight = model.proj_weights[cell_type][selected_eid]
|
||||
cache_flag = model.candidate_flags[cell_type][selected_eid]
|
||||
|
||||
for idx in range(num_ops):
|
||||
if idx == opid:
|
||||
model.proj_weights[cell_type][selected_eid][opid] = 0
|
||||
else:
|
||||
model.proj_weights[cell_type][selected_eid][idx] = 1.0 / num_ops
|
||||
|
||||
model.candidate_flags[cell_type][selected_eid] = False
|
||||
# print(model.get_projected_weights())
|
||||
if proj_crit == 'comb':
|
||||
synflow = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['synflow'])
|
||||
var = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['var'])
|
||||
# print(synflow, var)
|
||||
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
|
||||
measures = {'comb': comb}
|
||||
else:
|
||||
measures = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[proj_crit])
|
||||
|
||||
# print(measures)
|
||||
for idx in range(num_ops):
|
||||
model.proj_weights[cell_type][selected_eid][idx] = 0
|
||||
model.candidate_flags[cell_type][selected_eid] = cache_flag
|
||||
crit = measures[proj_crit]
|
||||
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
def project_global_op(model, input, target, args, infer, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
for eid in remain_eids:
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
|
||||
#weights_dict = {cell_type:weights}
|
||||
with torch.no_grad():
|
||||
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
best_eid = eid
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
#logging.info(crit_list)
|
||||
return best_eid, best_opid
|
||||
|
||||
def sample_edge(model, input, target, args, cell_type, selected_eid=None):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an node
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
|
||||
while len(eids) > 2:
|
||||
elected_eid = np.random.choice(eids, size=1)[0]
|
||||
eids.remove(elected_eid)
|
||||
|
||||
return selected_nid, eids
|
||||
|
||||
def project_edge(model, input, target, args, cell_type):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an node
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
elif args.edge_decision == 'reverse':
|
||||
selected_nid = remain_nids[-1]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
else:
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
#### select top2 edges
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
crit_list = []
|
||||
while len(eids) > 2:
|
||||
eid_todel = None
|
||||
crit_extrema = None
|
||||
for eid in eids:
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
weights[eid].data.fill_(0)
|
||||
|
||||
## proj evaluation
|
||||
with torch.no_grad():
|
||||
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
crit_list.append(crit)
|
||||
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
|
||||
crit_extrema = crit
|
||||
eid_todel = eid
|
||||
|
||||
eids.remove(eid_todel)
|
||||
|
||||
#### project
|
||||
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
|
||||
#logging.info(crit_list)
|
||||
return selected_nid, eids
|
||||
|
||||
|
||||
def pt_project(train_queue, model, args):
|
||||
model.eval()
|
||||
|
||||
#### macros
|
||||
num_projs = model.num_edges + len(model.nid2eids.keys())
|
||||
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
|
||||
proj_queue = train_queue
|
||||
|
||||
epoch = 0
|
||||
for step, (input, target) in enumerate(proj_queue):
|
||||
if epoch < model.num_edges:
|
||||
logging.info('project op')
|
||||
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
selected_eid_normal, best_opid_normal = project_global_op(model, input, target, args, cell_type='normal')
|
||||
elif args.edge_decision == 'sample':
|
||||
selected_eid_normal, best_opid_normal = sample_op(model, input, target, args, cell_type='normal')
|
||||
else:
|
||||
selected_eid_normal, best_opid_normal = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='normal')
|
||||
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
selected_eid_reduce, best_opid_reduce = project_global_op(model, input, target, args, cell_type='reduce')
|
||||
elif args.edge_decision == 'sample':
|
||||
selected_eid_reduce, best_opid_reduce = sample_op(model, input, target, args, cell_type='reduce')
|
||||
else:
|
||||
selected_eid_reduce, best_opid_reduce = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='reduce')
|
||||
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
|
||||
|
||||
else:
|
||||
logging.info('project edge')
|
||||
if args.edge_decision == 'sample':
|
||||
selected_nid_normal, eids_normal = sample_edge(model, input, target, args, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = sample_edge(model, input, target, args, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
else:
|
||||
selected_nid_normal, eids_normal = project_edge(model, input, target, args, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = project_edge(model, input, target, args, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
epoch+=1
|
||||
|
||||
if epoch == num_projs:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
def Jocab_Score(ori_model, cell_type, input, target, weights=None):
|
||||
model = deepcopy(ori_model)
|
||||
model.eval()
|
||||
if cell_type == 'reduce':
|
||||
model.proj_weights['reduce'] = weights
|
||||
model.proj_weights['normal'] = model.get_projected_weights('normal')
|
||||
else:
|
||||
model.proj_weights['normal'] = weights
|
||||
model.proj_weights['reduce'] = model.get_projected_weights('reduce')
|
||||
|
||||
batch_size = input.shape[0]
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
model.K = model.K + K + K2
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
|
||||
input = input.cuda()
|
||||
|
||||
model(input, using_proj=True)
|
||||
score = hooklogdet(model.K.cpu().numpy())
|
||||
|
||||
del model
|
||||
return score
|
||||
|
||||
def hooklogdet(K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
||||
133
sota/cnn/model.py
Normal file
133
sota/cnn/model.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sota.cnn.operations import *
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from nasbench201.utils import drop_path
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHead(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHead, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
# image size = 2 x 2
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(Network, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, num_classes)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
logits_aux = None
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
return logits, logits_aux
|
||||
150
sota/cnn/model_imagenet.py
Normal file
150
sota/cnn/model_imagenet.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
# from optimizers.darts.operations import *
|
||||
from sota.cnn.operations import *
|
||||
#from optimizers.darts.utils import drop_path
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1.-drop_prob
|
||||
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2 * i]]
|
||||
h2 = states[self._indices[2 * i + 1]]
|
||||
op1 = self._ops[2 * i]
|
||||
op2 = self._ops[2 * i + 1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
self.drop_path_prob = 0.0
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
for i in range(layers):
|
||||
if i in [layers // 3, 2 * layers // 3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
self.global_pooling = nn.AvgPool2d(7)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
logits_aux = None
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2 * self._layers // 3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
return logits, logits_aux
|
||||
288
sota/cnn/model_search.py
Normal file
288
sota/cnn/model_search.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from sota.cnn.operations import *
|
||||
from sota.cnn.genotypes import Genotype
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from nasbench201.utils import drop_path
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, C, stride, PRIMITIVES):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
if 'pool' in primitive:
|
||||
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights):
|
||||
ret = sum(w * op(x, block_input=True) if w == 0 else w * op(x) for w, op in zip(weights, self._ops) if w != 0)
|
||||
return ret
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.primitives = self.PRIMITIVES['primitives_reduct' if reduction else 'primitives_normal']
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
self._bns = nn.ModuleList()
|
||||
|
||||
edge_index = 0
|
||||
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride, self.primitives[edge_index])
|
||||
self._ops.append(op)
|
||||
edge_index += 1
|
||||
|
||||
def forward(self, s0, s1, weights, drop_prob=0.):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
if drop_prob > 0. and self.training:
|
||||
s = sum(drop_path(self._ops[offset+j](h, weights[offset+j]), drop_prob) for j, h in enumerate(states))
|
||||
else:
|
||||
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
||||
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0, nettype='cifar'):
|
||||
super(Network, self).__init__()
|
||||
#### original code
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._criterion = criterion
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.drop_path_prob = drop_path_prob
|
||||
self.nettype = nettype
|
||||
|
||||
nn.Module.PRIMITIVES = primitives; self.op_names = primitives
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
if self.nettype == 'cifar':
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
else:
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C_curr // 2, C_curr, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C_curr, C_curr, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
if self.nettype == 'cifar':
|
||||
reduction_prev = False
|
||||
else:
|
||||
reduction_prev = True
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
self._initialize_alphas()
|
||||
|
||||
#### optimizer
|
||||
self._args = args
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.get_weights(),
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
nesterov= args.nesterov)
|
||||
|
||||
|
||||
def reset_optimizer(self, lr, momentum, weight_decay):
|
||||
del self.optimizer
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.get_weights(),
|
||||
lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay)
|
||||
|
||||
def _loss(self, input, target, return_logits=False):
|
||||
logits = self(input)
|
||||
loss = self._criterion(logits, target)
|
||||
return (loss, logits) if return_logits else loss
|
||||
|
||||
def _initialize_alphas(self):
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(self.PRIMITIVES['primitives_normal'][0])
|
||||
self.num_edges = k
|
||||
self.num_ops = num_ops
|
||||
|
||||
self.alphas_normal = self._initialize_alphas_numpy(k, num_ops)
|
||||
self.alphas_reduce = self._initialize_alphas_numpy(k, num_ops)
|
||||
self._arch_parameters = [ # must be in this order!
|
||||
self.alphas_normal,
|
||||
self.alphas_reduce,
|
||||
]
|
||||
|
||||
def _initialize_alphas_numpy(self, k, num_ops):
|
||||
''' init from specified arch '''
|
||||
return Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
|
||||
|
||||
def forward(self, input):
|
||||
weights = self.get_softmax()
|
||||
weights_normal = weights['normal']
|
||||
weights_reduce = weights['reduce']
|
||||
|
||||
if self.nettype == 'cifar':
|
||||
s0 = s1 = self.stem(input)
|
||||
else:
|
||||
print('imagetnet')
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = weights_reduce
|
||||
else:
|
||||
weights = weights_normal
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0),-1))
|
||||
|
||||
return logits
|
||||
|
||||
def step(self, input, target, args, shared=None):
|
||||
assert shared is None, 'gradient sharing disabled'
|
||||
|
||||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||
Lt.backward()
|
||||
|
||||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
return logit_t, Lt
|
||||
|
||||
#### utils
|
||||
def set_arch_parameters(self, new_alphas):
|
||||
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
|
||||
alpha.data.copy_(new_alpha.data)
|
||||
|
||||
def get_softmax(self):
|
||||
weights_normal = F.softmax(self.alphas_normal, dim=-1)
|
||||
weights_reduce = F.softmax(self.alphas_reduce, dim=-1)
|
||||
return {'normal':weights_normal, 'reduce':weights_reduce}
|
||||
|
||||
def printing(self, logging, option='all'):
|
||||
weights = self.get_softmax()
|
||||
if option in ['all', 'normal']:
|
||||
weights_normal = weights['normal']
|
||||
logging.info(weights_normal)
|
||||
if option in ['all', 'reduce']:
|
||||
weights_reduce = weights['reduce']
|
||||
logging.info(weights_reduce)
|
||||
|
||||
def arch_parameters(self):
|
||||
return self._arch_parameters
|
||||
|
||||
def get_weights(self):
|
||||
return self.parameters()
|
||||
|
||||
def new(self):
|
||||
model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.PRIMITIVES, self._args,\
|
||||
drop_path_prob=self.drop_path_prob).cuda()
|
||||
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
|
||||
x.data.copy_(y.data)
|
||||
return model_new
|
||||
|
||||
def clip(self):
|
||||
for p in self.arch_parameters():
|
||||
for line in p:
|
||||
max_index = line.argmax()
|
||||
line.data.clamp_(0, 1)
|
||||
if line.sum() == 0.0:
|
||||
line.data[max_index] = 1.0
|
||||
line.data.div_(line.sum())
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights, normal=True):
|
||||
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct'] ## two are equal for Darts space
|
||||
|
||||
gene = []
|
||||
n = 2
|
||||
start = 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
|
||||
try:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
||||
except ValueError: # This error happens when the 'none' op is not present in the ops
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
||||
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if 'none' in PRIMITIVES[j]:
|
||||
if k != PRIMITIVES[j].index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
else:
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[start+j][k_best], j))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy(), True)
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy(), False)
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
||||
213
sota/cnn/model_search_darts_proj.py
Normal file
213
sota/cnn/model_search_darts_proj.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from sota.cnn.operations import *
|
||||
from sota.cnn.genotypes import Genotype
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from sota.cnn.model_search import Network
|
||||
|
||||
class DartsNetworkProj(Network):
|
||||
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
||||
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0):
|
||||
super(DartsNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
|
||||
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob)
|
||||
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### proj flags
|
||||
def _initialize_topology_dicts(self):
|
||||
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
|
||||
self.nid2selected_eids = {
|
||||
'normal': {0:[],1:[],2:[]},
|
||||
'reduce': {0:[],1:[],2:[]},
|
||||
}
|
||||
|
||||
def _initialize_flags(self):
|
||||
self.candidate_flags = {
|
||||
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
} # must be in this order
|
||||
self.candidate_flags_edge = {
|
||||
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
}
|
||||
|
||||
def _initialize_proj_weights(self):
|
||||
''' data structures used for proj '''
|
||||
if isinstance(self.alphas_normal, list):
|
||||
alphas_normal = torch.stack(self.alphas_normal, dim=0)
|
||||
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
|
||||
else:
|
||||
alphas_normal = self.alphas_normal
|
||||
alphas_reduce = self.alphas_reduce
|
||||
|
||||
self.proj_weights = { # for hard/soft assignment after project
|
||||
'normal': torch.zeros_like(alphas_normal),
|
||||
'reduce': torch.zeros_like(alphas_reduce),
|
||||
}
|
||||
|
||||
#### proj function
|
||||
def project_op(self, eid, opid, cell_type):
|
||||
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
|
||||
self.candidate_flags[cell_type][eid] = False
|
||||
|
||||
def project_edge(self, nid, eids, cell_type):
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in eids: # not top2
|
||||
self.proj_weights[cell_type][eid].data.fill_(0)
|
||||
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
|
||||
self.candidate_flags_edge[cell_type][nid] = False
|
||||
|
||||
#### critical function
|
||||
def get_projected_weights(self, cell_type):
|
||||
''' used in forward and genotype '''
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
## proj op
|
||||
for eid in range(self.num_edges):
|
||||
if not self.candidate_flags[cell_type][eid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
## proj edge
|
||||
for nid in self.nid2eids:
|
||||
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in self.nid2selected_eids[cell_type][nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def get_all_projected_weights(self, cell_type):
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
for eid in range(self.num_edges):
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
for nid in self.nid2eids:
|
||||
for eid in self.nid2eids[nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, input, weights_dict=None, using_proj=False):
|
||||
if using_proj:
|
||||
weights_normal = self.get_all_projected_weights('normal')
|
||||
weights_reduce = self.get_all_projected_weights('reduce')
|
||||
else:
|
||||
if weights_dict is None or 'normal' not in weights_dict:
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
else:
|
||||
weights_normal = weights_dict['normal']
|
||||
if weights_dict is None or 'reduce' not in weights_dict:
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
else:
|
||||
weights_reduce = weights_dict['reduce']
|
||||
|
||||
|
||||
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = weights_reduce
|
||||
else:
|
||||
weights = weights_normal
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0),-1))
|
||||
|
||||
return logits
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### utils
|
||||
def printing(self, logging, option='all'):
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
|
||||
if option in ['all', 'normal']:
|
||||
logging.info('\n%s', weights_normal)
|
||||
if option in ['all', 'reduce']:
|
||||
logging.info('\n%s', weights_reduce)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights, normal=True):
|
||||
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
|
||||
|
||||
gene = []
|
||||
n = 2
|
||||
start = 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
|
||||
try:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
||||
except ValueError:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
||||
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if 'none' in PRIMITIVES[j]:
|
||||
if k != PRIMITIVES[j].index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
else:
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[start+j][k_best], j))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
|
||||
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
||||
|
||||
def get_state_dict(self, epoch, architect, scheduler):
|
||||
model_state_dict = {
|
||||
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
|
||||
'state_dict': self.state_dict(),
|
||||
'alpha': self.arch_parameters(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'arch_optimizer': architect.optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
#### projection
|
||||
'nid2eids': self.nid2eids,
|
||||
'nid2selected_eids': self.nid2selected_eids,
|
||||
'candidate_flags': self.candidate_flags,
|
||||
'candidate_flags_edge': self.candidate_flags_edge,
|
||||
'proj_weights': self.proj_weights,
|
||||
}
|
||||
return model_state_dict
|
||||
|
||||
def set_state_dict(self, architect, scheduler, checkpoint):
|
||||
#### common
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.set_arch_parameters(checkpoint['alpha'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
|
||||
#### projection
|
||||
self.nid2eids = checkpoint['nid2eids']
|
||||
self.nid2selected_eids = checkpoint['nid2selected_eids']
|
||||
self.candidate_flags = checkpoint['candidate_flags']
|
||||
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
|
||||
self.proj_weights = checkpoint['proj_weights']
|
||||
214
sota/cnn/model_search_imagenet_proj.py
Normal file
214
sota/cnn/model_search_imagenet_proj.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
import torch.nn as nn
|
||||
from sota.cnn.operations import *
|
||||
from sota.cnn.genotypes import Genotype
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from sota.cnn.model_search import Network
|
||||
|
||||
class ImageNetNetworkProj(Network):
|
||||
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
||||
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0, nettype='imagenet'):
|
||||
super(ImageNetNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
|
||||
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob, nettype=nettype)
|
||||
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### proj flags
|
||||
def _initialize_topology_dicts(self):
|
||||
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
|
||||
self.nid2selected_eids = {
|
||||
'normal': {0:[],1:[],2:[]},
|
||||
'reduce': {0:[],1:[],2:[]},
|
||||
}
|
||||
|
||||
def _initialize_flags(self):
|
||||
self.candidate_flags = {
|
||||
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
} # must be in this order
|
||||
self.candidate_flags_edge = {
|
||||
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
}
|
||||
|
||||
def _initialize_proj_weights(self):
|
||||
''' data structures used for proj '''
|
||||
if isinstance(self.alphas_normal, list):
|
||||
alphas_normal = torch.stack(self.alphas_normal, dim=0)
|
||||
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
|
||||
else:
|
||||
alphas_normal = self.alphas_normal
|
||||
alphas_reduce = self.alphas_reduce
|
||||
|
||||
self.proj_weights = { # for hard/soft assignment after project
|
||||
'normal': torch.zeros_like(alphas_normal),
|
||||
'reduce': torch.zeros_like(alphas_reduce),
|
||||
}
|
||||
|
||||
#### proj function
|
||||
def project_op(self, eid, opid, cell_type):
|
||||
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
|
||||
self.candidate_flags[cell_type][eid] = False
|
||||
|
||||
def project_edge(self, nid, eids, cell_type):
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in eids: # not top2
|
||||
self.proj_weights[cell_type][eid].data.fill_(0)
|
||||
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
|
||||
self.candidate_flags_edge[cell_type][nid] = False
|
||||
|
||||
#### critical function
|
||||
def get_projected_weights(self, cell_type):
|
||||
''' used in forward and genotype '''
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
## proj op
|
||||
for eid in range(self.num_edges):
|
||||
if not self.candidate_flags[cell_type][eid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
## proj edge
|
||||
for nid in self.nid2eids:
|
||||
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in self.nid2selected_eids[cell_type][nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def get_all_projected_weights(self, cell_type):
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
for eid in range(self.num_edges):
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
for nid in self.nid2eids:
|
||||
for eid in self.nid2eids[nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, input, weights_dict=None, using_proj=False):
|
||||
if using_proj:
|
||||
weights_normal = self.get_all_projected_weights('normal')
|
||||
weights_reduce = self.get_all_projected_weights('reduce')
|
||||
else:
|
||||
if weights_dict is None or 'normal' not in weights_dict:
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
else:
|
||||
weights_normal = weights_dict['normal']
|
||||
if weights_dict is None or 'reduce' not in weights_dict:
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
else:
|
||||
weights_reduce = weights_dict['reduce']
|
||||
|
||||
|
||||
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = weights_reduce
|
||||
else:
|
||||
weights = weights_normal
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0),-1))
|
||||
|
||||
return logits
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### utils
|
||||
def printing(self, logging, option='all'):
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
|
||||
if option in ['all', 'normal']:
|
||||
logging.info('\n%s', weights_normal)
|
||||
if option in ['all', 'reduce']:
|
||||
logging.info('\n%s', weights_reduce)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights, normal=True):
|
||||
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
|
||||
|
||||
gene = []
|
||||
n = 2
|
||||
start = 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
|
||||
try:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
||||
except ValueError:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
||||
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if 'none' in PRIMITIVES[j]:
|
||||
if k != PRIMITIVES[j].index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
else:
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[start+j][k_best], j))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
|
||||
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
||||
|
||||
def get_state_dict(self, epoch, architect, scheduler):
|
||||
model_state_dict = {
|
||||
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
|
||||
'state_dict': self.state_dict(),
|
||||
'alpha': self.arch_parameters(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'arch_optimizer': architect.optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
#### projection
|
||||
'nid2eids': self.nid2eids,
|
||||
'nid2selected_eids': self.nid2selected_eids,
|
||||
'candidate_flags': self.candidate_flags,
|
||||
'candidate_flags_edge': self.candidate_flags_edge,
|
||||
'proj_weights': self.proj_weights,
|
||||
}
|
||||
return model_state_dict
|
||||
|
||||
def set_state_dict(self, architect, scheduler, checkpoint):
|
||||
#### common
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.set_arch_parameters(checkpoint['alpha'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
|
||||
#### projection
|
||||
self.nid2eids = checkpoint['nid2eids']
|
||||
self.nid2selected_eids = checkpoint['nid2selected_eids']
|
||||
self.candidate_flags = checkpoint['candidate_flags']
|
||||
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
|
||||
self.proj_weights = checkpoint['proj_weights']
|
||||
236
sota/cnn/networks_proposal.py
Normal file
236
sota/cnn/networks_proposal.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import time
|
||||
import glob
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import shutil
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import argparse
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
import json
|
||||
import copy
|
||||
|
||||
from sota.cnn.model_search import Network as DartsNetwork
|
||||
from sota.cnn.model_search_darts_proj import DartsNetworkProj
|
||||
from sota.cnn.model_search_imagenet_proj import ImageNetNetworkProj
|
||||
# from optimizers.darts.architect import Architect as DartsArchitect
|
||||
from nasbench201.architect_ig import Architect
|
||||
from sota.cnn.spaces import spaces_dict
|
||||
from foresight.pruners import *
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from sota.cnn.init_projection import pt_project
|
||||
from hdf5 import H5Dataset
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
parser = argparse.ArgumentParser("sota")
|
||||
parser.add_argument('--data', type=str, default='../../data',help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
||||
parser.add_argument('--seed', type=int, default=666, help='random seed')
|
||||
|
||||
#model opt related config
|
||||
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
|
||||
#system config
|
||||
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
|
||||
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
||||
parser.add_argument('--save_path', type=str, default='../../experiments/sota', help='experiment name')
|
||||
#search sapce config
|
||||
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
|
||||
parser.add_argument('--search_space', type=str, default='s5', help='searching space to choose from')
|
||||
parser.add_argument('--pool_size', type=int, default=10, help='number of model to proposed')
|
||||
|
||||
## projection
|
||||
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'sample'], help='used for both proj_op and proj_edge')
|
||||
parser.add_argument('--proj_crit_normal', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
|
||||
parser.add_argument('--proj_crit_reduce', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
|
||||
parser.add_argument('--proj_crit_edge', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
|
||||
parser.add_argument('--proj_mode_edge', type=str, default='reg', choices=['reg'],
|
||||
help='edge projection evaluation mode, reg: one edge at a time')
|
||||
args = parser.parse_args()
|
||||
|
||||
#### args augment
|
||||
|
||||
expid = args.save
|
||||
args.save = '{}/{}-search-{}-{}-{}-{}-{}'.format(args.save_path,
|
||||
args.dataset, args.save, args.search_space, args.seed, args.pool_size, args.proj_crit_normal)
|
||||
|
||||
if not args.edge_decision == 'random':
|
||||
args.save += '-' + args.edge_decision
|
||||
|
||||
scripts_to_save = glob.glob('*.py') + glob.glob('../../nasbench201/architect*.py') + glob.glob('../../optimizers/darts/architect.py')
|
||||
if os.path.exists(args.save):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
||||
print('proceed to override saving directory')
|
||||
shutil.rmtree(args.save)
|
||||
else:
|
||||
exit(0)
|
||||
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
||||
|
||||
#### logging
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
log_file = 'log.txt'
|
||||
log_path = os.path.join(args.save, log_file)
|
||||
logging.info('======> log filename: %s', log_file)
|
||||
|
||||
if os.path.exists(log_path):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
|
||||
print('proceed to override log file directory')
|
||||
else:
|
||||
exit(0)
|
||||
|
||||
fh = logging.FileHandler(log_path, mode='w')
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet':
|
||||
n_classes = 1000
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
def main():
|
||||
torch.set_num_threads(3)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
||||
torch.cuda.set_device(gpu)
|
||||
cudnn.benchmark = True
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.enabled = True
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
logging.info('gpu device = %d' % gpu)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
#### model
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = criterion.cuda()
|
||||
|
||||
## darts
|
||||
if args.dataset == 'imagenet':
|
||||
model = ImageNetNetworkProj(args.init_channels, n_classes, args.layers, criterion, spaces_dict[args.search_space], args)
|
||||
else:
|
||||
model = DartsNetworkProj(args.init_channels, n_classes, args.layers, criterion, spaces_dict[args.search_space], args)
|
||||
model = model.cuda()
|
||||
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
||||
|
||||
#### data
|
||||
if args.dataset == 'imagenet':
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.2),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
#for test
|
||||
#from nasbench201.DownsampledImageNet import ImageNet16
|
||||
# train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
# n_classes = 10
|
||||
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
|
||||
#valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
|
||||
|
||||
else:
|
||||
if args.dataset == 'cifar10':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
||||
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'cifar100':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
||||
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'svhn':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
|
||||
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
|
||||
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
|
||||
|
||||
num_train = len(train_data)
|
||||
indices = list(range(num_train))
|
||||
split = int(np.floor(args.train_portion * num_train))
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
|
||||
pin_memory=True)
|
||||
# for x, y in train_queue:
|
||||
# from torchvision import transforms
|
||||
# unloader = transforms.ToPILImage()
|
||||
# image = x.cpu().clone() # clone the tensor
|
||||
# image = image.squeeze(0) # remove the fake batch dimension
|
||||
# image = unloader(image)
|
||||
# image.save('example.jpg')
|
||||
|
||||
# print(x.size())
|
||||
# exit()
|
||||
|
||||
|
||||
#### projection
|
||||
networks_pool={}
|
||||
networks_pool['search_space'] = args.search_space
|
||||
networks_pool['dataset'] = args.dataset
|
||||
networks_pool['networks'] = []
|
||||
for i in range(args.pool_size):
|
||||
network_info={}
|
||||
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
|
||||
pt_project(train_queue, model, args)
|
||||
|
||||
## logging
|
||||
num_params = ig_utils.count_parameters_in_Compact(model)
|
||||
genotype = model.genotype()
|
||||
json_data = {}
|
||||
json_data['normal'] = genotype.normal
|
||||
json_data['normal_concat'] = [x for x in genotype.normal_concat]
|
||||
json_data['reduce'] = genotype.reduce
|
||||
json_data['reduce_concat'] = [x for x in genotype.reduce_concat]
|
||||
json_string = json.dumps(json_data)
|
||||
logging.info(json_string)
|
||||
network_info['id'] = str(i)
|
||||
network_info['genotype'] = json_string
|
||||
networks_pool['networks'].append(network_info)
|
||||
model.reset_arch_parameters()
|
||||
|
||||
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
|
||||
json.dump(networks_pool, save_file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
181
sota/cnn/operations.py
Normal file
181
sota/cnn/operations.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
OPS = {
|
||||
'noise': lambda C, stride, affine: NoiseOp(stride, 0., 1.),
|
||||
'none': lambda C, stride, affine: Zero(stride),
|
||||
'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3_skip': lambda C, stride, affine: SepConvSkip(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5_skip': lambda C, stride, affine: SepConvSkip(C, C, 5, stride, 2, affine=affine),
|
||||
'dil_conv_3x3_skip': lambda C, stride, affine: DilConvSkip(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5_skip': lambda C, stride, affine: DilConvSkip(C, C, 5, stride, 4, 2, affine=affine),
|
||||
}
|
||||
|
||||
|
||||
class NoiseOp(nn.Module):
|
||||
def __init__(self, stride, mean, std):
|
||||
super(NoiseOp, self).__init__()
|
||||
self.stride = stride
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
if self.stride != 1:
|
||||
x_new = x[:,:,::self.stride,::self.stride]
|
||||
else:
|
||||
x_new = x
|
||||
noise = Variable(x_new.data.new(x_new.size()).normal_(self.mean, self.std))
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:, :, ::self.stride, ::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
x = self.relu(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
|
||||
#### operations with skip
|
||||
class DilConvSkip(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConvSkip, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x) + x
|
||||
|
||||
|
||||
class SepConvSkip(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConvSkip, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x) + x
|
||||
248
sota/cnn/projection.py
Normal file
248
sota/cnn/projection.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import numpy as np
|
||||
import torch
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import torch.utils
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
|
||||
def project_op(model, proj_queue, args, infer, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'loss':
|
||||
crit_idx = 1
|
||||
compare = lambda x, y: x > y
|
||||
elif proj_crit == 'acc':
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
weights_dict = {cell_type:weights}
|
||||
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
|
||||
crit = valid_stats[crit_idx]
|
||||
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
logging.info('valid_acc %f', valid_stats[0])
|
||||
logging.info('valid_loss %f', valid_stats[1])
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
return selected_eid, best_opid
|
||||
|
||||
|
||||
def project_edge(model, proj_queue, args, infer, cell_type):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
#### select top2 edges
|
||||
if proj_crit == 'loss':
|
||||
crit_idx = 1
|
||||
compare = lambda x, y: x > y
|
||||
elif proj_crit == 'acc':
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
while len(eids) > 2:
|
||||
eid_todel = None
|
||||
crit_extrema = None
|
||||
for eid in eids:
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
weights[eid].data.fill_(0)
|
||||
weights_dict = {cell_type:weights}
|
||||
|
||||
## proj evaluation
|
||||
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
|
||||
crit = valid_stats[crit_idx]
|
||||
|
||||
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
|
||||
crit_extrema = crit
|
||||
eid_todel = eid
|
||||
logging.info('valid_acc %f', valid_stats[0])
|
||||
logging.info('valid_loss %f', valid_stats[1])
|
||||
eids.remove(eid_todel)
|
||||
|
||||
#### project
|
||||
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
|
||||
return selected_nid, eids
|
||||
|
||||
|
||||
def pt_project(train_queue, valid_queue, model, architect, optimizer,
|
||||
epoch, args, infer, perturb_alpha, epsilon_alpha):
|
||||
model.train()
|
||||
model.printing(logging)
|
||||
|
||||
train_acc, train_obj = infer(train_queue, model, log=False)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
logging.info('train_loss %f', train_obj)
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, log=False)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
logging.info('valid_loss %f', valid_obj)
|
||||
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
|
||||
|
||||
#### macros
|
||||
num_projs = model.num_edges + len(model.nid2eids.keys()) - 1 ## -1 because we project at both epoch 0 and -1
|
||||
tune_epochs = args.proj_intv * num_projs + 1
|
||||
proj_intv = args.proj_intv
|
||||
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
|
||||
proj_queue = valid_queue
|
||||
|
||||
|
||||
#### reset optimizer
|
||||
model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
model.optimizer, float(tune_epochs), eta_min=args.learning_rate_min)
|
||||
|
||||
|
||||
#### load proj checkpoints
|
||||
start_epoch = 0
|
||||
if args.dev_resume_epoch >= 0:
|
||||
filename = os.path.join(args.dev_resume_checkpoint_dir, 'checkpoint_{}.pth.tar'.format(args.dev_resume_epoch))
|
||||
if os.path.isfile(filename):
|
||||
logging.info("=> loading projection checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename, map_location='cpu')
|
||||
start_epoch = checkpoint['epoch']
|
||||
model.set_state_dict(architect, scheduler, checkpoint)
|
||||
model.set_arch_parameters(checkpoint['alpha'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
|
||||
else:
|
||||
logging.info("=> no checkpoint found at '{}'".format(filename))
|
||||
exit(0)
|
||||
|
||||
|
||||
#### projecting and tuning
|
||||
for epoch in range(start_epoch, tune_epochs):
|
||||
logging.info('epoch %d', epoch)
|
||||
|
||||
## project
|
||||
if epoch % proj_intv == 0 or epoch == tune_epochs - 1:
|
||||
## saving every projection
|
||||
save_state_dict = model.get_state_dict(epoch, architect, scheduler)
|
||||
ig_utils.save_checkpoint(save_state_dict, False, args.dev_save_checkpoint_dir, per_epoch=True)
|
||||
|
||||
if epoch < proj_intv * model.num_edges:
|
||||
logging.info('project op')
|
||||
|
||||
selected_eid_normal, best_opid_normal = project_op(model, proj_queue, args, infer, cell_type='normal')
|
||||
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
|
||||
selected_eid_reduce, best_opid_reduce = project_op(model, proj_queue, args, infer, cell_type='reduce')
|
||||
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
|
||||
|
||||
model.printing(logging)
|
||||
else:
|
||||
logging.info('project edge')
|
||||
|
||||
selected_nid_normal, eids_normal = project_edge(model, proj_queue, args, infer, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = project_edge(model, proj_queue, args, infer, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
|
||||
model.printing(logging)
|
||||
|
||||
## tune
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
model.train()
|
||||
n = input.size(0)
|
||||
|
||||
## fetch data
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
input_search, target_search = next(iter(valid_queue))
|
||||
input_search = input_search.cuda()
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
## train alpha
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
architect.step(input, target, input_search, target_search,
|
||||
return_logits=True)
|
||||
|
||||
## sdarts
|
||||
if perturb_alpha:
|
||||
# transform arch_parameters to prob (for perturbation)
|
||||
model.softmax_arch_parameters()
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
perturb_alpha(model, input, target, epsilon_alpha)
|
||||
|
||||
## train weight
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
logits, loss = model.step(input, target, args)
|
||||
|
||||
## sdarts
|
||||
if perturb_alpha:
|
||||
## restore alpha to unperturbed arch_parameters
|
||||
model.restore_arch_parameters()
|
||||
|
||||
## logging
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
break
|
||||
|
||||
## one epoch end
|
||||
model.printing(logging)
|
||||
|
||||
train_acc, train_obj = infer(train_queue, model, log=False)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
logging.info('train_loss %f', train_obj)
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, log=False)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
logging.info('valid_loss %f', valid_obj)
|
||||
|
||||
|
||||
logging.info('projection finished')
|
||||
model.printing(logging)
|
||||
num_params = ig_utils.count_parameters_in_Compact(model)
|
||||
genotype = model.genotype()
|
||||
logging.info('param size = %f', num_params)
|
||||
logging.info('genotype = %s', genotype)
|
||||
|
||||
return
|
||||
103
sota/cnn/spaces.py
Normal file
103
sota/cnn/spaces.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
|
||||
|
||||
primitives_1 = OrderedDict([('primitives_normal', [['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['max_pool_3x3',
|
||||
'skip_connect'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['max_pool_3x3',
|
||||
'skip_connect'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['dil_conv_3x3',
|
||||
'dil_conv_5x5'],
|
||||
['dil_conv_3x3',
|
||||
'dil_conv_5x5']]),
|
||||
('primitives_reduct', [['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'dil_conv_3x3'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5']])])
|
||||
|
||||
primitives_2 = OrderedDict([('primitives_normal', 14 * [['skip_connect',
|
||||
'sep_conv_3x3']]),
|
||||
('primitives_reduct', 14 * [['skip_connect',
|
||||
'sep_conv_3x3']])])
|
||||
|
||||
primitives_3 = OrderedDict([('primitives_normal', 14 * [['none',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3']]),
|
||||
('primitives_reduct', 14 * [['none',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3']])])
|
||||
|
||||
primitives_4 = OrderedDict([('primitives_normal', 14 * [['noise',
|
||||
'sep_conv_3x3']]),
|
||||
('primitives_reduct', 14 * [['noise',
|
||||
'sep_conv_3x3']])])
|
||||
|
||||
PRIMITIVES = [
|
||||
#'none', #0
|
||||
'max_pool_3x3', # 0
|
||||
'avg_pool_3x3', # 1
|
||||
'skip_connect', # 2
|
||||
'sep_conv_3x3', # 3
|
||||
'sep_conv_5x5', # 4
|
||||
'dil_conv_3x3', # 5
|
||||
'dil_conv_5x5' # 6
|
||||
]
|
||||
|
||||
primitives_5 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES]),
|
||||
('primitives_reduct', 14 * [PRIMITIVES])])
|
||||
|
||||
primitives_6 = OrderedDict([('primitives_normal', 14 * [['sep_conv_5x5']]),
|
||||
('primitives_reduct', 14 * [['sep_conv_5x5']])])
|
||||
spaces_dict = {
|
||||
's1': primitives_1,
|
||||
's2': primitives_2,
|
||||
's3': primitives_3,
|
||||
's4': primitives_4,
|
||||
's5': primitives_5, # DARTS Space
|
||||
's6': primitives_6,
|
||||
}
|
||||
309
sota/cnn/train.py
Normal file
309
sota/cnn/train.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, '../../')
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import argparse
|
||||
import shutil
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import json
|
||||
from sota.cnn.model import Network
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collections import namedtuple
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data', type=str, default='../../data',
|
||||
help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
|
||||
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
|
||||
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
|
||||
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
|
||||
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
|
||||
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
|
||||
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
|
||||
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
|
||||
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
|
||||
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
||||
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
|
||||
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--arch', type=str, default='c100_s4_pgd', help='which architecture to use')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
#### common
|
||||
parser.add_argument('--resume_epoch', type=int, default=0, help="load ckpt, start training at resume_epoch")
|
||||
parser.add_argument('--ckpt_interval', type=int, default=50, help="interval (epoch) for saving checkpoints")
|
||||
parser.add_argument('--resume_expid', type=str, default='', help="full expid to resume from, name == ckpt folder name")
|
||||
parser.add_argument('--fast', action='store_true', default=False, help="fast mode for debugging")
|
||||
parser.add_argument('--queue', action='store_true', default=False, help="queueing for gpu")
|
||||
|
||||
parser.add_argument('--from_dir', action='store_true', default=True, help="arch load form dir")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_network_pool(ckpt_path):
|
||||
with open(os.path.join(ckpt_path, 'best_networks.json'), 'r') as save_file:
|
||||
networks_pool = json.load(save_file)
|
||||
return networks_pool['networks']
|
||||
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
#### args augment
|
||||
expid = args.save
|
||||
|
||||
print(args.from_dir)
|
||||
if args.from_dir:
|
||||
id_name = os.path.split(args.arch)[1]
|
||||
# print('aaaaaaa', args.arch)
|
||||
args.arch = load_network_pool(args.arch)
|
||||
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
|
||||
args.dataset, args.save, id_name, args.seed)
|
||||
else:
|
||||
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
|
||||
args.dataset, args.save, args.arch, args.seed)
|
||||
if args.cutout:
|
||||
args.save += '-cutout-' + str(args.cutout_length) + '-' + str(args.cutout_prob)
|
||||
if args.auxiliary:
|
||||
args.save += '-auxiliary-' + str(args.auxiliary_weight)
|
||||
|
||||
#### logging
|
||||
if args.resume_epoch > 0: # do not delete dir if resume:
|
||||
args.save = '../../experiments/sota/{}/{}'.format(args.dataset, args.resume_expid)
|
||||
assert (os.path.exists(args.save), 'resume but {} does not exist!'.format(args.save))
|
||||
else:
|
||||
scripts_to_save = glob.glob('*.py')
|
||||
if os.path.exists(args.save):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
||||
print('proceed to override saving directory')
|
||||
shutil.rmtree(args.save)
|
||||
else:
|
||||
exit(0)
|
||||
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
log_file = 'log_resume_{}.txt'.format(args.resume_epoch) if args.resume_epoch > 0 else 'log.txt'
|
||||
fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w')
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
|
||||
def seed_torch(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_num_threads(3)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
#### gpu queueing
|
||||
if args.queue:
|
||||
ig_utils.queue_gpu()
|
||||
|
||||
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
||||
torch.cuda.set_device(gpu)
|
||||
cudnn.enabled = True
|
||||
seed_torch(args.seed)
|
||||
|
||||
logging.info('gpu device = %d' % gpu)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
if args.from_dir:
|
||||
genotype_config = json.loads(args.arch)
|
||||
genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'],
|
||||
reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat'])
|
||||
else:
|
||||
genotype = eval("genotypes.%s" % args.arch)
|
||||
|
||||
model = Network(args.init_channels, n_classes, args.layers, args.auxiliary, genotype)
|
||||
model = model.cuda()
|
||||
|
||||
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = criterion.cuda()
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
||||
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'cifar100':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
||||
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'svhn':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
|
||||
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
|
||||
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, float(args.epochs),
|
||||
# eta_min=1e-4
|
||||
)
|
||||
|
||||
#### resume
|
||||
start_epoch = 0
|
||||
if args.resume_epoch > 0:
|
||||
logging.info('loading checkpoint from {}'.format(expid))
|
||||
filename = os.path.join(args.save, 'checkpoint_{}.pth.tar'.format(args.resume_epoch))
|
||||
|
||||
if os.path.isfile(filename):
|
||||
print("=> loading checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename, map_location='cpu')
|
||||
resume_epoch = checkpoint['epoch'] # epoch
|
||||
model.load_state_dict(checkpoint['state_dict']) # model
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
|
||||
start_epoch = args.resume_epoch
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(filename, resume_epoch))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(filename))
|
||||
|
||||
#### main training
|
||||
best_valid_acc = 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
lr = scheduler.get_lr()[0]
|
||||
if args.cutout:
|
||||
train_transform.transforms[-1].cutout_prob = args.cutout_prob
|
||||
logging.info('epoch %d lr %e cutout_prob %e', epoch, lr,
|
||||
train_transform.transforms[-1].cutout_prob)
|
||||
else:
|
||||
logging.info('epoch %d lr %e', epoch, lr)
|
||||
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
|
||||
|
||||
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
writer.add_scalar('Acc/train', train_acc, epoch)
|
||||
writer.add_scalar('Obj/train', train_obj, epoch)
|
||||
|
||||
## scheduler
|
||||
scheduler.step()
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, criterion)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
writer.add_scalar('Acc/valid', valid_acc, epoch)
|
||||
writer.add_scalar('Obj/valid', valid_obj, epoch)
|
||||
|
||||
## checkpoint
|
||||
if (epoch + 1) % args.ckpt_interval == 0:
|
||||
save_state_dict = {
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
}
|
||||
ig_utils.save_checkpoint(save_state_dict, False, args.save, per_epoch=True)
|
||||
|
||||
best_valid_acc = max(best_valid_acc, valid_acc)
|
||||
logging.info('best valid_acc %f', best_valid_acc)
|
||||
writer.close()
|
||||
|
||||
|
||||
def train(train_queue, model, criterion, optimizer):
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
model.train()
|
||||
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits, logits_aux = model(input)
|
||||
loss = criterion(logits, target)
|
||||
if args.auxiliary:
|
||||
loss_aux = criterion(logits_aux, target)
|
||||
loss += args.auxiliary_weight * loss_aux
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
logging.info('//// WARNING: FAST MODE')
|
||||
break
|
||||
|
||||
return top1.avg, objs.avg
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion):
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for step, (input, target) in enumerate(valid_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
logits, _ = model(input)
|
||||
loss = criterion(logits, target)
|
||||
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
logging.info('//// WARNING: FAST MODE')
|
||||
break
|
||||
|
||||
return top1.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
254
sota/cnn/train_imagenet.py
Normal file
254
sota/cnn/train_imagenet.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
import nasbench201.utils as utils
|
||||
from sota.cnn.model_imagenet import NetworkImageNet as Network
|
||||
import sota.cnn.genotypes as genotypes
|
||||
from sota.cnn.hdf5 import H5Dataset
|
||||
|
||||
parser = argparse.ArgumentParser("imagenet")
|
||||
parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus')
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
|
||||
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
|
||||
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
|
||||
parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
|
||||
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
|
||||
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
|
||||
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
|
||||
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
|
||||
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
|
||||
parser.add_argument('--seed', type=int, default=0, help='random_ws seed')
|
||||
parser.add_argument('--arch', type=str, default='c10_s3_pgd', help='which architecture to use')
|
||||
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
|
||||
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
|
||||
parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay')
|
||||
parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays')
|
||||
parser.add_argument('--parallel', action='store_true', default=False, help='darts parallelism')
|
||||
parser.add_argument('--load', action='store_true', default=False, help='whether load checkpoint for continue training')
|
||||
args = parser.parse_args()
|
||||
|
||||
args.save = '../../experiments/sota/imagenet/eval/{}-{}-{}-{}'.format(
|
||||
args.save, time.strftime("%Y%m%d-%H%M%S"), args.arch, args.seed)
|
||||
if args.auxiliary:
|
||||
args.save += '-auxiliary-' + str(args.auxiliary_weight)
|
||||
args.save += '-' + str(np.random.randint(10000))
|
||||
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
|
||||
CLASSES = 1000
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
def seed_torch(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
cudnn.enabled = True
|
||||
seed_torch(args.seed)
|
||||
|
||||
logging.info('gpu device = %d' % args.gpu)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
genotype = eval("genotypes.%s" % args.arch)
|
||||
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
|
||||
|
||||
if args.parallel:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
|
||||
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = criterion.cuda()
|
||||
criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
|
||||
criterion_smooth = criterion_smooth.cuda()
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.2),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
|
||||
valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
|
||||
|
||||
if args.load:
|
||||
model, optimizer, start_epoch, best_acc_top1 = utils.load_checkpoint(
|
||||
model, optimizer, '../../experiments/sota/imagenet/eval/EXP-20200210-143540-c10_s3_pgd-0-auxiliary-0.4-2753')
|
||||
else:
|
||||
best_acc_top1 = 0
|
||||
start_epoch = 0
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
|
||||
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
|
||||
|
||||
train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
writer.add_scalar('Acc/train', train_acc, epoch)
|
||||
writer.add_scalar('Obj/train', train_obj, epoch)
|
||||
scheduler.step()
|
||||
|
||||
valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
|
||||
logging.info('valid_acc_top1 %f', valid_acc_top1)
|
||||
logging.info('valid_acc_top5 %f', valid_acc_top5)
|
||||
writer.add_scalar('Acc/valid_top1', valid_acc_top1, epoch)
|
||||
writer.add_scalar('Acc/valid_top5', valid_acc_top5, epoch)
|
||||
|
||||
is_best = False
|
||||
if valid_acc_top1 > best_acc_top1:
|
||||
best_acc_top1 = valid_acc_top1
|
||||
is_best = True
|
||||
|
||||
utils.save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc_top1': best_acc_top1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, is_best, args.save)
|
||||
|
||||
|
||||
def train(train_queue, model, criterion, optimizer):
|
||||
objs = utils.AvgrageMeter()
|
||||
top1 = utils.AvgrageMeter()
|
||||
top5 = utils.AvgrageMeter()
|
||||
model.train()
|
||||
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits, logits_aux = model(input)
|
||||
loss = criterion(logits, target)
|
||||
if args.auxiliary:
|
||||
loss_aux = criterion(logits_aux, target)
|
||||
loss += args.auxiliary_weight * loss_aux
|
||||
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
return top1.avg, objs.avg
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion):
|
||||
objs = utils.AvgrageMeter()
|
||||
top1 = utils.AvgrageMeter()
|
||||
top5 = utils.AvgrageMeter()
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for step, (input, target) in enumerate(valid_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
logits, _ = model(input)
|
||||
loss = criterion(logits, target)
|
||||
|
||||
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
67
sota/cnn/visualize.py
Normal file
67
sota/cnn/visualize.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import genotypes
|
||||
from graphviz import Digraph
|
||||
|
||||
|
||||
def plot(genotype, filename, mode=''):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='40', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='40', height='0.5', width='0.5',
|
||||
penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
# g.body.extend(['ratio=0.15'])
|
||||
# g.view()
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
assert len(genotype) % 2 == 0
|
||||
steps = len(genotype) // 2
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(steps):
|
||||
for k in [2 * i, 2 * i + 1]:
|
||||
op, j = genotype[k]
|
||||
if j == 0:
|
||||
u = "c_{k-2}"
|
||||
elif j == 1:
|
||||
u = "c_{k-1}"
|
||||
else:
|
||||
u = str(j - 2)
|
||||
v = str(i)
|
||||
|
||||
if mode == 'cue' and op != 'skip_connect' and op != 'noise':
|
||||
g.edge(u, v, label=op, fillcolor='gray', color='red', fontcolor='red')
|
||||
else:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
genotype_name = sys.argv[1]
|
||||
try:
|
||||
genotype = eval('genotypes.{}'.format(genotype_name))
|
||||
# print(genotype)
|
||||
except AttributeError:
|
||||
print("{} is not specified in genotypes.py".format(genotype_name))
|
||||
sys.exit(1)
|
||||
|
||||
mode = 'cue'
|
||||
path = '../../figs/genotypes/cnn_{}/'.format(mode)
|
||||
# print(genotype.normal)
|
||||
plot(genotype.normal, path + genotype_name + "_normal", mode=mode)
|
||||
plot(genotype.reduce, path + genotype_name + "_reduce", mode=mode)
|
||||
144
sota/cnn/visualize_full.py
Normal file
144
sota/cnn/visualize_full.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import sys
|
||||
import genotypes
|
||||
import numpy as np
|
||||
from graphviz import Digraph
|
||||
|
||||
|
||||
supernet_dict = {
|
||||
0: ('c_{k-2}', '0'),
|
||||
1: ('c_{k-1}', '0'),
|
||||
2: ('c_{k-2}', '1'),
|
||||
3: ('c_{k-1}', '1'),
|
||||
4: ('0', '1'),
|
||||
5: ('c_{k-2}', '2'),
|
||||
6: ('c_{k-1}', '2'),
|
||||
7: ('0', '2'),
|
||||
8: ('1', '2'),
|
||||
9: ('c_{k-2}', '3'),
|
||||
10: ('c_{k-1}', '3'),
|
||||
11: ('0', '3'),
|
||||
12: ('1', '3'),
|
||||
13: ('2', '3'),
|
||||
}
|
||||
steps = 4
|
||||
|
||||
def plot_space(primitives, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
g.body.extend(['ratio=50.0'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
|
||||
steps = 4
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
n = 2
|
||||
start = 0
|
||||
nodes_indx = ["c_{k-2}", "c_{k-1}"]
|
||||
for i in range(steps):
|
||||
end = start + n
|
||||
p = primitives[start:end]
|
||||
v = str(i)
|
||||
for node, prim in zip(nodes_indx, p):
|
||||
u = node
|
||||
for op in prim:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
start = end
|
||||
n += 1
|
||||
nodes_indx.append(v)
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
def plot(genotype, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='100', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
g.body.extend(['ratio=0.3'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
num_edges = len(genotype)
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for eid in range(num_edges):
|
||||
op = genotype[eid]
|
||||
u, v = supernet_dict[eid]
|
||||
if op != 'skip_connect':
|
||||
g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red')
|
||||
else:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
|
||||
# def plot(genotype, filename):
|
||||
# g = Digraph(
|
||||
# format='pdf',
|
||||
# edge_attr=dict(fontsize='100', fontname="times", penwidth='3'),
|
||||
# node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5',
|
||||
# penwidth='2', fontname="times"),
|
||||
# engine='dot')
|
||||
# g.body.extend(['rankdir=LR'])
|
||||
|
||||
# g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
# g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
# num_edges = len(genotype)
|
||||
|
||||
# for i in range(steps):
|
||||
# g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
# for eid in range(num_edges):
|
||||
# op = genotype[eid]
|
||||
# u, v = supernet_dict[eid]
|
||||
# if op != 'skip_connect':
|
||||
# g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red')
|
||||
# else:
|
||||
# g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
# g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
# for i in range(steps):
|
||||
# g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
# g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#### visualize the supernet ####
|
||||
if len(sys.argv) != 2:
|
||||
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
genotype_name = sys.argv[1]
|
||||
assert 'supernet' in genotype_name, 'this script only supports supernet visualization'
|
||||
try:
|
||||
genotype = eval('genotypes.{}'.format(genotype_name))
|
||||
except AttributeError:
|
||||
print("{} is not specified in genotypes.py".format(genotype_name))
|
||||
sys.exit(1)
|
||||
|
||||
path = '../../figs/genotypes/cnn_supernet_cue/'
|
||||
plot(genotype.normal, path + genotype_name + "_normal")
|
||||
plot(genotype.reduce, path + genotype_name + "_reduce")
|
||||
972
toy_model.ipynb
Normal file
972
toy_model.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,54 @@
|
||||
GitPython was originally written by Michael Trier.
|
||||
GitPython 0.2 was partially (re)written by Sebastian Thiel, based on 0.1.6 and git-dulwich.
|
||||
|
||||
Contributors are:
|
||||
|
||||
-Michael Trier <mtrier _at_ gmail.com>
|
||||
-Alan Briolat
|
||||
-Florian Apolloner <florian _at_ apolloner.eu>
|
||||
-David Aguilar <davvid _at_ gmail.com>
|
||||
-Jelmer Vernooij <jelmer _at_ samba.org>
|
||||
-Steve Frécinaux <code _at_ istique.net>
|
||||
-Kai Lautaportti <kai _at_ lautaportti.fi>
|
||||
-Paul Sowden <paul _at_ idontsmoke.co.uk>
|
||||
-Sebastian Thiel <byronimo _at_ gmail.com>
|
||||
-Jonathan Chu <jonathan.chu _at_ me.com>
|
||||
-Vincent Driessen <me _at_ nvie.com>
|
||||
-Phil Elson <pelson _dot_ pub _at_ gmail.com>
|
||||
-Bernard `Guyzmo` Pratz <guyzmo+gitpython+pub@m0g.net>
|
||||
-Timothy B. Hartman <tbhartman _at_ gmail.com>
|
||||
-Konstantin Popov <konstantin.popov.89 _at_ yandex.ru>
|
||||
-Peter Jones <pjones _at_ redhat.com>
|
||||
-Anson Mansfield <anson.mansfield _at_ gmail.com>
|
||||
-Ken Odegard <ken.odegard _at_ gmail.com>
|
||||
-Alexis Horgix Chotard
|
||||
-Piotr Babij <piotr.babij _at_ gmail.com>
|
||||
-Mikuláš Poul <mikulaspoul _at_ gmail.com>
|
||||
-Charles Bouchard-Légaré <cblegare.atl _at_ ntis.ca>
|
||||
-Yaroslav Halchenko <debian _at_ onerussian.com>
|
||||
-Tim Swast <swast _at_ google.com>
|
||||
-William Luc Ritchie
|
||||
-David Host <hostdm _at_ outlook.com>
|
||||
-A. Jesse Jiryu Davis <jesse _at_ emptysquare.net>
|
||||
-Steven Whitman <ninloot _at_ gmail.com>
|
||||
-Stefan Stancu <stefan.stancu _at_ gmail.com>
|
||||
-César Izurieta <cesar _at_ caih.org>
|
||||
-Arthur Milchior <arthur _at_ milchior.fr>
|
||||
-Anil Khatri <anil.soccer.khatri _at_ gmail.com>
|
||||
-JJ Graham <thetwoj _at_ gmail.com>
|
||||
-Ben Thayer <ben _at_ benthayer.com>
|
||||
-Dries Kennes <admin _at_ dries007.net>
|
||||
-Pratik Anurag <panurag247365 _at_ gmail.com>
|
||||
-Harmon <harmon.public _at_ gmail.com>
|
||||
-Liam Beguin <liambeguin _at_ gmail.com>
|
||||
-Ram Rachum <ram _at_ rachum.com>
|
||||
-Alba Mendez <me _at_ alba.sh>
|
||||
-Robert Westman <robert _at_ byteflux.io>
|
||||
-Hugo van Kemenade
|
||||
-Hiroki Tokunaga <tokusan441 _at_ gmail.com>
|
||||
-Julien Mauroy <pro.julien.mauroy _at_ gmail.com>
|
||||
-Patrick Gerard
|
||||
-Luke Twist <itsluketwist@gmail.com>
|
||||
-Joseph Hale <me _at_ jhale.dev>
|
||||
-Santos Gallegos <stsewd _at_ proton.me>
|
||||
Portions derived from other open source works and are clearly marked.
|
||||
@@ -0,0 +1,30 @@
|
||||
Copyright (C) 2008, 2009 Michael Trier and contributors
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the GitPython project nor the names of
|
||||
its contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: GitPython
|
||||
Version: 3.1.31
|
||||
Summary: GitPython is a Python library used to interact with Git repositories
|
||||
Home-page: https://github.com/gitpython-developers/GitPython
|
||||
Author: Sebastian Thiel, Michael Trier
|
||||
Author-email: byronimo@gmail.com, mtrier@gmail.com
|
||||
License: BSD
|
||||
Platform: UNKNOWN
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Environment :: Console
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: License :: OSI Approved :: BSD License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Operating System :: POSIX
|
||||
Classifier: Operating System :: Microsoft :: Windows
|
||||
Classifier: Operating System :: MacOS :: MacOS X
|
||||
Classifier: Typing :: Typed
|
||||
Classifier: Programming Language :: Python
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.7
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Requires-Python: >=3.7
|
||||
Description-Content-Type: text/markdown
|
||||
License-File: LICENSE
|
||||
License-File: AUTHORS
|
||||
Requires-Dist: gitdb (<5,>=4.0.1)
|
||||
Requires-Dist: typing-extensions (>=3.7.4.3) ; python_version < "3.8"
|
||||
|
||||
GitPython is a Python library used to interact with Git repositories
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
git/__init__.py,sha256=O2tZaGpLYVQiK9lN3NucvyEoZcSFig13tAB6d2TTTL0,2342
|
||||
git/cmd.py,sha256=i4IyhmCTP-72NPO5aVeWhDT6_jLmA1C2qzhsS7G2UVw,53712
|
||||
git/compat.py,sha256=3wWLkD9QrZvLiV6NtNxJILwGrLE2nw_SoLqaTEPH364,2256
|
||||
git/config.py,sha256=PO6qicfkKwRFlKJr9AUuDrWV0rimlmb5S2wIVLlOd7w,34581
|
||||
git/db.py,sha256=dEs2Bn-iDuHyero9afw8mrXHrLE7_CDExv943iWU9WI,2244
|
||||
git/diff.py,sha256=DOWd26Dk7FqnKt79zpniv19muBzdYa949TcQPqVbZGg,23434
|
||||
git/exc.py,sha256=ys5ZYuvzvNN3TfcB5R_bUNRy3OEvURS5pJMdfy0Iws4,6446
|
||||
git/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
git/remote.py,sha256=H88bonpIjnfozWScpQIFIccE7Soq2hfHO9ldnRCmGUY,45069
|
||||
git/types.py,sha256=bA4El-NC7YNwQ9jNtkbWgT0QmmAfVs4PVSwBOE_D1Bo,3020
|
||||
git/util.py,sha256=j5cjyeFibLs4Ed_ErkePf6sx1VWb95OQ4GlJUWgq6PU,39874
|
||||
git/index/__init__.py,sha256=43ovvVNocVRNiQd4fLqvUMuGGmwhBQ9SsiQ46vkvk1E,89
|
||||
git/index/base.py,sha256=5GnqwmhLNF9f12hUq4rQyOvqzxPF1Fdc0QOETT5r010,57523
|
||||
git/index/fun.py,sha256=Y41IGlu8XqnradQXFjTGMISM45m8J256bTKs4xWR4qY,16406
|
||||
git/index/typ.py,sha256=QnyWeqzU7_xnyiwOki5W633Jp9g5COqEf6B4PeW3hK8,6252
|
||||
git/index/util.py,sha256=ISsWZjGiflooNr6XtElP4AhWUxQOakouvgeXC2PEepI,3475
|
||||
git/objects/__init__.py,sha256=NW8HBfdZvBYe9W6IjMWafSj_DVlV2REmmrpWKrHkGVw,692
|
||||
git/objects/base.py,sha256=N2NTL9hLwgKqY-VQiar8Hvn4a41Y8o_Tmi_SR0mGAS8,7857
|
||||
git/objects/blob.py,sha256=FIbZTYniJ7nLsdrHuwhagFVs9tYoUIyXodRaHYLaQqs,986
|
||||
git/objects/commit.py,sha256=ji9ityweewpr12mHh9w3s3ubouYNNCRTBr-LBrjrPbs,27304
|
||||
git/objects/fun.py,sha256=SV3_G_jnEb_wEa5doF6AapX58StH3OGBxCAKeMyuA0I,8612
|
||||
git/objects/tag.py,sha256=ZXOLK_lV9E5G2aDl5t0hYDN2hhIhGF23HILHBnZgRX0,3840
|
||||
git/objects/tree.py,sha256=cSQbt3nn3cIrbVrBasB1wm2r-vzotYWhka1yDjOHf-k,14230
|
||||
git/objects/util.py,sha256=M8h53ueOV32nXE6XcnKhCHzXznT7pi8JpEEGgCNicXo,22275
|
||||
git/objects/submodule/__init__.py,sha256=OsMeiex7cG6ev2f35IaJ5csH-eXchSoNKCt4HXUG5Ws,93
|
||||
git/objects/submodule/base.py,sha256=R4jTjBJyMjFOfDAYwsA6Q3Lt6qeFYERPE4PABACW6GE,61539
|
||||
git/objects/submodule/root.py,sha256=Ev_RnGzv4hi3UqEFMHuSR-uGR7kYpwOgwZFUG31X-Hc,19568
|
||||
git/objects/submodule/util.py,sha256=u2zQGFWBmryqET0XWf9BuiY1OOgWB8YCU3Wz0xdp4E4,3380
|
||||
git/refs/__init__.py,sha256=PMF97jMUcivbCCEJnl2zTs-YtECNFp8rL8GHK8AitXU,203
|
||||
git/refs/head.py,sha256=rZ4LbFd05Gs9sAuSU5VQRDmJZfrwMwWtBpLlmiUQ-Zg,9756
|
||||
git/refs/log.py,sha256=Z8X9_ZGZrVTWz9p_-fk1N3m47G-HTRPwozoZBDd70DI,11892
|
||||
git/refs/reference.py,sha256=DUx7QvYqTBeVxG53ntPfKCp3wuJyDBRIZcPCy1OD22s,5414
|
||||
git/refs/remote.py,sha256=E63Bh5ig1GYrk6FE46iNtS5P6ZgODyPXot8eJw-mxts,2556
|
||||
git/refs/symbolic.py,sha256=XwfeYr1Zp-fuHAoGuVAXKk4EYlsuUMVu99OjJWuWDTQ,29967
|
||||
git/refs/tag.py,sha256=FNoCZ3BdDl2i5kD3si2P9hoXU9rDAZ_YK0Rn84TmKT8,4419
|
||||
git/repo/__init__.py,sha256=XMpdeowJRtTEd80jAcrKSQfMu2JZGMfPlpuIYHG2ZCk,80
|
||||
git/repo/base.py,sha256=uD4EL2AWUMSCHCqIk7voXoZ2iChaf5VJ1t1Abr7Zk10,54937
|
||||
git/repo/fun.py,sha256=VTRODXAb_x8bazkSd8g-Pkk8M2iLVK4kPoKQY9HXjZc,12962
|
||||
GitPython-3.1.31.dist-info/AUTHORS,sha256=0F09KKrRmwH3zJ4gqo1tJMVlalC9bSunDNKlRvR6q2c,2158
|
||||
GitPython-3.1.31.dist-info/LICENSE,sha256=_WV__CzvY9JceMq3gI1BTdA6KC5jiTSR_RHDL5i-Z_s,1521
|
||||
GitPython-3.1.31.dist-info/METADATA,sha256=zFy5SrG7Ur2UItx3seZXELCST9LBEX72wZa7Y7z7FSY,1340
|
||||
GitPython-3.1.31.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
|
||||
GitPython-3.1.31.dist-info/top_level.txt,sha256=0hzDuIp8obv624V3GmbqsagBWkk8ohtGU-Bc1PmTT0o,4
|
||||
GitPython-3.1.31.dist-info/RECORD,,
|
||||
@@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.37.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
gitdb<5,>=4.0.1
|
||||
@@ -0,0 +1 @@
|
||||
git
|
||||
@@ -0,0 +1,92 @@
|
||||
# __init__.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
# flake8: noqa
|
||||
# @PydevCodeAnalysisIgnore
|
||||
from git.exc import * # @NoMove @IgnorePep8
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import os.path as osp
|
||||
|
||||
from typing import Optional
|
||||
from git.types import PathLike
|
||||
|
||||
__version__ = '3.1.31'
|
||||
|
||||
|
||||
# { Initialization
|
||||
def _init_externals() -> None:
|
||||
"""Initialize external projects by putting them into the path"""
|
||||
if __version__ == '3.1.31' and "PYOXIDIZER" not in os.environ:
|
||||
sys.path.insert(1, osp.join(osp.dirname(__file__), "ext", "gitdb"))
|
||||
|
||||
try:
|
||||
import gitdb
|
||||
except ImportError as e:
|
||||
raise ImportError("'gitdb' could not be found in your PYTHONPATH") from e
|
||||
# END verify import
|
||||
|
||||
|
||||
# } END initialization
|
||||
|
||||
|
||||
#################
|
||||
_init_externals()
|
||||
#################
|
||||
|
||||
# { Imports
|
||||
|
||||
try:
|
||||
from git.config import GitConfigParser # @NoMove @IgnorePep8
|
||||
from git.objects import * # @NoMove @IgnorePep8
|
||||
from git.refs import * # @NoMove @IgnorePep8
|
||||
from git.diff import * # @NoMove @IgnorePep8
|
||||
from git.db import * # @NoMove @IgnorePep8
|
||||
from git.cmd import Git # @NoMove @IgnorePep8
|
||||
from git.repo import Repo # @NoMove @IgnorePep8
|
||||
from git.remote import * # @NoMove @IgnorePep8
|
||||
from git.index import * # @NoMove @IgnorePep8
|
||||
from git.util import ( # @NoMove @IgnorePep8
|
||||
LockFile,
|
||||
BlockingLockFile,
|
||||
Stats,
|
||||
Actor,
|
||||
rmtree,
|
||||
)
|
||||
except GitError as exc:
|
||||
raise ImportError("%s: %s" % (exc.__class__.__name__, exc)) from exc
|
||||
|
||||
# } END imports
|
||||
|
||||
__all__ = [name for name, obj in locals().items() if not (name.startswith("_") or inspect.ismodule(obj))]
|
||||
|
||||
|
||||
# { Initialize git executable path
|
||||
GIT_OK = None
|
||||
|
||||
|
||||
def refresh(path: Optional[PathLike] = None) -> None:
|
||||
"""Convenience method for setting the git executable path."""
|
||||
global GIT_OK
|
||||
GIT_OK = False
|
||||
|
||||
if not Git.refresh(path=path):
|
||||
return
|
||||
if not FetchInfo.refresh():
|
||||
return
|
||||
|
||||
GIT_OK = True
|
||||
|
||||
|
||||
# } END initialize git executable path
|
||||
|
||||
|
||||
#################
|
||||
try:
|
||||
refresh()
|
||||
except Exception as exc:
|
||||
raise ImportError("Failed to initialize: {0}".format(exc)) from exc
|
||||
#################
|
||||
1417
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/cmd.py
Normal file
1417
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/cmd.py
Normal file
File diff suppressed because it is too large
Load Diff
104
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/compat.py
Normal file
104
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/compat.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# config.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
"""utilities to help provide compatibility with python 3"""
|
||||
# flake8: noqa
|
||||
|
||||
import locale
|
||||
import os
|
||||
import sys
|
||||
|
||||
from gitdb.utils.encoding import (
|
||||
force_bytes, # @UnusedImport
|
||||
force_text, # @UnusedImport
|
||||
)
|
||||
|
||||
# typing --------------------------------------------------------------------
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Dict,
|
||||
IO,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
is_win: bool = os.name == "nt"
|
||||
is_posix = os.name == "posix"
|
||||
is_darwin = os.name == "darwin"
|
||||
defenc = sys.getfilesystemencoding()
|
||||
|
||||
|
||||
@overload
|
||||
def safe_decode(s: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def safe_decode(s: AnyStr) -> str:
|
||||
...
|
||||
|
||||
|
||||
def safe_decode(s: Union[AnyStr, None]) -> Optional[str]:
|
||||
"""Safely decodes a binary string to unicode"""
|
||||
if isinstance(s, str):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return s.decode(defenc, "surrogateescape")
|
||||
elif s is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError("Expected bytes or text, but got %r" % (s,))
|
||||
|
||||
|
||||
@overload
|
||||
def safe_encode(s: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def safe_encode(s: AnyStr) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
def safe_encode(s: Optional[AnyStr]) -> Optional[bytes]:
|
||||
"""Safely encodes a binary string to unicode"""
|
||||
if isinstance(s, str):
|
||||
return s.encode(defenc)
|
||||
elif isinstance(s, bytes):
|
||||
return s
|
||||
elif s is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError("Expected bytes or text, but got %r" % (s,))
|
||||
|
||||
|
||||
@overload
|
||||
def win_encode(s: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def win_encode(s: AnyStr) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
def win_encode(s: Optional[AnyStr]) -> Optional[bytes]:
|
||||
"""Encode unicodes for process arguments on Windows."""
|
||||
if isinstance(s, str):
|
||||
return s.encode(locale.getpreferredencoding(False))
|
||||
elif isinstance(s, bytes):
|
||||
return s
|
||||
elif s is not None:
|
||||
raise TypeError("Expected bytes or text, but got %r" % (s,))
|
||||
return None
|
||||
897
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/config.py
Normal file
897
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/config.py
Normal file
@@ -0,0 +1,897 @@
|
||||
# config.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
"""Module containing module parser implementation able to properly read and write
|
||||
configuration files"""
|
||||
|
||||
import sys
|
||||
import abc
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from io import BufferedReader, IOBase
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
|
||||
from git.compat import (
|
||||
defenc,
|
||||
force_text,
|
||||
is_win,
|
||||
)
|
||||
|
||||
from git.util import LockFile
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import configparser as cp
|
||||
|
||||
# typing-------------------------------------------------------
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
IO,
|
||||
List,
|
||||
Dict,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from git.types import Lit_config_levels, ConfigLevels_Tup, PathLike, assert_never, _T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo.base import Repo
|
||||
from io import BytesIO
|
||||
|
||||
T_ConfigParser = TypeVar("T_ConfigParser", bound="GitConfigParser")
|
||||
T_OMD_value = TypeVar("T_OMD_value", str, bytes, int, float, bool)
|
||||
|
||||
if sys.version_info[:3] < (3, 7, 2):
|
||||
# typing.Ordereddict not added until py 3.7.2
|
||||
from collections import OrderedDict
|
||||
|
||||
OrderedDict_OMD = OrderedDict
|
||||
else:
|
||||
from typing import OrderedDict
|
||||
|
||||
OrderedDict_OMD = OrderedDict[str, List[T_OMD_value]] # type: ignore[assignment, misc]
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
||||
__all__ = ("GitConfigParser", "SectionConstraint")
|
||||
|
||||
|
||||
log = logging.getLogger("git.config")
|
||||
log.addHandler(logging.NullHandler())
|
||||
|
||||
# invariants
|
||||
# represents the configuration level of a configuration file
|
||||
|
||||
|
||||
CONFIG_LEVELS: ConfigLevels_Tup = ("system", "user", "global", "repository")
|
||||
|
||||
|
||||
# Section pattern to detect conditional includes.
|
||||
# https://git-scm.com/docs/git-config#_conditional_includes
|
||||
CONDITIONAL_INCLUDE_REGEXP = re.compile(r"(?<=includeIf )\"(gitdir|gitdir/i|onbranch):(.+)\"")
|
||||
|
||||
|
||||
class MetaParserBuilder(abc.ABCMeta): # noqa: B024
|
||||
"""Utility class wrapping base-class methods into decorators that assure read-only properties"""
|
||||
|
||||
def __new__(cls, name: str, bases: Tuple, clsdict: Dict[str, Any]) -> "MetaParserBuilder":
|
||||
"""
|
||||
Equip all base-class methods with a needs_values decorator, and all non-const methods
|
||||
with a set_dirty_and_flush_changes decorator in addition to that."""
|
||||
kmm = "_mutating_methods_"
|
||||
if kmm in clsdict:
|
||||
mutating_methods = clsdict[kmm]
|
||||
for base in bases:
|
||||
methods = (t for t in inspect.getmembers(base, inspect.isroutine) if not t[0].startswith("_"))
|
||||
for name, method in methods:
|
||||
if name in clsdict:
|
||||
continue
|
||||
method_with_values = needs_values(method)
|
||||
if name in mutating_methods:
|
||||
method_with_values = set_dirty_and_flush_changes(method_with_values)
|
||||
# END mutating methods handling
|
||||
|
||||
clsdict[name] = method_with_values
|
||||
# END for each name/method pair
|
||||
# END for each base
|
||||
# END if mutating methods configuration is set
|
||||
|
||||
new_type = super(MetaParserBuilder, cls).__new__(cls, name, bases, clsdict)
|
||||
return new_type
|
||||
|
||||
|
||||
def needs_values(func: Callable[..., _T]) -> Callable[..., _T]:
|
||||
"""Returns method assuring we read values (on demand) before we try to access them"""
|
||||
|
||||
@wraps(func)
|
||||
def assure_data_present(self: "GitConfigParser", *args: Any, **kwargs: Any) -> _T:
|
||||
self.read()
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
# END wrapper method
|
||||
return assure_data_present
|
||||
|
||||
|
||||
def set_dirty_and_flush_changes(non_const_func: Callable[..., _T]) -> Callable[..., _T]:
|
||||
"""Return method that checks whether given non constant function may be called.
|
||||
If so, the instance will be set dirty.
|
||||
Additionally, we flush the changes right to disk"""
|
||||
|
||||
def flush_changes(self: "GitConfigParser", *args: Any, **kwargs: Any) -> _T:
|
||||
rval = non_const_func(self, *args, **kwargs)
|
||||
self._dirty = True
|
||||
self.write()
|
||||
return rval
|
||||
|
||||
# END wrapper method
|
||||
flush_changes.__name__ = non_const_func.__name__
|
||||
return flush_changes
|
||||
|
||||
|
||||
class SectionConstraint(Generic[T_ConfigParser]):
|
||||
|
||||
"""Constrains a ConfigParser to only option commands which are constrained to
|
||||
always use the section we have been initialized with.
|
||||
|
||||
It supports all ConfigParser methods that operate on an option.
|
||||
|
||||
:note:
|
||||
If used as a context manager, will release the wrapped ConfigParser."""
|
||||
|
||||
__slots__ = ("_config", "_section_name")
|
||||
_valid_attrs_ = (
|
||||
"get_value",
|
||||
"set_value",
|
||||
"get",
|
||||
"set",
|
||||
"getint",
|
||||
"getfloat",
|
||||
"getboolean",
|
||||
"has_option",
|
||||
"remove_section",
|
||||
"remove_option",
|
||||
"options",
|
||||
)
|
||||
|
||||
def __init__(self, config: T_ConfigParser, section: str) -> None:
|
||||
self._config = config
|
||||
self._section_name = section
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Yes, for some reason, we have to call it explicitly for it to work in PY3 !
|
||||
# Apparently __del__ doesn't get call anymore if refcount becomes 0
|
||||
# Ridiculous ... .
|
||||
self._config.release()
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr in self._valid_attrs_:
|
||||
return lambda *args, **kwargs: self._call_config(attr, *args, **kwargs)
|
||||
return super(SectionConstraint, self).__getattribute__(attr)
|
||||
|
||||
def _call_config(self, method: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call the configuration at the given method which must take a section name
|
||||
as first argument"""
|
||||
return getattr(self._config, method)(self._section_name, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def config(self) -> T_ConfigParser:
|
||||
"""return: Configparser instance we constrain"""
|
||||
return self._config
|
||||
|
||||
def release(self) -> None:
|
||||
"""Equivalent to GitConfigParser.release(), which is called on our underlying parser instance"""
|
||||
return self._config.release()
|
||||
|
||||
def __enter__(self) -> "SectionConstraint[T_ConfigParser]":
|
||||
self._config.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type: str, exception_value: str, traceback: str) -> None:
|
||||
self._config.__exit__(exception_type, exception_value, traceback)
|
||||
|
||||
|
||||
class _OMD(OrderedDict_OMD):
|
||||
"""Ordered multi-dict."""
|
||||
|
||||
def __setitem__(self, key: str, value: _T) -> None:
|
||||
super(_OMD, self).__setitem__(key, [value])
|
||||
|
||||
def add(self, key: str, value: Any) -> None:
|
||||
if key not in self:
|
||||
super(_OMD, self).__setitem__(key, [value])
|
||||
return None
|
||||
super(_OMD, self).__getitem__(key).append(value)
|
||||
|
||||
def setall(self, key: str, values: List[_T]) -> None:
|
||||
super(_OMD, self).__setitem__(key, values)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return super(_OMD, self).__getitem__(key)[-1]
|
||||
|
||||
def getlast(self, key: str) -> Any:
|
||||
return super(_OMD, self).__getitem__(key)[-1]
|
||||
|
||||
def setlast(self, key: str, value: Any) -> None:
|
||||
if key not in self:
|
||||
super(_OMD, self).__setitem__(key, [value])
|
||||
return
|
||||
|
||||
prior = super(_OMD, self).__getitem__(key)
|
||||
prior[-1] = value
|
||||
|
||||
def get(self, key: str, default: Union[_T, None] = None) -> Union[_T, None]:
|
||||
return super(_OMD, self).get(key, [default])[-1]
|
||||
|
||||
def getall(self, key: str) -> List[_T]:
|
||||
return super(_OMD, self).__getitem__(key)
|
||||
|
||||
def items(self) -> List[Tuple[str, _T]]: # type: ignore[override]
|
||||
"""List of (key, last value for key)."""
|
||||
return [(k, self[k]) for k in self]
|
||||
|
||||
def items_all(self) -> List[Tuple[str, List[_T]]]:
|
||||
"""List of (key, list of values for key)."""
|
||||
return [(k, self.getall(k)) for k in self]
|
||||
|
||||
|
||||
def get_config_path(config_level: Lit_config_levels) -> str:
|
||||
|
||||
# we do not support an absolute path of the gitconfig on windows ,
|
||||
# use the global config instead
|
||||
if is_win and config_level == "system":
|
||||
config_level = "global"
|
||||
|
||||
if config_level == "system":
|
||||
return "/etc/gitconfig"
|
||||
elif config_level == "user":
|
||||
config_home = os.environ.get("XDG_CONFIG_HOME") or osp.join(os.environ.get("HOME", "~"), ".config")
|
||||
return osp.normpath(osp.expanduser(osp.join(config_home, "git", "config")))
|
||||
elif config_level == "global":
|
||||
return osp.normpath(osp.expanduser("~/.gitconfig"))
|
||||
elif config_level == "repository":
|
||||
raise ValueError("No repo to get repository configuration from. Use Repo._get_config_path")
|
||||
else:
|
||||
# Should not reach here. Will raise ValueError if does. Static typing will warn missing elifs
|
||||
assert_never(
|
||||
config_level, # type: ignore[unreachable]
|
||||
ValueError(f"Invalid configuration level: {config_level!r}"),
|
||||
)
|
||||
|
||||
|
||||
class GitConfigParser(cp.RawConfigParser, metaclass=MetaParserBuilder):
|
||||
|
||||
"""Implements specifics required to read git style configuration files.
|
||||
|
||||
This variation behaves much like the git.config command such that the configuration
|
||||
will be read on demand based on the filepath given during initialization.
|
||||
|
||||
The changes will automatically be written once the instance goes out of scope, but
|
||||
can be triggered manually as well.
|
||||
|
||||
The configuration file will be locked if you intend to change values preventing other
|
||||
instances to write concurrently.
|
||||
|
||||
:note:
|
||||
The config is case-sensitive even when queried, hence section and option names
|
||||
must match perfectly.
|
||||
If used as a context manager, will release the locked file."""
|
||||
|
||||
# { Configuration
|
||||
# The lock type determines the type of lock to use in new configuration readers.
|
||||
# They must be compatible to the LockFile interface.
|
||||
# A suitable alternative would be the BlockingLockFile
|
||||
t_lock = LockFile
|
||||
re_comment = re.compile(r"^\s*[#;]")
|
||||
|
||||
# } END configuration
|
||||
|
||||
optvalueonly_source = r"\s*(?P<option>[^:=\s][^:=]*)"
|
||||
|
||||
OPTVALUEONLY = re.compile(optvalueonly_source)
|
||||
|
||||
OPTCRE = re.compile(optvalueonly_source + r"\s*(?P<vi>[:=])\s*" + r"(?P<value>.*)$")
|
||||
|
||||
del optvalueonly_source
|
||||
|
||||
# list of RawConfigParser methods able to change the instance
|
||||
_mutating_methods_ = ("add_section", "remove_section", "remove_option", "set")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_or_files: Union[None, PathLike, "BytesIO", Sequence[Union[PathLike, "BytesIO"]]] = None,
|
||||
read_only: bool = True,
|
||||
merge_includes: bool = True,
|
||||
config_level: Union[Lit_config_levels, None] = None,
|
||||
repo: Union["Repo", None] = None,
|
||||
) -> None:
|
||||
"""Initialize a configuration reader to read the given file_or_files and to
|
||||
possibly allow changes to it by setting read_only False
|
||||
|
||||
:param file_or_files:
|
||||
A single file path or file objects or multiple of these
|
||||
|
||||
:param read_only:
|
||||
If True, the ConfigParser may only read the data , but not change it.
|
||||
If False, only a single file path or file object may be given. We will write back the changes
|
||||
when they happen, or when the ConfigParser is released. This will not happen if other
|
||||
configuration files have been included
|
||||
:param merge_includes: if True, we will read files mentioned in [include] sections and merge their
|
||||
contents into ours. This makes it impossible to write back an individual configuration file.
|
||||
Thus, if you want to modify a single configuration file, turn this off to leave the original
|
||||
dataset unaltered when reading it.
|
||||
:param repo: Reference to repository to use if [includeIf] sections are found in configuration files.
|
||||
|
||||
"""
|
||||
cp.RawConfigParser.__init__(self, dict_type=_OMD)
|
||||
self._dict: Callable[..., _OMD] # type: ignore # mypy/typeshed bug?
|
||||
self._defaults: _OMD
|
||||
self._sections: _OMD # type: ignore # mypy/typeshed bug?
|
||||
|
||||
# Used in python 3, needs to stay in sync with sections for underlying implementation to work
|
||||
if not hasattr(self, "_proxies"):
|
||||
self._proxies = self._dict()
|
||||
|
||||
if file_or_files is not None:
|
||||
self._file_or_files: Union[PathLike, "BytesIO", Sequence[Union[PathLike, "BytesIO"]]] = file_or_files
|
||||
else:
|
||||
if config_level is None:
|
||||
if read_only:
|
||||
self._file_or_files = [
|
||||
get_config_path(cast(Lit_config_levels, f)) for f in CONFIG_LEVELS if f != "repository"
|
||||
]
|
||||
else:
|
||||
raise ValueError("No configuration level or configuration files specified")
|
||||
else:
|
||||
self._file_or_files = [get_config_path(config_level)]
|
||||
|
||||
self._read_only = read_only
|
||||
self._dirty = False
|
||||
self._is_initialized = False
|
||||
self._merge_includes = merge_includes
|
||||
self._repo = repo
|
||||
self._lock: Union["LockFile", None] = None
|
||||
self._acquire_lock()
|
||||
|
||||
def _acquire_lock(self) -> None:
|
||||
if not self._read_only:
|
||||
if not self._lock:
|
||||
if isinstance(self._file_or_files, (str, os.PathLike)):
|
||||
file_or_files = self._file_or_files
|
||||
elif isinstance(self._file_or_files, (tuple, list, Sequence)):
|
||||
raise ValueError(
|
||||
"Write-ConfigParsers can operate on a single file only, multiple files have been passed"
|
||||
)
|
||||
else:
|
||||
file_or_files = self._file_or_files.name
|
||||
|
||||
# END get filename from handle/stream
|
||||
# initialize lock base - we want to write
|
||||
self._lock = self.t_lock(file_or_files)
|
||||
# END lock check
|
||||
|
||||
self._lock._obtain_lock()
|
||||
# END read-only check
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Write pending changes if required and release locks"""
|
||||
# NOTE: only consistent in PY2
|
||||
self.release()
|
||||
|
||||
def __enter__(self) -> "GitConfigParser":
|
||||
self._acquire_lock()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.release()
|
||||
|
||||
def release(self) -> None:
|
||||
"""Flush changes and release the configuration write lock. This instance must not be used anymore afterwards.
|
||||
In Python 3, it's required to explicitly release locks and flush changes, as __del__ is not called
|
||||
deterministically anymore."""
|
||||
# checking for the lock here makes sure we do not raise during write()
|
||||
# in case an invalid parser was created who could not get a lock
|
||||
if self.read_only or (self._lock and not self._lock._has_lock()):
|
||||
return
|
||||
|
||||
try:
|
||||
try:
|
||||
self.write()
|
||||
except IOError:
|
||||
log.error("Exception during destruction of GitConfigParser", exc_info=True)
|
||||
except ReferenceError:
|
||||
# This happens in PY3 ... and usually means that some state cannot be written
|
||||
# as the sections dict cannot be iterated
|
||||
# Usually when shutting down the interpreter, don'y know how to fix this
|
||||
pass
|
||||
finally:
|
||||
if self._lock is not None:
|
||||
self._lock._release_lock()
|
||||
|
||||
def optionxform(self, optionstr: str) -> str:
|
||||
"""Do not transform options in any way when writing"""
|
||||
return optionstr
|
||||
|
||||
def _read(self, fp: Union[BufferedReader, IO[bytes]], fpname: str) -> None:
|
||||
"""A direct copy of the py2.4 version of the super class's _read method
|
||||
to assure it uses ordered dicts. Had to change one line to make it work.
|
||||
|
||||
Future versions have this fixed, but in fact its quite embarrassing for the
|
||||
guys not to have done it right in the first place !
|
||||
|
||||
Removed big comments to make it more compact.
|
||||
|
||||
Made sure it ignores initial whitespace as git uses tabs"""
|
||||
cursect = None # None, or a dictionary
|
||||
optname = None
|
||||
lineno = 0
|
||||
is_multi_line = False
|
||||
e = None # None, or an exception
|
||||
|
||||
def string_decode(v: str) -> str:
|
||||
if v[-1] == "\\":
|
||||
v = v[:-1]
|
||||
# end cut trailing escapes to prevent decode error
|
||||
|
||||
return v.encode(defenc).decode("unicode_escape")
|
||||
# end
|
||||
|
||||
# end
|
||||
|
||||
while True:
|
||||
# we assume to read binary !
|
||||
line = fp.readline().decode(defenc)
|
||||
if not line:
|
||||
break
|
||||
lineno = lineno + 1
|
||||
# comment or blank line?
|
||||
if line.strip() == "" or self.re_comment.match(line):
|
||||
continue
|
||||
if line.split(None, 1)[0].lower() == "rem" and line[0] in "rR":
|
||||
# no leading whitespace
|
||||
continue
|
||||
|
||||
# is it a section header?
|
||||
mo = self.SECTCRE.match(line.strip())
|
||||
if not is_multi_line and mo:
|
||||
sectname: str = mo.group("header").strip()
|
||||
if sectname in self._sections:
|
||||
cursect = self._sections[sectname]
|
||||
elif sectname == cp.DEFAULTSECT:
|
||||
cursect = self._defaults
|
||||
else:
|
||||
cursect = self._dict((("__name__", sectname),))
|
||||
self._sections[sectname] = cursect
|
||||
self._proxies[sectname] = None
|
||||
# So sections can't start with a continuation line
|
||||
optname = None
|
||||
# no section header in the file?
|
||||
elif cursect is None:
|
||||
raise cp.MissingSectionHeaderError(fpname, lineno, line)
|
||||
# an option line?
|
||||
elif not is_multi_line:
|
||||
mo = self.OPTCRE.match(line)
|
||||
if mo:
|
||||
# We might just have handled the last line, which could contain a quotation we want to remove
|
||||
optname, vi, optval = mo.group("option", "vi", "value")
|
||||
if vi in ("=", ":") and ";" in optval and not optval.strip().startswith('"'):
|
||||
pos = optval.find(";")
|
||||
if pos != -1 and optval[pos - 1].isspace():
|
||||
optval = optval[:pos]
|
||||
optval = optval.strip()
|
||||
if optval == '""':
|
||||
optval = ""
|
||||
# end handle empty string
|
||||
optname = self.optionxform(optname.rstrip())
|
||||
if len(optval) > 1 and optval[0] == '"' and optval[-1] != '"':
|
||||
is_multi_line = True
|
||||
optval = string_decode(optval[1:])
|
||||
# end handle multi-line
|
||||
# preserves multiple values for duplicate optnames
|
||||
cursect.add(optname, optval)
|
||||
else:
|
||||
# check if it's an option with no value - it's just ignored by git
|
||||
if not self.OPTVALUEONLY.match(line):
|
||||
if not e:
|
||||
e = cp.ParsingError(fpname)
|
||||
e.append(lineno, repr(line))
|
||||
continue
|
||||
else:
|
||||
line = line.rstrip()
|
||||
if line.endswith('"'):
|
||||
is_multi_line = False
|
||||
line = line[:-1]
|
||||
# end handle quotations
|
||||
optval = cursect.getlast(optname)
|
||||
cursect.setlast(optname, optval + string_decode(line))
|
||||
# END parse section or option
|
||||
# END while reading
|
||||
|
||||
# if any parsing errors occurred, raise an exception
|
||||
if e:
|
||||
raise e
|
||||
|
||||
def _has_includes(self) -> Union[bool, int]:
|
||||
return self._merge_includes and len(self._included_paths())
|
||||
|
||||
def _included_paths(self) -> List[Tuple[str, str]]:
|
||||
"""Return List all paths that must be included to configuration
|
||||
as Tuples of (option, value).
|
||||
"""
|
||||
paths = []
|
||||
|
||||
for section in self.sections():
|
||||
if section == "include":
|
||||
paths += self.items(section)
|
||||
|
||||
match = CONDITIONAL_INCLUDE_REGEXP.search(section)
|
||||
if match is None or self._repo is None:
|
||||
continue
|
||||
|
||||
keyword = match.group(1)
|
||||
value = match.group(2).strip()
|
||||
|
||||
if keyword in ["gitdir", "gitdir/i"]:
|
||||
value = osp.expanduser(value)
|
||||
|
||||
if not any(value.startswith(s) for s in ["./", "/"]):
|
||||
value = "**/" + value
|
||||
if value.endswith("/"):
|
||||
value += "**"
|
||||
|
||||
# Ensure that glob is always case insensitive if required.
|
||||
if keyword.endswith("/i"):
|
||||
value = re.sub(
|
||||
r"[a-zA-Z]",
|
||||
lambda m: "[{}{}]".format(m.group().lower(), m.group().upper()),
|
||||
value,
|
||||
)
|
||||
if self._repo.git_dir:
|
||||
if fnmatch.fnmatchcase(str(self._repo.git_dir), value):
|
||||
paths += self.items(section)
|
||||
|
||||
elif keyword == "onbranch":
|
||||
try:
|
||||
branch_name = self._repo.active_branch.name
|
||||
except TypeError:
|
||||
# Ignore section if active branch cannot be retrieved.
|
||||
continue
|
||||
|
||||
if fnmatch.fnmatchcase(branch_name, value):
|
||||
paths += self.items(section)
|
||||
|
||||
return paths
|
||||
|
||||
def read(self) -> None: # type: ignore[override]
|
||||
"""Reads the data stored in the files we have been initialized with. It will
|
||||
ignore files that cannot be read, possibly leaving an empty configuration
|
||||
|
||||
:return: Nothing
|
||||
:raise IOError: if a file cannot be handled"""
|
||||
if self._is_initialized:
|
||||
return None
|
||||
self._is_initialized = True
|
||||
|
||||
files_to_read: List[Union[PathLike, IO]] = [""]
|
||||
if isinstance(self._file_or_files, (str, os.PathLike)):
|
||||
# for str or Path, as str is a type of Sequence
|
||||
files_to_read = [self._file_or_files]
|
||||
elif not isinstance(self._file_or_files, (tuple, list, Sequence)):
|
||||
# could merge with above isinstance once runtime type known
|
||||
files_to_read = [self._file_or_files]
|
||||
else: # for lists or tuples
|
||||
files_to_read = list(self._file_or_files)
|
||||
# end assure we have a copy of the paths to handle
|
||||
|
||||
seen = set(files_to_read)
|
||||
num_read_include_files = 0
|
||||
while files_to_read:
|
||||
file_path = files_to_read.pop(0)
|
||||
file_ok = False
|
||||
|
||||
if hasattr(file_path, "seek"):
|
||||
# must be a file objectfile-object
|
||||
file_path = cast(IO[bytes], file_path) # replace with assert to narrow type, once sure
|
||||
self._read(file_path, file_path.name)
|
||||
else:
|
||||
# assume a path if it is not a file-object
|
||||
file_path = cast(PathLike, file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as fp:
|
||||
file_ok = True
|
||||
self._read(fp, fp.name)
|
||||
except IOError:
|
||||
continue
|
||||
|
||||
# Read includes and append those that we didn't handle yet
|
||||
# We expect all paths to be normalized and absolute (and will assure that is the case)
|
||||
if self._has_includes():
|
||||
for _, include_path in self._included_paths():
|
||||
if include_path.startswith("~"):
|
||||
include_path = osp.expanduser(include_path)
|
||||
if not osp.isabs(include_path):
|
||||
if not file_ok:
|
||||
continue
|
||||
# end ignore relative paths if we don't know the configuration file path
|
||||
file_path = cast(PathLike, file_path)
|
||||
assert osp.isabs(file_path), "Need absolute paths to be sure our cycle checks will work"
|
||||
include_path = osp.join(osp.dirname(file_path), include_path)
|
||||
# end make include path absolute
|
||||
include_path = osp.normpath(include_path)
|
||||
if include_path in seen or not os.access(include_path, os.R_OK):
|
||||
continue
|
||||
seen.add(include_path)
|
||||
# insert included file to the top to be considered first
|
||||
files_to_read.insert(0, include_path)
|
||||
num_read_include_files += 1
|
||||
# each include path in configuration file
|
||||
# end handle includes
|
||||
# END for each file object to read
|
||||
|
||||
# If there was no file included, we can safely write back (potentially) the configuration file
|
||||
# without altering it's meaning
|
||||
if num_read_include_files == 0:
|
||||
self._merge_includes = False
|
||||
# end
|
||||
|
||||
def _write(self, fp: IO) -> None:
|
||||
"""Write an .ini-format representation of the configuration state in
|
||||
git compatible format"""
|
||||
|
||||
def write_section(name: str, section_dict: _OMD) -> None:
|
||||
fp.write(("[%s]\n" % name).encode(defenc))
|
||||
|
||||
values: Sequence[str] # runtime only gets str in tests, but should be whatever _OMD stores
|
||||
v: str
|
||||
for (key, values) in section_dict.items_all():
|
||||
if key == "__name__":
|
||||
continue
|
||||
|
||||
for v in values:
|
||||
fp.write(("\t%s = %s\n" % (key, self._value_to_string(v).replace("\n", "\n\t"))).encode(defenc))
|
||||
# END if key is not __name__
|
||||
|
||||
# END section writing
|
||||
|
||||
if self._defaults:
|
||||
write_section(cp.DEFAULTSECT, self._defaults)
|
||||
value: _OMD
|
||||
|
||||
for name, value in self._sections.items():
|
||||
write_section(name, value)
|
||||
|
||||
def items(self, section_name: str) -> List[Tuple[str, str]]: # type: ignore[override]
|
||||
""":return: list((option, value), ...) pairs of all items in the given section"""
|
||||
return [(k, v) for k, v in super(GitConfigParser, self).items(section_name) if k != "__name__"]
|
||||
|
||||
def items_all(self, section_name: str) -> List[Tuple[str, List[str]]]:
|
||||
""":return: list((option, [values...]), ...) pairs of all items in the given section"""
|
||||
rv = _OMD(self._defaults)
|
||||
|
||||
for k, vs in self._sections[section_name].items_all():
|
||||
if k == "__name__":
|
||||
continue
|
||||
|
||||
if k in rv and rv.getall(k) == vs:
|
||||
continue
|
||||
|
||||
for v in vs:
|
||||
rv.add(k, v)
|
||||
|
||||
return rv.items_all()
|
||||
|
||||
@needs_values
|
||||
def write(self) -> None:
|
||||
"""Write changes to our file, if there are changes at all
|
||||
|
||||
:raise IOError: if this is a read-only writer instance or if we could not obtain
|
||||
a file lock"""
|
||||
self._assure_writable("write")
|
||||
if not self._dirty:
|
||||
return None
|
||||
|
||||
if isinstance(self._file_or_files, (list, tuple)):
|
||||
raise AssertionError(
|
||||
"Cannot write back if there is not exactly a single file to write to, have %i files"
|
||||
% len(self._file_or_files)
|
||||
)
|
||||
# end assert multiple files
|
||||
|
||||
if self._has_includes():
|
||||
log.debug(
|
||||
"Skipping write-back of configuration file as include files were merged in."
|
||||
+ "Set merge_includes=False to prevent this."
|
||||
)
|
||||
return None
|
||||
# end
|
||||
|
||||
fp = self._file_or_files
|
||||
|
||||
# we have a physical file on disk, so get a lock
|
||||
is_file_lock = isinstance(fp, (str, os.PathLike, IOBase)) # can't use Pathlike until 3.5 dropped
|
||||
if is_file_lock and self._lock is not None: # else raise Error?
|
||||
self._lock._obtain_lock()
|
||||
|
||||
if not hasattr(fp, "seek"):
|
||||
fp = cast(PathLike, fp)
|
||||
with open(fp, "wb") as fp_open:
|
||||
self._write(fp_open)
|
||||
else:
|
||||
fp = cast("BytesIO", fp)
|
||||
fp.seek(0)
|
||||
# make sure we do not overwrite into an existing file
|
||||
if hasattr(fp, "truncate"):
|
||||
fp.truncate()
|
||||
self._write(fp)
|
||||
|
||||
def _assure_writable(self, method_name: str) -> None:
|
||||
if self.read_only:
|
||||
raise IOError("Cannot execute non-constant method %s.%s" % (self, method_name))
|
||||
|
||||
def add_section(self, section: str) -> None:
|
||||
"""Assures added options will stay in order"""
|
||||
return super(GitConfigParser, self).add_section(section)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
""":return: True if this instance may change the configuration file"""
|
||||
return self._read_only
|
||||
|
||||
def get_value(
|
||||
self,
|
||||
section: str,
|
||||
option: str,
|
||||
default: Union[int, float, str, bool, None] = None,
|
||||
) -> Union[int, float, str, bool]:
|
||||
# can default or return type include bool?
|
||||
"""Get an option's value.
|
||||
|
||||
If multiple values are specified for this option in the section, the
|
||||
last one specified is returned.
|
||||
|
||||
:param default:
|
||||
If not None, the given default value will be returned in case
|
||||
the option did not exist
|
||||
:return: a properly typed value, either int, float or string
|
||||
|
||||
:raise TypeError: in case the value could not be understood
|
||||
Otherwise the exceptions known to the ConfigParser will be raised."""
|
||||
try:
|
||||
valuestr = self.get(section, option)
|
||||
except Exception:
|
||||
if default is not None:
|
||||
return default
|
||||
raise
|
||||
|
||||
return self._string_to_value(valuestr)
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
section: str,
|
||||
option: str,
|
||||
default: Union[int, float, str, bool, None] = None,
|
||||
) -> List[Union[int, float, str, bool]]:
|
||||
"""Get an option's values.
|
||||
|
||||
If multiple values are specified for this option in the section, all are
|
||||
returned.
|
||||
|
||||
:param default:
|
||||
If not None, a list containing the given default value will be
|
||||
returned in case the option did not exist
|
||||
:return: a list of properly typed values, either int, float or string
|
||||
|
||||
:raise TypeError: in case the value could not be understood
|
||||
Otherwise the exceptions known to the ConfigParser will be raised."""
|
||||
try:
|
||||
self.sections()
|
||||
lst = self._sections[section].getall(option)
|
||||
except Exception:
|
||||
if default is not None:
|
||||
return [default]
|
||||
raise
|
||||
|
||||
return [self._string_to_value(valuestr) for valuestr in lst]
|
||||
|
||||
def _string_to_value(self, valuestr: str) -> Union[int, float, str, bool]:
|
||||
types = (int, float)
|
||||
for numtype in types:
|
||||
try:
|
||||
val = numtype(valuestr)
|
||||
# truncated value ?
|
||||
if val != float(valuestr):
|
||||
continue
|
||||
return val
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
# END for each numeric type
|
||||
|
||||
# try boolean values as git uses them
|
||||
vl = valuestr.lower()
|
||||
if vl == "false":
|
||||
return False
|
||||
if vl == "true":
|
||||
return True
|
||||
|
||||
if not isinstance(valuestr, str):
|
||||
raise TypeError(
|
||||
"Invalid value type: only int, long, float and str are allowed",
|
||||
valuestr,
|
||||
)
|
||||
|
||||
return valuestr
|
||||
|
||||
def _value_to_string(self, value: Union[str, bytes, int, float, bool]) -> str:
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return str(value)
|
||||
return force_text(value)
|
||||
|
||||
@needs_values
|
||||
@set_dirty_and_flush_changes
|
||||
def set_value(self, section: str, option: str, value: Union[str, bytes, int, float, bool]) -> "GitConfigParser":
|
||||
"""Sets the given option in section to the given value.
|
||||
It will create the section if required, and will not throw as opposed to the default
|
||||
ConfigParser 'set' method.
|
||||
|
||||
:param section: Name of the section in which the option resides or should reside
|
||||
:param option: Name of the options whose value to set
|
||||
|
||||
:param value: Value to set the option to. It must be a string or convertible
|
||||
to a string
|
||||
:return: this instance"""
|
||||
if not self.has_section(section):
|
||||
self.add_section(section)
|
||||
self.set(section, option, self._value_to_string(value))
|
||||
return self
|
||||
|
||||
@needs_values
|
||||
@set_dirty_and_flush_changes
|
||||
def add_value(self, section: str, option: str, value: Union[str, bytes, int, float, bool]) -> "GitConfigParser":
|
||||
"""Adds a value for the given option in section.
|
||||
It will create the section if required, and will not throw as opposed to the default
|
||||
ConfigParser 'set' method. The value becomes the new value of the option as returned
|
||||
by 'get_value', and appends to the list of values returned by 'get_values`'.
|
||||
|
||||
:param section: Name of the section in which the option resides or should reside
|
||||
:param option: Name of the option
|
||||
|
||||
:param value: Value to add to option. It must be a string or convertible
|
||||
to a string
|
||||
:return: this instance"""
|
||||
if not self.has_section(section):
|
||||
self.add_section(section)
|
||||
self._sections[section].add(option, self._value_to_string(value))
|
||||
return self
|
||||
|
||||
def rename_section(self, section: str, new_name: str) -> "GitConfigParser":
|
||||
"""rename the given section to new_name
|
||||
:raise ValueError: if section doesn't exit
|
||||
:raise ValueError: if a section with new_name does already exist
|
||||
:return: this instance
|
||||
"""
|
||||
if not self.has_section(section):
|
||||
raise ValueError("Source section '%s' doesn't exist" % section)
|
||||
if self.has_section(new_name):
|
||||
raise ValueError("Destination section '%s' already exists" % new_name)
|
||||
|
||||
super(GitConfigParser, self).add_section(new_name)
|
||||
new_section = self._sections[new_name]
|
||||
for k, vs in self.items_all(section):
|
||||
new_section.setall(k, vs)
|
||||
# end for each value to copy
|
||||
|
||||
# This call writes back the changes, which is why we don't have the respective decorator
|
||||
self.remove_section(section)
|
||||
return self
|
||||
63
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/db.py
Normal file
63
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/db.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Module with our own gitdb implementation - it uses the git command"""
|
||||
from git.util import bin_to_hex, hex_to_bin
|
||||
from gitdb.base import OInfo, OStream
|
||||
from gitdb.db import GitDB # @UnusedImport
|
||||
from gitdb.db import LooseObjectDB
|
||||
|
||||
from gitdb.exc import BadObject
|
||||
from git.exc import GitCommandError
|
||||
|
||||
# typing-------------------------------------------------
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from git.types import PathLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.cmd import Git
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
|
||||
__all__ = ("GitCmdObjectDB", "GitDB")
|
||||
|
||||
|
||||
class GitCmdObjectDB(LooseObjectDB):
|
||||
|
||||
"""A database representing the default git object store, which includes loose
|
||||
objects, pack files and an alternates file
|
||||
|
||||
It will create objects only in the loose object database.
|
||||
:note: for now, we use the git command to do all the lookup, just until he
|
||||
have packs and the other implementations
|
||||
"""
|
||||
|
||||
def __init__(self, root_path: PathLike, git: "Git") -> None:
|
||||
"""Initialize this instance with the root and a git command"""
|
||||
super(GitCmdObjectDB, self).__init__(root_path)
|
||||
self._git = git
|
||||
|
||||
def info(self, binsha: bytes) -> OInfo:
|
||||
hexsha, typename, size = self._git.get_object_header(bin_to_hex(binsha))
|
||||
return OInfo(hex_to_bin(hexsha), typename, size)
|
||||
|
||||
def stream(self, binsha: bytes) -> OStream:
|
||||
"""For now, all lookup is done by git itself"""
|
||||
hexsha, typename, size, stream = self._git.stream_object_data(bin_to_hex(binsha))
|
||||
return OStream(hex_to_bin(hexsha), typename, size, stream)
|
||||
|
||||
# { Interface
|
||||
|
||||
def partial_to_complete_sha_hex(self, partial_hexsha: str) -> bytes:
|
||||
""":return: Full binary 20 byte sha from the given partial hexsha
|
||||
:raise AmbiguousObjectName:
|
||||
:raise BadObject:
|
||||
:note: currently we only raise BadObject as git does not communicate
|
||||
AmbiguousObjects separately"""
|
||||
try:
|
||||
hexsha, _typename, _size = self._git.get_object_header(partial_hexsha)
|
||||
return hex_to_bin(hexsha)
|
||||
except (GitCommandError, ValueError) as e:
|
||||
raise BadObject(partial_hexsha) from e
|
||||
# END handle exceptions
|
||||
|
||||
# } END interface
|
||||
662
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/diff.py
Normal file
662
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/diff.py
Normal file
@@ -0,0 +1,662 @@
|
||||
# diff.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
|
||||
import re
|
||||
from git.cmd import handle_process_output
|
||||
from git.compat import defenc
|
||||
from git.util import finalize_process, hex_to_bin
|
||||
|
||||
from .objects.blob import Blob
|
||||
from .objects.util import mode_str_to_int
|
||||
|
||||
|
||||
# typing ------------------------------------------------------------------
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Iterator,
|
||||
List,
|
||||
Match,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
cast,
|
||||
)
|
||||
from git.types import PathLike, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .objects.tree import Tree
|
||||
from .objects import Commit
|
||||
from git.repo.base import Repo
|
||||
from git.objects.base import IndexObject
|
||||
from subprocess import Popen
|
||||
from git import Git
|
||||
|
||||
Lit_change_type = Literal["A", "D", "C", "M", "R", "T", "U"]
|
||||
|
||||
|
||||
# def is_change_type(inp: str) -> TypeGuard[Lit_change_type]:
|
||||
# # return True
|
||||
# return inp in ['A', 'D', 'C', 'M', 'R', 'T', 'U']
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
|
||||
__all__ = ("Diffable", "DiffIndex", "Diff", "NULL_TREE")
|
||||
|
||||
# Special object to compare against the empty tree in diffs
|
||||
NULL_TREE = object()
|
||||
|
||||
_octal_byte_re = re.compile(b"\\\\([0-9]{3})")
|
||||
|
||||
|
||||
def _octal_repl(matchobj: Match) -> bytes:
|
||||
value = matchobj.group(1)
|
||||
value = int(value, 8)
|
||||
value = bytes(bytearray((value,)))
|
||||
return value
|
||||
|
||||
|
||||
def decode_path(path: bytes, has_ab_prefix: bool = True) -> Optional[bytes]:
|
||||
if path == b"/dev/null":
|
||||
return None
|
||||
|
||||
if path.startswith(b'"') and path.endswith(b'"'):
|
||||
path = path[1:-1].replace(b"\\n", b"\n").replace(b"\\t", b"\t").replace(b'\\"', b'"').replace(b"\\\\", b"\\")
|
||||
|
||||
path = _octal_byte_re.sub(_octal_repl, path)
|
||||
|
||||
if has_ab_prefix:
|
||||
assert path.startswith(b"a/") or path.startswith(b"b/")
|
||||
path = path[2:]
|
||||
|
||||
return path
|
||||
|
||||
|
||||
class Diffable(object):
|
||||
|
||||
"""Common interface for all object that can be diffed against another object of compatible type.
|
||||
|
||||
:note:
|
||||
Subclasses require a repo member as it is the case for Object instances, for practical
|
||||
reasons we do not derive from Object."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
# standin indicating you want to diff against the index
|
||||
class Index(object):
|
||||
pass
|
||||
|
||||
def _process_diff_args(
|
||||
self, args: List[Union[str, "Diffable", Type["Diffable.Index"], object]]
|
||||
) -> List[Union[str, "Diffable", Type["Diffable.Index"], object]]:
|
||||
"""
|
||||
:return:
|
||||
possibly altered version of the given args list.
|
||||
Method is called right before git command execution.
|
||||
Subclasses can use it to alter the behaviour of the superclass"""
|
||||
return args
|
||||
|
||||
def diff(
|
||||
self,
|
||||
other: Union[Type["Index"], "Tree", "Commit", None, str, object] = Index,
|
||||
paths: Union[PathLike, List[PathLike], Tuple[PathLike, ...], None] = None,
|
||||
create_patch: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "DiffIndex":
|
||||
"""Creates diffs between two items being trees, trees and index or an
|
||||
index and the working tree. It will detect renames automatically.
|
||||
|
||||
:param other:
|
||||
Is the item to compare us with.
|
||||
If None, we will be compared to the working tree.
|
||||
If Treeish, it will be compared against the respective tree
|
||||
If Index ( type ), it will be compared against the index.
|
||||
If git.NULL_TREE, it will compare against the empty tree.
|
||||
It defaults to Index to assure the method will not by-default fail
|
||||
on bare repositories.
|
||||
|
||||
:param paths:
|
||||
is a list of paths or a single path to limit the diff to.
|
||||
It will only include at least one of the given path or paths.
|
||||
|
||||
:param create_patch:
|
||||
If True, the returned Diff contains a detailed patch that if applied
|
||||
makes the self to other. Patches are somewhat costly as blobs have to be read
|
||||
and diffed.
|
||||
|
||||
:param kwargs:
|
||||
Additional arguments passed to git-diff, such as
|
||||
R=True to swap both sides of the diff.
|
||||
|
||||
:return: git.DiffIndex
|
||||
|
||||
:note:
|
||||
On a bare repository, 'other' needs to be provided as Index or as
|
||||
as Tree/Commit, or a git command error will occur"""
|
||||
args: List[Union[PathLike, Diffable, Type["Diffable.Index"], object]] = []
|
||||
args.append("--abbrev=40") # we need full shas
|
||||
args.append("--full-index") # get full index paths, not only filenames
|
||||
|
||||
# remove default '-M' arg (check for renames) if user is overriding it
|
||||
if not any(x in kwargs for x in ('find_renames', 'no_renames', 'M')):
|
||||
args.append("-M")
|
||||
|
||||
if create_patch:
|
||||
args.append("-p")
|
||||
else:
|
||||
args.append("--raw")
|
||||
args.append("-z")
|
||||
|
||||
# in any way, assure we don't see colored output,
|
||||
# fixes https://github.com/gitpython-developers/GitPython/issues/172
|
||||
args.append("--no-color")
|
||||
|
||||
if paths is not None and not isinstance(paths, (tuple, list)):
|
||||
paths = [paths]
|
||||
|
||||
if hasattr(self, "Has_Repo"):
|
||||
self.repo: "Repo" = self.repo
|
||||
|
||||
diff_cmd = self.repo.git.diff
|
||||
if other is self.Index:
|
||||
args.insert(0, "--cached")
|
||||
elif other is NULL_TREE:
|
||||
args.insert(0, "-r") # recursive diff-tree
|
||||
args.insert(0, "--root")
|
||||
diff_cmd = self.repo.git.diff_tree
|
||||
elif other is not None:
|
||||
args.insert(0, "-r") # recursive diff-tree
|
||||
args.insert(0, other)
|
||||
diff_cmd = self.repo.git.diff_tree
|
||||
|
||||
args.insert(0, self)
|
||||
|
||||
# paths is list here or None
|
||||
if paths:
|
||||
args.append("--")
|
||||
args.extend(paths)
|
||||
# END paths handling
|
||||
|
||||
kwargs["as_process"] = True
|
||||
proc = diff_cmd(*self._process_diff_args(args), **kwargs)
|
||||
|
||||
diff_method = Diff._index_from_patch_format if create_patch else Diff._index_from_raw_format
|
||||
index = diff_method(self.repo, proc)
|
||||
|
||||
proc.wait()
|
||||
return index
|
||||
|
||||
|
||||
T_Diff = TypeVar("T_Diff", bound="Diff")
|
||||
|
||||
|
||||
class DiffIndex(List[T_Diff]):
|
||||
|
||||
"""Implements an Index for diffs, allowing a list of Diffs to be queried by
|
||||
the diff properties.
|
||||
|
||||
The class improves the diff handling convenience"""
|
||||
|
||||
# change type invariant identifying possible ways a blob can have changed
|
||||
# A = Added
|
||||
# D = Deleted
|
||||
# R = Renamed
|
||||
# M = Modified
|
||||
# T = Changed in the type
|
||||
change_type = ("A", "C", "D", "R", "M", "T")
|
||||
|
||||
def iter_change_type(self, change_type: Lit_change_type) -> Iterator[T_Diff]:
|
||||
"""
|
||||
:return:
|
||||
iterator yielding Diff instances that match the given change_type
|
||||
|
||||
:param change_type:
|
||||
Member of DiffIndex.change_type, namely:
|
||||
|
||||
* 'A' for added paths
|
||||
* 'D' for deleted paths
|
||||
* 'R' for renamed paths
|
||||
* 'M' for paths with modified data
|
||||
* 'T' for changed in the type paths
|
||||
"""
|
||||
if change_type not in self.change_type:
|
||||
raise ValueError("Invalid change type: %s" % change_type)
|
||||
|
||||
for diffidx in self:
|
||||
if diffidx.change_type == change_type:
|
||||
yield diffidx
|
||||
elif change_type == "A" and diffidx.new_file:
|
||||
yield diffidx
|
||||
elif change_type == "D" and diffidx.deleted_file:
|
||||
yield diffidx
|
||||
elif change_type == "C" and diffidx.copied_file:
|
||||
yield diffidx
|
||||
elif change_type == "R" and diffidx.renamed:
|
||||
yield diffidx
|
||||
elif change_type == "M" and diffidx.a_blob and diffidx.b_blob and diffidx.a_blob != diffidx.b_blob:
|
||||
yield diffidx
|
||||
# END for each diff
|
||||
|
||||
|
||||
class Diff(object):
|
||||
|
||||
"""A Diff contains diff information between two Trees.
|
||||
|
||||
It contains two sides a and b of the diff, members are prefixed with
|
||||
"a" and "b" respectively to inidcate that.
|
||||
|
||||
Diffs keep information about the changed blob objects, the file mode, renames,
|
||||
deletions and new files.
|
||||
|
||||
There are a few cases where None has to be expected as member variable value:
|
||||
|
||||
``New File``::
|
||||
|
||||
a_mode is None
|
||||
a_blob is None
|
||||
a_path is None
|
||||
|
||||
``Deleted File``::
|
||||
|
||||
b_mode is None
|
||||
b_blob is None
|
||||
b_path is None
|
||||
|
||||
``Working Tree Blobs``
|
||||
|
||||
When comparing to working trees, the working tree blob will have a null hexsha
|
||||
as a corresponding object does not yet exist. The mode will be null as well.
|
||||
But the path will be available though.
|
||||
If it is listed in a diff the working tree version of the file must
|
||||
be different to the version in the index or tree, and hence has been modified."""
|
||||
|
||||
# precompiled regex
|
||||
re_header = re.compile(
|
||||
rb"""
|
||||
^diff[ ]--git
|
||||
[ ](?P<a_path_fallback>"?[ab]/.+?"?)[ ](?P<b_path_fallback>"?[ab]/.+?"?)\n
|
||||
(?:^old[ ]mode[ ](?P<old_mode>\d+)\n
|
||||
^new[ ]mode[ ](?P<new_mode>\d+)(?:\n|$))?
|
||||
(?:^similarity[ ]index[ ]\d+%\n
|
||||
^rename[ ]from[ ](?P<rename_from>.*)\n
|
||||
^rename[ ]to[ ](?P<rename_to>.*)(?:\n|$))?
|
||||
(?:^new[ ]file[ ]mode[ ](?P<new_file_mode>.+)(?:\n|$))?
|
||||
(?:^deleted[ ]file[ ]mode[ ](?P<deleted_file_mode>.+)(?:\n|$))?
|
||||
(?:^similarity[ ]index[ ]\d+%\n
|
||||
^copy[ ]from[ ].*\n
|
||||
^copy[ ]to[ ](?P<copied_file_name>.*)(?:\n|$))?
|
||||
(?:^index[ ](?P<a_blob_id>[0-9A-Fa-f]+)
|
||||
\.\.(?P<b_blob_id>[0-9A-Fa-f]+)[ ]?(?P<b_mode>.+)?(?:\n|$))?
|
||||
(?:^---[ ](?P<a_path>[^\t\n\r\f\v]*)[\t\r\f\v]*(?:\n|$))?
|
||||
(?:^\+\+\+[ ](?P<b_path>[^\t\n\r\f\v]*)[\t\r\f\v]*(?:\n|$))?
|
||||
""",
|
||||
re.VERBOSE | re.MULTILINE,
|
||||
)
|
||||
# can be used for comparisons
|
||||
NULL_HEX_SHA = "0" * 40
|
||||
NULL_BIN_SHA = b"\0" * 20
|
||||
|
||||
__slots__ = (
|
||||
"a_blob",
|
||||
"b_blob",
|
||||
"a_mode",
|
||||
"b_mode",
|
||||
"a_rawpath",
|
||||
"b_rawpath",
|
||||
"new_file",
|
||||
"deleted_file",
|
||||
"copied_file",
|
||||
"raw_rename_from",
|
||||
"raw_rename_to",
|
||||
"diff",
|
||||
"change_type",
|
||||
"score",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: "Repo",
|
||||
a_rawpath: Optional[bytes],
|
||||
b_rawpath: Optional[bytes],
|
||||
a_blob_id: Union[str, bytes, None],
|
||||
b_blob_id: Union[str, bytes, None],
|
||||
a_mode: Union[bytes, str, None],
|
||||
b_mode: Union[bytes, str, None],
|
||||
new_file: bool,
|
||||
deleted_file: bool,
|
||||
copied_file: bool,
|
||||
raw_rename_from: Optional[bytes],
|
||||
raw_rename_to: Optional[bytes],
|
||||
diff: Union[str, bytes, None],
|
||||
change_type: Optional[Lit_change_type],
|
||||
score: Optional[int],
|
||||
) -> None:
|
||||
|
||||
assert a_rawpath is None or isinstance(a_rawpath, bytes)
|
||||
assert b_rawpath is None or isinstance(b_rawpath, bytes)
|
||||
self.a_rawpath = a_rawpath
|
||||
self.b_rawpath = b_rawpath
|
||||
|
||||
self.a_mode = mode_str_to_int(a_mode) if a_mode else None
|
||||
self.b_mode = mode_str_to_int(b_mode) if b_mode else None
|
||||
|
||||
# Determine whether this diff references a submodule, if it does then
|
||||
# we need to overwrite "repo" to the corresponding submodule's repo instead
|
||||
if repo and a_rawpath:
|
||||
for submodule in repo.submodules:
|
||||
if submodule.path == a_rawpath.decode(defenc, "replace"):
|
||||
if submodule.module_exists():
|
||||
repo = submodule.module()
|
||||
break
|
||||
|
||||
self.a_blob: Union["IndexObject", None]
|
||||
if a_blob_id is None or a_blob_id == self.NULL_HEX_SHA:
|
||||
self.a_blob = None
|
||||
else:
|
||||
self.a_blob = Blob(repo, hex_to_bin(a_blob_id), mode=self.a_mode, path=self.a_path)
|
||||
|
||||
self.b_blob: Union["IndexObject", None]
|
||||
if b_blob_id is None or b_blob_id == self.NULL_HEX_SHA:
|
||||
self.b_blob = None
|
||||
else:
|
||||
self.b_blob = Blob(repo, hex_to_bin(b_blob_id), mode=self.b_mode, path=self.b_path)
|
||||
|
||||
self.new_file: bool = new_file
|
||||
self.deleted_file: bool = deleted_file
|
||||
self.copied_file: bool = copied_file
|
||||
|
||||
# be clear and use None instead of empty strings
|
||||
assert raw_rename_from is None or isinstance(raw_rename_from, bytes)
|
||||
assert raw_rename_to is None or isinstance(raw_rename_to, bytes)
|
||||
self.raw_rename_from = raw_rename_from or None
|
||||
self.raw_rename_to = raw_rename_to or None
|
||||
|
||||
self.diff = diff
|
||||
self.change_type: Union[Lit_change_type, None] = change_type
|
||||
self.score = score
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
for name in self.__slots__:
|
||||
if getattr(self, name) != getattr(other, name):
|
||||
return False
|
||||
# END for each name
|
||||
return True
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not (self == other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(tuple(getattr(self, n) for n in self.__slots__))
|
||||
|
||||
def __str__(self) -> str:
|
||||
h: str = "%s"
|
||||
if self.a_blob:
|
||||
h %= self.a_blob.path
|
||||
elif self.b_blob:
|
||||
h %= self.b_blob.path
|
||||
|
||||
msg: str = ""
|
||||
line = None # temp line
|
||||
line_length = 0 # line length
|
||||
for b, n in zip((self.a_blob, self.b_blob), ("lhs", "rhs")):
|
||||
if b:
|
||||
line = "\n%s: %o | %s" % (n, b.mode, b.hexsha)
|
||||
else:
|
||||
line = "\n%s: None" % n
|
||||
# END if blob is not None
|
||||
line_length = max(len(line), line_length)
|
||||
msg += line
|
||||
# END for each blob
|
||||
|
||||
# add headline
|
||||
h += "\n" + "=" * line_length
|
||||
|
||||
if self.deleted_file:
|
||||
msg += "\nfile deleted in rhs"
|
||||
if self.new_file:
|
||||
msg += "\nfile added in rhs"
|
||||
if self.copied_file:
|
||||
msg += "\nfile %r copied from %r" % (self.b_path, self.a_path)
|
||||
if self.rename_from:
|
||||
msg += "\nfile renamed from %r" % self.rename_from
|
||||
if self.rename_to:
|
||||
msg += "\nfile renamed to %r" % self.rename_to
|
||||
if self.diff:
|
||||
msg += "\n---"
|
||||
try:
|
||||
msg += self.diff.decode(defenc) if isinstance(self.diff, bytes) else self.diff
|
||||
except UnicodeDecodeError:
|
||||
msg += "OMITTED BINARY DATA"
|
||||
# end handle encoding
|
||||
msg += "\n---"
|
||||
# END diff info
|
||||
|
||||
# Python2 silliness: have to assure we convert our likely to be unicode object to a string with the
|
||||
# right encoding. Otherwise it tries to convert it using ascii, which may fail ungracefully
|
||||
res = h + msg
|
||||
# end
|
||||
return res
|
||||
|
||||
@property
|
||||
def a_path(self) -> Optional[str]:
|
||||
return self.a_rawpath.decode(defenc, "replace") if self.a_rawpath else None
|
||||
|
||||
@property
|
||||
def b_path(self) -> Optional[str]:
|
||||
return self.b_rawpath.decode(defenc, "replace") if self.b_rawpath else None
|
||||
|
||||
@property
|
||||
def rename_from(self) -> Optional[str]:
|
||||
return self.raw_rename_from.decode(defenc, "replace") if self.raw_rename_from else None
|
||||
|
||||
@property
|
||||
def rename_to(self) -> Optional[str]:
|
||||
return self.raw_rename_to.decode(defenc, "replace") if self.raw_rename_to else None
|
||||
|
||||
@property
|
||||
def renamed(self) -> bool:
|
||||
""":returns: True if the blob of our diff has been renamed
|
||||
:note: This property is deprecated, please use ``renamed_file`` instead.
|
||||
"""
|
||||
return self.renamed_file
|
||||
|
||||
@property
|
||||
def renamed_file(self) -> bool:
|
||||
""":returns: True if the blob of our diff has been renamed"""
|
||||
return self.rename_from != self.rename_to
|
||||
|
||||
@classmethod
|
||||
def _pick_best_path(cls, path_match: bytes, rename_match: bytes, path_fallback_match: bytes) -> Optional[bytes]:
|
||||
if path_match:
|
||||
return decode_path(path_match)
|
||||
|
||||
if rename_match:
|
||||
return decode_path(rename_match, has_ab_prefix=False)
|
||||
|
||||
if path_fallback_match:
|
||||
return decode_path(path_fallback_match)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _index_from_patch_format(cls, repo: "Repo", proc: Union["Popen", "Git.AutoInterrupt"]) -> DiffIndex:
|
||||
"""Create a new DiffIndex from the given text which must be in patch format
|
||||
:param repo: is the repository we are operating on - it is required
|
||||
:param stream: result of 'git diff' as a stream (supporting file protocol)
|
||||
:return: git.DiffIndex"""
|
||||
|
||||
## FIXME: Here SLURPING raw, need to re-phrase header-regexes linewise.
|
||||
text_list: List[bytes] = []
|
||||
handle_process_output(proc, text_list.append, None, finalize_process, decode_streams=False)
|
||||
|
||||
# for now, we have to bake the stream
|
||||
text = b"".join(text_list)
|
||||
index: "DiffIndex" = DiffIndex()
|
||||
previous_header: Union[Match[bytes], None] = None
|
||||
header: Union[Match[bytes], None] = None
|
||||
a_path, b_path = None, None # for mypy
|
||||
a_mode, b_mode = None, None # for mypy
|
||||
for _header in cls.re_header.finditer(text):
|
||||
(
|
||||
a_path_fallback,
|
||||
b_path_fallback,
|
||||
old_mode,
|
||||
new_mode,
|
||||
rename_from,
|
||||
rename_to,
|
||||
new_file_mode,
|
||||
deleted_file_mode,
|
||||
copied_file_name,
|
||||
a_blob_id,
|
||||
b_blob_id,
|
||||
b_mode,
|
||||
a_path,
|
||||
b_path,
|
||||
) = _header.groups()
|
||||
|
||||
new_file, deleted_file, copied_file = (
|
||||
bool(new_file_mode),
|
||||
bool(deleted_file_mode),
|
||||
bool(copied_file_name),
|
||||
)
|
||||
|
||||
a_path = cls._pick_best_path(a_path, rename_from, a_path_fallback)
|
||||
b_path = cls._pick_best_path(b_path, rename_to, b_path_fallback)
|
||||
|
||||
# Our only means to find the actual text is to see what has not been matched by our regex,
|
||||
# and then retro-actively assign it to our index
|
||||
if previous_header is not None:
|
||||
index[-1].diff = text[previous_header.end() : _header.start()]
|
||||
# end assign actual diff
|
||||
|
||||
# Make sure the mode is set if the path is set. Otherwise the resulting blob is invalid
|
||||
# We just use the one mode we should have parsed
|
||||
a_mode = old_mode or deleted_file_mode or (a_path and (b_mode or new_mode or new_file_mode))
|
||||
b_mode = b_mode or new_mode or new_file_mode or (b_path and a_mode)
|
||||
index.append(
|
||||
Diff(
|
||||
repo,
|
||||
a_path,
|
||||
b_path,
|
||||
a_blob_id and a_blob_id.decode(defenc),
|
||||
b_blob_id and b_blob_id.decode(defenc),
|
||||
a_mode and a_mode.decode(defenc),
|
||||
b_mode and b_mode.decode(defenc),
|
||||
new_file,
|
||||
deleted_file,
|
||||
copied_file,
|
||||
rename_from,
|
||||
rename_to,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
)
|
||||
|
||||
previous_header = _header
|
||||
header = _header
|
||||
# end for each header we parse
|
||||
if index and header:
|
||||
index[-1].diff = text[header.end() :]
|
||||
# end assign last diff
|
||||
|
||||
return index
|
||||
|
||||
@staticmethod
|
||||
def _handle_diff_line(lines_bytes: bytes, repo: "Repo", index: DiffIndex) -> None:
|
||||
lines = lines_bytes.decode(defenc)
|
||||
|
||||
# Discard everything before the first colon, and the colon itself.
|
||||
_, _, lines = lines.partition(":")
|
||||
|
||||
for line in lines.split("\x00:"):
|
||||
if not line:
|
||||
# The line data is empty, skip
|
||||
continue
|
||||
meta, _, path = line.partition("\x00")
|
||||
path = path.rstrip("\x00")
|
||||
a_blob_id: Optional[str]
|
||||
b_blob_id: Optional[str]
|
||||
old_mode, new_mode, a_blob_id, b_blob_id, _change_type = meta.split(None, 4)
|
||||
# Change type can be R100
|
||||
# R: status letter
|
||||
# 100: score (in case of copy and rename)
|
||||
# assert is_change_type(_change_type[0]), f"Unexpected value for change_type received: {_change_type[0]}"
|
||||
change_type: Lit_change_type = cast(Lit_change_type, _change_type[0])
|
||||
score_str = "".join(_change_type[1:])
|
||||
score = int(score_str) if score_str.isdigit() else None
|
||||
path = path.strip()
|
||||
a_path = path.encode(defenc)
|
||||
b_path = path.encode(defenc)
|
||||
deleted_file = False
|
||||
new_file = False
|
||||
copied_file = False
|
||||
rename_from = None
|
||||
rename_to = None
|
||||
|
||||
# NOTE: We cannot conclude from the existence of a blob to change type
|
||||
# as diffs with the working do not have blobs yet
|
||||
if change_type == "D":
|
||||
b_blob_id = None # Optional[str]
|
||||
deleted_file = True
|
||||
elif change_type == "A":
|
||||
a_blob_id = None
|
||||
new_file = True
|
||||
elif change_type == "C":
|
||||
copied_file = True
|
||||
a_path_str, b_path_str = path.split("\x00", 1)
|
||||
a_path = a_path_str.encode(defenc)
|
||||
b_path = b_path_str.encode(defenc)
|
||||
elif change_type == "R":
|
||||
a_path_str, b_path_str = path.split("\x00", 1)
|
||||
a_path = a_path_str.encode(defenc)
|
||||
b_path = b_path_str.encode(defenc)
|
||||
rename_from, rename_to = a_path, b_path
|
||||
elif change_type == "T":
|
||||
# Nothing to do
|
||||
pass
|
||||
# END add/remove handling
|
||||
|
||||
diff = Diff(
|
||||
repo,
|
||||
a_path,
|
||||
b_path,
|
||||
a_blob_id,
|
||||
b_blob_id,
|
||||
old_mode,
|
||||
new_mode,
|
||||
new_file,
|
||||
deleted_file,
|
||||
copied_file,
|
||||
rename_from,
|
||||
rename_to,
|
||||
"",
|
||||
change_type,
|
||||
score,
|
||||
)
|
||||
index.append(diff)
|
||||
|
||||
@classmethod
|
||||
def _index_from_raw_format(cls, repo: "Repo", proc: "Popen") -> "DiffIndex":
|
||||
"""Create a new DiffIndex from the given stream which must be in raw format.
|
||||
:return: git.DiffIndex"""
|
||||
# handles
|
||||
# :100644 100644 687099101... 37c5e30c8... M .gitignore
|
||||
|
||||
index: "DiffIndex" = DiffIndex()
|
||||
handle_process_output(
|
||||
proc,
|
||||
lambda byt: cls._handle_diff_line(byt, repo, index),
|
||||
None,
|
||||
finalize_process,
|
||||
decode_streams=False,
|
||||
)
|
||||
|
||||
return index
|
||||
186
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/exc.py
Normal file
186
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/exc.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# exc.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
""" Module containing all exceptions thrown throughout the git package, """
|
||||
|
||||
from gitdb.exc import BadName # NOQA @UnusedWildImport skipcq: PYL-W0401, PYL-W0614
|
||||
from gitdb.exc import * # NOQA @UnusedWildImport skipcq: PYL-W0401, PYL-W0614
|
||||
from git.compat import safe_decode
|
||||
from git.util import remove_password_if_present
|
||||
|
||||
# typing ----------------------------------------------------
|
||||
|
||||
from typing import List, Sequence, Tuple, Union, TYPE_CHECKING
|
||||
from git.types import PathLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo.base import Repo
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class GitError(Exception):
|
||||
"""Base class for all package exceptions"""
|
||||
|
||||
|
||||
class InvalidGitRepositoryError(GitError):
|
||||
"""Thrown if the given repository appears to have an invalid format."""
|
||||
|
||||
|
||||
class WorkTreeRepositoryUnsupported(InvalidGitRepositoryError):
|
||||
"""Thrown to indicate we can't handle work tree repositories"""
|
||||
|
||||
|
||||
class NoSuchPathError(GitError, OSError):
|
||||
"""Thrown if a path could not be access by the system."""
|
||||
|
||||
|
||||
class UnsafeProtocolError(GitError):
|
||||
"""Thrown if unsafe protocols are passed without being explicitly allowed."""
|
||||
|
||||
|
||||
class UnsafeOptionError(GitError):
|
||||
"""Thrown if unsafe options are passed without being explicitly allowed."""
|
||||
|
||||
|
||||
class CommandError(GitError):
|
||||
"""Base class for exceptions thrown at every stage of `Popen()` execution.
|
||||
|
||||
:param command:
|
||||
A non-empty list of argv comprising the command-line.
|
||||
"""
|
||||
|
||||
#: A unicode print-format with 2 `%s for `<cmdline>` and the rest,
|
||||
#: e.g.
|
||||
#: "'%s' failed%s"
|
||||
_msg = "Cmd('%s') failed%s"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: Union[List[str], Tuple[str, ...], str],
|
||||
status: Union[str, int, None, Exception] = None,
|
||||
stderr: Union[bytes, str, None] = None,
|
||||
stdout: Union[bytes, str, None] = None,
|
||||
) -> None:
|
||||
if not isinstance(command, (tuple, list)):
|
||||
command = command.split()
|
||||
self.command = remove_password_if_present(command)
|
||||
self.status = status
|
||||
if status:
|
||||
if isinstance(status, Exception):
|
||||
status = "%s('%s')" % (type(status).__name__, safe_decode(str(status)))
|
||||
else:
|
||||
try:
|
||||
status = "exit code(%s)" % int(status)
|
||||
except (ValueError, TypeError):
|
||||
s = safe_decode(str(status))
|
||||
status = "'%s'" % s if isinstance(status, str) else s
|
||||
|
||||
self._cmd = safe_decode(self.command[0])
|
||||
self._cmdline = " ".join(safe_decode(i) for i in self.command)
|
||||
self._cause = status and " due to: %s" % status or "!"
|
||||
stdout_decode = safe_decode(stdout)
|
||||
stderr_decode = safe_decode(stderr)
|
||||
self.stdout = stdout_decode and "\n stdout: '%s'" % stdout_decode or ""
|
||||
self.stderr = stderr_decode and "\n stderr: '%s'" % stderr_decode or ""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (self._msg + "\n cmdline: %s%s%s") % (
|
||||
self._cmd,
|
||||
self._cause,
|
||||
self._cmdline,
|
||||
self.stdout,
|
||||
self.stderr,
|
||||
)
|
||||
|
||||
|
||||
class GitCommandNotFound(CommandError):
|
||||
"""Thrown if we cannot find the `git` executable in the PATH or at the path given by
|
||||
the GIT_PYTHON_GIT_EXECUTABLE environment variable"""
|
||||
|
||||
def __init__(self, command: Union[List[str], Tuple[str], str], cause: Union[str, Exception]) -> None:
|
||||
super(GitCommandNotFound, self).__init__(command, cause)
|
||||
self._msg = "Cmd('%s') not found%s"
|
||||
|
||||
|
||||
class GitCommandError(CommandError):
|
||||
"""Thrown if execution of the git command fails with non-zero status code."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: Union[List[str], Tuple[str, ...], str],
|
||||
status: Union[str, int, None, Exception] = None,
|
||||
stderr: Union[bytes, str, None] = None,
|
||||
stdout: Union[bytes, str, None] = None,
|
||||
) -> None:
|
||||
super(GitCommandError, self).__init__(command, status, stderr, stdout)
|
||||
|
||||
|
||||
class CheckoutError(GitError):
|
||||
"""Thrown if a file could not be checked out from the index as it contained
|
||||
changes.
|
||||
|
||||
The .failed_files attribute contains a list of relative paths that failed
|
||||
to be checked out as they contained changes that did not exist in the index.
|
||||
|
||||
The .failed_reasons attribute contains a string informing about the actual
|
||||
cause of the issue.
|
||||
|
||||
The .valid_files attribute contains a list of relative paths to files that
|
||||
were checked out successfully and hence match the version stored in the
|
||||
index"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
failed_files: Sequence[PathLike],
|
||||
valid_files: Sequence[PathLike],
|
||||
failed_reasons: List[str],
|
||||
) -> None:
|
||||
|
||||
Exception.__init__(self, message)
|
||||
self.failed_files = failed_files
|
||||
self.failed_reasons = failed_reasons
|
||||
self.valid_files = valid_files
|
||||
|
||||
def __str__(self) -> str:
|
||||
return Exception.__str__(self) + ":%s" % self.failed_files
|
||||
|
||||
|
||||
class CacheError(GitError):
|
||||
|
||||
"""Base for all errors related to the git index, which is called cache internally"""
|
||||
|
||||
|
||||
class UnmergedEntriesError(CacheError):
|
||||
"""Thrown if an operation cannot proceed as there are still unmerged
|
||||
entries in the cache"""
|
||||
|
||||
|
||||
class HookExecutionError(CommandError):
|
||||
"""Thrown if a hook exits with a non-zero exit code. It provides access to the exit code and the string returned
|
||||
via standard output"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: Union[List[str], Tuple[str, ...], str],
|
||||
status: Union[str, int, None, Exception],
|
||||
stderr: Union[bytes, str, None] = None,
|
||||
stdout: Union[bytes, str, None] = None,
|
||||
) -> None:
|
||||
|
||||
super(HookExecutionError, self).__init__(command, status, stderr, stdout)
|
||||
self._msg = "Hook('%s') failed%s"
|
||||
|
||||
|
||||
class RepositoryDirtyError(GitError):
|
||||
"""Thrown whenever an operation on a repository fails as it has uncommitted changes that would be overwritten"""
|
||||
|
||||
def __init__(self, repo: "Repo", message: str) -> None:
|
||||
self.repo = repo
|
||||
self.message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "Operation cannot be performed on %r: %s" % (self.repo, self.message)
|
||||
@@ -0,0 +1,4 @@
|
||||
"""Initialize the index package"""
|
||||
# flake8: noqa
|
||||
from .base import *
|
||||
from .typ import *
|
||||
1401
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/base.py
Normal file
1401
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/base.py
Normal file
File diff suppressed because it is too large
Load Diff
444
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/fun.py
Normal file
444
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/fun.py
Normal file
@@ -0,0 +1,444 @@
|
||||
# Contains standalone functions to accompany the index implementation and make it
|
||||
# more versatile
|
||||
# NOTE: Autodoc hates it if this is a docstring
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import os
|
||||
from stat import (
|
||||
S_IFDIR,
|
||||
S_IFLNK,
|
||||
S_ISLNK,
|
||||
S_ISDIR,
|
||||
S_IFMT,
|
||||
S_IFREG,
|
||||
S_IXUSR,
|
||||
)
|
||||
import subprocess
|
||||
|
||||
from git.cmd import PROC_CREATIONFLAGS, handle_process_output
|
||||
from git.compat import (
|
||||
defenc,
|
||||
force_text,
|
||||
force_bytes,
|
||||
is_posix,
|
||||
is_win,
|
||||
safe_decode,
|
||||
)
|
||||
from git.exc import UnmergedEntriesError, HookExecutionError
|
||||
from git.objects.fun import (
|
||||
tree_to_stream,
|
||||
traverse_tree_recursive,
|
||||
traverse_trees_recursive,
|
||||
)
|
||||
from git.util import IndexFileSHA1Writer, finalize_process
|
||||
from gitdb.base import IStream
|
||||
from gitdb.typ import str_tree_type
|
||||
|
||||
import os.path as osp
|
||||
|
||||
from .typ import BaseIndexEntry, IndexEntry, CE_NAMEMASK, CE_STAGESHIFT
|
||||
from .util import pack, unpack
|
||||
|
||||
# typing -----------------------------------------------------------------------------
|
||||
|
||||
from typing import Dict, IO, List, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast
|
||||
|
||||
from git.types import PathLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import IndexFile
|
||||
from git.db import GitCmdObjectDB
|
||||
from git.objects.tree import TreeCacheTup
|
||||
|
||||
# from git.objects.fun import EntryTupOrNone
|
||||
|
||||
# ------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
S_IFGITLINK = S_IFLNK | S_IFDIR # a submodule
|
||||
CE_NAMEMASK_INV = ~CE_NAMEMASK
|
||||
|
||||
__all__ = (
|
||||
"write_cache",
|
||||
"read_cache",
|
||||
"write_tree_from_cache",
|
||||
"entry_key",
|
||||
"stat_mode_to_index_mode",
|
||||
"S_IFGITLINK",
|
||||
"run_commit_hook",
|
||||
"hook_path",
|
||||
)
|
||||
|
||||
|
||||
def hook_path(name: str, git_dir: PathLike) -> str:
|
||||
""":return: path to the given named hook in the given git repository directory"""
|
||||
return osp.join(git_dir, "hooks", name)
|
||||
|
||||
|
||||
def _has_file_extension(path):
|
||||
return osp.splitext(path)[1]
|
||||
|
||||
|
||||
def run_commit_hook(name: str, index: "IndexFile", *args: str) -> None:
|
||||
"""Run the commit hook of the given name. Silently ignores hooks that do not exist.
|
||||
|
||||
:param name: name of hook, like 'pre-commit'
|
||||
:param index: IndexFile instance
|
||||
:param args: arguments passed to hook file
|
||||
:raises HookExecutionError:"""
|
||||
hp = hook_path(name, index.repo.git_dir)
|
||||
if not os.access(hp, os.X_OK):
|
||||
return None
|
||||
|
||||
env = os.environ.copy()
|
||||
env["GIT_INDEX_FILE"] = safe_decode(str(index.path))
|
||||
env["GIT_EDITOR"] = ":"
|
||||
cmd = [hp]
|
||||
try:
|
||||
if is_win and not _has_file_extension(hp):
|
||||
# Windows only uses extensions to determine how to open files
|
||||
# (doesn't understand shebangs). Try using bash to run the hook.
|
||||
relative_hp = Path(hp).relative_to(index.repo.working_dir).as_posix()
|
||||
cmd = ["bash.exe", relative_hp]
|
||||
|
||||
cmd = subprocess.Popen(
|
||||
cmd + list(args),
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=index.repo.working_dir,
|
||||
close_fds=is_posix,
|
||||
creationflags=PROC_CREATIONFLAGS,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise HookExecutionError(hp, ex) from ex
|
||||
else:
|
||||
stdout_list: List[str] = []
|
||||
stderr_list: List[str] = []
|
||||
handle_process_output(cmd, stdout_list.append, stderr_list.append, finalize_process)
|
||||
stdout = "".join(stdout_list)
|
||||
stderr = "".join(stderr_list)
|
||||
if cmd.returncode != 0:
|
||||
stdout = force_text(stdout, defenc)
|
||||
stderr = force_text(stderr, defenc)
|
||||
raise HookExecutionError(hp, cmd.returncode, stderr, stdout)
|
||||
# end handle return code
|
||||
|
||||
|
||||
def stat_mode_to_index_mode(mode: int) -> int:
|
||||
"""Convert the given mode from a stat call to the corresponding index mode
|
||||
and return it"""
|
||||
if S_ISLNK(mode): # symlinks
|
||||
return S_IFLNK
|
||||
if S_ISDIR(mode) or S_IFMT(mode) == S_IFGITLINK: # submodules
|
||||
return S_IFGITLINK
|
||||
return S_IFREG | (mode & S_IXUSR and 0o755 or 0o644) # blobs with or without executable bit
|
||||
|
||||
|
||||
def write_cache(
|
||||
entries: Sequence[Union[BaseIndexEntry, "IndexEntry"]],
|
||||
stream: IO[bytes],
|
||||
extension_data: Union[None, bytes] = None,
|
||||
ShaStreamCls: Type[IndexFileSHA1Writer] = IndexFileSHA1Writer,
|
||||
) -> None:
|
||||
"""Write the cache represented by entries to a stream
|
||||
|
||||
:param entries: **sorted** list of entries
|
||||
:param stream: stream to wrap into the AdapterStreamCls - it is used for
|
||||
final output.
|
||||
|
||||
:param ShaStreamCls: Type to use when writing to the stream. It produces a sha
|
||||
while writing to it, before the data is passed on to the wrapped stream
|
||||
|
||||
:param extension_data: any kind of data to write as a trailer, it must begin
|
||||
a 4 byte identifier, followed by its size ( 4 bytes )"""
|
||||
# wrap the stream into a compatible writer
|
||||
stream_sha = ShaStreamCls(stream)
|
||||
|
||||
tell = stream_sha.tell
|
||||
write = stream_sha.write
|
||||
|
||||
# header
|
||||
version = 2
|
||||
write(b"DIRC")
|
||||
write(pack(">LL", version, len(entries)))
|
||||
|
||||
# body
|
||||
for entry in entries:
|
||||
beginoffset = tell()
|
||||
write(entry.ctime_bytes) # ctime
|
||||
write(entry.mtime_bytes) # mtime
|
||||
path_str = str(entry.path)
|
||||
path: bytes = force_bytes(path_str, encoding=defenc)
|
||||
plen = len(path) & CE_NAMEMASK # path length
|
||||
assert plen == len(path), "Path %s too long to fit into index" % entry.path
|
||||
flags = plen | (entry.flags & CE_NAMEMASK_INV) # clear possible previous values
|
||||
write(
|
||||
pack(
|
||||
">LLLLLL20sH",
|
||||
entry.dev,
|
||||
entry.inode,
|
||||
entry.mode,
|
||||
entry.uid,
|
||||
entry.gid,
|
||||
entry.size,
|
||||
entry.binsha,
|
||||
flags,
|
||||
)
|
||||
)
|
||||
write(path)
|
||||
real_size = (tell() - beginoffset + 8) & ~7
|
||||
write(b"\0" * ((beginoffset + real_size) - tell()))
|
||||
# END for each entry
|
||||
|
||||
# write previously cached extensions data
|
||||
if extension_data is not None:
|
||||
stream_sha.write(extension_data)
|
||||
|
||||
# write the sha over the content
|
||||
stream_sha.write_sha()
|
||||
|
||||
|
||||
def read_header(stream: IO[bytes]) -> Tuple[int, int]:
|
||||
"""Return tuple(version_long, num_entries) from the given stream"""
|
||||
type_id = stream.read(4)
|
||||
if type_id != b"DIRC":
|
||||
raise AssertionError("Invalid index file header: %r" % type_id)
|
||||
unpacked = cast(Tuple[int, int], unpack(">LL", stream.read(4 * 2)))
|
||||
version, num_entries = unpacked
|
||||
|
||||
# TODO: handle version 3: extended data, see read-cache.c
|
||||
assert version in (1, 2)
|
||||
return version, num_entries
|
||||
|
||||
|
||||
def entry_key(*entry: Union[BaseIndexEntry, PathLike, int]) -> Tuple[PathLike, int]:
|
||||
""":return: Key suitable to be used for the index.entries dictionary
|
||||
:param entry: One instance of type BaseIndexEntry or the path and the stage"""
|
||||
|
||||
# def is_entry_key_tup(entry_key: Tuple) -> TypeGuard[Tuple[PathLike, int]]:
|
||||
# return isinstance(entry_key, tuple) and len(entry_key) == 2
|
||||
|
||||
if len(entry) == 1:
|
||||
entry_first = entry[0]
|
||||
assert isinstance(entry_first, BaseIndexEntry)
|
||||
return (entry_first.path, entry_first.stage)
|
||||
else:
|
||||
# assert is_entry_key_tup(entry)
|
||||
entry = cast(Tuple[PathLike, int], entry)
|
||||
return entry
|
||||
# END handle entry
|
||||
|
||||
|
||||
def read_cache(
|
||||
stream: IO[bytes],
|
||||
) -> Tuple[int, Dict[Tuple[PathLike, int], "IndexEntry"], bytes, bytes]:
|
||||
"""Read a cache file from the given stream
|
||||
|
||||
:return: tuple(version, entries_dict, extension_data, content_sha)
|
||||
|
||||
* version is the integer version number
|
||||
* entries dict is a dictionary which maps IndexEntry instances to a path at a stage
|
||||
* extension_data is '' or 4 bytes of type + 4 bytes of size + size bytes
|
||||
* content_sha is a 20 byte sha on all cache file contents"""
|
||||
version, num_entries = read_header(stream)
|
||||
count = 0
|
||||
entries: Dict[Tuple[PathLike, int], "IndexEntry"] = {}
|
||||
|
||||
read = stream.read
|
||||
tell = stream.tell
|
||||
while count < num_entries:
|
||||
beginoffset = tell()
|
||||
ctime = unpack(">8s", read(8))[0]
|
||||
mtime = unpack(">8s", read(8))[0]
|
||||
(dev, ino, mode, uid, gid, size, sha, flags) = unpack(">LLLLLL20sH", read(20 + 4 * 6 + 2))
|
||||
path_size = flags & CE_NAMEMASK
|
||||
path = read(path_size).decode(defenc)
|
||||
|
||||
real_size = (tell() - beginoffset + 8) & ~7
|
||||
read((beginoffset + real_size) - tell())
|
||||
entry = IndexEntry((mode, sha, flags, path, ctime, mtime, dev, ino, uid, gid, size))
|
||||
# entry_key would be the method to use, but we safe the effort
|
||||
entries[(path, entry.stage)] = entry
|
||||
count += 1
|
||||
# END for each entry
|
||||
|
||||
# the footer contains extension data and a sha on the content so far
|
||||
# Keep the extension footer,and verify we have a sha in the end
|
||||
# Extension data format is:
|
||||
# 4 bytes ID
|
||||
# 4 bytes length of chunk
|
||||
# repeated 0 - N times
|
||||
extension_data = stream.read(~0)
|
||||
assert (
|
||||
len(extension_data) > 19
|
||||
), "Index Footer was not at least a sha on content as it was only %i bytes in size" % len(extension_data)
|
||||
|
||||
content_sha = extension_data[-20:]
|
||||
|
||||
# truncate the sha in the end as we will dynamically create it anyway
|
||||
extension_data = extension_data[:-20]
|
||||
|
||||
return (version, entries, extension_data, content_sha)
|
||||
|
||||
|
||||
def write_tree_from_cache(
|
||||
entries: List[IndexEntry], odb: "GitCmdObjectDB", sl: slice, si: int = 0
|
||||
) -> Tuple[bytes, List["TreeCacheTup"]]:
|
||||
"""Create a tree from the given sorted list of entries and put the respective
|
||||
trees into the given object database
|
||||
|
||||
:param entries: **sorted** list of IndexEntries
|
||||
:param odb: object database to store the trees in
|
||||
:param si: start index at which we should start creating subtrees
|
||||
:param sl: slice indicating the range we should process on the entries list
|
||||
:return: tuple(binsha, list(tree_entry, ...)) a tuple of a sha and a list of
|
||||
tree entries being a tuple of hexsha, mode, name"""
|
||||
tree_items: List["TreeCacheTup"] = []
|
||||
|
||||
ci = sl.start
|
||||
end = sl.stop
|
||||
while ci < end:
|
||||
entry = entries[ci]
|
||||
if entry.stage != 0:
|
||||
raise UnmergedEntriesError(entry)
|
||||
# END abort on unmerged
|
||||
ci += 1
|
||||
rbound = entry.path.find("/", si)
|
||||
if rbound == -1:
|
||||
# its not a tree
|
||||
tree_items.append((entry.binsha, entry.mode, entry.path[si:]))
|
||||
else:
|
||||
# find common base range
|
||||
base = entry.path[si:rbound]
|
||||
xi = ci
|
||||
while xi < end:
|
||||
oentry = entries[xi]
|
||||
orbound = oentry.path.find("/", si)
|
||||
if orbound == -1 or oentry.path[si:orbound] != base:
|
||||
break
|
||||
# END abort on base mismatch
|
||||
xi += 1
|
||||
# END find common base
|
||||
|
||||
# enter recursion
|
||||
# ci - 1 as we want to count our current item as well
|
||||
sha, _tree_entry_list = write_tree_from_cache(entries, odb, slice(ci - 1, xi), rbound + 1)
|
||||
tree_items.append((sha, S_IFDIR, base))
|
||||
|
||||
# skip ahead
|
||||
ci = xi
|
||||
# END handle bounds
|
||||
# END for each entry
|
||||
|
||||
# finally create the tree
|
||||
sio = BytesIO()
|
||||
tree_to_stream(tree_items, sio.write) # writes to stream as bytes, but doesn't change tree_items
|
||||
sio.seek(0)
|
||||
|
||||
istream = odb.store(IStream(str_tree_type, len(sio.getvalue()), sio))
|
||||
return (istream.binsha, tree_items)
|
||||
|
||||
|
||||
def _tree_entry_to_baseindexentry(tree_entry: "TreeCacheTup", stage: int) -> BaseIndexEntry:
|
||||
return BaseIndexEntry((tree_entry[1], tree_entry[0], stage << CE_STAGESHIFT, tree_entry[2]))
|
||||
|
||||
|
||||
def aggressive_tree_merge(odb: "GitCmdObjectDB", tree_shas: Sequence[bytes]) -> List[BaseIndexEntry]:
|
||||
"""
|
||||
:return: list of BaseIndexEntries representing the aggressive merge of the given
|
||||
trees. All valid entries are on stage 0, whereas the conflicting ones are left
|
||||
on stage 1, 2 or 3, whereas stage 1 corresponds to the common ancestor tree,
|
||||
2 to our tree and 3 to 'their' tree.
|
||||
:param tree_shas: 1, 2 or 3 trees as identified by their binary 20 byte shas
|
||||
If 1 or two, the entries will effectively correspond to the last given tree
|
||||
If 3 are given, a 3 way merge is performed"""
|
||||
out: List[BaseIndexEntry] = []
|
||||
|
||||
# one and two way is the same for us, as we don't have to handle an existing
|
||||
# index, instrea
|
||||
if len(tree_shas) in (1, 2):
|
||||
for entry in traverse_tree_recursive(odb, tree_shas[-1], ""):
|
||||
out.append(_tree_entry_to_baseindexentry(entry, 0))
|
||||
# END for each entry
|
||||
return out
|
||||
# END handle single tree
|
||||
|
||||
if len(tree_shas) > 3:
|
||||
raise ValueError("Cannot handle %i trees at once" % len(tree_shas))
|
||||
|
||||
# three trees
|
||||
for base, ours, theirs in traverse_trees_recursive(odb, tree_shas, ""):
|
||||
if base is not None:
|
||||
# base version exists
|
||||
if ours is not None:
|
||||
# ours exists
|
||||
if theirs is not None:
|
||||
# it exists in all branches, if it was changed in both
|
||||
# its a conflict, otherwise we take the changed version
|
||||
# This should be the most common branch, so it comes first
|
||||
if (base[0] != ours[0] and base[0] != theirs[0] and ours[0] != theirs[0]) or (
|
||||
base[1] != ours[1] and base[1] != theirs[1] and ours[1] != theirs[1]
|
||||
):
|
||||
# changed by both
|
||||
out.append(_tree_entry_to_baseindexentry(base, 1))
|
||||
out.append(_tree_entry_to_baseindexentry(ours, 2))
|
||||
out.append(_tree_entry_to_baseindexentry(theirs, 3))
|
||||
elif base[0] != ours[0] or base[1] != ours[1]:
|
||||
# only we changed it
|
||||
out.append(_tree_entry_to_baseindexentry(ours, 0))
|
||||
else:
|
||||
# either nobody changed it, or they did. In either
|
||||
# case, use theirs
|
||||
out.append(_tree_entry_to_baseindexentry(theirs, 0))
|
||||
# END handle modification
|
||||
else:
|
||||
|
||||
if ours[0] != base[0] or ours[1] != base[1]:
|
||||
# they deleted it, we changed it, conflict
|
||||
out.append(_tree_entry_to_baseindexentry(base, 1))
|
||||
out.append(_tree_entry_to_baseindexentry(ours, 2))
|
||||
# else:
|
||||
# we didn't change it, ignore
|
||||
# pass
|
||||
# END handle our change
|
||||
# END handle theirs
|
||||
else:
|
||||
if theirs is None:
|
||||
# deleted in both, its fine - its out
|
||||
pass
|
||||
else:
|
||||
if theirs[0] != base[0] or theirs[1] != base[1]:
|
||||
# deleted in ours, changed theirs, conflict
|
||||
out.append(_tree_entry_to_baseindexentry(base, 1))
|
||||
out.append(_tree_entry_to_baseindexentry(theirs, 3))
|
||||
# END theirs changed
|
||||
# else:
|
||||
# theirs didn't change
|
||||
# pass
|
||||
# END handle theirs
|
||||
# END handle ours
|
||||
else:
|
||||
# all three can't be None
|
||||
if ours is None:
|
||||
# added in their branch
|
||||
assert theirs is not None
|
||||
out.append(_tree_entry_to_baseindexentry(theirs, 0))
|
||||
elif theirs is None:
|
||||
# added in our branch
|
||||
out.append(_tree_entry_to_baseindexentry(ours, 0))
|
||||
else:
|
||||
# both have it, except for the base, see whether it changed
|
||||
if ours[0] != theirs[0] or ours[1] != theirs[1]:
|
||||
out.append(_tree_entry_to_baseindexentry(ours, 2))
|
||||
out.append(_tree_entry_to_baseindexentry(theirs, 3))
|
||||
else:
|
||||
# it was added the same in both
|
||||
out.append(_tree_entry_to_baseindexentry(ours, 0))
|
||||
# END handle two items
|
||||
# END handle heads
|
||||
# END handle base exists
|
||||
# END for each entries tuple
|
||||
|
||||
return out
|
||||
191
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/typ.py
Normal file
191
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/typ.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Module with additional types used by the index"""
|
||||
|
||||
from binascii import b2a_hex
|
||||
from pathlib import Path
|
||||
|
||||
from .util import pack, unpack
|
||||
from git.objects import Blob
|
||||
|
||||
|
||||
# typing ----------------------------------------------------------------------
|
||||
|
||||
from typing import NamedTuple, Sequence, TYPE_CHECKING, Tuple, Union, cast, List
|
||||
|
||||
from git.types import PathLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
|
||||
StageType = int
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
|
||||
__all__ = ("BlobFilter", "BaseIndexEntry", "IndexEntry", "StageType")
|
||||
|
||||
# { Invariants
|
||||
CE_NAMEMASK = 0x0FFF
|
||||
CE_STAGEMASK = 0x3000
|
||||
CE_EXTENDED = 0x4000
|
||||
CE_VALID = 0x8000
|
||||
CE_STAGESHIFT = 12
|
||||
|
||||
# } END invariants
|
||||
|
||||
|
||||
class BlobFilter(object):
|
||||
|
||||
"""
|
||||
Predicate to be used by iter_blobs allowing to filter only return blobs which
|
||||
match the given list of directories or files.
|
||||
|
||||
The given paths are given relative to the repository.
|
||||
"""
|
||||
|
||||
__slots__ = "paths"
|
||||
|
||||
def __init__(self, paths: Sequence[PathLike]) -> None:
|
||||
"""
|
||||
:param paths:
|
||||
tuple or list of paths which are either pointing to directories or
|
||||
to files relative to the current repository
|
||||
"""
|
||||
self.paths = paths
|
||||
|
||||
def __call__(self, stage_blob: Tuple[StageType, Blob]) -> bool:
|
||||
blob_pathlike: PathLike = stage_blob[1].path
|
||||
blob_path: Path = blob_pathlike if isinstance(blob_pathlike, Path) else Path(blob_pathlike)
|
||||
for pathlike in self.paths:
|
||||
path: Path = pathlike if isinstance(pathlike, Path) else Path(pathlike)
|
||||
# TODO: Change to use `PosixPath.is_relative_to` once Python 3.8 is no longer supported.
|
||||
filter_parts: List[str] = path.parts
|
||||
blob_parts: List[str] = blob_path.parts
|
||||
if len(filter_parts) > len(blob_parts):
|
||||
continue
|
||||
if all(i == j for i, j in zip(filter_parts, blob_parts)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class BaseIndexEntryHelper(NamedTuple):
|
||||
"""Typed namedtuple to provide named attribute access for BaseIndexEntry.
|
||||
Needed to allow overriding __new__ in child class to preserve backwards compat."""
|
||||
|
||||
mode: int
|
||||
binsha: bytes
|
||||
flags: int
|
||||
path: PathLike
|
||||
ctime_bytes: bytes = pack(">LL", 0, 0)
|
||||
mtime_bytes: bytes = pack(">LL", 0, 0)
|
||||
dev: int = 0
|
||||
inode: int = 0
|
||||
uid: int = 0
|
||||
gid: int = 0
|
||||
size: int = 0
|
||||
|
||||
|
||||
class BaseIndexEntry(BaseIndexEntryHelper):
|
||||
|
||||
"""Small Brother of an index entry which can be created to describe changes
|
||||
done to the index in which case plenty of additional information is not required.
|
||||
|
||||
As the first 4 data members match exactly to the IndexEntry type, methods
|
||||
expecting a BaseIndexEntry can also handle full IndexEntries even if they
|
||||
use numeric indices for performance reasons.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
inp_tuple: Union[
|
||||
Tuple[int, bytes, int, PathLike],
|
||||
Tuple[int, bytes, int, PathLike, bytes, bytes, int, int, int, int, int],
|
||||
],
|
||||
) -> "BaseIndexEntry":
|
||||
"""Override __new__ to allow construction from a tuple for backwards compatibility"""
|
||||
return super().__new__(cls, *inp_tuple)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%o %s %i\t%s" % (self.mode, self.hexsha, self.stage, self.path)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "(%o, %s, %i, %s)" % (self.mode, self.hexsha, self.stage, self.path)
|
||||
|
||||
@property
|
||||
def hexsha(self) -> str:
|
||||
"""hex version of our sha"""
|
||||
return b2a_hex(self.binsha).decode("ascii")
|
||||
|
||||
@property
|
||||
def stage(self) -> int:
|
||||
"""Stage of the entry, either:
|
||||
|
||||
* 0 = default stage
|
||||
* 1 = stage before a merge or common ancestor entry in case of a 3 way merge
|
||||
* 2 = stage of entries from the 'left' side of the merge
|
||||
* 3 = stage of entries from the right side of the merge
|
||||
|
||||
:note: For more information, see http://www.kernel.org/pub/software/scm/git/docs/git-read-tree.html
|
||||
"""
|
||||
return (self.flags & CE_STAGEMASK) >> CE_STAGESHIFT
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: Blob, stage: int = 0) -> "BaseIndexEntry":
|
||||
""":return: Fully equipped BaseIndexEntry at the given stage"""
|
||||
return cls((blob.mode, blob.binsha, stage << CE_STAGESHIFT, blob.path))
|
||||
|
||||
def to_blob(self, repo: "Repo") -> Blob:
|
||||
""":return: Blob using the information of this index entry"""
|
||||
return Blob(repo, self.binsha, self.mode, self.path)
|
||||
|
||||
|
||||
class IndexEntry(BaseIndexEntry):
|
||||
|
||||
"""Allows convenient access to IndexEntry data without completely unpacking it.
|
||||
|
||||
Attributes usully accessed often are cached in the tuple whereas others are
|
||||
unpacked on demand.
|
||||
|
||||
See the properties for a mapping between names and tuple indices."""
|
||||
|
||||
@property
|
||||
def ctime(self) -> Tuple[int, int]:
|
||||
"""
|
||||
:return:
|
||||
Tuple(int_time_seconds_since_epoch, int_nano_seconds) of the
|
||||
file's creation time"""
|
||||
return cast(Tuple[int, int], unpack(">LL", self.ctime_bytes))
|
||||
|
||||
@property
|
||||
def mtime(self) -> Tuple[int, int]:
|
||||
"""See ctime property, but returns modification time"""
|
||||
return cast(Tuple[int, int], unpack(">LL", self.mtime_bytes))
|
||||
|
||||
@classmethod
|
||||
def from_base(cls, base: "BaseIndexEntry") -> "IndexEntry":
|
||||
"""
|
||||
:return:
|
||||
Minimal entry as created from the given BaseIndexEntry instance.
|
||||
Missing values will be set to null-like values
|
||||
|
||||
:param base: Instance of type BaseIndexEntry"""
|
||||
time = pack(">LL", 0, 0)
|
||||
return IndexEntry((base.mode, base.binsha, base.flags, base.path, time, time, 0, 0, 0, 0, 0))
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, blob: Blob, stage: int = 0) -> "IndexEntry":
|
||||
""":return: Minimal entry resembling the given blob object"""
|
||||
time = pack(">LL", 0, 0)
|
||||
return IndexEntry(
|
||||
(
|
||||
blob.mode,
|
||||
blob.binsha,
|
||||
stage << CE_STAGESHIFT,
|
||||
blob.path,
|
||||
time,
|
||||
time,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
blob.size,
|
||||
)
|
||||
)
|
||||
119
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/util.py
Normal file
119
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/index/util.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Module containing index utilities"""
|
||||
from functools import wraps
|
||||
import os
|
||||
import struct
|
||||
import tempfile
|
||||
|
||||
from git.compat import is_win
|
||||
|
||||
import os.path as osp
|
||||
|
||||
|
||||
# typing ----------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from git.types import PathLike, _T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.index import IndexFile
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
|
||||
|
||||
__all__ = ("TemporaryFileSwap", "post_clear_cache", "default_index", "git_working_dir")
|
||||
|
||||
# { Aliases
|
||||
pack = struct.pack
|
||||
unpack = struct.unpack
|
||||
|
||||
|
||||
# } END aliases
|
||||
|
||||
|
||||
class TemporaryFileSwap(object):
|
||||
|
||||
"""Utility class moving a file to a temporary location within the same directory
|
||||
and moving it back on to where on object deletion."""
|
||||
|
||||
__slots__ = ("file_path", "tmp_file_path")
|
||||
|
||||
def __init__(self, file_path: PathLike) -> None:
|
||||
self.file_path = file_path
|
||||
self.tmp_file_path = str(self.file_path) + tempfile.mktemp("", "", "")
|
||||
# it may be that the source does not exist
|
||||
try:
|
||||
os.rename(self.file_path, self.tmp_file_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def __del__(self) -> None:
|
||||
if osp.isfile(self.tmp_file_path):
|
||||
if is_win and osp.exists(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
os.rename(self.tmp_file_path, self.file_path)
|
||||
# END temp file exists
|
||||
|
||||
|
||||
# { Decorators
|
||||
|
||||
|
||||
def post_clear_cache(func: Callable[..., _T]) -> Callable[..., _T]:
|
||||
"""Decorator for functions that alter the index using the git command. This would
|
||||
invalidate our possibly existing entries dictionary which is why it must be
|
||||
deleted to allow it to be lazily reread later.
|
||||
|
||||
:note:
|
||||
This decorator will not be required once all functions are implemented
|
||||
natively which in fact is possible, but probably not feasible performance wise.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def post_clear_cache_if_not_raised(self: "IndexFile", *args: Any, **kwargs: Any) -> _T:
|
||||
rval = func(self, *args, **kwargs)
|
||||
self._delete_entries_cache()
|
||||
return rval
|
||||
|
||||
# END wrapper method
|
||||
|
||||
return post_clear_cache_if_not_raised
|
||||
|
||||
|
||||
def default_index(func: Callable[..., _T]) -> Callable[..., _T]:
|
||||
"""Decorator assuring the wrapped method may only run if we are the default
|
||||
repository index. This is as we rely on git commands that operate
|
||||
on that index only."""
|
||||
|
||||
@wraps(func)
|
||||
def check_default_index(self: "IndexFile", *args: Any, **kwargs: Any) -> _T:
|
||||
if self._file_path != self._index_path():
|
||||
raise AssertionError(
|
||||
"Cannot call %r on indices that do not represent the default git index" % func.__name__
|
||||
)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
# END wrapper method
|
||||
|
||||
return check_default_index
|
||||
|
||||
|
||||
def git_working_dir(func: Callable[..., _T]) -> Callable[..., _T]:
|
||||
"""Decorator which changes the current working dir to the one of the git
|
||||
repository in order to assure relative paths are handled correctly"""
|
||||
|
||||
@wraps(func)
|
||||
def set_git_working_dir(self: "IndexFile", *args: Any, **kwargs: Any) -> _T:
|
||||
cur_wd = os.getcwd()
|
||||
os.chdir(str(self.repo.working_tree_dir))
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
os.chdir(cur_wd)
|
||||
# END handle working dir
|
||||
|
||||
# END wrapper
|
||||
|
||||
return set_git_working_dir
|
||||
|
||||
|
||||
# } END decorators
|
||||
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Import all submodules main classes into the package space
|
||||
"""
|
||||
# flake8: noqa
|
||||
import inspect
|
||||
|
||||
from .base import *
|
||||
from .blob import *
|
||||
from .commit import *
|
||||
from .submodule import util as smutil
|
||||
from .submodule.base import *
|
||||
from .submodule.root import *
|
||||
from .tag import *
|
||||
from .tree import *
|
||||
|
||||
# Fix import dependency - add IndexObject to the util module, so that it can be
|
||||
# imported by the submodule.base
|
||||
smutil.IndexObject = IndexObject # type: ignore[attr-defined]
|
||||
smutil.Object = Object # type: ignore[attr-defined]
|
||||
del smutil
|
||||
|
||||
# must come after submodule was made available
|
||||
|
||||
__all__ = [name for name, obj in locals().items() if not (name.startswith("_") or inspect.ismodule(obj))]
|
||||
@@ -0,0 +1,224 @@
|
||||
# base.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
|
||||
from git.exc import WorkTreeRepositoryUnsupported
|
||||
from git.util import LazyMixin, join_path_native, stream_copy, bin_to_hex
|
||||
|
||||
import gitdb.typ as dbtyp
|
||||
import os.path as osp
|
||||
|
||||
from .util import get_object_type_by_name
|
||||
|
||||
|
||||
# typing ------------------------------------------------------------------
|
||||
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
from git.types import PathLike, Commit_ish, Lit_commit_ish
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
from gitdb.base import OStream
|
||||
from .tree import Tree
|
||||
from .blob import Blob
|
||||
from .submodule.base import Submodule
|
||||
from git.refs.reference import Reference
|
||||
|
||||
IndexObjUnion = Union["Tree", "Blob", "Submodule"]
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
_assertion_msg_format = "Created object %r whose python type %r disagrees with the actual git object type %r"
|
||||
|
||||
__all__ = ("Object", "IndexObject")
|
||||
|
||||
|
||||
class Object(LazyMixin):
|
||||
|
||||
"""Implements an Object which may be Blobs, Trees, Commits and Tags"""
|
||||
|
||||
NULL_HEX_SHA = "0" * 40
|
||||
NULL_BIN_SHA = b"\0" * 20
|
||||
|
||||
TYPES = (
|
||||
dbtyp.str_blob_type,
|
||||
dbtyp.str_tree_type,
|
||||
dbtyp.str_commit_type,
|
||||
dbtyp.str_tag_type,
|
||||
)
|
||||
__slots__ = ("repo", "binsha", "size")
|
||||
type: Union[Lit_commit_ish, None] = None
|
||||
|
||||
def __init__(self, repo: "Repo", binsha: bytes):
|
||||
"""Initialize an object by identifying it by its binary sha.
|
||||
All keyword arguments will be set on demand if None.
|
||||
|
||||
:param repo: repository this object is located in
|
||||
|
||||
:param binsha: 20 byte SHA1"""
|
||||
super(Object, self).__init__()
|
||||
self.repo = repo
|
||||
self.binsha = binsha
|
||||
assert len(binsha) == 20, "Require 20 byte binary sha, got %r, len = %i" % (
|
||||
binsha,
|
||||
len(binsha),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new(cls, repo: "Repo", id: Union[str, "Reference"]) -> Commit_ish:
|
||||
"""
|
||||
:return: New Object instance of a type appropriate to the object type behind
|
||||
id. The id of the newly created object will be a binsha even though
|
||||
the input id may have been a Reference or Rev-Spec
|
||||
|
||||
:param id: reference, rev-spec, or hexsha
|
||||
|
||||
:note: This cannot be a __new__ method as it would always call __init__
|
||||
with the input id which is not necessarily a binsha."""
|
||||
return repo.rev_parse(str(id))
|
||||
|
||||
@classmethod
|
||||
def new_from_sha(cls, repo: "Repo", sha1: bytes) -> Commit_ish:
|
||||
"""
|
||||
:return: new object instance of a type appropriate to represent the given
|
||||
binary sha1
|
||||
:param sha1: 20 byte binary sha1"""
|
||||
if sha1 == cls.NULL_BIN_SHA:
|
||||
# the NULL binsha is always the root commit
|
||||
return get_object_type_by_name(b"commit")(repo, sha1)
|
||||
# END handle special case
|
||||
oinfo = repo.odb.info(sha1)
|
||||
inst = get_object_type_by_name(oinfo.type)(repo, oinfo.binsha)
|
||||
inst.size = oinfo.size
|
||||
return inst
|
||||
|
||||
def _set_cache_(self, attr: str) -> None:
|
||||
"""Retrieve object information"""
|
||||
if attr == "size":
|
||||
oinfo = self.repo.odb.info(self.binsha)
|
||||
self.size = oinfo.size # type: int
|
||||
# assert oinfo.type == self.type, _assertion_msg_format % (self.binsha, oinfo.type, self.type)
|
||||
else:
|
||||
super(Object, self)._set_cache_(attr)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
""":return: True if the objects have the same SHA1"""
|
||||
if not hasattr(other, "binsha"):
|
||||
return False
|
||||
return self.binsha == other.binsha
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
""":return: True if the objects do not have the same SHA1"""
|
||||
if not hasattr(other, "binsha"):
|
||||
return True
|
||||
return self.binsha != other.binsha
|
||||
|
||||
def __hash__(self) -> int:
|
||||
""":return: Hash of our id allowing objects to be used in dicts and sets"""
|
||||
return hash(self.binsha)
|
||||
|
||||
def __str__(self) -> str:
|
||||
""":return: string of our SHA1 as understood by all git commands"""
|
||||
return self.hexsha
|
||||
|
||||
def __repr__(self) -> str:
|
||||
""":return: string with pythonic representation of our object"""
|
||||
return '<git.%s "%s">' % (self.__class__.__name__, self.hexsha)
|
||||
|
||||
@property
|
||||
def hexsha(self) -> str:
|
||||
""":return: 40 byte hex version of our 20 byte binary sha"""
|
||||
# b2a_hex produces bytes
|
||||
return bin_to_hex(self.binsha).decode("ascii")
|
||||
|
||||
@property
|
||||
def data_stream(self) -> "OStream":
|
||||
""":return: File Object compatible stream to the uncompressed raw data of the object
|
||||
:note: returned streams must be read in order"""
|
||||
return self.repo.odb.stream(self.binsha)
|
||||
|
||||
def stream_data(self, ostream: "OStream") -> "Object":
|
||||
"""Writes our data directly to the given output stream
|
||||
|
||||
:param ostream: File object compatible stream object.
|
||||
:return: self"""
|
||||
istream = self.repo.odb.stream(self.binsha)
|
||||
stream_copy(istream, ostream)
|
||||
return self
|
||||
|
||||
|
||||
class IndexObject(Object):
|
||||
|
||||
"""Base for all objects that can be part of the index file , namely Tree, Blob and
|
||||
SubModule objects"""
|
||||
|
||||
__slots__ = ("path", "mode")
|
||||
|
||||
# for compatibility with iterable lists
|
||||
_id_attribute_ = "path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: "Repo",
|
||||
binsha: bytes,
|
||||
mode: Union[None, int] = None,
|
||||
path: Union[None, PathLike] = None,
|
||||
) -> None:
|
||||
"""Initialize a newly instanced IndexObject
|
||||
|
||||
:param repo: is the Repo we are located in
|
||||
:param binsha: 20 byte sha1
|
||||
:param mode:
|
||||
is the stat compatible file mode as int, use the stat module
|
||||
to evaluate the information
|
||||
:param path:
|
||||
is the path to the file in the file system, relative to the git repository root, i.e.
|
||||
file.ext or folder/other.ext
|
||||
:note:
|
||||
Path may not be set of the index object has been created directly as it cannot
|
||||
be retrieved without knowing the parent tree."""
|
||||
super(IndexObject, self).__init__(repo, binsha)
|
||||
if mode is not None:
|
||||
self.mode = mode
|
||||
if path is not None:
|
||||
self.path = path
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
:return:
|
||||
Hash of our path as index items are uniquely identifiable by path, not
|
||||
by their data !"""
|
||||
return hash(self.path)
|
||||
|
||||
def _set_cache_(self, attr: str) -> None:
|
||||
if attr in IndexObject.__slots__:
|
||||
# they cannot be retrieved lateron ( not without searching for them )
|
||||
raise AttributeError(
|
||||
"Attribute '%s' unset: path and mode attributes must have been set during %s object creation"
|
||||
% (attr, type(self).__name__)
|
||||
)
|
||||
else:
|
||||
super(IndexObject, self)._set_cache_(attr)
|
||||
# END handle slot attribute
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
""":return: Name portion of the path, effectively being the basename"""
|
||||
return osp.basename(self.path)
|
||||
|
||||
@property
|
||||
def abspath(self) -> PathLike:
|
||||
"""
|
||||
:return:
|
||||
Absolute path to this index object in the file system ( as opposed to the
|
||||
.path field which is a path relative to the git repository ).
|
||||
|
||||
The returned path will be native to the system and contains '\' on windows."""
|
||||
if self.repo.working_tree_dir is not None:
|
||||
return join_path_native(self.repo.working_tree_dir, self.path)
|
||||
else:
|
||||
raise WorkTreeRepositoryUnsupported("Working_tree_dir was None or empty")
|
||||
@@ -0,0 +1,36 @@
|
||||
# blob.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
from mimetypes import guess_type
|
||||
from . import base
|
||||
|
||||
from git.types import Literal
|
||||
|
||||
__all__ = ("Blob",)
|
||||
|
||||
|
||||
class Blob(base.IndexObject):
|
||||
|
||||
"""A Blob encapsulates a git blob object"""
|
||||
|
||||
DEFAULT_MIME_TYPE = "text/plain"
|
||||
type: Literal["blob"] = "blob"
|
||||
|
||||
# valid blob modes
|
||||
executable_mode = 0o100755
|
||||
file_mode = 0o100644
|
||||
link_mode = 0o120000
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def mime_type(self) -> str:
|
||||
"""
|
||||
:return: String describing the mime type of this file (based on the filename)
|
||||
:note: Defaults to 'text/plain' in case the actual file type is unknown."""
|
||||
guesses = None
|
||||
if self.path:
|
||||
guesses = guess_type(str(self.path))
|
||||
return guesses and guesses[0] or self.DEFAULT_MIME_TYPE
|
||||
@@ -0,0 +1,762 @@
|
||||
# commit.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
import datetime
|
||||
import re
|
||||
from subprocess import Popen, PIPE
|
||||
from gitdb import IStream
|
||||
from git.util import hex_to_bin, Actor, Stats, finalize_process
|
||||
from git.diff import Diffable
|
||||
from git.cmd import Git
|
||||
|
||||
from .tree import Tree
|
||||
from . import base
|
||||
from .util import (
|
||||
Serializable,
|
||||
TraversableIterableObj,
|
||||
parse_date,
|
||||
altz_to_utctz_str,
|
||||
parse_actor_and_date,
|
||||
from_timestamp,
|
||||
)
|
||||
|
||||
from time import time, daylight, altzone, timezone, localtime
|
||||
import os
|
||||
from io import BytesIO
|
||||
import logging
|
||||
|
||||
|
||||
# typing ------------------------------------------------------------------
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
IO,
|
||||
Iterator,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
cast,
|
||||
Dict,
|
||||
)
|
||||
|
||||
from git.types import PathLike, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
from git.refs import SymbolicReference
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
log = logging.getLogger("git.objects.commit")
|
||||
log.addHandler(logging.NullHandler())
|
||||
|
||||
__all__ = ("Commit",)
|
||||
|
||||
|
||||
class Commit(base.Object, TraversableIterableObj, Diffable, Serializable):
|
||||
|
||||
"""Wraps a git Commit object.
|
||||
|
||||
This class will act lazily on some of its attributes and will query the
|
||||
value on demand only if it involves calling the git binary."""
|
||||
|
||||
# ENVIRONMENT VARIABLES
|
||||
# read when creating new commits
|
||||
env_author_date = "GIT_AUTHOR_DATE"
|
||||
env_committer_date = "GIT_COMMITTER_DATE"
|
||||
|
||||
# CONFIGURATION KEYS
|
||||
conf_encoding = "i18n.commitencoding"
|
||||
|
||||
# INVARIANTS
|
||||
default_encoding = "UTF-8"
|
||||
|
||||
# object configuration
|
||||
type: Literal["commit"] = "commit"
|
||||
__slots__ = (
|
||||
"tree",
|
||||
"author",
|
||||
"authored_date",
|
||||
"author_tz_offset",
|
||||
"committer",
|
||||
"committed_date",
|
||||
"committer_tz_offset",
|
||||
"message",
|
||||
"parents",
|
||||
"encoding",
|
||||
"gpgsig",
|
||||
)
|
||||
_id_attribute_ = "hexsha"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: "Repo",
|
||||
binsha: bytes,
|
||||
tree: Union[Tree, None] = None,
|
||||
author: Union[Actor, None] = None,
|
||||
authored_date: Union[int, None] = None,
|
||||
author_tz_offset: Union[None, float] = None,
|
||||
committer: Union[Actor, None] = None,
|
||||
committed_date: Union[int, None] = None,
|
||||
committer_tz_offset: Union[None, float] = None,
|
||||
message: Union[str, bytes, None] = None,
|
||||
parents: Union[Sequence["Commit"], None] = None,
|
||||
encoding: Union[str, None] = None,
|
||||
gpgsig: Union[str, None] = None,
|
||||
) -> None:
|
||||
"""Instantiate a new Commit. All keyword arguments taking None as default will
|
||||
be implicitly set on first query.
|
||||
|
||||
:param binsha: 20 byte sha1
|
||||
:param parents: tuple( Commit, ... )
|
||||
is a tuple of commit ids or actual Commits
|
||||
:param tree: Tree object
|
||||
:param author: Actor
|
||||
is the author Actor object
|
||||
:param authored_date: int_seconds_since_epoch
|
||||
is the authored DateTime - use time.gmtime() to convert it into a
|
||||
different format
|
||||
:param author_tz_offset: int_seconds_west_of_utc
|
||||
is the timezone that the authored_date is in
|
||||
:param committer: Actor
|
||||
is the committer string
|
||||
:param committed_date: int_seconds_since_epoch
|
||||
is the committed DateTime - use time.gmtime() to convert it into a
|
||||
different format
|
||||
:param committer_tz_offset: int_seconds_west_of_utc
|
||||
is the timezone that the committed_date is in
|
||||
:param message: string
|
||||
is the commit message
|
||||
:param encoding: string
|
||||
encoding of the message, defaults to UTF-8
|
||||
:param parents:
|
||||
List or tuple of Commit objects which are our parent(s) in the commit
|
||||
dependency graph
|
||||
:return: git.Commit
|
||||
|
||||
:note:
|
||||
Timezone information is in the same format and in the same sign
|
||||
as what time.altzone returns. The sign is inverted compared to git's
|
||||
UTC timezone."""
|
||||
super(Commit, self).__init__(repo, binsha)
|
||||
self.binsha = binsha
|
||||
if tree is not None:
|
||||
assert isinstance(tree, Tree), "Tree needs to be a Tree instance, was %s" % type(tree)
|
||||
if tree is not None:
|
||||
self.tree = tree
|
||||
if author is not None:
|
||||
self.author = author
|
||||
if authored_date is not None:
|
||||
self.authored_date = authored_date
|
||||
if author_tz_offset is not None:
|
||||
self.author_tz_offset = author_tz_offset
|
||||
if committer is not None:
|
||||
self.committer = committer
|
||||
if committed_date is not None:
|
||||
self.committed_date = committed_date
|
||||
if committer_tz_offset is not None:
|
||||
self.committer_tz_offset = committer_tz_offset
|
||||
if message is not None:
|
||||
self.message = message
|
||||
if parents is not None:
|
||||
self.parents = parents
|
||||
if encoding is not None:
|
||||
self.encoding = encoding
|
||||
if gpgsig is not None:
|
||||
self.gpgsig = gpgsig
|
||||
|
||||
@classmethod
|
||||
def _get_intermediate_items(cls, commit: "Commit") -> Tuple["Commit", ...]:
|
||||
return tuple(commit.parents)
|
||||
|
||||
@classmethod
|
||||
def _calculate_sha_(cls, repo: "Repo", commit: "Commit") -> bytes:
|
||||
"""Calculate the sha of a commit.
|
||||
|
||||
:param repo: Repo object the commit should be part of
|
||||
:param commit: Commit object for which to generate the sha
|
||||
"""
|
||||
|
||||
stream = BytesIO()
|
||||
commit._serialize(stream)
|
||||
streamlen = stream.tell()
|
||||
stream.seek(0)
|
||||
|
||||
istream = repo.odb.store(IStream(cls.type, streamlen, stream))
|
||||
return istream.binsha
|
||||
|
||||
def replace(self, **kwargs: Any) -> "Commit":
|
||||
"""Create new commit object from existing commit object.
|
||||
|
||||
Any values provided as keyword arguments will replace the
|
||||
corresponding attribute in the new object.
|
||||
"""
|
||||
|
||||
attrs = {k: getattr(self, k) for k in self.__slots__}
|
||||
|
||||
for attrname in kwargs:
|
||||
if attrname not in self.__slots__:
|
||||
raise ValueError("invalid attribute name")
|
||||
|
||||
attrs.update(kwargs)
|
||||
new_commit = self.__class__(self.repo, self.NULL_BIN_SHA, **attrs)
|
||||
new_commit.binsha = self._calculate_sha_(self.repo, new_commit)
|
||||
|
||||
return new_commit
|
||||
|
||||
def _set_cache_(self, attr: str) -> None:
|
||||
if attr in Commit.__slots__:
|
||||
# read the data in a chunk, its faster - then provide a file wrapper
|
||||
_binsha, _typename, self.size, stream = self.repo.odb.stream(self.binsha)
|
||||
self._deserialize(BytesIO(stream.read()))
|
||||
else:
|
||||
super(Commit, self)._set_cache_(attr)
|
||||
# END handle attrs
|
||||
|
||||
@property
|
||||
def authored_datetime(self) -> datetime.datetime:
|
||||
return from_timestamp(self.authored_date, self.author_tz_offset)
|
||||
|
||||
@property
|
||||
def committed_datetime(self) -> datetime.datetime:
|
||||
return from_timestamp(self.committed_date, self.committer_tz_offset)
|
||||
|
||||
@property
|
||||
def summary(self) -> Union[str, bytes]:
|
||||
""":return: First line of the commit message"""
|
||||
if isinstance(self.message, str):
|
||||
return self.message.split("\n", 1)[0]
|
||||
else:
|
||||
return self.message.split(b"\n", 1)[0]
|
||||
|
||||
def count(self, paths: Union[PathLike, Sequence[PathLike]] = "", **kwargs: Any) -> int:
|
||||
"""Count the number of commits reachable from this commit
|
||||
|
||||
:param paths:
|
||||
is an optional path or a list of paths restricting the return value
|
||||
to commits actually containing the paths
|
||||
|
||||
:param kwargs:
|
||||
Additional options to be passed to git-rev-list. They must not alter
|
||||
the output style of the command, or parsing will yield incorrect results
|
||||
:return: int defining the number of reachable commits"""
|
||||
# yes, it makes a difference whether empty paths are given or not in our case
|
||||
# as the empty paths version will ignore merge commits for some reason.
|
||||
if paths:
|
||||
return len(self.repo.git.rev_list(self.hexsha, "--", paths, **kwargs).splitlines())
|
||||
return len(self.repo.git.rev_list(self.hexsha, **kwargs).splitlines())
|
||||
|
||||
@property
|
||||
def name_rev(self) -> str:
|
||||
"""
|
||||
:return:
|
||||
String describing the commits hex sha based on the closest Reference.
|
||||
Mostly useful for UI purposes"""
|
||||
return self.repo.git.name_rev(self)
|
||||
|
||||
@classmethod
|
||||
def iter_items(
|
||||
cls,
|
||||
repo: "Repo",
|
||||
rev: Union[str, "Commit", "SymbolicReference"], # type: ignore
|
||||
paths: Union[PathLike, Sequence[PathLike]] = "",
|
||||
**kwargs: Any,
|
||||
) -> Iterator["Commit"]:
|
||||
"""Find all commits matching the given criteria.
|
||||
|
||||
:param repo: is the Repo
|
||||
:param rev: revision specifier, see git-rev-parse for viable options
|
||||
:param paths:
|
||||
is an optional path or list of paths, if set only Commits that include the path
|
||||
or paths will be considered
|
||||
:param kwargs:
|
||||
optional keyword arguments to git rev-list where
|
||||
``max_count`` is the maximum number of commits to fetch
|
||||
``skip`` is the number of commits to skip
|
||||
``since`` all commits since i.e. '1970-01-01'
|
||||
:return: iterator yielding Commit items"""
|
||||
if "pretty" in kwargs:
|
||||
raise ValueError("--pretty cannot be used as parsing expects single sha's only")
|
||||
# END handle pretty
|
||||
|
||||
# use -- in any case, to prevent possibility of ambiguous arguments
|
||||
# see https://github.com/gitpython-developers/GitPython/issues/264
|
||||
|
||||
args_list: List[PathLike] = ["--"]
|
||||
|
||||
if paths:
|
||||
paths_tup: Tuple[PathLike, ...]
|
||||
if isinstance(paths, (str, os.PathLike)):
|
||||
paths_tup = (paths,)
|
||||
else:
|
||||
paths_tup = tuple(paths)
|
||||
|
||||
args_list.extend(paths_tup)
|
||||
# END if paths
|
||||
|
||||
proc = repo.git.rev_list(rev, args_list, as_process=True, **kwargs)
|
||||
return cls._iter_from_process_or_stream(repo, proc)
|
||||
|
||||
def iter_parents(self, paths: Union[PathLike, Sequence[PathLike]] = "", **kwargs: Any) -> Iterator["Commit"]:
|
||||
"""Iterate _all_ parents of this commit.
|
||||
|
||||
:param paths:
|
||||
Optional path or list of paths limiting the Commits to those that
|
||||
contain at least one of the paths
|
||||
:param kwargs: All arguments allowed by git-rev-list
|
||||
:return: Iterator yielding Commit objects which are parents of self"""
|
||||
# skip ourselves
|
||||
skip = kwargs.get("skip", 1)
|
||||
if skip == 0: # skip ourselves
|
||||
skip = 1
|
||||
kwargs["skip"] = skip
|
||||
|
||||
return self.iter_items(self.repo, self, paths, **kwargs)
|
||||
|
||||
@property
|
||||
def stats(self) -> Stats:
|
||||
"""Create a git stat from changes between this commit and its first parent
|
||||
or from all changes done if this is the very first commit.
|
||||
|
||||
:return: git.Stats"""
|
||||
if not self.parents:
|
||||
text = self.repo.git.diff_tree(self.hexsha, "--", numstat=True, no_renames=True, root=True)
|
||||
text2 = ""
|
||||
for line in text.splitlines()[1:]:
|
||||
(insertions, deletions, filename) = line.split("\t")
|
||||
text2 += "%s\t%s\t%s\n" % (insertions, deletions, filename)
|
||||
text = text2
|
||||
else:
|
||||
text = self.repo.git.diff(self.parents[0].hexsha, self.hexsha, "--", numstat=True, no_renames=True)
|
||||
return Stats._list_from_string(self.repo, text)
|
||||
|
||||
@property
|
||||
def trailers(self) -> Dict:
|
||||
"""Get the trailers of the message as dictionary
|
||||
|
||||
Git messages can contain trailer information that are similar to RFC 822
|
||||
e-mail headers (see: https://git-scm.com/docs/git-interpret-trailers).
|
||||
|
||||
This functions calls ``git interpret-trailers --parse`` onto the message
|
||||
to extract the trailer information. The key value pairs are stripped of
|
||||
leading and trailing whitespaces before they get saved into a dictionary.
|
||||
|
||||
Valid message with trailer:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Subject line
|
||||
|
||||
some body information
|
||||
|
||||
another information
|
||||
|
||||
key1: value1
|
||||
key2 : value 2 with inner spaces
|
||||
|
||||
dictionary will look like this:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
"key1": "value1",
|
||||
"key2": "value 2 with inner spaces"
|
||||
}
|
||||
|
||||
:return: Dictionary containing whitespace stripped trailer information
|
||||
|
||||
"""
|
||||
d = {}
|
||||
cmd = ["git", "interpret-trailers", "--parse"]
|
||||
proc: Git.AutoInterrupt = self.repo.git.execute(cmd, as_process=True, istream=PIPE) # type: ignore
|
||||
trailer: str = proc.communicate(str(self.message).encode())[0].decode()
|
||||
if trailer.endswith("\n"):
|
||||
trailer = trailer[0:-1]
|
||||
if trailer != "":
|
||||
for line in trailer.split("\n"):
|
||||
key, value = line.split(":", 1)
|
||||
d[key.strip()] = value.strip()
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def _iter_from_process_or_stream(cls, repo: "Repo", proc_or_stream: Union[Popen, IO]) -> Iterator["Commit"]:
|
||||
"""Parse out commit information into a list of Commit objects
|
||||
We expect one-line per commit, and parse the actual commit information directly
|
||||
from our lighting fast object database
|
||||
|
||||
:param proc: git-rev-list process instance - one sha per line
|
||||
:return: iterator returning Commit objects"""
|
||||
|
||||
# def is_proc(inp) -> TypeGuard[Popen]:
|
||||
# return hasattr(proc_or_stream, 'wait') and not hasattr(proc_or_stream, 'readline')
|
||||
|
||||
# def is_stream(inp) -> TypeGuard[IO]:
|
||||
# return hasattr(proc_or_stream, 'readline')
|
||||
|
||||
if hasattr(proc_or_stream, "wait"):
|
||||
proc_or_stream = cast(Popen, proc_or_stream)
|
||||
if proc_or_stream.stdout is not None:
|
||||
stream = proc_or_stream.stdout
|
||||
elif hasattr(proc_or_stream, "readline"):
|
||||
proc_or_stream = cast(IO, proc_or_stream)
|
||||
stream = proc_or_stream
|
||||
|
||||
readline = stream.readline
|
||||
while True:
|
||||
line = readline()
|
||||
if not line:
|
||||
break
|
||||
hexsha = line.strip()
|
||||
if len(hexsha) > 40:
|
||||
# split additional information, as returned by bisect for instance
|
||||
hexsha, _ = line.split(None, 1)
|
||||
# END handle extra info
|
||||
|
||||
assert len(hexsha) == 40, "Invalid line: %s" % hexsha
|
||||
yield cls(repo, hex_to_bin(hexsha))
|
||||
# END for each line in stream
|
||||
# TODO: Review this - it seems process handling got a bit out of control
|
||||
# due to many developers trying to fix the open file handles issue
|
||||
if hasattr(proc_or_stream, "wait"):
|
||||
proc_or_stream = cast(Popen, proc_or_stream)
|
||||
finalize_process(proc_or_stream)
|
||||
|
||||
@classmethod
|
||||
def create_from_tree(
|
||||
cls,
|
||||
repo: "Repo",
|
||||
tree: Union[Tree, str],
|
||||
message: str,
|
||||
parent_commits: Union[None, List["Commit"]] = None,
|
||||
head: bool = False,
|
||||
author: Union[None, Actor] = None,
|
||||
committer: Union[None, Actor] = None,
|
||||
author_date: Union[None, str, datetime.datetime] = None,
|
||||
commit_date: Union[None, str, datetime.datetime] = None,
|
||||
) -> "Commit":
|
||||
"""Commit the given tree, creating a commit object.
|
||||
|
||||
:param repo: Repo object the commit should be part of
|
||||
:param tree: Tree object or hex or bin sha
|
||||
the tree of the new commit
|
||||
:param message: Commit message. It may be an empty string if no message is provided.
|
||||
It will be converted to a string , in any case.
|
||||
:param parent_commits:
|
||||
Optional Commit objects to use as parents for the new commit.
|
||||
If empty list, the commit will have no parents at all and become
|
||||
a root commit.
|
||||
If None , the current head commit will be the parent of the
|
||||
new commit object
|
||||
:param head:
|
||||
If True, the HEAD will be advanced to the new commit automatically.
|
||||
Else the HEAD will remain pointing on the previous commit. This could
|
||||
lead to undesired results when diffing files.
|
||||
:param author: The name of the author, optional. If unset, the repository
|
||||
configuration is used to obtain this value.
|
||||
:param committer: The name of the committer, optional. If unset, the
|
||||
repository configuration is used to obtain this value.
|
||||
:param author_date: The timestamp for the author field
|
||||
:param commit_date: The timestamp for the committer field
|
||||
|
||||
:return: Commit object representing the new commit
|
||||
|
||||
:note:
|
||||
Additional information about the committer and Author are taken from the
|
||||
environment or from the git configuration, see git-commit-tree for
|
||||
more information"""
|
||||
if parent_commits is None:
|
||||
try:
|
||||
parent_commits = [repo.head.commit]
|
||||
except ValueError:
|
||||
# empty repositories have no head commit
|
||||
parent_commits = []
|
||||
# END handle parent commits
|
||||
else:
|
||||
for p in parent_commits:
|
||||
if not isinstance(p, cls):
|
||||
raise ValueError(f"Parent commit '{p!r}' must be of type {cls}")
|
||||
# end check parent commit types
|
||||
# END if parent commits are unset
|
||||
|
||||
# retrieve all additional information, create a commit object, and
|
||||
# serialize it
|
||||
# Generally:
|
||||
# * Environment variables override configuration values
|
||||
# * Sensible defaults are set according to the git documentation
|
||||
|
||||
# COMMITTER AND AUTHOR INFO
|
||||
cr = repo.config_reader()
|
||||
env = os.environ
|
||||
|
||||
committer = committer or Actor.committer(cr)
|
||||
author = author or Actor.author(cr)
|
||||
|
||||
# PARSE THE DATES
|
||||
unix_time = int(time())
|
||||
is_dst = daylight and localtime().tm_isdst > 0
|
||||
offset = altzone if is_dst else timezone
|
||||
|
||||
author_date_str = env.get(cls.env_author_date, "")
|
||||
if author_date:
|
||||
author_time, author_offset = parse_date(author_date)
|
||||
elif author_date_str:
|
||||
author_time, author_offset = parse_date(author_date_str)
|
||||
else:
|
||||
author_time, author_offset = unix_time, offset
|
||||
# END set author time
|
||||
|
||||
committer_date_str = env.get(cls.env_committer_date, "")
|
||||
if commit_date:
|
||||
committer_time, committer_offset = parse_date(commit_date)
|
||||
elif committer_date_str:
|
||||
committer_time, committer_offset = parse_date(committer_date_str)
|
||||
else:
|
||||
committer_time, committer_offset = unix_time, offset
|
||||
# END set committer time
|
||||
|
||||
# assume utf8 encoding
|
||||
enc_section, enc_option = cls.conf_encoding.split(".")
|
||||
conf_encoding = cr.get_value(enc_section, enc_option, cls.default_encoding)
|
||||
if not isinstance(conf_encoding, str):
|
||||
raise TypeError("conf_encoding could not be coerced to str")
|
||||
|
||||
# if the tree is no object, make sure we create one - otherwise
|
||||
# the created commit object is invalid
|
||||
if isinstance(tree, str):
|
||||
tree = repo.tree(tree)
|
||||
# END tree conversion
|
||||
|
||||
# CREATE NEW COMMIT
|
||||
new_commit = cls(
|
||||
repo,
|
||||
cls.NULL_BIN_SHA,
|
||||
tree,
|
||||
author,
|
||||
author_time,
|
||||
author_offset,
|
||||
committer,
|
||||
committer_time,
|
||||
committer_offset,
|
||||
message,
|
||||
parent_commits,
|
||||
conf_encoding,
|
||||
)
|
||||
|
||||
new_commit.binsha = cls._calculate_sha_(repo, new_commit)
|
||||
|
||||
if head:
|
||||
# need late import here, importing git at the very beginning throws
|
||||
# as well ...
|
||||
import git.refs
|
||||
|
||||
try:
|
||||
repo.head.set_commit(new_commit, logmsg=message)
|
||||
except ValueError:
|
||||
# head is not yet set to the ref our HEAD points to
|
||||
# Happens on first commit
|
||||
master = git.refs.Head.create(
|
||||
repo,
|
||||
repo.head.ref,
|
||||
new_commit,
|
||||
logmsg="commit (initial): %s" % message,
|
||||
)
|
||||
repo.head.set_reference(master, logmsg="commit: Switching to %s" % master)
|
||||
# END handle empty repositories
|
||||
# END advance head handling
|
||||
|
||||
return new_commit
|
||||
|
||||
# { Serializable Implementation
|
||||
|
||||
def _serialize(self, stream: BytesIO) -> "Commit":
|
||||
write = stream.write
|
||||
write(("tree %s\n" % self.tree).encode("ascii"))
|
||||
for p in self.parents:
|
||||
write(("parent %s\n" % p).encode("ascii"))
|
||||
|
||||
a = self.author
|
||||
aname = a.name
|
||||
c = self.committer
|
||||
fmt = "%s %s <%s> %s %s\n"
|
||||
write(
|
||||
(
|
||||
fmt
|
||||
% (
|
||||
"author",
|
||||
aname,
|
||||
a.email,
|
||||
self.authored_date,
|
||||
altz_to_utctz_str(self.author_tz_offset),
|
||||
)
|
||||
).encode(self.encoding)
|
||||
)
|
||||
|
||||
# encode committer
|
||||
aname = c.name
|
||||
write(
|
||||
(
|
||||
fmt
|
||||
% (
|
||||
"committer",
|
||||
aname,
|
||||
c.email,
|
||||
self.committed_date,
|
||||
altz_to_utctz_str(self.committer_tz_offset),
|
||||
)
|
||||
).encode(self.encoding)
|
||||
)
|
||||
|
||||
if self.encoding != self.default_encoding:
|
||||
write(("encoding %s\n" % self.encoding).encode("ascii"))
|
||||
|
||||
try:
|
||||
if self.__getattribute__("gpgsig"):
|
||||
write(b"gpgsig")
|
||||
for sigline in self.gpgsig.rstrip("\n").split("\n"):
|
||||
write((" " + sigline + "\n").encode("ascii"))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
write(b"\n")
|
||||
|
||||
# write plain bytes, be sure its encoded according to our encoding
|
||||
if isinstance(self.message, str):
|
||||
write(self.message.encode(self.encoding))
|
||||
else:
|
||||
write(self.message)
|
||||
# END handle encoding
|
||||
return self
|
||||
|
||||
def _deserialize(self, stream: BytesIO) -> "Commit":
|
||||
"""
|
||||
:param from_rev_list: if true, the stream format is coming from the rev-list command
|
||||
Otherwise it is assumed to be a plain data stream from our object
|
||||
"""
|
||||
readline = stream.readline
|
||||
self.tree = Tree(self.repo, hex_to_bin(readline().split()[1]), Tree.tree_id << 12, "")
|
||||
|
||||
self.parents = []
|
||||
next_line = None
|
||||
while True:
|
||||
parent_line = readline()
|
||||
if not parent_line.startswith(b"parent"):
|
||||
next_line = parent_line
|
||||
break
|
||||
# END abort reading parents
|
||||
self.parents.append(type(self)(self.repo, hex_to_bin(parent_line.split()[-1].decode("ascii"))))
|
||||
# END for each parent line
|
||||
self.parents = tuple(self.parents)
|
||||
|
||||
# we don't know actual author encoding before we have parsed it, so keep the lines around
|
||||
author_line = next_line
|
||||
committer_line = readline()
|
||||
|
||||
# we might run into one or more mergetag blocks, skip those for now
|
||||
next_line = readline()
|
||||
while next_line.startswith(b"mergetag "):
|
||||
next_line = readline()
|
||||
while next_line.startswith(b" "):
|
||||
next_line = readline()
|
||||
# end skip mergetags
|
||||
|
||||
# now we can have the encoding line, or an empty line followed by the optional
|
||||
# message.
|
||||
self.encoding = self.default_encoding
|
||||
self.gpgsig = ""
|
||||
|
||||
# read headers
|
||||
enc = next_line
|
||||
buf = enc.strip()
|
||||
while buf:
|
||||
if buf[0:10] == b"encoding ":
|
||||
self.encoding = buf[buf.find(b" ") + 1 :].decode(self.encoding, "ignore")
|
||||
elif buf[0:7] == b"gpgsig ":
|
||||
sig = buf[buf.find(b" ") + 1 :] + b"\n"
|
||||
is_next_header = False
|
||||
while True:
|
||||
sigbuf = readline()
|
||||
if not sigbuf:
|
||||
break
|
||||
if sigbuf[0:1] != b" ":
|
||||
buf = sigbuf.strip()
|
||||
is_next_header = True
|
||||
break
|
||||
sig += sigbuf[1:]
|
||||
# end read all signature
|
||||
self.gpgsig = sig.rstrip(b"\n").decode(self.encoding, "ignore")
|
||||
if is_next_header:
|
||||
continue
|
||||
buf = readline().strip()
|
||||
# decode the authors name
|
||||
|
||||
try:
|
||||
(
|
||||
self.author,
|
||||
self.authored_date,
|
||||
self.author_tz_offset,
|
||||
) = parse_actor_and_date(author_line.decode(self.encoding, "replace"))
|
||||
except UnicodeDecodeError:
|
||||
log.error(
|
||||
"Failed to decode author line '%s' using encoding %s",
|
||||
author_line,
|
||||
self.encoding,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
(
|
||||
self.committer,
|
||||
self.committed_date,
|
||||
self.committer_tz_offset,
|
||||
) = parse_actor_and_date(committer_line.decode(self.encoding, "replace"))
|
||||
except UnicodeDecodeError:
|
||||
log.error(
|
||||
"Failed to decode committer line '%s' using encoding %s",
|
||||
committer_line,
|
||||
self.encoding,
|
||||
exc_info=True,
|
||||
)
|
||||
# END handle author's encoding
|
||||
|
||||
# a stream from our data simply gives us the plain message
|
||||
# The end of our message stream is marked with a newline that we strip
|
||||
self.message = stream.read()
|
||||
try:
|
||||
self.message = self.message.decode(self.encoding, "replace")
|
||||
except UnicodeDecodeError:
|
||||
log.error(
|
||||
"Failed to decode message '%s' using encoding %s",
|
||||
self.message,
|
||||
self.encoding,
|
||||
exc_info=True,
|
||||
)
|
||||
# END exception handling
|
||||
|
||||
return self
|
||||
|
||||
# } END serializable implementation
|
||||
|
||||
@property
|
||||
def co_authors(self) -> List[Actor]:
|
||||
"""
|
||||
Search the commit message for any co-authors of this commit.
|
||||
Details on co-authors: https://github.blog/2018-01-29-commit-together-with-co-authors/
|
||||
|
||||
:return: List of co-authors for this commit (as Actor objects).
|
||||
"""
|
||||
co_authors = []
|
||||
|
||||
if self.message:
|
||||
results = re.findall(
|
||||
r"^Co-authored-by: (.*) <(.*?)>$",
|
||||
self.message,
|
||||
re.MULTILINE,
|
||||
)
|
||||
for author in results:
|
||||
co_authors.append(Actor(*author))
|
||||
|
||||
return co_authors
|
||||
@@ -0,0 +1,254 @@
|
||||
"""Module with functions which are supposed to be as fast as possible"""
|
||||
from stat import S_ISDIR
|
||||
|
||||
|
||||
from git.compat import safe_decode, defenc
|
||||
|
||||
# typing ----------------------------------------------
|
||||
|
||||
from typing import (
|
||||
Callable,
|
||||
List,
|
||||
MutableSequence,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import ReadableBuffer
|
||||
from git import GitCmdObjectDB
|
||||
|
||||
EntryTup = Tuple[bytes, int, str] # same as TreeCacheTup in tree.py
|
||||
EntryTupOrNone = Union[EntryTup, None]
|
||||
|
||||
# ---------------------------------------------------
|
||||
|
||||
|
||||
__all__ = (
|
||||
"tree_to_stream",
|
||||
"tree_entries_from_data",
|
||||
"traverse_trees_recursive",
|
||||
"traverse_tree_recursive",
|
||||
)
|
||||
|
||||
|
||||
def tree_to_stream(entries: Sequence[EntryTup], write: Callable[["ReadableBuffer"], Union[int, None]]) -> None:
|
||||
"""Write the give list of entries into a stream using its write method
|
||||
|
||||
:param entries: **sorted** list of tuples with (binsha, mode, name)
|
||||
:param write: write method which takes a data string"""
|
||||
ord_zero = ord("0")
|
||||
bit_mask = 7 # 3 bits set
|
||||
|
||||
for binsha, mode, name in entries:
|
||||
mode_str = b""
|
||||
for i in range(6):
|
||||
mode_str = bytes([((mode >> (i * 3)) & bit_mask) + ord_zero]) + mode_str
|
||||
# END for each 8 octal value
|
||||
|
||||
# git slices away the first octal if its zero
|
||||
if mode_str[0] == ord_zero:
|
||||
mode_str = mode_str[1:]
|
||||
# END save a byte
|
||||
|
||||
# here it comes: if the name is actually unicode, the replacement below
|
||||
# will not work as the binsha is not part of the ascii unicode encoding -
|
||||
# hence we must convert to an utf8 string for it to work properly.
|
||||
# According to my tests, this is exactly what git does, that is it just
|
||||
# takes the input literally, which appears to be utf8 on linux.
|
||||
if isinstance(name, str):
|
||||
name_bytes = name.encode(defenc)
|
||||
else:
|
||||
name_bytes = name # type: ignore[unreachable] # check runtime types - is always str?
|
||||
write(b"".join((mode_str, b" ", name_bytes, b"\0", binsha)))
|
||||
# END for each item
|
||||
|
||||
|
||||
def tree_entries_from_data(data: bytes) -> List[EntryTup]:
|
||||
"""Reads the binary representation of a tree and returns tuples of Tree items
|
||||
|
||||
:param data: data block with tree data (as bytes)
|
||||
:return: list(tuple(binsha, mode, tree_relative_path), ...)"""
|
||||
ord_zero = ord("0")
|
||||
space_ord = ord(" ")
|
||||
len_data = len(data)
|
||||
i = 0
|
||||
out = []
|
||||
while i < len_data:
|
||||
mode = 0
|
||||
|
||||
# read mode
|
||||
# Some git versions truncate the leading 0, some don't
|
||||
# The type will be extracted from the mode later
|
||||
while data[i] != space_ord:
|
||||
# move existing mode integer up one level being 3 bits
|
||||
# and add the actual ordinal value of the character
|
||||
mode = (mode << 3) + (data[i] - ord_zero)
|
||||
i += 1
|
||||
# END while reading mode
|
||||
|
||||
# byte is space now, skip it
|
||||
i += 1
|
||||
|
||||
# parse name, it is NULL separated
|
||||
|
||||
ns = i
|
||||
while data[i] != 0:
|
||||
i += 1
|
||||
# END while not reached NULL
|
||||
|
||||
# default encoding for strings in git is utf8
|
||||
# Only use the respective unicode object if the byte stream was encoded
|
||||
name_bytes = data[ns:i]
|
||||
name = safe_decode(name_bytes)
|
||||
|
||||
# byte is NULL, get next 20
|
||||
i += 1
|
||||
sha = data[i : i + 20]
|
||||
i = i + 20
|
||||
out.append((sha, mode, name))
|
||||
# END for each byte in data stream
|
||||
return out
|
||||
|
||||
|
||||
def _find_by_name(tree_data: MutableSequence[EntryTupOrNone], name: str, is_dir: bool, start_at: int) -> EntryTupOrNone:
|
||||
"""return data entry matching the given name and tree mode
|
||||
or None.
|
||||
Before the item is returned, the respective data item is set
|
||||
None in the tree_data list to mark it done"""
|
||||
|
||||
try:
|
||||
item = tree_data[start_at]
|
||||
if item and item[2] == name and S_ISDIR(item[1]) == is_dir:
|
||||
tree_data[start_at] = None
|
||||
return item
|
||||
except IndexError:
|
||||
pass
|
||||
# END exception handling
|
||||
for index, item in enumerate(tree_data):
|
||||
if item and item[2] == name and S_ISDIR(item[1]) == is_dir:
|
||||
tree_data[index] = None
|
||||
return item
|
||||
# END if item matches
|
||||
# END for each item
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def _to_full_path(item: None, path_prefix: str) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def _to_full_path(item: EntryTup, path_prefix: str) -> EntryTup:
|
||||
...
|
||||
|
||||
|
||||
def _to_full_path(item: EntryTupOrNone, path_prefix: str) -> EntryTupOrNone:
|
||||
"""Rebuild entry with given path prefix"""
|
||||
if not item:
|
||||
return item
|
||||
return (item[0], item[1], path_prefix + item[2])
|
||||
|
||||
|
||||
def traverse_trees_recursive(
|
||||
odb: "GitCmdObjectDB", tree_shas: Sequence[Union[bytes, None]], path_prefix: str
|
||||
) -> List[Tuple[EntryTupOrNone, ...]]:
|
||||
"""
|
||||
:return: list of list with entries according to the given binary tree-shas.
|
||||
The result is encoded in a list
|
||||
of n tuple|None per blob/commit, (n == len(tree_shas)), where
|
||||
* [0] == 20 byte sha
|
||||
* [1] == mode as int
|
||||
* [2] == path relative to working tree root
|
||||
The entry tuple is None if the respective blob/commit did not
|
||||
exist in the given tree.
|
||||
:param tree_shas: iterable of shas pointing to trees. All trees must
|
||||
be on the same level. A tree-sha may be None in which case None
|
||||
:param path_prefix: a prefix to be added to the returned paths on this level,
|
||||
set it '' for the first iteration
|
||||
:note: The ordering of the returned items will be partially lost"""
|
||||
trees_data: List[List[EntryTupOrNone]] = []
|
||||
|
||||
nt = len(tree_shas)
|
||||
for tree_sha in tree_shas:
|
||||
if tree_sha is None:
|
||||
data: List[EntryTupOrNone] = []
|
||||
else:
|
||||
# make new list for typing as list invariant
|
||||
data = list(tree_entries_from_data(odb.stream(tree_sha).read()))
|
||||
# END handle muted trees
|
||||
trees_data.append(data)
|
||||
# END for each sha to get data for
|
||||
|
||||
out: List[Tuple[EntryTupOrNone, ...]] = []
|
||||
|
||||
# find all matching entries and recursively process them together if the match
|
||||
# is a tree. If the match is a non-tree item, put it into the result.
|
||||
# Processed items will be set None
|
||||
for ti, tree_data in enumerate(trees_data):
|
||||
|
||||
for ii, item in enumerate(tree_data):
|
||||
if not item:
|
||||
continue
|
||||
# END skip already done items
|
||||
entries: List[EntryTupOrNone]
|
||||
entries = [None for _ in range(nt)]
|
||||
entries[ti] = item
|
||||
_sha, mode, name = item
|
||||
is_dir = S_ISDIR(mode) # type mode bits
|
||||
|
||||
# find this item in all other tree data items
|
||||
# wrap around, but stop one before our current index, hence
|
||||
# ti+nt, not ti+1+nt
|
||||
for tio in range(ti + 1, ti + nt):
|
||||
tio = tio % nt
|
||||
entries[tio] = _find_by_name(trees_data[tio], name, is_dir, ii)
|
||||
|
||||
# END for each other item data
|
||||
# if we are a directory, enter recursion
|
||||
if is_dir:
|
||||
out.extend(
|
||||
traverse_trees_recursive(
|
||||
odb,
|
||||
[((ei and ei[0]) or None) for ei in entries],
|
||||
path_prefix + name + "/",
|
||||
)
|
||||
)
|
||||
else:
|
||||
out.append(tuple(_to_full_path(e, path_prefix) for e in entries))
|
||||
|
||||
# END handle recursion
|
||||
# finally mark it done
|
||||
tree_data[ii] = None
|
||||
# END for each item
|
||||
|
||||
# we are done with one tree, set all its data empty
|
||||
del tree_data[:]
|
||||
# END for each tree_data chunk
|
||||
return out
|
||||
|
||||
|
||||
def traverse_tree_recursive(odb: "GitCmdObjectDB", tree_sha: bytes, path_prefix: str) -> List[EntryTup]:
|
||||
"""
|
||||
:return: list of entries of the tree pointed to by the binary tree_sha. An entry
|
||||
has the following format:
|
||||
* [0] 20 byte sha
|
||||
* [1] mode as int
|
||||
* [2] path relative to the repository
|
||||
:param path_prefix: prefix to prepend to the front of all returned paths"""
|
||||
entries = []
|
||||
data = tree_entries_from_data(odb.stream(tree_sha).read())
|
||||
|
||||
# unpacking/packing is faster than accessing individual items
|
||||
for sha, mode, name in data:
|
||||
if S_ISDIR(mode):
|
||||
entries.extend(traverse_tree_recursive(odb, sha, path_prefix + name + "/"))
|
||||
else:
|
||||
entries.append((sha, mode, path_prefix + name))
|
||||
# END for each item
|
||||
|
||||
return entries
|
||||
@@ -0,0 +1,2 @@
|
||||
# NOTE: Cannot import anything here as the top-level _init_ has to handle
|
||||
# our dependencies
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,426 @@
|
||||
from .base import Submodule, UpdateProgress
|
||||
from .util import find_first_remote_branch
|
||||
from git.exc import InvalidGitRepositoryError
|
||||
import git
|
||||
|
||||
import logging
|
||||
|
||||
# typing -------------------------------------------------------------------
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from git.types import Commit_ish
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
from git.util import IterableList
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
__all__ = ["RootModule", "RootUpdateProgress"]
|
||||
|
||||
log = logging.getLogger("git.objects.submodule.root")
|
||||
log.addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
class RootUpdateProgress(UpdateProgress):
|
||||
"""Utility class which adds more opcodes to the UpdateProgress"""
|
||||
|
||||
REMOVE, PATHCHANGE, BRANCHCHANGE, URLCHANGE = [
|
||||
1 << x for x in range(UpdateProgress._num_op_codes, UpdateProgress._num_op_codes + 4)
|
||||
]
|
||||
_num_op_codes = UpdateProgress._num_op_codes + 4
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
BEGIN = RootUpdateProgress.BEGIN
|
||||
END = RootUpdateProgress.END
|
||||
REMOVE = RootUpdateProgress.REMOVE
|
||||
BRANCHCHANGE = RootUpdateProgress.BRANCHCHANGE
|
||||
URLCHANGE = RootUpdateProgress.URLCHANGE
|
||||
PATHCHANGE = RootUpdateProgress.PATHCHANGE
|
||||
|
||||
|
||||
class RootModule(Submodule):
|
||||
|
||||
"""A (virtual) Root of all submodules in the given repository. It can be used
|
||||
to more easily traverse all submodules of the master repository"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
k_root_name = "__ROOT__"
|
||||
|
||||
def __init__(self, repo: "Repo"):
|
||||
# repo, binsha, mode=None, path=None, name = None, parent_commit=None, url=None, ref=None)
|
||||
super(RootModule, self).__init__(
|
||||
repo,
|
||||
binsha=self.NULL_BIN_SHA,
|
||||
mode=self.k_default_mode,
|
||||
path="",
|
||||
name=self.k_root_name,
|
||||
parent_commit=repo.head.commit,
|
||||
url="",
|
||||
branch_path=git.Head.to_full_path(self.k_head_default),
|
||||
)
|
||||
|
||||
def _clear_cache(self) -> None:
|
||||
"""May not do anything"""
|
||||
pass
|
||||
|
||||
# { Interface
|
||||
|
||||
def update(
|
||||
self,
|
||||
previous_commit: Union[Commit_ish, None] = None, # type: ignore[override]
|
||||
recursive: bool = True,
|
||||
force_remove: bool = False,
|
||||
init: bool = True,
|
||||
to_latest_revision: bool = False,
|
||||
progress: Union[None, "RootUpdateProgress"] = None,
|
||||
dry_run: bool = False,
|
||||
force_reset: bool = False,
|
||||
keep_going: bool = False,
|
||||
) -> "RootModule":
|
||||
"""Update the submodules of this repository to the current HEAD commit.
|
||||
This method behaves smartly by determining changes of the path of a submodules
|
||||
repository, next to changes to the to-be-checked-out commit or the branch to be
|
||||
checked out. This works if the submodules ID does not change.
|
||||
Additionally it will detect addition and removal of submodules, which will be handled
|
||||
gracefully.
|
||||
|
||||
:param previous_commit: If set to a commit'ish, the commit we should use
|
||||
as the previous commit the HEAD pointed to before it was set to the commit it points to now.
|
||||
If None, it defaults to HEAD@{1} otherwise
|
||||
:param recursive: if True, the children of submodules will be updated as well
|
||||
using the same technique
|
||||
:param force_remove: If submodules have been deleted, they will be forcibly removed.
|
||||
Otherwise the update may fail if a submodule's repository cannot be deleted as
|
||||
changes have been made to it (see Submodule.update() for more information)
|
||||
:param init: If we encounter a new module which would need to be initialized, then do it.
|
||||
:param to_latest_revision: If True, instead of checking out the revision pointed to
|
||||
by this submodule's sha, the checked out tracking branch will be merged with the
|
||||
latest remote branch fetched from the repository's origin.
|
||||
Unless force_reset is specified, a local tracking branch will never be reset into its past, therefore
|
||||
the remote branch must be in the future for this to have an effect.
|
||||
:param force_reset: if True, submodules may checkout or reset their branch even if the repository has
|
||||
pending changes that would be overwritten, or if the local tracking branch is in the future of the
|
||||
remote tracking branch and would be reset into its past.
|
||||
:param progress: RootUpdateProgress instance or None if no progress should be sent
|
||||
:param dry_run: if True, operations will not actually be performed. Progress messages
|
||||
will change accordingly to indicate the WOULD DO state of the operation.
|
||||
:param keep_going: if True, we will ignore but log all errors, and keep going recursively.
|
||||
Unless dry_run is set as well, keep_going could cause subsequent/inherited errors you wouldn't see
|
||||
otherwise.
|
||||
In conjunction with dry_run, it can be useful to anticipate all errors when updating submodules
|
||||
:return: self"""
|
||||
if self.repo.bare:
|
||||
raise InvalidGitRepositoryError("Cannot update submodules in bare repositories")
|
||||
# END handle bare
|
||||
|
||||
if progress is None:
|
||||
progress = RootUpdateProgress()
|
||||
# END assure progress is set
|
||||
|
||||
prefix = ""
|
||||
if dry_run:
|
||||
prefix = "DRY-RUN: "
|
||||
|
||||
repo = self.repo
|
||||
|
||||
try:
|
||||
# SETUP BASE COMMIT
|
||||
###################
|
||||
cur_commit = repo.head.commit
|
||||
if previous_commit is None:
|
||||
try:
|
||||
previous_commit = repo.commit(repo.head.log_entry(-1).oldhexsha)
|
||||
if previous_commit.binsha == previous_commit.NULL_BIN_SHA:
|
||||
raise IndexError
|
||||
# END handle initial commit
|
||||
except IndexError:
|
||||
# in new repositories, there is no previous commit
|
||||
previous_commit = cur_commit
|
||||
# END exception handling
|
||||
else:
|
||||
previous_commit = repo.commit(previous_commit) # obtain commit object
|
||||
# END handle previous commit
|
||||
|
||||
psms: "IterableList[Submodule]" = self.list_items(repo, parent_commit=previous_commit)
|
||||
sms: "IterableList[Submodule]" = self.list_items(repo)
|
||||
spsms = set(psms)
|
||||
ssms = set(sms)
|
||||
|
||||
# HANDLE REMOVALS
|
||||
###################
|
||||
rrsm = spsms - ssms
|
||||
len_rrsm = len(rrsm)
|
||||
|
||||
for i, rsm in enumerate(rrsm):
|
||||
op = REMOVE
|
||||
if i == 0:
|
||||
op |= BEGIN
|
||||
# END handle begin
|
||||
|
||||
# fake it into thinking its at the current commit to allow deletion
|
||||
# of previous module. Trigger the cache to be updated before that
|
||||
progress.update(
|
||||
op,
|
||||
i,
|
||||
len_rrsm,
|
||||
prefix + "Removing submodule %r at %s" % (rsm.name, rsm.abspath),
|
||||
)
|
||||
rsm._parent_commit = repo.head.commit
|
||||
rsm.remove(
|
||||
configuration=False,
|
||||
module=True,
|
||||
force=force_remove,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if i == len_rrsm - 1:
|
||||
op |= END
|
||||
# END handle end
|
||||
progress.update(op, i, len_rrsm, prefix + "Done removing submodule %r" % rsm.name)
|
||||
# END for each removed submodule
|
||||
|
||||
# HANDLE PATH RENAMES
|
||||
#####################
|
||||
# url changes + branch changes
|
||||
csms = spsms & ssms
|
||||
len_csms = len(csms)
|
||||
for i, csm in enumerate(csms):
|
||||
psm: "Submodule" = psms[csm.name]
|
||||
sm: "Submodule" = sms[csm.name]
|
||||
|
||||
# PATH CHANGES
|
||||
##############
|
||||
if sm.path != psm.path and psm.module_exists():
|
||||
progress.update(
|
||||
BEGIN | PATHCHANGE,
|
||||
i,
|
||||
len_csms,
|
||||
prefix + "Moving repository of submodule %r from %s to %s" % (sm.name, psm.abspath, sm.abspath),
|
||||
)
|
||||
# move the module to the new path
|
||||
if not dry_run:
|
||||
psm.move(sm.path, module=True, configuration=False)
|
||||
# END handle dry_run
|
||||
progress.update(
|
||||
END | PATHCHANGE,
|
||||
i,
|
||||
len_csms,
|
||||
prefix + "Done moving repository of submodule %r" % sm.name,
|
||||
)
|
||||
# END handle path changes
|
||||
|
||||
if sm.module_exists():
|
||||
# HANDLE URL CHANGE
|
||||
###################
|
||||
if sm.url != psm.url:
|
||||
# Add the new remote, remove the old one
|
||||
# This way, if the url just changes, the commits will not
|
||||
# have to be re-retrieved
|
||||
nn = "__new_origin__"
|
||||
smm = sm.module()
|
||||
rmts = smm.remotes
|
||||
|
||||
# don't do anything if we already have the url we search in place
|
||||
if len([r for r in rmts if r.url == sm.url]) == 0:
|
||||
progress.update(
|
||||
BEGIN | URLCHANGE,
|
||||
i,
|
||||
len_csms,
|
||||
prefix + "Changing url of submodule %r from %s to %s" % (sm.name, psm.url, sm.url),
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
assert nn not in [r.name for r in rmts]
|
||||
smr = smm.create_remote(nn, sm.url)
|
||||
smr.fetch(progress=progress)
|
||||
|
||||
# If we have a tracking branch, it should be available
|
||||
# in the new remote as well.
|
||||
if len([r for r in smr.refs if r.remote_head == sm.branch_name]) == 0:
|
||||
raise ValueError(
|
||||
"Submodule branch named %r was not available in new submodule remote at %r"
|
||||
% (sm.branch_name, sm.url)
|
||||
)
|
||||
# END head is not detached
|
||||
|
||||
# now delete the changed one
|
||||
rmt_for_deletion = None
|
||||
for remote in rmts:
|
||||
if remote.url == psm.url:
|
||||
rmt_for_deletion = remote
|
||||
break
|
||||
# END if urls match
|
||||
# END for each remote
|
||||
|
||||
# if we didn't find a matching remote, but have exactly one,
|
||||
# we can safely use this one
|
||||
if rmt_for_deletion is None:
|
||||
if len(rmts) == 1:
|
||||
rmt_for_deletion = rmts[0]
|
||||
else:
|
||||
# if we have not found any remote with the original url
|
||||
# we may not have a name. This is a special case,
|
||||
# and its okay to fail here
|
||||
# Alternatively we could just generate a unique name and leave all
|
||||
# existing ones in place
|
||||
raise InvalidGitRepositoryError(
|
||||
"Couldn't find original remote-repo at url %r" % psm.url
|
||||
)
|
||||
# END handle one single remote
|
||||
# END handle check we found a remote
|
||||
|
||||
orig_name = rmt_for_deletion.name
|
||||
smm.delete_remote(rmt_for_deletion)
|
||||
# NOTE: Currently we leave tags from the deleted remotes
|
||||
# as well as separate tracking branches in the possibly totally
|
||||
# changed repository ( someone could have changed the url to
|
||||
# another project ). At some point, one might want to clean
|
||||
# it up, but the danger is high to remove stuff the user
|
||||
# has added explicitly
|
||||
|
||||
# rename the new remote back to what it was
|
||||
smr.rename(orig_name)
|
||||
|
||||
# early on, we verified that the our current tracking branch
|
||||
# exists in the remote. Now we have to assure that the
|
||||
# sha we point to is still contained in the new remote
|
||||
# tracking branch.
|
||||
smsha = sm.binsha
|
||||
found = False
|
||||
rref = smr.refs[self.branch_name]
|
||||
for c in rref.commit.traverse():
|
||||
if c.binsha == smsha:
|
||||
found = True
|
||||
break
|
||||
# END traverse all commits in search for sha
|
||||
# END for each commit
|
||||
|
||||
if not found:
|
||||
# adjust our internal binsha to use the one of the remote
|
||||
# this way, it will be checked out in the next step
|
||||
# This will change the submodule relative to us, so
|
||||
# the user will be able to commit the change easily
|
||||
log.warning(
|
||||
"Current sha %s was not contained in the tracking\
|
||||
branch at the new remote, setting it the the remote's tracking branch",
|
||||
sm.hexsha,
|
||||
)
|
||||
sm.binsha = rref.commit.binsha
|
||||
# END reset binsha
|
||||
|
||||
# NOTE: All checkout is performed by the base implementation of update
|
||||
# END handle dry_run
|
||||
progress.update(
|
||||
END | URLCHANGE,
|
||||
i,
|
||||
len_csms,
|
||||
prefix + "Done adjusting url of submodule %r" % (sm.name),
|
||||
)
|
||||
# END skip remote handling if new url already exists in module
|
||||
# END handle url
|
||||
|
||||
# HANDLE PATH CHANGES
|
||||
#####################
|
||||
if sm.branch_path != psm.branch_path:
|
||||
# finally, create a new tracking branch which tracks the
|
||||
# new remote branch
|
||||
progress.update(
|
||||
BEGIN | BRANCHCHANGE,
|
||||
i,
|
||||
len_csms,
|
||||
prefix
|
||||
+ "Changing branch of submodule %r from %s to %s"
|
||||
% (sm.name, psm.branch_path, sm.branch_path),
|
||||
)
|
||||
if not dry_run:
|
||||
smm = sm.module()
|
||||
smmr = smm.remotes
|
||||
# As the branch might not exist yet, we will have to fetch all remotes to be sure ... .
|
||||
for remote in smmr:
|
||||
remote.fetch(progress=progress)
|
||||
# end for each remote
|
||||
|
||||
try:
|
||||
tbr = git.Head.create(
|
||||
smm,
|
||||
sm.branch_name,
|
||||
logmsg="branch: Created from HEAD",
|
||||
)
|
||||
except OSError:
|
||||
# ... or reuse the existing one
|
||||
tbr = git.Head(smm, sm.branch_path)
|
||||
# END assure tracking branch exists
|
||||
|
||||
tbr.set_tracking_branch(find_first_remote_branch(smmr, sm.branch_name))
|
||||
# NOTE: All head-resetting is done in the base implementation of update
|
||||
# but we will have to checkout the new branch here. As it still points to the currently
|
||||
# checkout out commit, we don't do any harm.
|
||||
# As we don't want to update working-tree or index, changing the ref is all there is to do
|
||||
smm.head.reference = tbr
|
||||
# END handle dry_run
|
||||
|
||||
progress.update(
|
||||
END | BRANCHCHANGE,
|
||||
i,
|
||||
len_csms,
|
||||
prefix + "Done changing branch of submodule %r" % sm.name,
|
||||
)
|
||||
# END handle branch
|
||||
# END handle
|
||||
# END for each common submodule
|
||||
except Exception as err:
|
||||
if not keep_going:
|
||||
raise
|
||||
log.error(str(err))
|
||||
# end handle keep_going
|
||||
|
||||
# FINALLY UPDATE ALL ACTUAL SUBMODULES
|
||||
######################################
|
||||
for sm in sms:
|
||||
# update the submodule using the default method
|
||||
sm.update(
|
||||
recursive=False,
|
||||
init=init,
|
||||
to_latest_revision=to_latest_revision,
|
||||
progress=progress,
|
||||
dry_run=dry_run,
|
||||
force=force_reset,
|
||||
keep_going=keep_going,
|
||||
)
|
||||
|
||||
# update recursively depth first - question is which inconsistent
|
||||
# state will be better in case it fails somewhere. Defective branch
|
||||
# or defective depth. The RootSubmodule type will never process itself,
|
||||
# which was done in the previous expression
|
||||
if recursive:
|
||||
# the module would exist by now if we are not in dry_run mode
|
||||
if sm.module_exists():
|
||||
type(self)(sm.module()).update(
|
||||
recursive=True,
|
||||
force_remove=force_remove,
|
||||
init=init,
|
||||
to_latest_revision=to_latest_revision,
|
||||
progress=progress,
|
||||
dry_run=dry_run,
|
||||
force_reset=force_reset,
|
||||
keep_going=keep_going,
|
||||
)
|
||||
# END handle dry_run
|
||||
# END handle recursive
|
||||
# END for each submodule to update
|
||||
|
||||
return self
|
||||
|
||||
def module(self) -> "Repo":
|
||||
""":return: the actual repository containing the submodules"""
|
||||
return self.repo
|
||||
|
||||
# } END interface
|
||||
|
||||
|
||||
# } END classes
|
||||
@@ -0,0 +1,118 @@
|
||||
import git
|
||||
from git.exc import InvalidGitRepositoryError
|
||||
from git.config import GitConfigParser
|
||||
from io import BytesIO
|
||||
import weakref
|
||||
|
||||
|
||||
# typing -----------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
from git.types import PathLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import Submodule
|
||||
from weakref import ReferenceType
|
||||
from git.repo import Repo
|
||||
from git.refs import Head
|
||||
from git import Remote
|
||||
from git.refs import RemoteReference
|
||||
|
||||
|
||||
__all__ = (
|
||||
"sm_section",
|
||||
"sm_name",
|
||||
"mkhead",
|
||||
"find_first_remote_branch",
|
||||
"SubmoduleConfigParser",
|
||||
)
|
||||
|
||||
# { Utilities
|
||||
|
||||
|
||||
def sm_section(name: str) -> str:
|
||||
""":return: section title used in .gitmodules configuration file"""
|
||||
return f'submodule "{name}"'
|
||||
|
||||
|
||||
def sm_name(section: str) -> str:
|
||||
""":return: name of the submodule as parsed from the section name"""
|
||||
section = section.strip()
|
||||
return section[11:-1]
|
||||
|
||||
|
||||
def mkhead(repo: "Repo", path: PathLike) -> "Head":
|
||||
""":return: New branch/head instance"""
|
||||
return git.Head(repo, git.Head.to_full_path(path))
|
||||
|
||||
|
||||
def find_first_remote_branch(remotes: Sequence["Remote"], branch_name: str) -> "RemoteReference":
|
||||
"""Find the remote branch matching the name of the given branch or raise InvalidGitRepositoryError"""
|
||||
for remote in remotes:
|
||||
try:
|
||||
return remote.refs[branch_name]
|
||||
except IndexError:
|
||||
continue
|
||||
# END exception handling
|
||||
# END for remote
|
||||
raise InvalidGitRepositoryError("Didn't find remote branch '%r' in any of the given remotes" % branch_name)
|
||||
|
||||
|
||||
# } END utilities
|
||||
|
||||
|
||||
# { Classes
|
||||
|
||||
|
||||
class SubmoduleConfigParser(GitConfigParser):
|
||||
|
||||
"""
|
||||
Catches calls to _write, and updates the .gitmodules blob in the index
|
||||
with the new data, if we have written into a stream. Otherwise it will
|
||||
add the local file to the index to make it correspond with the working tree.
|
||||
Additionally, the cache must be cleared
|
||||
|
||||
Please note that no mutating method will work in bare mode
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._smref: Union["ReferenceType[Submodule]", None] = None
|
||||
self._index = None
|
||||
self._auto_write = True
|
||||
super(SubmoduleConfigParser, self).__init__(*args, **kwargs)
|
||||
|
||||
# { Interface
|
||||
def set_submodule(self, submodule: "Submodule") -> None:
|
||||
"""Set this instance's submodule. It must be called before
|
||||
the first write operation begins"""
|
||||
self._smref = weakref.ref(submodule)
|
||||
|
||||
def flush_to_index(self) -> None:
|
||||
"""Flush changes in our configuration file to the index"""
|
||||
assert self._smref is not None
|
||||
# should always have a file here
|
||||
assert not isinstance(self._file_or_files, BytesIO)
|
||||
|
||||
sm = self._smref()
|
||||
if sm is not None:
|
||||
index = self._index
|
||||
if index is None:
|
||||
index = sm.repo.index
|
||||
# END handle index
|
||||
index.add([sm.k_modules_file], write=self._auto_write)
|
||||
sm._clear_cache()
|
||||
# END handle weakref
|
||||
|
||||
# } END interface
|
||||
|
||||
# { Overridden Methods
|
||||
def write(self) -> None: # type: ignore[override]
|
||||
rval: None = super(SubmoduleConfigParser, self).write()
|
||||
self.flush_to_index()
|
||||
return rval
|
||||
|
||||
# END overridden methods
|
||||
|
||||
|
||||
# } END classes
|
||||
@@ -0,0 +1,107 @@
|
||||
# objects.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
""" Module containing all object based types. """
|
||||
from . import base
|
||||
from .util import get_object_type_by_name, parse_actor_and_date
|
||||
from ..util import hex_to_bin
|
||||
from ..compat import defenc
|
||||
|
||||
from typing import List, TYPE_CHECKING, Union
|
||||
|
||||
from git.types import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
from git.util import Actor
|
||||
from .commit import Commit
|
||||
from .blob import Blob
|
||||
from .tree import Tree
|
||||
|
||||
__all__ = ("TagObject",)
|
||||
|
||||
|
||||
class TagObject(base.Object):
|
||||
|
||||
"""Non-Lightweight tag carrying additional information about an object we are pointing to."""
|
||||
|
||||
type: Literal["tag"] = "tag"
|
||||
__slots__ = (
|
||||
"object",
|
||||
"tag",
|
||||
"tagger",
|
||||
"tagged_date",
|
||||
"tagger_tz_offset",
|
||||
"message",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: "Repo",
|
||||
binsha: bytes,
|
||||
object: Union[None, base.Object] = None,
|
||||
tag: Union[None, str] = None,
|
||||
tagger: Union[None, "Actor"] = None,
|
||||
tagged_date: Union[int, None] = None,
|
||||
tagger_tz_offset: Union[int, None] = None,
|
||||
message: Union[str, None] = None,
|
||||
) -> None: # @ReservedAssignment
|
||||
"""Initialize a tag object with additional data
|
||||
|
||||
:param repo: repository this object is located in
|
||||
:param binsha: 20 byte SHA1
|
||||
:param object: Object instance of object we are pointing to
|
||||
:param tag: name of this tag
|
||||
:param tagger: Actor identifying the tagger
|
||||
:param tagged_date: int_seconds_since_epoch
|
||||
is the DateTime of the tag creation - use time.gmtime to convert
|
||||
it into a different format
|
||||
:param tagged_tz_offset: int_seconds_west_of_utc is the timezone that the
|
||||
authored_date is in, in a format similar to time.altzone"""
|
||||
super(TagObject, self).__init__(repo, binsha)
|
||||
if object is not None:
|
||||
self.object: Union["Commit", "Blob", "Tree", "TagObject"] = object
|
||||
if tag is not None:
|
||||
self.tag = tag
|
||||
if tagger is not None:
|
||||
self.tagger = tagger
|
||||
if tagged_date is not None:
|
||||
self.tagged_date = tagged_date
|
||||
if tagger_tz_offset is not None:
|
||||
self.tagger_tz_offset = tagger_tz_offset
|
||||
if message is not None:
|
||||
self.message = message
|
||||
|
||||
def _set_cache_(self, attr: str) -> None:
|
||||
"""Cache all our attributes at once"""
|
||||
if attr in TagObject.__slots__:
|
||||
ostream = self.repo.odb.stream(self.binsha)
|
||||
lines: List[str] = ostream.read().decode(defenc, "replace").splitlines()
|
||||
|
||||
_obj, hexsha = lines[0].split(" ")
|
||||
_type_token, type_name = lines[1].split(" ")
|
||||
object_type = get_object_type_by_name(type_name.encode("ascii"))
|
||||
self.object = object_type(self.repo, hex_to_bin(hexsha))
|
||||
|
||||
self.tag = lines[2][4:] # tag <tag name>
|
||||
|
||||
if len(lines) > 3:
|
||||
tagger_info = lines[3] # tagger <actor> <date>
|
||||
(
|
||||
self.tagger,
|
||||
self.tagged_date,
|
||||
self.tagger_tz_offset,
|
||||
) = parse_actor_and_date(tagger_info)
|
||||
|
||||
# line 4 empty - it could mark the beginning of the next header
|
||||
# in case there really is no message, it would not exist. Otherwise
|
||||
# a newline separates header from message
|
||||
if len(lines) > 5:
|
||||
self.message = "\n".join(lines[5:])
|
||||
else:
|
||||
self.message = ""
|
||||
# END check our attributes
|
||||
else:
|
||||
super(TagObject, self)._set_cache_(attr)
|
||||
@@ -0,0 +1,424 @@
|
||||
# tree.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
|
||||
from git.util import IterableList, join_path
|
||||
import git.diff as git_diff
|
||||
from git.util import to_bin_sha
|
||||
|
||||
from . import util
|
||||
from .base import IndexObject, IndexObjUnion
|
||||
from .blob import Blob
|
||||
from .submodule.base import Submodule
|
||||
|
||||
from .fun import tree_entries_from_data, tree_to_stream
|
||||
|
||||
|
||||
# typing -------------------------------------------------
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from git.types import PathLike, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
from io import BytesIO
|
||||
|
||||
TreeCacheTup = Tuple[bytes, int, str]
|
||||
|
||||
TraversedTreeTup = Union[Tuple[Union["Tree", None], IndexObjUnion, Tuple["Submodule", "Submodule"]]]
|
||||
|
||||
|
||||
# def is_tree_cache(inp: Tuple[bytes, int, str]) -> TypeGuard[TreeCacheTup]:
|
||||
# return isinstance(inp[0], bytes) and isinstance(inp[1], int) and isinstance([inp], str)
|
||||
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b)
|
||||
|
||||
__all__ = ("TreeModifier", "Tree")
|
||||
|
||||
|
||||
def git_cmp(t1: TreeCacheTup, t2: TreeCacheTup) -> int:
|
||||
a, b = t1[2], t2[2]
|
||||
# assert isinstance(a, str) and isinstance(b, str)
|
||||
len_a, len_b = len(a), len(b)
|
||||
min_len = min(len_a, len_b)
|
||||
min_cmp = cmp(a[:min_len], b[:min_len])
|
||||
|
||||
if min_cmp:
|
||||
return min_cmp
|
||||
|
||||
return len_a - len_b
|
||||
|
||||
|
||||
def merge_sort(a: List[TreeCacheTup], cmp: Callable[[TreeCacheTup, TreeCacheTup], int]) -> None:
|
||||
if len(a) < 2:
|
||||
return None
|
||||
|
||||
mid = len(a) // 2
|
||||
lefthalf = a[:mid]
|
||||
righthalf = a[mid:]
|
||||
|
||||
merge_sort(lefthalf, cmp)
|
||||
merge_sort(righthalf, cmp)
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
k = 0
|
||||
|
||||
while i < len(lefthalf) and j < len(righthalf):
|
||||
if cmp(lefthalf[i], righthalf[j]) <= 0:
|
||||
a[k] = lefthalf[i]
|
||||
i = i + 1
|
||||
else:
|
||||
a[k] = righthalf[j]
|
||||
j = j + 1
|
||||
k = k + 1
|
||||
|
||||
while i < len(lefthalf):
|
||||
a[k] = lefthalf[i]
|
||||
i = i + 1
|
||||
k = k + 1
|
||||
|
||||
while j < len(righthalf):
|
||||
a[k] = righthalf[j]
|
||||
j = j + 1
|
||||
k = k + 1
|
||||
|
||||
|
||||
class TreeModifier(object):
|
||||
|
||||
"""A utility class providing methods to alter the underlying cache in a list-like fashion.
|
||||
|
||||
Once all adjustments are complete, the _cache, which really is a reference to
|
||||
the cache of a tree, will be sorted. Assuring it will be in a serializable state"""
|
||||
|
||||
__slots__ = "_cache"
|
||||
|
||||
def __init__(self, cache: List[TreeCacheTup]) -> None:
|
||||
self._cache = cache
|
||||
|
||||
def _index_by_name(self, name: str) -> int:
|
||||
""":return: index of an item with name, or -1 if not found"""
|
||||
for i, t in enumerate(self._cache):
|
||||
if t[2] == name:
|
||||
return i
|
||||
# END found item
|
||||
# END for each item in cache
|
||||
return -1
|
||||
|
||||
# { Interface
|
||||
def set_done(self) -> "TreeModifier":
|
||||
"""Call this method once you are done modifying the tree information.
|
||||
It may be called several times, but be aware that each call will cause
|
||||
a sort operation
|
||||
|
||||
:return self:"""
|
||||
merge_sort(self._cache, git_cmp)
|
||||
return self
|
||||
|
||||
# } END interface
|
||||
|
||||
# { Mutators
|
||||
def add(self, sha: bytes, mode: int, name: str, force: bool = False) -> "TreeModifier":
|
||||
"""Add the given item to the tree. If an item with the given name already
|
||||
exists, nothing will be done, but a ValueError will be raised if the
|
||||
sha and mode of the existing item do not match the one you add, unless
|
||||
force is True
|
||||
|
||||
:param sha: The 20 or 40 byte sha of the item to add
|
||||
:param mode: int representing the stat compatible mode of the item
|
||||
:param force: If True, an item with your name and information will overwrite
|
||||
any existing item with the same name, no matter which information it has
|
||||
:return: self"""
|
||||
if "/" in name:
|
||||
raise ValueError("Name must not contain '/' characters")
|
||||
if (mode >> 12) not in Tree._map_id_to_type:
|
||||
raise ValueError("Invalid object type according to mode %o" % mode)
|
||||
|
||||
sha = to_bin_sha(sha)
|
||||
index = self._index_by_name(name)
|
||||
|
||||
item = (sha, mode, name)
|
||||
# assert is_tree_cache(item)
|
||||
|
||||
if index == -1:
|
||||
self._cache.append(item)
|
||||
else:
|
||||
if force:
|
||||
self._cache[index] = item
|
||||
else:
|
||||
ex_item = self._cache[index]
|
||||
if ex_item[0] != sha or ex_item[1] != mode:
|
||||
raise ValueError("Item %r existed with different properties" % name)
|
||||
# END handle mismatch
|
||||
# END handle force
|
||||
# END handle name exists
|
||||
return self
|
||||
|
||||
def add_unchecked(self, binsha: bytes, mode: int, name: str) -> None:
|
||||
"""Add the given item to the tree, its correctness is assumed, which
|
||||
puts the caller into responsibility to assure the input is correct.
|
||||
For more information on the parameters, see ``add``
|
||||
|
||||
:param binsha: 20 byte binary sha"""
|
||||
assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str)
|
||||
tree_cache = (binsha, mode, name)
|
||||
|
||||
self._cache.append(tree_cache)
|
||||
|
||||
def __delitem__(self, name: str) -> None:
|
||||
"""Deletes an item with the given name if it exists"""
|
||||
index = self._index_by_name(name)
|
||||
if index > -1:
|
||||
del self._cache[index]
|
||||
|
||||
# } END mutators
|
||||
|
||||
|
||||
class Tree(IndexObject, git_diff.Diffable, util.Traversable, util.Serializable):
|
||||
|
||||
"""Tree objects represent an ordered list of Blobs and other Trees.
|
||||
|
||||
``Tree as a list``::
|
||||
|
||||
Access a specific blob using the
|
||||
tree['filename'] notation.
|
||||
|
||||
You may as well access by index
|
||||
blob = tree[0]
|
||||
"""
|
||||
|
||||
type: Literal["tree"] = "tree"
|
||||
__slots__ = "_cache"
|
||||
|
||||
# actual integer ids for comparison
|
||||
commit_id = 0o16 # equals stat.S_IFDIR | stat.S_IFLNK - a directory link
|
||||
blob_id = 0o10
|
||||
symlink_id = 0o12
|
||||
tree_id = 0o04
|
||||
|
||||
_map_id_to_type: Dict[int, Type[IndexObjUnion]] = {
|
||||
commit_id: Submodule,
|
||||
blob_id: Blob,
|
||||
symlink_id: Blob
|
||||
# tree id added once Tree is defined
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: "Repo",
|
||||
binsha: bytes,
|
||||
mode: int = tree_id << 12,
|
||||
path: Union[PathLike, None] = None,
|
||||
):
|
||||
super(Tree, self).__init__(repo, binsha, mode, path)
|
||||
|
||||
@classmethod
|
||||
def _get_intermediate_items(
|
||||
cls,
|
||||
index_object: IndexObjUnion,
|
||||
) -> Union[Tuple["Tree", ...], Tuple[()]]:
|
||||
if index_object.type == "tree":
|
||||
return tuple(index_object._iter_convert_to_object(index_object._cache))
|
||||
return ()
|
||||
|
||||
def _set_cache_(self, attr: str) -> None:
|
||||
if attr == "_cache":
|
||||
# Set the data when we need it
|
||||
ostream = self.repo.odb.stream(self.binsha)
|
||||
self._cache: List[TreeCacheTup] = tree_entries_from_data(ostream.read())
|
||||
else:
|
||||
super(Tree, self)._set_cache_(attr)
|
||||
# END handle attribute
|
||||
|
||||
def _iter_convert_to_object(self, iterable: Iterable[TreeCacheTup]) -> Iterator[IndexObjUnion]:
|
||||
"""Iterable yields tuples of (binsha, mode, name), which will be converted
|
||||
to the respective object representation"""
|
||||
for binsha, mode, name in iterable:
|
||||
path = join_path(self.path, name)
|
||||
try:
|
||||
yield self._map_id_to_type[mode >> 12](self.repo, binsha, mode, path)
|
||||
except KeyError as e:
|
||||
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
|
||||
# END for each item
|
||||
|
||||
def join(self, file: str) -> IndexObjUnion:
|
||||
"""Find the named object in this tree's contents
|
||||
|
||||
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
|
||||
:raise KeyError: if given file or tree does not exist in tree"""
|
||||
msg = "Blob or Tree named %r not found"
|
||||
if "/" in file:
|
||||
tree = self
|
||||
item = self
|
||||
tokens = file.split("/")
|
||||
for i, token in enumerate(tokens):
|
||||
item = tree[token]
|
||||
if item.type == "tree":
|
||||
tree = item
|
||||
else:
|
||||
# safety assertion - blobs are at the end of the path
|
||||
if i != len(tokens) - 1:
|
||||
raise KeyError(msg % file)
|
||||
return item
|
||||
# END handle item type
|
||||
# END for each token of split path
|
||||
if item == self:
|
||||
raise KeyError(msg % file)
|
||||
return item
|
||||
else:
|
||||
for info in self._cache:
|
||||
if info[2] == file: # [2] == name
|
||||
return self._map_id_to_type[info[1] >> 12](
|
||||
self.repo, info[0], info[1], join_path(self.path, info[2])
|
||||
)
|
||||
# END for each obj
|
||||
raise KeyError(msg % file)
|
||||
# END handle long paths
|
||||
|
||||
def __truediv__(self, file: str) -> IndexObjUnion:
|
||||
"""For PY3 only"""
|
||||
return self.join(file)
|
||||
|
||||
@property
|
||||
def trees(self) -> List["Tree"]:
|
||||
""":return: list(Tree, ...) list of trees directly below this tree"""
|
||||
return [i for i in self if i.type == "tree"]
|
||||
|
||||
@property
|
||||
def blobs(self) -> List[Blob]:
|
||||
""":return: list(Blob, ...) list of blobs directly below this tree"""
|
||||
return [i for i in self if i.type == "blob"]
|
||||
|
||||
@property
|
||||
def cache(self) -> TreeModifier:
|
||||
"""
|
||||
:return: An object allowing to modify the internal cache. This can be used
|
||||
to change the tree's contents. When done, make sure you call ``set_done``
|
||||
on the tree modifier, or serialization behaviour will be incorrect.
|
||||
See the ``TreeModifier`` for more information on how to alter the cache"""
|
||||
return TreeModifier(self._cache)
|
||||
|
||||
def traverse(
|
||||
self, # type: ignore[override]
|
||||
predicate: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: True,
|
||||
prune: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: False,
|
||||
depth: int = -1,
|
||||
branch_first: bool = True,
|
||||
visit_once: bool = False,
|
||||
ignore_self: int = 1,
|
||||
as_edge: bool = False,
|
||||
) -> Union[Iterator[IndexObjUnion], Iterator[TraversedTreeTup]]:
|
||||
"""For documentation, see util.Traversable._traverse()
|
||||
Trees are set to visit_once = False to gain more performance in the traversal"""
|
||||
|
||||
# """
|
||||
# # To typecheck instead of using cast.
|
||||
# import itertools
|
||||
# def is_tree_traversed(inp: Tuple) -> TypeGuard[Tuple[Iterator[Union['Tree', 'Blob', 'Submodule']]]]:
|
||||
# return all(isinstance(x, (Blob, Tree, Submodule)) for x in inp[1])
|
||||
|
||||
# ret = super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self)
|
||||
# ret_tup = itertools.tee(ret, 2)
|
||||
# assert is_tree_traversed(ret_tup), f"Type is {[type(x) for x in list(ret_tup[0])]}"
|
||||
# return ret_tup[0]"""
|
||||
return cast(
|
||||
Union[Iterator[IndexObjUnion], Iterator[TraversedTreeTup]],
|
||||
super(Tree, self)._traverse(
|
||||
predicate,
|
||||
prune,
|
||||
depth, # type: ignore
|
||||
branch_first,
|
||||
visit_once,
|
||||
ignore_self,
|
||||
),
|
||||
)
|
||||
|
||||
def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[IndexObjUnion]:
|
||||
"""
|
||||
:return: IterableList with the results of the traversal as produced by
|
||||
traverse()
|
||||
Tree -> IterableList[Union['Submodule', 'Tree', 'Blob']]
|
||||
"""
|
||||
return super(Tree, self)._list_traverse(*args, **kwargs)
|
||||
|
||||
# List protocol
|
||||
|
||||
def __getslice__(self, i: int, j: int) -> List[IndexObjUnion]:
|
||||
return list(self._iter_convert_to_object(self._cache[i:j]))
|
||||
|
||||
def __iter__(self) -> Iterator[IndexObjUnion]:
|
||||
return self._iter_convert_to_object(self._cache)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._cache)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice]) -> IndexObjUnion:
|
||||
if isinstance(item, int):
|
||||
info = self._cache[item]
|
||||
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
|
||||
|
||||
if isinstance(item, str):
|
||||
# compatibility
|
||||
return self.join(item)
|
||||
# END index is basestring
|
||||
|
||||
raise TypeError("Invalid index type: %r" % item)
|
||||
|
||||
def __contains__(self, item: Union[IndexObjUnion, PathLike]) -> bool:
|
||||
if isinstance(item, IndexObject):
|
||||
for info in self._cache:
|
||||
if item.binsha == info[0]:
|
||||
return True
|
||||
# END compare sha
|
||||
# END for each entry
|
||||
# END handle item is index object
|
||||
# compatibility
|
||||
|
||||
# treat item as repo-relative path
|
||||
else:
|
||||
path = self.path
|
||||
for info in self._cache:
|
||||
if item == join_path(path, info[2]):
|
||||
return True
|
||||
# END for each item
|
||||
return False
|
||||
|
||||
def __reversed__(self) -> Iterator[IndexObjUnion]:
|
||||
return reversed(self._iter_convert_to_object(self._cache)) # type: ignore
|
||||
|
||||
def _serialize(self, stream: "BytesIO") -> "Tree":
|
||||
"""Serialize this tree into the stream. Please note that we will assume
|
||||
our tree data to be in a sorted state. If this is not the case, serialization
|
||||
will not generate a correct tree representation as these are assumed to be sorted
|
||||
by algorithms"""
|
||||
tree_to_stream(self._cache, stream.write)
|
||||
return self
|
||||
|
||||
def _deserialize(self, stream: "BytesIO") -> "Tree":
|
||||
self._cache = tree_entries_from_data(stream.read())
|
||||
return self
|
||||
|
||||
|
||||
# END tree
|
||||
|
||||
# finalize map definition
|
||||
Tree._map_id_to_type[Tree.tree_id] = Tree
|
||||
#
|
||||
@@ -0,0 +1,637 @@
|
||||
# util.py
|
||||
# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors
|
||||
#
|
||||
# This module is part of GitPython and is released under
|
||||
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
|
||||
"""Module for general utility functions"""
|
||||
# flake8: noqa F401
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import warnings
|
||||
from git.util import IterableList, IterableObj, Actor
|
||||
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
from string import digits
|
||||
import time
|
||||
import calendar
|
||||
from datetime import datetime, timedelta, tzinfo
|
||||
|
||||
# typing ------------------------------------------------------------
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Iterator,
|
||||
Generic,
|
||||
NamedTuple,
|
||||
overload,
|
||||
Sequence, # NOQA: F401
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from git.types import Has_id_attribute, Literal, _T # NOQA: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BytesIO, StringIO
|
||||
from .commit import Commit
|
||||
from .blob import Blob
|
||||
from .tag import TagObject
|
||||
from .tree import Tree, TraversedTreeTup
|
||||
from subprocess import Popen
|
||||
from .submodule.base import Submodule
|
||||
from git.types import Protocol, runtime_checkable
|
||||
else:
|
||||
# Protocol = Generic[_T] # Needed for typing bug #572?
|
||||
Protocol = ABC
|
||||
|
||||
def runtime_checkable(f):
|
||||
return f
|
||||
|
||||
|
||||
class TraverseNT(NamedTuple):
|
||||
depth: int
|
||||
item: Union["Traversable", "Blob"]
|
||||
src: Union["Traversable", None]
|
||||
|
||||
|
||||
T_TIobj = TypeVar("T_TIobj", bound="TraversableIterableObj") # for TraversableIterableObj.traverse()
|
||||
|
||||
TraversedTup = Union[
|
||||
Tuple[Union["Traversable", None], "Traversable"], # for commit, submodule
|
||||
"TraversedTreeTup",
|
||||
] # for tree.traverse()
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
__all__ = (
|
||||
"get_object_type_by_name",
|
||||
"parse_date",
|
||||
"parse_actor_and_date",
|
||||
"ProcessStreamAdapter",
|
||||
"Traversable",
|
||||
"altz_to_utctz_str",
|
||||
"utctz_to_altz",
|
||||
"verify_utctz",
|
||||
"Actor",
|
||||
"tzoffset",
|
||||
"utc",
|
||||
)
|
||||
|
||||
ZERO = timedelta(0)
|
||||
|
||||
# { Functions
|
||||
|
||||
|
||||
def mode_str_to_int(modestr: Union[bytes, str]) -> int:
|
||||
"""
|
||||
:param modestr: string like 755 or 644 or 100644 - only the last 6 chars will be used
|
||||
:return:
|
||||
String identifying a mode compatible to the mode methods ids of the
|
||||
stat module regarding the rwx permissions for user, group and other,
|
||||
special flags and file system flags, i.e. whether it is a symlink
|
||||
for example."""
|
||||
mode = 0
|
||||
for iteration, char in enumerate(reversed(modestr[-6:])):
|
||||
char = cast(Union[str, int], char)
|
||||
mode += int(char) << iteration * 3
|
||||
# END for each char
|
||||
return mode
|
||||
|
||||
|
||||
def get_object_type_by_name(
|
||||
object_type_name: bytes,
|
||||
) -> Union[Type["Commit"], Type["TagObject"], Type["Tree"], Type["Blob"]]:
|
||||
"""
|
||||
:return: type suitable to handle the given object type name.
|
||||
Use the type to create new instances.
|
||||
|
||||
:param object_type_name: Member of TYPES
|
||||
|
||||
:raise ValueError: In case object_type_name is unknown"""
|
||||
if object_type_name == b"commit":
|
||||
from . import commit
|
||||
|
||||
return commit.Commit
|
||||
elif object_type_name == b"tag":
|
||||
from . import tag
|
||||
|
||||
return tag.TagObject
|
||||
elif object_type_name == b"blob":
|
||||
from . import blob
|
||||
|
||||
return blob.Blob
|
||||
elif object_type_name == b"tree":
|
||||
from . import tree
|
||||
|
||||
return tree.Tree
|
||||
else:
|
||||
raise ValueError("Cannot handle unknown object type: %s" % object_type_name.decode())
|
||||
|
||||
|
||||
def utctz_to_altz(utctz: str) -> int:
|
||||
"""Convert a git timezone offset into a timezone offset west of
|
||||
UTC in seconds (compatible with time.altzone).
|
||||
|
||||
:param utctz: git utc timezone string, i.e. +0200
|
||||
"""
|
||||
int_utctz = int(utctz)
|
||||
seconds = ((abs(int_utctz) // 100) * 3600 + (abs(int_utctz) % 100) * 60)
|
||||
return seconds if int_utctz < 0 else -seconds
|
||||
|
||||
|
||||
def altz_to_utctz_str(altz: int) -> str:
|
||||
"""Convert a timezone offset west of UTC in seconds into a git timezone offset string
|
||||
|
||||
:param altz: timezone offset in seconds west of UTC
|
||||
"""
|
||||
hours = abs(altz) // 3600
|
||||
minutes = (abs(altz) % 3600) // 60
|
||||
sign = "-" if altz >= 60 else "+"
|
||||
return "{}{:02}{:02}".format(sign, hours, minutes)
|
||||
|
||||
|
||||
def verify_utctz(offset: str) -> str:
|
||||
""":raise ValueError: if offset is incorrect
|
||||
:return: offset"""
|
||||
fmt_exc = ValueError("Invalid timezone offset format: %s" % offset)
|
||||
if len(offset) != 5:
|
||||
raise fmt_exc
|
||||
if offset[0] not in "+-":
|
||||
raise fmt_exc
|
||||
if offset[1] not in digits or offset[2] not in digits or offset[3] not in digits or offset[4] not in digits:
|
||||
raise fmt_exc
|
||||
# END for each char
|
||||
return offset
|
||||
|
||||
|
||||
class tzoffset(tzinfo):
|
||||
def __init__(self, secs_west_of_utc: float, name: Union[None, str] = None) -> None:
|
||||
self._offset = timedelta(seconds=-secs_west_of_utc)
|
||||
self._name = name or "fixed"
|
||||
|
||||
def __reduce__(self) -> Tuple[Type["tzoffset"], Tuple[float, str]]:
|
||||
return tzoffset, (-self._offset.total_seconds(), self._name)
|
||||
|
||||
def utcoffset(self, dt: Union[datetime, None]) -> timedelta:
|
||||
return self._offset
|
||||
|
||||
def tzname(self, dt: Union[datetime, None]) -> str:
|
||||
return self._name
|
||||
|
||||
def dst(self, dt: Union[datetime, None]) -> timedelta:
|
||||
return ZERO
|
||||
|
||||
|
||||
utc = tzoffset(0, "UTC")
|
||||
|
||||
|
||||
def from_timestamp(timestamp: float, tz_offset: float) -> datetime:
|
||||
"""Converts a timestamp + tz_offset into an aware datetime instance."""
|
||||
utc_dt = datetime.fromtimestamp(timestamp, utc)
|
||||
try:
|
||||
local_dt = utc_dt.astimezone(tzoffset(tz_offset))
|
||||
return local_dt
|
||||
except ValueError:
|
||||
return utc_dt
|
||||
|
||||
|
||||
def parse_date(string_date: Union[str, datetime]) -> Tuple[int, int]:
|
||||
"""
|
||||
Parse the given date as one of the following
|
||||
|
||||
* aware datetime instance
|
||||
* Git internal format: timestamp offset
|
||||
* RFC 2822: Thu, 07 Apr 2005 22:13:13 +0200.
|
||||
* ISO 8601 2005-04-07T22:13:13
|
||||
The T can be a space as well
|
||||
|
||||
:return: Tuple(int(timestamp_UTC), int(offset)), both in seconds since epoch
|
||||
:raise ValueError: If the format could not be understood
|
||||
:note: Date can also be YYYY.MM.DD, MM/DD/YYYY and DD.MM.YYYY.
|
||||
"""
|
||||
if isinstance(string_date, datetime):
|
||||
if string_date.tzinfo:
|
||||
utcoffset = cast(timedelta, string_date.utcoffset()) # typeguard, if tzinfoand is not None
|
||||
offset = -int(utcoffset.total_seconds())
|
||||
return int(string_date.astimezone(utc).timestamp()), offset
|
||||
else:
|
||||
raise ValueError(f"string_date datetime object without tzinfo, {string_date}")
|
||||
|
||||
# git time
|
||||
try:
|
||||
if string_date.count(" ") == 1 and string_date.rfind(":") == -1:
|
||||
timestamp, offset_str = string_date.split()
|
||||
if timestamp.startswith("@"):
|
||||
timestamp = timestamp[1:]
|
||||
timestamp_int = int(timestamp)
|
||||
return timestamp_int, utctz_to_altz(verify_utctz(offset_str))
|
||||
else:
|
||||
offset_str = "+0000" # local time by default
|
||||
if string_date[-5] in "-+":
|
||||
offset_str = verify_utctz(string_date[-5:])
|
||||
string_date = string_date[:-6] # skip space as well
|
||||
# END split timezone info
|
||||
offset = utctz_to_altz(offset_str)
|
||||
|
||||
# now figure out the date and time portion - split time
|
||||
date_formats = []
|
||||
splitter = -1
|
||||
if "," in string_date:
|
||||
date_formats.append("%a, %d %b %Y")
|
||||
splitter = string_date.rfind(" ")
|
||||
else:
|
||||
# iso plus additional
|
||||
date_formats.append("%Y-%m-%d")
|
||||
date_formats.append("%Y.%m.%d")
|
||||
date_formats.append("%m/%d/%Y")
|
||||
date_formats.append("%d.%m.%Y")
|
||||
|
||||
splitter = string_date.rfind("T")
|
||||
if splitter == -1:
|
||||
splitter = string_date.rfind(" ")
|
||||
# END handle 'T' and ' '
|
||||
# END handle rfc or iso
|
||||
|
||||
assert splitter > -1
|
||||
|
||||
# split date and time
|
||||
time_part = string_date[splitter + 1 :] # skip space
|
||||
date_part = string_date[:splitter]
|
||||
|
||||
# parse time
|
||||
tstruct = time.strptime(time_part, "%H:%M:%S")
|
||||
|
||||
for fmt in date_formats:
|
||||
try:
|
||||
dtstruct = time.strptime(date_part, fmt)
|
||||
utctime = calendar.timegm(
|
||||
(
|
||||
dtstruct.tm_year,
|
||||
dtstruct.tm_mon,
|
||||
dtstruct.tm_mday,
|
||||
tstruct.tm_hour,
|
||||
tstruct.tm_min,
|
||||
tstruct.tm_sec,
|
||||
dtstruct.tm_wday,
|
||||
dtstruct.tm_yday,
|
||||
tstruct.tm_isdst,
|
||||
)
|
||||
)
|
||||
return int(utctime), offset
|
||||
except ValueError:
|
||||
continue
|
||||
# END exception handling
|
||||
# END for each fmt
|
||||
|
||||
# still here ? fail
|
||||
raise ValueError("no format matched")
|
||||
# END handle format
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unsupported date format or type: {string_date}, type={type(string_date)}") from e
|
||||
# END handle exceptions
|
||||
|
||||
|
||||
# precompiled regex
|
||||
_re_actor_epoch = re.compile(r"^.+? (.*) (\d+) ([+-]\d+).*$")
|
||||
_re_only_actor = re.compile(r"^.+? (.*)$")
|
||||
|
||||
|
||||
def parse_actor_and_date(line: str) -> Tuple[Actor, int, int]:
|
||||
"""Parse out the actor (author or committer) info from a line like::
|
||||
|
||||
author Tom Preston-Werner <tom@mojombo.com> 1191999972 -0700
|
||||
|
||||
:return: [Actor, int_seconds_since_epoch, int_timezone_offset]"""
|
||||
actor, epoch, offset = "", "0", "0"
|
||||
m = _re_actor_epoch.search(line)
|
||||
if m:
|
||||
actor, epoch, offset = m.groups()
|
||||
else:
|
||||
m = _re_only_actor.search(line)
|
||||
actor = m.group(1) if m else line or ""
|
||||
return (Actor._from_string(actor), int(epoch), utctz_to_altz(offset))
|
||||
|
||||
|
||||
# } END functions
|
||||
|
||||
|
||||
# { Classes
|
||||
|
||||
|
||||
class ProcessStreamAdapter(object):
|
||||
|
||||
"""Class wireing all calls to the contained Process instance.
|
||||
|
||||
Use this type to hide the underlying process to provide access only to a specified
|
||||
stream. The process is usually wrapped into an AutoInterrupt class to kill
|
||||
it if the instance goes out of scope."""
|
||||
|
||||
__slots__ = ("_proc", "_stream")
|
||||
|
||||
def __init__(self, process: "Popen", stream_name: str) -> None:
|
||||
self._proc = process
|
||||
self._stream: StringIO = getattr(process, stream_name) # guessed type
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
return getattr(self._stream, attr)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Traversable(Protocol):
|
||||
|
||||
"""Simple interface to perform depth-first or breadth-first traversals
|
||||
into one direction.
|
||||
Subclasses only need to implement one function.
|
||||
Instances of the Subclass must be hashable
|
||||
|
||||
Defined subclasses = [Commit, Tree, SubModule]
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _get_intermediate_items(cls, item: Any) -> Sequence["Traversable"]:
|
||||
"""
|
||||
Returns:
|
||||
Tuple of items connected to the given item.
|
||||
Must be implemented in subclass
|
||||
|
||||
class Commit:: (cls, Commit) -> Tuple[Commit, ...]
|
||||
class Submodule:: (cls, Submodule) -> Iterablelist[Submodule]
|
||||
class Tree:: (cls, Tree) -> Tuple[Tree, ...]
|
||||
"""
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
@abstractmethod
|
||||
def list_traverse(self, *args: Any, **kwargs: Any) -> Any:
|
||||
""" """
|
||||
warnings.warn(
|
||||
"list_traverse() method should only be called from subclasses."
|
||||
"Calling from Traversable abstract class will raise NotImplementedError in 3.1.20"
|
||||
"Builtin sublclasses are 'Submodule', 'Tree' and 'Commit",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._list_traverse(*args, **kwargs)
|
||||
|
||||
def _list_traverse(
|
||||
self, as_edge: bool = False, *args: Any, **kwargs: Any
|
||||
) -> IterableList[Union["Commit", "Submodule", "Tree", "Blob"]]:
|
||||
"""
|
||||
:return: IterableList with the results of the traversal as produced by
|
||||
traverse()
|
||||
Commit -> IterableList['Commit']
|
||||
Submodule -> IterableList['Submodule']
|
||||
Tree -> IterableList[Union['Submodule', 'Tree', 'Blob']]
|
||||
"""
|
||||
# Commit and Submodule have id.__attribute__ as IterableObj
|
||||
# Tree has id.__attribute__ inherited from IndexObject
|
||||
if isinstance(self, Has_id_attribute):
|
||||
id = self._id_attribute_
|
||||
else:
|
||||
id = "" # shouldn't reach here, unless Traversable subclass created with no _id_attribute_
|
||||
# could add _id_attribute_ to Traversable, or make all Traversable also Iterable?
|
||||
|
||||
if not as_edge:
|
||||
out: IterableList[Union["Commit", "Submodule", "Tree", "Blob"]] = IterableList(id)
|
||||
out.extend(self.traverse(as_edge=as_edge, *args, **kwargs))
|
||||
return out
|
||||
# overloads in subclasses (mypy doesn't allow typing self: subclass)
|
||||
# Union[IterableList['Commit'], IterableList['Submodule'], IterableList[Union['Submodule', 'Tree', 'Blob']]]
|
||||
else:
|
||||
# Raise deprecationwarning, doesn't make sense to use this
|
||||
out_list: IterableList = IterableList(self.traverse(*args, **kwargs))
|
||||
return out_list
|
||||
|
||||
@abstractmethod
|
||||
def traverse(self, *args: Any, **kwargs: Any) -> Any:
|
||||
""" """
|
||||
warnings.warn(
|
||||
"traverse() method should only be called from subclasses."
|
||||
"Calling from Traversable abstract class will raise NotImplementedError in 3.1.20"
|
||||
"Builtin sublclasses are 'Submodule', 'Tree' and 'Commit",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._traverse(*args, **kwargs)
|
||||
|
||||
def _traverse(
|
||||
self,
|
||||
predicate: Callable[[Union["Traversable", "Blob", TraversedTup], int], bool] = lambda i, d: True,
|
||||
prune: Callable[[Union["Traversable", "Blob", TraversedTup], int], bool] = lambda i, d: False,
|
||||
depth: int = -1,
|
||||
branch_first: bool = True,
|
||||
visit_once: bool = True,
|
||||
ignore_self: int = 1,
|
||||
as_edge: bool = False,
|
||||
) -> Union[Iterator[Union["Traversable", "Blob"]], Iterator[TraversedTup]]:
|
||||
""":return: iterator yielding of items found when traversing self
|
||||
:param predicate: f(i,d) returns False if item i at depth d should not be included in the result
|
||||
|
||||
:param prune:
|
||||
f(i,d) return True if the search should stop at item i at depth d.
|
||||
Item i will not be returned.
|
||||
|
||||
:param depth:
|
||||
define at which level the iteration should not go deeper
|
||||
if -1, there is no limit
|
||||
if 0, you would effectively only get self, the root of the iteration
|
||||
i.e. if 1, you would only get the first level of predecessors/successors
|
||||
|
||||
:param branch_first:
|
||||
if True, items will be returned branch first, otherwise depth first
|
||||
|
||||
:param visit_once:
|
||||
if True, items will only be returned once, although they might be encountered
|
||||
several times. Loops are prevented that way.
|
||||
|
||||
:param ignore_self:
|
||||
if True, self will be ignored and automatically pruned from
|
||||
the result. Otherwise it will be the first item to be returned.
|
||||
If as_edge is True, the source of the first edge is None
|
||||
|
||||
:param as_edge:
|
||||
if True, return a pair of items, first being the source, second the
|
||||
destination, i.e. tuple(src, dest) with the edge spanning from
|
||||
source to destination"""
|
||||
|
||||
"""
|
||||
Commit -> Iterator[Union[Commit, Tuple[Commit, Commit]]
|
||||
Submodule -> Iterator[Submodule, Tuple[Submodule, Submodule]]
|
||||
Tree -> Iterator[Union[Blob, Tree, Submodule,
|
||||
Tuple[Union[Submodule, Tree], Union[Blob, Tree, Submodule]]]
|
||||
|
||||
ignore_self=True is_edge=True -> Iterator[item]
|
||||
ignore_self=True is_edge=False --> Iterator[item]
|
||||
ignore_self=False is_edge=True -> Iterator[item] | Iterator[Tuple[src, item]]
|
||||
ignore_self=False is_edge=False -> Iterator[Tuple[src, item]]"""
|
||||
|
||||
visited = set()
|
||||
stack: Deque[TraverseNT] = deque()
|
||||
stack.append(TraverseNT(0, self, None)) # self is always depth level 0
|
||||
|
||||
def addToStack(
|
||||
stack: Deque[TraverseNT],
|
||||
src_item: "Traversable",
|
||||
branch_first: bool,
|
||||
depth: int,
|
||||
) -> None:
|
||||
lst = self._get_intermediate_items(item)
|
||||
if not lst: # empty list
|
||||
return None
|
||||
if branch_first:
|
||||
stack.extendleft(TraverseNT(depth, i, src_item) for i in lst)
|
||||
else:
|
||||
reviter = (TraverseNT(depth, lst[i], src_item) for i in range(len(lst) - 1, -1, -1))
|
||||
stack.extend(reviter)
|
||||
|
||||
# END addToStack local method
|
||||
|
||||
while stack:
|
||||
d, item, src = stack.pop() # depth of item, item, item_source
|
||||
|
||||
if visit_once and item in visited:
|
||||
continue
|
||||
|
||||
if visit_once:
|
||||
visited.add(item)
|
||||
|
||||
rval: Union[TraversedTup, "Traversable", "Blob"]
|
||||
if as_edge: # if as_edge return (src, item) unless rrc is None (e.g. for first item)
|
||||
rval = (src, item)
|
||||
else:
|
||||
rval = item
|
||||
|
||||
if prune(rval, d):
|
||||
continue
|
||||
|
||||
skipStartItem = ignore_self and (item is self)
|
||||
if not skipStartItem and predicate(rval, d):
|
||||
yield rval
|
||||
|
||||
# only continue to next level if this is appropriate !
|
||||
nd = d + 1
|
||||
if depth > -1 and nd > depth:
|
||||
continue
|
||||
|
||||
addToStack(stack, item, branch_first, nd)
|
||||
# END for each item on work stack
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Serializable(Protocol):
|
||||
|
||||
"""Defines methods to serialize and deserialize objects from and into a data stream"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
# @abstractmethod
|
||||
def _serialize(self, stream: "BytesIO") -> "Serializable":
|
||||
"""Serialize the data of this object into the given data stream
|
||||
:note: a serialized object would ``_deserialize`` into the same object
|
||||
:param stream: a file-like object
|
||||
:return: self"""
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
# @abstractmethod
|
||||
def _deserialize(self, stream: "BytesIO") -> "Serializable":
|
||||
"""Deserialize all information regarding this object from the stream
|
||||
:param stream: a file-like object
|
||||
:return: self"""
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
|
||||
|
||||
class TraversableIterableObj(IterableObj, Traversable):
|
||||
__slots__ = ()
|
||||
|
||||
TIobj_tuple = Tuple[Union[T_TIobj, None], T_TIobj]
|
||||
|
||||
def list_traverse(self: T_TIobj, *args: Any, **kwargs: Any) -> IterableList[T_TIobj]:
|
||||
return super(TraversableIterableObj, self)._list_traverse(*args, **kwargs)
|
||||
|
||||
@overload # type: ignore
|
||||
def traverse(self: T_TIobj) -> Iterator[T_TIobj]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def traverse(
|
||||
self: T_TIobj,
|
||||
predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
|
||||
prune: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
|
||||
depth: int,
|
||||
branch_first: bool,
|
||||
visit_once: bool,
|
||||
ignore_self: Literal[True],
|
||||
as_edge: Literal[False],
|
||||
) -> Iterator[T_TIobj]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def traverse(
|
||||
self: T_TIobj,
|
||||
predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
|
||||
prune: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool],
|
||||
depth: int,
|
||||
branch_first: bool,
|
||||
visit_once: bool,
|
||||
ignore_self: Literal[False],
|
||||
as_edge: Literal[True],
|
||||
) -> Iterator[Tuple[Union[T_TIobj, None], T_TIobj]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def traverse(
|
||||
self: T_TIobj,
|
||||
predicate: Callable[[Union[T_TIobj, TIobj_tuple], int], bool],
|
||||
prune: Callable[[Union[T_TIobj, TIobj_tuple], int], bool],
|
||||
depth: int,
|
||||
branch_first: bool,
|
||||
visit_once: bool,
|
||||
ignore_self: Literal[True],
|
||||
as_edge: Literal[True],
|
||||
) -> Iterator[Tuple[T_TIobj, T_TIobj]]:
|
||||
...
|
||||
|
||||
def traverse(
|
||||
self: T_TIobj,
|
||||
predicate: Callable[[Union[T_TIobj, TIobj_tuple], int], bool] = lambda i, d: True,
|
||||
prune: Callable[[Union[T_TIobj, TIobj_tuple], int], bool] = lambda i, d: False,
|
||||
depth: int = -1,
|
||||
branch_first: bool = True,
|
||||
visit_once: bool = True,
|
||||
ignore_self: int = 1,
|
||||
as_edge: bool = False,
|
||||
) -> Union[Iterator[T_TIobj], Iterator[Tuple[T_TIobj, T_TIobj]], Iterator[TIobj_tuple]]:
|
||||
"""For documentation, see util.Traversable._traverse()"""
|
||||
|
||||
"""
|
||||
# To typecheck instead of using cast.
|
||||
import itertools
|
||||
from git.types import TypeGuard
|
||||
def is_commit_traversed(inp: Tuple) -> TypeGuard[Tuple[Iterator[Tuple['Commit', 'Commit']]]]:
|
||||
for x in inp[1]:
|
||||
if not isinstance(x, tuple) and len(x) != 2:
|
||||
if all(isinstance(inner, Commit) for inner in x):
|
||||
continue
|
||||
return True
|
||||
|
||||
ret = super(Commit, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self, as_edge)
|
||||
ret_tup = itertools.tee(ret, 2)
|
||||
assert is_commit_traversed(ret_tup), f"{[type(x) for x in list(ret_tup[0])]}"
|
||||
return ret_tup[0]
|
||||
"""
|
||||
return cast(
|
||||
Union[Iterator[T_TIobj], Iterator[Tuple[Union[None, T_TIobj], T_TIobj]]],
|
||||
super(TraversableIterableObj, self)._traverse(
|
||||
predicate, prune, depth, branch_first, visit_once, ignore_self, as_edge # type: ignore
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
# flake8: noqa
|
||||
# import all modules in order, fix the names they require
|
||||
from .symbolic import *
|
||||
from .reference import *
|
||||
from .head import *
|
||||
from .tag import *
|
||||
from .remote import *
|
||||
|
||||
from .log import *
|
||||
277
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/refs/head.py
Normal file
277
zero-cost-nas/.eggs/GitPython-3.1.31-py3.8.egg/git/refs/head.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from git.config import GitConfigParser, SectionConstraint
|
||||
from git.util import join_path
|
||||
from git.exc import GitCommandError
|
||||
|
||||
from .symbolic import SymbolicReference
|
||||
from .reference import Reference
|
||||
|
||||
# typinng ---------------------------------------------------
|
||||
|
||||
from typing import Any, Sequence, Union, TYPE_CHECKING
|
||||
|
||||
from git.types import PathLike, Commit_ish
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git.repo import Repo
|
||||
from git.objects import Commit
|
||||
from git.refs import RemoteReference
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
__all__ = ["HEAD", "Head"]
|
||||
|
||||
|
||||
def strip_quotes(string: str) -> str:
|
||||
if string.startswith('"') and string.endswith('"'):
|
||||
return string[1:-1]
|
||||
return string
|
||||
|
||||
|
||||
class HEAD(SymbolicReference):
|
||||
|
||||
"""Special case of a Symbolic Reference as it represents the repository's
|
||||
HEAD reference."""
|
||||
|
||||
_HEAD_NAME = "HEAD"
|
||||
_ORIG_HEAD_NAME = "ORIG_HEAD"
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, repo: "Repo", path: PathLike = _HEAD_NAME):
|
||||
if path != self._HEAD_NAME:
|
||||
raise ValueError("HEAD instance must point to %r, got %r" % (self._HEAD_NAME, path))
|
||||
super(HEAD, self).__init__(repo, path)
|
||||
self.commit: "Commit"
|
||||
|
||||
def orig_head(self) -> SymbolicReference:
|
||||
"""
|
||||
:return: SymbolicReference pointing at the ORIG_HEAD, which is maintained
|
||||
to contain the previous value of HEAD"""
|
||||
return SymbolicReference(self.repo, self._ORIG_HEAD_NAME)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
commit: Union[Commit_ish, SymbolicReference, str] = "HEAD",
|
||||
index: bool = True,
|
||||
working_tree: bool = False,
|
||||
paths: Union[PathLike, Sequence[PathLike], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> "HEAD":
|
||||
"""Reset our HEAD to the given commit optionally synchronizing
|
||||
the index and working tree. The reference we refer to will be set to
|
||||
commit as well.
|
||||
|
||||
:param commit:
|
||||
Commit object, Reference Object or string identifying a revision we
|
||||
should reset HEAD to.
|
||||
|
||||
:param index:
|
||||
If True, the index will be set to match the given commit. Otherwise
|
||||
it will not be touched.
|
||||
|
||||
:param working_tree:
|
||||
If True, the working tree will be forcefully adjusted to match the given
|
||||
commit, possibly overwriting uncommitted changes without warning.
|
||||
If working_tree is True, index must be true as well
|
||||
|
||||
:param paths:
|
||||
Single path or list of paths relative to the git root directory
|
||||
that are to be reset. This allows to partially reset individual files.
|
||||
|
||||
:param kwargs:
|
||||
Additional arguments passed to git-reset.
|
||||
|
||||
:return: self"""
|
||||
mode: Union[str, None]
|
||||
mode = "--soft"
|
||||
if index:
|
||||
mode = "--mixed"
|
||||
|
||||
# it appears, some git-versions declare mixed and paths deprecated
|
||||
# see http://github.com/Byron/GitPython/issues#issue/2
|
||||
if paths:
|
||||
mode = None
|
||||
# END special case
|
||||
# END handle index
|
||||
|
||||
if working_tree:
|
||||
mode = "--hard"
|
||||
if not index:
|
||||
raise ValueError("Cannot reset the working tree if the index is not reset as well")
|
||||
|
||||
# END working tree handling
|
||||
|
||||
try:
|
||||
self.repo.git.reset(mode, commit, "--", paths, **kwargs)
|
||||
except GitCommandError as e:
|
||||
# git nowadays may use 1 as status to indicate there are still unstaged
|
||||
# modifications after the reset
|
||||
if e.status != 1:
|
||||
raise
|
||||
# END handle exception
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class Head(Reference):
|
||||
|
||||
"""A Head is a named reference to a Commit. Every Head instance contains a name
|
||||
and a Commit object.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> repo = Repo("/path/to/repo")
|
||||
>>> head = repo.heads[0]
|
||||
|
||||
>>> head.name
|
||||
'master'
|
||||
|
||||
>>> head.commit
|
||||
<git.Commit "1c09f116cbc2cb4100fb6935bb162daa4723f455">
|
||||
|
||||
>>> head.commit.hexsha
|
||||
'1c09f116cbc2cb4100fb6935bb162daa4723f455'"""
|
||||
|
||||
_common_path_default = "refs/heads"
|
||||
k_config_remote = "remote"
|
||||
k_config_remote_ref = "merge" # branch to merge from remote
|
||||
|
||||
@classmethod
|
||||
def delete(cls, repo: "Repo", *heads: "Union[Head, str]", force: bool = False, **kwargs: Any) -> None:
|
||||
"""Delete the given heads
|
||||
|
||||
:param force:
|
||||
If True, the heads will be deleted even if they are not yet merged into
|
||||
the main development stream.
|
||||
Default False"""
|
||||
flag = "-d"
|
||||
if force:
|
||||
flag = "-D"
|
||||
repo.git.branch(flag, *heads)
|
||||
|
||||
def set_tracking_branch(self, remote_reference: Union["RemoteReference", None]) -> "Head":
|
||||
"""
|
||||
Configure this branch to track the given remote reference. This will alter
|
||||
this branch's configuration accordingly.
|
||||
|
||||
:param remote_reference: The remote reference to track or None to untrack
|
||||
any references
|
||||
:return: self"""
|
||||
from .remote import RemoteReference
|
||||
|
||||
if remote_reference is not None and not isinstance(remote_reference, RemoteReference):
|
||||
raise ValueError("Incorrect parameter type: %r" % remote_reference)
|
||||
# END handle type
|
||||
|
||||
with self.config_writer() as writer:
|
||||
if remote_reference is None:
|
||||
writer.remove_option(self.k_config_remote)
|
||||
writer.remove_option(self.k_config_remote_ref)
|
||||
if len(writer.options()) == 0:
|
||||
writer.remove_section()
|
||||
else:
|
||||
writer.set_value(self.k_config_remote, remote_reference.remote_name)
|
||||
writer.set_value(
|
||||
self.k_config_remote_ref,
|
||||
Head.to_full_path(remote_reference.remote_head),
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def tracking_branch(self) -> Union["RemoteReference", None]:
|
||||
"""
|
||||
:return: The remote_reference we are tracking, or None if we are
|
||||
not a tracking branch"""
|
||||
from .remote import RemoteReference
|
||||
|
||||
reader = self.config_reader()
|
||||
if reader.has_option(self.k_config_remote) and reader.has_option(self.k_config_remote_ref):
|
||||
ref = Head(
|
||||
self.repo,
|
||||
Head.to_full_path(strip_quotes(reader.get_value(self.k_config_remote_ref))),
|
||||
)
|
||||
remote_refpath = RemoteReference.to_full_path(join_path(reader.get_value(self.k_config_remote), ref.name))
|
||||
return RemoteReference(self.repo, remote_refpath)
|
||||
# END handle have tracking branch
|
||||
|
||||
# we are not a tracking branch
|
||||
return None
|
||||
|
||||
def rename(self, new_path: PathLike, force: bool = False) -> "Head":
|
||||
"""Rename self to a new path
|
||||
|
||||
:param new_path:
|
||||
Either a simple name or a path, i.e. new_name or features/new_name.
|
||||
The prefix refs/heads is implied
|
||||
|
||||
:param force:
|
||||
If True, the rename will succeed even if a head with the target name
|
||||
already exists.
|
||||
|
||||
:return: self
|
||||
:note: respects the ref log as git commands are used"""
|
||||
flag = "-m"
|
||||
if force:
|
||||
flag = "-M"
|
||||
|
||||
self.repo.git.branch(flag, self, new_path)
|
||||
self.path = "%s/%s" % (self._common_path_default, new_path)
|
||||
return self
|
||||
|
||||
def checkout(self, force: bool = False, **kwargs: Any) -> Union["HEAD", "Head"]:
|
||||
"""Checkout this head by setting the HEAD to this reference, by updating the index
|
||||
to reflect the tree we point to and by updating the working tree to reflect
|
||||
the latest index.
|
||||
|
||||
The command will fail if changed working tree files would be overwritten.
|
||||
|
||||
:param force:
|
||||
If True, changes to the index and the working tree will be discarded.
|
||||
If False, GitCommandError will be raised in that situation.
|
||||
|
||||
:param kwargs:
|
||||
Additional keyword arguments to be passed to git checkout, i.e.
|
||||
b='new_branch' to create a new branch at the given spot.
|
||||
|
||||
:return:
|
||||
The active branch after the checkout operation, usually self unless
|
||||
a new branch has been created.
|
||||
If there is no active branch, as the HEAD is now detached, the HEAD
|
||||
reference will be returned instead.
|
||||
|
||||
:note:
|
||||
By default it is only allowed to checkout heads - everything else
|
||||
will leave the HEAD detached which is allowed and possible, but remains
|
||||
a special state that some tools might not be able to handle."""
|
||||
kwargs["f"] = force
|
||||
if kwargs["f"] is False:
|
||||
kwargs.pop("f")
|
||||
|
||||
self.repo.git.checkout(self, **kwargs)
|
||||
if self.repo.head.is_detached:
|
||||
return self.repo.head
|
||||
else:
|
||||
return self.repo.active_branch
|
||||
|
||||
# { Configuration
|
||||
def _config_parser(self, read_only: bool) -> SectionConstraint[GitConfigParser]:
|
||||
if read_only:
|
||||
parser = self.repo.config_reader()
|
||||
else:
|
||||
parser = self.repo.config_writer()
|
||||
# END handle parser instance
|
||||
|
||||
return SectionConstraint(parser, 'branch "%s"' % self.name)
|
||||
|
||||
def config_reader(self) -> SectionConstraint[GitConfigParser]:
|
||||
"""
|
||||
:return: A configuration parser instance constrained to only read
|
||||
this instance's values"""
|
||||
return self._config_parser(read_only=True)
|
||||
|
||||
def config_writer(self) -> SectionConstraint[GitConfigParser]:
|
||||
"""
|
||||
:return: A configuration writer instance with read-and write access
|
||||
to options of this head"""
|
||||
return self._config_parser(read_only=False)
|
||||
|
||||
# } END configuration
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user