Source code for renn.data.data_utils

""" Data utils """
import csv
import tensorflow as tf


[docs]def readfile(filename, parse_row): """ Reads a csv file containing labeled data, where the function parse_row() extracts a score and text from the labeled data """ with open(filename, 'r') as f: for row in csv.reader(f): score, text = parse_row(row) yield {'text': text, 'label': score}
[docs]def column_parser(text_column): """ Returns a parser which parses a row of a csv file containing labeled data, extracting the label and the text This parser assumes the label is the zeroth element of the row, and the text is the 'text_column' element """ def f(row): return int(row[0]), row[text_column] return f
[docs]def sentiment_relabel(num_classes): """ Returns a function which relabels (initially five-class) sentiment labels for subclassing the Yelp and Amazon datasets. """ LABEL_TYPE = tf.int64 TENSORS = { -1: tf.constant(-1, dtype=LABEL_TYPE), 0: tf.constant(0, dtype=LABEL_TYPE), 1: tf.constant(1, dtype=LABEL_TYPE), 2: tf.constant(2, dtype=LABEL_TYPE), 3: tf.constant(3, dtype=LABEL_TYPE), 4: tf.constant(4, dtype=LABEL_TYPE), 5: tf.constant(5, dtype=LABEL_TYPE) } equal = lambda x, y: tf.cast(tf.equal(x, y), LABEL_TYPE) greater = lambda x, y: tf.cast(tf.greater(x, y), LABEL_TYPE) less = lambda x, y: tf.cast(tf.less(x, y), LABEL_TYPE) if num_classes in [1, 2]: return lambda x: greater(x, 3) + equal(x, 3) * TENSORS[-1] elif num_classes == 3: return lambda x: equal(x, 1) * TENSORS[0] + equal(x, 2) * TENSORS[ -1] + equal(x, 3) * TENSORS[1] + equal(x, 4) * TENSORS[-1] + equal( x, 5) * TENSORS[2] elif num_classes == 5: return lambda x: tf.subtract(x, 1) else: raise ValueError('Sentiment subclasses must be 1,2,3,5')
PARSERS = { 'yelp': column_parser(1), 'dbpedia': column_parser(2), 'amazon': column_parser(2) }