decryst/src/decr_sas.c

317 lines
9.1 KiB
C

#define _XOPEN_SOURCE 500 // snprintf().
#include <math.h>
#include <pthread.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <zmq.h>
#include <arpa/inet.h>
#include "rng.h"
#include "cryst.h"
#include "optim.h"
#include "netio.h"
#include "utils.h"
#include "int_netsa.h"
#define ADDR_MAX 32
#define ERR_GOTO(DESC, DEST) { desc = DESC; goto DEST; }
#define ERR_PROC(RET) \
((void) (desc && fprintf (stderr, "E: %s in %s()\n", desc, __func__)), RET)
struct worker_t {
struct crin_t const *crin;
void *ctx, *buf;
pthread_t thread;
unsigned const (*sizes)[2];
unsigned idx;
int ret;
};
static int addr_print (char buf[], unsigned len, unsigned idx) {
int rc = snprintf (buf, len, "inproc://%u", idx);
return rc > 0 && rc < len;
}
static unsigned crin_dof (struct crin_t const *crin) {
rng_t *rng;
crystal *cr;
char const *desc = NULL;
unsigned dof = 0;
if (!(rng = rng_mk2 ())) ERR_GOTO ("failed rng_mk2()", rng_err)
if (!(cr = crin_eval (crin, rng))) ERR_GOTO ("failed crin_eval()", cr_err)
if (!(dof = cryst_dof (cr))) fprintf
(stderr, "E: dof == 0 in %s()\n", __func__);
cryst_fin (cr); cr_err:
free (rng); rng_err:
return ERR_PROC(dof);
}
static int recv_chk (
void *client, void *buf_, uint8_t *cmd_,
unsigned const sizes[CMD_CNT + 1][2], unsigned n
) {
int rc = zmq_recv (client, buf_, n * sizes[CMD_CNT][0], 0);
if (rc < 1) return 0;
uint8_t *buf = buf_, cmd = *buf;
if (cmd >= CMD_CNT) return 0;
unsigned size = sizes[cmd][0], i;
if (rc != n * size) return 0;
for (i = 1; i < n; ++i) if (*(buf + i * size) != cmd) return 0;
*cmd_ = cmd;
return 1;
}
static int send_chk (void *buf_, unsigned const size, unsigned n) {
uint8_t const *buf = buf_;
for (unsigned i = 0; i < n; ++i, buf += size) if (!buf[1]) return 0;
return 1;
}
static int worker_mk (
rng_t **rng, crystal **cr, sa_comp **comp, struct crin_t const *crin
) {
char const *desc = NULL;
if (!(*rng = rng_mk2 ())) ERR_GOTO ("failed rng_mk2()", rng_err)
if (!(*cr = crin_eval (crin, *rng))) ERR_GOTO
("failed crin_eval()", cr_err)
if (!(*comp = sa_comp_mk (*cr, *rng))) ERR_GOTO
("failed sa_comp_mk()", comp_err)
if (!cryst_eval (*cr)) ERR_GOTO ("failed cryst_eval()", retn_err)
cryst_ack (*cr, 1);
return 1; retn_err:
sa_comp_fin (*comp); comp_err:
cryst_fin (*cr); cr_err:
free (*rng); rng_err:
return ERR_PROC(0);
}
// Seems to somehow trigger `-Wmaybe-uninitialized'?
// <https://gcc.gnu.org/bugzilla/show_bug.cgi?id=79768>.
static void worker_fin (rng_t *rng, crystal *cr, sa_comp *comp) {
sa_comp_fin (comp);
cryst_fin (cr);
free (rng);
}
static void *worker_main (struct worker_t *worker) {
rng_t *rng;
crystal *cr;
sa_comp *comp;
void *client;
char const *desc = NULL;
char addr[ADDR_MAX];
float const *scores, *best;
float ctl[8], stat[3];
uint8_t *cb = worker->buf;
uint32_t *ub = worker->buf;
unsigned const (*sizes)[2] = worker->sizes, *bbuf;
unsigned n[4], i;
if (!addr_print (addr, sizeof (addr), worker->idx)) ERR_GOTO
("failed snprintf()", client_err)
if (!(client = zmq_socket (worker->ctx, ZMQ_REP))) ERR_GOTO
("failed zmq_socket()", client_err)
if (zmq_bind (client, addr)) ERR_GOTO ("failed zmq_bind()", worker_err)
int flag = worker_mk (&rng, &cr, &comp, worker->crin);
if (flag) {
scores = cryst_scores (cr);
best = sa_comp_best (comp);
bbuf = sa_comp_bbuf (comp);
}
while (1) {
if (zmq_recv (
client, worker->buf, sizes[CMD_CNT][0], 0
) == -1) ERR_GOTO ("failed zmq_recv()", retn_err)
if (flag) switch (cb[0]) {
case CMD_COMP:
memcpy (ctl, ub + 1, 7 * sizeof (uint32_t));
nltohf2 (ctl, 7);
n[1] = ntohl (ub[8]);
cb[1] = arefinite (ctl, 7) && ctl[6] >= 0.0 && n[1] < n[2] &&
sa_comp_get (comp, n, ctl, n + 3, stat);
if (cb[1]) {
memcpy (ub + 1, stat, 3 * sizeof (uint32_t));
hftonl2 (ub + 1, 3);
ub[4] = htonl (n[3]);
}
break;
case CMD_DUMP:
cb[1] = 1;
for (i = 0; i < 3; ++i) ub[i + 1] = hftonl (scores[i]);
cryst_dump (cr, ub + 4);
break;
case CMD_BEST:
cb[1] = 1;
for (i = 0; i < 3; ++i) ub[i + 1] = hftonl (best[i]);
memcpy (ub + 4, bbuf, sizes[CMD_CNT][1]);
break;
case CMD_LOAD:
if ((
cb[1] = cryst_load (cr, ub + 1) && cryst_eval (cr)
)) cryst_ack (cr, 1);
break;
case CMD_INIT:
ctl[7] = nltohf (ub[1]);
n[0] = ntohl (ub[2]);
n[2] = ntohl (ub[3]);
cb[1] = isfinite (ctl[7]) && ctl[7] > 1.0 && n[0] && n[2];
break;
default:
break;
} else cb[1] = 0;
if (zmq_send (
client, worker->buf, sizes[cb[0]][1], 0
) == -1) ERR_GOTO ("failed zmq_send()", retn_err)
if (cb[0] == CMD_FIN) {
if (flag) break;
else goto retn_err;
}
}
worker->ret = 1; retn_err:
if (flag) { worker_fin (rng, cr, comp); } worker_err:
if (!mymq_close (client)) { worker->ret = 0; } client_err:
return ERR_PROC(NULL);
}
static struct hosts_t *workers_run (
void *ctx, struct crin_t const *crin, struct worker_t workers[],
unsigned const sizes[CMD_CNT + 1][2], unsigned n
) {
struct ihost_t *ihosts;
struct hosts_t *hosts;
char const *desc = NULL;
char addr[ADDR_MAX];
unsigned i;
for (i = 0; i < n; ++i) {
workers[i].crin = NULL;
if (!(workers[i].buf = malloc (sizes[CMD_CNT][0]))) ERR_GOTO
("failed malloc()", ihosts_err)
workers[i].crin = crin;
workers[i].ctx = ctx;
workers[i].sizes = sizes;
workers[i].idx = i;
if (pthread_create (
&workers[i].thread, NULL,
(void *(*) (void *)) &worker_main, workers + i
)) {
free (workers[i].buf);
workers[i].crin = NULL;
ERR_GOTO ("failed pthread_create()", ihosts_err)
}
}
if (!(ihosts = calloc (n, sizeof (struct ihost_t)))) ERR_GOTO
("failed calloc()", ihosts_err)
for (i = 0; i < n; ++i) {
if (!addr_print (addr, sizeof (addr), i)) ERR_GOTO
("failed snprintf()", retn_err)
if (!(ihosts[i].socket = zmq_socket (ctx, ZMQ_REQ))) ERR_GOTO
("failed zmq_socket()", retn_err)
if (zmq_connect (ihosts[i].socket, addr)) ERR_GOTO
("failed zmq_connect()", retn_err)
ihosts[i].n = 1;
}
hosts = hosts_mk (ihosts, n);
free (ihosts);
return hosts; retn_err:
for (i = 0; i < n; ++i) {
if (ihosts[i].socket) mymq_close (ihosts[i].socket);
else break;
}
free (ihosts); ihosts_err:
return ERR_PROC(NULL);
}
static int workers_chk (struct worker_t workers[], unsigned n) {
int ret = 1;
for (unsigned i = 0; i < n; ++i) {
if (!workers[i].crin) break;
if (pthread_join (workers[i].thread, NULL) || !workers[i].ret) ret = 0;
free (workers[i].buf);
}
return ret;
}
int main (int argc, char const *const *argv) {
struct worker_t *workers;
struct hosts_t *hosts;
struct crin_t *crin;
void *ctx, *client, *buf;
char const *desc = NULL;
uint8_t cmd;
unsigned sizes[CMD_CNT + 1][2] = {
{ 2 * sizeof (uint8_t), 0 },
{ sizeof (struct msg_init), 2 * sizeof (uint8_t) },
{ sizeof (struct msg_comp), sizeof (struct msg_cret) },
{ 2 * sizeof (uint8_t), 0 },
{ 2 * sizeof (uint8_t), 0 },
{ 0, 2 * sizeof (uint8_t) },
{ 0 }
}, n;
int flags[2] = { 0 };
if (!(argc == 3 && scan2_unsigned (argv[1], &n))) {
fprintf (stderr, "Usage: decr_sas num_of_cores zmq_addr < cryst.cr\n");
goto workers_err;
}
if (!(workers = calloc (n, sizeof (struct worker_t)))) ERR_GOTO
("failed calloc() for `workers'", workers_err)
if (!(ctx = zmq_ctx_new ())) ERR_GOTO ("failed zmq_ctx_new()", ctx_err)
if (!(client = zmq_socket (ctx, ZMQ_REP))) ERR_GOTO
("failed zmq_socket()", client_err)
if (zmq_bind (client, argv[2])) ERR_GOTO ("failed zmq_bind()", crin_err)
if (!(crin = crin_read (stdin))) ERR_GOTO ("failed crin_read()", crin_err)
if (!(sizes[CMD_CNT][1] = crin_dof (crin))) ERR_GOTO ("dof == 0", hosts_err)
sizes[CMD_DUMP][1] = sizes[CMD_BEST][1] =
(sizes[CMD_CNT][1] + 4) * sizeof (uint32_t);
sizes[CMD_LOAD][0] = (sizes[CMD_CNT][1] + 1) * sizeof (uint32_t);
sizes[CMD_CNT][0] = sizes[CMD_DUMP][1] < sizes[CMD_COMP][0] ?
sizes[CMD_COMP][0] : sizes[CMD_DUMP][1];
sizes[CMD_CNT][1] *= sizeof (uint32_t);
if (!(hosts = workers_run (ctx, crin, workers, sizes, n))) goto hosts_err;
if (!(buf = calloc (n, sizes[CMD_CNT][0]))) ERR_GOTO
("failed calloc() for `buf'", buf_err)
if (!recv_chk (client, buf, &cmd, sizes, n) || cmd != CMD_INIT) ERR_GOTO
("error with initial request", retn_err)
if (
!vec_send (hosts, buf, sizes[cmd][0]) ||
!vec_recv (hosts, buf, sizes[cmd][1]) ||
zmq_send (client, buf, n * sizes[cmd][1], 0) == -1
) ERR_GOTO ("error with initial reply", retn_err)
crin_fin (crin); crin = NULL;
while (1) {
flags[1] = send_chk (buf, sizes[cmd][1], n);
if (
!recv_chk (client, buf, &cmd, sizes, n) ||
cmd >= CMD_CNT || cmd == CMD_INIT ||
!(flags[1] || cmd == CMD_FIN)
) ERR_GOTO ("error with request", retn_err)
if (
!vec_send (hosts, buf, sizes[cmd][0]) ||
!vec_recv (hosts, buf, sizes[cmd][1])
) ERR_GOTO ("error with reply", retn_err)
if (cmd == CMD_FIN) break;
if (zmq_send (client, buf, n * sizes[cmd][1], 0) == -1) ERR_GOTO
("error with reply", retn_err)
}
flags[0] = 1; retn_err:
free (buf); buf_err:
hosts_fin (hosts); hosts_err:
if (crin) { crin_fin (crin); } crin_err:
if (!mymq_close (client)) { flags[0] = 0; } client_err:
if (zmq_ctx_destroy (ctx)) { flags[0] = 0; } ctx_err:
if (!workers_chk (workers, n)) flags[0] = 0;
free (workers); workers_err:
return ERR_PROC(!flags[0]);
}