Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
local human agent (#110)
Browse files Browse the repository at this point in the history
* local human

* Space
  • Loading branch information
jaseweston authored May 30, 2017

Verified

This commit was signed with the committer’s verified signature.
RobinMalfait Robin Malfait
1 parent 2e9037c commit 3023f0c
Showing 2 changed files with 68 additions and 34 deletions.
27 changes: 27 additions & 0 deletions parlai/agents/local_human/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Agent does gets the local keyboard input in the act() function.
Example: python examples/eval_model.py -m local_human -t babi:Task1k:1 -dt valid
"""

from parlai.core.agents import Agent
from parlai.core.worlds import display_messages

class LocalHumanAgent(Agent):

def __init__(self, opt, shared=None):
super().__init__(opt)
self.id = 'localHuman'

def observe(self, msg):
print(display_messages([msg]))

def act(self):
obs = self.observation
reply = {}
reply['id'] = self.getID()
reply['text'] = input("Enter Your Reply: ")
return reply
75 changes: 41 additions & 34 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,46 @@ def validate(observation):
else:
raise RuntimeError('Must return dictionary from act().')

def display_messages(msgs):
"""Returns a string describing the set of messages provided"""
lines = []
episode_done = False
for index, msg in enumerate(msgs):
if msg is None:
continue
if msg.get('episode_done', False):
episode_done = True
# Possibly indent the text (for the second speaker, if two).
space = ''
if len(msgs) == 2 and index == 1:
space = ' '
if msg.get('reward', None) is not None:
lines.append(space + '[reward: {r}]'.format(r=msg['reward']))
if msg.get('text', ''):
ID = '[' + msg['id'] + ']: ' if 'id' in msg else ''
lines.append(space + ID + msg['text'])
if msg.get('labels', False):
lines.append(space + ('[labels: {}]'.format(
'|'.join(msg['labels']))))
if msg.get('label_candidates', False):
cand_len = len(msg['label_candidates'])
if cand_len <= 10:
lines.append(space + ('[cands: {}]'.format(
'|'.join(msg['label_candidates']))))
else:
# select five label_candidates from the candidate set,
# can't slice in because it's a set
cand_iter = iter(msg['label_candidates'])
display_cands = (next(cand_iter) for _ in range(5))
# print those cands plus how many cands remain
lines.append(space + ('[cands: {}{}]'.format(
'|'.join(display_cands),
'| ...and {} more'.format(cand_len - 5)
)))
if episode_done:
lines.append('- - - - - - - - - - - - - - - - - - - - -')
return '\n'.join(lines)


class World(object):
"""Empty parent providing null definitions of API functions for Worlds.
@@ -91,40 +131,7 @@ def display(self):
By default, display the messages between the agents."""
if not hasattr(self, 'acts'):
return ''
lines = []
for index, msg in enumerate(self.acts):
if msg is None:
continue
# Possibly indent the text (for the second speaker, if two).
space = ''
if len(self.acts) == 2 and index == 1:
space = ' '
if msg.get('reward', None) is not None:
lines.append(space + '[reward: {r}]'.format(r=msg['reward']))
if msg.get('text', ''):
ID = '[' + msg['id'] + ']: ' if 'id' in msg else ''
lines.append(space + ID + msg['text'])
if msg.get('labels', False):
lines.append(space + ('[labels: {}]'.format(
'|'.join(msg['labels']))))
if msg.get('label_candidates', False):
cand_len = len(msg['label_candidates'])
if cand_len <= 10:
lines.append(space + ('[cands: {}]'.format(
'|'.join(msg['label_candidates']))))
else:
# select five label_candidates from the candidate set,
# can't slice in because it's a set
cand_iter = iter(msg['label_candidates'])
display_cands = (next(cand_iter) for _ in range(5))
# print those cands plus how many cands remain
lines.append(space + ('[cands: {}{}]'.format(
'|'.join(display_cands),
'| ...and {} more'.format(cand_len - 5)
)))
if self.episode_done():
lines.append('- - - - - - - - - - - - - - - - - - - - -')
return '\n'.join(lines)
return display_messages(self.acts)

def episode_done(self):
"""Whether the episode is done or not. """

0 comments on commit 3023f0c

Please sign in to comment.