How Darknet Works

Tag: Deep Learning

Published on: 21 Sep 2021


Recently I came across Darknet, a lightweight neural network framework elegantly written in C and CUDA. As a guy interested in hardware, I often feel being spoiled by PyTorch and wonder how things work under the hood. So I dig further into the implementation of Darknet and summarize my findings in this blog post.

aa Image Source: pjreddie/darknet

As an example use case, we set the Darknet into detector mode and use the YOLOv3-tiny network definition.

Darknet program interface

Usage: ./darknet detector train DATASET_DEF.data NETWORK_DEF.cfg CHECKPOINT_FILE

The first argument 'detector' will trigger the run_detector(argc, argv) function call

    .....
    } else if (0 == strcmp(argv[1], "lsd")){
        run_lsd(argc, argv);
    } else if (0 == strcmp(argv[1], "detector")){
        run_detector(argc, argv); // <----- We are here
    } else if (0 == strcmp(argv[1], "detect")){
        float thresh = find_float_arg(argc, argv, "-thresh", .5);
        char *filename = (argc > 4) ? argv[4]: 0;
        char *outfile = find_char_arg(argc, argv, "-out", 0);
        int fullscreen = find_arg(argc, argv, "-fullscreen");
        test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, .5, outfile, fullscreen);
    } else if (0 == strcmp(argv[1], "cifar")){
        run_cifar(argc, argv);
    } else if (0 == strcmp(argv[1], "go")){
    .....

Training the object detector

The run_detector(argc, argv) function call requires in addition the following arguments:

argv[2]: [train/test/valid]
argv[3]: [cfg]
argv[4]: [weights (optional)]

followed by optional CLI arguments in key-value pairs.

If optional arguments are given, they will be parsed with the following lines:

char *prefix = find_char_arg(argc, argv, "-prefix", 0);
float thresh = find_float_arg(argc, argv, "-thresh", .5);
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
...
// In each parser call a default value is given as the last argument

Next, we can find a call to train_detector with the arguments parsed:

...
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);
...

Things get interesting now. Inside train_detector we can find two central part of a DL framework, namely:

  • “Instantiation” of the NN from model cfg files
  • Restoring weights from previous checkpoints

Model instantiation

Before diving directly into the code, let’s first take a look at the model definition files, which locates themselves in the cfg directory.

We use yolov3-tiny.cfg as an example.

The first part of the cfg files always begin with a [net] section for defining parameters shared by the whole network:

[net]
# Testing
batch=1
subdivisions=1
# Training
# batch=64
# subdivisions=2
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1

Next, we can find definitions of each NN layers:

[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

...etc

Programmatically, the cfg file is first parsed into a list of lists. The first level list corresponds to each section in the network definition file and the second level is essentially a key-value pair.

Since we don’t have a List type at hand, and we don’t know the length of this list prior to runtime, we have to implement them manually as linked lists:

typedef struct node{
    void *val;
    struct node *next;
    struct node *prev;
} node;

typedef struct list{
    int size;
    node *front;
    node *back;
} list;

Next, the nested lists is then converted to a network stored as a C struct where each layer is then parsed and stored as an element in the layer *layers member variable. Here you find how each layer is filled:

while(n){
    params.index = count;
    fprintf(stderr, "%5d ", count);
    s = (section *)n->val;
    options = s->options;
    layer l = {0};
    LAYER_TYPE lt = string_to_layer_type(s->type);
    if(lt == CONVOLUTIONAL){
        l = parse_convolutional(options, params); // <---- Here
    }else if(lt == DECONVOLUTIONAL){
        l = parse_deconvolutional(options, params);
    }else if(lt == LOCAL){
        l = parse_local(options, params);
    }else if(lt == ACTIVE){
...
        net->layers[count] = l;
...

here params define the input parameters for each network layer e.g. height, width, channel, number of inputs (hwc). After each layer params will be updated according to the output of this layer.

options define the parameters for the layer itself, e.g. in the case of colvolutional_layer it will contain keys such as filters, size, stride etc.

When we look even deeper we can find how each layer is produced, e.g. inside parse_convolutional(list *options, size_params params) and make_convolutional_layer.

Here we can observe the following:

  • every layer share the same type, as a result, the “base” type layer contains member variables for every possible layer, which is a ugly language limitation
  • the forward and backward calls are implemented as function pointers, i.e., l.forward and l.backward

The forward call

At runtime, the forward and backward functions take two arguments: layer and net. Again consider the convolutional_layer function:

int m = l.n/l.groups;
// number of filters per group
// l.n: number of output channels
// l.groups: for group conv
int k = l.size*l.size*l.c/l.groups;
// number of corresponding input elements for each output convolution (per group)
// l.size: kernel size
// l.c: number of input channels
int n = l.out_w*l.out_h;
// number of output pixels
// l.out_w is calculated by the following formula (same applies to h):
// (l.w + 2*l.pad - l.size) / l.stride + 1

for(i = 0; i < l.batch; ++i){
    for(j = 0; j < l.groups; ++j){
        float *a = l.weights + j*l.nweights/l.groups;
                // l.nweights = c/groups*n*size*size;
                // find the address to current group's weights
        float *b = net.workspace;
                // for storing results from converting im into column vector
        float *c = l.output + (i*l.groups + j)*n*m;
                // here: n*m = filters * out_w * out_h / groups
                // is the number of output elements per group
        float *im =  net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
                // address of input corresponding to the batch and group

        if (l.size == 1) {
            b = im;
        } else {
            im2col_cpu(im, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
        }
        gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
                // matrix multiplication, c = ab
    }
}

There are two function calls worth noting here. First, im2col_cpu manipulate and replicate input image im according to the convolutional kernel parameters (size, stride, padding) to produce a column vector. Then, this vector will be fed into gemm for matrix multiplication and produce the convolution output.

im2col_cpu:

void im2col_cpu(float* data_im,
     int channels,  int height,  int width,
     int ksize,  int stride, int pad, float* data_col)
{
    int c,h,w;
        // input size after padding & stride
    int height_col = (height + 2*pad - ksize) / stride + 1;
    int width_col = (width + 2*pad - ksize) / stride + 1;

    int channels_col = channels * ksize * ksize;
    for (c = 0; c < channels_col; ++c) {
        int w_offset = c % ksize;
        int h_offset = (c / ksize) % ksize;
        int c_im = c / ksize / ksize;
        for (h = 0; h < height_col; ++h) {
            for (w = 0; w < width_col; ++w) {
                int im_row = h_offset + h * stride;
                int im_col = w_offset + w * stride;
                int col_index = (c * height_col + h) * width_col + w;
                data_col[col_index] = im2col_get_pixel(data_im, height, width, channels, im_row, im_col, c_im, pad);
            }
        }
    }
}

Demo

Assume we have a 2x2 RGB input image, after padding it will have the following memory footprint (stored in row-major order):

And a CxHxW = 3x3x3 kernel with padding=1 and stride=1 for convolution:

Now it can be calculated:

height_col = (2 + 2*1 - 3) / 1 + 1 = 2
width_col = (2 + 2*1 - 3) / 1 + 1 = 2
channels_col = 3 * 3 * 3 = 27

As a result, we will have a data_col with following memory footprint (right side):

(data_col is then a row-major array with 27x4 elements, note how data is copied from the original image)

Finally, the matrix multiplication function is called gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); which gives us the result \(C = A \cdot B\).

Depending on the first two arguments, different gemm variations will be called. Here we analyze the general form, gemm_nn:

void gemm_nn(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
        // M: no. of output channels
        // K: no. of channels (size * size * C)
        // N: no. of output pixels (out_w * out_h)
        // ALPHA: constant multiplier for matrix A; set to 1.0 (ignore)
        // A: weights (M x K), row-major
        // lda: no. of elements in each row for A, lda = K
        // B: input image processed with im2col function (K x N), row-major
        // ldb = N
        // C: output matrix (M x N)
        // ldc = N
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[i*lda+k];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}

After this step, we are mostly done with the forward propagation call. The only trivial operations left are an optional batch normalization and applying the non-linear activation.

Conclusion

This post gives a quick explanation of the Darknet framework by analyzing its data structure and forward propagation call. Written completely in C, Darknet can be easily adapted to embedded devices or platforms tight on computing resources.


© Chengxin Wang. All rights reserved.