diff --git a/mdns.c b/mdns.c index 6c91bf2..cfdd501 100644 --- a/mdns.c +++ b/mdns.c @@ -47,11 +47,6 @@ struct name_comp { // ----- label functions ----- -// compares 2 names -static inline int cmp_nlabel(const uint8_t *L1, const uint8_t *L2) { - return strcmp((char *) L1, (char *) L2); -} - // duplicates a name inline uint8_t *dup_nlabel(const uint8_t *n) { assert(n[0] <= 63); // prevent mis-use @@ -463,6 +458,23 @@ struct rr_entry *rr_entry_find(struct rr_list *rr_list, uint8_t *name, uint16_t return NULL; } +// looks for a matching entry in rr_list +// if entry is a PTR, we need to check if the PTR target also matches +struct rr_entry *rr_entry_match(struct rr_list *rr_list, struct rr_entry *entry) { + struct rr_list *rr = rr_list; + for (; rr; rr = rr->next) { + if (rr->e->type == entry->type && cmp_nlabel(rr->e->name, entry->name) == 0) { + if (entry->type != RR_PTR) { + return rr->e; + } else if (cmp_nlabel(MDNS_RR_GET_PTR_NAME(entry), MDNS_RR_GET_PTR_NAME(rr->e)) == 0) { + // if it's a PTR, we need to make sure PTR target also matches + return rr->e; + } + } + } + return NULL; +} + void rr_group_destroy(struct rr_group *group) { struct rr_group *g = group; diff --git a/mdns.h b/mdns.h index 210bc26..332d721 100644 --- a/mdns.h +++ b/mdns.h @@ -31,6 +31,7 @@ #include #include +#include #define MALLOC_ZERO_STRUCT(x, type) \ x = malloc(sizeof(struct type)); \ @@ -131,6 +132,9 @@ struct rr_group { #define MDNS_FLAG_GET_RCODE(x) (x & 0x0F) #define MDNS_FLAG_GET_OPCODE(x) ((x >> 11) & 0x0F) +// gets the PTR target name, either from "name" member or "entry" member +#define MDNS_RR_GET_PTR_NAME(rr) (rr->data.PTR.name != NULL ? rr->data.PTR.name : rr->data.PTR.entry->name) + struct mdns_pkt { uint16_t id; // transaction ID uint16_t flags; @@ -154,6 +158,7 @@ void mdns_pkt_destroy(struct mdns_pkt *p); void rr_group_destroy(struct rr_group *group); struct rr_group *rr_group_find(struct rr_group *g, uint8_t *name); struct rr_entry *rr_entry_find(struct rr_list *rr_list, uint8_t *name, uint16_t type); +struct rr_entry *rr_entry_match(struct rr_list *rr_list, struct rr_entry *entry); void rr_group_add(struct rr_group **group, struct rr_entry *rr); int rr_list_count(struct rr_list *rr); @@ -175,4 +180,9 @@ uint8_t *dup_label(const uint8_t *label); uint8_t *dup_nlabel(const uint8_t *n); uint8_t *join_nlabel(const uint8_t *n1, const uint8_t *n2); +// compares 2 names +static inline int cmp_nlabel(const uint8_t *L1, const uint8_t *L2) { + return strcmp((char *) L1, (char *) L2); +} + #endif /*!__MDNS_H__*/ diff --git a/mdnsd.c b/mdnsd.c index b671ac1..37ead2b 100644 --- a/mdnsd.c +++ b/mdnsd.c @@ -157,7 +157,6 @@ static ssize_t send_packet(int fd, const void *data, size_t len) { // type can be RR_ANY, which populates all entries EXCEPT RR_NSEC static int populate_answers(struct mdnsd *svr, struct rr_list **rr_head, uint8_t *name, enum rr_type type) { int num_ans = 0; - struct rr_entry *rr; // check if we have the records pthread_mutex_lock(&svr->data_lock); @@ -167,22 +166,18 @@ static int populate_answers(struct mdnsd *svr, struct rr_list **rr_head, uint8_t return num_ans; } - // include all records? - if (type == RR_ANY) { - struct rr_list *n = ans_grp->rr; - for (; n; n = n->next) { - // exclude NSEC - if (n->e->type == RR_NSEC) - continue; + // decide which records should go into answers + struct rr_list *n = ans_grp->rr; + for (; n; n = n->next) { + // exclude NSEC for RR_ANY + if (type == RR_ANY && n->e->type == RR_NSEC) + continue; + if (type == n->e->type && cmp_nlabel(name, n->e->name) == 0) { num_ans += rr_list_append(rr_head, n->e); } - } else { - // match record type - rr = rr_entry_find(ans_grp->rr, name, type); - if (rr) - num_ans += rr_list_append(rr_head, rr); } + pthread_mutex_unlock(&svr->data_lock); return num_ans; @@ -197,10 +192,7 @@ static void add_related_rr(struct mdnsd *svr, struct rr_list *list, struct mdns_ case RR_PTR: // target host A, AAAA records reply->num_add_rr += populate_answers(svr, &reply->rr_add, - (ans->data.PTR.name ? - ans->data.PTR.name : - ans->data.PTR.entry->name), - RR_ANY); + MDNS_RR_GET_PTR_NAME(ans), RR_ANY); break; case RR_SRV: @@ -266,6 +258,8 @@ static int process_mdns_pkt(struct mdnsd *svr, struct mdns_pkt *pkt, struct mdns struct rr_list *qnl = pkt->rr_qn; for (i = 0; i < pkt->num_qn; i++, qnl = qnl->next) { struct rr_entry *qn = qnl->e; + int num_ans_added = 0; + char *namestr = nlabel_to_str(qn->name); DEBUG_PRINTF("qn #%d: type 0x%02x %s - ", i, qn->type, namestr); free(namestr); @@ -276,17 +270,36 @@ static int process_mdns_pkt(struct mdnsd *svr, struct mdns_pkt *pkt, struct mdns continue; } - // see if it is in the answers - if (rr_entry_find(pkt->rr_ans, qn->name, qn->type)) { - DEBUG_PRINTF("our record is already in their answers\n"); - continue; + num_ans_added = populate_answers(svr, &reply->rr_ans, qn->name, qn->type); + reply->num_ans_rr += num_ans_added; + + DEBUG_PRINTF("added %d answers\n", num_ans_added); + } + + // remove our replies if they were already in their answers + struct rr_list *ans = NULL, *prev_ans = NULL; + for (ans = reply->rr_ans; ans; ) { + struct rr_list *next_ans = ans->next; + + if (rr_entry_match(pkt->rr_ans, ans->e)) { + // check if list item is head + if (prev_ans == NULL) + reply->rr_ans = ans->next; + else + prev_ans->next = ans->next; + free(ans); + + ans = prev_ans; + + // adjust answer count + reply->num_ans_rr--; } - reply->num_ans_rr += populate_answers(svr, &reply->rr_ans, qn->name, qn->type); - - DEBUG_PRINTF("adding %d answers\n", reply->num_ans_rr); + prev_ans = ans; + ans = next_ans; } + // see if we can match additional records for answers add_related_rr(svr, reply->rr_ans, reply);