terms.c

/*  File   : terms.c
    Author : Richard A. O'Keefe.
    Updated: 2010
    Purpose: Terms and Unification in C
*/
#include <ctype.h>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

/*  A term is
    - a VARIABLE
      + which may be UNBOUND
      + or BOUND to some term
    - or a NON-VARIABLE, which consists of
      a FUNCTION SYMBOL and a sequence of
      zero or more ARGUMENTS, which are terms.
      The number of arguments is called the ARITY.
      The combination of a function symbol and an arity,
      commonly written f/n, is called a FUNCTOR.

    A function symbol is an uninterpreted atom.

    Programming languages based on first order logic commonly
    admit numbers as a kind of non-variable term, in which case
    we distinguish between numbers and callable terms.
    An ATOMIC term is either a number or a function symbol with
    no arguments; a COMPOUND term is a term with 1 or more
    arguments.

    Issues:
      representing terms
      classifying terms
      dereferencing terms (returning something that is not a bound variable)
      testing whether an unbound variable occurs in a term
      reading terms
      writing terms
      testing whether two terms are identical
      binding (and trailing) a variable
      unifying two terms.

    Syntax:
      term
        : variable
        | atom
        | atom '(' term_sequence ')'
        | list
        ;
      term_sequence
        : term [ ',' term_sequence ]
        ;
      list
        : '[' [term_sequence ['|' term] ']'
        ;
      variable
        : /[_A-Z][A-Za-z0-9_]* /
        ;
      atom
        : /[a-z][A-Za-z0-9_]* /
        | /'([^']|'')*'/
        ;

    Lists are just a handy abbreviation.
    [] is the same as '[]'.
    [t1,t2,...,tn] is the same as [t1,t2,...,tn|[]]
    [t1|tt] is the same as '.'(t1,tt)
    [t1,t2,...,tn|tt] is the same as '.'(t1,[t2,...,tn|tt])

*/

/* ---------------------------------------------------------------------

    The string hash table holds a unique copy of each string.

*/
typedef struct Unique_String *ustr;
typedef struct Term_Info *term;


struct Unique_String {
    ustr            next;
    term            variable;   /* only used while reading */
    term            atom;       /* also used when reading */
    size_t          quotes;     /* how many extra ' does it need */
    size_t          size;
    char            name[4];
};

#define SEGBITS 9
#define SEGSIZE (1 << SEGBITS)
typedef ustr segment[SEGSIZE];
#define MAXLOAD 4
#define DIRBITS 9
#define DIRSIZE (1 << DIRBITS)
#define INIT_SEGS 8
#if INIT_SEGS > DIRSIZE
#error oops
#endif

struct HASHTABLE {
    unsigned long p;            /* next bucket to split */
    unsigned long maxp;         /* upper bound on p during expansion */
    unsigned long max_load;     /* maximum load factor (in keys): */
                                /* n_segs*SEGSIZE*MAXLOAD */
    unsigned long n_keys;       /* number of keys */
    unsigned long n_segs;       /* number of segments */
    segment      *directory[DIRSIZE];
};

static struct HASHTABLE string_table;

static void *emalloc(size_t size) {
    void *result = malloc(size);
    if (result == 0) {
        perror("malloc");
        exit(EXIT_FAILURE);
    }
}

void init_strtab(void) {
    struct HASHTABLE * const h = &string_table;
    int i, j;

    h->p = 0;
    h->n_segs = INIT_SEGS;
    h->maxp = h->n_segs << SEGBITS;
    h->max_load = MAXLOAD * h->maxp;
    h->n_keys = 0;

    for (i = 0; i < (int)h->n_segs; i++) {
        segment * const s = emalloc(sizeof *s);

        for (j = 0; j < SEGSIZE; j++) (*s)[j] = 0;
        h->directory[i] = s;
    }
    for (i = i; i < DIRSIZE; i++) h->directory[i] = 0;
}

static unsigned long hash_string(char const *str, size_t len) {
    unsigned char const *p = (unsigned char const *)str;
    unsigned long h;
    size_t L;

    h = 0;
    for (L = 3; L < len; L += 4, p += 4)
        h = h * 31 + (p[0] + (p[1] << 3) + (p[2] << 6) + (p[3] << 9));
    switch (L - len) {
      case 0:
        h += p[2] << 6;         /*FALLTHROUGH*/
      case 1:
        h += p[1] << 3;         /*FALLTHROUGH*/
      case 2:
        h += p[0];              /*FALLTHROUGH*/
      default:;
    }
    return h;
}

static size_t count_quotes(char const *str, size_t len) {
    size_t i, q;

    if (len != 0 && islower(str[0])) {
        i = 1;
        while (i != len && (isalnum(str[i]) || str[i] == '_')) i++;
        if (i == len) return 0;
    }
    q = 2;
    for (i = 0; i != len; i++) if (str[i] == '\'') q++;
    return q;
}

ustr find_mem(char const *str, size_t len) {
    struct HASHTABLE * const h = &string_table;
    size_t const pad = ((0u-len) & ((sizeof (unsigned long)) - 1)) + 1;
    unsigned long const u0 = hash_string(str, len);

    {
        unsigned long const u1 = u0 & (h->maxp - 1);
        unsigned long const u  = u1 >= h->p ? u1 : u0 & ((h->maxp << 1) - 1);
        ustr * const b = &(*h->directory[u>>SEGBITS])[u & (SEGSIZE - 1)];
        ustr *p;        /* previous */
        ustr  c;        /* current  */

        for (p = b; (c = *p) != 0; p = &c->next) {
            if (c->size == len && 0 == memcmp(c->name, str, len)) {
                *p = c->next;
                c->next = *b;
                *b = c;
                return c;
            }
        }
        c = emalloc(offsetof(struct Unique_String, name) + len + pad);
        c->quotes = count_quotes(str, len);
        c->size = len;
        (void)memcpy(c->name, str, len);
        c->name[len] = '\0';
        c->next = *b;
        *b = c;
        if (++h->n_keys > h->max_load) {
            unsigned long const oldp = h->p;
            unsigned long const newp = oldp + h->maxp;

            if (newp < DIRSIZE * (unsigned long)SEGSIZE) {
                segment * const olds = h->directory[h->p >> SEGBITS];
                segment * news;

                if ((newp & (SEGSIZE-1)) == 0) {
                    int j;

                    news = emalloc(sizeof *news);
                    for (j = 0; j < SEGSIZE; j++) (*news)[j] = 0;
                    h->n_segs++;
                    h->directory[newp >> SEGBITS] = news;
                } else {
                    news = h->directory[newp >> SEGBITS];
                }
                h->max_load += MAXLOAD;
                {
                    unsigned long const mask = (h->maxp << 1) - 1;
                    ustr *prev = &(*olds)[oldp & (SEGSIZE-1)];
                    ustr *tail = &(*news)[newp & (SEGSIZE-1)];
                    ustr  curr;

                    while ((curr = *prev) != 0) {
                        unsigned long const slot =
                            mask & hash_string(curr->name, curr->size);

                        if (slot == newp) {
                            *tail = curr;
                            tail = &curr->next;
                            *prev = curr->next;
                        } else {
                            prev = &curr->next;
                        }
                    }
                    *tail = 0;
                }
                if (++h->p == h->maxp) {
                    h->maxp <<= 1;
                    h->p = 0;
                }
            }
        }
        return c;
    }
}


/*----------------------------------------------------------------------
    We'll represent a term by a pointer to a record containing
       - the arity, using -1 for an unbound variable or -2 for a bound one
       - if the term is a variable, a pointer for its value
       - if the term is not a variable, a pointer to a unique string
         for its function symbol, and 0 or more pointers for its
         arguments.

    We're not going to bother using hash consing for these terms.
    Function symbols do not change at run time, but terms do.
*/
struct Term_Info {
    int arity;
    union {
        term binding;
        ustr name;
    } u;
    term arg[];
};

#define BOUND_VARIABLE (-2)
#define UNBOUND_VARIABLE (-1)

term make_anonymous_variable(void) {
    term r = emalloc(offsetof(struct Term_Info, arg[0]));
    r->arity = UNBOUND_VARIABLE;
    r->u.binding = r;
    return r;
}

term make_variable(ustr name) {
    term r = name->variable;
    if (r == 0) {
        r = emalloc(offsetof(struct Term_Info, arg[0]));
        r->arity = UNBOUND_VARIABLE;
        r->u.binding = r;
        name->variable = r;
    }
    return r;
}

/* just for efficiency, we use hash consing for atoms */
term make_atom(ustr name) {
    term r = name->atom;
    if (r == 0) {
        r = emalloc(offsetof(struct Term_Info, arg[0]));
        r->arity = 0;
        r->u.name = name;
        name->atom = r;
    }
    return r;
}

term make_term(ustr name, int arity, term const arg[]) {
    term r = emalloc(offsetof(struct Term_Info, arg[arity]));
    int i;
    r->arity = arity,
    r->u.name = name;
    for (i = 0; i < arity; i++) r->arg[i] = arg[i];
    return r;
}

/*----------------------------------------------------------------------

    Parsing

*/
static int more_input(void) {
    int c;

    do {
        c = getchar();
        if (c == '%') {
            do c = getchar(); while (c >= 0 && c != '\n');
        }
        if (c < 0) return 0;
    } while (c <= ' ');
    ungetc(c, stdin);
    return 1;
}

static int next_non_blank(void) {
    int c;

    do {
        c = getchar();
        if (c == '%') {
            do c = getchar(); while (c >= 0 && c != '\n');
        }
        if (c < 0) {
            fprintf(stderr, "Unexpected end of file\n");
            exit(EXIT_FAILURE);
        }
    } while (c <= ' ');
    return c;
}

static void name_buffer_overflow(void) {
    fprintf(stderr, "Variable name or atom too long.\n");
    exit(EXIT_FAILURE);
}

static void term_buffer_overflow(void) {
    fprintf(stderr, "Term or list too long.\n");
    exit(EXIT_FAILURE);
}

static term nil_atom = 0;
static ustr dot_name = 0;

term read_term(void) {
    int c;
    char name[1024];
    char *p, *e;
    term args[100];
    term *t, *x;

    c = next_non_blank();
    p = name;
    e = name + sizeof name;

    if (isupper(c) || c == '_') {
        do {
            if (p == e) name_buffer_overflow();
            *p++ = c;
            c = getchar();
        } while (isalnum(c) || c == '_');
        ungetc(c, stdin);
        if (p == name+1 && name[0] == '_') {
            return make_anonymous_variable();
        } else {
            return make_variable(find_mem(name, p-name));
        }
    }

    if (c == '[') {
        if (nil_atom == 0) nil_atom = make_atom(find_mem("[]", 2));
        c = next_non_blank();
        if (c == ']') {
            return nil_atom;
        } else {
            term r, n;
            t = args;
            x = args + sizeof args;
            ungetc(c, stdin);
            do {
                if (t == x) term_buffer_overflow();
                *t++ = read_term();
                c = next_non_blank();
            } while (c == ',');
            if (c == '|') {
                r = read_term();
                c = next_non_blank();
            } else {
                r = nil_atom;
            }
            if (c != ']') {
                fprintf(stderr, "Missing ]\n");
                exit(EXIT_FAILURE);
            }
            if (dot_name == 0) dot_name = find_mem(".", 1);
            while (t != args) {
                n = emalloc(offsetof(struct Term_Info, arg[2]));
                n->arity = 2, n->u.name = dot_name,
                n->arg[0] = *--t, n->arg[1] = r;
                r = n;
            }
            return r;
        }
    }

    if (islower(c)) {
        do {
            if (p == e) name_buffer_overflow();
            *p++ = c;
            c = getchar();
        } while (isalnum(c) || c == '_');
    } else
    if (c == '\'') {
        for (;;) {
            c = getchar();
            if (c < 0) {
                fprintf(stderr, "Unexpected EOF\n");
                exit(EXIT_FAILURE);
            }
            if (c == '\'') {
                c = getchar();
                if (c != '\'') break;
            }
            if (p == e) name_buffer_overflow();
            *p++ = c;
        }
    }
    if (c != '(') {
        ungetc(c, stdin);
        return make_atom(find_mem(name, p-name));
    } else {
        t = args;
        x = args + sizeof args;
        do {
            if (t == x) term_buffer_overflow();
            *t++ = read_term();
            c = next_non_blank();
        } while (c == ',');
        if (c != ')') {
            fprintf(stderr, "Missing )\n");
            exit(EXIT_FAILURE);
        }
        return make_term(find_mem(name, p-name), t-args, args);
    }

    fprintf(stderr, "A term cannot begin with '%c'\n");
    exit(EXIT_FAILURE);
}


/*----------------------------------------------------------------------

    Classifying and decomposing terms

*/

#define is_bound_variable(t)   ((t)->arity == BOUND_VARIABLE)
#define is_unbound_variable(t) ((t)->arity == UNBOUND_VARIABLE)
#define is_variable(t)         ((t)->arity < 0)
#define is_not_variable(t)     ((t)->arity >= 0)
#define is_atom(t)             ((t)->arity == 0)
#define is_compound(t)         ((t)->arity > 0)
#define is_dotted_pair(t)      ((t)->arity == 2 && (t)->u.name == dot_name)
#define variable_binding(t)    ((t)->u.binding)

term dereference(term t) {
    while (is_bound_variable(t)) t = variable_binding(t);
    return t;
}

term argument(term t, int i) {
    t = dereference(t);
    if (i < 0 || i >= t->arity) {
        fprintf(stderr, "argument index out of range\n");
        abort();
    }
    return t->arg[i];
}

/*----------------------------------------------------------------------

    Printing

*/

void print_term(term t) {
    t = dereference(t);
    if (is_unbound_variable(t)) {
        printf("V%p", t);
    } else {
        char  *p = t->u.name->name;
        size_t n = t->u.name->size;
        size_t i;

        if (t->u.name->quotes == 0) {
            for (i = 0; i != n; i++) putchar(p[i]);
        } else {
            putchar('\'');
            for (i = 0; i != n; i++) {
                if (p[i] == '\'') putchar('\'');
                putchar(p[i]);
            }
            putchar('\'');
        }
        if (is_compound(t)) {
            int a;

            for (a = 0; a < t->arity; a++) {
                printf(a == 0 ? "(" : ", ");
                print_term(t->arg[a]);
            }
            printf(")");
        }
    }
}

#if 0

int main(void) {
    init_strtab();
    while (more_input()) {
        print_term(read_term());
        printf("\n");
    }
}

#endif

/*----------------------------------------------------------------------

    Testing whether two terms are the same

    Two terms are the same if, after replacing bound variables
    by their values, either
        - both are the same unbound variable or
        - both have the same arity and function symbol,
          and corresponding arguments are the same.

*/
int same(term x, term y) {
    int i;

    x = dereference(x);
    y = dereference(y);
    if (is_unbound_variable(x)) return x == y;
    if (is_unbound_variable(y)) return 0;
    if (x->arity != y->arity) return 0;
    if (x->u.name != y->u.name) return 0;
    for (i = 0; i < x->arity; i++) {
        if (!same(x->arg[i], y->arg[i])) return 0;
    }
    return 1;
}

/*----------------------------------------------------------------------

    Checking whether a term contains a variable.
    We don't want X = f(X) to succeed making X = f(f(f(f(f(f....
    This check is called the "occurs check".
    For efficiency, most Prolog systems omit it.
    Alain Colmerauer showed that this is sort of legitimate
    as long as you admit to using a non-standard version of
    equality.
*/

int contains(term t, term v) {
    t = dereference(t);
    if (is_unbound_variable(t)) {
        return t == v;
    } else {
        int i;
        for (i = 0; i < t->arity; i++) {
            if (contains(t->arg[i], v)) return 1;
        }
        return 0;
    }
}

/*----------------------------------------------------------------------

    In the unification process, we try to make two terms the same
    by binding variables to values.  This is a side effect!
    If the match fails, we want to undo it.
    Indeed, in the context of a full system, we might be backtracking
    over alternative proofs, and would need to undo successful matches.
    So we maintain a "trail".

    In general, a trail is a history of inverse actions.
    Before performing a side effect E, you push -E on the trail,
    where -E is the action that will undo E.
    In our case, the only action is "bind a variable", so if we
    know what the variable is, we know everything we need.

    With hindsight, I would have done better to thread the trail
    through the variable terms themselves.  I was too heavily
    influenced by Prolog, which uses some cleverness to avoid
    trailing bindings in the first place whenever they are sure
    to disappear before we need to undo them.
*/

term *trail_base = 0;
int   trail_top  = 0;
int   trail_limit = 0;

static void bind(term v, term t) {
    if (trail_top == trail_limit) {
        int   new_limit = trail_top == 0 ? 1000 : trail_limit*2;
        term *new_trail = emalloc(new_limit * sizeof *trail_base);
        memcpy(new_trail, trail_base, trail_top * sizeof *trail_base);
        free(trail_base);
        trail_base  = new_trail;
        trail_limit = new_limit;
    }
    trail_base[trail_top++] = v;
    v->u.binding = t;
    v->arity = BOUND_VARIABLE;
}

static void unwind_to(int old_top) {
    while (trail_top != old_top) {
        term v = trail_base[--trail_top];
        v->u.binding = v;
        v->arity = UNBOUND_VARIABLE;
    }
}

/*----------------------------------------------------------------------

    Now for the big one:

        U N I F I C A T I O N !

    After all that preparation, it's simple.

    To unify two terms x and y:
        dereference them.
        if x is an unbound variable
            if y is x do nothing and succeed.
            if y is not x and y contains x, fail.
            otherwise bind x to y and succeed.
        if y is an unbound variable
            the same only the other way around.
        if x and y have different arities or function symbols, fail.
        otherwise, unify corresponding arguments,
        and fail if any of the recursive calls fails.
*/

static int unify_worker(term x, term y) {
    int i;

    x = dereference(x);
    y = dereference(y);
    if (is_unbound_variable(x)) {
        if (y == x) return 1;
        if (contains(y, x)) return 0;
        bind(x, y);
    } else
    if (is_unbound_variable(y)) {
        if (x == y) return 1;
        if (contains(x, y)) return 0;
        bind(y, x);
    } else {
        if (x->arity != y->arity) return 0;
        if (x->u.name != y->u.name) return 0;
        for (i = 0; i < x->arity; i++) {
            if (!unify_worker(x->arg[i], y->arg[i])) return 0;
        }
    }
    return 1;
}

int unify(term x, term y) {
    int old_trail = trail_top;

    if (unify_worker(x, y)) return 1;
    unwind_to(old_trail);
    return 0;
}


/*----------------------------------------------------------------------

    Testing.

*/

int main(void) {
    term x, y;
    int old_trail;
    int c;

    init_strtab();
    while (more_input()) {
        x = read_term();
        if (next_non_blank() != '=') {
            fprintf(stderr, "Missing =\n");
            exit(EXIT_FAILURE);
        }
        y = read_term();
        if (next_non_blank() != '.') {
            fprintf(stderr, "Missing .\n");
            exit(EXIT_FAILURE);
        }
        old_trail = trail_top;
        if (unify_worker(x, y)) {
            printf("YES: ");
            print_term(x);
            printf(".\n\n");
        } else {
            printf("NO\n\n");
        }
        unwind_to(old_trail);
    }
    return 0;
}