By Gore, Libucha
Codenames is a team party game where cards with a single word on them are laid out on a table. Teams are comprised of two members: one who is trying to get their teammate to guess a word, and another who guesses depending on their teammate’s speech. The speaker is allowed to say one word and a number corresponding to how many words that word corresponds to. For the sake of our model, we restrict the number to 2 in order to simplify the model.
This game is apt for RSA modeling, as there is leveled reasoning between the speaker and the listener.
To establish priors for our listener, we used Stanford’s GloVe project which maps words to a 25 dimensional vector space.
Vectors have native support in WebPPL which allows our model to do mathematical operations on these objects.
Originally, we imported these from an external file; however, we only need a few vectors for this demonstration, so we will manually add these. We also define prior functions that sample randomly from our vectors object.
var vectors = {eagle : Vector([-0.8186906894583743, -0.8443627918594182,
-0.04304780086785447, -0.8257634263841377,
-0.7607218950809542, 0.47786735164930183,
0.36942709316422206, 0.18560148725224498,
0.38625176009619944, 0.24384273963053932,
1.0355862286322068, -0.14170089242313555,
-0.17017960843359828, 0.27636172471279313,
-0.49477465481497807, -1.199206930890509,
1.0531720839078256, -0.5154875303291531,
0.30704269353337016, 1.5382356443196483,
-0.13215425501400774, 1.2222507503066664,
1.3819617662995949, -1.1579407453927437,
0.9439311306043343]),
pig: Vector([-0.9027808771549458, -1.4539105978263833,
0.5743098399154295, 1.3052815987119957,
-0.038556210348244704, -0.22144102997326148,
1.222050088622139, -0.027526643946408857,
-0.13265827668708097, 1.4799207507145387,
0.02371336629548181, -0.9405402658175948,
0.06556493358788004, -1.6556208133885402,
-0.44306373689318584, -0.475710035110901,
1.2435716830499404, -1.0677780309283533,
-0.03344465447168945, 0.16184568683816827,
0.8718035460475897, 2.082956682688621,
0.47430271385843953, -0.4479993650378608,
1.5192928553678355]),
chicken: Vector([-0.9555633717916903, -0.23467550895948608,
0.9081102168032618, 1.7681919864431317,
-0.3888166871286516, 1.2292398003323308,
1.0624961440319318, -0.3558803892040966,
-0.17024423658814317, 0.7046776782991592,
1.624196256505183, -1.1423231844008523,
-0.9490267652945451, -1.8004114037674281,
-0.026086280368055388, -2.089757256612839,
3.5660566372328693, -2.4178611093952225,
-0.7077960662621875, 0.9418434965990246,
0.438927172322575, 1.0891725023940724,
-0.1237861204326181, 0.7602054634506068,
3.0515580696224083]),
farm: Vector([-1.3191469349030631, -0.34747873883058705,
0.09525267994894762, 0.08014872654330456,
0.1179814806966339, -0.26926061020753783,
0.709033965954239, -0.6521777385143812,
1.0195239553589313, 0.7192612109870958,
1.1711460976059695, -1.0779079866249233,
-0.5443503049555966, 0.08523251153754875,
0.1455530206584687, -1.501097375488643,
1.1151234505440395, 0.0581591541412683,
-0.1102242123027589, 0.5253857581277014,
0.21780949510893402, 0.026030837039037417,
0.07282095318396448, -0.6093002665622598,
2.0466066458317336]),
animal: Vector([-0.08132595025854601, -1.8280616716238214,
-0.4241550049238374, 1.3405833261217683,
1.3635302219051426, 0.19656106954281044,
1.0553637657141577, 0.8640316722860499,
-0.34682275265131135, 0.27196141799987644,
0.9785603157742483, -3.1767493003780873,
-0.7566904249011203, -1.1935303767007424,
0.2523177522167622, 0.33414675815038736,
0.21147820292953767, 0.2089073521353749,
0.36413859545070626, -0.3145854077725169,
0.8470589609352164, 0.8914477422324714,
0.06602846837066885, -1.1974184866543685,
1.6807019645814638]),
bird: Vector([0.9317828273959833, -1.142927450658389,
1.1249556341704339, 0.7533022372085103,
0.039221572709652965, 0.5302815428039684,
1.1525754405638204, 0.5707370610821617,
0.01803607760778035, 0.9527229145321762,
1.0851468114908822, -0.4626041548552341,
-0.5371489443168416, -0.8343285842461913,
-0.09713481034287788, 0.8070233789520264,
0.21755780815430825, -0.6588132708557186,
-0.7963193188039507, 0.12395864485237663,
-0.18545774404118467, 1.311026289715281,
0.7764007851264465, -0.5776179488468618,
0.5640559901962993])
};
Next, we define our meaning function. This involves computing the distance between the clue and the words on the board. We take this answer and we run it through a sigmoid function which normalizes it between 0 and 1 and then flip on this value. This effectively weights words which are very close to the clue, while still giving some probability to very distant words.
///fold: vectors
var vectors = {eagle : Vector([-0.8186906894583743, -0.8443627918594182,
-0.04304780086785447, -0.8257634263841377,
-0.7607218950809542, 0.47786735164930183,
0.36942709316422206, 0.18560148725224498,
0.38625176009619944, 0.24384273963053932,
1.0355862286322068, -0.14170089242313555,
-0.17017960843359828, 0.27636172471279313,
-0.49477465481497807, -1.199206930890509,
1.0531720839078256, -0.5154875303291531,
0.30704269353337016, 1.5382356443196483,
-0.13215425501400774, 1.2222507503066664,
1.3819617662995949, -1.1579407453927437,
0.9439311306043343]),
pig: Vector([-0.9027808771549458, -1.4539105978263833,
0.5743098399154295, 1.3052815987119957,
-0.038556210348244704, -0.22144102997326148,
1.222050088622139, -0.027526643946408857,
-0.13265827668708097, 1.4799207507145387,
0.02371336629548181, -0.9405402658175948,
0.06556493358788004, -1.6556208133885402,
-0.44306373689318584, -0.475710035110901,
1.2435716830499404, -1.0677780309283533,
-0.03344465447168945, 0.16184568683816827,
0.8718035460475897, 2.082956682688621,
0.47430271385843953, -0.4479993650378608,
1.5192928553678355]),
chicken: Vector([-0.9555633717916903, -0.23467550895948608,
0.9081102168032618, 1.7681919864431317,
-0.3888166871286516, 1.2292398003323308,
1.0624961440319318, -0.3558803892040966,
-0.17024423658814317, 0.7046776782991592,
1.624196256505183, -1.1423231844008523,
-0.9490267652945451, -1.8004114037674281,
-0.026086280368055388, -2.089757256612839,
3.5660566372328693, -2.4178611093952225,
-0.7077960662621875, 0.9418434965990246,
0.438927172322575, 1.0891725023940724,
-0.1237861204326181, 0.7602054634506068,
3.0515580696224083]),
farm: Vector([-1.3191469349030631, -0.34747873883058705,
0.09525267994894762, 0.08014872654330456,
0.1179814806966339, -0.26926061020753783,
0.709033965954239, -0.6521777385143812,
1.0195239553589313, 0.7192612109870958,
1.1711460976059695, -1.0779079866249233,
-0.5443503049555966, 0.08523251153754875,
0.1455530206584687, -1.501097375488643,
1.1151234505440395, 0.0581591541412683,
-0.1102242123027589, 0.5253857581277014,
0.21780949510893402, 0.026030837039037417,
0.07282095318396448, -0.6093002665622598,
2.0466066458317336]),
animal: Vector([-0.08132595025854601, -1.8280616716238214,
-0.4241550049238374, 1.3405833261217683,
1.3635302219051426, 0.19656106954281044,
1.0553637657141577, 0.8640316722860499,
-0.34682275265131135, 0.27196141799987644,
0.9785603157742483, -3.1767493003780873,
-0.7566904249011203, -1.1935303767007424,
0.2523177522167622, 0.33414675815038736,
0.21147820292953767, 0.2089073521353749,
0.36413859545070626, -0.3145854077725169,
0.8470589609352164, 0.8914477422324714,
0.06602846837066885, -1.1974184866543685,
1.6807019645814638]),
bird: Vector([0.9317828273959833, -1.142927450658389,
1.1249556341704339, 0.7533022372085103,
0.039221572709652965, 0.5302815428039684,
1.1525754405638204, 0.5707370610821617,
0.01803607760778035, 0.9527229145321762,
1.0851468114908822, -0.4626041548552341,
-0.5371489443168416, -0.8343285842461913,
-0.09713481034287788, 0.8070233789520264,
0.21755780815430825, -0.6588132708557186,
-0.7963193188039507, 0.12395864485237663,
-0.18545774404118467, 1.311026289715281,
0.7764007851264465, -0.5776179488468618,
0.5640559901962993])
};
///
var meaning = function(clue, words) {
var distance = function(vector1, vector2)
{
var squared = map(function(tuple)
{
return (tuple[0] - tuple[1])*(tuple[0] - tuple[1]);
}
, zip(ad.tensor.toScalars(vector1), ad.tensor.toScalars(vector2))
);
var answer = Math.sqrt(sum(squared));
return answer;
};
var sigmoid = function(num)
{
return 1/(1 + Math.exp(-1*num));
};
var trueFalse = function(clue, word)
{
var dist = distance(vectors[clue], vectors[word]);
var prob = sigmoid(1/dist);
return flip(prob);
};
var wordsVectors = map(function(word) {return trueFalse(clue, word);}, words);
return all(function(s) {return s;}, wordsVectors);
};
Our literalListener samples a random subset over the words on the board. It then runs meaning over the subset and the passed clue. The value is then gated on the meaning function returning true. We also include prior functions in order to sample subsets as well as words from our corpus.
///fold: vectors
var vectors = {eagle : Vector([-0.8186906894583743, -0.8443627918594182,
-0.04304780086785447, -0.8257634263841377,
-0.7607218950809542, 0.47786735164930183,
0.36942709316422206, 0.18560148725224498,
0.38625176009619944, 0.24384273963053932,
1.0355862286322068, -0.14170089242313555,
-0.17017960843359828, 0.27636172471279313,
-0.49477465481497807, -1.199206930890509,
1.0531720839078256, -0.5154875303291531,
0.30704269353337016, 1.5382356443196483,
-0.13215425501400774, 1.2222507503066664,
1.3819617662995949, -1.1579407453927437,
0.9439311306043343]),
pig: Vector([-0.9027808771549458, -1.4539105978263833,
0.5743098399154295, 1.3052815987119957,
-0.038556210348244704, -0.22144102997326148,
1.222050088622139, -0.027526643946408857,
-0.13265827668708097, 1.4799207507145387,
0.02371336629548181, -0.9405402658175948,
0.06556493358788004, -1.6556208133885402,
-0.44306373689318584, -0.475710035110901,
1.2435716830499404, -1.0677780309283533,
-0.03344465447168945, 0.16184568683816827,
0.8718035460475897, 2.082956682688621,
0.47430271385843953, -0.4479993650378608,
1.5192928553678355]),
chicken: Vector([-0.9555633717916903, -0.23467550895948608,
0.9081102168032618, 1.7681919864431317,
-0.3888166871286516, 1.2292398003323308,
1.0624961440319318, -0.3558803892040966,
-0.17024423658814317, 0.7046776782991592,
1.624196256505183, -1.1423231844008523,
-0.9490267652945451, -1.8004114037674281,
-0.026086280368055388, -2.089757256612839,
3.5660566372328693, -2.4178611093952225,
-0.7077960662621875, 0.9418434965990246,
0.438927172322575, 1.0891725023940724,
-0.1237861204326181, 0.7602054634506068,
3.0515580696224083]),
farm: Vector([-1.3191469349030631, -0.34747873883058705,
0.09525267994894762, 0.08014872654330456,
0.1179814806966339, -0.26926061020753783,
0.709033965954239, -0.6521777385143812,
1.0195239553589313, 0.7192612109870958,
1.1711460976059695, -1.0779079866249233,
-0.5443503049555966, 0.08523251153754875,
0.1455530206584687, -1.501097375488643,
1.1151234505440395, 0.0581591541412683,
-0.1102242123027589, 0.5253857581277014,
0.21780949510893402, 0.026030837039037417,
0.07282095318396448, -0.6093002665622598,
2.0466066458317336]),
animal: Vector([-0.08132595025854601, -1.8280616716238214,
-0.4241550049238374, 1.3405833261217683,
1.3635302219051426, 0.19656106954281044,
1.0553637657141577, 0.8640316722860499,
-0.34682275265131135, 0.27196141799987644,
0.9785603157742483, -3.1767493003780873,
-0.7566904249011203, -1.1935303767007424,
0.2523177522167622, 0.33414675815038736,
0.21147820292953767, 0.2089073521353749,
0.36413859545070626, -0.3145854077725169,
0.8470589609352164, 0.8914477422324714,
0.06602846837066885, -1.1974184866543685,
1.6807019645814638]),
bird: Vector([0.9317828273959833, -1.142927450658389,
1.1249556341704339, 0.7533022372085103,
0.039221572709652965, 0.5302815428039684,
1.1525754405638204, 0.5707370610821617,
0.01803607760778035, 0.9527229145321762,
1.0851468114908822, -0.4626041548552341,
-0.5371489443168416, -0.8343285842461913,
-0.09713481034287788, 0.8070233789520264,
0.21755780815430825, -0.6588132708557186,
-0.7963193188039507, 0.12395864485237663,
-0.18545774404118467, 1.311026289715281,
0.7764007851264465, -0.5776179488468618,
0.5640559901962993])
};
///
var meaning = function(clue, words) {
var distance = function(vector1, vector2)
{
var squared = map(function(tuple)
{
return (tuple[0] - tuple[1])*(tuple[0] - tuple[1]);
}
, zip(ad.tensor.toScalars(vector1), ad.tensor.toScalars(vector2))
);
var answer = Math.sqrt(sum(squared));
return answer;
};
var sigmoid = function(num)
{
return 1/(1 + Math.exp(-1*num));
};
var trueFalse = function(clue, word)
{
var dist = distance(vectors[clue], vectors[word]);
var prob = sigmoid(1/dist);
return flip(prob);
};
var wordsVectors = map(function(word) {return trueFalse(clue, word);}, words);
return all(function(s) {return s;}, wordsVectors);
};
var wordsPrior = function()
{
var pairs = [["chicken", "eagle"], ["eagle", "pig"], ["chicken", "pig"]];
return uniformDraw(pairs);
};
var cluePrior = function()
{
return uniformDraw(["farm", "animal", "bird"]);
};
var literalListener = function(clue)
{
Infer(function()
{
var randomSubset = wordsPrior();
var uttTruthVal = meaning(clue, randomSubset);
condition(uttTruthVal);
return randomSubset;
}
)
};
viz.table(literalListener("farm"))
The speaker reasons about the best clue given a state in order to maximize the chance the literalListener guesses right, while reducing cost. We don’t include a cost in this model, as it’s expected to be constant. We then filter based on this score, and we scale by an optimality parameter.
var speaker = function(subset)
{
Infer(function()
{
var clue = cluePrior();
factor(alpha*literalListener(clue).score(subset));
return clue;
}
)
};
The pragmatic listener reasons about the best state given a clue. This is done by inferring what the speaker would say given each state in the subset, and then filtering based on the best result for speaker.
var pragmaticListener = function(clue)
{
Infer(function()
{
var randomSubset = wordsPrior();
var s1 = speaker(randomSubset);
observe(s1, clue);
return randomSubset;
}
)
};
This is the model in full. The results are not exactly ideal; however, the example gives a good starting point for this model. We would expect “chicken” and “pig” to return the highest value for “farm”; however, this is not the case. We expect that this issue comes from the fact that these vectors are simply grouped very closely in the “animal” area, and that the differences in probabilities we are seeing are likely marginal when compared to the entire set of vectors.
The model is very similar to all the other models we covered in class. It almost exactly follows the basic RSA model from a high level. The big difference is the way we got the priors (the vectors), as well as the meaning function which involves vector math. The listeners/speaker, however, are almost direct ports of the first RSA model we covered.
///fold: vectors
var vectors = {eagle : Vector([-0.8186906894583743, -0.8443627918594182,
-0.04304780086785447, -0.8257634263841377,
-0.7607218950809542, 0.47786735164930183,
0.36942709316422206, 0.18560148725224498,
0.38625176009619944, 0.24384273963053932,
1.0355862286322068, -0.14170089242313555,
-0.17017960843359828, 0.27636172471279313,
-0.49477465481497807, -1.199206930890509,
1.0531720839078256, -0.5154875303291531,
0.30704269353337016, 1.5382356443196483,
-0.13215425501400774, 1.2222507503066664,
1.3819617662995949, -1.1579407453927437,
0.9439311306043343]),
pig: Vector([-0.9027808771549458, -1.4539105978263833,
0.5743098399154295, 1.3052815987119957,
-0.038556210348244704, -0.22144102997326148,
1.222050088622139, -0.027526643946408857,
-0.13265827668708097, 1.4799207507145387,
0.02371336629548181, -0.9405402658175948,
0.06556493358788004, -1.6556208133885402,
-0.44306373689318584, -0.475710035110901,
1.2435716830499404, -1.0677780309283533,
-0.03344465447168945, 0.16184568683816827,
0.8718035460475897, 2.082956682688621,
0.47430271385843953, -0.4479993650378608,
1.5192928553678355]),
chicken: Vector([-0.9555633717916903, -0.23467550895948608,
0.9081102168032618, 1.7681919864431317,
-0.3888166871286516, 1.2292398003323308,
1.0624961440319318, -0.3558803892040966,
-0.17024423658814317, 0.7046776782991592,
1.624196256505183, -1.1423231844008523,
-0.9490267652945451, -1.8004114037674281,
-0.026086280368055388, -2.089757256612839,
3.5660566372328693, -2.4178611093952225,
-0.7077960662621875, 0.9418434965990246,
0.438927172322575, 1.0891725023940724,
-0.1237861204326181, 0.7602054634506068,
3.0515580696224083]),
farm: Vector([-1.3191469349030631, -0.34747873883058705,
0.09525267994894762, 0.08014872654330456,
0.1179814806966339, -0.26926061020753783,
0.709033965954239, -0.6521777385143812,
1.0195239553589313, 0.7192612109870958,
1.1711460976059695, -1.0779079866249233,
-0.5443503049555966, 0.08523251153754875,
0.1455530206584687, -1.501097375488643,
1.1151234505440395, 0.0581591541412683,
-0.1102242123027589, 0.5253857581277014,
0.21780949510893402, 0.026030837039037417,
0.07282095318396448, -0.6093002665622598,
2.0466066458317336]),
animal: Vector([-0.08132595025854601, -1.8280616716238214,
-0.4241550049238374, 1.3405833261217683,
1.3635302219051426, 0.19656106954281044,
1.0553637657141577, 0.8640316722860499,
-0.34682275265131135, 0.27196141799987644,
0.9785603157742483, -3.1767493003780873,
-0.7566904249011203, -1.1935303767007424,
0.2523177522167622, 0.33414675815038736,
0.21147820292953767, 0.2089073521353749,
0.36413859545070626, -0.3145854077725169,
0.8470589609352164, 0.8914477422324714,
0.06602846837066885, -1.1974184866543685,
1.6807019645814638]),
bird: Vector([0.9317828273959833, -1.142927450658389,
1.1249556341704339, 0.7533022372085103,
0.039221572709652965, 0.5302815428039684,
1.1525754405638204, 0.5707370610821617,
0.01803607760778035, 0.9527229145321762,
1.0851468114908822, -0.4626041548552341,
-0.5371489443168416, -0.8343285842461913,
-0.09713481034287788, 0.8070233789520264,
0.21755780815430825, -0.6588132708557186,
-0.7963193188039507, 0.12395864485237663,
-0.18545774404118467, 1.311026289715281,
0.7764007851264465, -0.5776179488468618,
0.5640559901962993])
};
///
var meaning = function(clue, words) {
var distance = function(vector1, vector2)
{
var squared = map(function(tuple)
{
return (tuple[0] - tuple[1])*(tuple[0] - tuple[1]);
}
, zip(ad.tensor.toScalars(vector1), ad.tensor.toScalars(vector2))
);
var answer = Math.sqrt(sum(squared));
return answer;
};
var sigmoid = function(num)
{
return 1/(1 + Math.exp(-1*num));
};
var trueFalse = function(clue, word)
{
var dist = distance(vectors[clue], vectors[word]);
var prob = sigmoid(1/dist);
return flip(prob);
};
var wordsVectors = map(function(word) {return trueFalse(clue, word);}, words);
return all(function(s) {return s;}, wordsVectors);
};
var wordsPrior = function()
{
var pairs = [["chicken", "eagle"], ["eagle", "pig"], ["chicken", "pig"]];
return uniformDraw(pairs);
};
var cluePrior = function()
{
return uniformDraw(["farm", "animal", "bird"]);
};
var literalListener = function(clue)
{
Infer(function()
{
var randomSubset = wordsPrior();
var uttTruthVal = meaning(clue, randomSubset);
condition(uttTruthVal);
return randomSubset;
}
)
};
var alpha = 1;
var speaker = function(subset)
{
Infer(function()
{
var clue = cluePrior();
factor(alpha*literalListener(clue).score(subset));
return clue;
}
)
};
var pragmaticListener = function(clue)
{
Infer(function()
{
var randomSubset = wordsPrior();
var s1 = speaker(randomSubset);
observe(s1, clue);
return randomSubset;
}
)
};
viz.table(pragmaticListener("farm"));