diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java
index ca1cda7f78336c92bba9dfc5fa216436a663e481..a3e832046227da221745e0023bed1b3cac7d97fe 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java
@@ -14,26 +14,26 @@ package com.oracle.truffle.r.nodes.builtin.base;
 
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asDoubleVector;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue;
-import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
-import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notIntNA;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
 
-import com.oracle.truffle.api.CompilerDirectives;
+import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
 import com.oracle.truffle.r.runtime.RError.Message;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
-import com.oracle.truffle.r.runtime.data.RComplex;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
+import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
 import com.oracle.truffle.r.runtime.data.RNull;
-import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
+import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
+import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
+import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
 
 @RBuiltin(name = "diag", kind = INTERNAL, parameterNames = {"x", "nrow", "ncol"}, behavior = PURE)
 public abstract class Diag extends RBuiltinNode.Arg3 {
@@ -48,7 +48,6 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
     private int checkX(RAbstractVector x, int nrow, int ncol) {
         int mn = (nrow < ncol) ? nrow : ncol;
         if (mn > 0 && x.getLength() == 0) {
-            CompilerDirectives.transferToInterpreter();
             throw error(Message.POSITIVE_LENGTH, "x");
         }
         return mn;
@@ -56,8 +55,8 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
 
     @Specialization
     protected Object diag(@SuppressWarnings("unused") RNull x, int nrow, int ncol) {
-        if (nrow == 0 && ncol == 0) {
-            return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{0, 0});
+        if (nrow == 0 || ncol == 0) {
+            return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{nrow, ncol});
         } else {
             throw error(Message.X_NUMERIC);
         }
@@ -74,43 +73,31 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
         return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x), new int[]{nrow, ncol});
     }
 
-    @Specialization
-    protected Object diag(RAbstractComplexVector x, int nrow, int ncol) {
+    @Specialization(guards = "xAccess.supports(x)")
+    protected RAbstractVector diagCached(RAbstractVector x, int nrow, int ncol,
+                    @Cached("x.access()") VectorAccess xAccess,
+                    @Cached("createNew(xAccess.getType())") VectorAccess resultAccess,
+                    @Cached("create()") VectorFactory factory) {
         int mn = checkX(x, nrow, ncol);
 
-        double[] data = new double[nrow * ncol * 2];
-        int nx = x.getLength();
-        for (int j = 0; j < mn; j++) {
-            RComplex value = x.getDataAt(j % nx);
-            int index = j * (nrow + 1) * 2;
-            data[index] = value.getRealPart();
-            data[index + 1] = value.getImaginaryPart();
-        }
-        return RDataFactory.createComplexVector(data, x.isComplete(), new int[]{nrow, ncol});
-    }
-
-    @Specialization
-    protected Object diag(RAbstractDoubleVector source, int nrow, int ncol) {
-        int mn = checkX(source, nrow, ncol);
-
-        double[] data = new double[nrow * ncol];
-        int nx = source.getLength();
-        for (int j = 0; j < mn; j++) {
-            data[j * (nrow + 1)] = source.getDataAt(j % nx);
+        try (SequentialIterator xIter = xAccess.access(x)) {
+            RAbstractVector result = factory.createUninitializedVector(xAccess.getType(), nrow * ncol, new int[]{nrow, ncol}, null, null);
+            try (RandomIterator resultIter = resultAccess.randomAccess(result)) {
+                int resultIndex = 0;
+                for (int j = 0; j < mn; j++) {
+                    xAccess.nextWithWrap(xIter);
+                    resultAccess.setFromSameType(resultIter, resultIndex, xAccess, xIter);
+                    resultIndex += nrow + 1;
+                }
+            }
+            result.setComplete(x.isComplete());
+            return result;
         }
-        return RDataFactory.createDoubleVector(data, source.isComplete(), new int[]{nrow, ncol});
     }
 
-    @Specialization
-    protected Object diag(RAbstractLogicalVector source, int nrow, int ncol) {
-        int mn = checkX(source, nrow, ncol);
-
-        byte[] data = new byte[nrow * ncol];
-        int nx = source.getLength();
-        for (int j = 0; j < mn; j++) {
-            data[j * (nrow + 1)] = source.getDataAt(j % nx);
-        }
-        return RDataFactory.createLogicalVector(data, source.isComplete(), new int[]{nrow, ncol});
+    @Specialization(replaces = "diagCached")
+    protected RAbstractVector diagGeneric(RAbstractVector x, int nrow, int ncol,
+                    @Cached("create()") VectorFactory factory) {
+        return diagCached(x, nrow, ncol, x.slowPathAccess(), VectorAccess.createSlowPathNew(x.getRType()), factory);
     }
-
 }