Author: shawkins
Date: 2012-03-28 21:43:03 -0400 (Wed, 28 Mar 2012)
New Revision: 3956
Added:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/SingleArgumentAggregateFunction.java
Modified:
trunk/api/src/main/java/org/teiid/metadata/AggregateAttributes.java
trunk/engine/src/main/java/org/teiid/common/buffer/BufferManager.java
trunk/engine/src/main/java/org/teiid/common/buffer/impl/BufferManagerImpl.java
trunk/engine/src/main/java/org/teiid/dqp/internal/datamgr/ExecutionContextImpl.java
trunk/engine/src/main/java/org/teiid/dqp/internal/process/MetaDataProcessor.java
trunk/engine/src/main/java/org/teiid/query/QueryPlugin.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/AggregateFunction.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/ArrayAgg.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/Avg.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/ConstantFunction.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/Count.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/Max.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/Min.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/RankingFunction.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/StatsFunction.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/Sum.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/TextAgg.java
trunk/engine/src/main/java/org/teiid/query/function/aggregate/XMLAgg.java
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/CapabilitiesUtil.java
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RuleAssignOutputElements.java
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RulePushAggregates.java
trunk/engine/src/main/java/org/teiid/query/processor/relational/GroupingNode.java
trunk/engine/src/main/java/org/teiid/query/processor/relational/SortingFilter.java
trunk/engine/src/main/java/org/teiid/query/processor/relational/WindowFunctionProjectNode.java
trunk/engine/src/main/java/org/teiid/query/rewriter/QueryRewriter.java
trunk/engine/src/main/java/org/teiid/query/sql/LanguageObject.java
trunk/engine/src/main/java/org/teiid/query/sql/navigator/PreOrPostOrderNavigator.java
trunk/engine/src/main/java/org/teiid/query/sql/symbol/AggregateSymbol.java
trunk/engine/src/main/java/org/teiid/query/sql/symbol/Function.java
trunk/engine/src/main/java/org/teiid/query/validator/ValidationVisitor.java
trunk/engine/src/main/resources/org/teiid/query/i18n.properties
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestDuplicateFilter.java
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestGroupingNode.java
Log:
TEIID-1560 adding support for n-ary aggregate functions
Modified: trunk/api/src/main/java/org/teiid/metadata/AggregateAttributes.java
===================================================================
--- trunk/api/src/main/java/org/teiid/metadata/AggregateAttributes.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/api/src/main/java/org/teiid/metadata/AggregateAttributes.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -33,8 +33,8 @@
private boolean allowsDistinct;
private boolean windowable;
- private boolean decomposable;
- private boolean respectsNulls;
+ private boolean usesAllRows = true;
+ private boolean respectsNulls = true;
private boolean allowsOrderBy;
public boolean allowsOrderBy() {
@@ -61,14 +61,6 @@
this.windowable = windowable;
}
- public boolean isDecomposable() {
- return decomposable;
- }
-
- public void setDecomposable(boolean decomposable) {
- this.decomposable = decomposable;
- }
-
public boolean respectsNulls() {
return respectsNulls;
}
@@ -77,4 +69,12 @@
this.respectsNulls = respectsNulls;
}
+ public void setUsesAllRows(boolean usesAllRows) {
+ this.usesAllRows = usesAllRows;
+ }
+
+ public boolean usesAllRows() {
+ return this.usesAllRows;
+ }
+
}
Modified: trunk/engine/src/main/java/org/teiid/common/buffer/BufferManager.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/common/buffer/BufferManager.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/common/buffer/BufferManager.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -122,7 +122,7 @@
*/
int getSchemaSize(List<? extends Expression> elements);
- STree createSTree(final List elements, String groupName, int keyLength);
+ STree createSTree(List<? extends Expression> elements, String groupName, int
keyLength);
void addTupleBuffer(TupleBuffer tb);
Modified: trunk/engine/src/main/java/org/teiid/common/buffer/impl/BufferManagerImpl.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/common/buffer/impl/BufferManagerImpl.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/common/buffer/impl/BufferManagerImpl.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -482,7 +482,7 @@
return tupleBuffer;
}
- public STree createSTree(final List elements, String groupName, int keyLength) {
+ public STree createSTree(final List<? extends Expression> elements, String
groupName, int keyLength) {
Long newID = this.tsId.getAndIncrement();
int[] lobIndexes = LobManager.getLobIndexes(elements);
Class<?>[] types = getTypeClasses(elements);
@@ -503,7 +503,7 @@
return new STree(keyManager, bm, new ListNestedSortComparator(compareIndexes),
getProcessorBatchSize(elements.subList(0, keyLength)), getProcessorBatchSize(elements),
keyLength, lobManager);
}
- private static Class<?>[] getTypeClasses(final List elements) {
+ private static Class<?>[] getTypeClasses(final List<? extends Expression>
elements) {
Class<?>[] types = new Class[elements.size()];
for (ListIterator<? extends Expression> i = elements.listIterator();
i.hasNext();) {
Expression expr = i.next();
Modified:
trunk/engine/src/main/java/org/teiid/dqp/internal/datamgr/ExecutionContextImpl.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/dqp/internal/datamgr/ExecutionContextImpl.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/dqp/internal/datamgr/ExecutionContextImpl.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -259,8 +259,7 @@
@Override
public String getConnectionID() {
- // TODO Auto-generated method stub
- return null;
+ return getConnectionId();
}
@Override
Modified:
trunk/engine/src/main/java/org/teiid/dqp/internal/process/MetaDataProcessor.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/dqp/internal/process/MetaDataProcessor.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/dqp/internal/process/MetaDataProcessor.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -58,7 +58,6 @@
import org.teiid.query.sql.lang.SPParameter;
import org.teiid.query.sql.lang.StoredProcedure;
import org.teiid.query.sql.symbol.AggregateSymbol;
-import org.teiid.query.sql.symbol.AggregateSymbol.Type;
import org.teiid.query.sql.symbol.AliasSymbol;
import org.teiid.query.sql.symbol.ElementSymbol;
import org.teiid.query.sql.symbol.Expression;
@@ -67,6 +66,7 @@
import org.teiid.query.sql.symbol.Reference;
import org.teiid.query.sql.symbol.Symbol;
import org.teiid.query.sql.symbol.WindowFunction;
+import org.teiid.query.sql.symbol.AggregateSymbol.Type;
import org.teiid.query.sql.util.SymbolMap;
import org.teiid.query.sql.visitor.ReferenceCollectorVisitor;
import org.teiid.query.tempdata.TempTableStore;
@@ -341,9 +341,9 @@
private Map createAggregateMetadata(String shortColumnName,
AggregateSymbol symbol) throws
QueryMetadataException, TeiidComponentException {
- Expression expression = symbol.getExpression();
Type function = symbol.getAggregateFunction();
if(function == Type.MIN || function == Type.MAX){
+ Expression expression = symbol.getArg(0);
if(expression instanceof ElementSymbol) {
return createColumnMetadata(shortColumnName, expression);
}
Modified: trunk/engine/src/main/java/org/teiid/query/QueryPlugin.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/QueryPlugin.java 2012-03-28 15:29:46 UTC
(rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/QueryPlugin.java 2012-03-29 01:43:03 UTC
(rev 3956)
@@ -468,8 +468,6 @@
TEIID30422,
TEIID30423,
TEIID30424,
- TEIID30425,
- TEIID30426,
TEIID30427,
TEIID30428,
TEIID30429,
Modified:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/AggregateFunction.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/function/aggregate/AggregateFunction.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/function/aggregate/AggregateFunction.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -38,11 +38,11 @@
*/
public abstract class AggregateFunction {
- private int expressionIndex = -1;
+ protected int[] argIndexes;
private int conditionIndex = -1;
- public void setExpressionIndex(int expressionIndex) {
- this.expressionIndex = expressionIndex;
+ public void setArgIndexes(int[] argIndexes) {
+ this.argIndexes = argIndexes;
}
public void setConditionIndex(int conditionIndex) {
@@ -53,9 +53,13 @@
* Called to initialize the function. In the future this may expand
* with additional information.
* @param dataType Data type of element begin aggregated
- * @param inputType
+ * @param inputTypes
*/
- public void initialize(Class<?> dataType, Class<?> inputType) {}
+ public void initialize(Class<?> dataType, Class<?>[] inputTypes) {}
+
+ public int[] getArgIndexes() {
+ return argIndexes;
+ }
/**
* Called to reset the state of the function.
@@ -66,14 +70,14 @@
if (conditionIndex != -1 && !Boolean.TRUE.equals(tuple.get(conditionIndex)))
{
return;
}
- if (expressionIndex == -1) {
- addInputDirect(null, tuple);
- return;
+ if (!respectsNull()) {
+ for (int i = 0; i < argIndexes.length; i++) {
+ if (tuple.get(argIndexes[i]) == null) {
+ return;
+ }
+ }
}
- Object input = tuple.get(expressionIndex);
- if (input != null || respectsNull()) {
- addInputDirect(input, tuple);
- }
+ addInputDirect(tuple);
}
public boolean respectsNull() {
@@ -82,11 +86,10 @@
/**
* Called for the element value in every row of a group.
- * @param input Input value, may be null
* @param tuple
* @throws TeiidProcessingException
*/
- public abstract void addInputDirect(Object input, List<?> tuple) throws
TeiidComponentException, TeiidProcessingException;
+ public abstract void addInputDirect(List<?> tuple) throws
TeiidComponentException, TeiidProcessingException;
/**
* Called after all values have been processed to get the result.
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/ArrayAgg.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/ArrayAgg.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/ArrayAgg.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -30,7 +30,7 @@
import org.teiid.core.TeiidProcessingException;
import org.teiid.query.util.CommandContext;
-public class ArrayAgg extends AggregateFunction {
+public class ArrayAgg extends SingleArgumentAggregateFunction {
private ArrayList<Object> result;
private CommandContext context;
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/Avg.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/Avg.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/Avg.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -60,7 +60,7 @@
}
/**
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object input, List<?> tuple)
throws FunctionExecutionException, ExpressionEvaluationException,
TeiidComponentException {
Modified:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/ConstantFunction.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/function/aggregate/ConstantFunction.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/function/aggregate/ConstantFunction.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -31,7 +31,7 @@
/**
*/
-public class ConstantFunction extends AggregateFunction {
+public class ConstantFunction extends SingleArgumentAggregateFunction {
private Object value;
@@ -45,7 +45,7 @@
}
/**
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object input, List<?> tuple)
throws FunctionExecutionException, ExpressionEvaluationException,
TeiidComponentException {
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/Count.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/Count.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/Count.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -24,6 +24,9 @@
import java.util.List;
+import org.teiid.core.TeiidComponentException;
+import org.teiid.core.TeiidProcessingException;
+
/**
* Just a simple COUNT() implementation that counts every non-null row it sees.
*/
@@ -34,11 +37,10 @@
public void reset() {
count = 0;
}
-
- /**
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
- */
- public void addInputDirect(Object input, List<?> tuple) {
+
+ @Override
+ public void addInputDirect(List<?> tuple)
+ throws TeiidComponentException, TeiidProcessingException {
count++;
}
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/Max.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/Max.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/Max.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -27,12 +27,12 @@
import org.teiid.api.exception.query.ExpressionEvaluationException;
import org.teiid.api.exception.query.FunctionExecutionException;
import org.teiid.core.TeiidComponentException;
-import org.teiid.query.QueryPlugin;
+import org.teiid.query.sql.symbol.Constant;
/**
*/
-public class Max extends AggregateFunction {
+public class Max extends SingleArgumentAggregateFunction {
private Object maxValue;
@@ -41,7 +41,7 @@
}
/**
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object value, List<?> tuple)
throws FunctionExecutionException, ExpressionEvaluationException,
TeiidComponentException {
@@ -49,14 +49,10 @@
if(maxValue == null) {
maxValue = value;
} else {
- if(value instanceof Comparable) {
- Comparable valueComp = (Comparable) value;
+ Comparable valueComp = (Comparable) value;
- if(valueComp.compareTo(maxValue) > 0) {
- maxValue = valueComp;
- }
- } else {
- throw new FunctionExecutionException(QueryPlugin.Event.TEIID30425,
QueryPlugin.Util.gs(QueryPlugin.Event.TEIID30425, "MAX",
value.getClass().getName()));//$NON-NLS-1$
+ if (Constant.COMPARATOR.compare(valueComp, maxValue) > 0) {
+ maxValue = valueComp;
}
}
}
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/Min.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/Min.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/Min.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -27,12 +27,12 @@
import org.teiid.api.exception.query.ExpressionEvaluationException;
import org.teiid.api.exception.query.FunctionExecutionException;
import org.teiid.core.TeiidComponentException;
-import org.teiid.query.QueryPlugin;
+import org.teiid.query.sql.symbol.Constant;
/**
*/
-public class Min extends AggregateFunction {
+public class Min extends SingleArgumentAggregateFunction {
private Object minValue;
@@ -41,7 +41,7 @@
}
/**
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object value, List<?> tuple)
throws FunctionExecutionException, ExpressionEvaluationException,
TeiidComponentException {
@@ -49,14 +49,10 @@
if(minValue == null) {
minValue = value;
} else {
- if(value instanceof Comparable) {
- Comparable valueComp = (Comparable) value;
+ Comparable valueComp = (Comparable) value;
- if(valueComp.compareTo(minValue) < 0) {
- minValue = valueComp;
- }
- } else {
- throw new FunctionExecutionException(QueryPlugin.Event.TEIID30426,
QueryPlugin.Util.gs(QueryPlugin.Event.TEIID30426, "MIN",
value.getClass().getName())); //$NON-NLS-1$
+ if(Constant.COMPARATOR.compare(valueComp, minValue) < 0) {
+ minValue = valueComp;
}
}
}
Modified:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/RankingFunction.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/function/aggregate/RankingFunction.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/function/aggregate/RankingFunction.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -49,7 +49,7 @@
}
@Override
- public void addInputDirect(Object input, List<?> tuple)
+ public void addInputDirect(List<?> tuple)
throws FunctionExecutionException, ExpressionEvaluationException,
TeiidComponentException {
if (type == Type.RANK) {
Added:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/SingleArgumentAggregateFunction.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/function/aggregate/SingleArgumentAggregateFunction.java
(rev 0)
+++
trunk/engine/src/main/java/org/teiid/query/function/aggregate/SingleArgumentAggregateFunction.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -0,0 +1,52 @@
+/*
+ * JBoss, Home of Professional Open Source.
+ * See the COPYRIGHT.txt file distributed with this work for information
+ * regarding copyright ownership. Some portions may be licensed
+ * to Red Hat, Inc. under one or more contributor license agreements.
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+ * 02110-1301 USA.
+ */
+
+package org.teiid.query.function.aggregate;
+
+import java.util.List;
+
+import org.teiid.core.TeiidComponentException;
+import org.teiid.core.TeiidProcessingException;
+
+public abstract class SingleArgumentAggregateFunction extends AggregateFunction {
+
+ @Override
+ public void addInputDirect(List<?> tuple)
+ throws TeiidComponentException, TeiidProcessingException {
+ addInputDirect(tuple.get(argIndexes[0]), tuple);
+ }
+
+ public void initialize(java.lang.Class<?> dataType, java.lang.Class<?>[]
inputTypes) {
+ initialize(dataType, inputTypes[0]);
+ }
+
+ /**
+ * @param dataType
+ * @param inputType
+ */
+ public void initialize(java.lang.Class<?> dataType, java.lang.Class<?>
inputType) {
+
+ }
+
+ public abstract void addInputDirect(Object input, List<?> tuple)
+ throws TeiidProcessingException, TeiidComponentException;
+}
Property changes on:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/SingleArgumentAggregateFunction.java
___________________________________________________________________
Added: svn:mime-type
+ text/plain
Modified:
trunk/engine/src/main/java/org/teiid/query/function/aggregate/StatsFunction.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/function/aggregate/StatsFunction.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/function/aggregate/StatsFunction.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -29,7 +29,7 @@
import org.teiid.core.TeiidComponentException;
import org.teiid.query.sql.symbol.AggregateSymbol.Type;
-public class StatsFunction extends AggregateFunction {
+public class StatsFunction extends SingleArgumentAggregateFunction {
private double sum = 0;
private double sumSq = 0;
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/Sum.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/Sum.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/Sum.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -37,7 +37,7 @@
* of a column. The type of the result varies depending on the type
* of the input {@see AggregateSymbol}
*/
-public class Sum extends AggregateFunction {
+public class Sum extends SingleArgumentAggregateFunction {
// Various possible accumulators, depending on type
protected static final int LONG = 0;
@@ -85,7 +85,7 @@
}
/**
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object input, List<?> tuple)
throws FunctionExecutionException, ExpressionEvaluationException,
TeiidComponentException {
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/TextAgg.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/TextAgg.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/TextAgg.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -31,8 +31,8 @@
import javax.sql.rowset.serial.SerialBlob;
import org.teiid.common.buffer.FileStore;
+import org.teiid.common.buffer.FileStoreInputStreamFactory;
import org.teiid.common.buffer.FileStore.FileStoreOutputStream;
-import org.teiid.common.buffer.FileStoreInputStreamFactory;
import org.teiid.core.TeiidComponentException;
import org.teiid.core.TeiidProcessingException;
import org.teiid.core.types.BlobImpl;
@@ -47,7 +47,7 @@
/**
* Aggregates Text entries
*/
-public class TextAgg extends AggregateFunction {
+public class TextAgg extends SingleArgumentAggregateFunction {
private FileStoreInputStreamFactory result;
private CommandContext context;
@@ -87,7 +87,7 @@
/**
* @throws TeiidProcessingException
* @throws TeiidComponentException
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object input, List<?> tuple) throws
TeiidComponentException, TeiidProcessingException {
try {
Modified: trunk/engine/src/main/java/org/teiid/query/function/aggregate/XMLAgg.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/function/aggregate/XMLAgg.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/function/aggregate/XMLAgg.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -33,7 +33,7 @@
/**
* Aggregates XML entries
*/
-public class XMLAgg extends AggregateFunction {
+public class XMLAgg extends SingleArgumentAggregateFunction {
private XMLType result;
private XmlConcat concat;
@@ -51,7 +51,7 @@
/**
* @throws TeiidProcessingException
* @throws TeiidComponentException
- * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(Object,
List)
+ * @see org.teiid.query.function.aggregate.AggregateFunction#addInputDirect(List)
*/
public void addInputDirect(Object input, List<?> tuple) throws
TeiidComponentException, TeiidProcessingException {
if (concat == null) {
Modified:
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/CapabilitiesUtil.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/CapabilitiesUtil.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/CapabilitiesUtil.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -151,7 +151,7 @@
Type func = aggregate.getAggregateFunction();
switch (func) {
case COUNT:
- if(aggregate.getExpression() == null) {
+ if(aggregate.getArgs().length == 0) {
if(! caps.supportsCapability(Capability.QUERY_AGGREGATES_COUNT_STAR)) {
return false;
}
Modified:
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RuleAssignOutputElements.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RuleAssignOutputElements.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RuleAssignOutputElements.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -580,9 +580,9 @@
Expression ex = symbolMap.getMappedExpression((ElementSymbol) outputSymbol);
if(ex instanceof AggregateSymbol) {
AggregateSymbol agg = (AggregateSymbol)ex;
- Expression aggExpr = agg.getExpression();
- if(aggExpr != null) {
- ElementCollectorVisitor.getElements(aggExpr, requiredSymbols);
+ Expression[] aggExprs = agg.getArgs();
+ for (Expression expression : aggExprs) {
+ ElementCollectorVisitor.getElements(expression, requiredSymbols);
}
OrderBy orderBy = agg.getOrderBy();
if(orderBy != null) {
Modified:
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RulePushAggregates.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RulePushAggregates.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/optimizer/relational/rules/RulePushAggregates.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -22,19 +22,7 @@
package org.teiid.query.optimizer.relational.rules;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.LinkedHashMap;
-import java.util.LinkedHashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
+import java.util.*;
import org.teiid.api.exception.query.QueryMetadataException;
import org.teiid.api.exception.query.QueryPlannerException;
@@ -55,10 +43,10 @@
import org.teiid.query.optimizer.relational.RelationalPlanner;
import org.teiid.query.optimizer.relational.RuleStack;
import org.teiid.query.optimizer.relational.plantree.NodeConstants;
-import org.teiid.query.optimizer.relational.plantree.NodeConstants.Info;
import org.teiid.query.optimizer.relational.plantree.NodeEditor;
import org.teiid.query.optimizer.relational.plantree.NodeFactory;
import org.teiid.query.optimizer.relational.plantree.PlanNode;
+import org.teiid.query.optimizer.relational.plantree.NodeConstants.Info;
import org.teiid.query.resolver.util.ResolverUtil;
import org.teiid.query.resolver.util.ResolverVisitor;
import org.teiid.query.rewriter.QueryRewriter;
@@ -70,17 +58,8 @@
import org.teiid.query.sql.lang.OrderBy;
import org.teiid.query.sql.lang.Select;
import org.teiid.query.sql.lang.SetQuery.Operation;
-import org.teiid.query.sql.symbol.AggregateSymbol;
+import org.teiid.query.sql.symbol.*;
import org.teiid.query.sql.symbol.AggregateSymbol.Type;
-import org.teiid.query.sql.symbol.AliasSymbol;
-import org.teiid.query.sql.symbol.Constant;
-import org.teiid.query.sql.symbol.ElementSymbol;
-import org.teiid.query.sql.symbol.Expression;
-import org.teiid.query.sql.symbol.ExpressionSymbol;
-import org.teiid.query.sql.symbol.Function;
-import org.teiid.query.sql.symbol.GroupSymbol;
-import org.teiid.query.sql.symbol.SearchedCaseExpression;
-import org.teiid.query.sql.symbol.Symbol;
import org.teiid.query.sql.util.SymbolMap;
import org.teiid.query.sql.visitor.AggregateSymbolCollectorVisitor;
import org.teiid.query.sql.visitor.ElementCollectorVisitor;
@@ -420,16 +399,17 @@
for (AggregateSymbol agg : aggregates) {
agg = (AggregateSymbol)agg.clone();
if (agg.getAggregateFunction() == Type.COUNT) {
- if (agg.getExpression() == null) {
+ if (agg.getArgs().length == 0) {
allSymbols.addSymbol(new ExpressionSymbol("stagedAgg", new
Constant(1))); //$NON-NLS-1$
} else {
- SearchedCaseExpression count = new SearchedCaseExpression(Arrays.asList(new
IsNullCriteria(agg.getExpression())), Arrays.asList(new Constant(Integer.valueOf(0))));
+ SearchedCaseExpression count = new SearchedCaseExpression(Arrays.asList(new
IsNullCriteria(agg.getArg(0))), Arrays.asList(new Constant(Integer.valueOf(0))));
count.setElseExpression(new Constant(Integer.valueOf(1)));
count.setType(DataTypeManager.DefaultDataClasses.INTEGER);
allSymbols.addSymbol(new ExpressionSymbol("stagedAgg", count));
//$NON-NLS-1$
}
} else { //min, max, sum
- Expression ex = agg.getExpression();
+ assert agg.getArgs().length == 1; //prior canStage should ensure this is true
+ Expression ex = agg.getArg(0);
ex = ResolverUtil.convertExpression(ex,
DataTypeManager.getDataTypeName(agg.getType()), metadata);
allSymbols.addSymbol(new ExpressionSymbol("stagedAgg", ex));
//$NON-NLS-1$
}
@@ -669,7 +649,7 @@
if (stagedGroupingSymbols.isEmpty()) {
// if the source has no rows we need to insert a select node with criteria
count(*)>0
PlanNode selectNode = NodeFactory.getNewNode(NodeConstants.Types.SELECT);
- AggregateSymbol count = new AggregateSymbol(NonReserved.COUNT, false, null);
//$NON-NLS-1$
+ AggregateSymbol count = new AggregateSymbol(NonReserved.COUNT, false, null);
aggregates.add(count); //consider the count aggregate for the push down call below
selectNode.setProperty(NodeConstants.Info.SELECT_CRITERIA, new
CompareCriteria(count, CompareCriteria.GT,
new
Constant(new Integer(0))));
@@ -702,10 +682,10 @@
//remove any aggregates that are computed over a group by column
for (final Iterator<AggregateSymbol> iterator = aggregates.iterator();
iterator.hasNext();) {
final AggregateSymbol symbol = iterator.next();
- Expression expr = symbol.getExpression();
- if (expr == null) {
- continue;
+ if (symbol.getArgs().length != 1) {
+ continue;
}
+ Expression expr = symbol.getArg(0);
if (stagedGroupingSymbols.contains(expr)) {
iterator.remove();
}
@@ -796,12 +776,12 @@
return result;
}
for (T aggregateSymbol : expressions) {
- if (aggs && ((AggregateSymbol)aggregateSymbol).getExpression() == null)
{
- return null; //count(*) is not yet handled. a general approach would be
count(*) => count(r.col) * count(l.col), but the logic here assumes a simpler initial
mapping
+ if (aggs) {
+ AggregateSymbol as = (AggregateSymbol)aggregateSymbol;
+ if ((!as.canStage() && as.isCardinalityDependent()) ||
(as.getAggregateFunction() == Type.COUNT && as.getArgs().length == 0)) {
+ return null; //count(*) is not yet handled. a general approach would be
count(*) => count(r.col) * count(l.col), but the logic here assumes a simpler initial
mapping
+ }
}
- if (aggs && !((AggregateSymbol)aggregateSymbol).canStage()) {
- continue;
- }
Set<GroupSymbol> groups =
GroupsUsedByElementsVisitor.getGroups(aggregateSymbol);
if (groups.isEmpty()) {
continue;
@@ -862,7 +842,7 @@
Type aggFunction = partitionAgg.getAggregateFunction();
if (aggFunction == Type.COUNT) {
//COUNT(x) -> CONVERT(SUM(COUNT(x)), INTEGER)
- AggregateSymbol newAgg = new AggregateSymbol(NonReserved.SUM, false,
partitionAgg); //$NON-NLS-1$
+ AggregateSymbol newAgg = new AggregateSymbol(NonReserved.SUM, false,
partitionAgg);
// Build conversion function to convert SUM (which returns LONG) back to
INTEGER
Function convertFunc = new Function(FunctionLibrary.CONVERT, new
Expression[] {newAgg, new
Constant(DataTypeManager.getDataTypeName(partitionAgg.getType()))});
ResolverVisitor.resolveLanguageObject(convertFunc, metadata);
@@ -871,11 +851,11 @@
nestedAggregates.add(partitionAgg);
} else if (aggFunction == Type.AVG) {
//AVG(x) -> SUM(SUM(x)) / SUM(COUNT(x))
- AggregateSymbol countAgg = new AggregateSymbol(NonReserved.COUNT, false,
partitionAgg.getExpression()); //$NON-NLS-1$
- AggregateSymbol sumAgg = new AggregateSymbol(NonReserved.SUM, false,
partitionAgg.getExpression()); //$NON-NLS-1$
+ AggregateSymbol countAgg = new AggregateSymbol(NonReserved.COUNT, false,
partitionAgg.getArg(0));
+ AggregateSymbol sumAgg = new AggregateSymbol(NonReserved.SUM, false,
partitionAgg.getArg(0));
- AggregateSymbol sumSumAgg = new AggregateSymbol(NonReserved.SUM, false,
sumAgg); //$NON-NLS-1$
- AggregateSymbol sumCountAgg = new AggregateSymbol(NonReserved.SUM, false,
countAgg); //$NON-NLS-1$
+ AggregateSymbol sumSumAgg = new AggregateSymbol(NonReserved.SUM, false,
sumAgg);
+ AggregateSymbol sumCountAgg = new AggregateSymbol(NonReserved.SUM, false,
countAgg);
Expression convertedSum = new Function(FunctionLibrary.CONVERT, new
Expression[] {sumSumAgg, new
Constant(DataTypeManager.getDataTypeName(partitionAgg.getType()))});
Expression convertCount = new Function(FunctionLibrary.CONVERT, new
Expression[] {sumCountAgg, new
Constant(DataTypeManager.getDataTypeName(partitionAgg.getType()))});
@@ -888,13 +868,13 @@
nestedAggregates.add(sumAgg);
} else if (partitionAgg.isEnhancedNumeric()) {
//e.g. STDDEV_SAMP := CASE WHEN COUNT(X) > 1 THEN SQRT((SUM(X^2) -
SUM(X)^2/COUNT(X))/(COUNT(X) - 1))
- AggregateSymbol countAgg = new AggregateSymbol(NonReserved.COUNT, false,
partitionAgg.getExpression()); //$NON-NLS-1$
- AggregateSymbol sumAgg = new AggregateSymbol(NonReserved.SUM, false,
partitionAgg.getExpression()); //$NON-NLS-1$
- AggregateSymbol sumSqAgg = new AggregateSymbol(NonReserved.SUM, false,
new Function(SourceSystemFunctions.POWER, new Expression[] {partitionAgg.getExpression(),
new Constant(2)})); //$NON-NLS-1$
+ AggregateSymbol countAgg = new AggregateSymbol(NonReserved.COUNT, false,
partitionAgg.getArg(0));
+ AggregateSymbol sumAgg = new AggregateSymbol(NonReserved.SUM, false,
partitionAgg.getArg(0));
+ AggregateSymbol sumSqAgg = new AggregateSymbol(NonReserved.SUM, false,
new Function(SourceSystemFunctions.POWER, new Expression[] {partitionAgg.getArg(0), new
Constant(2)}));
- AggregateSymbol sumSumAgg = new AggregateSymbol(NonReserved.SUM, false,
sumAgg); //$NON-NLS-1$
- AggregateSymbol sumCountAgg = new AggregateSymbol(NonReserved.SUM, false,
countAgg); //$NON-NLS-1$
- AggregateSymbol sumSumSqAgg = new AggregateSymbol(NonReserved.SUM, false,
sumSqAgg); //$NON-NLS-1$
+ AggregateSymbol sumSumAgg = new AggregateSymbol(NonReserved.SUM, false,
sumAgg);
+ AggregateSymbol sumCountAgg = new AggregateSymbol(NonReserved.SUM, false,
countAgg);
+ AggregateSymbol sumSumSqAgg = new AggregateSymbol(NonReserved.SUM, false,
sumSqAgg);
Expression convertedSum = new Function(FunctionLibrary.CONVERT, new
Expression[] {sumSumAgg, new Constant(DataTypeManager.DefaultDataTypes.DOUBLE)});
@@ -926,7 +906,7 @@
nestedAggregates.add(sumSqAgg);
} else {
//AGG(X) -> AGG(AGG(X))
- newExpression = new AggregateSymbol(aggFunction.name(), false,
partitionAgg); //$NON-NLS-1$
+ newExpression = new AggregateSymbol(aggFunction.name(), false,
partitionAgg);
nestedAggregates.add(partitionAgg);
}
Modified:
trunk/engine/src/main/java/org/teiid/query/processor/relational/GroupingNode.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/processor/relational/GroupingNode.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/processor/relational/GroupingNode.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -26,6 +26,7 @@
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
@@ -93,7 +94,7 @@
// Collection phase
private int phase = COLLECTION;
private Map elementMap; // Map of incoming symbol to index in
source elements
- private List<Expression> collectedExpressions; // Collected
Expressions
+ private LinkedHashMap<Expression, Integer> collectedExpressions; //
Collected Expressions
private int distinctCols = -1;
// Sort phase
@@ -158,22 +159,19 @@
// Incoming elements and lookup map for evaluating expressions
List<? extends Expression> sourceElements =
this.getChildren()[0].getElements();
this.elementMap = createLookupMap(sourceElements);
-
+ this.collectedExpressions = new LinkedHashMap<Expression, Integer>();
// List should contain all grouping columns / expressions as we need those for
sorting
if(this.orderBy != null) {
- this.collectedExpressions = new
ArrayList<Expression>(this.orderBy.size() + getElements().size());
for (OrderByItem item : this.orderBy) {
Expression ex = SymbolMap.getExpression(item.getSymbol());
- this.collectedExpressions.add(ex);
+ getIndex(ex, this.collectedExpressions);
}
if (removeDuplicates) {
for (Expression ses : sourceElements) {
- collectExpression(SymbolMap.getExpression(ses));
+ getIndex(ses, collectedExpressions);
}
distinctCols = collectedExpressions.size();
}
- } else {
- this.collectedExpressions = new
ArrayList<Expression>(getElements().size());
}
// Construct aggregate function state accumulators
@@ -184,100 +182,117 @@
symbol = outputMapping.getMappedExpression((ElementSymbol)symbol);
}
Class<?> outputType = symbol.getType();
- Class<?> inputType = symbol.getType();
if(symbol instanceof AggregateSymbol) {
- AggregateSymbol aggSymbol = (AggregateSymbol) symbol;
- if(aggSymbol.getExpression() == null) {
- functions[i] = new Count();
- } else {
- Expression ex = aggSymbol.getExpression();
- inputType = ex.getType();
- int index = collectExpression(ex);
- Type function = aggSymbol.getAggregateFunction();
- switch (function) {
- case COUNT:
- functions[i] = new Count();
- break;
- case SUM:
- functions[i] = new Sum();
- break;
- case AVG:
- functions[i] = new Avg();
- break;
- case MIN:
- functions[i] = new Min();
- break;
- case MAX:
- functions[i] = new Max();
- break;
- case XMLAGG:
- functions[i] = new XMLAgg(context);
- break;
- case ARRAY_AGG:
- functions[i] = new ArrayAgg(context);
- break;
- case TEXTAGG:
- functions[i] = new TextAgg(context, (TextLine)ex);
- break;
- default:
- functions[i] = new StatsFunction(function);
- }
-
- if(aggSymbol.isDistinct()) {
- functions[i] = handleDistinct(functions[i], inputType,
getBufferManager(), getConnectionID());
- } else if (aggSymbol.getOrderBy() != null) { //handle the xmlagg
case
- int[] orderIndecies = new
int[aggSymbol.getOrderBy().getOrderByItems().size()];
- List<OrderByItem> orderByItems = new
ArrayList<OrderByItem>(orderIndecies.length);
- List<ElementSymbol> schema = new
ArrayList<ElementSymbol>(orderIndecies.length + 1);
- ElementSymbol element = new ElementSymbol("val");
//$NON-NLS-1$
- element.setType(inputType);
- schema.add(element);
- for (ListIterator<OrderByItem> iterator =
aggSymbol.getOrderBy().getOrderByItems().listIterator(); iterator.hasNext();) {
- OrderByItem item = iterator.next();
- orderIndecies[iterator.previousIndex()] =
collectExpression(item.getSymbol());
- element = new
ElementSymbol(String.valueOf(iterator.previousIndex()));
- element.setType(item.getSymbol().getType());
- schema.add(element);
- OrderByItem newItem = item.clone();
- newItem.setSymbol(element);
- orderByItems.add(newItem);
- }
- SortingFilter filter = new SortingFilter(functions[i],
getBufferManager(), getConnectionID(), false);
- filter.setIndecies(orderIndecies);
- filter.setElements(schema);
- filter.setSortItems(orderByItems);
- functions[i] = filter;
- }
- functions[i].setExpressionIndex(index);
- }
- if (aggSymbol.getCondition() != null) {
-
functions[i].setConditionIndex(collectExpression(aggSymbol.getCondition()));
- }
+ AggregateSymbol aggSymbol = (AggregateSymbol) symbol;
+ functions[i] = initAccumulator(context, aggSymbol, this,
this.collectedExpressions);
} else {
functions[i] = new ConstantFunction();
-
functions[i].setExpressionIndex(this.collectedExpressions.indexOf(symbol));
+ functions[i].setArgIndexes(new int[]
{this.collectedExpressions.get(symbol)});
+ functions[i].initialize(outputType, new
Class<?>[]{symbol.getType()});
}
- functions[i].initialize(outputType, inputType);
}
}
+
+ static Integer getIndex(Expression ex, LinkedHashMap<Expression, Integer>
expressionIndexes) {
+ Integer index = expressionIndexes.get(ex);
+ if (index == null) {
+ index = expressionIndexes.size();
+ expressionIndexes.put(ex, index);
+ }
+ return index;
+ }
- static SortingFilter handleDistinct(AggregateFunction af, Class<?> inputType,
BufferManager bm, String cid) {
- SortingFilter filter = new SortingFilter(af, bm, cid, true);
- ElementSymbol element = new ElementSymbol("val"); //$NON-NLS-1$
- element.setType(inputType);
- filter.setElements(Arrays.asList(element));
- return filter;
+ static AggregateFunction initAccumulator(CommandContext context,
+ AggregateSymbol aggSymbol, RelationalNode node, LinkedHashMap<Expression,
Integer> expressionIndexes) {
+ int[] argIndexes = new int[aggSymbol.getArgs().length];
+ AggregateFunction result = null;
+ Expression[] args = aggSymbol.getArgs();
+ Class<?>[] inputTypes = new Class[args.length];
+ for (int j = 0; j < args.length; j++) {
+ inputTypes[j] = args[j].getType();
+ argIndexes[j] = getIndex(args[j], expressionIndexes);
+ }
+ Type function = aggSymbol.getAggregateFunction();
+ switch (function) {
+ case RANK:
+ case DENSE_RANK:
+ result = new RankingFunction(function);
+ break;
+ case ROW_NUMBER: //same as count(*)
+ case COUNT:
+ result = new Count();
+ break;
+ case SUM:
+ result = new Sum();
+ break;
+ case AVG:
+ result = new Avg();
+ break;
+ case MIN:
+ result = new Min();
+ break;
+ case MAX:
+ result = new Max();
+ break;
+ case XMLAGG:
+ result = new XMLAgg(context);
+ break;
+ case ARRAY_AGG:
+ result = new ArrayAgg(context);
+ break;
+ case TEXTAGG:
+ result = new TextAgg(context, (TextLine)args[0]);
+ break;
+ default:
+ result = new StatsFunction(function);
+ }
+ if(aggSymbol.isDistinct()) {
+ SortingFilter filter = new SortingFilter(result, node.getBufferManager(),
node.getConnectionID(), true);
+ List<ElementSymbol> elements = createSortSchema(result, inputTypes);
+ filter.setElements(elements);
+ result = filter;
+ } else if (aggSymbol.getOrderBy() != null) {
+ int numOrderByItems = aggSymbol.getOrderBy().getOrderByItems().size();
+ List<OrderByItem> orderByItems = new
ArrayList<OrderByItem>(numOrderByItems);
+ List<ElementSymbol> schema = createSortSchema(result, inputTypes);
+ argIndexes = Arrays.copyOf(argIndexes, argIndexes.length + numOrderByItems);
+ for (ListIterator<OrderByItem> iterator =
aggSymbol.getOrderBy().getOrderByItems().listIterator(); iterator.hasNext();) {
+ OrderByItem item = iterator.next();
+ argIndexes[args.length + iterator.previousIndex()] = getIndex(item.getSymbol(),
expressionIndexes);
+ ElementSymbol element = new ElementSymbol(String.valueOf(iterator.previousIndex()));
+ element.setType(item.getSymbol().getType());
+ schema.add(element);
+ OrderByItem newItem = item.clone();
+ newItem.setSymbol(element);
+ orderByItems.add(newItem);
+ }
+ SortingFilter filter = new SortingFilter(result, node.getBufferManager(),
node.getConnectionID(), false);
+ filter.setElements(schema);
+ filter.setSortItems(orderByItems);
+ result = filter;
+ }
+ result.setArgIndexes(argIndexes);
+ if (aggSymbol.getCondition() != null) {
+ result.setConditionIndex(getIndex(aggSymbol.getCondition(), expressionIndexes));
+ }
+ result.initialize(aggSymbol.getType(), inputTypes);
+ return result;
}
- private int collectExpression(Expression ex) {
- int index = this.collectedExpressions.indexOf(ex);
- if(index == -1) {
- index = this.collectedExpressions.size();
- this.collectedExpressions.add(ex);
+ private static List<ElementSymbol> createSortSchema(AggregateFunction af,
+ Class<?>[] inputTypes) {
+ List<ElementSymbol> elements = new
ArrayList<ElementSymbol>(inputTypes.length);
+ int[] filteredArgIndexes = new int[inputTypes.length];
+ for (int i = 0; i < inputTypes.length; i++) {
+ ElementSymbol element = new ElementSymbol("val" + i); //$NON-NLS-1$
+ element.setType(inputTypes[i]);
+ elements.add(element);
+ filteredArgIndexes[i] = i;
}
- return index;
- }
-
+ af.setArgIndexes(filteredArgIndexes);
+ return elements;
+ }
+
AggregateFunction[] getFunctions() {
return functions;
}
@@ -306,7 +321,7 @@
public TupleSource getCollectionTupleSource() {
final RelationalNode sourceNode = this.getChildren()[0];
- return new ProjectingTupleSource(sourceNode, eval, collectedExpressions);
+ return new ProjectingTupleSource(sourceNode, eval, new
ArrayList<Expression>(collectedExpressions.keySet()));
}
private void collectionPhase() {
@@ -337,7 +352,7 @@
}
this.indexes = Arrays.copyOf(sortIndexes, orderBy.size());
this.sortUtility = new SortUtility(getCollectionTupleSource(),
removeDuplicates?Mode.DUP_REMOVE_SORT:Mode.SORT, getBufferManager(),
- getConnectionID(), collectedExpressions, sortTypes, nullOrdering,
sortIndexes);
+ getConnectionID(), new
ArrayList<Expression>(collectedExpressions.keySet()), sortTypes, nullOrdering,
sortIndexes);
this.phase = SORT;
}
}
Modified:
trunk/engine/src/main/java/org/teiid/query/processor/relational/SortingFilter.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/processor/relational/SortingFilter.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/processor/relational/SortingFilter.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -25,8 +25,6 @@
import java.util.ArrayList;
import java.util.List;
-import org.teiid.api.exception.query.ExpressionEvaluationException;
-import org.teiid.api.exception.query.FunctionExecutionException;
import org.teiid.common.buffer.BufferManager;
import org.teiid.common.buffer.TupleBuffer;
import org.teiid.common.buffer.TupleSource;
@@ -36,12 +34,12 @@
import org.teiid.query.function.aggregate.AggregateFunction;
import org.teiid.query.processor.relational.SortUtility.Mode;
import org.teiid.query.sql.lang.OrderByItem;
+import org.teiid.query.sql.symbol.ElementSymbol;
/**
*/
public class SortingFilter extends AggregateFunction {
- private static final int[] NO_INDECIES = new int[0];
// Initial setup - can be reused
private AggregateFunction proxy;
private BufferManager mgr;
@@ -49,11 +47,9 @@
private boolean removeDuplicates;
// Derived and static - can be reused
- private List elements;
+ private List<ElementSymbol> elements;
private List<OrderByItem> sortItems;
- private int[] indecies = NO_INDECIES;
-
// Temporary state - should be reset
private TupleBuffer collectionBuffer;
private SortUtility sortUtility;
@@ -70,27 +66,21 @@
this.removeDuplicates = removeDuplicates;
}
- public List getElements() {
+ public List<ElementSymbol> getElements() {
return elements;
}
- public void setElements(List elements) {
+ public void setElements(List<ElementSymbol> elements) {
this.elements = elements;
}
- public void setIndecies(int[] indecies) {
- this.indecies = indecies;
- }
-
public void setSortItems(List<OrderByItem> sortItems) {
this.sortItems = sortItems;
}
- /**
- * @see org.teiid.query.function.aggregate.AggregateFunction#initialize(String,
Class)
- */
- public void initialize(Class<?> dataType, Class<?> inputType) {
- this.proxy.initialize(dataType, inputType);
+ @Override
+ public void initialize(java.lang.Class<?> dataType, java.lang.Class<?>[]
inputTypes) {
+ this.proxy.initialize(dataType, inputTypes);
}
public void reset() {
@@ -107,17 +97,16 @@
}
@Override
- public void addInputDirect(Object input, List<?> tuple)
- throws FunctionExecutionException, ExpressionEvaluationException,
- TeiidComponentException, TeiidProcessingException {
+ public void addInputDirect(List<?> tuple)
+ throws TeiidComponentException, TeiidProcessingException {
if(collectionBuffer == null) {
collectionBuffer = mgr.createTupleBuffer(elements, groupName,
TupleSourceType.PROCESSOR);
collectionBuffer.setForwardOnly(true);
}
- List<Object> row = new ArrayList<Object>(1 + indecies.length);
- row.add(input);
- for (int i = 0; i < indecies.length; i++) {
- row.add(tuple.get(indecies[i]));
+ List<Object> row = new ArrayList<Object>(argIndexes.length);
+ //TODO remove overlap
+ for (int i = 0; i < argIndexes.length; i++) {
+ row.add(tuple.get(argIndexes[i]));
}
this.collectionBuffer.addTuple(row);
}
@@ -142,11 +131,12 @@
// Add all input to proxy
TupleSource sortedSource = sorted.createIndexedTupleSource();
while(true) {
- List tuple = sortedSource.nextTuple();
+ List<?> tuple = sortedSource.nextTuple();
if(tuple == null) {
break;
}
- this.proxy.addInputDirect(tuple.get(0), null);
+ //TODO should possibly remove the order by columns from this tuple
+ this.proxy.addInputDirect(tuple);
}
} finally {
sorted.remove();
Modified:
trunk/engine/src/main/java/org/teiid/query/processor/relational/WindowFunctionProjectNode.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/processor/relational/WindowFunctionProjectNode.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/processor/relational/WindowFunctionProjectNode.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -45,25 +45,13 @@
import org.teiid.language.SortSpecification.NullOrdering;
import org.teiid.query.eval.Evaluator;
import org.teiid.query.function.aggregate.AggregateFunction;
-import org.teiid.query.function.aggregate.ArrayAgg;
-import org.teiid.query.function.aggregate.Avg;
-import org.teiid.query.function.aggregate.Count;
-import org.teiid.query.function.aggregate.Max;
-import org.teiid.query.function.aggregate.Min;
-import org.teiid.query.function.aggregate.RankingFunction;
-import org.teiid.query.function.aggregate.StatsFunction;
-import org.teiid.query.function.aggregate.Sum;
-import org.teiid.query.function.aggregate.TextAgg;
-import org.teiid.query.function.aggregate.XMLAgg;
import org.teiid.query.processor.ProcessorDataManager;
import org.teiid.query.processor.relational.GroupingNode.ProjectingTupleSource;
import org.teiid.query.processor.relational.SortUtility.Mode;
import org.teiid.query.sql.lang.OrderBy;
import org.teiid.query.sql.lang.OrderByItem;
-import org.teiid.query.sql.symbol.AggregateSymbol;
import org.teiid.query.sql.symbol.ElementSymbol;
import org.teiid.query.sql.symbol.Expression;
-import org.teiid.query.sql.symbol.TextLine;
import org.teiid.query.sql.symbol.WindowFunction;
import org.teiid.query.sql.symbol.WindowSpecification;
import org.teiid.query.sql.symbol.AggregateSymbol.Type;
@@ -83,8 +71,6 @@
private static class WindowFunctionInfo {
WindowFunction function;
- int expressionIndex = -1;
- int conditionIndex = -1;
int outputIndex;
}
@@ -99,9 +85,9 @@
private LinkedHashMap<WindowSpecification, WindowSpecificationInfo> windows = new
LinkedHashMap<WindowSpecification, WindowSpecificationInfo>();
private LinkedHashMap<Expression, Integer> expressionIndexes;
- private LinkedHashMap<Integer, Integer> passThrough = new
LinkedHashMap<Integer, Integer>();
+ private List<int[]> passThrough = new ArrayList<int[]>();
- private Map elementMap;
+ private Map<Expression, Integer> elementMap;
//processing state
private Phase phase = Phase.COLLECT;
@@ -171,7 +157,7 @@
public void init() {
expressionIndexes = new LinkedHashMap<Expression, Integer>();
for (int i = 0; i < getElements().size(); i++) {
- Expression ex = SymbolMap.getExpression((Expression) getElements().get(i));
+ Expression ex = SymbolMap.getExpression(getElements().get(i));
if (ex instanceof WindowFunction) {
WindowFunction wf = (WindowFunction)ex;
WindowSpecification ws = wf.getWindowSpecification();
@@ -181,7 +167,7 @@
windows.put(wf.getWindowSpecification(), wsi);
if (ws.getPartition() != null) {
for (Expression ex1 : ws.getPartition()) {
- Integer index = getIndex(ex1);
+ Integer index = GroupingNode.getIndex(ex1, expressionIndexes);
wsi.groupIndexes.add(index);
wsi.orderType.add(OrderBy.ASC);
wsi.nullOrderings.add(null);
@@ -190,7 +176,7 @@
if (ws.getOrderBy() != null) {
for (OrderByItem item : ws.getOrderBy().getOrderByItems()) {
Expression ex1 = SymbolMap.getExpression(item.getSymbol());
- Integer index = getIndex(ex1);
+ Integer index = GroupingNode.getIndex(ex1, expressionIndexes);
wsi.sortIndexes.add(index);
wsi.orderType.add(item.isAscending());
wsi.nullOrderings.add(item.getNullOrdering());
@@ -199,13 +185,17 @@
}
WindowFunctionInfo wfi = new WindowFunctionInfo();
wfi.function = wf;
- ex = wf.getFunction().getExpression();
- if (ex != null) {
- wfi.expressionIndex = getIndex(ex);
+ //collect the agg expressions
+ for (Expression e : wf.getFunction().getArgs()) {
+ GroupingNode.getIndex(e, expressionIndexes);
}
+ if (wf.getFunction().getOrderBy() != null) {
+ for (OrderByItem item : wf.getFunction().getOrderBy().getOrderByItems()) {
+ GroupingNode.getIndex(item.getSymbol(), expressionIndexes);
+ }
+ }
if (wf.getFunction().getCondition() != null) {
- ex = wf.getFunction().getCondition();
- wfi.conditionIndex = getIndex(ex);
+ GroupingNode.getIndex(wf.getFunction().getCondition(), expressionIndexes);
}
wfi.outputIndex = i;
if (wf.getFunction().getAggregateFunction() == Type.ROW_NUMBER) {
@@ -214,8 +204,8 @@
wsi.functions.add(wfi);
}
} else {
- int index = getIndex(ex);
- passThrough.put(i, index);
+ int index = GroupingNode.getIndex(ex, expressionIndexes);
+ passThrough.add(new int[] {i, index});
}
}
}
@@ -249,8 +239,8 @@
for (int i = 0; i < size; i++) {
outputRow.add(null);
}
- for (Map.Entry<Integer, Integer> entry : passThrough.entrySet()) {
- outputRow.set(entry.getKey(), tuple.get(entry.getValue()));
+ for (int[] entry : passThrough) {
+ outputRow.set(entry[0], tuple.get(entry[1]));
}
List<Map.Entry<WindowSpecification, WindowSpecificationInfo>> specs = new
ArrayList<Map.Entry<WindowSpecification,WindowSpecificationInfo>>(windows.entrySet());
for (int specIndex = 0; specIndex < specs.size(); specIndex++) {
@@ -398,7 +388,6 @@
}
/**
- * TODO: consolidate with {@link GroupingNode}
* @param functions
* @param specIndex
* @param rowValues
@@ -409,77 +398,19 @@
if (functions.isEmpty()) {
return aggs;
}
- //initialize the function accumulators
- List<ElementSymbol> elements = new
ArrayList<ElementSymbol>(functions.size() + 1);
- ElementSymbol key = new ElementSymbol("id"); //$NON-NLS-1$
- key.setType(DataTypeManager.DefaultDataClasses.INTEGER);
- elements.add(key);
-
- CommandContext context = this.getContext();
+ List<ElementSymbol> elements = new
ArrayList<ElementSymbol>(functions.size());
for (WindowFunctionInfo wfi : functions) {
- AggregateSymbol aggSymbol = wfi.function.getFunction();
- Class<?> outputType = aggSymbol.getType();
+ aggs.add(GroupingNode.initAccumulator(this.getContext(), wfi.function.getFunction(),
this, expressionIndexes));
+ Class<?> outputType = wfi.function.getType();
ElementSymbol value = new ElementSymbol("val"); //$NON-NLS-1$
value.setType(outputType);
elements.add(value);
- Class<?> inputType = aggSymbol.getType();
- if (aggSymbol.getExpression() != null) {
- inputType = aggSymbol.getExpression().getType();
- }
- Type function = aggSymbol.getAggregateFunction();
- AggregateFunction af = null;
- switch (function) {
- case RANK:
- case DENSE_RANK:
- af = new RankingFunction(function);
- break;
- case ROW_NUMBER: //same as count(*)
- case COUNT:
- af = new Count();
- break;
- case SUM:
- af = new Sum();
- break;
- case AVG:
- af = new Avg();
- break;
- case MIN:
- af = new Min();
- break;
- case MAX:
- af = new Max();
- break;
- case XMLAGG:
- af = new XMLAgg(context);
- break;
- case ARRAY_AGG:
- af = new ArrayAgg(context);
- break;
- case TEXTAGG:
- af = new TextAgg(context, (TextLine)aggSymbol.getExpression());
- break;
- default:
- af = new StatsFunction(function);
- }
-
- if(aggSymbol.isDistinct()) {
- af = GroupingNode.handleDistinct(af, inputType, getBufferManager(),
getConnectionID());
- }
-
- af.setExpressionIndex(wfi.expressionIndex);
- af.setConditionIndex(wfi.conditionIndex);
- af.initialize(outputType, inputType);
- aggs.add(af);
}
-
- if (!aggs.isEmpty()) {
- if (!rowValues) {
- valueMapping[specIndex] = this.getBufferManager().createSTree(elements,
this.getConnectionID(), 1);
- } else {
- rowValueMapping[specIndex] = this.getBufferManager().createSTree(elements,
this.getConnectionID(), 1);
- }
+ if (!rowValues) {
+ valueMapping[specIndex] = this.getBufferManager().createSTree(elements,
this.getConnectionID(), 1);
+ } else {
+ rowValueMapping[specIndex] = this.getBufferManager().createSTree(elements,
this.getConnectionID(), 1);
}
-
return aggs;
}
@@ -531,20 +462,11 @@
inputTs = null;
}
- private Integer getIndex(Expression ex) {
- Integer index = expressionIndexes.get(ex);
- if (index == null) {
- index = expressionIndexes.size();
- expressionIndexes.put(ex, index);
- }
- return index;
- }
-
@Override
public void initialize(CommandContext context, BufferManager bufferManager,
ProcessorDataManager dataMgr) {
super.initialize(context, bufferManager, dataMgr);
- List sourceElements = this.getChildren()[0].getElements();
+ List<? extends Expression> sourceElements = this.getChildren()[0].getElements();
this.elementMap = createLookupMap(sourceElements);
}
Modified: trunk/engine/src/main/java/org/teiid/query/rewriter/QueryRewriter.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/rewriter/QueryRewriter.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/rewriter/QueryRewriter.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -2002,13 +2002,13 @@
if (expression.isDistinct()) {
expression.setDistinct(false);
}
- if (rewriteAggs && expression.getExpression() != null &&
EvaluatableVisitor.willBecomeConstant(expression.getExpression())) {
- return expression.getExpression();
+ if (rewriteAggs && expression.getArg(0) != null &&
EvaluatableVisitor.willBecomeConstant(expression.getArg(0))) {
+ return expression.getArg(0);
}
}
- if (expression.getExpression() != null && expression.getCondition() != null
&& !expression.respectsNulls()) {
+ if (expression.getArgs().length == 1 && expression.getCondition() != null
&& !expression.respectsNulls()) {
Expression cond = expression.getCondition();
- Expression ex = expression.getExpression();
+ Expression ex = expression.getArg(0);
if (!(cond instanceof Criteria)) {
cond = new ExpressionCriteria(cond);
}
Modified: trunk/engine/src/main/java/org/teiid/query/sql/LanguageObject.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/sql/LanguageObject.java 2012-03-28 15:29:46
UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/sql/LanguageObject.java 2012-03-29 01:43:03
UTC (rev 3956)
@@ -23,6 +23,7 @@
package org.teiid.query.sql;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
/**
@@ -57,6 +58,19 @@
}
return result;
}
+
+ @SuppressWarnings("unchecked")
+ public static <T extends LanguageObject> T[] deepClone(T[] collection) {
+ if (collection == null) {
+ return null;
+ }
+ T[] copy = Arrays.copyOf(collection, collection.length);
+ for (int i = 0; i < copy.length; i++) {
+ LanguageObject t = copy[i];
+ copy[i] = (T) t.clone();
+ }
+ return copy;
+ }
}
Modified:
trunk/engine/src/main/java/org/teiid/query/sql/navigator/PreOrPostOrderNavigator.java
===================================================================
---
trunk/engine/src/main/java/org/teiid/query/sql/navigator/PreOrPostOrderNavigator.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/main/java/org/teiid/query/sql/navigator/PreOrPostOrderNavigator.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -26,86 +26,9 @@
import org.teiid.query.sql.LanguageObject;
import org.teiid.query.sql.LanguageVisitor;
-import org.teiid.query.sql.lang.AlterProcedure;
-import org.teiid.query.sql.lang.AlterTrigger;
-import org.teiid.query.sql.lang.AlterView;
-import org.teiid.query.sql.lang.ArrayTable;
-import org.teiid.query.sql.lang.BatchedUpdateCommand;
-import org.teiid.query.sql.lang.BetweenCriteria;
-import org.teiid.query.sql.lang.CompareCriteria;
-import org.teiid.query.sql.lang.CompoundCriteria;
-import org.teiid.query.sql.lang.Create;
-import org.teiid.query.sql.lang.Delete;
-import org.teiid.query.sql.lang.DependentSetCriteria;
-import org.teiid.query.sql.lang.Drop;
-import org.teiid.query.sql.lang.DynamicCommand;
-import org.teiid.query.sql.lang.ExistsCriteria;
-import org.teiid.query.sql.lang.ExpressionCriteria;
-import org.teiid.query.sql.lang.From;
-import org.teiid.query.sql.lang.GroupBy;
-import org.teiid.query.sql.lang.Insert;
-import org.teiid.query.sql.lang.Into;
-import org.teiid.query.sql.lang.IsNullCriteria;
-import org.teiid.query.sql.lang.JoinPredicate;
-import org.teiid.query.sql.lang.JoinType;
-import org.teiid.query.sql.lang.Limit;
-import org.teiid.query.sql.lang.MatchCriteria;
-import org.teiid.query.sql.lang.NotCriteria;
-import org.teiid.query.sql.lang.Option;
-import org.teiid.query.sql.lang.OrderBy;
-import org.teiid.query.sql.lang.OrderByItem;
-import org.teiid.query.sql.lang.Query;
-import org.teiid.query.sql.lang.SPParameter;
-import org.teiid.query.sql.lang.Select;
-import org.teiid.query.sql.lang.SetClause;
-import org.teiid.query.sql.lang.SetClauseList;
-import org.teiid.query.sql.lang.SetCriteria;
-import org.teiid.query.sql.lang.SetQuery;
-import org.teiid.query.sql.lang.StoredProcedure;
-import org.teiid.query.sql.lang.SubqueryCompareCriteria;
-import org.teiid.query.sql.lang.SubqueryFromClause;
-import org.teiid.query.sql.lang.SubquerySetCriteria;
-import org.teiid.query.sql.lang.TextTable;
-import org.teiid.query.sql.lang.UnaryFromClause;
-import org.teiid.query.sql.lang.Update;
-import org.teiid.query.sql.lang.WithQueryCommand;
-import org.teiid.query.sql.lang.XMLTable;
-import org.teiid.query.sql.proc.AssignmentStatement;
-import org.teiid.query.sql.proc.Block;
-import org.teiid.query.sql.proc.BranchingStatement;
-import org.teiid.query.sql.proc.CommandStatement;
-import org.teiid.query.sql.proc.CreateProcedureCommand;
-import org.teiid.query.sql.proc.DeclareStatement;
-import org.teiid.query.sql.proc.IfStatement;
-import org.teiid.query.sql.proc.LoopStatement;
-import org.teiid.query.sql.proc.RaiseErrorStatement;
-import org.teiid.query.sql.proc.TriggerAction;
-import org.teiid.query.sql.proc.WhileStatement;
-import org.teiid.query.sql.symbol.AggregateSymbol;
-import org.teiid.query.sql.symbol.AliasSymbol;
-import org.teiid.query.sql.symbol.CaseExpression;
-import org.teiid.query.sql.symbol.Constant;
-import org.teiid.query.sql.symbol.DerivedColumn;
-import org.teiid.query.sql.symbol.ElementSymbol;
-import org.teiid.query.sql.symbol.Expression;
-import org.teiid.query.sql.symbol.ExpressionSymbol;
-import org.teiid.query.sql.symbol.Function;
-import org.teiid.query.sql.symbol.GroupSymbol;
-import org.teiid.query.sql.symbol.MultipleElementSymbol;
-import org.teiid.query.sql.symbol.QueryString;
-import org.teiid.query.sql.symbol.Reference;
-import org.teiid.query.sql.symbol.ScalarSubquery;
-import org.teiid.query.sql.symbol.SearchedCaseExpression;
-import org.teiid.query.sql.symbol.TextLine;
-import org.teiid.query.sql.symbol.WindowFunction;
-import org.teiid.query.sql.symbol.WindowSpecification;
-import org.teiid.query.sql.symbol.XMLAttributes;
-import org.teiid.query.sql.symbol.XMLElement;
-import org.teiid.query.sql.symbol.XMLForest;
-import org.teiid.query.sql.symbol.XMLNamespaces;
-import org.teiid.query.sql.symbol.XMLParse;
-import org.teiid.query.sql.symbol.XMLQuery;
-import org.teiid.query.sql.symbol.XMLSerialize;
+import org.teiid.query.sql.lang.*;
+import org.teiid.query.sql.proc.*;
+import org.teiid.query.sql.symbol.*;
@@ -140,7 +63,12 @@
public void visit(AggregateSymbol obj) {
preVisitVisitor(obj);
- visitNode(obj.getExpression());
+ Expression[] args = obj.getArgs();
+ if(args != null) {
+ for(int i=0; i<args.length; i++) {
+ visitNode(args[i]);
+ }
+ }
visitNode(obj.getOrderBy());
visitNode(obj.getCondition());
postVisitVisitor(obj);
Modified: trunk/engine/src/main/java/org/teiid/query/sql/symbol/AggregateSymbol.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/sql/symbol/AggregateSymbol.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/sql/symbol/AggregateSymbol.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -25,11 +25,13 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
+import java.util.TreeMap;
import org.teiid.core.types.DataTypeManager;
import org.teiid.core.util.EquivalenceUtil;
import org.teiid.core.util.HashCodeUtil;
import org.teiid.query.parser.SQLParserUtil;
+import org.teiid.query.sql.LanguageObject;
import org.teiid.query.sql.LanguageVisitor;
import org.teiid.query.sql.lang.OrderBy;
@@ -44,6 +46,8 @@
*/
public class AggregateSymbol extends Function implements DerivedExpression {
+ private static final Expression[] EMPTY_ARGS = new Expression[0];
+
public enum Type {
COUNT,
SUM,
@@ -65,6 +69,17 @@
ROW_NUMBER,
USER_DEFINED;
}
+
+ private static final Map<String, Type> nameMap = new TreeMap<String,
Type>(String.CASE_INSENSITIVE_ORDER);
+
+ static {
+ for (Type t : Type.values()) {
+ if (t == Type.USER_DEFINED) {
+ continue;
+ }
+ nameMap.put(t.name(), t);
+ }
+ }
private Type aggregate;
private boolean distinct;
@@ -104,8 +119,8 @@
* @param canonicalName
* @since 4.3
*/
- protected AggregateSymbol(String name, Type aggregateFunction, boolean isDistinct,
Expression expression) {
- super(name, expression == null?new Expression[0]:new Expression[] {expression});
+ protected AggregateSymbol(String name, Type aggregateFunction, boolean isDistinct,
Expression[] args) {
+ super(name, args);
this.aggregate = aggregateFunction;
this.distinct = isDistinct;
}
@@ -117,10 +132,9 @@
* @param expression Contained expression
*/
public AggregateSymbol(String aggregateFunction, boolean isDistinct, Expression
expression) {
- super(aggregateFunction, expression == null?new Expression[0]:new Expression[]
{expression});
- try {
- this.aggregate = Type.valueOf(aggregateFunction.toUpperCase());
- } catch (IllegalArgumentException e) {
+ super(aggregateFunction, expression == null?EMPTY_ARGS:new Expression[] {expression});
+ this.aggregate = nameMap.get(aggregateFunction);
+ if (this.aggregate == null) {
this.aggregate = Type.USER_DEFINED;
}
this.distinct = isDistinct;
@@ -172,10 +186,10 @@
case COUNT:
return COUNT_TYPE;
case SUM:
- Class<?> expressionType = this.getExpression().getType();
+ Class<?> expressionType = this.getArg(0).getType();
return SUM_TYPES.get(expressionType);
case AVG:
- expressionType = this.getExpression().getType();
+ expressionType = this.getArg(0).getType();
return AVG_TYPES.get(expressionType);
case ARRAY_AGG:
return DataTypeManager.DefaultDataClasses.OBJECT;
@@ -191,7 +205,7 @@
if (isAnalytical()) {
return DataTypeManager.DefaultDataClasses.INTEGER;
}
- return this.getExpression().getType();
+ return this.getArg(0).getType();
}
public boolean isAnalytical() {
@@ -233,12 +247,7 @@
* Return a deep copy of this object
*/
public Object clone() {
- AggregateSymbol copy = null;
- if(getExpression() != null) {
- copy = new AggregateSymbol(getName(), getAggregateFunction(), isDistinct(),
(Expression) getExpression().clone());
- } else {
- copy = new AggregateSymbol(getName(), getAggregateFunction(), isDistinct(), null);
- }
+ AggregateSymbol copy = new AggregateSymbol(getName(), getAggregateFunction(),
isDistinct(), LanguageObject.Util.deepClone(getArgs()));
if (orderBy != null) {
copy.setOrderBy(orderBy.clone());
}
@@ -246,6 +255,8 @@
copy.setCondition((Expression) condition.clone());
}
copy.isWindowed = this.isWindowed;
+ copy.setType(getType());
+ copy.setFunctionDescriptor(getFunctionDescriptor());
return copy;
}
@@ -254,7 +265,7 @@
*/
public int hashCode() {
int hasCode = HashCodeUtil.hashCode(aggregate.hashCode(), distinct);
- return HashCodeUtil.hashCode(hasCode, this.getExpression());
+ return HashCodeUtil.hashCode(hasCode, super.hashCode());
}
/**
@@ -267,10 +278,10 @@
AggregateSymbol other = (AggregateSymbol)obj;
- return this.aggregate.equals(other.aggregate)
+ return super.equals(obj)
+ && this.aggregate.equals(other.aggregate)
&& this.distinct == other.distinct
&& this.isWindowed == other.isWindowed
- && EquivalenceUtil.areEqual(this.getExpression(),
other.getExpression())
&& EquivalenceUtil.areEqual(this.condition, other.condition)
&& EquivalenceUtil.areEqual(this.getOrderBy(), other.getOrderBy());
}
@@ -286,6 +297,8 @@
case SOME:
case EVERY:
return false;
+ case USER_DEFINED:
+ return getFunctionDescriptor().getMethod().getAggregateAttributes().usesAllRows();
}
return true;
}
@@ -323,6 +336,8 @@
return false;
case XMLAGG:
return orderBy == null;
+ case USER_DEFINED:
+ return false;
}
return true;
}
@@ -335,15 +350,4 @@
this.isWindowed = isWindowed;
}
- @Deprecated
- public Expression getExpression() {
- if (this.getArgs().length == 0) {
- return null;
- }
- if (this.getArgs().length > 1) {
- throw new AssertionError("getExpression should not be used with a non-unary
aggregate"); //$NON-NLS-1$
- }
- return this.getArg(0);
- }
-
}
Modified: trunk/engine/src/main/java/org/teiid/query/sql/symbol/Function.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/sql/symbol/Function.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/sql/symbol/Function.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -26,6 +26,7 @@
import org.teiid.core.util.EquivalenceUtil;
import org.teiid.core.util.HashCodeUtil;
import org.teiid.query.function.FunctionDescriptor;
+import org.teiid.query.sql.LanguageObject;
import org.teiid.query.sql.LanguageVisitor;
import org.teiid.query.sql.visitor.SQLStringVisitor;
@@ -224,13 +225,7 @@
* @return Deep copy of the object
*/
public Object clone() {
- Expression[] copyArgs = new Expression[args.length];
- for(int i=0; i<args.length; i++) {
- if(args[i] != null) {
- copyArgs[i] = (Expression) args[i].clone();
- }
- }
-
+ Expression[] copyArgs = LanguageObject.Util.deepClone(this.args);
Function copy = new Function(getName(), copyArgs);
copy.setType(getType());
copy.setFunctionDescriptor(getFunctionDescriptor());
Modified: trunk/engine/src/main/java/org/teiid/query/validator/ValidationVisitor.java
===================================================================
--- trunk/engine/src/main/java/org/teiid/query/validator/ValidationVisitor.java 2012-03-28
15:29:46 UTC (rev 3955)
+++ trunk/engine/src/main/java/org/teiid/query/validator/ValidationVisitor.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -987,9 +987,11 @@
Expression condition = obj.getCondition();
validateNoSubqueriesOrOuterReferences(condition);
}
- Expression aggExp = obj.getExpression();
-
- validateNoNestedAggs(aggExp);
+ Expression[] aggExps = obj.getArgs();
+
+ for (Expression expression : aggExps) {
+ validateNoNestedAggs(expression);
+ }
validateNoNestedAggs(obj.getOrderBy());
validateNoNestedAggs(obj.getCondition());
@@ -998,17 +1000,17 @@
if((aggregateFunction == Type.SUM || aggregateFunction == Type.AVG) &&
obj.getType() == null) {
handleValidationError(QueryPlugin.Util.getString("ERR.015.012.0041", new
Object[] {aggregateFunction, obj}), obj); //$NON-NLS-1$
} else if (obj.getType() != DataTypeManager.DefaultDataClasses.NULL) {
- if (aggregateFunction == Type.XMLAGG && aggExp.getType() !=
DataTypeManager.DefaultDataClasses.XML) {
+ if (aggregateFunction == Type.XMLAGG && aggExps[0].getType() !=
DataTypeManager.DefaultDataClasses.XML) {
handleValidationError(QueryPlugin.Util.getString("AggregateValidationVisitor.non_xml",
new Object[] {aggregateFunction, obj}), obj); //$NON-NLS-1$
- } else if (obj.isBoolean() && aggExp.getType() !=
DataTypeManager.DefaultDataClasses.BOOLEAN) {
+ } else if (obj.isBoolean() && aggExps[0].getType() !=
DataTypeManager.DefaultDataClasses.BOOLEAN) {
handleValidationError(QueryPlugin.Util.getString("AggregateValidationVisitor.non_boolean",
new Object[] {aggregateFunction, obj}), obj); //$NON-NLS-1$
}
}
- if((obj.isDistinct() || aggregateFunction == Type.MIN || aggregateFunction ==
Type.MAX) &&
DataTypeManager.isNonComparable(DataTypeManager.getDataTypeName(aggExp.getType()))) {
+ if((obj.isDistinct() || aggregateFunction == Type.MIN || aggregateFunction ==
Type.MAX) &&
DataTypeManager.isNonComparable(DataTypeManager.getDataTypeName(aggExps[0].getType()))) {
handleValidationError(QueryPlugin.Util.getString("AggregateValidationVisitor.non_comparable",
new Object[] {aggregateFunction, obj}), obj); //$NON-NLS-1$
}
if(obj.isEnhancedNumeric()) {
- if (!Number.class.isAssignableFrom(aggExp.getType())) {
+ if (!Number.class.isAssignableFrom(aggExps[0].getType())) {
handleValidationError(QueryPlugin.Util.getString("ERR.015.012.0041",
new Object[] {aggregateFunction, obj}), obj); //$NON-NLS-1$
}
if (obj.isDistinct()) {
@@ -1018,7 +1020,7 @@
if (obj.getAggregateFunction() != Type.TEXTAGG) {
return;
}
- TextLine tl = (TextLine)obj.getExpression();
+ TextLine tl = (TextLine)aggExps[0];
if (tl.isIncludeHeader()) {
validateDerivedColumnNames(obj, tl.getExpressions());
}
Modified: trunk/engine/src/main/resources/org/teiid/query/i18n.properties
===================================================================
--- trunk/engine/src/main/resources/org/teiid/query/i18n.properties 2012-03-28 15:29:46
UTC (rev 3955)
+++ trunk/engine/src/main/resources/org/teiid/query/i18n.properties 2012-03-29 01:43:03
UTC (rev 3956)
@@ -1001,8 +1001,6 @@
TEIID30410=Parse Exception occurs for executing: {0} {1}
TEIID30413=Unable to evaluate {0}: expected Properties for command payload but got object
of type {1}
TEIID30416=Expected a java.sql.Array, or java array type, but got: {0}
-TEIID30425=Unable to compute aggregate function {0} on data of type {1}
-TEIID30426=Unable to compute aggregate function {0} on data of type {1}
TEIID30431={0} has invalid character: {1}
TEIID30449=Invalid escape sequence "{0}" with escape character "{1}"
TEIID30451=Unable to evaluate {0}: {1}
Modified:
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestDuplicateFilter.java
===================================================================
---
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestDuplicateFilter.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestDuplicateFilter.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -22,62 +22,54 @@
package org.teiid.query.processor.relational;
+import static org.junit.Assert.*;
+
import java.util.Arrays;
+import org.junit.Test;
import org.teiid.common.buffer.BufferManager;
import org.teiid.common.buffer.BufferManagerFactory;
import org.teiid.core.TeiidComponentException;
import org.teiid.core.TeiidProcessingException;
import org.teiid.core.types.DataTypeManager;
import org.teiid.query.function.aggregate.Count;
-import org.teiid.query.processor.relational.SortingFilter;
import org.teiid.query.sql.symbol.ElementSymbol;
-import junit.framework.TestCase;
-
-
/**
*/
-public class TestDuplicateFilter extends TestCase {
+public class TestDuplicateFilter {
- /**
- * Constructor for TestDuplicateFilter.
- * @param arg0
- */
- public TestDuplicateFilter(String arg0) {
- super(arg0);
- }
-
- public void helpTestDuplicateFilter(Object[] input, Class dataType, int expected)
throws TeiidComponentException, TeiidProcessingException {
+ public void helpTestDuplicateFilter(Object[] input, Class<?> dataType, int
expected) throws TeiidComponentException, TeiidProcessingException {
BufferManager mgr = BufferManagerFactory.getStandaloneBufferManager();
SortingFilter filter = new SortingFilter(new Count(), mgr, "test",
true); //$NON-NLS-1$
- filter.initialize(dataType, dataType);
+ filter.initialize(dataType, new Class[] {dataType});
ElementSymbol element = new ElementSymbol("val"); //$NON-NLS-1$
element.setType(dataType);
filter.setElements(Arrays.asList(element));
+ filter.setArgIndexes(new int[] {0});
filter.reset();
// Add inputs
for(int i=0; i<input.length; i++) {
- filter.addInputDirect(input[i], null);
+ filter.addInputDirect(Arrays.asList(input[i]));
}
Integer actual = (Integer) filter.getResult();
assertEquals("Did not get expected number of results", expected,
actual.intValue()); //$NON-NLS-1$
}
- public void testNoInputs() throws Exception {
+ @Test public void testNoInputs() throws Exception {
helpTestDuplicateFilter(new Object[0], DataTypeManager.DefaultDataClasses.STRING,
0);
}
- public void testSmall() throws Exception {
+ @Test public void testSmall() throws Exception {
Object[] input = new Object[] { "a", "b", "a",
"c", "a", "c", "c", "f" }; //$NON-NLS-1$
//$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$ //$NON-NLS-5$ //$NON-NLS-6$ //$NON-NLS-7$
//$NON-NLS-8$
helpTestDuplicateFilter(input, DataTypeManager.DefaultDataClasses.STRING, 4);
}
- public void testBig() throws Exception {
+ @Test public void testBig() throws Exception {
int NUM_VALUES = 10000;
int NUM_OUTPUT = 200;
Object[] input = new Object[NUM_VALUES];
Modified:
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestGroupingNode.java
===================================================================
---
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestGroupingNode.java 2012-03-28
15:29:46 UTC (rev 3955)
+++
trunk/engine/src/test/java/org/teiid/query/processor/relational/TestGroupingNode.java 2012-03-29
01:43:03 UTC (rev 3956)
@@ -174,7 +174,7 @@
AggregateFunction[] functions = node.getFunctions();
AggregateFunction countDist = functions[5];
SortingFilter dup = (SortingFilter)countDist;
- assertEquals(DataTypeManager.DefaultDataClasses.INTEGER,
((ElementSymbol)dup.getElements().get(0)).getType());
+ assertEquals(DataTypeManager.DefaultDataClasses.INTEGER,
dup.getElements().get(0).getType());
}
@Test public void test2() throws Exception {