update GDAS
This commit is contained in:
		| @@ -62,6 +62,8 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-201/train-a-net.sh '|nor_ | |||||||
| `|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure. | `|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | **Tensorflow codes for GDAS are in experimental state**, which locates at `exps-tf`. | ||||||
|  |  | ||||||
| # Citation | # Citation | ||||||
|  |  | ||||||
| If you find that this project helps your research, please consider citing the following paper: | If you find that this project helps your research, please consider citing the following paper: | ||||||
|   | |||||||
| @@ -1,4 +1,9 @@ | |||||||
|  | # [D-X-Y] | ||||||
|  | # Run GDAS | ||||||
| # CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py | # CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py | ||||||
|  | # Run DARTS | ||||||
|  | # CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py --tau_max -1 --tau_min -1 --epochs 50 | ||||||
|  | # | ||||||
| import os, sys, math, time, random, argparse | import os, sys, math, time, random, argparse | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|   | |||||||
| @@ -7,10 +7,10 @@ from copy import deepcopy | |||||||
| from ..cell_operations import OPS | from ..cell_operations import OPS | ||||||
|  |  | ||||||
|  |  | ||||||
| class SearchCell(tf.keras.layers.Layer): | class NAS201SearchCell(tf.keras.layers.Layer): | ||||||
|  |  | ||||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False): |   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False): | ||||||
|     super(SearchCell, self).__init__() |     super(NAS201SearchCell, self).__init__() | ||||||
|  |  | ||||||
|     self.op_names  = deepcopy(op_names) |     self.op_names  = deepcopy(op_names) | ||||||
|     self.max_nodes = max_nodes |     self.max_nodes = max_nodes | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import tensorflow as tf | |||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import ResNetBasicblock | from ..cell_operations import ResNetBasicblock | ||||||
| from .search_cells     import SearchCell | from .search_cells     import NAS201SearchCell as SearchCell | ||||||
|  |  | ||||||
|  |  | ||||||
| def sample_gumbel(shape, eps=1e-20): | def sample_gumbel(shape, eps=1e-20): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user