HIVE的row_number函数,类似于Oracle的ROW_NUMBER函数,实现在HIVE跑Map/Reduce的Reduce过程中取行号,一般应用于Sort By,Order By
具体代码如下:
import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.LongWritable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @Description(name = "row_number", value = "_FUNC_(a, [...]) - Assumes that incoming data is SORTed and DISTRIBUTEd according to the given columns, and then returns the row number for each row within the partition,") public class GenericUDFPartitionRowNumber extends GenericUDF { private Logger logger = LoggerFactory.getLogger(GenericUDFPartitionRowNumber.class); private LongWritable rowIndex = new LongWritable(0); private Object[] partitionColumnValues; private ObjectInspector[] objectInspectors; private int[] sortDirections; // holds +1 (for compare() > 0), 0 for unknown, -1 (for compare() < 0) /** * Takes the output of compare() and scales it to either, +1, 0 or -1. * * @param val * @return */ protected static int collapseToIndicator(int val) { if (val > 0) { return 1; } else if (val == 0) { return 0; } else { return -1; } } /** * Wraps Object.equals, but allows one or both arguments to be null. Note * that nullSafeEquals(null, null) == true. * * @param o1 * First object * @param o2 * Second object * @return */ protected static boolean nullSafeEquals(Object o1, Object o2) { if (o1 == null && o2 == null) { return true; } else if (o1 == null || o2 == null) { return false; } else { return (o1.equals(o2)); } } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { assert (arguments.length == partitionColumnValues.length); for (int i = 0; i < arguments.length; i++) { if (partitionColumnValues[i] == null) { partitionColumnValues[i] = ObjectInspectorUtils.copyToStandardObject(arguments[i].get(), objectInspectors[i]); } else if (!nullSafeEquals(arguments[i].get(), partitionColumnValues[i])) { // check sort directions. We know the elements aren't equal. int newDirection = collapseToIndicator(ObjectInspectorUtils.compare(arguments[i].get(), objectInspectors[i],partitionColumnValues[i], objectInspectors[i])); if (sortDirections[i] == 0) { // We don't already know what the sort direction should be sortDirections[i] = newDirection; } else if (sortDirections[i] != newDirection) { throw new HiveException( "Data in column: " + i + " does not appear to be consistently sorted, so partitionedRowNumber cannot be used."); } // reset everything (well, the remaining column values, because the previous ones haven't changed. for (int j = i; j < arguments.length; j++) { partitionColumnValues[j] = ObjectInspectorUtils.copyToStandardObject(arguments[j].get(),objectInspectors[j]); } rowIndex.set(1); return rowIndex; } } // partition columns are identical. Increment and continue. rowIndex.set(rowIndex.get() + 1); return rowIndex; } @Override public String getDisplayString(String[] children) { return "partitionedRowNumber(" + StringUtils.join(children, ", ") + ")"; } @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { logger.info("run times"); if (arguments.length == 0) { throw new UDFArgumentLengthException("The function partitionedRowNumber expects at least 1 argument."); } partitionColumnValues = new Object[arguments.length]; for (ObjectInspector oi : arguments) { if (ObjectInspectorUtils.isConstantObjectInspector(oi)) { throw new UDFArgumentException("No constant arguments should be passed to partitionedRowNumber."); } } objectInspectors = arguments; sortDirections = new int[arguments.length]; return PrimitiveObjectInspectorFactory.writableLongObjectInspector; } }
HIVE的0.11.0版本中提供了row_number函数,看了一下源码:
import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.exec.WindowFunctionDescription; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.IntWritable; @WindowFunctionDescription ( description = @Description( name = "row_number", value = "_FUNC_() - The ROW_NUMBER function assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each row within the partition." ), supportsWindow = false, pivotResult = true ) public class GenericUDAFRowNumber extends AbstractGenericUDAFResolver { static final Log LOG = LogFactory.getLog(GenericUDAFRowNumber.class.getName()); @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { if (parameters.length != 0) { throw new UDFArgumentTypeException(parameters.length - 1, "No argument is expected."); } return new GenericUDAFRowNumberEvaluator(); } static class RowNumberBuffer implements AggregationBuffer { ArrayList<IntWritable> rowNums; int nextRow; void init() { rowNums = new ArrayList<IntWritable>(); } RowNumberBuffer() { init(); nextRow = 1; } void incr() { rowNums.add(new IntWritable(nextRow++)); } } public static class GenericUDAFRowNumberEvaluator extends GenericUDAFEvaluator { @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); if (m != Mode.COMPLETE) { throw new HiveException("Only COMPLETE mode supported for row_number function"); } return ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.writableIntObjectInspector); } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { return new RowNumberBuffer(); } @Override public void reset(AggregationBuffer agg) throws HiveException { ((RowNumberBuffer) agg).init(); } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { ((RowNumberBuffer) agg).incr(); } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { throw new HiveException("terminatePartial not supported"); } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { throw new HiveException("merge not supported"); } @Override public Object terminate(AggregationBuffer agg) throws HiveException { return ((RowNumberBuffer) agg).rowNums; } } }
内置的row_number函数需要结合窗口函数使用,例如:
select s, sum(f) over (partition by i), row_number() over () from over10k where s = 'tom allen' or s = 'bob steinbeck';
窗口函数为0.11.0版本新增的特征。